In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import autorootcwd

In [None]:

import torch
import os
from torchmetrics import MetricCollection
from torchmetrics.image import (
    StructuralSimilarityIndexMeasure as SSIM,
    PeakSignalNoiseRatio as PSNR,
)
from src.data.components import PairedDataset
from torch.utils.data import Subset
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from src.eval import PieAPP
from tqdm import tqdm
from src.models import transforms as CT

In [None]:
metrics = MetricCollection(
    {
        "ssim": SSIM(),
        "psnr": PSNR(),
        "lpips": LPIPS(),
        "pieapp": PieAPP(),
    }
)

In [None]:
# Define 2 pairs of random tensors at different sizes
x = torch.ones(1, 3, 64, 64)
y = x + 0.1 * torch.randn_like(x)
x2 = torch.ones(1, 3, 256, 256)
y2 = x2 + 0.1 * torch.randn_like(x2)

# Print smaller image metrics
print("Metrics for smaller images")
print(metrics(x, y))

# Print larger image metrics
print("Metrics for larger images")
print(metrics(x2, y2))

In [None]:
# Constants
RAW_DIR = os.getcwd()
DATA_DIR = os.path.join(RAW_DIR, 'data')

film_paired_dir = os.path.join(DATA_DIR, "paired", "processed", "film")
digital_paired_dir = os.path.join(DATA_DIR, "paired", "processed", "digital")
digital_film_data = PairedDataset(image_dirs=(film_paired_dir, digital_paired_dir))
film_0, digital_0 = digital_film_data[0]
digital_film_subset = Subset(digital_film_data, range(1))

In [None]:
def get_valid_dim(dim: int, downsample: int = 1) -> int:
    """
    Returns the nearest multiple of 8 that is less than or equal to the
    input dimension. This is required because of the network architecture.

    Args:
        dim (int): The input dimension

    Returns:
        int: The nearest multiple of 4 that is less than or equal to the input
    """
    adjusted_dim = dim // downsample
    valid_dim = (adjusted_dim // 8) * 8
    return valid_dim

In [None]:
downsample = [12, 16, 32]

all_metrics = {}
for i, (film, digital) in enumerate(tqdm(digital_film_data)):
    for sample in downsample:
        height = CT.get_valid_dim(film.size[1], downsample=sample)
        width = CT.get_valid_dim(film.size[0], downsample=sample)
        film_transform = CT.TestTransforms(dim=(height, width))(film)
        digital_transform = CT.TestTransforms(dim=(height, width))(digital)
        film_transform = film_transform.unsqueeze(0)
        digital_transform = digital_transform.unsqueeze(0)
        results = metrics(film_transform, digital_transform)
        for metric, score in results.items():
            all_metrics.setdefault(metric, {}).setdefault(sample, []).append(score)


import numpy as np
print("Mean scores for each down sample level as a pandas DataFrame")
mean_scores = {}
for metric, scores in all_metrics.items():
    mean_scores[metric] = {}
    for sample, sample_scores in scores.items():
        mean_scores[metric][sample] = np.mean(sample_scores)

import pandas as pd
df = pd.DataFrame(mean_scores, index=downsample, columns=all_metrics.keys())
print(df)


