### Comparison to Baselines

In this notebook, we calculate evaluation metrics only comparing digital to film.


In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import autorootcwd

In [None]:
import os
import torch
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt

from torchmetrics import MetricCollection
from torchmetrics.image import (
    StructuralSimilarityIndexMeasure as SSIM,
    PeakSignalNoiseRatio as PSNR,
)
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS

from src.eval import PieAPP
from src.models import transforms as CT
from src.data.components import PairedDataset

In [None]:
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
device

## Random Noise Checks

First, we check how calling a metric collection works.

In [None]:
compute_metrics = MetricCollection({
    "ssim": SSIM(),
    "psnr": PSNR(),
    "lpips": LPIPS(net_type="squeeze", normalize=True),
    "pieapp": PieAPP(),
})

metrics = compute_metrics.to(device)

In [None]:
c = lambda x: x.clamp(0+1e-5, 1-1e-5)

In [None]:
x = torch.rand(1, 3, 256, 256).to(device)
noise = torch.rand_like(x).to(device)

print(compute_metrics(c(x), c(x + noise * 0.1))) # less noise
print(compute_metrics(c(x), c(x + noise * 0.5))) # more noise

assert metrics(c(x), c(x + noise * 0.1)) == metrics(c(x), c(x + noise * 0.1)), "Non-deterministic behavior"

As expected, the the SSIM and PSNR values are higher when adding less noise, and all other metrics are lower.

We also observe that re-computing with the same inputs gives the same results, i.e. the results are deterministic.

In [None]:
# Check commutativity
values_a = compute_metrics(c(x), c(x + noise * 0.1))
values_b = compute_metrics(c(x + noise * 0.1), c(x))
print("Is commutative")
for metric in metrics:
    val_a = values_a[metric]
    val_b = values_b[metric]
    is_commutative = torch.allclose(val_a, val_b, atol=1e-6)

    print(f"{metric}: {is_commutative}")

## Single Imge

We compute metrics on a single image pair.

In [None]:
film_paired_dir = os.path.join("data", "paired", "grain", "film")
digital_paired_dir = os.path.join("data", "paired", "grain", "digital")
digital_film_data = PairedDataset(image_dirs=(film_paired_dir, digital_paired_dir))

print(len(digital_film_data))

Let's look at the first example and compute its metrics

In [None]:
from PIL import Image

In [None]:
film, digital = digital_film_data[0]
black = Image.new('RGB', film.size)

fig, ax = plt.subplots(1, 3, figsize=(10, 5))
ax[0].imshow(film); ax[1].imshow(digital); ax[2].imshow(black);

We have to apply some transforms (to model input + resize) to be able to feed through the metrics (these are the same transforms that we do before passing the data to the model).

In [None]:
def to_pil(img):
    return CT.FromModelInput()(img[0].cpu())

Let's now compute some baselines scores.

In [None]:
trans_digital = CT.to_infer(digital, device=device, downsample=2)
trans_film = CT.to_infer(film, device=device, downsample=2)
trans_black = CT.to_infer(black, device=device, downsample=2)

fig, ax = plt.subplots(1, 3, figsize=(10, 5))
ax[0].imshow(to_pil(trans_digital))
ax[1].imshow(to_pil(trans_film));
ax[2].imshow(to_pil(trans_black));

In [None]:
compute_metrics(trans_digital, trans_film)

In [None]:
compute_metrics(trans_black, trans_film)

Nice, these look like sensible numbers.

In [None]:
# Check sensitivity to downsampling
data = []
for factor in [2, 4, 8]:
    trans_digital = CT.to_infer(digital, downsample=factor, device=device)
    trans_film = CT.to_infer(film, downsample=factor, device=device)
    metrics = compute_metrics(trans_digital, trans_film)

    data.append({"factor": factor, **{metric: value.item() for metric, value in metrics.items()}})

pd.DataFrame(data)

Also, let's check for sensitivity to crops.

In [None]:
# Check sensitivity to crop
crop = CT.TrainTransforms(256, augment=0)

crop_digital = crop(trans_digital)
crop_film = crop(trans_film)

compute_metrics(crop_digital, crop_film)

Notes on runtime:

* Runtime CPU (912x1360): 23s
* Runtime MPS (912x1360): 10s
* Runtime MPS (1824x2728): 1.30m

## Full Dataset

In [None]:
film_paired_dir = os.path.join("data", "paired", "processed", "film")
digital_paired_dir = os.path.join("data", "paired", "processed", "digital")
digital_film_data = PairedDataset(image_dirs=(film_paired_dir, digital_paired_dir))

print(len(digital_film_data))

Let's now iterate over all the images in the dataset, with a downsampling factor of 2.

In [None]:
data = []
for downsample in [2, 4, 8]:
    for i, (film, digital) in tqdm(enumerate(digital_film_data), total=len(digital_film_data), desc=f"Downsample: {downsample}"):
        trans_digital = CT.to_infer(digital, downsample=downsample, device=device)
        trans_film = CT.to_infer(film, downsample=downsample, device=devic
                                 e)

        metrics = compute_metrics(trans_digital, trans_film)
        data.append({
            "image_id": i+1,
            "downsample": downsample,
            **{metric: value.item() for metric, value in metrics.items()}
        })

baselines = pd.DataFrame(data)

In [None]:
# 
scores = list(compute_metrics.keys())
baselines.groupby("downsample").mean()[scores]

In [None]:
# Variation based on downsample
print("Mean standard deviation of metrics across downsampling factors")
baselines.groupby(["image_id"])[scores].std().mean()

In [None]:
baselines[baselines.downsample == 2]