# ⚡ Lightning Data Module

This notebook shows how to use the various data modules defined in this project.

## Setup

---

Let's install some necessary dependencies and set global variables.

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
# Autoroot
import autorootcwd

In [None]:
# Imports
import os

import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from hydra import compose, initialize
from hydra.utils import instantiate
from torchvision.transforms.v2.functional import to_pil_image

# Local modules
from src.data.components import PairedDataset, UnpairedDataset, CombinedDataset
from src.data import PairedDigitalFilmDataModule, CombinedDigitalFilmDataModule

In [None]:
# Constants
RAW_DIR = os.getcwd()
DATA_DIR = os.path.join(RAW_DIR, 'data')

## Datasets

### Paired Data


Initialises a `PairedDataset` instance. This dataset is used to load
image pairs from two data directories. The dataset assumes that the
filenames in both directories match for corresponding image pairs
and are in the same format. Data augmentation can be applied to the
images when loading. The dataset can be truncated to a maximum number
of samples, if required.

In [None]:
# Instantiate dataset
film_paired_dir = os.path.join(DATA_DIR, "paired", "processed", "film")
digital_paired_dir = os.path.join(DATA_DIR, "paired", "processed", "digital")
digital_film_data = PairedDataset(image_dirs=(film_paired_dir, digital_paired_dir))

print(f"✅ Loaded {len(digital_film_data)} paired samples")

In [None]:
# Inspect samples
film, digital = digital_film_data[0]
print(f"Film type: {type(film)}, Digital type: {type(digital)}")
print(f"Film image: {film.size}, Digital image {digital.size}")

# Show sample
fig, axs = plt.subplots(ncols=2, figsize=(16, 8))
axs[0].imshow(np.array(film)); axs[1].imshow(np.array(digital))
axs[0].set_title("Film"); axs[1].set_title("Digital");

In [None]:
# Create dataloader
batch_size = 4
dataloader = DataLoader(digital_film_data, batch_size=batch_size, collate_fn=digital_film_data.collate)

film_batch, digital_batch = next(iter(dataloader))
print(f"Film Type: {type(film_batch)}, Digital batch: {type(digital_batch)}")
print(f"Film Batch: {film_batch.shape}, Digital batch: {digital_batch.shape}")

# Show sample
fig, axs = plt.subplots(nrows=2, ncols=batch_size, figsize=(4 * batch_size, 6))
fig.suptitle("Film-Digital Batch")
fig.tight_layout(pad=1.0)
for i in range(batch_size):
    axs[0, i].imshow(np.array(to_pil_image(film_batch[i])))
    axs[1, i].imshow(np.array((to_pil_image(digital_batch[i]))))
axs[0, 0].set_ylabel("Film")
axs[1, 0].set_ylabel("Digital");

### Unpaired Data

Initialises a `UnpairedDataset` instance. This dataset is used to load images from
a single image directory and apply data augmentation if required. The dataset can be
truncated to a maximum number of samples, if required.

In [None]:
# Instantiate digital dataset
digital_unpaired_dir = os.path.join(DATA_DIR, "unpaired", "digital")
digital_dataset = UnpairedDataset(image_dir=digital_unpaired_dir)

print(f"✅ Loaded {len(digital_dataset)} digital samples")

In [None]:
# Inspect sample
digital = digital_dataset[0]
print(f"Image type: {type(digital)}, Image shape: {digital.size}")

In [None]:
# Inspect samples
digital_loader = DataLoader(digital_dataset, batch_size=batch_size, collate_fn=digital_dataset.collate)
digital_batch = next(iter(digital_loader))
print(f"Digital batch size: {digital_batch.shape}")

# Show sample
fig, axs = plt.subplots(ncols=batch_size, figsize=(4 * batch_size, 3))
fig.suptitle("Digital samples")
for i in range(batch_size):
    axs[i].imshow(np.array(to_pil_image((digital_batch[i]))))

In [None]:
# Instantiate film dataset
film_unpaired_dir = os.path.join(DATA_DIR, "unpaired", "film")
film_dataset = UnpairedDataset(image_dir=film_unpaired_dir)

print(f"✅ Loaded {len(film_dataset)} film samples")

In [None]:
# Inspect samples
film_loader = DataLoader(film_dataset, batch_size=batch_size, collate_fn=film_dataset.collate)
film_batch = next(iter(film_loader))
print(f"Film batch size: {film_batch.shape}")

# Show sample
fig, axs = plt.subplots(ncols=batch_size, figsize=(4 * batch_size, 3))
fig.suptitle("Film samples")
for i in range(batch_size):
    axs[i].imshow(np.array(to_pil_image(film_batch[i])))

### Combined dataset

In [None]:
# Instantiate combined dataset
combined_dataset = CombinedDataset(
    digital_dataset=digital_dataset,
    film_dataset=film_dataset,
    paired_dataset=digital_film_data,
    num_paired_per_batch=1,
    num_unpaired_per_batch=1,
)

print(f"✅ Loaded {len(combined_dataset)} combined batches")

In [None]:
# Inspect samples
film, digital, (film_paired, digital_paired) = combined_dataset[0]
print(f"Digital: {digital.shape}, Film: {film.shape}")
print(f"Digital paired: {digital_paired.shape}, Film paired: {film_paired.shape}")

# Show sample
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))
axs[0, 0].imshow(np.array(to_pil_image(digital[0])))
axs[0, 1].imshow(np.array(to_pil_image(film[0])))
axs[1, 0].imshow(np.array(to_pil_image(digital_paired[0])))
axs[1, 1].imshow(np.array(to_pil_image(film_paired[0])))
axs[0, 0].set_title("Digital")
axs[0, 1].set_title("Film")
axs[1, 0].set_title("Digital paired")
axs[1, 1].set_title("Film paired");

## Datamodules

### Paired Digital Film Data

Initialise a `PairedDigitalFilmDataModule` which is a Lightning wrapper around
the paired digital-film image pair dataset. The dataset is split into
train, validation and test sets.

In [None]:
# Instantiate digital-film data module
batch_size = 4
digital_film_data_module = PairedDigitalFilmDataModule(batch_size=batch_size)

# Setup data module
digital_film_data_module.prepare_data()
digital_film_data_module.setup()

print(f"✅ Loaded and prepared data module")

In [None]:
# Get loader
train_loader = digital_film_data_module.train_dataloader()
val_loader = digital_film_data_module.val_dataloader() # Batch size: 1
test_loader = digital_film_data_module.test_dataloader() # Batch size: 1

print(f"Train loader: {len(train_loader)}, Val loader: {len(val_loader)}, Test loader: {len(test_loader)}")

In [None]:
# Get batch
film_batch, digital_batch = next(iter(train_loader))
print(f"Film Batch: {film_batch.shape}, Digital batch: {digital_batch.shape}")

# Show sample
fig, axs = plt.subplots(nrows=2, ncols=batch_size, figsize=(4 * batch_size, 8))
fig.suptitle("Digital-Film Batch")
fig.tight_layout(pad=1.0)
for i in range(batch_size):
    axs[0, i].imshow(np.array(to_pil_image((film_batch[i]))))
    axs[1, i].imshow(np.array(to_pil_image(digital_batch[i])))
axs[0, 0].set_ylabel("Film")
axs[1, 0].set_ylabel("Digital");

### Combined Film Data Module


In [None]:
# Instantiate combined data module
combined_digital_film_data_module = CombinedDigitalFilmDataModule(
    batch_size=batch_size,
    num_paired_per_batch=1,
    num_unpaired_per_batch=1,
)

# Setup data module
combined_digital_film_data_module.prepare_data()
combined_digital_film_data_module.setup()

print(f"✅ Loaded and prepared data module")

In [None]:
# Get loader
train_loader = combined_digital_film_data_module.train_dataloader() 
val_loader = combined_digital_film_data_module.val_dataloader() 
test_loader = combined_digital_film_data_module.test_dataloader()

print(f"Train loader: {len(train_loader)}, Val loader: {len(val_loader)}, Test loader: {len(test_loader)}")

In [None]:
batch = next(iter(train_loader))

batch[2].shape

In [None]:
# Get batch
film_batch, digital_batch, paired_batch = next(iter(train_loader))
film_batch, digital_batch = film_batch.squeeze(0), digital_batch.squeeze(0),
film_paired_batch, digital_paired_batch = paired_batch.squeeze(0)

print(f"Film Batch: {film_batch.shape}, Digital batch: {digital_batch.shape}, Film paired batch: {film_paired_batch.shape}, Digital paired batch: {digital_paired_batch.shape}")

# Show sample
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(8, 8))
fig.suptitle("Digital-Film Batch")
fig.tight_layout(pad=1.0)
axs[0, 0].imshow(np.array(to_pil_image(film_batch[0])))
axs[0, 1].imshow(np.array(to_pil_image(digital_batch[0])))
axs[1, 0].imshow(np.array(to_pil_image(film_paired_batch[0])))
axs[1, 1].imshow(np.array(to_pil_image(digital_paired_batch[0])));

## Hydra

We can initialise the two above data modules in a Hydra configuration file. The configuration files are located in `configs/data`.

In [None]:
# Instantiate paired data module
with initialize(version_base=None, config_path="../configs/data", job_name="data"):
        cfg = compose(config_name="paired")
        
        # Instantiate data module
        datamodule = instantiate(cfg)

        # Setup data module
        datamodule.prepare_data()
        datamodule.setup()

        print(f"✅ Loaded and prepared data module")

In [None]:
# Instantiate paired data module
with initialize(version_base=None, config_path="../configs/data", job_name="data"):
        cfg = compose(config_name="combined")
        
        # Instantiate data module
        datamodule = instantiate(cfg)

        # Setup data module
        datamodule.prepare_data()
        datamodule.setup()

        print(f"✅ Loaded and prepared data module")