# Wasserstein Generative Adversarial Network
In this notebook, we train the [Improved Wasserstein GAN](https://arxiv.org/pdf/1704.00028.pdf) on the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset and investigate some of the results of training the model.

### Generative Adversarial Networks (GANs):
A GAN is a neural network model consisting of two parts; a generator and a critic (also called a discriminator depending on the specific GAN model). The generator is trained to produce novel instances of data that look real compared to some reference training dataset. In order to "sample" batches of data from the generator model, inputs proved to the model are actually sampled from a prior distribution (typically Gaussian) which then produces a "sampled" batch of output data. The critic is trained to be able to discern between data produced by the generator ("fake" data) and data selected from the training dataset ("real" data). These two models are trained against each other with competing objectoves, but as one improves, so does the other and at the end of training you get a generator that can produce realistic looking data that "fools" the critic. GAN's are often trained on image datasets in order to produce novel images that look like realistic samples from the training dataset. 

### GANs and Probability Distributions:

One way to interperet a dataset is that it consists of observations of a random variable which belongs to some complex probability distribution. Different observations will have different probabilities of occurence. If we have some training dataset consisting of "real" data (i.e. images of dogs), we can say this data belongs to some probability distribution of occurrence in which zero probability is assigned for data outside this distribution (i.e. images of lamps) and non-zero probability is assigned for data within this distribution (i.e. images of golden retrievers).

In a GAN, the critic learns to detect whether input data is sampled from "real" or "fake" probability distributions and the generator learns a function approximation of the "real" data probability distribution using the critics output as a supervision proxy signal.

### GAN Implementations:

In the [original GAN implementation](https://arxiv.org/pdf/1406.2661.pdf), the critic model is actually refered to as the *discriminator* model, as it is trained to classify whether an input image is "fake" (from the generator), or "real" from the training dataset. This approach is successful, but the presence of sigmoid activation at the output of the discriminator can lead to saturation during traing of some complex data distributions, leading to instability and difficulty training in some cases. 

The Improved Wasserstein GAN is introduced to remedy some of the instability issues faced by the [original GAN](https://arxiv.org/pdf/1406.2661.pdf) by defining a different objective for the critic. This new objective involves training the critic such that the output can directly be used to compute an estimate of the divergence between the probability distributions in which two batches of input data are sampled from. This is why the name "critic" is used instead of the name "discriminator", as this model now outputs an unbounded real value more analogous to "how real a batch of data is" rather than the probability that a batch of data is "real". In this setting, the critic and an identical copy of the critic are given a batch of "real" data sampled from the training dataset, and a batch of "fake" data sampled from the generator model, respectively. The difference between the output of the critic given "real" data and the output of the critic copy given "fake" data is used as an estimate of the [Wasserstein Distance](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between the probability distributions from which the two batches of input data were sampled from. The Wasserstein distance is a probability distribution divergence metric based on [Earth Movers Distance](https://en.wikipedia.org/wiki/Earth_mover%27s_distance), and can be interpereted loosely as the amount of "work" it would take to convert one probability distribution to another by physically moving "probability density" between the two. If two batches of data are sampled from the same distribution, the Wasserstein distance between these two samples is zero. In contrast, a non-zero Wasserstein distance between two samples of data indicates that these batches of data were sampled from different probability distributions. In the Wasserstein GAN, this property is used to define both the critic and generator objectives, in which the critic learns to maximize a Wasserstein distance estimate (to better differentiate between "real" and "fake" data), and the generator learns to minimize the same Wasserstein distance estimate (to better fool the critic with the generated "fake" data). The magic beind this implementation is that we can simply optimize these objectives without worrying about whether or not the estimated Wasserstein distance has any actual meaningful value! The Wasserstein distance "units" learned by the critic and the generator simply cancel each other out! In the following section we present the mathematical formulations of these objectives, in which we can see in more detail why this is the case.

### Defining the Objective Functions:

For a critic model, $C_\theta$, with parameters $\theta$, "real" input data batch, $x_r$, and "fake" input data batch, $x_f$, the estimated Wasserstein distance estimate, $W$, between the distributions from which the two input batches were sampled is expressed by:

\begin{align}
W(x_r, x_f) = C_\theta(x_r) - C_\theta(x_f).
\tag{1}
\label{eq:wass_dist}
\end{align}

A nice way to remember the Wasserstein distance estimate in terms of critic output is "real minus fake".

Since $x_f$ is sampled from the generator network, we can substitute $x_f = G_\phi(z)$ for generator network, $G$, with parameters $\phi$, and inputs, $z$, sampled from a prior distribution, changing eq. \ref{eq:wass_dist} to:

\begin{align}
W(x_r, z) = C_\theta(x_r) - C_\theta(G_\phi(z)).
\tag{2}
\label{eq:wass_dist_full}
\end{align}

At this point, we can see that $\frac{\partial W}{\partial \theta}$ and $\frac{\partial W}{\partial \phi}$ can both be computed from eq. \ref{eq:wass_dist_full}. Remembering also from above that we want the critic to maximize the Wasserstein distance while we want the generator to minimize Wasserstein distance, this sets us up nicely to define the critic and generator objectives both in terms of eq. \ref{eq:wass_dist_full}. 

1. **Critic Objective:** We define the critic objective, $J_c$, as the maximization of eq. \ref{eq:wass_dist_full} by defining the minimzation of the negation:

\begin{align}
J_c = \min_{\theta}\left( -W(x_r, z) \right)
\tag{3}
\label{eq:crit_obj}
\end{align}

2. **Generator Objective:** We define the generator objective, $J_g$, as the minimzation of eq. \ref{eq:wass_dist_full}:

\begin{align}
J_g = \min_{\phi}\left( W(x_r, z) \right)
\tag{4}
\label{eq:gen_obj}
\end{align}

*NOTE:* Since only the $-C_\theta(G_\phi(z))$ term of eq. \ref{eq:wass_dist_full} is dependent on the generator parameters, $\phi$, we should technically define the generator objective as $J_g = \min_{\phi}\left(-C_\theta(G_\phi(z))\right)$. We leave the objective as eq. \ref{eq:gen_obj} in the code for readibility since Pytorch can be configured to automatically ignore specific variables during different update phases.

3. **Critic Gradient Regularizer:** The "Improved" part of the Improved Wasserstein GAN involves a gradient penalty in order to promote a [1-Lipschitz constrain](https://en.wikipedia.org/wiki/Lipschitz_continuity) on the critic output function. This gradient penalty aims to enforce that the critic output gradient w.r.t the inputs has a value of 1 everywhere. In practice, computing this gradient for every possible input is clearly intractable, therefore the gradient is computed for a carefully selected batch of input samples. This sample batch is constructed by randomly interpolating between the batch of "real" data and the batch of "fake" data from the generator. The "sampled gradient" at this interpolated batch of data is pushed to be a value of one with a Lagrangian multiplier regularization term added to the critic objectove function. With the introduction of a randomly interpolated data batch, $x_i$, the "Improved" critic objectove function becomes,

\begin{align}
J_c = \min_{\theta}\left( -W(x_r, z) + \lambda \left( ||\nabla_{\theta} C_\theta(x_i)|| - 1 \right)^2 \right)
\tag{5}
\label{eq:crit_obj_reg}
\end{align}

where $\nabla_{\theta}$ is the gradient operator w.r.t parameters $\theta$ and $\lambda$ is the scalar regularization strength usually set to 10.

### Run the Model Training Scripts:
Training is perfomed separate from this notebook on a CUDA enabled machine using a NVIDIA GeForce GTX 1060 GPU for acceleration. Overall, it takes a little over 2 hours to complete 100 training epochs over the MNIST training dataset---a large imporvement from the projected ~33 hours it would have taken on an Intel Core i7. Trained model weights are saved and loaded to to the corresponding critic and generator models used in this notebook.

To perform the training, run the `train.py` script which parses a number of training options from the `config.yaml` configuration file and runs the training loop on an instance of the `WassersteinGAN` class.

### Load Trained Critic and Generator Models:
If re-running this notebook locally, you will have to edit `DATA_DIR`, `CRITIC_MODEL_PATH`, and `GENERATOR_MODEL_PATH` to match where you have saved these files locally on your machine.

In [1]:
import torch
import numpy as np
import plotly.express as px
from plotly.subplots import make_subplots

from model.cnn import CNN
from model.transpose_cnn import TransposeCNN
from util.data_utils import generate_df_from_image_dataset, tile_images
from util.pytorch_utils import build_image_dataset

DATA_DIR = '/home/dylan/datasets/mnist_png/'
IMG_CHAN = 1
Z_DIM = 128
BATCH_SIZE = 64
CRITIC_MODEL_PATH = '/home/dylan/trained_model_files/pytorch/glados/glados_gan/glados_wgan_critic.pt'
GENERATOR_MODEL_PATH = '/home/dylan/trained_model_files/pytorch/glados/glados_gan/glados_wgan_generator.pt'

# initialize critic (CNN) model
critic = CNN(
    in_chan=IMG_CHAN,
    out_dim=1,
    out_act=None
)

# initialize generator (TransposeCNN) model
generator = TransposeCNN(
    in_dim=Z_DIM,
    out_chan=IMG_CHAN,
    out_act=torch.nn.Tanh()
)

# load trained weights files
critic.load_state_dict(torch.load(CRITIC_MODEL_PATH, map_location=torch.device('cpu')))
generator.load_state_dict(torch.load(GENERATOR_MODEL_PATH, map_location=torch.device('cpu')))

<All keys matched successfully>

### Initialize A Dataloader for the MNIST Testing Set:
On the first run of the below block of code, you may get an error saying something like `FloatProgress not found. Please update jupyter and ipywidgets`, caused by the current install of jupyterlab being out of date. A link should be provided in the error message with instructions to update, but if using *conda* for Python package management this problem can be fixed by installing the update via command line with:

`conda install -c conda-forge ipywidgets`

In [2]:
# generate filenames/labels df from image data directory
data_dict = generate_df_from_image_dataset(DATA_DIR)

# build testing dataloader
test_set, test_loader = build_image_dataset(
    data_dict['test'],
    image_size=(32, 32),
    batch_size=BATCH_SIZE,
    num_workers=1
)

### Generate Fake and Real Images:
We start by defining a normal distribution to sample our 128 dimensional inputs for the generator. Then we generate 64 fake images using a sample from our normal distribution as inputs to the generator and we collect 64 real images directly from the MNIST training set dataloader. As these images are currently in the form of Pytorch tensors, we must call `detach()` and `numpy()` to remove their Pytorch functionality dependence and convert them to numpy arrays before reorganizing them into a nicely tiled single image.

In [13]:
# initialize zdim dimensional normal distribution to sample generator inputs
z_dist = torch.distributions.normal.Normal(
    torch.zeros(64, Z_DIM),
    torch.ones(64, Z_DIM)
)

# feed a batch of z samples to the generator
fake_images = generator(z_dist.sample())

# get a batch of real images from the MNIST testing set
real_images = iter(test_loader).next()['image']

# detach and convert image batches to numpy arrays
fake_images = fake_images.detach().numpy()
real_images = real_images.detach().numpy()

# transpose so pixel dimension in last place
fake_images = np.transpose(fake_images, [0, 2, 3, 1])
real_images = np.transpose(real_images, [0, 2, 3, 1])

# detach and convert image batches to numpy arrays and tile them into one image
fake_images_tiled = tile_images(fake_images)
real_images_tiled = tile_images(real_images)

# squeeze singular pixel dimension
fake_images_tiled = np.squeeze(fake_images_tiled)
real_images_tiled = np.squeeze(real_images_tiled)

real_fig = px.imshow(real_images_tiled, title='Real Images')
real_fig.update_layout(width=400, height=400, margin=dict(l=10, r=10, b=10, t=80), coloraxis_showscale=False)
real_fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)

fake_fig = px.imshow(fake_images_tiled, title='Generated Images')
fake_fig.update_layout(width=400, height=400, margin=dict(l=10, r=10, b=10, t=80), coloraxis_showscale=False)
fake_fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)

real_fig.show()
fake_fig.show()