## preperation

In [1]:
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

In [2]:
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz


In [3]:
# Network Params
image_dim = 784 # 28*28 pixels
gen_hidden_dim = 256
disc_hidden_dim = 256
noise_dim = 128 # Noise data points

## WGAN_GP

In [4]:
tf.reset_default_graph() # Clearing all tensors before this

In [5]:
# trainning param
Batch_Size = 50
Critic_Iters = 5 # for WGAN and WGAN-GP, number of critic iters per gen iter
Lambda = 10 # gradient penalty lambda hyperparameter
Iters = 100001 # number of generator iterations to train for

In [6]:
with tf.name_scope('g_h'):
    g_W1 = tf.Variable(tf.random_normal([noise_dim,gen_hidden_dim]),name='g_W1')
    g_b1 = tf.Variable(tf.random_normal([gen_hidden_dim]),name='g_b1')
with tf.name_scope('g_o'):
    g_W2 = tf.Variable(tf.random_normal([gen_hidden_dim,image_dim]),name='g_W2')
    g_b2 = tf.Variable(tf.random_normal([image_dim]),name='g_b2')

In [7]:
with tf.name_scope('d_h'):
    d_W1 = tf.Variable(tf.random_normal([image_dim,disc_hidden_dim]),name='d_W1')
    d_b1 = tf.Variable(tf.random_normal([disc_hidden_dim]),name='d_b1')
with tf.name_scope('g_o'):
    d_W2 = tf.Variable(tf.random_normal([disc_hidden_dim,1]),name='d_W2')
    d_b2 = tf.Variable(tf.random_normal([1]),name='d_b2')

In [8]:
# Generator
with tf.name_scope('Generator'):
    def generator(noises, reuse=False):
        with tf.variable_scope('generator') as scope:
            if (reuse):
                tf.get_variable_scope().reuse_variables()
            # hidden layer with name "g_hidden"
            hidden = tf.nn.relu(noises @ g_W1 + g_b1, name='gen_hidden')
            # out layer with name "g_out"
            out_images = tf.nn.sigmoid(hidden @ g_W2 + g_b2, name='gen_out')
        return out_images

# Discriminator
with tf.name_scope('Discriminator'):
    def discriminator(images, reuse=False):
        with tf.variable_scope('discriminator') as scope:
            if (reuse):
                tf.get_variable_scope().reuse_variables()            
            # hidden layer with name "d_hidden"
            hidden = tf.nn.relu(images @ d_W1 + d_b1, name='disc_hidden')
            # out layer with name "d_out"
            out = tf.add(hidden @ d_W2, d_b2,name = 'disc_out')
        return out

In [9]:
gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name='input_noise')

In [10]:
fake_data = generator(gen_input)
real_data = tf.placeholder(tf.float32, shape=[None, image_dim], name='real_data')

In [11]:
disc_real = discriminator(real_data)
disc_fake = discriminator(fake_data, reuse=True)

In [12]:
gen_cost = -tf.reduce_mean(disc_fake)
disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

In [13]:
tvars = tf.trainable_variables()
disc_vars = [var for var in tvars if 'd_' in var.name]
gen_vars = [var for var in tvars if 'g_' in var.name]

In [14]:
alpha = tf.random_uniform(shape=[Batch_Size,1],minval=0.,maxval=1.)
differences = fake_data-real_data
interpolates = real_data + (alpha*differences)
gradients = tf.gradients(discriminator(interpolates, reuse=True),[interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients),reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
disc_cost += Lambda*gradient_penalty

In [15]:
train_gen = tf.train.AdamOptimizer(
        learning_rate=1e-4, 
        beta1=0.5,
        beta2=0.9
    ).minimize(gen_cost, var_list=gen_vars)

train_disc = tf.train.AdamOptimizer(
        learning_rate=1e-4, 
        beta1=0.5, 
        beta2=0.9
    ).minimize(disc_cost, var_list=disc_vars)

In [16]:
with tf.name_scope('summaries'):
    loss = tf.summary.scalar('loss',-disc_cost)
    
    gen_W1 = tf.summary.scalar('g_W1',tf.reduce_mean(g_W1))
    gen_b1 = tf.summary.scalar('g_b1',tf.reduce_mean(g_b1))
    gen_W2 = tf.summary.scalar('g_W2',tf.reduce_mean(g_W2))
    gen_b2 = tf.summary.scalar('g_b2',tf.reduce_mean(g_b2))
    
    disc_W1 = tf.summary.scalar('d_W1',tf.reduce_mean(d_W1))
    disc_b1 = tf.summary.scalar('d_b1',tf.reduce_mean(d_b1))
    disc_W2 = tf.summary.scalar('d_W2',tf.reduce_mean(d_W2))
    disc_bb = tf.summary.scalar('d_b2',tf.reduce_mean(d_b2))
    
    summary_op = tf.summary.merge_all()

In [17]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter('WGAN_GP_log',sess.graph,flush_secs = 30)
    for step in range(Iters):

        batch_x, _ = mnist.train.next_batch(Batch_Size)
        # Generate noise to feed to the generator
        z = np.random.uniform(-1., 1., size=[Batch_Size, noise_dim])
        
        # train discriminator
        for i in range(Critic_Iters):
            _,dl = sess.run([train_disc,disc_cost],
                                   feed_dict={real_data:batch_x,gen_input:z})
        
        # train generator
        _,gl=sess.run([train_gen,gen_cost],
                      feed_dict={gen_input:z})
        
        # keep log
        if step % 1000 == 0:
            print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (step, gl, dl))
        if step % 100 == 0:
            summary = sess.run(summary_op,feed_dict={gen_input:z, real_data:batch_x})
            writer.add_summary(summary,global_step = step)
            
        # Generate images from noise, using the generator network.
        if step % 10000 == 0:
            f, a = plt.subplots(4, 10, figsize=(10, 4))
            for i in range(10):
                # Noise input.
                z = np.random.uniform(-1., 1., size=[4, noise_dim])
                g = sess.run([fake_data], feed_dict={gen_input: z})
                g = np.reshape(g, newshape=(4, 28, 28, 1))
                # Reverse colours for better display
                g = -1 * (g - 1)
                
                for j in range(4):
                    # Generate image from noise. Extend to 3 channels for matplot figure.
                    img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),
                                     newshape=(28, 28, 3))
                    a[j][i].imshow(img)

            plt.draw()
            print('wgan_gp'+str(step)+'.png')
            plt.savefig('wgan_gp'+str(step)+'.png')
    print('Done!')
    writer.close()

Step 0: Generator Loss: 114.657654, Discriminator Loss: 1105188.625000
wgan_gp0.png
Step 1000: Generator Loss: -35.852932, Discriminator Loss: 111672.882812
Step 2000: Generator Loss: -28.763277, Discriminator Loss: 22722.531250
Step 3000: Generator Loss: -5.601259, Discriminator Loss: 2372.725342
Step 4000: Generator Loss: 2.792879, Discriminator Loss: 230.540619
Step 5000: Generator Loss: 10.454459, Discriminator Loss: 13.533514
Step 6000: Generator Loss: 6.770102, Discriminator Loss: -5.961549
Step 7000: Generator Loss: 1.877961, Discriminator Loss: 2.025961
Step 8000: Generator Loss: -1.159017, Discriminator Loss: 12.308200
Step 9000: Generator Loss: -1.267867, Discriminator Loss: 3.602905
Step 10000: Generator Loss: -1.310349, Discriminator Loss: 12.007820
wgan_gp10000.png
Step 11000: Generator Loss: -0.566052, Discriminator Loss: -4.430213
Step 12000: Generator Loss: -0.634516, Discriminator Loss: -4.527105
Step 13000: Generator Loss: -0.878309, Discriminator Loss: -4.867041
Step