### General Settings

Change the respective settings to run appropriately

Use `limit_train_batches`, `limit_val_batches`, `limit_test_batches` as required

In [1]:
project_dir = '/Users/rajjain/PycharmProjects/ADRL-Course-Work/'
data_dir = project_dir + 'data/'
mnist_data_dir = '/Users/rajjain/Desktop/CourseWork/MNIST/'
usps_data_dir = '/Users/rajjain/Desktop/CourseWork/USPS/'
use_gpu = False
num_cpus = 2

In [2]:
from torch.nn import init, Linear, Sequential, Conv2d, PReLU, Module, BatchNorm2d, ConvTranspose2d, Hardtanh, \
    Flatten, L1Loss, BatchNorm1d, Unflatten, CrossEntropyLoss
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.seed import seed_everything
from torchmetrics.functional.classification import accuracy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningModule
from torchvision.datasets import USPS, MNIST
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from torchvision import transforms
from torch.optim import Adam
from torchinfo import summary
from datetime import datetime
from glob import glob
from tqdm import tqdm
import itertools
import shutil
import torch
import gc
import os

In [3]:
def custom_collate_fn(batch):
    imgs = torch.stack([elem[0] for elem in batch])
    return [imgs]

# Models

In [4]:
class U2MGenerator(Module):
    """
    Take USPS image [-1, 1] and generate "fake" MNIST image [-1, 1]
    """

    def __init__(self):
        super(U2MGenerator, self).__init__()
        self.model = Sequential(
            # Downsampling part
            Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            Flatten(),

            Linear(in_features=200, out_features=200),
            PReLU(num_parameters=200, init=0.25),
            BatchNorm1d(num_features=200),

            Unflatten(1, (8, 5, 5)),

            # Upsampling part
            ConvTranspose2d(in_channels=8, out_channels=6, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            ConvTranspose2d(in_channels=3, out_channels=1, kernel_size=(3, 3), stride=(1, 1)),
            Hardtanh(min_val=-1, max_val=1),
        )

    def initialise(self):
        for i in [0, 3, 6, 10, 14, 17]:
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['20'].weight)

    def forward(self, x):
        return self.model(x)


class M2UGenerator(Module):
    """
    Take MNIST image [-1, 1] and generate "fake" USPS image [-1, 1]
    """

    def __init__(self):
        super(M2UGenerator, self).__init__()

        self.model = Sequential(
            # Downsampling part
            Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            Conv2d(in_channels=3, out_channels=6, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            Flatten(),

            Linear(in_features=200, out_features=200),
            PReLU(num_parameters=200, init=0.25),
            BatchNorm1d(num_features=200),

            Unflatten(1, (8, 5, 5)),

            # Upsampling part
            ConvTranspose2d(in_channels=8, out_channels=6, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            ConvTranspose2d(in_channels=3, out_channels=1, kernel_size=(3, 3), stride=(1, 1)),
            Hardtanh(min_val=-1, max_val=1),
        )

    def initialise(self):
        for i in [0, 3, 6, 10, 14, 17]:
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['20'].weight)

    def forward(self, x):
        return self.model(x)


class CriticM(Module):
    """
    Critic for MNIST images - Checks if the mnist image passed is from mnist "real" data distribution or from "fake" U2M generator
    """

    def __init__(self):
        super(CriticM, self).__init__()

        self.model = Sequential(
            Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=3, init=0.25),

            Conv2d(in_channels=3, out_channels=6, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=6, init=0.25),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=8, init=0.25),

            Flatten(),

            Linear(in_features=200, out_features=1),
        )

    def initialise(self):
        for i in [0, 2, 4]:
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['7'].weight)

    def forward(self, x):
        return self.model(x)


class CriticU(Module):
    """
    Critic for USPS images - Checks if the usps image passed is from usps "real" data distribution or from "fake" M2U generator
    """

    def __init__(self):
        super(CriticU, self).__init__()

        self.model = Sequential(
            Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=3, init=0.25),

            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=8, init=0.25),

            Flatten(),

            Linear(in_features=200, out_features=1),
        )

    def initialise(self):
        for i in [0, 2, 4]:
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['7'].weight)

    def forward(self, x):
        return self.model(x)


class MUCycleGANGP(LightningModule):

    def __init__(self, ncritic, ngen, bs, cycle_weight, penalty_weight: float):
        super(MUCycleGANGP, self).__init__()
        self.save_hyperparameters()
        self.ncritic = ncritic
        self.ngen = ngen
        self.bs = bs
        self.cycle_weight = cycle_weight
        self.penalty_weight = penalty_weight

        # USPS to MNIST
        self.genU2M = U2MGenerator()
        self.criticM = CriticM()

        # MNIST to USPS
        self.genM2U = M2UGenerator()
        self.criticU = CriticU()

        # Initialisations
        self.genU2M.initialise()
        self.criticM.initialise()
        self.genM2U.initialise()
        self.criticU.initialise()

        # CycleGAN authors use image pool to update the discriminator. That was required because GAN training (on the
        # usual objective) was unstable. We are using WGAN and hopefully won't get into such issues. Hence, skipping
        # keeping the pool of images

        self.float()

    def _get_gradient_penalty(self, critic, reals, fakes):
        batch_size = reals.shape[0]
        eps = torch.rand(batch_size, 1, 1, 1, device=self.device)
        eps = eps.expand_as(reals)
        interpolated = eps * reals + (1 - eps) * fakes
        interpolated.requires_grad_(True)
        interpolated_scores = critic(interpolated)
        gradients = torch.autograd.grad(
            outputs=interpolated_scores,
            inputs=interpolated,
            grad_outputs=torch.ones_like(interpolated_scores),
            create_graph=True,
            retain_graph=True,
        )[0]
        gradients = gradients.view(batch_size, -1)
        gradients_norm = gradients.norm(2, 1)  # norm of gradient each of the samples
        penalty = (gradients_norm - 1) ** 2  # penalty for each sample
        return self.penalty_weight * penalty.mean()  # mean across samples

    def _critic_losses(self, real_uspss, real_mnists, btype):
        # WGAN: Critic gets updated from the fake and real data
        # CycleGAN: Need to do this for both critics!

        # MNIST Critic
        real_mnists_score = self.criticM(real_mnists).mean()
        fake_mnists = self.genU2M(real_uspss)
        fake_mnists_score = self.criticM(fake_mnists).mean()
        criticM_loss = fake_mnists_score - real_mnists_score  # minimise this!

        criticM_gp = self._get_gradient_penalty(self.criticM, real_mnists, fake_mnists)

        # USPS Critic
        real_uspss_score = self.criticU(real_uspss).mean()
        fake_uspss = self.genM2U(real_mnists)
        fake_uspss_score = self.criticU(fake_uspss).mean()
        criticU_loss = fake_uspss_score - real_uspss_score

        criticU_gp = self._get_gradient_penalty(self.criticU, real_uspss, fake_uspss)

        not_training = btype != 'train'
        self.log(f'{btype}/criticM_loss', criticM_loss, on_step=False, on_epoch=True, sync_dist=not_training)
        self.log(f'{btype}/criticU_loss', criticU_loss, on_step=False, on_epoch=True, sync_dist=not_training)
        self.log(f'{btype}/criticM_gp', criticM_gp, on_step=False, on_epoch=True, sync_dist=not_training)
        self.log(f'{btype}/criticU_gp', criticU_gp, on_step=False, on_epoch=True, sync_dist=not_training)

        self.log(f'{btype}/real_mnists_score', real_mnists_score, on_step=False, on_epoch=True, sync_dist=not_training)
        self.log(f'{btype}/fake_mnists_score', fake_mnists_score, on_step=False, on_epoch=True, sync_dist=not_training)
        self.log(f'{btype}/real_uspss_score', real_uspss_score, on_step=False, on_epoch=True, sync_dist=not_training)
        self.log(f'{btype}/fake_uspss_score', fake_uspss_score, on_step=False, on_epoch=True, sync_dist=not_training)
        return criticM_loss, criticU_loss, criticM_gp, criticU_gp

    def _generator_losses(self, real_uspss, real_mnists, btype):
        # WGAN: Generator gets updates from the fake data
        # CycleGAN: Do this for both generators and additionally put cycle consistency loss

        # USPS to MNIST
        fake_mnists = self.genU2M(real_uspss)
        fake_mnists_score = self.criticM(fake_mnists).mean()
        genU2M_loss = -fake_mnists_score  # minimise this!

        # MNIST to USPS
        fake_uspss = self.genM2U(real_mnists)
        fake_uspss_score = self.criticU(fake_uspss).mean()
        genM2U_loss = -fake_uspss_score

        # Cycle Consistency
        # Side Note: *Ideally* L1 norm should be added across dimensions and mean-ed across samples. In their
        # implementation, authors have mean-ed across dimensions too, keeping same implementation as them
        usps_identity_loss = L1Loss()(real_uspss, self.genM2U(fake_mnists))
        mnist_identity_loss = L1Loss()(real_mnists, self.genU2M(fake_uspss))
        cycle_loss = self.cycle_weight * (usps_identity_loss + mnist_identity_loss)

        not_training = btype != 'train'
        self.log(f'{btype}/genU2M_loss', genU2M_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/genM2U_loss', genM2U_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/cycle_loss', cycle_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/usps_identity_loss', usps_identity_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/mnist_identity_loss', mnist_identity_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)

        return genU2M_loss, genM2U_loss, cycle_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        real_uspss, real_mnists = batch['usps'][0], batch['mnist'][0]

        if optimizer_idx == 1:  # Critic optimizer - only update Critic weights
            criticM_loss, criticU_loss, criticM_gp, criticU_gp = self._critic_losses(real_uspss, real_mnists, 'train')
            # CycleGAN authors divide this loss by 2 to slow down the rate of critic learning. Here, as the loss is
            # wasserstein loss, I believe it may not be required
            return criticM_loss + criticU_loss + criticM_gp + criticU_gp

        if optimizer_idx == 0:  # Generator optimizer - only update Generator weights
            genU2M_loss, genM2U_loss, cycle_loss = self._generator_losses(real_uspss, real_mnists, 'train')
            return genU2M_loss + genM2U_loss + cycle_loss

        # Is there a way to consolidate the losses and return one per epoch?
        raise Exception(f'Unknown optimizer index: {optimizer_idx}')

    def _shared_eval(self, batch, btype):
        real_uspss, real_mnists = batch['usps'][0], batch['mnist'][0]
        criticM_loss, criticU_loss, _, _ = self._critic_losses(real_uspss, real_mnists, btype)
        _, _, cycle_loss = self._generator_losses(real_uspss, real_mnists, btype)

        mnist_emd = -criticM_loss
        usps_emd = -criticU_loss
        overall_loss = mnist_emd + usps_emd + cycle_loss

        self.log(f'{btype}/mnist_emd', mnist_emd, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=True)
        self.log(f'{btype}/usps_emd', usps_emd, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=True)
        self.log(f'{btype}/overall_loss', overall_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._shared_eval(batch, 'val')

    def test_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._shared_eval(batch, 'test')

    def configure_optimizers(self):
        """Using the strategy from WGAN paper instead of CycleGAN paper!"""
        generator_opt = Adam(params=itertools.chain(self.genU2M.parameters(), self.genM2U.parameters()), lr=0.0001, betas=(0, 0.9))
        critic_opt = Adam(params=itertools.chain(self.criticU.parameters(), self.criticM.parameters()), lr=0.0001, betas=(0, 0.9))
        return (
            {"optimizer": generator_opt, "frequency": self.ngen},
            {"optimizer": critic_opt, "frequency": self.ncritic},
        )

    def train_dataloader(self):
        usps_dataset = USPS(usps_data_dir, train=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5,), (0.5,)),
                            ]))
        usps_dataloader = DataLoader(usps_dataset, self.bs, shuffle=True, num_workers=0,
                                     collate_fn=custom_collate_fn, drop_last=True)
        mnist_dataset = MNIST(mnist_data_dir, train=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  transforms.Normalize((0.5,), (0.5,)),
                              ]))
        mnist_dataloader = DataLoader(mnist_dataset, self.bs, shuffle=True, num_workers=0,
                                      collate_fn=custom_collate_fn, drop_last=True)
        return {
            'usps': usps_dataloader,
            'mnist': mnist_dataloader,
        }

    def val_dataloader(self):
        usps_dataset = USPS(usps_data_dir, train=False,
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5,), (0.5,)),
                            ]))
        usps_dataloader = DataLoader(usps_dataset, self.bs, shuffle=False, num_workers=0,
                                     collate_fn=custom_collate_fn, drop_last=True)
        mnist_dataset = MNIST(mnist_data_dir, train=False,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  transforms.Normalize((0.5,), (0.5,)),
                              ]))
        mnist_dataloader = DataLoader(mnist_dataset, self.bs, shuffle=False, num_workers=0,
                                      collate_fn=custom_collate_fn, drop_last=True)
        return CombinedLoader({
            'usps': usps_dataloader,
            'mnist': mnist_dataloader,
        }, mode='max_size_cycle')

    def test_dataloader(self):
        usps_dataset = USPS(usps_data_dir, train=False,
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5,), (0.5,)),
                            ]))
        usps_dataloader = DataLoader(usps_dataset, self.bs, shuffle=False, num_workers=0,
                                     collate_fn=custom_collate_fn, drop_last=True)
        mnist_dataset = MNIST(mnist_data_dir, train=False,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  transforms.Normalize((0.5,), (0.5,)),
                              ]))
        mnist_dataloader = DataLoader(mnist_dataset, self.bs, shuffle=False, num_workers=0,
                                      collate_fn=custom_collate_fn, drop_last=True)
        return CombinedLoader({
            'usps': usps_dataloader,
            'mnist': mnist_dataloader,
        }, mode='max_size_cycle')

    def summary(self) -> str:
        _summary_kwargs = dict(dtypes=[torch.float], depth=3, col_names=['input_size', 'output_size', 'num_params'],
                               row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        _usps = torch.randn((10, 1, 16, 16), dtype=torch.float)
        _mnist = torch.randn((10, 1, 28, 28), dtype=torch.float)
        _summary_string = str(summary(model=self.genU2M, input_data=_usps, **_summary_kwargs)) + '\n' + \
                          str(summary(model=self.criticU, input_data=_usps, **_summary_kwargs)) + '\n' + \
                          str(summary(model=self.genM2U, input_data=_mnist, **_summary_kwargs)) + '\n' + \
                          str(summary(model=self.criticM, input_data=_mnist, **_summary_kwargs))
        return _summary_string


class USPSClassifier(LightningModule):
    """
    Classify USPS image [-1, 1]
    """

    def __init__(self, bs):
        super(USPSClassifier, self).__init__()
        self.save_hyperparameters()
        self.bs = bs
        self.model = Sequential(
            # Downsampling part
            Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            Flatten(),

            Linear(in_features=200, out_features=10),
        )
        self.float()

    def initialise(self):
        for i in [0, 3, 6]:
            init.kaiming_normal_(self.model[i].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model[10].weight)

    def forward(self, x):
        return self.model(x)

    def _common_step(self, batch, btype):
        not_training = btype != 'train'
        x, y = batch
        y_hat = self(x)
        loss = CrossEntropyLoss()(y_hat, y)
        acc = accuracy(y_hat, y, average='macro', num_classes=10, multiclass=True)
        self.log(f'{btype}/usps_classifier_loss', loss, on_step=False, on_epoch=True, sync_dist=not_training)
        self.log(f'{btype}/usps_classifier_acc', acc, on_step=False, on_epoch=True, sync_dist=not_training)
        return loss

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def train_dataloader(self):
        usps_dataset = USPS(usps_data_dir, train=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5,), (0.5,)),
                            ]))
        usps_dataloader = DataLoader(usps_dataset, self.bs, shuffle=True, num_workers=0)
        return usps_dataloader

    def eval_dataloader(self):
        usps_dataset = USPS(usps_data_dir, train=False,
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5,), (0.5,)),
                            ]))
        usps_dataloader = DataLoader(usps_dataset, self.bs, shuffle=False, num_workers=0)
        return usps_dataloader

    def val_dataloader(self):
        return self.eval_dataloader()

    def test_dataloader(self):
        return self.eval_dataloader()

    def summary(self) -> str:
        summary_kwargs = dict(dtypes=[torch.float], depth=3, col_names=['input_size', 'output_size', 'num_params'],
                              row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        usps_imgs = torch.randn((10, 1, 16, 16), dtype=torch.float)
        summary_string = str(summary(model=self, input_data=usps_imgs, **summary_kwargs))
        return summary_string


class MNISTClassifier(LightningModule):
    """
    Classify MNIST image [-1, 1]
    """

    def __init__(self, bs):
        super(MNISTClassifier, self).__init__()
        self.save_hyperparameters()
        self.bs = bs
        self.model = Sequential(
            # Downsampling part
            Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            Conv2d(in_channels=3, out_channels=6, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),

            Flatten(),

            Linear(in_features=200, out_features=10),
        )
        self.float()

    def initialise(self):
        for i in [0, 3, 6]:
            init.kaiming_normal_(self.model[i].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model[10].weight)

    def forward(self, x):
        return self.model(x)

    def _common_step(self, batch, btype):
        not_training = btype != 'train'
        x, y = batch
        y_hat = self(x)
        loss = CrossEntropyLoss()(y_hat, y)
        acc = accuracy(y_hat, y, average='macro', num_classes=10, multiclass=True)
        self.log(f'{btype}/mnist_classifier_loss', loss, on_step=False, on_epoch=True, sync_dist=not_training)
        self.log(f'{btype}/mnist_classifier_acc', acc, on_step=False, on_epoch=True, sync_dist=not_training)
        return loss

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def train_dataloader(self):
        mnist_dataset = MNIST(mnist_data_dir, train=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  transforms.Normalize((0.5,), (0.5,)),
                              ]))
        mnist_dataloader = DataLoader(mnist_dataset, self.bs, shuffle=True, num_workers=0)
        return mnist_dataloader

    def eval_dataloader(self):
        mnist_dataset = MNIST(mnist_data_dir, train=False,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  transforms.Normalize((0.5,), (0.5,)),
                              ]))
        mnist_dataloader = DataLoader(mnist_dataset, self.bs, shuffle=False, num_workers=0)
        return mnist_dataloader

    def val_dataloader(self):
        return self.eval_dataloader()

    def test_dataloader(self):
        return self.eval_dataloader()

    def summary(self) -> str:
        summary_kwargs = dict(dtypes=[torch.float], depth=3, col_names=['input_size', 'output_size', 'num_params'],
                              row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        usps_imgs = torch.randn((10, 1, 28, 28), dtype=torch.float)
        summary_string = str(summary(model=self, input_data=usps_imgs, **summary_kwargs))
        return summary_string


In [5]:
print(MUCycleGANGP(1, 1, 1, 1, 1).summary())
print(USPSClassifier(1).summary())
print(MNISTClassifier(1).summary())

Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #
U2MGenerator (U2MGenerator)              [10, 1, 16, 16]           [10, 1, 28, 28]           --
├─Sequential (model): 1-1                [10, 1, 16, 16]           [10, 1, 28, 28]           --
│    └─Conv2d (0): 2-1                   [10, 1, 16, 16]           [10, 3, 14, 14]           30
│    └─PReLU (1): 2-2                    [10, 3, 14, 14]           [10, 3, 14, 14]           3
│    └─BatchNorm2d (2): 2-3              [10, 3, 14, 14]           [10, 3, 14, 14]           6
│    └─Conv2d (3): 2-4                   [10, 3, 14, 14]           [10, 6, 12, 12]           168
│    └─PReLU (4): 2-5                    [10, 6, 12, 12]           [10, 6, 12, 12]           6
│    └─BatchNorm2d (5): 2-6              [10, 6, 12, 12]           [10, 6, 12, 12]           12
│    └─Conv2d (6): 2-7                   [10, 6, 12, 12]           [10, 8, 5, 5]             776
│    └─PReLU (7): 2-8               

# Train & Test

In [6]:
def train_and_test(max_epochs: int, tags: list[str], gpu_num: list[int],
                   model_class, model_kwargs: dict, model_desc: str):
    seed_everything(0, workers=True)

    folder_name = datetime.utcnow().isoformat(sep="T", timespec="microseconds")
    results_dir = project_dir + f'domain_adap/cycle_gan/results/run_{folder_name}/'
    os.makedirs(results_dir, exist_ok=False)

    checkpoint_callback = ModelCheckpoint(monitor='val/overall_loss', mode='min', dirpath=results_dir,
                                          save_last=True, save_top_k=1, auto_insert_metric_name=False,
                                          filename='epoch={epoch}-val_overall_loss={val/overall_loss:.4f}')

    if use_gpu:
        trainer_kwargs = dict(accelerator="gpu", devices=gpu_num)
    else:
        trainer_kwargs = dict()

    tf_logger = TensorBoardLogger(save_dir=results_dir, version=f'tf_logs',
                                  default_hp_metric=False)

    model = model_class(**model_kwargs)

    trainer = Trainer(default_root_dir=results_dir, max_epochs=max_epochs,
                      callbacks=[checkpoint_callback], logger=[tf_logger],
                      log_every_n_steps=1, num_sanity_val_steps=0, multiple_trainloader_mode='max_size_cycle',
                      limit_train_batches=6, limit_val_batches=6, limit_test_batches=6,
                      deterministic=True, **trainer_kwargs)
    trainer.fit(model)
    trainer.test(model)

    summary = model.summary() + '\n' + model_desc
    with open(results_dir + 'model_desc.md', 'w') as f:
        f.write(summary)

    gc.collect()
    return f'run_{folder_name}'


def train_and_test_classifier(max_epochs: int, tags: list[str], gpu_num: list[int],
                              model_class, model_kwargs: dict, model_desc: str, dataset: str):
    seed_everything(0, workers=True)

    folder_name = datetime.utcnow().isoformat(sep="T", timespec="microseconds")
    results_dir = project_dir + f'domain_adap/cycle_gan/results/run_{folder_name}/'
    os.makedirs(results_dir, exist_ok=False)

    checkpoint_callback = ModelCheckpoint(monitor=f'val/{dataset}_classifier_loss', mode='min', dirpath=results_dir,
                                          filename='best')

    if use_gpu:
        trainer_kwargs = dict(accelerator="gpu", devices=gpu_num)
    else:
        trainer_kwargs = dict()

    tf_logger = TensorBoardLogger(save_dir=results_dir, version=f'tf_logs',
                                  default_hp_metric=False)

    model = model_class(**model_kwargs)

    trainer = Trainer(default_root_dir=results_dir, max_epochs=max_epochs,
                      callbacks=[checkpoint_callback], logger=[tf_logger],
                      log_every_n_steps=1, num_sanity_val_steps=0,
                      limit_train_batches=20, limit_val_batches=20, limit_test_batches=20,
                      deterministic=True, **trainer_kwargs)
    trainer.fit(model)
    trainer.test(ckpt_path='best')

    summary = model.summary() + '\n' + model_desc
    with open(results_dir + 'model_desc.md', 'w') as f:
        f.write(summary)

    gc.collect()
    return f'run_{folder_name}'


def test_target(test_dl, generator, classifier):
    from torch.nn import CrossEntropyLoss
    from torchmetrics.functional.classification import accuracy
    ys, y_hats = [], []
    for x, y in tqdm(test_dl):
        gen_x = generator(x)
        y_hat = classifier(gen_x)
        ys.append(y)
        y_hats.append(y_hat)
    all_y = torch.concat(ys)
    all_y_hat = torch.concat(y_hats)
    loss = CrossEntropyLoss()(all_y_hat, all_y).item()
    acc = accuracy(all_y_hat, all_y, average='macro', num_classes=10, multiclass=True).item()
    return loss, acc


def test_target_accuracies(mnist_classifier_folder, usps_classifier_folder, cycle_gan_folder):
    mnist_classifier = MNISTClassifier.load_from_checkpoint(project_dir + 'domain_adap/cycle_gan/results/' +
                                                            mnist_classifier_folder + '/best.ckpt', bs=20)
    usps_classifier = USPSClassifier.load_from_checkpoint(project_dir + 'domain_adap/cycle_gan/results/' +
                                                          usps_classifier_folder + '/best.ckpt', bs=20)

    cycle_gan_direc = project_dir + 'domain_adap/cycle_gan/results/' + cycle_gan_folder + '/'
    best_fname = glob(cycle_gan_direc + 'epoch*.ckpt')[0]
    cycle_gan = MUCycleGANGP.load_from_checkpoint(best_fname)

    mnist_test_dl = mnist_classifier.test_dataloader()
    usps_test_dl = usps_classifier.test_dataloader()

    mnist_loss, mnist_acc = test_target(mnist_test_dl, cycle_gan.genM2U, usps_classifier)
    usps_loss, usps_acc = test_target(usps_test_dl, cycle_gan.genU2M, mnist_classifier)

    print(f'Target: MNIST -- CE Loss = {mnist_loss:2.4f}, Acc = {mnist_acc * 100:.2f}%')
    print(f'Target:  USPS -- CE Loss = {usps_loss:2.4f}, Acc = {usps_acc * 100:.2f}%')

In [7]:
mnist_classifier_folder = train_and_test_classifier(2, tags=[], gpu_num=[], model_class=MNISTClassifier,
                                                    model_kwargs=dict(bs=10), model_desc='MNIST Classifier', 
                                                    dataset='mnist')

usps_classifier_folder = train_and_test_classifier(2, tags=[], gpu_num=[], model_class=USPSClassifier,
                                                   model_kwargs=dict(bs=10), model_desc='USPS Classifier', 
                                                   dataset='usps')

cycle_gan_folder = train_and_test(max_epochs=2, tags=[], gpu_num=[], model_class=MUCycleGANGP,
                                  model_kwargs=dict(ncritic=2, ngen=1, cycle_weight=10, bs=10, penalty_weight=1),
                                  model_desc='WGAN with Gradient Penalty; EMD estimates + Cycle Loss as a overall loss')

test_target_accuracies(mnist_classifier_folder, usps_classifier_folder, cycle_gan_folder)


Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 3.2 K 
-------------------------------------
3.2 K     Trainable params
0         Non-trainable params
3.2 K     Total params
0.013     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/cycle_gan/results/run_2022-11-10T08:08:09.835223/best.ckpt
Loaded model weights from checkpoint at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/cycle_gan/results/run_2022-11-10T08:08:09.835223/best.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 3.0 K 
-------------------------------------
3.0 K     Trainable params
0         Non-trainable params
3.0 K     Total params
0.012     Total estimated model params size (MB)


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric               DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test/mnist_classifier_acc     0.3514385223388672
test/mnist_classifier_loss    1.8667727708816528
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/cycle_gan/results/run_2022-11-10T08:08:10.426572/best.ckpt
Loaded model weights from checkpoint at /Users/rajjain/PycharmProjects/ADRL-Course-Work/domain_adap/cycle_gan/results/run_2022-11-10T08:08:10.426572/best.ckpt


Testing: 0it [00:00, ?it/s]

Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type         | Params
-----------------------------------------
0 | genU2M  | U2MGenerator | 42.9 K
1 | criticM | CriticM      | 1.3 K 
2 | genM2U  | M2UGenerator | 42.9 K
3 | criticU | CriticU      | 1.2 K 
-----------------------------------------
88.4 K    Trainable params
0         Non-trainable params
88.4 K    Total params
0.354     Total estimated model params size (MB)


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test/usps_classifier_acc    0.5545138716697693
test/usps_classifier_loss    1.342002272605896
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test/criticM_gp        0.30373769998550415
    test/criticM_loss       1.0020878314971924
     test/criticU_gp        0.07823963463306427
    test/criticU_loss       0.1398886740207672
     test/cycle_loss        18.734392166137695
 test/fake_mnists_score    0.008811501786112785
  test/fake_uspss_score     0.4236476719379425
    test/genM2U_loss        -0.4236476719379425
    test/genU2M_loss       -0.008811501786112785
     test/mnist_emd         -1.0020878314971924
test/mnist_identity_loss    1.1771318912506104
    test/overall_loss        17.59241485595703
 test/real_mnists_score     -0.9932763576507568
  test/real_uspss_score     0.2837590277194977
      test/usps_emd         -0.1398886740207672


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:01<00:00, 306.13it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 252.33it/s]


Target: MNIST -- CE Loss = 2.5223, Acc = 10.58%
Target:  USPS -- CE Loss = 2.4421, Acc = 9.68%


# Plots & Analysis

In [None]:
def plot_side_by_side(title, reals, gens, fname):
    fig, axes = plt.subplots(10, 10, figsize=(8, 8))
    fig.subplots_adjust(wspace=0.01, hspace=0.01, left=0, bottom=0, right=1, top=0.95)
    axes = axes.flat

    for i in range(reals.shape[0]):
        target_idx = 2 * i
        ax = axes[target_idx]
        ax.set_axis_off()
        ax.imshow(reals[i], cmap='gray', vmin=0, vmax=255)

        pred_idx = target_idx + 1
        ax = axes[pred_idx]
        ax.set_axis_off()
        ax.imshow(gens[i], cmap='gray', vmin=0, vmax=255)

    fig.suptitle(title)
    plt.savefig(project_dir + f'domain_adap/cycle_gan/img_results/{fname}.png')
    plt.close('all')
    
def convert_to_image(ndarray):  # -1 to 1
    ndarray = ndarray * 0.5 + 0.5  # 0 to 1
    ndarray *= 255  # 0 to 255
    ndarray = numpy.round(ndarray, decimals=0)  # rounded off
    return ndarray.astype(int)


def see_some_translations(model: MUCycleGANGP, mnist, usps, model_type, epochs):
    gen_usps = model.genM2U(mnist)
    gen_mnist = model.genU2M(usps)

    id_usps = model.genM2U(gen_mnist)
    id_mnist = model.genU2M(gen_usps)

    mnist = convert_to_image(numpy.transpose(mnist.numpy(), (0, 2, 3, 1)))
    usps = convert_to_image(numpy.transpose(usps.numpy(), (0, 2, 3, 1)))
    gen_usps = convert_to_image(numpy.transpose(gen_usps.detach().numpy(), (0, 2, 3, 1)))
    gen_mnist = convert_to_image(numpy.transpose(gen_mnist.detach().numpy(), (0, 2, 3, 1)))
    id_usps = convert_to_image(numpy.transpose(id_usps.detach().numpy(), (0, 2, 3, 1)))
    id_mnist = convert_to_image(numpy.transpose(id_mnist.detach().numpy(), (0, 2, 3, 1)))

    plot_side_by_side(
        f'GP - MNIST to USPS - NCritic = {model.ncritic}, NGen = {model.ngen}, Cycle Weight = {model.cycle_weight}, Penalty = {model.penalty_weight}, Epochs = {epochs} - {model_type}',
        mnist, gen_usps,
        f'gan_gp_trans_ncritic={model.ncritic}_ngen={model.ngen}_cycleweight={model.cycle_weight}_penalty={model.penalty_weight}_epochs={epochs}_{model_type}_m2u')
    plot_side_by_side(
        f'GP - USPS to MNIST - NCritic = {model.ncritic}, NGen = {model.ngen}, Cycle Weight = {model.cycle_weight}, Penalty = {model.penalty_weight}, Epochs = {epochs} - {model_type}',
        usps, gen_mnist,
        f'gan_gp_trans_ncritic={model.ncritic}_ngen={model.ngen}_cycleweight={model.cycle_weight}_penalty={model.penalty_weight}_epochs={epochs}_{model_type}_u2m')

    plot_side_by_side(
        f'GP - MNIST to MNIST - NCritic = {model.ncritic}, NGen = {model.ngen}, Cycle Weight = {model.cycle_weight}, Penalty = {model.penalty_weight}, Epochs = {epochs} - {model_type}',
        mnist, id_mnist,
        f'gan_gp_id_ncritic={model.ncritic}_ngen={model.ngen}_cycleweight={model.cycle_weight}_penalty={model.penalty_weight}_epochs={epochs}_{model_type}_m2m')
    plot_side_by_side(
        f'GP - USPS to USPS - NCritic = {model.ncritic}, NGen = {model.ngen}, Cycle Weight = {model.cycle_weight}, Penalty = {model.penalty_weight}, Epochs = {epochs} - {model_type}',
        usps, id_usps,
        f'gan_gp_id_ncritic={model.ncritic}_ngen={model.ngen}_cycleweight={model.cycle_weight}_penalty={model.penalty_weight}_epochs={epochs}_{model_type}_u2u')


def get_data(num_images):
    mnist_dataset = MNIST(mnist_data_dir, train=False,
                          transform=transforms.Compose([
                              transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                              transforms.Normalize((0.5,), (0.5,)),
                          ]))
    mnist_dataloader = DataLoader(mnist_dataset, num_images, shuffle=False, num_workers=num_cpus,
                                  collate_fn=custom_collate_fn)
    mnist = next(iter(mnist_dataloader))

    usps_dataset = USPS(usps_data_dir, train=False,
                        transform=transforms.Compose([
                            transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                            transforms.Normalize((0.5,), (0.5,)),
                        ]))
    usps_dataloader = DataLoader(usps_dataset, num_images, shuffle=False, num_workers=num_cpus,
                                 collate_fn=custom_collate_fn)
    usps = next(iter(usps_dataloader))

    return mnist[0], usps[0]


def plot_translations(result_dir):
    num_imgs = 50
    epochs = 200
    _mnist, _usps = get_data(num_imgs)
    direc = project_dir + 'domain_adap/cycle_gan/results/' + result_dir + '/'
    best_fname = glob(direc + 'epoch*.ckpt')[0]
    best_mod = MUCycleGANGP.load_from_checkpoint(best_fname)
    see_some_translations(best_mod, _mnist, _usps, 'best', epochs)
    mod = MUCycleGANGP.load_from_checkpoint(project_dir + 'domain_adap/cycle_gan/results/' + result_dir + '/last.ckpt')
    see_some_translations(mod, _mnist, _usps, 'last', epochs)


### CycleGAN with Gradient Clipping

In [None]:
class USPSDownSampler(Module):
    """Common class for down sampling USPS data"""

    def __init__(self):
        super(USPSDownSampler, self).__init__()
        self.model = Sequential(
            Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),
        )

    def initialise(self):
        for i in range(0, 7, 3):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')

    def forward(self, x):
        return self.model(x)


class U2MGenerator(Module):
    """
    Take USPS image [-1, 1] and generate "fake" MNIST image [-1, 1]
    """

    def __init__(self, downsampler: USPSDownSampler):
        super(U2MGenerator, self).__init__()
        self.model = Sequential(
            # Downsampling part
            downsampler,

            Flatten(),

            Linear(in_features=200, out_features=200),
            PReLU(num_parameters=200, init=0.25),
            BatchNorm1d(num_features=200),

            Unflatten(1, (8, 5, 5)),

            # Upsampling part
            ConvTranspose2d(in_channels=8, out_channels=6, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            ConvTranspose2d(in_channels=3, out_channels=1, kernel_size=(3, 3), stride=(1, 1)),
            Hardtanh(min_val=-1, max_val=1),
        )

    def initialise(self):
        for i in [2, 6, 9]:
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['12'].weight)

    def forward(self, x):
        return self.model(x)


class MNISTDownSampler(Module):
    """Common class for down sampling MNIST data"""

    def __init__(self):
        super(MNISTDownSampler, self).__init__()
        self.model = Sequential(
            Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            Conv2d(in_channels=3, out_channels=6, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=8, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=8, init=0.25),
            BatchNorm2d(num_features=8),
        )

    def initialise(self):
        for i in range(0, 7, 3):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')

    def forward(self, x):
        return self.model(x)


class M2UGenerator(Module):
    """
    Take MNIST image [-1, 1] and generate "fake" USPS image [-1, 1]
    """

    def __init__(self, downsampler: MNISTDownSampler):
        super(M2UGenerator, self).__init__()

        self.model = Sequential(
            # Downsampling part
            downsampler,

            Flatten(),

            Linear(in_features=200, out_features=200),
            PReLU(num_parameters=200, init=0.25),
            BatchNorm1d(num_features=200),

            Unflatten(1, (8, 5, 5)),

            # Upsampling part
            ConvTranspose2d(in_channels=8, out_channels=6, kernel_size=(4, 4), stride=(2, 2)),
            PReLU(num_parameters=6, init=0.25),
            BatchNorm2d(num_features=6),

            ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=3, init=0.25),
            BatchNorm2d(num_features=3),

            ConvTranspose2d(in_channels=3, out_channels=1, kernel_size=(3, 3), stride=(1, 1)),
            Hardtanh(min_val=-1, max_val=1),
        )

    def initialise(self):
        for i in [2, 6, 9]:
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['12'].weight)

    def forward(self, x):
        return self.model(x)


class CriticM(Module):
    """
    Critic for MNIST images - Checks if the mnist image passed is from mnist "real" data distribution or from "fake" U2M generator
    """

    def __init__(self, downsampler: MNISTDownSampler):
        super(CriticM, self).__init__()

        self.model = Sequential(
            downsampler,
            Flatten(),
            Linear(in_features=200, out_features=1),
        )

    def initialise(self):
        init.xavier_normal_(self.model._modules['2'].weight)

    def forward(self, x):
        return self.model(x)


class CriticU(Module):
    """
    Critic for USPS images - Checks if the usps image passed is from usps "real" data distribution or from "fake" M2U generator
    """

    def __init__(self, downsampler: USPSDownSampler):
        super(CriticU, self).__init__()

        self.model = Sequential(
            downsampler,
            Flatten(),
            Linear(in_features=200, out_features=1),
        )

    def initialise(self):
        init.xavier_normal_(self.model._modules['2'].weight)

    def forward(self, x):
        return self.model(x)


class MUCycleGAN(LightningModule):

    def __init__(self, ncritic, ngen, bs, cycle_weight, clamp_value: float):
        super(MUCycleGAN, self).__init__()
        self.save_hyperparameters()
        self.ncritic = ncritic
        self.ngen = ngen
        self.bs = bs
        self.cycle_weight = cycle_weight
        self.clamp_value = clamp_value

        # DownSamplers - Share them between respective generators and critics
        self.mnist_downsampler = MNISTDownSampler()
        self.usps_downsampler = USPSDownSampler()
        self.mnist_downsampler.initialise()
        self.usps_downsampler.initialise()

        # USPS to MNIST
        self.genU2M = U2MGenerator(self.usps_downsampler)
        self.criticM = CriticM(self.mnist_downsampler)
        self.genU2M.initialise()
        self.criticM.initialise()

        # MNIST to USPS
        self.genM2U = M2UGenerator(self.mnist_downsampler)
        self.criticU = CriticU(self.usps_downsampler)
        self.genM2U.initialise()
        self.criticU.initialise()

        # CycleGAN authors use image pool to update the discriminator. That was required because GAN training (on the
        # usual objective) was unstable. We are using WGAN and hopefully won't get into such issues. Hence, skipping
        # keeping the pool of images

        self.float()

    def _critic_losses(self, real_uspss, real_mnists, btype):
        # WGAN: Critic gets updated from the fake and real data
        # CycleGAN: Need to do this for both critics!

        # MNIST Critic
        real_mnists_score = self.criticM(real_mnists).mean()
        fake_mnists = self.genU2M(real_uspss)
        fake_mnists_score = self.criticM(fake_mnists).mean()
        criticM_loss = fake_mnists_score - real_mnists_score  # minimise this!

        # USPS Critic
        real_uspss_score = self.criticU(real_uspss).mean()
        fake_uspss = self.genM2U(real_mnists)
        fake_uspss_score = self.criticU(fake_uspss).mean()
        criticU_loss = fake_uspss_score - real_uspss_score

        not_training = btype != 'train'
        self.log(f'{btype}/criticM_loss', criticM_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/criticU_loss', criticU_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        return criticM_loss, criticU_loss

    def _generator_losses(self, real_uspss, real_mnists, btype):
        # WGAN: Generator gets updates from the fake data
        # CycleGAN: Do this for both generators and additionally put cycle consistency loss

        # USPS to MNIST
        fake_mnists = self.genU2M(real_uspss)
        fake_mnists_score = self.criticM(fake_mnists).mean()
        genU2M_loss = -fake_mnists_score  # minimise this!

        # MNIST to USPS
        fake_uspss = self.genM2U(real_mnists)
        fake_uspss_score = self.criticU(fake_uspss).mean()
        genM2U_loss = -fake_uspss_score

        # Cycle Consistency
        # Side Note: *Ideally* L1 norm should be added across dimensions and mean-ed across samples. In their
        # implementation, authors have mean-ed across dimensions too, keeping same implementation as them
        usps_identity_loss = L1Loss()(real_uspss, self.genM2U(fake_mnists))
        mnist_identity_loss = L1Loss()(real_mnists, self.genU2M(fake_uspss))
        cycle_loss = self.cycle_weight * (usps_identity_loss + mnist_identity_loss)

        not_training = btype != 'train'
        self.log(f'{btype}/genU2M_loss', genU2M_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/genM2U_loss', genM2U_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/cycle_loss', cycle_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/usps_identity_loss', usps_identity_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)
        self.log(f'{btype}/mnist_identity_loss', mnist_identity_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=not_training)

        return genU2M_loss, genM2U_loss, cycle_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        real_uspss, real_mnists = batch['usps'][0], batch['mnist'][0]

        if optimizer_idx == 1:  # Critic optimizer - only update Critic weights
            criticM_loss, criticU_loss = self._critic_losses(real_uspss, real_mnists, 'train')
            # CycleGAN authors divide this loss by 2 to slow down the rate of critic learning. Here, as the loss is
            # wasserstein loss, I believe it may not be required
            return criticM_loss + criticU_loss

        if optimizer_idx == 0:  # Generator optimizer - only update Generator weights
            genU2M_loss, genM2U_loss, cycle_loss = self._generator_losses(real_uspss, real_mnists, 'train')
            return genU2M_loss + genM2U_loss + cycle_loss

        # Is there a way to consolidate the losses and return one per epoch?
        raise Exception(f'Unknown optimizer index: {optimizer_idx}')

    def _shared_eval(self, batch, btype):
        real_uspss, real_mnists = batch['usps'][0], batch['mnist'][0]
        criticM_loss, criticU_loss = self._critic_losses(real_uspss, real_mnists, btype)
        _, _, cycle_loss = self._generator_losses(real_uspss, real_mnists, btype)

        mnist_emd = -criticM_loss
        usps_emd = -criticU_loss
        overall_loss = mnist_emd + usps_emd + cycle_loss

        self.log(f'{btype}/mnist_emd', mnist_emd, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=True)
        self.log(f'{btype}/usps_emd', usps_emd, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=True)
        self.log(f'{btype}/overall_loss', overall_loss, on_step=False, on_epoch=True, reduce_fx=torch.mean, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        self._shared_eval(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._shared_eval(batch, 'test')

    def configure_optimizers(self):
        """Using the strategy from WGAN paper instead of CycleGAN paper!"""
        generator_opt = RMSprop(params=itertools.chain(self.genU2M.parameters(), self.genM2U.parameters()), lr=0.00005)
        critic_opt = RMSprop(params=itertools.chain(self.criticU.parameters(), self.criticM.parameters()), lr=0.00005)
        return (
            {"optimizer": generator_opt, "frequency": self.ngen},
            {"optimizer": critic_opt, "frequency": self.ncritic},
        )

    def on_train_batch_end(self, *args):
        """After weights updating by the optimisers, clamp the weights"""
        for weight in self.criticU.parameters():
            weight.data.clamp_(-self.clamp_value, self.clamp_value)
        for weight in self.criticM.parameters():
            weight.data.clamp_(-self.clamp_value, self.clamp_value)

    def train_dataloader(self):
        usps_dataset = USPS(usps_data_dir, train=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5,), (0.5,)),
                            ]))
        usps_dataloader = DataLoader(usps_dataset, self.bs, shuffle=True, num_workers=num_cpus,
                                     collate_fn=custom_collate_fn)
        mnist_dataset = MNIST(mnist_data_dir, train=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  transforms.Normalize((0.5,), (0.5,)),
                              ]))
        mnist_dataloader = DataLoader(mnist_dataset, self.bs, shuffle=True, num_workers=num_cpus,
                                      collate_fn=custom_collate_fn)
        return {
            'usps': usps_dataloader,
            'mnist': mnist_dataloader,
        }

    def val_dataloader(self):
        usps_dataset = USPS(usps_data_dir, train=False,
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5,), (0.5,)),
                            ]))
        usps_dataloader = DataLoader(usps_dataset, self.bs, shuffle=False, num_workers=num_cpus,
                                     collate_fn=custom_collate_fn)
        mnist_dataset = MNIST(mnist_data_dir, train=False,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  transforms.Normalize((0.5,), (0.5,)),
                              ]))
        mnist_dataloader = DataLoader(mnist_dataset, self.bs, shuffle=False, num_workers=num_cpus,
                                      collate_fn=custom_collate_fn)
        return CombinedLoader({
            'usps': usps_dataloader,
            'mnist': mnist_dataloader,
        }, mode='max_size_cycle')

    def test_dataloader(self):
        usps_dataset = USPS(usps_data_dir, train=False,
                            transform=transforms.Compose([
                                transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                transforms.Normalize((0.5,), (0.5,)),
                            ]))
        usps_dataloader = DataLoader(usps_dataset, self.bs, shuffle=False, num_workers=num_cpus,
                                     collate_fn=custom_collate_fn)
        mnist_dataset = MNIST(mnist_data_dir, train=False,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                  transforms.Normalize((0.5,), (0.5,)),
                              ]))
        mnist_dataloader = DataLoader(mnist_dataset, self.bs, shuffle=False, num_workers=num_cpus,
                                      collate_fn=custom_collate_fn)
        return CombinedLoader({
            'usps': usps_dataloader,
            'mnist': mnist_dataloader,
        }, mode='max_size_cycle')

    def summary(self) -> str:
        _summary_kwargs = dict(dtypes=[torch.float], depth=4, col_names=['input_size', 'output_size', 'num_params'],
                               row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        _usps = torch.randn((10, 1, 16, 16), dtype=torch.float)
        _mnist = torch.randn((10, 1, 28, 28), dtype=torch.float)
        _summary_string = str(summary(model=self.genU2M, input_data=_usps, **_summary_kwargs)) + '\n' + \
                          str(summary(model=self.criticU, input_data=_usps, **_summary_kwargs)) + '\n' + \
                          str(summary(model=self.genM2U, input_data=_mnist, **_summary_kwargs)) + '\n' + \
                          str(summary(model=self.criticM, input_data=_mnist, **_summary_kwargs))
        return _summary_string
