# 🎨 Inference

This notebook shows how to use the trained model to perform digital-to-film style transfer.

## Setup

---

Let's install some necessary dependencies and set global variables.

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import autorootcwd

In [None]:
# Imports
import os
import wandb
from matplotlib import pyplot as plt

from src.data.components import PairedDataset
from src.models import TranslationModule
from src.models.transforms import pil_to_plot

In [None]:
# Constants
api = wandb.Api()

# Define W&B Run ID
USER = "sillystill"
PROJECT = "sillystill"
RUN_ID = "uf9p1ygx"
VERSION = "v0"

# Define local path
LOCAL_PATH = "logs/hydra/runs/2024-05-16_22-10-43/checkpoints/best.ckpt"

## Translation Module


In [None]:
# Download from checkpoint
if RUN_ID and VERSION:
    try:
        CKPT = f"{USER}/{PROJECT}/model-{RUN_ID}:{VERSION}"
        artifact = api.artifact(CKPT)
        artifact.download()
        path = os.path.join("artifacts", f"model-{RUN_ID}:{VERSION}", "model.ckpt")
        print(f"✅ Successfully downloaded checkpoint from {CKPT} to {path}")
    except Exception as e:
        path = LOCAL_PATH
        print(f"ℹ️ Could not download checkpoint from {CKPT}")
        print(f"✅ Loaded local path {path}")

In [None]:
# Load the checkpoint
model = TranslationModule.load_from_checkpoint(path);

print(f"✅ Loaded model from {path} (Device: {model.device})")

In [None]:
# Load example image
film_paired_dir = os.path.join("data", "paired", "processed", "film")
digital_paired_dir = os.path.join("data", "paired", "processed", "digital")
digital_film_data = PairedDataset(image_dirs=(film_paired_dir, digital_paired_dir))

print(f"✅ Loaded {len(digital_film_data)} image pairs")

In [None]:
# Show example image pair
idx = 0
film, digital = digital_film_data[idx]

# Downsample images (4x)
downsample = lambda x, factor: x.resize((x.width // factor, x.height // factor))
film, digital = downsample(film, 4), downsample(digital, 4)

fig, ax = plt.subplots(1, 2, figsize=(12, 4))
fig.suptitle(f"Digital-Image Pair (Index: {idx})")
ax[0].imshow(pil_to_plot(digital)); ax[1].imshow(pil_to_plot(film));
ax[0].set_xlabel("Digital"); ax[1].set_xlabel("Film Image");

In [None]:
# Run inference
film_predicted = model.predict(digital)

In [None]:
# Plot side-by-side
fig, ax = plt.subplots(1, 3, figsize=(18, 4))
fig.suptitle(f"Digital-Image Pair (Index: {idx})")
fig.tight_layout(pad=1.0)
ax[0].imshow(pil_to_plot(digital))
ax[1].imshow(pil_to_plot(film))
ax[2].imshow(pil_to_plot(film_predicted))
ax[0].set_xlabel("Digital"); ax[1].set_xlabel("Film Image"); ax[2].set_xlabel("Predicted Film Image");

In [None]:
# Run inference on all images
from tqdm import tqdm

for idx, (film, digital) in tqdm(enumerate(digital_film_data), total=len(digital_film_data)):
    # Run inference
    film, digital = downsample(film, 4), downsample(digital, 4)
    film_predicted = model.predict(digital)

    # Save images
    save_dir = f"outputs/{RUN_ID}/{idx}"
    os.makedirs(save_dir, exist_ok=True)
    digital.save(f"{save_dir}/digital.png")
    film.save(f"{save_dir}/film.png")
    film_predicted.save(f"{save_dir}/film_predicted.png")