# ⚡ Transforms

This notebook shows how to use the transforms in this project.

## Setup

---

Let's install some necessary dependencies and set global variables.

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
# Autoroot
import autorootcwd

In [None]:
# Imports
import os
import torch
import numpy as np
from matplotlib import pyplot as plt
from torchvision.transforms.v2.functional import to_pil_image

from src.data.components import PairedDataset
from src.models.transforms import ToModelInput, FromModelInput, Augment, tensor_to_plot, pil_to_plot

In [None]:
# Instantiate dataset
film_paired_dir = os.path.join("data", "paired", "processed", "film")
digital_paired_dir = os.path.join("data", "paired", "processed", "digital")
digital_film_data = PairedDataset(image_dirs=(film_paired_dir, digital_paired_dir))

print(f"✅ Loaded {len(digital_film_data)} paired samples")

In [None]:
# Transforms
film, digital = digital_film_data[13]
film, digital = film.crop((0, 0, 256, 256)), digital.crop((0, 0, 256, 256))

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(pil_to_plot(film)); ax[1].imshow(pil_to_plot(digital));
ax[0].set_title("Film"); ax[1].set_title("Digital");

In [None]:
print(type(film), type(digital))
print(film.size, digital.size)

### Transform to Model Input

The model input should be normalised [0, 1] Tensors.

In [None]:
to_model_input = ToModelInput()
film_in, digital_in = to_model_input(film), to_model_input(digital)

print(film_in.dtype, digital_in.dtype)
print(film_in.shape, digital_in.shape)

In [None]:
# Show
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(tensor_to_plot(film_in)); ax[1].imshow(tensor_to_plot(digital_in));
ax[0].set_title("Film"); ax[1].set_title("Digital");

### Transform from Model Input

The model input should be normalised [0, 1] Tensors.

In [None]:
from_model_input = FromModelInput()
film_out, digital_out = from_model_input(film_in), from_model_input(digital_in)

print(type(film), type(digital))
print(film.size, digital.size)

In [None]:
# Show transform
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(pil_to_plot(film_out)); ax[1].imshow(pil_to_plot(digital_out));
ax[0].set_title("Film"); ax[1].set_title("Digital");

### Augment

In [None]:
augment = Augment(0.2)
film_augment, digital_augment = augment(torch.cat([film_in.unsqueeze(0), digital_in.unsqueeze(0)], dim=0))

# Show transform
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(tensor_to_plot(film_augment)); ax[1].imshow(tensor_to_plot(digital_augment));
ax[0].set_title("Film"); ax[1].set_title("Digital");