# ⚡ Lightning Modules

This notebook shows how to use the various model modules defined in this project.

## Setup

---

Let's install some necessary dependencies and set global variables.

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
# Autoroot
import autorootcwd

In [None]:
# Imports
import os
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from torchvision.transforms.v2.functional import to_pil_image
from torch.utils.data import DataLoader

from src.data import PairedDigitalFilmDataModule
from src.models.net import FFNet, ConvNet, UNet
from src.models import TranslationModule, AutoTranslationModule

In [None]:
DATA_DIR = os.path.join(os.getcwd(), 'data')

### BaseModule

the `BaseModule` defines some helpful methods for transformations and logging.

In [None]:
# Initialise `BaseModule`
from src.models.base_module import BaseModule

base = BaseModule(augment=0.0, training_patch_size=128)

In [None]:
# Get the batched data
batch_size = 4
data = PairedDigitalFilmDataModule(batch_size=batch_size)
data.prepare_data(); data.setup()
loader = data.train_dataloader()

In [None]:
batch = next(iter(loader))
film_batch, digital_batch = batch

fig, axs = plt.subplots(nrows=2, ncols=batch_size, figsize=(16, 6))
for i in range(batch_size):
    axs[0, i].imshow(np.array(to_pil_image(film_batch[i])))
    axs[1, i].imshow(np.array(to_pil_image(digital_batch[i])))
axs[0, 0].set_ylabel('Film'); axs[1, 0].set_ylabel('Digital');

In [None]:
# Train transforms
batch_train_transformed = base.train_transform(batch)
original_batch = base.undo_transform(batch_train_transformed)
_, digital_train_transformed = batch_train_transformed
_, digital_original = original_batch

fig, axs = plt.subplots(nrows=2, ncols=batch_size, figsize=(16, 6))
for i in range(batch_size):
    axs[0, i].imshow(np.array(to_pil_image(digital_train_transformed[i])))
    axs[1, i].imshow(np.array(to_pil_image(digital_original[i])))
axs[0, 0].set_ylabel('Train Transformation'); axs[1, 0].set_ylabel("Original");

In [None]:
# Test transforms
batch_train_transformed = base.test_transform(batch)
original_batch = base.undo_transform(batch_train_transformed)
_, digital_train_transformed = batch_train_transformed
_, digital_original = original_batch

fig, axs = plt.subplots(nrows=2, ncols=batch_size, figsize=(16, 6))
for i in range(batch_size):
    axs[0, i].imshow(np.array(to_pil_image(digital_train_transformed[i])))
    axs[1, i].imshow(np.array(to_pil_image(digital_original[i])))
axs[0, 0].set_ylabel("Test Transformation"); axs[1, 0].set_ylabel("Original");

### TranslationModule

The `TranslationModule` is a direct image-to-image translation module. It can be based on the various encoder-decoder networks (like `FFNet`, `ConvNet` or `UNet`)

In [None]:
# Initialise `TranslationModule`
module = TranslationModule(
    net=FFNet(input_output_size=3*256*256),
    loss=nn.MSELoss()
)

x = torch.randn(1, 3, 256, 256)
y = module.forward(x)
assert x.shape == y.shape

In [None]:
# Initialise `TranslationModule`
module = TranslationModule(
    net=ConvNet(),
    loss=nn.MSELoss()
)

x = torch.randn(1, 3, 256, 256)
y = module.forward(x)
assert x.shape == y.shape

In [None]:
# Initialise `TranslationModule`
module = TranslationModule(
    net=UNet(),
    loss=nn.MSELoss()
)

x = torch.randn(1, 3, 256, 256)
y = module.forward(x)
assert x.shape == y.shape

### AutoTranslationModule

Base module for auto-translation models as seen in "Semi-Supervised
Raw-to-Raw Mapping": https://arxiv.org/pdf/2106.13883


In [None]:
# Initialise `AutoTranslationModule`
module = AutoTranslationModule(
    optimizer=torch.optim.Adam,
    scheduler=None,
)

batch = (
    torch.randn(3, 3, 256, 256),
    torch.randn(3, 3, 256, 256),
    torch.randn(2, 2, 3, 256, 256),
)
loss, film_paired, digital_to_film = module.step(batch)
assert loss is not None
film_paired.shape, digital_to_film.shape