# Wasserstein Generative Adversarial Network
In this notebook, we train the [Improved Wasserstein GAN](https://arxiv.org/pdf/1704.00028.pdf) on the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset.

### What even is a Generative Adversarial Network (GAN)?
A GAN is a neural network model consisting of two parts; a generator and a critic (also called a discriminator depending on the specific GAN model). The generator is trained to generate (*wow*) novel instances of data that look real compared to some reference training dataset. In order to "sample" instances of this data from the generator model, the inputs from the model are actually sampled from a prior distribution (typically Gaussian) which then produces the "sampled" batch of output data. The critic is trained to be able to discern between data produced from the generator and data selected from the reference training dataset. These two models are pitted against each other---as one improves, so does the other and at the end of training you get a generator that can produce realistic looking data. For example, GAN's are usually trained on a specific image dataset in order to produce novel images that look like realistic samples from the training dataset. 

More specifically, if we have some training dataset consisting of "real" data (i.e. images of dogs), we say this data belongs to some probability distribution of occurrence in which zero probability is assigned for data outside this distribution (i.e. images of lamps) and non-zero probability is assigned for data within this distribution (i.e. images of golden retrievers). Given some metric that measures the divergence between two probability distributions, the goal of the critic is to learn output such that this divergence is maximized and the goal of the generator is the exact opposite; to learn output such that this divergence metric is minimized. This creates a min-max game between the critic and generator, therefore successful training requires a sort of "balance of power" between the two models by preventing one from moving too far along its objective before the other is updated again. 

In the [original GAN implementation](https://arxiv.org/pdf/1406.2661.pdf), the critic model is actually refered to as the *discriminator* model, as it is trained to classify whether an input image is "fake" (from the generator), or "real" from the training dataset. This approach is successful, but the presence of sigmoids at the output of the discriminator can lead to saturation during traing of some complex data distributions, leading to instability and difficulty training in some cases. 

The Improved Wasserstein GAN is introduced to remedy some of the instability issues faced by the original GAN by defining a different objective for the discriminator. This new objective involves training the discriminator to directly estimate the divergence between the probability distributions in which two batches of data are sampled from, therefore the name "discriminator" is changed to "critic" as this model now outputs an unbounded real value instead of a probability used for classification. In this setting, the critic is given a batch of "real" data sampled from the training dataset, and a batch of "fake" data sampled from the generator model, and its job is to estimate the [Wasserstein Distance](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between the probability distributions in which these two batches of data are sampled from. If the data in both batches belong to the same distribution, the critic should output zero since there should be no measured divergence between two identical probability distributions. This property is used to construct the objective functions used for the Improved Wasserstein GAN model. Below, we outline the components that make up this objective function.

### Objective Function
The Improved Wasserstein GAN objective function is mainly two-fold with an added regularization term for stability (hence, the "improved" status). We briefly describe each part of the objectove function below, then dive right into training. 

1. **[Wasserstein Distance](https://en.wikipedia.org/wiki/Wasserstein_metric)([Earth Movers Distance](https://en.wikipedia.org/wiki/Earth_mover%27s_distance)):** This is a metric used to measure the divergence between two probability distributions. The value is non-zero for two probability distributions that are even slightly different, and zero for two identical probability distributions. In practice, directly computing the divergence betwene two distributions is often intractable, therefore when training a Wasserstein GAN the problem is cleverly set up such that the critic learns to compute an estimate of the Wasserstein distance between the distributions that produce two batches of samples. In reality, the critic actually learns some arbritrary divergence metric that is useful for training in which the units of measure do not really matter at all. We show why this is this case below, by defining the objective for the critic and generator in terms of this shared, arbritrary divergence estimate. 

For a critic, $C_\theta$, with parameters $\theta$, "real data" input batch, $x_r$, and "fake data" input batch, $x_f$, the estimated Wasserstein distance estimate, $W$, between the distributions from which the two input batches were samples is expressed by:

\begin{align}
W(x_r, x_f) = C_\theta(x_r) - C_\theta(x_f).
\tag{1}
\label{eq:wass_dist}
\end{align}

A nice way to remember the Wasserstein distance estimate in terms of critic output is "real minus fake".

Since $x_f$ is sampled from the generator network, we can substitute $x_f = G_\phi(z)$ for generator network, $G$, with parameters $\phi$, and inputs, $z$, sampled from a prior distribution. Both the critic and generator objectives are defined in terms of eq. \ref{eq:wass_dist}, therefore we mention again that the units of measure for the estimate in eq. \ref{eq:wass_dist} do not matter to us.  

### Training
Training was perfomed separate from this notebook on a CUDA enabled machine using a NVIDIA GeForce GTX 1060 GPU for acceleration. Overall, it took a little over 2 hours to complete 100 training epochs over the MNIST training dataset---a large imporvement from the projected ~33 hours it would have taken on an Intel Core i7. Trained model weights were saved and are loaded to to the critic and generator models used in this notebook.

Training is done by running the `train.py` script which accepts a number of training options as arguments and runs the training loop on an instance of the `WassersteinGAN` class. An example usage of the `train.py` script is shown below, which trains a Wasserstein GAN model named 'glados_wgan' for 100 epochs with a batch size of 128.

`python train.py --name glados_wgan --ne 100 --bs 128`

There are a number of additional training options that can be configured listed within the `train.py` file. The overall structure of the Wasserstein GAN, loss functions, and the training loop can be seen in the `wasserstein_gan.py` model class file.

### Setup.
Package imports, relative imports, and some paremeter definitions.

In [2]:
import torch
import numpy as np
import plotly

# relative imports
from cnn import CNN
from transpose_cnn import TransposeCNN
from dataloader_utils import make_mnist_dataloaders
from data_utils import tile_images

BATCH_SIZE = 64
NUM_WORKERS = 1
DATA_DIR = '/tmp/mnist_data/'
NUM_CHAN = 1
Z_DIM = 128
CRITIC_MODEL_FILE = '/home/dylan/trained_model_files/pytorch/glados/glados_gan/glados_wgan_critic.pt'
GENERATOR_MODEL_FILE = '/home/dylan/trained_model_files/pytorch/glados/glados_gan/glados_wgan_generator.pt'

### Initialize dataloaders.
Below, we use the custom function `make_mnist_dataloaders()` defined in `dataloader_utils.py` to return iteratable Pytorch dataloaders for the MNIST dataset. On the first run of the below block of code, you may get an error saying something like `FloatProgress not found. Please update jupyter and ipywidgets`, caused by the current install of jupyterlab being slightly out of date. A link should be provided in the error message with instructions to update, but if using *conda* for Python package management this problem can be fixed by installing the update via command line with:

`conda install -c conda-forge ipywidgets`

Additionally, on the first run of the below code block and after the above error is fixed, you will see progress bars indicating the initial download of the MNIST datasets. By default, this path is set to `/tmp/mnist_data`. Running the below code block after the initial data download should return no output.

In [3]:
# intialize MNIST dataloaders
train_loader, test_loader = make_mnist_dataloaders(
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    data_dir=DATA_DIR
)

### Initialize and load pre-trained critic and generator models.

In [4]:
# initialize critic (CNN) model
critic = CNN(
    in_chan=NUM_CHAN,
    out_dim=1,
    out_act=None
)

# initialize generator (TransposeCNN) model
generator = TransposeCNN(
    in_dim=Z_DIM,
    out_chan=NUM_CHAN,
    out_act=torch.nn.Tanh()
)

# load trained weights files
critic.load_state_dict(torch.load(CRITIC_MODEL_FILE, map_location=torch.device('cpu')))
generator.load_state_dict(torch.load(GENERATOR_MODEL_FILE, map_location=torch.device('cpu')))

<All keys matched successfully>

### Let's start by generating some fake images and comparing them to some real images. 
We start by defining a normal distribution to sample our `Z_DIM` dimensional inputs for the generator. Then we generate `BATCH_SIZE` fake images using a sample from our normal distribution as inputs to the generator and we collect `BATCH_SIZE` real images directly from the MNIST training set dataloader. As these images are currently in the form of Pytorch tensors, we must call `detach()` and `numpy()` on them to remove their Pytorch functionality dependence and convert them to numpy arrays before reorganizing them into a nicely tiled single image.

In [5]:
# initialize zdim dimensional normal distribution to sample generator inputs
z_dist = torch.distributions.normal.Normal(
    torch.zeros(BATCH_SIZE, Z_DIM),
    torch.ones(BATCH_SIZE, Z_DIM)
)

# feed a batch of z samples to the generator
fake_images = generator(z_dist.sample())

# get a batch of real images from the MNIST training set
real_images = iter(train_loader).next()[0]

# detach and convert image batches to numpy arrays and tile them into one image
fake_images_tiled = tile_images(fake_images.detach().numpy())
real_images_tiled = tile_images(real_images.detach().numpy())