<h1>Generative Adversarial Network</h1>
In this notebook we will explore a popular method for data generation, the Generative Adversarial Network (GAN). GANs are in fact a pair of networks (at least two) a Generator and a Discriminator. The Generator takes a random sample from a distribution (usually a standard normal distribution) and produces an image. The Discriminator takes an image and tries to classify it as either coming from the Generator or the dataset of real images. The twist is that we optimize the Generator to produce an image that the Discriminator classifies as "real" (coming from the dataset). <br>

<img src="../data/MNIST_GAN_DEMO.gif" width="700" align="center">

We do this by generating an image and passing it through the Discriminator and then calculating the loss (as if the fake image was real) and the back propagating gradients back through the discriminator to the Generator and then updating only the generator. By doing this the Generator "sees" how to update itself in such a way so that the Discriminator will classify it's images as "real". As this is happening at the same times as the Discriminator is learning to distinguish the Generators "fake" images from the datasets "real" ones we get a "bootstrapping" effect. As a result the Generator gets better at generating "fake" images and the discriminator gets better at detecting fake from real, until (ideally) the fake images are indistinguishable from real.    

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.utils as vutils

import os
import matplotlib.pyplot as plt
import numpy as np
import imageio

from tqdm.notebook import trange, tqdm

In [None]:
# training parameters
batch_size = 512
dlr = 1e-4
glr = 1e-4

train_epoch = 100

# data_loader
img_size = 32

data_set_root = "../../datasets"

<b> Use a GPU if avaliable </b>

In [None]:
use_cuda = torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")

In [None]:
transform = transforms.Compose([
                                transforms.Resize(img_size),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=([0.5]), std=([0.5]))
                                ])

trainset = datasets.MNIST(data_set_root, train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

<h2>Adversarial Training</h2>
<img src="../data/GAN.jpg" width="700" align="center">


<h3> Generator </h3>
Our Generator is a simple transpose convolution network, it takes in a vector as an input and up samples it multiple times until it is the desired size.<br>
NOTE: We use a 3D tensor for the input instead of a 1D vector as it is easier to just use transpose convolution layers for everything instead of having to reshape the input.

In [None]:
class Generator(nn.Module):
    def __init__(self, z=64, ch=16):
        super(Generator, self).__init__()
        self.conv1 = nn.ConvTranspose2d(z, ch * 4, 4, 1)

        self.conv2 = nn.Conv2d(ch * 4, ch * 4, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(ch * 4)
        
        self.conv3 = nn.Conv2d(ch * 4, ch * 2, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(ch * 2)
        
        self.conv4 = nn.Conv2d(ch * 2, ch * 2, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(ch * 2)
        
        self.conv5 = nn.Conv2d(ch * 2, ch * 2, 3, 1, 1)
        self.bn4 = nn.BatchNorm2d(ch * 2)

        self.conv_out = nn.Conv2d(ch * 2, 1, 3, 1, 1)
        self.up_nn = nn.Upsample(scale_factor=2, mode="nearest")

    # Forward method
    def forward(self, x):        
        x = F.elu(self.conv1(x))
        x = self.up_nn(F.elu(self.bn1(self.conv2(x))))
        x = self.up_nn(F.elu(self.bn2(self.conv3(x))))
        x = self.up_nn(F.elu(self.bn3(self.conv4(x))))
        x = F.elu(self.bn4(self.conv5(x)))

        return torch.tanh(self.conv_out(x))

# Discriminator

The Discriminator is simple convolutional classifier network that has a single output.<br>
NOTE:The output is also a 3D tensor for a single output, we will squish it before calculating the loss

In [None]:
class Discriminator(nn.Module):
    # initializers
    def __init__(self, ch=16):
        super(Discriminator, self).__init__()
        self.conv_in = nn.Conv2d(1, ch, 3, 1, 1)
        
        self.conv1 = nn.Conv2d(ch, ch*2, 3, 2, 1)        
        self.conv2 = nn.Conv2d(ch*2, ch*2, 3, 2, 1)        
        self.conv3 = nn.Conv2d(ch*2, ch*4, 3, 2, 1)        
        self.conv4 = nn.Conv2d(ch*4, ch*4, 3, 2, 1)
        self.bn = nn.BatchNorm2d(ch*4)
        
        self.do = nn.Dropout()
        self.conv5 = nn.Conv2d(ch*4, 1, 2, 1)

    # forward method
    def forward(self, x):
        x = F.elu(self.conv_in(x))
        x = F.elu(self.conv1(x))
        x = F.elu(self.conv2(x))
        x = F.elu(self.conv3(x))
        x = self.do(F.elu(self.bn(self.conv4(x))))
        x = self.conv5(x).reshape(x.shape[0], 1)
        return x

In [None]:
def gan_bce_loss(output, real_label=True):
    if real_label:
        return F.binary_cross_entropy_with_logits(output, torch.ones_like(output))
    else:     
        return F.binary_cross_entropy_with_logits(output, torch.zeros_like(output))

In [None]:
def gan_w_loss(output, real_label=True):
    if real_label:
        return -output.mean()
    else:     
        return output.mean()

<h2> Basic GAN Model</h2>

![alt text](https://miro.medium.com/max/1600/1*M_YipQF_oC6owsU1VVrfhg.jpeg)

A more formal way of describing what is going on is to characterise our dataset as a discrete sample from some high dimensional distribution $d_r$ (where every pixel value is a degree of freedom). Our Generator maps from some input distribution $N(0,1)$ to $d_f$ and the Discriminator tries to learn to distinguish $d_r$ from $d_f$ while the Generator tries to make $d_f = d_r$ at which point the Discriminator cannot tell the difference. 

<h3>Define Hyperparameters, Network and Optimizer</h3>
GANs are very instable and sensitive to hyperparameters, see how well you can get the generated output by tuning hyperparameters.

In [None]:
# network
latent_noise_dem = 128

g_net = Generator(latent_noise_dem, ch=32).to(device)
d_net = Discriminator(ch=32).to(device)

#A fixed latent noise vector so we can see the improvement over the epochs
fixed_latent_noise = torch.randn(16, latent_noise_dem, 1, 1).to(device)

# Adam optimizer
g_optimizer = optim.Adam(g_net.parameters(), lr=glr)
d_optimizer = optim.Adam(d_net.parameters(), lr=dlr)

In [None]:
# If you are using gradient clipping, initialise the params to be smaller
# with torch.no_grad():
#     for param in d_net.parameters():
#         param.data *= 0.01

In [None]:
test_images_log = []
d_losses = []
g_losses = []

d_out_fake = []
d_out_real = []

g_loss = 0
d_loss = 0

# The main training loop


The overall objective function is: ![alt text](https://miro.medium.com/max/1500/1*l9se1koH_eQdZesko5eQpw.jpeg)

This is also formalised as a "minmax" game, where the Generator is trying to minimise the above loss function for a Discriminator that is trying to maximise it

The Loss function we are using is the binary cross entropy loss because what the Discriminator do is essentially a binary classification problem: Real or Fake. At the point where $d_f = d_r$ the output of our Discriminator should be 0.5 (halfway between 0 and 1 - note this rarely actually happens)<br>


GANs can be difficult to Train, some common reasons are:
1. At the beginning, the distribution of real and fake samples are far away from each other, it can be easy for the Discriminator to tell the fake from the real

3. The Discriminator becomes too good and stops providing useful gradients back to the Generator

2. The Discriminator becomes bad and cannot tell the difference between the real and fake samples even though the output from the generator is bad

3. The Generator overpowers the Discriminator and the Discriminator cannot tell real from fake (even though $d_f  \neq d_r$)

4. The Generator starts outputting only a very few images (low variation in the images) and as the Discriminator is only looking for fake images it is not corrected

In [None]:
pbar = trange(train_epoch, leave=False, desc="Epoch")    
for epoch in pbar:
    pbar.set_postfix_str('G Loss: %.4f, D Loss: %.4f' % (g_loss/len(train_loader), 
                                                         d_loss/len(train_loader)))
    g_loss = 0
    d_loss = 0

    for num_iter, (images, label) in enumerate(tqdm(train_loader, leave=False)):

#         with torch.no_grad():
#             for param in d_net.parameters():
#                 param.clamp_(-0.05, 0.05)

        images = images.to(device)
        
        #the size of the current minibatch
        bs = images.shape[0]

        ########### Train Generator G ##############
        #Step1: Sample a latent vector from a normal distribution and pass it through the generator
        #to get a batch of fake images
        latent_noise = torch.randn(bs, latent_noise_dem, 1, 1).to(device)
        g_output = g_net(latent_noise)
        
        #Step3: Pass the minibatch of fake images (from the Generator) through the Discriminator and calculate
        #the loss against the "real" label - the Generator wants the discriminator to think it's outputs are real
        d_result = d_net(g_output)
        g_train_loss = gan_bce_loss(d_result, True)
        
        #Step4: Backpropogate the loss through the discriminator and into the Generator and take a training step 
        g_net.zero_grad()
        g_train_loss.backward()
        g_optimizer.step()
        
        #log the generator training loss
        g_losses.append(g_train_loss.item())
        g_loss += g_train_loss.item()

#         with torch.no_grad():
#             for param in g_net.parameters():
#                 param.clamp_(-0.01, 0.01)
        
        ########### Train Discriminator D! ############
        
        #Step1: Pass the minibatch of real images through the Discriminator and calculate
        #the loss against the "real" label
        d_real_out = d_net(images)
        d_real_loss = gan_bce_loss(d_real_out, True)
        d_out_real.append(d_real_out.mean().item())
        
        #Step2: Pass the minibatch of fake images (from the Generator) through the Discriminator and calculate
        #the loss against the "fake" label
        #We "detach()" the output of the Generator here as we don't need it to backpropagate through the
        #Generator in this step
        d_fake_out = d_net(g_output.detach())
        d_fake_loss = gan_bce_loss(d_fake_out, False)
        d_out_fake.append(d_fake_out.mean().item())

        #Step3: Add the two losses together, backpropogate through the discriminator and take a training step 
        d_train_loss = (d_real_loss + d_fake_loss)/2

        d_net.zero_grad()
        d_train_loss.backward()
        d_optimizer.step()
        #log the discriminator training loss
        d_losses.append(d_train_loss.item())
        d_loss += d_train_loss.item()

    with torch.no_grad():
        g_net.eval()
        #log the output of the generator given the fixed latent noise vector
        test_fake = g_net(fixed_latent_noise)
        imgs = torchvision.utils.make_grid(test_fake.cpu().detach(), 4, pad_value=1, normalize=True)
        imgs_np = (imgs.numpy().transpose((1, 2, 0)) * 255).astype(np.uint8)
        test_images_log.append(imgs_np)
        g_net.train()


<h3>Plot out the losses and visualize the generated images</h3>
Notice how noisy the losses for both the Generator and Discriminator are?

In [None]:
plt.plot(d_losses)

In [None]:
plt.plot(g_losses)

In [None]:
test_fake = g_net(fixed_latent_noise)
plt.figure(figsize = (20,10))
out = vutils.make_grid(test_fake.detach().cpu(), 4, normalize=True)
plt.imshow(out.numpy().transpose((1, 2, 0)))

<b>Lets create a gif of our generated images throughout training</b>

In [None]:
imageio.mimsave('MNIST_GAN.gif', test_images_log)