<h1>Generative Adversarial Network</h1>
In this lab we will explore a popular method for data generation, the Generative Adversarial Network (GAN). GANs are in fact a pair of networks (at least two) a Generator and a Discriminator. The Generator takes a random sample from a distribution (usually a standard normal distribution) and produces an image. The Discriminator takes an image and tries to classify it as either coming from the Generator or the dataset of real images. The twist is that we optimize the Generator to produce an image that the Discriminator classifies as "real" (coming from the dataset). <br>

<img src="MNIST_GAN_DEMO.gif" width="700" align="center">

We do this by generating an image and passing it through the Discriminator and then calculating the loss (as if the fake image was real) and the back propagating gradients back through the discriminator to the Generator and then updating only the generator. By doing this the Generator "sees" how to update itself in such a way so that the Discriminator will classify it's images as "real". As this is happening at the same times as the Discriminator is learning to distinguish the Generators "fake" images from the datasets "real" ones we get a "bootstrapping" effect. As a result the Generator gets better at generating "fake" images and the discriminator gets better at detecting fake from real, until (ideally) the fake images are indistinguishable from real.    

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.utils as vutils

import os
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output
import imageio

<h2> Basic GAN Model</h2>

![alt text](https://miro.medium.com/max/1600/1*M_YipQF_oC6owsU1VVrfhg.jpeg)

A more formal way of describing what is going on is to characterise our dataset as a discrete sample from some high dimensional distribution $d_r$ (where every pixel value is a degree of freedom). Our Generator maps from some input distribution $N(0,1)$ to $d_f$ and the Discriminator tries to learn to distinguish $d_r$ from $d_f$ while the Generator tries to make $d_f = d_r$ at which point the Discriminator cannot tell the difference. 

<h3> Generator </h3>
Our Generator is a simple transpose convolution network, it takes in a vector as an input and up samples it multiple times until it is the desired size.<br>
NOTE: We use a 3D tensor for the input instead of a 1D vector as it is easier to just use transpose convolution layers for everything instead of having to reshape the input.

In [None]:
class generator(nn.Module):
    # initializers
    def __init__(self, z = 64, ch = 16):
        super(generator, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(z, ch*8, 2, 2)
        self.deconv2 = nn.ConvTranspose2d(ch*8, ch*4, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(ch*4)
        self.deconv3 = nn.ConvTranspose2d(ch*4, ch*2, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(ch*2)
        self.deconv4 = nn.ConvTranspose2d(ch*2, ch, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm2d(ch)
        self.deconv5 = nn.ConvTranspose2d(ch, 1, 4, 2, 1)

    # forward method
    def forward(self, x):
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        x = torch.tanh(self.deconv5(x))

        return x

# Discriminator

The Discriminator is simple convolutional classifier network that has a single output.<br>
NOTE:The output is also a 3D tensor for a single output, we will squish it before calculating the loss

In [None]:
class discriminator(nn.Module):
    # initializers
    def __init__(self, ch = 16):
        super(discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, ch, 4, 2, 1)
        self.conv1_bn = nn.BatchNorm2d(ch)
        self.conv2 = nn.Conv2d(ch, ch*2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(ch*2)
        self.conv3 = nn.Conv2d(ch*2, ch*4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(ch*4)
        self.conv4 = nn.Conv2d(ch*4, ch*8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm2d(ch*8)
        self.conv5 = nn.Conv2d(ch*8, 1, 2, 2)

    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu((self.conv2(x)), 0.2)
        x = F.leaky_relu((self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        x = self.conv5(x)

        return x

<b> Use a GPU if avaliable </b>

In [None]:
use_cuda = torch.cuda.is_available()
GPU_indx  = 0
device = torch.device(GPU_indx if use_cuda else "cpu")

<h3>Define Hyperparameters, Network and Optimizer</h3>
GANs are very instable and sensitive to hyperparameters, see how well you can get the generated output by tuning hyperparameters.

In [None]:
# training parameters
batch_size = 512
dlr = 1e-4
glr = 2e-4

train_epoch = 100

# data_loader
img_size = 32

latent_noise_dem = 128

transform = transforms.Compose([
                                transforms.Resize(img_size),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=([0.5]), std=([0.5]))
                                ])

trainset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers = 4)

# network
G = generator(latent_noise_dem, 16).to(device)
D = discriminator(16).to(device)

#A fixed latent noise vector so we can see the improvement over the epochs
fixed_latent_noise = torch.randn(16, latent_noise_dem, 1, 1).to(device)
# Binary Cross Entropy loss
BCE_loss = nn.BCEWithLogitsLoss()

# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=glr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=dlr, betas=(0.5, 0.999))

<b>Create a save folder</b>

In [None]:
if not os.path.isdir('MNIST_DCGAN_results'):
    os.mkdir('MNIST_DCGAN_results')
    
test_images_log = []
D_losses = []
G_losses = []

D_out_fake = []
D_out_real = []

# The main training loop


The overall objective function is: ![alt text](https://miro.medium.com/max/1500/1*l9se1koH_eQdZesko5eQpw.jpeg)

This is also formalised as a "minmax" game, where the Generator is trying to minimise the above loss function for a Discriminator that is trying to maximise it

The Loss function we are using is the binary cross entropy loss because what the Discriminator do is essentially a binary classification problem: Real or Fake. At the point where $d_f = d_r$ the output of our Discriminator should be 0.5 (halfway between 0 and 1 - note this rarely actually happens)<br>


GANs can be difficult to Train, some common reasons are:
1. At the beginning, the distribution of real and fake samples are far away from each other, it can be easy for the Discriminator to tell the fake from the real

3. The Discriminator becomes too good and stops providing useful gradients back to the Generator

2. The Discriminator becomes bad and cannot tell the difference between the real and fake samples even though the output from the generator is bad

3. The Generator overpowers the Discriminator and the Discriminator cannot tell real from fake (even though $d_f  \neq d_r$)

4. The Generator starts outputting only a very few images (low variation in the images) and as the Discriminator is only looking for fake images it is not corrected

In [None]:
for epoch in range(train_epoch):
    for num_iter, data in enumerate(train_loader):
        images, _ = data
        images = images.to(device)
        
        #the size of the current minibatch
        mini_batch = images.size()[0]
        
        #Create the "real" and "fake" labels
        #this is what the discriminator should idealy produce
        label_real = torch.ones(mini_batch).to(device)
        label_fake = torch.zeros(mini_batch).to(device)        
        
        ########### Train Discriminator D! ############
        
        #Step1: Sample a latent vector from a normal distribution and pass it through the generator
        #to get a batch of fake images
        latent_noise = torch.randn(mini_batch, latent_noise_dem, 1, 1).to(device)
        G_output = G(latent_noise)
        
        #Step2: Pass the minibatch of real images through the Discriminator and calculate
        #the loss against the "real" label
        #Add some noise so the Discriminator cannot tell that the real image's pixel
        #values are decrete
        input_noise = 0.01*torch.randn_like(images)
        D_real_out = D(images+input_noise).squeeze()
        D_real_loss = BCE_loss(D_real_out, label_real)
        D_out_real.append(D_real_out.mean().item())
        
        #Step3: Pass the minibatch of fake images (from the Generator) through the Discriminator and calculate
        #the loss against the "fake" label
        #We "detach()" the output of the Generator here as we don't need it to backpropagate through the
        #Generator in this step
        #Add some noise so the Discriminator cannot tell that the real image's pixel
        #values are decrete
        input_noise = 0.01*torch.randn_like(images)
        D_fake_out = D(G_output.detach() + input_noise).squeeze()
        D_fake_loss = BCE_loss(D_fake_out, label_fake)
        D_out_fake.append(D_fake_out.mean().item())

        #Step4: Add the two losses together, backpropogate through the discriminator and take a training step 
        D_train_loss = D_real_loss + D_fake_loss

        D.zero_grad()
        D_train_loss.backward()
        D_optimizer.step()
        #log the discriminator training loss
        D_losses.append(D_train_loss.item())
                
        ########### Train Generator G ##############
        
        #Step1: Sample a latent vector from a normal distribution and pass it through the generator
        #to get a batch of fake images
        latent_noise = torch.randn(mini_batch, latent_noise_dem, 1, 1).to(device)
        G_output = G(latent_noise)
        
        #Step3: Pass the minibatch of fake images (from the Generator) through the Discriminator and calculate
        #the loss against the "real" label - the Generator wants the discriminator to think it's outputs are real
        D_result = D(G_output).squeeze()
        G_train_loss = BCE_loss(D_result, label_real)
        
        #Step4: Backpropogate the loss through the discriminator and into the Generator and take a training step 
        G.zero_grad()
        G_train_loss.backward()
        G_optimizer.step()
        
        #log the generator training loss
        G_losses.append(G_train_loss.item())
        
        clear_output(True)
        #Print out the training status
        print('Epoch [%d/%d], Step [%d/%d], D_loss: %.4f, G_loss: %.4f'
              % (epoch+1, train_epoch, num_iter+1, len(train_loader), D_train_loss, G_train_loss))
                
    #save both networks
    torch.save(G.state_dict(), "MNIST_DCGAN_results/generator_param.pt")
    torch.save(D.state_dict(), "MNIST_DCGAN_results/discriminator_param.pt")
    
    #log the output of the generator given the fixed latent noise vector
    test_fake = G(fixed_latent_noise)
    imgs_np = (torchvision.utils.make_grid(test_fake.cpu().detach(), 4, pad_value = 0.5).numpy().transpose((1, 2, 0))*255).astype(np.uint8)
    test_images_log.append(imgs_np)


<h3>Plot out the losses and visualize the generated images</h3>
Notice how noisy the losses for both the Generator and Discriminator are?

In [None]:
plt.plot(D_losses)

In [None]:
plt.plot(G_losses)

In [None]:
test_fake = G(fixed_latent_noise)
plt.figure(figsize = (20,10))
out = vutils.make_grid(test_fake.detach().cpu(), 4)
plt.imshow(out.numpy().transpose((1, 2, 0)))

<b>Lets create a gif of our generated images throughout training</b>

In [None]:
imageio.mimsave('MNIST_GAN.gif', test_images_log)