### Comparison to Baselines

In this notebook, we calculate evaluation metrics only comparing digital to film.


In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import autorootcwd

In [None]:
import os
from tqdm import tqdm
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset

from torchmetrics import MetricCollection
from torchmetrics.image import (
    StructuralSimilarityIndexMeasure as SSIM,
    PeakSignalNoiseRatio as PSNR,
)

from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from src.eval import PieAPP
from src.data.components import PairedDataset
from src.models import transforms as CT

In [None]:
infer_metrics = MetricCollection(
    {
        "ssim": SSIM(),
        "psnr": PSNR(),
        "lpips": LPIPS(),
        "pieapp": PieAPP(),
    }
)

In [None]:
# Constants
RAW_DIR = os.getcwd()
DATA_DIR = os.path.join(RAW_DIR, 'data')
SUBSET = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## No Alteration

We compute metrics on simply predicting the film image as the digital image. I.e. we compute metrics over the (digital, film) pairs.

In [None]:
film_paired_dir = os.path.join(DATA_DIR, "paired", "processed", "film")
digital_paired_dir = os.path.join(DATA_DIR, "paired", "processed", "digital")
digital_film_data = PairedDataset(image_dirs=(film_paired_dir, digital_paired_dir))
film_0, digital_0 = digital_film_data[0]
if SUBSET:
    digital_film_data = Subset(digital_film_data, range(1))

Let's look at the first example and compute its metrics

In [None]:
def to_infer(img):
    height = CT.get_valid_dim(img.size[1], downsample=4)
    width = CT.get_valid_dim(img.size[0], downsample=4)
    img_transform = CT.TestTransforms(dim=(height, width))
    img = img_transform(img).unsqueeze(0).clamp(0+1e-5, 1-1e-5)
    return img.to(device)

In [None]:
film_0, digital_0 = digital_film_data[0]
metrics = {k: float(v) for k, v in infer_metrics(to_infer(digital_0), to_infer(film_0)).items()}


In [None]:
metrics

Let's now iterate over all the images in the dataset

In [None]:
all_metrics = {}
for film, digital in tqdm(digital_film_data):
    film, digital = to_infer(film), to_infer(digital)
    metrics = infer_metrics(film, digital)
    for metric in metrics:
        if metric not in all_metrics:
            all_metrics[metric] = []
        score = metrics[metric]

        if isinstance(score, torch.Tensor):
            score = score.item()

        all_metrics[metric].append(score)

In [None]:
# Average the metrics
mean_metrics = {k: sum(v) / len(v) for k, v in all_metrics.items()}
mean_metrics