# MNIST WGAN GP
This is a improvement of the WGAN, that avoid the gradient clipping by using a regularizing term (Gradient Penalty).

#### Changes
* Discriminator won't have sigmoid
* Generator and Discriminator loss differences
* Clipping on Discriminator gradients during training.

The discriminator will induce the generator to produce samples similar to the real samples.

#### Some notes on interpreting losses
On WGAN you might have negative losses, what we want to observe is that both generator/discriminator converges near zero.

#### References
* [Paper](https://arxiv.org/pdf/1704.00028.pdf)
* [Code](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py)
* [Various GAN Implementation on Pytorch](https://github.com/eriklindernoren/PyTorch-GAN)

In [1]:
import mnist_data_pytorch as data
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import numpy as np
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)
print('Pytorch version:', torch.__version__)
# Tensorboard
from torch.utils.tensorboard import SummaryWriter
!rm -rf ./runs
writer = SummaryWriter('./runs/train')

# Metaparameters
num_epochs = 100
num_classes = 10
latent_size = 64
gen_lr = 0.0002
disc_lr = 0.0002
# Loss weight for gradient penalty
lambda_gp = 10
b1 = 0.5
b2 = 0.999

Device: cuda:0
Pytorch version: 1.2.0


#### Define Encoder/Decoder/Discriminator

In [2]:
class Generator(nn.Module):
    def __init__(self, latent_size=100, output_size=784):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_size, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, output_size),
            nn.Tanh(),
            #nn.Sigmoid() #Sigmoid bit better for MNIST
        )

    def forward(self, z):
        x = self.model(z)
        return x


class Discriminator(nn.Module):
    def __init__(self, input_size=784):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, x):
        real_or_fake = self.model(x)
        return real_or_fake

# Initialize Networks
G = Generator(output_size=784, latent_size=latent_size).to(device)
D = Discriminator(input_size=784).to(device)

#### Gradient Penalty

In [3]:
def compute_GP(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    batch_size = real_samples.size(0)
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand(batch_size, 1).to(device)
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = (torch.ones(batch_size, 1).fill_(0.0)).to(device)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


#### Initialize Optimizers
On the paper they used RMSProp

In [4]:
optim_generator = torch.optim.Adam(G.parameters(), lr=gen_lr, betas=(b1, b2))
optim_discriminator = torch.optim.Adam(D.parameters(), lr=disc_lr, betas=(b1, b2))

#### Train loop

In [5]:
k = 1
for epoch in tqdm(range(num_epochs)):
    running_loss_G = 0.0
    running_loss_D = 0.0
    # Iterate over the data
    for idx_sample, (real_imgs, _) in enumerate(data.dataloaders['train']):
        real_imgs = real_imgs.to(device)
        real_imgs = torch.flatten(real_imgs, start_dim=1, end_dim=-1)
        batch_size = real_imgs.size()[0]
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        for _ in range(k):
            # Generate samples from random noise
            z_sample = torch.randn(batch_size, latent_size).to(device)
            gen_samples = G(z_sample)
            optim_discriminator.zero_grad()
            # Simple WGAN Loss for discriminator
            # Gradient penalty
            gradient_penalty = compute_GP(D, real_imgs.data, gen_samples.data)
            d_loss = -torch.mean(D(real_imgs)) + torch.mean(D(gen_samples.detach())) + lambda_gp * gradient_penalty

            d_loss.backward()
            optim_discriminator.step()
        
        # ---------------------
        #  Train Generator
        # ---------------------
        optim_generator.zero_grad()
        # Sample from distribution Z (z~Z)
        z_sample = torch.randn(batch_size, latent_size).to(device)

        # Loss measures generator's ability to fool the discriminator
        gen_samples = G(z_sample)
        # Simple WGAN Loss for generator
        g_loss = -torch.mean(D(gen_samples))

        g_loss.backward()
        optim_generator.step()
        
        # Update statistics
        running_loss_G += g_loss.item() * batch_size
        # Update statistics
        running_loss_D += d_loss.item() * batch_size
    
    # Epoch ends
    epoch_loss_generator = running_loss_G / len(data.dataloaders['train'].dataset)
    epoch_loss_discriminator = running_loss_D / len(data.dataloaders['train'].dataset)
    
    # Send results to tensorboard
    writer.add_scalar('train/loss_generator', epoch_loss_generator, epoch)
    writer.add_scalar('train/loss_discriminator', epoch_loss_discriminator, epoch)
    
    # Send images to tensorboard
    writer.add_images('train/gen_samples', gen_samples.view(batch_size,1,28,28), epoch)
    writer.add_images('train/input_images', real_imgs.view(batch_size,1,28,28), epoch)
    
    # Send latent to tensorboard
    writer.add_histogram('train/latent', z_sample, epoch)
    writer.add_histogram('train/X', real_imgs, epoch)
    writer.add_histogram('train/G(z)', gen_samples, epoch)

  0%|          | 0/100 [00:00<?, ?it/s]


RuntimeError: invalid gradient at index 0 - got [100, 1] but expected shape compatible with [100, 1, 100, 1]

#### Generate Samples (Unconditioned)
Observe that the generated samples are somehow a mix of all classes.

In [None]:
def generate_sample(num_idx=0):
    G.eval()
    
    z = torch.randn(1, latent_size).to(device)
    with torch.no_grad(): 
        generated_sample = G(z)

    plt.imshow(generated_sample.view(28,28).cpu().numpy())
    plt.title('Generated sample')
    plt.show()

In [None]:
interact(generate_sample, num_idx=widgets.IntSlider(min=0, max=100, step=1, value=0));