In [2]:
import tensorflow as tf
import numpy as np
import datetime
import matplotlib.pyplot as plt
%matplotlib inline

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/")

In [None]:
sample_image = mnist.train.next_batch(1)[0]
sample_image = sample_image.reshape([28,28])
plt.imshow(sample_image, cmap = 'Greys')

In [None]:
#Images coming in are 28 x 28 x 1
def discriminator(images, reuse = None):
    with tf.variable_scope(tf.get_variable_scope(), reuse = reuse) as scope:
        # First convolution and pool layes
        # 32 feature maps from 5 x 5 filters
        
        d_w1 = tf.get_variable('dw1', [5,5,1,32], initializer=tf.truncated_normal_initializer(stddev=0.02))
        d_b1 = tf.get_variable('db1', [32], initializer=tf.constant_initializer(0))
        d1 = tf.nn.conv2d(input=images, filter=d_w1, strides = [1,1,1,1], padding='SAME')
        # d1 should be 28 x 28 x 32 (op_height= ceil(ip_height/strides) )
        d1 = d1 + d_b1
        d1 = tf.nn.relu(d1)
        d1 = tf.nn.avg_pool(d1, ksize=[1,2,2,1], strides= [1,2,2,1], padding='SAME')
        # d1 should be now 14 x 14 x 32 (op_height = ip_height/stride)
        
        # Second convolution and pool layers
        # 64 filters of 5 x 5 x 32
        d_w2 = tf.get_variable('dw2', [5,5,32,64], initializer=tf.truncated_normal_initializer(stddev=0.02))
        d_b2 = tf.get_variable('db2', [64], initializer=tf.constant_initializer(0))
        d2 = tf.nn.conv2d(input=images, filter=d_w2, strides = [1,1,1,1], padding='SAME')
        # d1 should be 14 x 14 x 64 ( op_height= ceil(ip_height/strides) )
        d2 = d2 + d_b2
        d2 = tf.nn.relu(d2)
        d2 = tf.nn.avg_pool(d2, ksize=[1,2,2,1], strides= [1,2,2,1], padding='SAME')
        # d1 should be now 7 x 7 X 64 (op_height = ip_height/stride)
        
        # First fully connected layer
        d_w3 = tf.get_variable('dw3', [7 * 7 * 64, 1024], initializer=tf.truncated_normal_initializern(stddev=0.02))
        d_b3 = tf.get_variable('db3', [64], initializer=tf.constant_initializer(0))
        d3 = tf.reshape(d2, [-1, 7 * 7 * 64])
        d3 = tf.matmul(d3, dw_3)
        d3 = d3 + db3
        d3 = tf.nn.relu(d3)
        
        # Second fully connected layer
        d_w4 = tf.get_variable('d_w4', [1024, 1], initializer=tf.truncated_normal_initializer(stddev=0.02))
        d_b4 = tf.get_variable('d_b4', [1], initializer=tf.constant_initializer(0))
        d4 = tf.matmul(d3, d_w4) + d_b4        
        
        # d4 contains unscaled value
        return d4

In [3]:
# z is placeholder, z_dim is 100
def generator(z, batch_size, z_dim):
    g_w1 = tf.get_variable('g_w1', [z_dim, 3136], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
    g_b1 = tf.get_variable('g_b1', [3136], initializer=tf.truncated_normal_initializer(stddev=0.02))
    g1 = tf.matmul(z, g_w1) + g_b1
    # reshape to [-1, 56, 56, 1],  -1 being automatic calculation of batch size, 56x56 is 3136, 1 is channel
    g1 = tf.reshape(g1, [-1, 56,56,1])
    g1 =tf.contrib.layers.batch_norm(g1, epsilon=1e-5, scope='bn1')
    g1 = tf.nn.relu(g1)
    
    # Generate 50 filters with dimension 3x3x1
    g_w2 = tf.get_variable('g_w2', [3, 3, 1, z_dim/2], dtype=float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
    g_b2 = tf.get_variable('g_b2', [z_dim/2], initializer=tf.truncated_normal_initializer(stddev=0.02))
    g2 = tf.nn.conv2d(g1, g_w2, strides=[1,2,2,1], padding='SAME')
    g2 = g2 + g_b2
    g2 = tf.contrib.layers.batch_norm(g2, epsilon=1e-5, scope='bn2')
    g2 = tf.nn.relu(g2)
    # Need to understand this better. How resize from 28x28x50 to 56x56 or possibly 56x56x50 not sure
    # Looks like resize_images takes height and width and doesn't touch channels (50 in this case) and batch_size
    g2 = tf.image.resize_images(g2, [56,56])
    
    # Generate 25 filters with dimension 3x3x50
    g_w3 = tf.get_variable('g_w3', [3, 3, z_dim/2, z_dim/4], dtype=float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
    g_b3 = tf.get_variable('g_b2', [z_dim/4], initializer=tf.truncated_normal_initializer(stddev=0.02))
    g3 = tf.nn.conv2d(g2, g_w3, strides=[1,2,2,1], padding='SAME')
    g3 = g3 + g_b3
    g3 = tf.contrib.layers.batch_norm(g3, epsilon=1e-5, scope = 'bn3')
    g3 = tf.nn.relu(g3)
    g3 = tf.image.resize_images(g3, [56, 56])
    
    # Final convolution with one channel image
    g_w4 = tf.get_variable('g_w4', [1, 1, z_dim/4, 1], dtype=float32, initializer=tf.truncated_normal_initializer(stddev=0.02))
    g_b4 = tf.get_variable('g_b4', [1], initializer=tf.truncated_normal_initializer(stddev=0.02))
    g4 = tf.nn.conv2d(g3, g_w4, strides=[1,2,2,1], padding='SAME')
    g4 = g4 + g_b4
    g4 = tf.sigmoid(g4)
    
    # Dimensions of g4: batch_size x 28 x 28 x 1
    return g4

In [None]:
## Generate sample image using our generator

In [None]:
z_dimensions = 100
z_placeholder = tf.placeholder(dtype=float32, [None, z_dimensions])

generated_image_output = generator(z_placeholder, 1, z_dimensions)
                        # mean, std, [dimensions]
z_batch = np.random.normal(0, 1, [1, z_dimensions])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    generated_image = sess.run(generated_image_output, feed_dict = {z_placeholder: z_batch})
    generated_image = generated_image.reshape([28,28])
    plt.imshow(generated_image, cmap = 'Greys')

In [None]:
## Training the GAN

In [None]:
tf.reset_default_graph()
batch_size = 50

z_placeholder = tf.placeholder(float32, [None, z_dimensions], name = 'z_placeholder')

x_placeholder = tf.placeholder(float32, [None, 28, 28, 1], name = 'x_placeholder')

Gz = generator(z_placeholder, batch_size, z_dimensions)

Dx = discriminator(x_placeholder)

Dg = discriminator(Gz, reuse = True)

In [None]:
## Calculate loss

In [None]:
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dx, labels = tf.ones_like(Dx)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= Dz, labels=tf.zeros_like(Dz)))

In [None]:
# Calculate generator loss

In [None]:
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = Dg, labels = tf.ones_like(Dg)))

In [None]:
## Trainable weights for generators and discriminators

In [None]:
tvars = tf.trainable_variables()

d_vars = [var for var in tvars if 'd_' in var.name]
g_vars = [var for var in tvars if 'g_' in var.name]

print([v.name for v in d_vars])
print([v.name for v in g_vars])

In [None]:
## Optimizers

In [None]:
d_trainer_fake = tf.train.AdamOptimizer(0.0003).minimize(d_loss_fake, var_list = d_vars)
d_trainer_real = tf.train.AdamOptimizer(0.0003).minimize(d_loss_real, var_list = d_vars)

g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list = g_vars)

In [None]:
## Add summary

In [None]:
tf.get_variable_scope().reuse_variables()

tf.summary.scalar('Generator_loss', g_loss)
tf.summary.scalar('Discriminator_loss_real', d_loss_real)
tf.summary.scalar('Discriminator loss fake', d_loss_fake)

images_for_tensorboard = generator(z_placeholder, batch_size, z_dimensions)
tf.summary.image('Generated_images', images_for_tensorboard, 5)
merged = tf.summary.merge_all()

logdir = "tensorboard/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "/"
writer = tf.summary.FileWriter(logdir, sess.graph)

In [None]:
# Training

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Pretrain discriminator with real images as well as fake noise
for i in range(300):
    z_batch = np.random.normal(0, 1, size = [None, z_dimensions])
    real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1])
    _, _. dLossReal, dLossFake = sess.run([d_trainer_real, d_trainer_fake, d_loss_real, d_loss_fake], 
                                          {x_placeholder: real_image_batch, z_placeholder:z_batch})
    
    if(i % 100 == 0):
        print('dLossReal:', dLossReal, 'dLossFake:', dLossFake)
        
# Train generator and descriminator together

for i in range(100000):
    real_image_batch = mnist.train.next_batch(batch_size)[0].reshape([batch_size, 28, 28, 1])
    z_batch = np.random.normal(0, 1, size=[batch_size, z_dimensions])

    # Train discriminator on both real and fake images
    _, __, dLossReal, dLossFake = sess.run([d_trainer_real, d_trainer_fake, d_loss_real, d_loss_fake],
                                           {x_placeholder: real_image_batch, z_placeholder: z_batch})
    
    z_batch = tf.random_normal(0,1, [batch_size, z_dimensions])
    _ = sess.run([g_trainer, g_loss], {z_placeholder:z_batch})

    if(i % 10 == 0):
        # Update Tensorboard with summary statistics
        z_batch = tf.random_normal(0,1, [batch_size, z_dimensions])
        summary = sess.run(merged, {z_placeholder: z_batch, x_placeholder:real_image_batch})
        writer.add_summary(summary, i)
        
    if i % 100 == 0:
        # Every 100 iterations, show a generated image
        print("Iteration:", i, "at", datetime.datetime.now())
        z_batch = np.random.normal(0, 1, size=[1, z_dimensions])
        generated_images = generator(z_placeholder, 1, z_dimensions)
        images = sess.run(generated_images, {z_placeholder: z_batch})
        plt.imshow(images[0].reshape([28, 28]), cmap='Greys')
        plt.show()

        # Show discriminator's estimate
        im = images[0].reshape([1, 28, 28, 1])
        result = discriminator(x_placeholder)
        estimate = sess.run(result, {x_placeholder: im})
        print("Estimate:", estimate)