<a href="https://colab.research.google.com/github/mlelarge/dataflowr/blob/master/WGAN_empty_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Wasserstein GAN

Wasserstein GAN is a development on the famous GAN which avoids vanishing gradient issues by comparing the generated and real data distribution with Earth Mover distance.

For a brief overview, refer to [this](https://paper.dropbox.com/doc/Wasserstein-GAN--AnSk4vkryFmJgICMb_fybpwHAg-GvU0p2V9ThzdwY3BbhoP7).

For a more in-depth discussion, refer this [blogpost](https://www.alexirpan.com/2017/02/22/wasserstein-gan.html) or the [paper](https://arxiv.org/abs/1701.07875).

Here we will implement a WGAN for MNIST.


In practice, there are only really two things that change in WGAN compared to regular GAN :

1.   Instead of having a discriminator, which outputs a probability, we have a critic, which outputs a score. Hence there is no need of sigmoid at the output, and no log in the loss
2.   To ensure that the function represented by the critic is lipschitz, we clip the weights of the critic.



In [0]:
import torch
from torch import nn
from torch.autograd import Variable
from torch.optim import RMSprop
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CIFAR10, MNIST
from pylab import plt
from tqdm import tqdm_notebook
%matplotlib inline

In [0]:
class Config:
    lr = 0.00005
    nz = 100 # noise dimension
    image_size = 64 # we will resize all images to 64*64 using torch.transforms
    nc = 1 # number of channels of the image (1 for MNIST, 3 for CIFAR10)
    ngf = 64 # number of channels of generator
    ndf = 64 # number of channels of discriminator
    batch_size = 32
    max_epoch = 50 # =1 when debug
    gpu = torch.cuda.is_available() # use gpu or not
    clamp_num=0.01 # WGAN gradient clipping parameter
    
opt=Config()
print('Using gpu : ', opt.gpu)

Using gpu :  True


In [0]:
# data preprocess

transform=transforms.Compose([
                transforms.Resize(opt.image_size) ,
                transforms.ToTensor(),
                transforms.Normalize([0.5],[0.5])
                ])

dataset=MNIST(root='.',transform=transform,download=True)
dataloader=torch.utils.data.DataLoader(dataset, opt.batch_size, shuffle = True)

In [0]:
''' REMINDER : nn.Conv2d takes as arguments 
nn.Conv2D(input channels, output channels, kernel_size, stride, padding)

spatial dimension after the convolution is given by  
output_dim = (input_dim + 2*padding - kernel_size) / stride + 1

nn.Conv2d does the exact opposite (upsampling rather than downsampling)
'''

netd = nn.Sequential(
            # layer 1 : spatial dimension 64 -> 32
            nn.Conv2d(?, opt.ndf, ?, ?, ?,bias=False), 
            nn.LeakyReLU(0.2,inplace=True), 
            
            # layer 2 : spatial dimension 32 -> 16
            nn.Conv2d(opt.ndf,opt.ndf*2, ?, ?, ?, bias=False), 
            nn.BatchNorm2d(opt.ndf*2),
            nn.LeakyReLU(0.2,inplace=True),
            
            # layer 3 : spatial dimension 16 -> 8
            nn.Conv2d(opt.ndf*2,opt.ndf*4, ?, ?, ?, bias=False), 
            nn.BatchNorm2d(opt.ndf*4),
            nn.LeakyReLU(0.2,inplace=True),
            
            # layer 4 : spatial dimension 8 -> 4
            nn.Conv2d(opt.ndf*4,opt.ndf*8, ?, ?, ?, bias=False), 
            nn.BatchNorm2d(opt.ndf*8),
            nn.LeakyReLU(0.2,inplace=True),
            
            # layer 5 : spatial dimension 4 -> 1
            nn.Conv2d(opt.ndf*8, ?, ?, ?, ?,bias=False), 
            # This is a "critic", not a discriminator, so no need of a sigmoid !
        )

netg = nn.Sequential(
            nn.ConvTranspose2d(?,opt.ngf*8, ?, ?, ?,bias=False),
            nn.BatchNorm2d(opt.ngf*8),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4, ?, ?, ?, bias=False),
            nn.BatchNorm2d(opt.ngf*4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2, ?, ?, ?, bias=False),
            nn.BatchNorm2d(opt.ngf*2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(opt.ngf*2,opt.ngf, ?, ?, ?, bias=False),
            nn.BatchNorm2d(opt.ngf),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(opt.ngf, ?, ?, ?, ?, bias=False),
            nn.Tanh() 
            # We are outputting images so the output should be in [-1,1] !
        )


def weight_init(m):
    # weight_initialization: important for wgan
    class_name=m.__class__.__name__
    if class_name.find('Conv')!=-1:
        m.weight.data.normal_(0,0.02)
    elif class_name.find('Norm')!=-1:
        m.weight.data.normal_(1.0,0.02)
#     else:print(class_name)

if opt.gpu:
    netd.cuda()
    netg.cuda()

In [0]:
# optimizer : Use torch.optim.RMSprop instead of torch.optim.Adam
optimizerD = ?
optimizerG = ?

In [0]:
# reset the weights
netd.apply(weight_init)
netg.apply(weight_init)
lossD_history = []
lossG_history = []

In [0]:
'''
The training is very long, you can stop it and resume whenever you want !
Just look at how the generated images evolve, and stop when you think it's not getting any better.
You can resume training afterwards if you wish.
''' begin training

log_every = 100

def train():

  print('Start training')

  fix_noise = torch.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1) 
  # we will see how the generator reconstructs this noise every 100 steps
  if opt.gpu:
      fix_noise = fix_noise.cuda()

  for epoch in range(opt.max_epoch):
      for ii, data in enumerate(dataloader):

        # ----- train netd -----
        # your code here
        # remember to clip the parameters using param.clamp_(min, max) !
        # 

        # ------ train netg -------
        # your code here
        # 

        # log every 100 steps
        if ii%log_every==0:
          print('LossD = {}, LossG = {}'.format(lossD, lossG))
          fake = netg(fix_noise)
          imgs = make_grid(fake.data*0.5+0.5).cpu() 
          plt.imshow(imgs.permute(1,2,0).numpy()) 
          plt.show()
          
  return
          
train()

### Visualization

In [0]:
noise = torch.randn(64,opt.nz,1,1)
if opt.gpu:
  noise = noise.cuda()
fake=netg(noise)
imgs = make_grid(fake.data*0.5+0.5).cpu()
plt.figure(figsize=(10,10))
plt.imshow(imgs.permute(1,2,0).numpy())
plt.show()

### Plot the losses of discriminator and generator

In [0]:
plt.plot(lossG_history, label='Generator')
plt.plot(lossD_history, label='Discriminator')
plt.legend()

### Please comment the loss curves

Your answer here

The training involves a competition between the critic and the generator.

If one of the two is winning the fight too easily (loss curve going down much quicker), we can weaken it by updating its weights less often.



### Suggest a modification to your previous code to fix this and see how results are affected

In [0]:
# Remember to reset the weights !
netd.apply(weight_init)
netg.apply(weight_init)
lossD_history = []
lossG_history = []

def train_modified():
  # your code here
  
train_modified()

### Now give the WGAN a quick try on CIFAR10 !

In [0]:
transform=transforms.Compose([
                transforms.Resize(opt.image_size) ,
                transforms.ToTensor(),
                transforms.Normalize([0.5]*3,[0.5]*3)
                ])

dataset=CIFAR(root='.',transform=transform,download=True)
dataloader=torch.utils.data.DataLoader(dataset, opt.batch_size, shuffle = True)

opt.nc = ?
netd = ?
netg = ?

netd.apply(weight_init)
netg.apply(weight_init)
lossD_history = []
lossG_history = []

train_modified()