# Lab 5 - Conditional models

Plan for today:
* we will recall how [GANs](https://arxiv.org/abs/1406.2661) work
* we will implement a GAN and then a [Conditional Generative Adversarial Network](https://arxiv.org/abs/1411.1784)


In [None]:
import torch
from torch.optim import SGD, Adam
from torch import nn
import torch
from typing import List
from torchvision.datasets import MNIST
from torchvision import transforms as tv
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from time import sleep
from torchvision.models import vgg16, vgg16_bn, resnet50, resnet18
import seaborn as sns
from sklearn.cluster import KMeans
from collections import Counter
from typing import Tuple

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
ds_train = MNIST(root="data", train=True, download=True, transform=tv.ToTensor())

ds_test = MNIST(root="data", train=False, download=True, transform=tv.ToTensor())

In [None]:
batch_size=128
dl_train = DataLoader(ds_train, batch_size, shuffle=True, drop_last=False) # dataloader with full dataset 

dl_test = DataLoader(ds_test, batch_size, shuffle=False)

## 1. A quick recap on GANs

When training a GAN, we actually train two networks:
* generator $G$ - transforms noise vectors $z$ into samples $x$
* discriminator $D$ - a binary classifier which classifies samples as *real* or *fake*.


Those two networks play an **adversarial game**:
* the generator tries to produce samples which will fool the discriminator
* the discriminator tries to learn how to distinguish between real samples and the fake ones


Let's discuss how the training procedure should look.

### Task 1 - train a vanilla GAN on MNIST


In [None]:
class Generator(nn.Module):
  def __init__(self, noise_dim: int, ...):
    """
    noise_dim - size of the noise vector
    
    """
    super().__init__()
    
    ### YOUR CODE HERE
    # initialize necessary submodules  
    ###

  def forward(self, noise_vectors):
    """
    noise_vectors: tensor of shape [b, `noise_dim`]

    returns:
      samples of shape [b, c, h, w] 
    """

    ### YOUR CODE HERE
    # * process noise vectors into generated samples
    # hint - since we're dealing with MNIST, it may be useful to apply 
    # sigmoid / tanh activation to the network output and squash it between (0,1)
    ###


class Discriminator(nn.Module):
    def __init__(self, ...):
        super().__init__()
        ### YOUR CODE HERE
        # initialize necessary submodules  
        ###
    

    def forward(self, images):
        """
        images: samples of shape [b, c, h, w] 

        returns:
          binary classification of shape [b] - whether the images are fake or not
        """
        

In [None]:
def train(gen: Generator, disc: Discriminator, train_loader, num_epochs: int):
  for e in range(num_epochs):
    for (real_images, _) in dl_train:
      ### Training discriminator (to discriminate between real / fake images)
      # * sample from generator the same number of fake images as real ones
      # * minimize the binary cross-entropy loss of discriminator classifying real and fake images as 1s and 0s, respectively
      #   * remember to .detach() the generated images before inputting them to the discriminator - we don't want the gradient of this loss to flow through the generator!
      ###

      ### Training generator  (to fool discrimintor into classifying fake images as real)
      # * minimize the binary cross-entropy loss of discriminator classifying the fake images as 1s (the reverse of what we did in the discriminator training step)
      ###

      # hint - it is useful to keep separate optimizers for the generator and discriminator, so as not to mix up gradients
      # it is also useful to plot example generator outputs every few epochs throughout the training

## 2 - Making the GANs conditional


## How is this related to conditional models?

When traininng a classical GAN, we sample the noise vectors from a distribution of our choice and have little control over what we'll generate. To help with that, we can add inject a **condition** e.g. the class of our sample, to the model.

It is quite simple to achieve this. Discriminator and generator should have additional modules which process the condition (e.g. in the form of one-hot vector) into a vector of desired shape and concatenate this vector with the model input:
* for $G$, concatenate label output with noise vector
* for $D$, concatenate label output with the classified image (e.g. as an additional channel)

## Task for you - implement a cGAN
* implement two variants of generators / discriminators
  * built with linear layers (+ activations, batchnorm, dropout, etc)
  * built with convolutional layers  (+ activations, batchnorm, dropout, etc), except the final linear layer of the discriminator

You can use the code below as a start:


In [None]:
class CGenerator(nn.Module):
  def __init__(self, n_classes: int, noise_dim: int, ...):
    """
    n_classes - size of the one-hot condition vector
    noise_dim - size of the noise vector
    
    """
    super().__init__()
    
    ### YOUR CODE HERE
    # initialize necessary submodules  
    ###

  def forward(self, noise_vectors, label_vectors):
    """
    noise_vectors: tensor of shape [b, `noise_dim`]
    label_vectors: one-hot tensor of shape [b, `n_classes`]

    returns:
      samples of shape [b, c, h, w] 
    """

    ### YOUR CODE HERE
    # * concatenate noise and label vectors
    # * process them into generated samples
    # hint - since we're dealing with MNIST, it may be useful to apply 
    # sigmoid / tanh activation to the network output and squash it between (0,1)
    ###


class CDiscriminator(nn.Module):
    def __init__(self, n_classes: int, ...):
        super().__init__()
        ### YOUR CODE HERE
        # initialize necessary submodules  
        ###
    

    def forward(self, images, label_vectors):
        """
        images: samples of shape [b, c, h, w] 
        label_vectors: samples of shape [b, n_classes]

        returns:
          binary classification of shape [b] - whether the images with labels are fake or not
        """
        

In [None]:
def train(gen: CGenerator, disc: CDiscriminator, train_loader, num_epochs: int):
  for e in range(num_epochs):
    for (real_images, labels) in dl_train:
      labels_one_hot = nn.functional.one_hot(labels, num_classes=n_classes).to(device).float()

      ### Training discriminator (to discriminate between real / fake images)
      # * sample from generator the same number of fake images as real ones
      # * minimize the binary cross-entropy loss of discriminator classifying real and fake images as 1s and 0s, respectively
      #   * remember to .detach() the generated images before inputting them to the discriminator - we don't want the gradient of this loss to flow through the generator!
      ###

      ### Training generator  (to fool discrimintor into classifying fake images as real)
      # * minimize the binary cross-entropy loss of discriminator classifying the fake images as 1s (the reverse of what we did in the discriminator training step)
      ###

      # hint - it is useful to keep separate optimizers for the generator and discriminator, so as not to mix up gradients
      # it is also useful to plot example generator outputs every few epochs throughout the training

## Task - draw example samples from the trained models
* for each digit class, sample an image
* for each image sample, add the title with the digit which served as a conditioning vector


## Task - draw a couple of examples of interpolation between classes:
For a fixed noise vector, choose two classes and draw samples generated from this noise vector and one-hot vectors of two classes gradually transitioning between the two classes. 

Please print or otherwise indicate which two classes you are transitioining between.