# Generative Adversarial Networks (GANs)

In this lab, we will develop several basic GANs and experiment with them.

Some of the information in this lab is based on material from Building Basic Generative Adversarial Networks (GANs) in Coursera.

## Generative Models vs Discriminative Models

Discriminative models are typically used for classification in machine learning.
Discriminative classifiers take a set of features $x$, such as having a nose or wheels,
and from these features determine a category $y$, 
meaning that they try to model the probability of class $y$ given the set of features $x$.
Assuming $X$ is a random variable over sets of features and $Y$ is a random variable over sets of possible classes,
a discriminative models estimates

$$D(x) = P(Y=y \mid X=x).$$

Generative models, however, model $P(x)$ or $P(x \mid y)$.
Generative models based on sampling take a random input and sometimes also a class $y$ such as a "dog."
From these inputs, a generative sampler will attempt to generate a set of features $x$ that are
representative of the class "dog." The random noise input ensures that we don't generate the same
dog each time.

Assuming random noise distribution $N$, a conditional sample-based generative model attempts
the following:

1. Input class $y$
2. $z \sim N$
3. $x = G(z,y)$

the goal is $P_{z\sim N}(G(z,y)=x) = P(x \mid y)$.


## Generative Adversarial Networks (GANS)

GANs for images are composed of two models, a generator that generates images
and a discriminator that is a discriminative classifier.
The generator takes in a random noise input and an optional class and
deterministically transforms the input into an image. The discriminator attempts to 
determine which of its inputs are real samples from the data distribution and which ones
are fake samples generated by the generator. Over time, the models compete. If the training is
set up well, when complete, the generator can take in any random noise input and produce a realistic result.
In summary, $G$ learns to produce realistic examples like an artist painting paintings that look like photos,
while $D$ distinguishes the painted photos from real photos.
The basic GAN model described by Goodfellow et al. (2014) looks like this:

<img src="figures/gan_architecture-1.png" title="GAN Framework" style="width: 640px;" />

After this lab, you may be interested in reading about [6 GAN Architectures You Really Should Know](https://neptune.ai/blog/6-gan-architectures).

## Gan Setup

Let's build our first GAN.

There are about a million tutorials on coding GANs with PyTorch available online. We'll use
[code from GitHub user diegoalejogm](https://github.com/diegoalejogm/gans).

To run this code you'll need some dependencies such as tensorboardX:

In [None]:
!pip install tensorboardX

Here is a `Logger` class with useful tricks to indicate training progress and visualize results.

In [None]:
import os
import numpy as np
import errno
import torchvision.utils as vutils
from tensorboardX import SummaryWriter
from IPython import display
from matplotlib import pyplot as plt
import torch

'''
    TensorBoard Data will be stored in './runs' path
'''


class Logger:

    def __init__(self, model_name, data_name):
        self.model_name = model_name
        self.data_name = data_name

        self.comment = '{}_{}'.format(model_name, data_name)
        self.data_subdir = '{}/{}'.format(model_name, data_name)

        # TensorBoard
        self.writer = SummaryWriter(comment=self.comment)

    def log(self, d_error, g_error, epoch, n_batch, num_batches):

        # var_class = torch.autograd.variable.Variable
        if isinstance(d_error, torch.autograd.Variable):
            d_error = d_error.data.cpu().numpy()
        if isinstance(g_error, torch.autograd.Variable):
            g_error = g_error.data.cpu().numpy()

        step = Logger._step(epoch, n_batch, num_batches)
        self.writer.add_scalar(
            '{}/D_error'.format(self.comment), d_error, step)
        self.writer.add_scalar(
            '{}/G_error'.format(self.comment), g_error, step)

    def log_images(self, images, num_images, epoch, n_batch, num_batches, format='NCHW', normalize=True):
        '''
        input images are expected in format (NCHW)
        '''
        if type(images) == np.ndarray:
            images = torch.from_numpy(images)
        
        if format=='NHWC':
            images = images.transpose(1,3)
        

        step = Logger._step(epoch, n_batch, num_batches)
        img_name = '{}/images{}'.format(self.comment, '')

        # Make horizontal grid from image tensor
        horizontal_grid = vutils.make_grid(
            images, normalize=normalize, scale_each=True)
        # Make vertical grid from image tensor
        nrows = int(np.sqrt(num_images))
        grid = vutils.make_grid(
            images, nrow=nrows, normalize=True, scale_each=True)

        # Add horizontal images to tensorboard
        self.writer.add_image(img_name, horizontal_grid, step)

        # Save plots
        self.save_torch_images(horizontal_grid, grid, epoch, n_batch)

    def save_torch_images(self, horizontal_grid, grid, epoch, n_batch, plot_horizontal=True):
        out_dir = './data/images/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)

        # Plot and save horizontal
        fig = plt.figure(figsize=(16, 16))
        plt.imshow(np.moveaxis(horizontal_grid.numpy(), 0, -1))
        plt.axis('off')
        if plot_horizontal:
            display.display(plt.gcf())
        self._save_images(fig, epoch, n_batch, 'hori')
        plt.close()

        # Save squared
        fig = plt.figure()
        plt.imshow(np.moveaxis(grid.numpy(), 0, -1))
        plt.axis('off')
        self._save_images(fig, epoch, n_batch)
        plt.close()

    def _save_images(self, fig, epoch, n_batch, comment=''):
        out_dir = './data/images/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)
        fig.savefig('{}/{}_epoch_{}_batch_{}.png'.format(out_dir,
                                                         comment, epoch, n_batch))

    def display_status(self, epoch, num_epochs, n_batch, num_batches, d_error, g_error, d_pred_real, d_pred_fake):
        
        # var_class = torch.autograd.variable.Variable
        if isinstance(d_error, torch.autograd.Variable):
            d_error = d_error.data.cpu().numpy()
        if isinstance(g_error, torch.autograd.Variable):
            g_error = g_error.data.cpu().numpy()
        if isinstance(d_pred_real, torch.autograd.Variable):
            d_pred_real = d_pred_real.data
        if isinstance(d_pred_fake, torch.autograd.Variable):
            d_pred_fake = d_pred_fake.data
        
        
        print('Epoch: [{}/{}], Batch Num: [{}/{}]'.format(
            epoch,num_epochs, n_batch, num_batches)
             )
        print('Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(d_error, g_error))
        print('D(x): {:.4f}, D(G(z)): {:.4f}'.format(d_pred_real.mean(), d_pred_fake.mean()))

    def save_models(self, generator, discriminator, epoch):
        out_dir = './data/models/{}'.format(self.data_subdir)
        Logger._make_dir(out_dir)
        torch.save(generator.state_dict(),
                   '{}/G_epoch_{}'.format(out_dir, epoch))
        torch.save(discriminator.state_dict(),
                   '{}/D_epoch_{}'.format(out_dir, epoch))

    def close(self):
        self.writer.close()

    # Private Functionality

    @staticmethod
    def _step(epoch, n_batch, num_batches):
        return epoch * num_batches + n_batch

    @staticmethod
    def _make_dir(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

## Vanilla GAN for MNIST dataset

Next we'll download the MNIST dataset as a small dataset on which we can get things running quickly:

In [None]:
import torch
from torch import nn, optim
from torchvision import transforms, datasets

DATA_FOLDER = './torch_data/VGAN/MNIST'
def mnist_data():
    compose = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])
        ])
    out_dir = '{}/dataset'.format(DATA_FOLDER)
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

# Load Dataset and attach a DataLoader

data = mnist_data()
data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)
num_batches = len(data_loader)

### Generator

The generator in a GAN is the model
you want to help achieve high performance.
A generator generates different objects because of the random noise sample.
If we make small changes to the noise, we should be able to see corresponding
small changes to the output. The generator is driven by a noise vector
sampled from a latent space (the domain of $p_z$) and transforms that
noise sample into an element of the domain of
$p_{data}$.

<img src="figures/Generator.jpg" title="Generator" style="width: 640px;" />

The generator model can be practically anything that has the right input
and output tensor shapes. The "vanilla" GAN is the simplest GAN network architecture.
Here is the structure of a simple vanilla GAN generator using only
fully connected layers:

<img src="figures/VanillaGAN-Gen.png" title="Generator model" style="width: 640px;" />

And here is sample code for the model's PyTorch Module. Note that since we
normalize the real-valued input data to the range [-1,1], to limit the generator to the
same range, we use a hyperbolic tangent activation at the output:

In [None]:
class GeneratorNet(torch.nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = 784
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.LeakyReLU(0.2)
        )
        self.hidden1 = nn.Sequential(            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(1024, n_out),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

# Function to create noise samples for the generator's input

def noise(size):
    n = torch.randn(size, 100)
    if torch.cuda.is_available(): return n.cuda() 
    return n

### Discriminator

The discriminator is a type of classifier, but it is just to classify its input as
real or fake.
When a fake sample from the generator is given, it should ouptut 0 for fake:

<img src="figures/DiscriminatorFake.png" title="Discriminator-1" style="width: 640px;" />

On the other hand, if the input is real, it shoudl output 1 for real:

<img src="figures/DiscriminatorReal.jpg" title="Discriminator-2" style="width: 640px;" />

We'll use the following simple discriminator structure:

<img src="figures/VanillaGAN-Dis.png" title="VanillaGAN Discriminator" style="width: 640px;" />

Here is the Module:

In [None]:
class DiscriminatorNet(torch.nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        n_features = 784
        n_out = 1
        
        self.hidden0 = nn.Sequential( 
            nn.Linear(n_features, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.out = nn.Sequential(
            torch.nn.Linear(256, n_out),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
def images_to_vectors(images):
    return images.view(images.size(0), 784)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 1, 28, 28)

### Create the modules

Let's create an instance of the generator and discriminator:

In [None]:
discriminator = DiscriminatorNet()
generator = GeneratorNet()

if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()

### Set up the optimizers

The optimization is a min-max game.
The generator wants to minimize the objective function, whereas the discriminator wants to maximize the same objective function.
The discriminator's loss function is binary cross entropy:

$$\mathcal{L}_D = \max_D\mathcal{L}(D;G)=\mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_x{z}}[\log (1-D(G(z))]$$

The generator does't affect the first term of $\mathcal{L}_D$, so its goal
is a bit simpler, to minimize the second term of $\mathcal{L}_D$:

$$\mathcal{L}_G = \min_G\mathcal{L}(G;D)=\mathbb{E}_{z \sim p_x{z}}[\log (D(G(z))]$$

Putting these together we have

$$\min_G\max_D\mathcal{L}(D;G)=\mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_x{z}}[\log (1-D(G(z))]$$

Here is a diagram of the objective function:

<img src="figures/GanObjectivefunction.png" title="min-max optimization" style="width: 640px;" />


### Aside: why not select a very strong discriminator?

If we have a discriminator that is far superior to the generator, it will quickly determine that all the fake examples are 100% fake with
high confidence. That will not be very useful for the generator, which needs a signal to tell it how to make its samples look less fake.
On the other hand, if we have a generator that is far superior to the discriminator, we will get predictions indicating that both the real and
generated samples are equally likely to be real or fake. This is actually the end goal: a perfect generator. But we are unlikely to obtain a
perfect generator, and if the generator is not yet perfect, it is important to keep the generator and discriminator both improving together,
with similar "skill levels" from the beginning. Since the discriminator has an "easier" job than the generator, it will be difficult to keep
the competition balanced. We will talk about some solutions to this problem, such as WGANs, later.

OK, back to the optimizers.
From Goodfellow et al. (2014), the two networks are trained in an alternating fashion.
So it will be straightforward to have a separate optimizer for each model.

In [None]:
# Optimizers

d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

# Loss function

loss = nn.BCELoss()

# How many epochs to train for

num_epochs = 200

# Number of steps to apply to the discriminator for each step of the generator (1 in Goodfellow et al.)

d_steps = 1

### Training

The targets for the discriminator may be 0 or 1 depending on whether we're giving it
real or fake data:

In [None]:
def real_data_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = torch.ones(size, 1)
    if torch.cuda.is_available(): return data.cuda()
    return data

def fake_data_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = torch.zeros(size, 1)
    if torch.cuda.is_available(): return data.cuda()
    return data

Here's a function for a single step for the discriminator:

In [None]:
def train_discriminator(optimizer, real_data, fake_data):
    # Reset gradients
    optimizer.zero_grad()
    
    # Propagate real data
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, real_data_target(real_data.size(0)))
    error_real.backward()

    # Propagate fake data
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, fake_data_target(real_data.size(0)))
    error_fake.backward()
    
    # Take a step
    optimizer.step()
    
    # Return error
    return error_real + error_fake, prediction_real, prediction_fake

And here's a function for a single step of the generator:

In [None]:
def train_generator(optimizer, fake_data):
    # Reset gradients
    optimizer.zero_grad()

    # Propagate the fake data through the discriminator and backpropagate.
    # Note that since we want the generator to output something that gets
    # the discriminator to output a 1, we use the real data target here.
    prediction = discriminator(fake_data)
    error = loss(prediction, real_data_target(prediction.size(0)))
    error.backward()
    
    # Update weights with gradients
    optimizer.step()
    
    # Return error
    return error

### Generate test noise samples

Let's generate some noise vectors to use as inputs to the generator.
We'll use these samples repeatedly to see the evolution of the generator
over time.

In [None]:
num_test_samples = 16
test_noise = noise(num_test_samples)

## Start training

Now let's train the model:

In [None]:
logger = Logger(model_name='VGAN', data_name='MNIST')

for epoch in range(num_epochs):
    for n_batch, (real_batch,_) in enumerate(data_loader):

        # Train discriminator on a real batch and a fake batch
        
        real_data = images_to_vectors(real_batch)
        if torch.cuda.is_available(): real_data = real_data.cuda()
        fake_data = generator(noise(real_data.size(0))).detach()
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer,
                                                                real_data, fake_data)
        
        # Train generator

        fake_data = generator(noise(real_batch.size(0)))
        g_error = train_generator(g_optimizer, fake_data)
        
        # Log errors and display progress

        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        if (n_batch) % 100 == 0:
            display.clear_output(True)
            # Display Images
            test_images = vectors_to_images(generator(test_noise)).data.cpu()
            logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )
            
        # Save model checkpoints
        logger.save_models(generator, discriminator, epoch)