# Wasserstein Generative Adversarial Network
Exploration of the [Improved Wasserstein GAN](https://arxiv.org/pdf/1704.00028.pdf) trained on the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset.

### Training
Training was perfomed separate from this notebook on a CUDA enabled machine using a NVIDIA GeForce GTX 1060 GPU for acceleration. Overall, it took a little over 2 hours to complete 100 training epochs over the MNIST training dataset---a large imporvement from the projected ~33 hours it would have taken on an Intel Core i7. Trained model weights were saved and are loaded to to the critic and generator models used in this notebook.

Training is done by running the `train.py` script which accepts a number of training options as arguments and runs the training loop on an instance of the `WassersteinGAN` class. An example usage of the `train.py` script is shown below, which trains a Wasserstein GAN model named 'glados_wgan' for 100 epochs with a batch size of 128.

`python train.py --name glados_wgan --ne 100 --bs 128`

There are a number of additional training options that can be configured listed within the `train.py` file. The overall structure of the Wasserstein GAN, loss functions, and the training loop can be seen in the `wasserstein_gan.py` model class file.

### Setup.
Package imports, relative imports, and some paremeter definitions.

In [1]:
import torch
import numpy as np
import plotly

# relative imports
from cnn import CNN
from transpose_cnn import TransposeCNN
from dataloader_utils import make_mnist_dataloaders
from data_utils import tile_images

BATCH_SIZE = 64
NUM_WORKERS = 1
DATA_DIR = '/tmp/mnist_data/'
NUM_CHAN = 1
Z_DIM = 128
CRITIC_MODEL_FILE = '/tmp/glados_wgan_critic.pt'
GENERATOR_MODEL_FILE = '/tmp/glados_wgan_generator.pt'

### Initialize dataloaders.
Below, we use the custom function `make_mnist_dataloaders()` defined in `dataloader_utils.py` to return iteratable Pytorch dataloaders for the MNIST dataset. On the first run of the below block of code, you may get an error saying something like `FloatProgress not found. Please update jupyter and ipywidgets`, caused by the current install of jupyterlab being slightly out of date. A link should be provided in the error message with instructions to update, but if using *conda* for Python package management this problem can be fixed by installing the update via command line with:

`conda install -c conda-forge ipywidgets`

Additionally, on the first run of the below code block and after the above error is fixed, you will see progress bars indicating the initial download of the MNIST datasets. By default, this path is set to `/tmp/mnist_data`. Running the below code block after the initial data download should return no output.

In [2]:
# intialize MNIST dataloaders
train_loader, test_loader = make_mnist_dataloaders(
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    data_dir=DATA_DIR
)

### Initialize and load pre-trained critic and generator models.

In [3]:
# initialize critic (CNN) model
critic = CNN(
    in_chan=NUM_CHAN,
    out_dim=1,
    out_act=None
)

# initialize generator (TransposeCNN) model
generator = TransposeCNN(
    in_dim=Z_DIM,
    out_chan=NUM_CHAN,
    out_act=torch.nn.Tanh()
)

# load trained weights files
critic.load_state_dict(torch.load(CRITIC_MODEL_FILE, map_location=torch.device('cpu')))
generator.load_state_dict(torch.load(GENERATOR_MODEL_FILE, map_location=torch.device('cpu')))

<All keys matched successfully>

### Let's start by generating some fake images and comparing them to some real images. 
We start by defining a normal distribution to sample our `Z_DIM` dimensional inputs for the generator. Then we generate `BATCH_SIZE` fake images using a sample from our normal distribution as inputs to the generator and we collect `BATCH_SIZE` real images directly from the MNIST training set dataloader. As these images are currently in the form of Pytorch tensors, we must call `detach()` and `numpy()` on them to remove their Pytorch functionality dependence and convert them to numpy arrays before reorganizing them into a nicely tiled single image.

In [6]:
# initialize zdim dimensional normal distribution to sample generator inputs
z_dist = torch.distributions.normal.Normal(
    torch.zeros(BATCH_SIZE, Z_DIM),
    torch.ones(BATCH_SIZE, Z_DIM)
)

# feed a batch of z samples to the generator
fake_images = generator(z_dist.sample())

# get a batch of real images from the MNIST training set
real_images = iter(train_loader).next()[0]

# detach and convert image batches to numpy arrays and tile them into one image
fake_images_tiled = tile_images(fake_images.detach().numpy())
real_images_tiled = tile_images(real_images.detach().numpy())