# MNIST WGAN
Implementation of first version of the vanilla WGAN on MNIST dataset.

#### Losses
* Discriminator Loss
* Generator Loss

The discriminator will induce the generator to produce samples similar to the real samples.

#### References
* [Paper](https://arxiv.org/abs/1701.07875)
* [Code](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan/wgan.py)
* [Various GAN Implementation on Pytorch](https://github.com/eriklindernoren/PyTorch-GAN)

In [1]:
import mnist_data_pytorch as data
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import numpy as np
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)
print('Pytorch version:', torch.__version__)
# Tensorboard
from torch.utils.tensorboard import SummaryWriter
!rm -rf ./runs
writer = SummaryWriter('./runs/train')

# Metaparameters
num_epochs = 100
num_classes = 10
latent_size = 64
gen_lr = 0.00005
disc_lr = 0.00005
EPS = 1e-15
clip_value = 0.01

Device: cuda:0
Pytorch version: 1.2.0


#### Define Encoder/Decoder/Discriminator

In [2]:
class Generator(nn.Module):
    def __init__(self, latent_size=100, output_size=784):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_size, 256, normalize=False),
            *block(256, 256),
            nn.Linear(256, output_size),
            #nn.Tanh()
            nn.Sigmoid() #Sigmoid bit better for MNIST
        )

    def forward(self, z):
        x = self.model(z)
        return x


class Discriminator(nn.Module):
    def __init__(self, input_size=784):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, x):
        real_or_fake = self.model(x)
        return real_or_fake

# Initialize Networks
G = Generator(output_size=784, latent_size=latent_size).to(device)
D = Discriminator(input_size=784).to(device)

#### Initialize Optimizers
On the paper they used RMSProp

In [3]:
optim_generator = torch.optim.RMSProp(G.parameters(), lr=gen_lr)
optim_discriminator = torch.optim.RMSProp(D.parameters(), lr=disc_lr)

#### Train loop

In [4]:
k = 1
for epoch in tqdm(range(num_epochs)):
    running_loss_G = 0.0
    running_loss_D = 0.0
    # Iterate over the data
    for idx_sample, (real_imgs, _) in enumerate(data.dataloaders['train']):
        real_imgs = real_imgs.to(device)
        real_imgs = torch.flatten(real_imgs, start_dim=1, end_dim=-1)
        batch_size = real_imgs.size()[0]
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        for _ in range(k):
            # Adversarial ground truths (you can do soft-label here....)
            valid = (torch.ones(batch_size, 1).fill_(1.0)).to(device)
            fake = (torch.ones(batch_size, 1).fill_(0.0)).to(device)
            # Generate samples from random noise
            z_sample = torch.randn(batch_size, latent_size).to(device)
            gen_samples = G(z_sample)
            optim_discriminator.zero_grad()
            # Simple WGAN Loss for discriminator
            d_loss = -torch.mean(D(real_imgs)) + torch.mean(D(gen_samples.detach()))

            d_loss.backward()
            optim_discriminator.step()
            
            # Clip weights of discriminator (Enforce 1-Lipschitz)
            for p in discriminator.parameters():
                p.data.clamp_(-clip_value, clip_value)
        
        # ---------------------
        #  Train Generator
        # ---------------------
        optim_generator.zero_grad()
        # Sample from distribution Z (z~Z)
        z_sample = torch.randn(batch_size, latent_size).to(device)

        # Loss measures generator's ability to fool the discriminator
        gen_samples = G(z_sample)
        # Simple WGAN Loss for generator
        g_loss = -torch.mean(D(gen_samples))

        g_loss.backward()
        optim_generator.step()
        
        # Update statistics
        running_loss_G += g_loss.item() * batch_size
        # Update statistics
        running_loss_D += d_loss.item() * batch_size
    
    # Epoch ends
    epoch_loss_generator = running_loss_G / len(data.dataloaders['train'].dataset)
    epoch_loss_discriminator = running_loss_D / len(data.dataloaders['train'].dataset)
    
    # Send results to tensorboard
    writer.add_scalar('train/loss_generator', epoch_loss_generator, epoch)
    writer.add_scalar('train/loss_discriminator', epoch_loss_discriminator, epoch)
    
    # Send images to tensorboard
    writer.add_images('train/gen_samples', gen_samples.view(batch_size,1,28,28), epoch)
    writer.add_images('train/input_images', real_imgs.view(batch_size,1,28,28), epoch)
    
    # Send latent to tensorboard
    writer.add_histogram('train/latent', z_sample, epoch)
    writer.add_histogram('train/X', real_imgs, epoch)
    writer.add_histogram('train/G(z)', gen_samples, epoch)

100%|██████████| 100/100 [09:13<00:00,  5.53s/it]


#### Generate Samples (Unconditioned)
Observe that the generated samples are somehow a mix of all classes.

In [5]:
def generate_sample(num_idx=0):
    G.eval()
    
    z = torch.randn(1, latent_size).to(device)
    with torch.no_grad(): 
        generated_sample = G(z)

    plt.imshow(generated_sample.view(28,28).cpu().numpy())
    plt.title('Generated sample')
    plt.show()

In [6]:
interact(generate_sample, num_idx=widgets.IntSlider(min=0, max=100, step=1, value=0));

interactive(children=(IntSlider(value=0, description='num_idx'), Output()), _dom_classes=('widget-interact',))