### General Settings

Change the respective settings to run appropriately

Use `limit_train_batches`, `limit_val_batches`, `limit_test_batches` as required

In [1]:
project_dir = '/Users/rajjain/PycharmProjects/ADRL-Course-Work/'
data_dir = project_dir + 'data/'
celeba_data_dir = '/Users/rajjain/Desktop/CourseWork/CelebA/'
bitmoji_data_dir = '/Users/rajjain/Desktop/CourseWork/Bitmoji/'
mnist_data_dir = '/Users/rajjain/Desktop/CourseWork/MNIST/'
svhn_data_dir = '/Users/rajjain/Desktop/CourseWork/SVHN/'
use_gpu = False
num_cpus = 2

### Imports

In [2]:
from torch.nn import init, Linear, Sequential, Conv2d, PReLU, Module, BatchNorm2d, ConvTranspose2d, Hardtanh, \
    Flatten, MaxPool2d, L1Loss
from torchvision.datasets import CelebA, ImageFolder, SVHN, MNIST
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningModule
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from torchvision import transforms
from torch.optim import RMSprop
from datetime import datetime
from torchinfo import summary
from glob import glob
import itertools
import pandas
import torch
import gc
import os

import matplotlib.pyplot as plt

### Common functions

In [3]:
def custom_collate_fn(batch):
    imgs = torch.stack([elem[0] for elem in batch])
    return [imgs]


# Cycle WGAN for CelebA - Bitmoji

## Model

In [4]:
class C2BGenerator(Module):
    """
    Take CelebA image [-1, 1] and generate "fake" Bitmoji image [-1, 1]
    """
    def __init__(self):
        super(C2BGenerator, self).__init__()
        self.model = Sequential(
            # Downsampling part
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1), padding='same'),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(4, 3), stride=(1, 1)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            Conv2d(in_channels=8, out_channels=10, kernel_size=(6, 4), stride=(1, 1)),
            PReLU(num_parameters=10, init=0.25),
            BatchNorm2d(num_features=10),

            Conv2d(in_channels=10, out_channels=12, kernel_size=(10, 4), stride=(2, 2)),
            PReLU(num_parameters=12, init=0.25),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=14, kernel_size=(15, 6), stride=(3, 3)),
            PReLU(num_parameters=14, init=0.25),
            BatchNorm2d(num_features=14),

            # Upsampling part
            ConvTranspose2d(in_channels=14, out_channels=12, kernel_size=(4, 8), stride=(2, 2)),
            PReLU(num_parameters=12, init=0.25),
            BatchNorm2d(num_features=12),

            ConvTranspose2d(in_channels=12, out_channels=10, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=10, init=0.25),
            BatchNorm2d(num_features=10),

            ConvTranspose2d(in_channels=10, out_channels=8, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            ConvTranspose2d(in_channels=8, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            Hardtanh(min_val=-1, max_val=1),
        )

    def initialise(self):
        for i in range(0, 27, 3):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['27'].weight)

    def forward(self, x):
        return self.model(x)


class B2CGenerator(Module):
    """
    Take Bitmoji image [-1, 1] and generate "fake" CelebA image [-1, 1]
    """
    def __init__(self):
        super(B2CGenerator, self).__init__()

        self.model = Sequential(
            # Downsampling part
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1), padding='same'),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            Conv2d(in_channels=8, out_channels=10, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=10, init=0.25),
            BatchNorm2d(num_features=10),

            Conv2d(in_channels=10, out_channels=12, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=12, init=0.25),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=14, kernel_size=(4, 8), stride=(2, 2)),
            PReLU(num_parameters=14, init=0.25),
            BatchNorm2d(num_features=14),

            # Upsampling part
            ConvTranspose2d(in_channels=14, out_channels=12, kernel_size=(15, 6), stride=(3, 3), 
                            output_padding=(1, )),
            PReLU(num_parameters=12, init=0.25),
            BatchNorm2d(num_features=12),

            ConvTranspose2d(in_channels=12, out_channels=10, kernel_size=(10, 4), stride=(2, 2)),
            PReLU(num_parameters=10, init=0.25),
            BatchNorm2d(num_features=10),

            ConvTranspose2d(in_channels=10, out_channels=8, kernel_size=(6, 4), stride=(1, 1)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            ConvTranspose2d(in_channels=8, out_channels=6, kernel_size=(4, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(3, 2), stride=(1, 1)),
            Hardtanh(min_val=-1, max_val=1),
        )

    def initialise(self):
        for i in range(0, 27, 3):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['27'].weight)

    def forward(self, x):
        return self.model(x)


class CriticB(Module):
    """
    Critic for Bitmoji images - Checks if the bitmoji image passed is from bitmoji "real" data distribution or from 
    "fake" C2B generator
    """
    def __init__(self):
        super(CriticB, self).__init__()

        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1), padding='same'),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=12, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=15, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=15),

            Flatten(),

            Linear(in_features=60, out_features=1),
        )

    def initialise(self):
        for i in range(0, 13, 4):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['17'].weight)

    def forward(self, x):
        return self.model(x)


class CriticC(Module):
    """
    Critic for CelebA images - Checks if the celeba image passed is from celeba "real" data distribution or from 
    "fake" B2C generator
    """
    def __init__(self):
        super(CriticC, self).__init__()

        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1), padding='same'),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=12, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=15, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=15),

            Flatten(),

            Linear(in_features=225, out_features=1),
        )

    def initialise(self):
        for i in range(0, 13, 4):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['17'].weight)

    def forward(self, x):
        return self.model(x)


class CBCycleWGAN(LightningModule):

    def __init__(self, ncritic, ngen, bs, cycle_weight):
        super(CBCycleWGAN, self).__init__()
        self.save_hyperparameters()
        self.ncritic = ncritic
        self.ngen = ngen
        self.bs = bs
        self.cycle_weight = cycle_weight

        # CelebA to Bitmoji
        self.genC2B = C2BGenerator()
        self.criticB = CriticB()

        # Bitmoji to CelebA
        self.genB2C = B2CGenerator()
        self.criticC = CriticC()

        # Initialisations
        self.genC2B.initialise()
        self.criticB.initialise()
        self.genB2C.initialise()
        self.criticC.initialise()

        # CycleGAN authors use image pool to update the discriminator. That was required because GAN training (on the
        # usual objective) was unstable. We are using WGAN and hopefully won't get into such issues. Hence, skipping
        # keeping the pool of images

        self.float()

    def forward(self, real_celebas, real_bitmojis):
        real_bitmoji_scores = self.criticB(real_bitmojis)
        real_celeba_scores = self.criticC(real_celebas)
        fake_bitmojis = self.genC2B(real_celebas)
        fake_celebas = self.genB2C(real_bitmojis)
        return real_celeba_scores, real_bitmoji_scores, fake_celebas, fake_bitmojis

    def _critic_losses(self, real_celebas, real_bitmojis, btype):
        # WGAN: Critic gets updated from the fake and real data
        # CycleGAN: Need to do this for both critics!

        # Bitmoji Critic
        real_bitmojis_score = self.criticB(real_bitmojis).mean()
        fake_bitmojis = self.genC2B(real_celebas)
        fake_bitmojis_score = self.criticB(fake_bitmojis).mean()
        criticB_loss = fake_bitmojis_score - real_bitmojis_score  # minimise this!

        # CelebA Critic
        real_celebas_score = self.criticC(real_celebas).mean()
        fake_celebas = self.genB2C(real_bitmojis)
        fake_celebas_score = self.criticC(fake_celebas).mean()
        criticC_loss = fake_celebas_score - real_celebas_score

        not_training = btype != 'train'
        self.log(f'{btype}/criticB_loss', criticB_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, 
                 sync_dist=not_training)
        self.log(f'{btype}/criticC_loss', criticC_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, 
                 sync_dist=not_training)
        return criticB_loss, criticC_loss

    def _generator_losses(self, real_celebas, real_bitmojis, btype):
        # WGAN: Generator gets updates from the fake data
        # CycleGAN: Do this for both generators and additionally put cycle consistency loss

        # CelebA to Bitmoji
        fake_bitmojis = self.genC2B(real_celebas)
        fake_bitmojis_score = self.criticB(fake_bitmojis).mean()
        genC2B_loss = -fake_bitmojis_score  # minimise this!

        # Bitmoji to CelebA
        fake_celebas = self.genB2C(real_bitmojis)
        fake_celebas_score = self.criticC(fake_celebas).mean()
        genB2C_loss = -fake_celebas_score

        # Cycle Consistency
        # Side Note: *Ideally* L1 norm should be added across dimensions and mean-ed across samples. In their
        # implementation, authors have mean-ed across dimensions too, keeping same implementation as them
        celeba_identity_loss = L1Loss()(real_celebas, self.genB2C(fake_bitmojis))
        bitmoji_identity_loss = L1Loss()(real_bitmojis, self.genC2B(fake_celebas))
        cycle_loss = self.cycle_weight * (celeba_identity_loss + bitmoji_identity_loss)

        not_training = btype != 'train'
        self.log(f'{btype}/genC2B_loss', genC2B_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, 
                 sync_dist=not_training)
        self.log(f'{btype}/genB2C_loss', genB2C_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, 
                 sync_dist=not_training)
        self.log(f'{btype}/cycle_loss',  cycle_loss,  on_step=False, on_epoch=True, reduce_fx=torch.mean, 
                 sync_dist=not_training)
        self.log(f'{btype}/celeba_identity_loss',  celeba_identity_loss,  on_step=False, on_epoch=True, 
                 reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/bitmoji_identity_loss', bitmoji_identity_loss, on_step=False, on_epoch=True, 
                 reduce_fx=torch.mean, sync_dist=not_training)

        return genC2B_loss, genB2C_loss, cycle_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        real_celebas, real_bitmojis = batch['celeba'][0], batch['bitmoji'][0]

        if optimizer_idx == 1:  # Critic optimizer - only update Critic weights
            criticB_loss, criticC_loss = self._critic_losses(real_celebas, real_bitmojis, 'train')
            # CycleGAN authors divide this loss by 2 to slow down the rate of critic learning. Here, as the loss is
            # wasserstein loss, I believe it may not be required
            return criticB_loss + criticC_loss

        if optimizer_idx == 0:  # Generator optimizer - only update Generator weights
            genC2B_loss, genB2C_loss, cycle_loss = self._generator_losses(real_celebas, real_bitmojis, 'train')
            return genC2B_loss + genB2C_loss + cycle_loss

        # Is there a way to consolidate the losses and return one per epoch?
        raise Exception(f'Unknown optimizer index: {optimizer_idx}')

    def _shared_eval(self, batch, btype):
        real_celebas, real_bitmojis = batch['celeba'][0], batch['bitmoji'][0]
        criticB_loss, criticC_loss = self._critic_losses(real_celebas, real_bitmojis, btype)
        genC2B_loss, genB2C_loss, cycle_loss = self._generator_losses(real_celebas, real_bitmojis, btype)
        total_loss = criticB_loss + criticC_loss + genC2B_loss + genB2C_loss + cycle_loss
        self.log(f'{btype}/loss', total_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        self._shared_eval(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._shared_eval(batch, 'test')

    def configure_optimizers(self):
        """Using the strategy from WGAN paper instead of CycleGAN paper!"""
        generator_opt = RMSprop(params=itertools.chain(self.genC2B.parameters(), self.genB2C.parameters()), 
                                lr=0.00005)
        critic_opt = RMSprop(params=itertools.chain(self.criticC.parameters(), self.criticB.parameters()), 
                             lr=0.00005)
        return (
            {"optimizer": generator_opt, "frequency": self.ngen},
            {"optimizer": critic_opt, "frequency": self.ncritic},
        )

    def on_train_batch_end(self, *args):
        """After weights updating by the optimisers, clamp the weights"""
        for weight in self.criticC.parameters():
            weight.data.clamp_(-0.01, 0.01)
        for weight in self.criticB.parameters():
            weight.data.clamp_(-0.01, 0.01)

    def train_dataloader(self):
        celeba_dataset = CelebA(celeba_data_dir, split='train', target_type=[],
                                transform=transforms.Compose([
                                    transforms.RandomHorizontalFlip(p=0.2),
                                    transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ]))
        celeba_dataloader = DataLoader(celeba_dataset, self.bs, shuffle=True, num_workers=0,
                                       collate_fn=custom_collate_fn)
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'train/',
                                      transform=transforms.Compose([
                                          transforms.RandomHorizontalFlip(p=0.2),
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=True, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return {
            'celeba': celeba_dataloader,
            'bitmoji': bitmoji_dataloader,
        }

    def val_dataloader(self):
        celeba_dataset = CelebA(celeba_data_dir, split='valid', target_type=[],
                                transform=transforms.Compose([
                                    transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ]))
        celeba_dataloader = DataLoader(celeba_dataset, self.bs, shuffle=False, num_workers=0,
                                       collate_fn=custom_collate_fn)
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'val/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return CombinedLoader({
            'celeba': celeba_dataloader,
            'bitmoji': bitmoji_dataloader,
        }, mode='max_size_cycle')

    def test_dataloader(self):
        celeba_dataset = CelebA(celeba_data_dir, split='test', target_type=[],
                                transform=transforms.Compose([
                                    transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ]))
        celeba_dataloader = DataLoader(celeba_dataset, self.bs, shuffle=False, num_workers=0,
                                       collate_fn=custom_collate_fn)
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'test/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return CombinedLoader({
            'celeba': celeba_dataloader,
            'bitmoji': bitmoji_dataloader,
        }, mode='max_size_cycle')


In [5]:
summary_kwargs = dict(dtypes=[torch.float], depth=3, col_names=['input_size', 'output_size', 'num_params'],
                      row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
ns = 10

celeba = torch.randn((ns, 3, 218, 178), dtype=torch.float)
summary_string = str(summary(model=C2BGenerator(), input_data=celeba, **summary_kwargs))
print(summary_string)

summary_string = str(summary(model=CriticC(), input_data=celeba, **summary_kwargs))
print(summary_string)

bitmoji = torch.randn((ns, 3, 128, 128), dtype=torch.float)
summary_string = str(summary(model=B2CGenerator(), input_data=bitmoji, **summary_kwargs))
print(summary_string)

summary_string = str(summary(model=CriticB(), input_data=bitmoji, **summary_kwargs))
print(summary_string)


Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #
C2BGenerator (C2BGenerator)              [10, 3, 218, 178]         [10, 3, 128, 128]         --
├─Sequential (model): 1-1                [10, 3, 218, 178]         [10, 3, 128, 128]         --
│    └─Conv2d (0): 2-1                   [10, 3, 218, 178]         [10, 6, 218, 178]         168
│    └─PReLU (1): 2-2                    [10, 6, 218, 178]         [10, 6, 218, 178]         6
│    └─BatchNorm2d (2): 2-3              [10, 6, 218, 178]         [10, 6, 218, 178]         12
│    └─Conv2d (3): 2-4                   [10, 6, 218, 178]         [10, 8, 215, 176]         584
│    └─PReLU (4): 2-5                    [10, 8, 215, 176]         [10, 8, 215, 176]         8
│    └─BatchNorm2d (5): 2-6              [10, 8, 215, 176]         [10, 8, 215, 176]         16
│    └─Conv2d (6): 2-7                   [10, 8, 215, 176]         [10, 10, 210, 173]        1,930
│    └─PReLU (7): 2-8           

# Cycle WGAN for SVHN - MNIST

## Model

In [6]:
class S2MGenerator(Module):
    """
    Take SVHN image [-1, 1] and generate "fake" MNIST image [-1, 1]
    """
    def __init__(self):
        super(S2MGenerator, self).__init__()
        self.model = Sequential(
            # Downsampling part
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1), padding='same'),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            Conv2d(in_channels=8, out_channels=10, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=10, init=0.25),
            BatchNorm2d(num_features=10),

            Conv2d(in_channels=10, out_channels=12, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=12, init=0.25),
            BatchNorm2d(num_features=12),

            # Upsampling part
            ConvTranspose2d(in_channels=12, out_channels=10, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=10, init=0.25),
            BatchNorm2d(num_features=10),

            ConvTranspose2d(in_channels=10, out_channels=8, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            ConvTranspose2d(in_channels=8, out_channels=6, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            ConvTranspose2d(in_channels=6, out_channels=5, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=5, init=0.25),
            BatchNorm2d(num_features=5),

            ConvTranspose2d(in_channels=5, out_channels=3, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            ConvTranspose2d(in_channels=3, out_channels=1, kernel_size=(4, 4), stride=(1, 1)),
            Hardtanh(min_val=-1, max_val=1),
        )

    def initialise(self):
        for i in range(0, 27, 3):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['27'].weight)

    def forward(self, x):
        return self.model(x)


class M2SGenerator(Module):
    """
    Take MNIST image [-1, 1] and generate "fake" SVHN image [-1, 1]
    """
    def __init__(self):
        super(M2SGenerator, self).__init__()

        self.model = Sequential(
            # Downsampling part
            Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(1, 1), padding='same'),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            Conv2d(in_channels=8, out_channels=10, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=10, init=0.25),
            BatchNorm2d(num_features=10),

            Conv2d(in_channels=10, out_channels=12, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=12, init=0.25),
            BatchNorm2d(num_features=12),

            # Upsampling part
            ConvTranspose2d(in_channels=12, out_channels=10, kernel_size=(4, 4), stride=(2, 2), output_padding=1),
            PReLU(num_parameters=10, init=0.25),
            BatchNorm2d(num_features=10),

            ConvTranspose2d(in_channels=10, out_channels=10, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=10, init=0.25),
            BatchNorm2d(num_features=10),

            ConvTranspose2d(in_channels=10, out_channels=8, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            ConvTranspose2d(in_channels=8, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            Hardtanh(min_val=-1, max_val=1),
        )

    def initialise(self):
        for i in range(0, 27, 3):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['27'].weight)

    def forward(self, x):
        return self.model(x)


class CriticM(Module):
    """
    Critic for MNIST images - Checks if the mnist image passed is from mnist "real" data distribution or from "fake" S2M generator
    """
    def __init__(self):
        super(CriticM, self).__init__()

        self.model = Sequential(
            Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(1, 1), padding='same'),
            PReLU(num_parameters=3, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=3),

            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            BatchNorm2d(num_features=9),

            Flatten(),

            Linear(in_features=81, out_features=1),
        )

    def initialise(self):
        for i in range(0, 9, 4):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['12'].weight)

    def forward(self, x):
        return self.model(x)


class CriticS(Module):
    """
    Critic for SVHN images - Checks if the svhn image passed is from svhn "real" data distribution or from "fake" M2S generator
    """
    def __init__(self):
        super(CriticS, self).__init__()

        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1), padding='same'),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=8, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=8),

            Conv2d(in_channels=8, out_channels=10, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=10, init=0.25),
            BatchNorm2d(num_features=10),

            Flatten(),

            Linear(in_features=160, out_features=1),
        )

    def initialise(self):
        for i in range(0, 9, 4):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['12'].weight)

    def forward(self, x):
        return self.model(x)


class SMCycleWGAN(LightningModule):

    def __init__(self, ncritic, ngen, bs, cycle_weight):
        super(SMCycleWGAN, self).__init__()
        self.save_hyperparameters()
        self.ncritic = ncritic
        self.ngen = ngen
        self.bs = bs
        self.cycle_weight = cycle_weight

        # SVHN to MNIST
        self.genS2M = S2MGenerator()
        self.criticM = CriticM()

        # MNIST to SVHN
        self.genM2S = M2SGenerator()
        self.criticS = CriticS()

        # Initialisations
        self.genS2M.initialise()
        self.criticM.initialise()
        self.genM2S.initialise()
        self.criticS.initialise()

        # CycleGAN authors use image pool to update the discriminator. That was required because GAN training (on the
        # usual objective) was unstable. We are using WGAN and hopefully won't get into such issues. Hence, skipping
        # keeping the pool of images

        self.float()

    def forward(self, real_svhns, real_mnists):
        real_mnist_scores = self.criticM(real_mnists)
        real_svhn_scores = self.criticS(real_svhns)
        fake_mnists = self.genS2M(real_svhns)
        fake_svhns = self.genM2S(real_mnists)
        return real_svhn_scores, real_mnist_scores, fake_svhns, fake_mnists

    def _critic_losses(self, real_svhns, real_mnists, btype):
        # WGAN: Critic gets updated from the fake and real data
        # CycleGAN: Need to do this for both critics!

        # MNIST Critic
        real_mnists_score = self.criticM(real_mnists).mean()
        fake_mnists = self.genS2M(real_svhns)
        fake_mnists_score = self.criticM(fake_mnists).mean()
        criticM_loss = fake_mnists_score - real_mnists_score  # minimise this!

        # SVHN Critic
        real_svhns_score = self.criticS(real_svhns).mean()
        fake_svhns = self.genM2S(real_mnists)
        fake_svhns_score = self.criticS(fake_svhns).mean()
        criticS_loss = fake_svhns_score - real_svhns_score

        not_training = btype != 'train'
        self.log(f'{btype}/criticM_loss', criticM_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/criticS_loss', criticS_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        return criticM_loss, criticS_loss

    def _generator_losses(self, real_svhns, real_mnists, btype):
        # WGAN: Generator gets updates from the fake data
        # CycleGAN: Do this for both generators and additionally put cycle consistency loss

        # SVHN to MNIST
        fake_mnists = self.genS2M(real_svhns)
        fake_mnists_score = self.criticM(fake_mnists).mean()
        genS2M_loss = -fake_mnists_score  # minimise this!

        # MNIST to SVHN
        fake_svhns = self.genM2S(real_mnists)
        fake_svhns_score = self.criticS(fake_svhns).mean()
        genM2S_loss = -fake_svhns_score

        # Cycle Consistency
        # Side Note: *Ideally* L1 norm should be added across dimensions and mean-ed across samples. In their
        # implementation, authors have mean-ed across dimensions too, keeping same implementation as them
        svhn_identity_loss = L1Loss()(real_svhns, self.genM2S(fake_mnists))
        mnist_identity_loss = L1Loss()(real_mnists, self.genS2M(fake_svhns))
        cycle_loss = self.cycle_weight * (svhn_identity_loss + mnist_identity_loss)

        not_training = btype != 'train'
        self.log(f'{btype}/genS2M_loss', genS2M_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/genM2S_loss', genM2S_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/cycle_loss',  cycle_loss,  on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/svhn_identity_loss',  svhn_identity_loss,  on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/mnist_identity_loss', mnist_identity_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)

        return genS2M_loss, genM2S_loss, cycle_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        real_svhns, real_mnists = batch['svhn'][0], batch['mnist'][0]

        if optimizer_idx == 1:  # Critic optimizer - only update Critic weights
            criticM_loss, criticS_loss = self._critic_losses(real_svhns, real_mnists, 'train')
            # CycleGAN authors divide this loss by 2 to slow down the rate of critic learning. Here, as the loss is
            # wasserstein loss, I believe it may not be required
            return criticM_loss + criticS_loss

        if optimizer_idx == 0:  # Generator optimizer - only update Generator weights
            genS2M_loss, genM2S_loss, cycle_loss = self._generator_losses(real_svhns, real_mnists, 'train')
            return genS2M_loss + genM2S_loss + cycle_loss

        # Is there a way to consolidate the losses and return one per epoch?
        raise Exception(f'Unknown optimizer index: {optimizer_idx}')

    def _shared_eval(self, batch, btype):
        real_svhns, real_mnists = batch['svhn'][0], batch['mnist'][0]
        criticM_loss, criticS_loss = self._critic_losses(real_svhns, real_mnists, btype)
        genS2M_loss, genM2S_loss, cycle_loss = self._generator_losses(real_svhns, real_mnists, btype)
        total_loss = criticM_loss + criticS_loss + genS2M_loss + genM2S_loss + cycle_loss
        self.log(f'{btype}/loss', total_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        self._shared_eval(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._shared_eval(batch, 'test')

    def configure_optimizers(self):
        """Using the strategy from WGAN paper instead of CycleGAN paper!"""
        generator_opt = RMSprop(params=itertools.chain(self.genS2M.parameters(), self.genM2S.parameters()), lr=0.00005)
        critic_opt = RMSprop(params=itertools.chain(self.criticS.parameters(), self.criticM.parameters()), lr=0.00005)
        return (
            {"optimizer": generator_opt, "frequency": self.ngen},
            {"optimizer": critic_opt, "frequency": self.ncritic},
        )

    def on_train_batch_end(self, *args):
        """After weights updating by the optimisers, clamp the weights"""
        for weight in self.criticS.parameters():
            weight.data.clamp_(-0.01, 0.01)
        for weight in self.criticM.parameters():
            weight.data.clamp_(-0.01, 0.01)

    def train_dataloader(self):
        svhn_dataset = SVHN(svhn_data_dir, split='train',
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
        svhn_dataloader = DataLoader(svhn_dataset, self.bs, shuffle=True, num_workers=0,
                                     collate_fn=custom_collate_fn)
        mnist_dataset = MNIST(mnist_data_dir, train=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  transforms.Normalize((0.5, ), (0.5, )),
                              ]))
        mnist_dataloader = DataLoader(mnist_dataset, self.bs, shuffle=True, num_workers=0,
                                      collate_fn=custom_collate_fn)
        return {
            'svhn': svhn_dataloader,
            'mnist': mnist_dataloader,
        }

    def val_dataloader(self):
        svhn_dataset = SVHN(svhn_data_dir, split='test',
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
        svhn_dataloader = DataLoader(svhn_dataset, self.bs, shuffle=False, num_workers=0,
                                     collate_fn=custom_collate_fn)
        mnist_dataset = MNIST(mnist_data_dir, train=False,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  transforms.Normalize((0.5, ), (0.5, )),
                              ]))
        mnist_dataloader = DataLoader(mnist_dataset, self.bs, shuffle=False, num_workers=0,
                                      collate_fn=custom_collate_fn)
        return CombinedLoader({
            'svhn': svhn_dataloader,
            'mnist': mnist_dataloader,
        }, mode='max_size_cycle')

    def test_dataloader(self):
        svhn_dataset = SVHN(svhn_data_dir, split='test',
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
        svhn_dataloader = DataLoader(svhn_dataset, self.bs, shuffle=False, num_workers=0,
                                     collate_fn=custom_collate_fn)
        mnist_dataset = MNIST(mnist_data_dir, train=False,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  transforms.Normalize((0.5, ), (0.5, )),
                              ]))
        mnist_dataloader = DataLoader(mnist_dataset, self.bs, shuffle=False, num_workers=0,
                                      collate_fn=custom_collate_fn)
        return CombinedLoader({
            'svhn': svhn_dataloader,
            'mnist': mnist_dataloader,
        }, mode='max_size_cycle')


In [7]:
summary_kwargs = dict(dtypes=[torch.float], depth=3, col_names=['input_size', 'output_size', 'num_params'],
                          row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
ns = 10

svhn = torch.randn((ns, 3, 32, 32), dtype=torch.float)
summary_string = str(summary(model=S2MGenerator(), input_data=svhn, **summary_kwargs))
print(summary_string)

summary_string = str(summary(model=CriticS(), input_data=svhn, **summary_kwargs))
print(summary_string)

mnist = torch.randn((ns, 1, 28, 28), dtype=torch.float)
summary_string = str(summary(model=M2SGenerator(), input_data=mnist, **summary_kwargs))
print(summary_string)

summary_string = str(summary(model=CriticM(), input_data=mnist, **summary_kwargs))
print(summary_string)

Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #
S2MGenerator (S2MGenerator)              [10, 3, 32, 32]           [10, 1, 28, 28]           --
├─Sequential (model): 1-1                [10, 3, 32, 32]           [10, 1, 28, 28]           --
│    └─Conv2d (0): 2-1                   [10, 3, 32, 32]           [10, 6, 32, 32]           168
│    └─PReLU (1): 2-2                    [10, 6, 32, 32]           [10, 6, 32, 32]           6
│    └─BatchNorm2d (2): 2-3              [10, 6, 32, 32]           [10, 6, 32, 32]           12
│    └─Conv2d (3): 2-4                   [10, 6, 32, 32]           [10, 8, 30, 30]           440
│    └─PReLU (4): 2-5                    [10, 8, 30, 30]           [10, 8, 30, 30]           8
│    └─BatchNorm2d (5): 2-6              [10, 8, 30, 30]           [10, 8, 30, 30]           16
│    └─Conv2d (6): 2-7                   [10, 8, 30, 30]           [10, 10, 27, 27]          1,290
│    └─PReLU (7): 2-8           

# Common Training Utilities

In [12]:
def train_and_test(max_epochs: int,
                   tags: list[str], gpu_num: list[int],
                   model_class, model_kwargs: dict,
                   loss_desc: str, input_shape: list[tuple]):
    seed_everything(0, workers=True)

    folder_name = datetime.utcnow().isoformat(sep="T", timespec="microseconds")
    results_dir = project_dir + f'gan/results/run_{folder_name}/'
    os.makedirs(results_dir, exist_ok=False)

    checkpoint_callback = ModelCheckpoint(monitor='val/loss', mode='min', dirpath=results_dir,
                                          save_last=True, save_top_k=10, auto_insert_metric_name=False,
                                          filename='epoch={epoch}-val_loss={val/loss:.4f}')

    if use_gpu:
        trainer_kwargs = dict(accelerator="gpu", devices=gpu_num)
    else:
        trainer_kwargs = dict()

    tf_logger = TensorBoardLogger(save_dir=results_dir, version=f'tf_logs',
                                  default_hp_metric=False)

    model = model_class(**model_kwargs)

    trainer = Trainer(default_root_dir=results_dir, max_epochs=max_epochs,
                      callbacks=[checkpoint_callback], logger=[tf_logger],
                      log_every_n_steps=1, num_sanity_val_steps=0, multiple_trainloader_mode='max_size_cycle',
                      limit_train_batches=5, limit_val_batches=1, limit_test_batches=1,
                      deterministic=True, **trainer_kwargs)
    trainer.fit(model)
    trainer.test(model)

    summary_string = str(summary(model=model, input_size=[(10, *input_shape[0]), (10, *input_shape[1])],
                                 dtypes=[torch.float, torch.float], depth=3, verbose=0,
                                 col_names=['input_size', 'output_size', 'num_params'],
                                 row_settings=['depth', 'var_names'], device=torch.device('cpu'))) + '\n' + loss_desc
    with open(results_dir + 'model_desc.md', 'w') as f:
        f.write(summary_string)
    gc.collect()


# Train & Test

## CelebA - Bitmoji

In [14]:
train_and_test(max_epochs=2, tags=[], gpu_num=[],
               model_class=CBCycleWGAN, model_kwargs=dict(ncritic=2, ngen=1, cycle_weight=1, bs=4),
               loss_desc='General loss', input_shape=[(3, 218, 178), (3, 128, 128)])

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.

  | Name    | Type         | Params
-----------------------------------------
0 | genC2B  | C2BGenerator | 31.5 K
1 | criticB | CriticB      | 5.5 K 
2 | genB2C  | B2CGenerator | 31.5 K
3 | criticC | CriticC      | 5.7 K 
-----------------------------------------
74.2 K    Trainable params
0         Non-trainable params
74.2 K    Total params
0.297     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric               DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test/bitmoji_identity_loss    0.8739818930625916
test/celeba_identity_loss     0.7569977045059204
    test/criticB_loss                0.0
    test/criticC_loss                0.0
     test/cycle_loss          1.6309795379638672
     test/genB2C_loss        0.009965265169739723
     test/genC2B_loss       -0.010036039166152477
        test/loss             1.630908727645874
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


## SVHN - MNIST

In [15]:
train_and_test(max_epochs=2, tags=[], gpu_num=[],
                   model_class=SMCycleWGAN, model_kwargs=dict(ncritic=2, ngen=1, cycle_weight=10, bs=4),
                   loss_desc='General loss', input_shape=[(3, 32, 32), (1, 28, 28)])

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
`Trainer(limit_test_batches=1)` was configured so 1 batch will be used.

  | Name    | Type         | Params
-----------------------------------------
0 | genS2M  | S2MGenerator | 7.4 K 
1 | criticM | CriticM      | 829   
2 | genM2S  | M2SGenerator | 8.4 K 
3 | criticS | CriticS      | 1.6 K 
-----------------------------------------
18.2 K    Trainable params
0         Non-trainable params
18.2 K    Total params
0.073     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
    test/criticM_loss               0.0
    test/criticS_loss               0.0
     test/cycle_loss        15.491276741027832
    test/genM2S_loss        0.01000344567000866
    test/genS2M_loss       -0.01000091340392828
        test/loss           15.491279602050781
test/mnist_identity_loss    1.1419328451156616
 test/svhn_identity_loss    0.4071948528289795
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


# Plots & Analysis

See GitHub code for examples of usage

## CelebA - Bitmoji

In [None]:
def plot_side_by_side(title, reals, gens, fname):
    fig, axes = plt.subplots(10, 10, figsize=(8, 8))
    fig.subplots_adjust(wspace=0.01, hspace=0.01, left=0, bottom=0, right=1, top=0.95)
    axes = axes.flat

    for i in range(reals.shape[0]):
        target_idx = 2 * i
        ax = axes[target_idx]
        ax.set_axis_off()
        ax.imshow(reals[i])

        pred_idx = target_idx + 1
        ax = axes[pred_idx]
        ax.set_axis_off()
        ax.imshow(gens[i])

    fig.suptitle(title)
    plt.savefig(project_dir + f'gan/celeba_bitmoji/img_results/{fname}')


def convert_to_image(ndarray):  # -1 to 1
    ndarray = ndarray * 0.5 + 0.5  # 0 to 1
    ndarray *= 255  # 0 to 255
    ndarray = numpy.round(ndarray, decimals=0)  # rounded off
    return ndarray.astype(int)


def see_some_translations(model, celeba, bitmoji, model_type):
    gen_bitmoji = model.genC2B(celeba).detach().numpy()
    gen_celeba = model.genB2C(bitmoji).detach().numpy()

    celeba = convert_to_image(numpy.transpose(celeba.numpy(), (0, 2, 3, 1)))
    bitmoji = convert_to_image(numpy.transpose(bitmoji.numpy(), (0, 2, 3, 1)))
    gen_bitmoji = convert_to_image(numpy.transpose(gen_bitmoji, (0, 2, 3, 1)))
    gen_celeba = convert_to_image(numpy.transpose(gen_celeba, (0, 2, 3, 1)))

    plot_side_by_side(f'CelebA to Bitmoji - NCritic = {model.ncritic}, NGen = {model.ngen}, Cycle Weight = {model.cycle_weight} - {model_type}',
                      celeba, gen_bitmoji, f'gan_trans_ncritic={model.ncritic}_ngen={model.ngen}_cycleweight={model.cycle_weight}_{model_type}_c2b')
    plot_side_by_side(f'Bitmoji to CelebA - NCritic = {model.ncritic}, NGen = {model.ngen}, Cycle Weight = {model.cycle_weight} - {model_type}',
                      bitmoji, gen_celeba, f'gan_trans_ncritic={model.ncritic}_ngen={model.ngen}_cycleweight={model.cycle_weight}_{model_type}_b2c')


def get_data(num_images):
    celeba_dataset = CelebA(celeba_data_dir, split='test', target_type=[],
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
    celeba_dataloader = DataLoader(celeba_dataset, num_images, shuffle=True, num_workers=num_cpus,
                                   collate_fn=custom_collate_fn)
    celeba = next(iter(celeba_dataloader))

    bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'test/',
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                  ]))
    bitmoji_dataloader = DataLoader(bitmoji_dataset, num_images, shuffle=True, num_workers=num_cpus,
                                    collate_fn=custom_collate_fn)
    bitmoji = next(iter(bitmoji_dataloader))
    return celeba[0], bitmoji[0]


def give_best_fname(direc):
    data = []
    for fname in glob(direc + 'epoch*.ckpt'):
        val_loss = float(fname.replace(direc, '').split('=')[2].replace('.ckpt', ''))
        data.append({'fname': fname, 'val_loss': val_loss})
    df = pandas.DataFrame(data)
    min_idx = df.val_loss.idxmin()
    best_fname = df.fname[min_idx]
    return best_fname


def plot_translations(result_dir):
    num_imgs = 50
    _celeba, _bitmoji = get_data(num_imgs)
    direc = project_dir + 'gan/celeba_bitmoji/results/' + result_dir + '/'
    best_fname = give_best_fname(direc)
    best_mod = CBCycleWGAN.load_from_checkpoint(best_fname)
    see_some_translations(best_mod, _celeba, _bitmoji, 'best')
    mod = CBCycleWGAN.load_from_checkpoint(project_dir + 'gan/celeba_bitmoji/results/' + result_dir + '/last.ckpt')
    see_some_translations(mod, _celeba, _bitmoji, 'last')


## SVHN - MNIST

In [None]:
def plot_side_by_side(title, reals, gens, gray, fname):
    fig, axes = plt.subplots(10, 10, figsize=(8, 8))
    fig.subplots_adjust(wspace=0.01, hspace=0.01, left=0, bottom=0, right=1, top=0.95)
    axes = axes.flat

    for i in range(reals.shape[0]):
        target_idx = 2 * i
        ax = axes[target_idx]
        ax.set_axis_off()

        if gray == 'real':
            ax.imshow(reals[i], cmap='gray', vmin=0, vmax=255)
        else:
            ax.imshow(reals[i])

        pred_idx = target_idx + 1
        ax = axes[pred_idx]
        ax.set_axis_off()
        if gray == 'gen':
            ax.imshow(gens[i], cmap='gray', vmin=0, vmax=255)
        else:
            ax.imshow(gens[i])

    fig.suptitle(title)
    plt.savefig(project_dir + f'gan/svhn_mnist/img_results/{fname}')


def convert_to_image(ndarray):  # -1 to 1
    ndarray = ndarray * 0.5 + 0.5  # 0 to 1
    ndarray *= 255  # 0 to 255
    ndarray = numpy.round(ndarray, decimals=0)  # rounded off
    return ndarray.astype(int)


def see_some_translations(model, svhn, mnist, model_type):
    gen_mnist = model.genS2M(svhn).detach().numpy()
    gen_svhn = model.genM2S(mnist).detach().numpy()

    svhn = convert_to_image(numpy.transpose(svhn.numpy(), (0, 2, 3, 1)))
    mnist = convert_to_image(numpy.transpose(mnist.numpy(), (0, 2, 3, 1)))
    gen_mnist = convert_to_image(numpy.transpose(gen_mnist, (0, 2, 3, 1)))
    gen_svhn = convert_to_image(numpy.transpose(gen_svhn, (0, 2, 3, 1)))

    plot_side_by_side(f'SVHN to MNIST - NCritic = {model.ncritic}, NGen = {model.ngen}, Cycle Weight = {model.cycle_weight} - {model_type}',
                      svhn, gen_mnist, 'gen',
                      f'gan_trans_ncritic={model.ncritic}_ngen={model.ngen}_cycleweight={model.cycle_weight}_{model_type}_s2m')
    plot_side_by_side(f'MNIST to SVHN - NCritic = {model.ncritic}, NGen = {model.ngen}, Cycle Weight = {model.cycle_weight} - {model_type}',
                      mnist, gen_svhn, 'real',
                      f'gan_trans_ncritic={model.ncritic}_ngen={model.ngen}_cycleweight={model.cycle_weight}_{model_type}_m2s')


def get_data(num_images):
    svhn_dataset = SVHN(svhn_data_dir, split='test',
                        transform=transforms.Compose([
                            transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                        ]))
    svhn_dataloader = DataLoader(svhn_dataset, num_images, shuffle=True, num_workers=num_cpus,
                                 collate_fn=custom_collate_fn)
    mnist_dataset = MNIST(mnist_data_dir, train=False,
                          transform=transforms.Compose([
                              transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                              transforms.Normalize((0.5,), (0.5,)),
                          ]))
    mnist_dataloader = DataLoader(mnist_dataset, num_images, shuffle=True, num_workers=num_cpus,
                                  collate_fn=custom_collate_fn)
    svhn = next(iter(svhn_dataloader))
    mnist = next(iter(mnist_dataloader))
    return svhn[0], mnist[0]


def give_best_fname(direc):
    data = []
    for fname in glob(direc + 'epoch*.ckpt'):
        val_loss = float(fname.replace(direc, '').split('=')[2].replace('.ckpt', ''))
        data.append({'fname': fname, 'val_loss': val_loss})
    df = pandas.DataFrame(data)
    min_idx = df.val_loss.idxmin()
    best_fname = df.fname[min_idx]
    return best_fname


def plot_translations(result_dir):
    num_imgs = 50
    _svhn, _mnist = get_data(num_imgs)
    direc = project_dir + 'gan/svhn_mnist/results/' + result_dir + '/'
    best_fname = give_best_fname(direc)
    best_mod = SMCycleWGAN.load_from_checkpoint(best_fname)
    see_some_translations(best_mod, _svhn, _mnist, 'best')
    mod = SMCycleWGAN.load_from_checkpoint(project_dir + 'gan/svhn_mnist/results/' + result_dir + '/last.ckpt')
    see_some_translations(mod, _svhn, _mnist, 'last')
