# Testing data module

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as T
import pytorch_3T27T.data as DataModule
from pytorch_3T27T.data import OneFoldDataLoader, KFoldDataLoader
from PIL import Image

In [None]:
from pytorch_3T27T.utils import Configuration, init_and_config
options = {
    'train_transforms': {
        'type': 'AugmentationFactory',
        'options': {
            'train': True,
            'augmentations': ['resize', 'patch'],
            'load_size': (128, 128),
            'patch_size': (64, 64),
        },
    },
    'test_transforms': {
        'type': 'AugmentationFactory',
        'options': {
            'train': False,
            'augmentations': ['resize'],
            'load_size': (256, 256),
        },
    },
    'train_dataset': {
        'type': 'CycleGANDataset',
        'options': {
            'download': True,
            'dataset': "horse2zebra",
            'train': True,
        }
    },
    'test_dataset': {
        'type': 'CycleGANDataset',
        'options': {
            'download': True,
            'dataset': "horse2zebra",
            'train': False,
        }
    },
    'train_dataloader': {
        'type': 'BaseDataLoader',
        'options': {
            'batch_size': 10,
        }
    },
    'test_dataloader': {
        'type': 'BaseDataLoader',
        'options': {
            'batch_size': 1,
        }
    },
}

cfg = Configuration(**options)

In [None]:
def ToPILImage(img):
    if img.ndim > 3:
        batch_size = img.shape[0]
        batch = []
        for i in range(batch_size):
            batch.append(T.ToPILImage()(img[i].to('cpu')))
    return batch

In [None]:
def plot_CycleGAN_sample(imgs):
    fig, axs = plt.subplots(nrows=1, ncols=len(imgs), squeeze=True)
    for col_idx, img in enumerate(imgs):
        ax = axs[col_idx]
        img = np.asarray(img)
        ax.imshow(img)
        ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    fig.tight_layout()
    return fig

In [None]:
train_transform = init_and_config(DataModule, 'train_transforms', cfg)
test_transform = init_and_config(DataModule, 'test_transforms', cfg)

In [None]:
train_dataset = init_and_config(DataModule, 'train_dataset', cfg, transform=train_transform.get_transform())
test_dataset = init_and_config(DataModule, 'test_dataset', cfg, transform=test_transform.get_transform())

In [None]:
train_dataloader = init_and_config(DataModule, 'train_dataloader', cfg, train_dataset)
test_dataloader = init_and_config(DataModule, 'test_dataloader', cfg, test_dataset)

In [None]:
for i, (train_sample, test_sample) in enumerate(zip(train_dataloader, test_dataloader)):
    if i == 5:
        break
    train_A, train_B = ToPILImage(train_sample['A'][0]), ToPILImage(train_sample['B'][0])
    test_A, test_B = ToPILImage(test_sample['A'][0]), ToPILImage(test_sample['B'][0])
    fig = plot_CycleGAN_sample([train_A[0], test_A[0], train_B[0], test_B[0]]) 
    

In [None]:
kloader = KFoldDataLoader(train_dataset)

In [None]:
for (fold, train_set, val_set) in kloader:
    print(f"Fold: {fold}")
    for i, (train_sample, val_sample) in enumerate(zip(train_set, val_set)):
        if i == 2:
            break
        train_A = ToPILImage(train_sample['A'][0])[0]
        train_B = ToPILImage(train_sample['B'][0])[0]
        val_A = ToPILImage(val_sample['A'][0])[0]
        val_B = ToPILImage(val_sample['B'][0])[0]
        fig = plot_CycleGAN_sample([train_A, train_B, val_A, val_B]) 
        plt.show()

In [None]:
loader = OneFoldDataLoader(train_dataset)
train_dl, val_dl = loader()

In [None]:
for (i, sample) in enumerate(train_dl):
    if i == 5:
        break
    A, B = ToPILImage(sample['A'][0])[0], ToPILImage(sample['B'][0])[0]
    fig = plot_CycleGAN_sample([A, B]) 