### Normalization Tests

This notebook was used to find a way to use the same transformation as `transforms.Compose` on batches of images as `transforms.Compose` was limited
to single images as input.

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import warnings

In [2]:

class DaganDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, originals, augmentations, transform=None):
        assert len(originals) == len(augmentations)
        self.originals = originals
        self.augmentations = augmentations
        self.transform = transform

    def __len__(self):
        return len(self.originals)

    def __getitem__(self, idx):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            return self.transform(self.originals[idx]), self.transform(
                self.augmentations[idx]
            )


def create_dagan_dataloader(originals, augmentations, transform, batch_size):
    train_dataset = DaganDataset(originals, augmentations, transform)
    return DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)


In [4]:
train_dataset_path = '../../data/dagan/train.npz'
dataset = np.load(train_dataset_path)

In [11]:
in_channels = dataset['orig'].shape[-1]
img_size = dataset['orig'].shape[2]
batch_size = 32
max_pixel_value = 1.0
mid_pixel_value = max_pixel_value / 2

In [12]:
import torchvision.transforms as transforms
train_transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize(
            (mid_pixel_value,) * in_channels, (mid_pixel_value,) * in_channels # mean, standard deviation
        ),
    ]
)

In [13]:
dl = create_dagan_dataloader(dataset['orig'], dataset['aug'], train_transform, batch_size)

In [15]:
a = dl.dataset.originals[0]

In [19]:
a.shape

(84, 84, 3)

In [22]:
b = a[None, :]
b.shape

(1, 84, 84, 3)