# How to turn a vanilla GAN into a vanilla WGAN
This notebook will help to enhance on the understanding of WGAN algorithm and WGAN_GP by changing a vanilla GAN into a WGAN, and then a WGAN_GP<br/>
The GAN code is based on Kojin's GAN workshop

## Different Loss Functions 
### GAN
- $D$: $max_D V(D) = E_{x\sim p_{data}}(x)[logD(x)] + E_{z\sim p_z(z)}[log(1-D(G(z))]$
- $G$: $min_G V(G) = E_{z\sim p_z(z)}[log(1-D(G(z))]$

### WGAN
- $D$: $max_D V(D) = E_{x\sim p_{data}}[D(x)] - E_{z\sim p_z(z)}[D(G(z)]$
- $G$: $max_G V(G) = E_{z\sim p_z(z)}[D(G(z)]$

### WGAN-GP
- $D$: $max_D V(D) = E_{x\sim p_{data}}[D(x)] - E_{z\sim p_z(z)}[D(G(z)] + \lambda(\lVert \nabla D(\hat{x}) \rVert_{2}-1)^2$
- $G$: $max_G V(G) = E_{z\sim p_z(z)}[D(G(z)]$

|Features                      |  GAN   | WGAN |WGAN-GP
| --------------------------- |:------:|:----: |:-----
|output layer of Discriminator |Sigmoid | Linear |Linear
|optimizer                     | Adam   | RMS  | Adam
|weight clipping               | False  | True | False
|Batch Normalization           | False  | True | False

##### * Batch Normalization (not covered here)
batch normalization normalize the output of each activation layer by subtracting the batch mean and dividing by the batch standard deviation to stablize training. <br/>

## Preperation

In [None]:
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

In [None]:
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

In [None]:
# Network Params
image_dim = 784 # 28*28 pixels
gen_hidden_dim = 256
disc_hidden_dim = 256
noise_dim = 128 # Noise data points

In [None]:
tf.reset_default_graph() # Clearing all tensors before this

## GAN to WGAN 

In [None]:
# <change> add hyperparameter Critic_Iters and c for WGAN

In [None]:
# trainning param
Batch_Size = 50
Critic_Iters = 5 # for WGAN and WGAN-GP, number of critic iters per gen iter
c = 0.01 # threshold for weight cliping (-c,c)
Iters = 20001 # number of generator iterations to train for

In [None]:
# Generator
def generator(noises, reuse=False):
    with tf.variable_scope('generator') as scope:
        if (reuse):
            tf.get_variable_scope().reuse_variables()
        # hidden layer with name "g_hidden"
        hidden = tf.layers.dense(noises, gen_hidden_dim, tf.nn.relu, name='g_hidden')
        # out layer with name "g_out"
        out_images = tf.layers.dense(hidden, image_dim, tf.nn.sigmoid, name='g_out')
    return out_images

# Discriminator
def discriminator(images, reuse=False):
    with tf.variable_scope('discriminator') as scope:
        if (reuse):
            tf.get_variable_scope().reuse_variables()            
        # hidden layer with name "d_hidden"
        hidden = tf.layers.dense(images, disc_hidden_dim, tf.nn.relu, name='d_hidden') 
        # out layer with name "d_out"
        out = tf.layers.dense(hidden, 1, None, name='d_out') # <change> ReLU output into linear activation
    return out

In [None]:
gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name='input_noise')

In [None]:
fake_data = generator(gen_input)
real_data = tf.placeholder(tf.float32, shape=[None, image_dim], name='real_data')

In [None]:
disc_real = discriminator(real_data)
disc_fake = discriminator(fake_data, reuse=True)

In [None]:
# <change> cost function

In [None]:
gen_cost = -tf.reduce_mean(disc_fake)
disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

In [None]:
tvars = tf.trainable_variables()
disc_vars = [var for var in tvars if 'd_' in var.name]
gen_vars = [var for var in tvars if 'g_' in var.name]

In [None]:
# <change> optimizer to RMS

In [None]:
train_gen = tf.train.RMSPropOptimizer(
        learning_rate=5e-5, 
    ).minimize(gen_cost, var_list=gen_vars)

train_disc = tf.train.RMSPropOptimizer(
        learning_rate=5e-5, 
    ).minimize(disc_cost, var_list=disc_vars)

In [None]:
# <change> add weight clipping

In [None]:
clip_D = [p.assign(tf.clip_by_value(p,-c,c)) for p in disc_vars]

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for step in range(Iters):

        batch_x, _ = mnist.train.next_batch(Batch_Size)
        # Generate noise to feed to the generator
        z = np.random.uniform(-1., 1., size=[Batch_Size, noise_dim])
        
        # train discriminator
        for i in range(Critic_Iters): #have inner loop for discriminator
            _,dl = sess.run([train_disc,disc_cost],
                                   feed_dict={real_data:batch_x,gen_input:z})
            _ = sess.run(clip_D) #put weight_clipping here
        
        # train generator
        _,gl=sess.run([train_gen,gen_cost],
                      feed_dict={gen_input:z})
        
        if step % 1000 == 0 or step == 1:
            print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (step, gl, dl))
            
        # Generate images from noise, using the generator network.
        if step % 10000 == 0 or step == 1:
            f, a = plt.subplots(4, 10, figsize=(10, 4))
            for i in range(10):
                # Noise input.
                z = np.random.uniform(-1., 1., size=[4, noise_dim])
                g = sess.run([fake_data], feed_dict={gen_input: z})
                g = np.reshape(g, newshape=(4, 28, 28, 1))
                # Reverse colours for better display
                g = -1 * (g - 1)
                img_inventory["WGAN_"+str(step)] = g
                
                for j in range(4):
                    # Generate image from noise. Extend to 3 channels for matplot figure.
                    img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),
                                     newshape=(28, 28, 3))
                    a[j][i].imshow(img)

            plt.draw()
            print('wgan'+str(step)+'.png')
            plt.savefig('wgan'+str(step)+'.png')

## WGAN to WGAN-GP 

WGAN-GP uses the similar loss function as WGAN, but WGAN-GP has extra regularization component. <br/>
Besides, they optimize in different ways:<br/>
WGAN-GP uses AdamOptimizer, WGAN uses RMSOptimizer <br/>
WGAN-GP does not require weight clipping <br/>
WGAN-GP needs to hand pick penalty coefficient

In [None]:
tf.reset_default_graph() # Clearing all tensors before this

In [None]:
# <change> add hyperparameter Critic_Iters and Lambda for WGAN-GP

In [None]:
# trainning param
Batch_Size = 50
Critic_Iters = 5 # for WGAN and WGAN-GP, number of critic iters per gen iter
Lambda = 10 # gradient penalty lambda hyperparameter
Iters = 200000 # number of generator iterations to train for

In [None]:
# Generator
def generator(noises, reuse=False):
    with tf.variable_scope('generator') as scope:
        if (reuse):
            tf.get_variable_scope().reuse_variables()
        # hidden layer with name "g_hidden"
        hidden = tf.layers.dense(noises, gen_hidden_dim, tf.nn.relu, name='g_hidden')
        # out layer with name "g_out"
        out_images = tf.layers.dense(hidden, image_dim, tf.nn.sigmoid, name='g_out')
    return out_images

# Discriminator
def discriminator(images, reuse=False):
    with tf.variable_scope('discriminator') as scope:
        if (reuse):
            tf.get_variable_scope().reuse_variables()            
        # hidden layer with name "d_hidden"
        hidden = tf.layers.dense(images, disc_hidden_dim, tf.nn.relu, name='d_hidden')
        # out layer with name "d_out"
        out = tf.layers.dense(hidden, 1, None, name='d_out') # <change> output layer turn into a linear one as WGAN does
    return out

In [None]:
gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name='input_noise')

In [None]:
fake_data = generator(gen_input)
real_data = tf.placeholder(tf.float32, shape=[None, image_dim], name='real_data')

In [None]:
disc_real = discriminator(real_data)
disc_fake = discriminator(fake_data, reuse=True)

In [None]:
gen_cost = -tf.reduce_mean(disc_fake)
disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

In [None]:
tvars = tf.trainable_variables()
disc_vars = [var for var in tvars if 'd_' in var.name]
gen_vars = [var for var in tvars if 'g_' in var.name]

In [None]:
# <change> add regularization component

In [None]:
alpha = tf.random_uniform(shape=[Batch_Size,1],minval=0.,maxval=1.)
differences = fake_data-real_data
interpolates = real_data + (alpha*differences)
gradients = tf.gradients(discriminator(interpolates, reuse=True),[interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients),reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
disc_cost += Lambda*gradient_penalty

In [None]:
# keep using AdamOptimizer

In [None]:
train_gen = tf.train.AdamOptimizer(
        learning_rate=1e-4, 
        beta1=0.5,
        beta2=0.9
    ).minimize(gen_cost, var_list=gen_vars)

train_disc = tf.train.AdamOptimizer(
        learning_rate=1e-4, 
        beta1=0.5, 
        beta2=0.9
    ).minimize(disc_cost, var_list=disc_vars)

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for step in range(Iters):

        batch_x, _ = mnist.train.next_batch(Batch_Size)
        # Generate noise to feed to the generator
        z = np.random.uniform(-1., 1., size=[Batch_Size, noise_dim])
        
        # train discriminator
        for i in range(Critic_Iters):
            _,dl = sess.run([train_disc,disc_cost],
                                   feed_dict={real_data:batch_x,gen_input:z})
        
        # train generator
        _,gl=sess.run([train_gen,gen_cost],
                      feed_dict={gen_input:z})
        
        if step % 1000 == 0 or step == 1:
            print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (step, gl, dl))
    
        # Generate images from noise, using the generator network.
        if step % 10000 == 0 or step == 1:
            f, a = plt.subplots(4, 10, figsize=(10, 4))
            for i in range(10):
                # Noise input.
                z = np.random.uniform(-1., 1., size=[4, noise_dim])
                g = sess.run([fake_data], feed_dict={gen_input: z})
                g = np.reshape(g, newshape=(4, 28, 28, 1))
                # Reverse colours for better display
                g = -1 * (g - 1)
                for j in range(4):
                    # Generate image from noise. Extend to 3 channels for matplot figure.
                    img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),
                                     newshape=(28, 28, 3))
                    a[j][i].imshow(img)

            plt.draw()
            print('wgan_gp'+str(step)+'.png')
            plt.savefig('wgan_gp'+str(step)+'.png')