# Vanilla GAN Tutorial

Authors: [Artem Zolkin](https://github.com/arquestro)

[![Catalyst logo](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/catalyst_logo.png)](https://github.com/catalyst-team/catalyst)

In this tutorial we train simple GANs with the help of Catalyst on MNIST dataset.
Our goal here is acquiring generator that can create handwritten digit image from noise vector.

For training we can use GanRunner that is provided with catalyst.dl part of the library.

## Requirements
For this tutorial we need to have Albumentations and Catalyst 20.2.1 version and higher installed:

In [1]:
# Uncomment the following to install Catalyst
!pip install -U catalyst==20.2.1

# Uncomment the following to install the latest version of Albumentations
!pip install -U albumentations

In [5]:
import torch
import catalyst
import os
from catalyst.dl import utils

print(f"torch: {torch.__version__}, catalyst: {catalyst.__version__}")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # "" - CPU, "0" - 1 GPU
SEED = 42
utils.set_global_seed(SEED)
utils.prepare_cudnn(deterministic=True)

## Models For GanRunner(made with PyTorch)
First, let's define our discriminator and generator models. We'll just need PyTorch and numpy for that:

In [0]:
import numpy
import numpy as np
import torch

### Generator Model
For generator model we'll use is a simple architecture

In [0]:
class SimpleGenerator(torch.nn.Module):
    def __init__(
        self,
        noise_dim=10,
        hidden_dim=256,
        image_resolution=(28, 28),
        channels=1
    ):
        super().__init__()
        self.noise_dim = noise_dim
        self.image_resolution = image_resolution
        self.channels = channels

        self.net = torch.nn.Sequential(
            torch.nn.Linear(noise_dim, hidden_dim), torch.nn.LeakyReLU(0.05),
            torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.LeakyReLU(0.05),
            torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.LeakyReLU(0.05),
            torch.nn.Linear(hidden_dim, numpy.prod(image_resolution)), torch.nn.Tanh()
        )

    def forward(self, x):
        x = self.net(x)
        x = x.reshape(x.size(0), self.channels, *self.image_resolution)
        return x

### Discriminator Model

And for discriminator model we'll use this architecture

In [0]:
class SimpleDiscriminator(torch.nn.Module):
    def __init__(self, image_resolution=(28, 28), channels=1, hidden_dim=100):
        super().__init__()
        self.image_resolution = image_resolution
        self.channels = channels
        self.net = torch.nn.Sequential(
            torch.nn.Linear(channels * numpy.prod(image_resolution), hidden_dim),
            torch.nn.LeakyReLU(0.05), torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.LeakyReLU(0.05), torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.LeakyReLU(0.05), torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.LeakyReLU(0.05), torch.nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        x = self.net(x.reshape(x.size(0), -1))
        return x

### Initializing And Combining models

And we need to initialize our models and combine them into a dictionary.
Pay attention to discriminator_key and generator_key variables, we'll need them later. Don't forget about the noise dimension parameter.

In [0]:
noise_dim = 16
discriminator_key = 'discriminator'
discriminator = SimpleDiscriminator()
generator_key = 'generator'
generator = SimpleGenerator(noise_dim=noise_dim)
model = torch.nn.ModuleDict({
    discriminator_key: discriminator,
    generator_key: generator,
})

## Datasets and Data Loaders

Second, we define datasets and data loaders that GanRunner will use for train and validation. We'll use torchvision library to retrieve MNIST dataset. 

In [0]:
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from collections import OrderedDict

### Datasets

#### Data Transforms

We need to provide our datasets with proper transforms. To deliver proper data too our GanRunner we have to add additional scalar and noise tensor with transforms. For these purposes let's use Albumentations transforms and make custom ones.

In [0]:
# Custom transforms are from "mnist-gans" example, transforms.py
from typing import Tuple, Union
from albumentations.core.transforms_interface import BasicTransform, ImageOnlyTransform

class AsImage(ImageOnlyTransform):
    def __init__(self, always_apply=False, p=1.0):
        super().__init__(always_apply, p)

    def apply(self, img, **params):
        if img.ndim == 2:
            return np.expand_dims(img, axis=-1)
        return img

    def get_transform_init_args_names(self):
        return []


class AdditionalValue(BasicTransform):
    def __init__(self, output_key: str = None, **kwargs):
        self.name = "AdditionalValue"
        self.output_key = output_key

    def __call__(self, force_apply=False, **dict_):
        assert self.output_key not in dict_, \
            "Output key is supposed not to be present in dict"
        dict_[self.output_key] = self._compute_output(dict_)
        return dict_
    
    def _compute_output(self, dict_):
        raise NotImplementedError()


class AdditionalNoiseTensor(AdditionalValue):
    def __init__(self, tensor_size: Tuple[int, ...], output_key: str = None):
        super().__init__(output_key)
        self.tensor_size = tensor_size

    def _compute_output(self, dict_):
        return torch.randn(self.tensor_size)

    
class AdditionalScalar(AdditionalValue):
    def __init__(self, value: Union[int, float], output_key: str = None):
        super().__init__(output_key)
        self.value = value

    def _compute_output(self, dict_):
        return torch.tensor([self.value])

In [0]:
from albumentations import Compose
from albumentations.augmentations.transforms import Normalize
from albumentations.pytorch.transforms import ToTensorV2

noise_input = "noise"
real_targets = "real_targets"
fake_targets = "fake_targets"
data_transforms = Compose(transforms=[
    AsImage(),
    Normalize(mean=0.5, std=0.5),
    ToTensorV2(),
    AdditionalNoiseTensor(tensor_size=[noise_dim], output_key=noise_input),
    AdditionalScalar(value=1.0, output_key=real_targets),
    AdditionalScalar(value=0.0, output_key=fake_targets),
])
# Workaround "wrapper" to deliver transforms to torchvision dataset superclass
def transform(dict_):
    return data_transforms(**dict_)

#### Initialize Datasets
Let's wrap up the torchvision MNIST dataset and define datasets for future use in data loaders.

In [13]:
class MNIST(torchvision.datasets.MNIST):
    """
    MNIST Dataset with key_value __get_item__ output
    """
    def __init__(
        self,
        root,
        train=True,
        transform=None,
        target_transform=None,
        download=False,
        image_key="image",
        target_key="target"
    ):
        super().__init__(root, train, transform, target_transform, download)
        self.image_key = image_key
        self.target_key = target_key

    def __getitem__(self, index: int):
        """Get dataset element"""
        image, target = self.data[index], self.targets[index]
        dict_ = {
            self.image_key: image,
            self.target_key: target,
        }
        if self.transform is not None:
            dict_ = self.transform(dict_)
        return dict_


train_dataset_real = MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)
valid_dataset = MNIST(
    root="./data",
    train=False,
    download=True,
    transform=transform,
)
datasets = {
    "train": train_dataset_real,
    "valid": valid_dataset,
}

### Defining loaders

Notice, that we need to provide our GanRunner only with dataloaders, no datasets in the input further, they're already in data loaders.

Pay attention, that loaders name with train datasets must begin with "train" to be picked up properly by Runner functionality further. Same rule for validation data loader naming.

In [0]:
batch_size = 64
train_loader_real = DataLoader(train_dataset_real, batch_size=batch_size)
valid_loader_real = DataLoader(valid_dataset, batch_size=batch_size)
loaders = {
    "train": train_loader_real,
    "valid": valid_loader_real,
}
loaders = OrderedDict(loaders)

## Criterion
Let's define the criterion for our GanRunner.

In [0]:
from torch.nn.modules.loss import BCEWithLogitsLoss

criterion = BCEWithLogitsLoss()

## Optimizers
Next we define the optimizers for each of our models, discriminator and generator.

In [0]:
from torch.optim import Adam

lr = 0.0002
discriminator_optimizer = Adam(discriminator.parameters(), lr=lr)
generator_optimizer = Adam(generator.parameters(), lr=lr)
optimizer = {
    discriminator_key: discriminator_optimizer,
    generator_key: generator_optimizer,
}

## Runner state key-value arguments
The following dictionary items "phase" and "num" are needed to get PhaseManagerCallback working properly in the GanRunner. And we need to include the noise dimensions to provide runner with it.

In [0]:
state_kwargs = {
    "discriminator_train_phase": "discriminator_train",
    "discriminator_train_num": 1,
    "generator_train_phase": "generator_train",
    "generator_train_num": 1,
    "noise_dim": noise_dim,
}

## Callbacks
Following callbacks include loss definition, optimizers, logging. Then for GanRunner to work properly we have to define what callbacks should be converted to PhaseBatchWrapperCallback later in GanExperiment. GanExperiment is ran by GanRunner during training.

In [0]:
from catalyst.dl.callbacks import CriterionAggregatorCallback
from catalyst.dl import OptimizerCallback, CriterionCallback

discriminator_loss_key = "loss_d"
generator_loss_key = "loss_g"
from catalyst.dl import TensorboardLogger
callbacks = OrderedDict({
    "loss_d_real": CriterionCallback(
        input_key="real_targets", 
        output_key="real_logits", 
        prefix="loss_d_real"
    ),
    "loss_d_fake": CriterionCallback(
        input_key="fake_targets", 
        output_key="fake_logits", 
        prefix="loss_d_fake"
    ),
    discriminator_loss_key: CriterionAggregatorCallback(
        loss_keys=["loss_d_real", "loss_d_fake"],
        loss_aggregate_fn="mean", prefix=discriminator_loss_key
    ),
    generator_loss_key: CriterionCallback(
        input_key="real_targets", 
        output_key="fake_logits",
        prefix=generator_loss_key
    ),
    "optim_d": OptimizerCallback(
        loss_key=discriminator_loss_key,
        optimizer_key=discriminator_key
    ),
    "optim_g": OptimizerCallback(
        loss_key=generator_loss_key,
        optimizer_key=generator_key
    ),
    "tensorboard": TensorboardLogger(),
})
# Preparation for callback wrapping. Wrapping is done later in GanExperiment, which is run by GanRunner
discriminator_phase_callbacks = ["loss_d_real", "loss_d_fake", discriminator_loss_key, "optim_d"]
generator_phase_callbacks = [generator_loss_key, "optim_g"]
phase2callbacks = {
    state_kwargs["discriminator_train_phase"]: discriminator_phase_callbacks,
    state_kwargs["generator_train_phase"]: generator_phase_callbacks,
}

## GanRunner
Define other training parameters and initialize GanRunner.

In [0]:
logdir = "./logs"
num_epochs = 100
main_metric = generator_loss_key
verbose = True
check = False

In [0]:
from catalyst.dl.runner import GanRunner

real_data_key = "image"
runner = GanRunner(
    data_input_key=real_data_key,
    discriminator_model_key=discriminator_key,
    generator_model_key=generator_key,
)

## Start the GAN training
Finally, let's run the training!

In [21]:
runner.train(
    model=model,
    loaders=loaders,
    callbacks=callbacks,
    logdir=logdir,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=num_epochs,
    main_metric=main_metric,
    state_kwargs=state_kwargs,
    phase2callbacks=phase2callbacks,
    verbose=verbose,
    check=check,
)

## Visualize the results

In [28]:
# In a courtesy of Caffe's filter visualization example
# http://nbviewer.jupyter.org/github/BVLC/caffe/blob/master/examples/00-classification.ipynb
import matplotlib.pyplot as plt

def imshow_grid(data, height=None, width=None, normalize=False, padsize=1, padval=0):
    '''
    Take an array of shape (N, H, W) or (N, H, W, C)
    and visualize each (H, W) image in a grid style (height x width).
    '''
    if normalize:
        data -= data.min()
        data /= data.max()

    N = data.shape[0]
    if height is None:
        if width is None:
            height = int(np.ceil(np.sqrt(N)))
        else:
            height = int(np.ceil( N / float(width) ))

    if width is None:
        width = int(np.ceil( N / float(height) ))

    assert height * width >= N

    # append padding
    padding = ((0, (width*height) - data.shape[0]), (0, padsize), (0, padsize)) + ((0, 0),) * (data.ndim - 3)
    data = np.pad(data, padding, mode='constant', constant_values=(padval, padval))

    # tile the filters into an image
    data = data.reshape((height, width) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
    data = data.reshape((height * data.shape[1], width * data.shape[3]) + data.shape[4:])

    plt.imshow(data)


noise_sample = torch.randn((25, noise_dim)).to('cuda:0')
result = runner.model[generator_key].forward(noise_sample)
print(result.shape)
result = result.permute((0, 2, 3, 1)).cpu().detach().numpy()
result = numpy.repeat(result, repeats=3,axis=3)
imshow_grid(result)