# Evaluation on All Samples

We use a trained model and generate metrics 

## Setup

---

Let's install some necessary dependencies and set global variables.

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import autorootcwd

In [None]:
# Imports
import os
import json
import wandb
from matplotlib import pyplot as plt

from src.data.components import PairedDataset
from src.models import TranslationModule
from src.utils import process_pair

import torch
from torch.utils.data import Subset
from torchmetrics import MetricCollection
from torchmetrics.image import (
    StructuralSimilarityIndexMeasure as SSIM,
    PeakSignalNoiseRatio as PSNR,
)

from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from src.eval import PieAPP

In [None]:
# Constants
api = wandb.Api()

# Define W&B Run ID
USER = "sillystill"
PROJECT = "sillystill"
RUN_ID = "8edwcgyg"
VERSION = "v0"
SAVE_IMAGES = False

# Define local path
LOCAL_PATH = "logs/hydra/runs/2024-05-16_22-10-43/checkpoints/best.ckpt"

# Hyperparameter
SUBSET = True
SUBSET_SIZE = 1
DOWNSAMPLE = 16

## Translation Module


In [None]:
# Download from checkpoint
if RUN_ID and VERSION:
    try:
        CKPT = f"{USER}/{PROJECT}/model-{RUN_ID}:{VERSION}"
        artifact = api.artifact(CKPT)
        artifact.download()
        path = os.path.join("artifacts", f"model-{RUN_ID}:{VERSION}", "model.ckpt")
        print(f"✅ Successfully downloaded checkpoint from {CKPT} to {path}")
    except Exception as e:
        path = LOCAL_PATH
        print(f"ℹ️ Could not download checkpoint from {CKPT}")
        print(f"✅ Loaded local path {path}")

In [None]:
# Load the checkpoint
model = TranslationModule.load_from_checkpoint(path);

print(f"✅ Loaded model from {path} (Device: {model.device})")

In [None]:
# Load example image
film_paired_dir = os.path.join("data", "paired", "processed", "film")
digital_paired_dir = os.path.join("data", "paired", "processed", "digital")
digital_film_data = PairedDataset(image_dirs=(film_paired_dir, digital_paired_dir))
if SUBSET:
    digital_film_data = Subset(digital_film_data, range(SUBSET_SIZE))

print(f"✅ Loaded {len(digital_film_data)} image pairs")

In [None]:
# Define metrics
metrics = MetricCollection(
    {
        "ssim": SSIM(),
        "psnr": PSNR(),
        "lpips": LPIPS(),
        "pieapp": PieAPP(),
    }
)

In [None]:
# Run inference on all images
from tqdm import tqdm

all_metrics = {}
for idx, (film, digital) in tqdm(enumerate(digital_film_data), total=len(digital_film_data)):
    
    # Run inference
    film_predicted = model.predict(digital, downsample=DOWNSAMPLE)

    # Process images to be in the same format as test images
    film, film_predicted = process_pair(film, film_predicted, downsample=DOWNSAMPLE)

    for metric in metrics:
        if metric not in all_metrics:
            all_metrics[metric] = []
        
        score = metrics[metric](film.unsqueeze(0), film_predicted.unsqueeze(0))

        if isinstance(score, torch.Tensor):
            score = score.item()

        all_metrics[metric].append(score)

    if SAVE_IMAGES:
        # Save images
        save_dir = f"outputs/{RUN_ID}/{idx}"
        os.makedirs(save_dir, exist_ok=True)
        digital.save(f"{save_dir}/digital.png")
        film.save(f"{save_dir}/film.png")
        film_predicted.save(f"{save_dir}/film_predicted.png")

# Print metrics
means = {metric: sum(scores) / len(scores) for metric, scores in all_metrics.items()}
for metric, score in means.items():
    print(f"{metric}: {score}")

# Plot metric histograms
for metric, scores in all_metrics.items():
    plt.hist(scores, bins=20)
    plt.title(metric)
    plt.show()

# Save metrics
save_dir = f"outputs/{RUN_ID}"
os.makedirs(save_dir, exist_ok=True)
metrics_path = f"outputs/{RUN_ID}/metrics.json"
with open(metrics_path, "w") as f:
    json.dump(means, f)