# Wasserstein Generative Adversarial Network
Implementation of the [Improved Wasserstein GAN](https://arxiv.org/pdf/1704.00028.pdf) on the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset.

### Setup

In [6]:
import torch
import numpy as np

# relative imports
from cnn import CNN
from transpose_cnn import TransposeCNN
from dataloader_utils import make_mnist_dataloaders

# constants and hyperparameters
LEARN_RATE = 1e-4
BATCH_SIZE = 64
NUM_TRAIN = 60000
NUM_TEST = 10000
NUM_WORKERS = 4
NUM_CHAN =1
NUM_EPOCHS = 20
Z_DIM = 128

# try to get gpu device, if not just use cpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print('[INFO] using \'{}\' device'.format(device))

[INFO] using 'cpu' device


### Initialize Models

In [2]:
# initialize critic (CNN) model
critic = CNN(
    in_chan=NUM_CHAN, 
    out_dim=1, 
    hid_act=torch.nn.ReLU(), 
    out_act=torch.nn.Identity(), 
    layer_norm=False
)

# initialize generator (TransposeCNN) model
generator = TransposeCNN(
    in_dim=Z_DIM,
    out_chan=NUM_CHAN,
    hid_act=torch.nn.ReLU(),
    out_act=torch.nn.Tanh(),
    layer_norm=True
)

print('[INFO] critic structure \n{}'.format(critic))
print('[INFO] generator structure \n{}'.format(generator))

[INFO] critic structure 
CNN(
  (hid_act): ReLU()
  (out_act): Identity()
  (conv_1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_4): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_5): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (norm_1): Identity()
  (norm_2): Identity()
  (norm_3): Identity()
  (norm_4): Identity()
  (norm_5): Identity()
  (fc_1): Linear(in_features=2048, out_features=1, bias=True)
)
[INFO] generator structure 
TransposeCNN(
  (hid_act): ReLU()
  (out_act): Tanh()
  (fc_1): Linear(in_features=128, out_features=2048, bias=True)
  (conv_1): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_2): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv_3): ConvTranspose2d(1

### Initialize Dataloaders
Below, we use the custom function `make_mnist_dataloaders()` defined in `dataloader_utils.py` to return iteratable Pytorch dataloaders for the MNIST dataset. On the first run of the below block of code, you may get an error saying something like `FloatProgress not found. Please update jupyter and ipywidgets`, caused by the current install of jupyterlab being slightly out of date. A link should be provided in the error message with instructions to update, but if using *conda* for Python package management this problem can be fixed by installing the update via command line with:

`conda install -c conda-forge ipywidgets`

Additionally, on the first run of the below code block and after the above error is fixed, you will see progress bars indicating the initial download of the MNIST datasets. By default, this path is set to `/tmp/mnist_data`. Running the below code block after the initial data download should return no output.

In [7]:
train_loader, test_loader = make_mnist_dataloaders(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)