In [10]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tensorflow.examples.tutorials.mnist import input_data

## Setting up the basics

In [11]:
mnist = input_data.read_data_sets('./inputs/mnist')
# Resetting default graph, starting from scratch
tf.reset_default_graph()

epochs = 1
batch_size = 64
n_noise = 200
learning_rate=0.00015

real_images = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='real_images')
noise = tf.placeholder(dtype=tf.float32, shape=[None, n_noise])

# The keep_prob variable will be used by our dropout layers, which we introduce for more stable learning outcome
keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob')
is_training = tf.placeholder(dtype=tf.bool, name='is_training')

# Leaky Relu activation
# https://en.wikipedia.org/wiki/Rectifier_%28neural_networks%29#Potential_problems
def lrelu(x):
    return tf.maximum(x, tf.multiply(x, 0.2))

# Binary cross entropy for descriminators
def binary_cross_entropy(x, z):
    eps = 1e-12
    return (-(x * tf.log(z + eps) + (1. - x) * tf.log(1. - z + eps)))



Extracting ./inputs/mnist/train-images-idx3-ubyte.gz
Extracting ./inputs/mnist/train-labels-idx1-ubyte.gz
Extracting ./inputs/mnist/t10k-images-idx3-ubyte.gz
Extracting ./inputs/mnist/t10k-labels-idx1-ubyte.gz


# The descriminator

In [12]:
# It takes either real or fake MNIST image 28 x 28 in grayscale
# we use a sigmoid to make sure our output can be interpreted 
# as the probability the input image is a real MNIST character.
def discriminator(real_images, reuse=None, keep_prob=keep_prob):
    activation=lrelu
    with tf.variable_scope('disc', reuse=reuse):
        x = tf.reshape(real_images, shape=[-1, 28, 28, 1])
        x = tf.layers.conv2d(x, kernel_size=5, filters=64, strides=2, padding='same', activation=activation)
        x = tf.layers.dropout(x, keep_prob)
        x = tf.layers.conv2d(x, kernel_size=5, filters=64, strides=1, padding='same', activation=activation)
        x = tf.layers.dropout(x, keep_prob)
        x = tf.layers.conv2d(x, kernel_size=5, filters=64, strides=1, padding='same', activation=activation)
        x = tf.layers.dropout(x, keep_prob)
        x = tf.layers.flatten(x)
        x = tf.layers.dense(x, units=128, activation=activation)
        x = tf.layers.dense(x, units=1, activation=tf.nn.sigmoid)
        return x

# The generator

In [13]:
# z => noise
def generator(z, keep_prob=keep_prob, is_training=is_training):
    activation = lrelu
    momentum = 0.99
    with tf.variable_scope('gen', reuse=None):
        x = z
        d1 = 4
        d2 = 1
        print('checking units ', (d1 * d1 * d2))
        x = tf.layers.dense(x, units=d1 * d1 * d2, activation=activation)
        x = tf.layers.dropout(x, keep_prob)
        # https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm
        x = tf.contrib.layers.batch_norm(x, is_training=is_training, decay=momentum)
        x = tf.reshape(x, shape=[-1, d1, d1, d2])
        x = tf.image.resize_images(x, size=[7, 7])
        x = tf.layers.conv2d_transpose(x, kernel_size=5, filters=64, strides=2, padding='same', activation=activation)
        x = tf.layers.dropout(x, keep_prob)
        x = tf.contrib.layers.batch_norm(x, is_training=is_training, decay=momentum)
        x = tf.layers.conv2d_transpose(x, kernel_size=5, filters=64, strides=2, padding='same', activation=activation)
        x = tf.layers.dropout(x, keep_prob)
        x = tf.contrib.layers.batch_norm(x, is_training=is_training, decay=momentum)
        x = tf.layers.conv2d_transpose(x, kernel_size=5, filters=64, strides=1, padding='same', activation=activation)
        x = tf.layers.dropout(x, keep_prob)
        x = tf.contrib.layers.batch_norm(x, is_training=is_training, decay=momentum)
        x = tf.layers.conv2d_transpose(x, kernel_size=5, filters=1, strides=1, padding='same', activation=tf.nn.sigmoid)
        # x = tf.layers.dense(x, units=784, activation=tf.nn.tanh)
        return x

# Loss functions and optimizers

In [14]:
g = generator(noise, keep_prob, is_training)
d_real = discriminator(real_images)
d_fake = discriminator(g, reuse=True)

vars_g = [var for var in tf.trainable_variables() if 'gen' in var.name]
vars_d = [var for var in tf.trainable_variables() if 'disc' in var.name]

# Applying regularizers
d_reg = tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(1e-6), vars_d)
g_reg = tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(1e-6), vars_g)

loss_d_real = binary_cross_entropy(tf.ones_like(d_real), d_real)
loss_d_fake = binary_cross_entropy(tf.zeros_like(d_fake), d_fake)
loss_g = tf.reduce_mean(binary_cross_entropy(tf.ones_like(d_fake), d_fake))
loss_d = tf.reduce_mean(0.5 * (loss_d_real + loss_d_fake))

# optimizer_d = tf.train.AdamOptimizer(learning_rate).minimize(loss_d, var_list=vars_d)
# optimizer_g = tf.train.AdamOptimizer(learning_rate).minimize(loss_g, var_list=vars_g)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer_d = tf.train.RMSPropOptimizer(learning_rate=0.00015).minimize(loss_d + d_reg, var_list=vars_d)
    optimizer_g = tf.train.RMSPropOptimizer(learning_rate=0.00015).minimize(loss_g + g_reg, var_list=vars_g)


sess = tf.Session()
sess.run(tf.global_variables_initializer())

('checking units ', 16)


# Training GAN

In [15]:
samples = []
for i in range(epochs):
    train_d = True
    train_g = True
    keep_prob_train = 0.6

    # Creating noise
    n = np.random.uniform(0.0, 1.0, [batch_size, n_noise]).astype(np.float32)
    # batch = [np.reshape(b, [28, 28]) for b in mnist.train.next_batch(batch_size=batch_size)[0]]
    batch_x, _ = mnist.train.next_batch(batch_size)
    batch_x = np.reshape(batch_x, newshape=[-1, 28, 28, 1])
    print(batch_x.shape)
    d_real_ls, d_fake_ls, g_ls, d_ls = sess.run(
        [loss_d_real, loss_d_fake, loss_g, loss_d],
        feed_dict={real_images: batch_x, noise: n, keep_prob: keep_prob_train, is_training: True}
    )
    d_real_ls = np.mean(d_real_ls)
    d_fake_ls = np.mean(d_fake_ls)

    g_ls = g_ls
    d_ls = d_ls

    if g_ls * 1.5 < d_ls:
        train_g = False
        pass

    if d_ls * 2 < g_ls:
        train_d = False
        pass
    if train_d:
        sess.run(optimizer_d, feed_dict={noise: n, real_images: batch_x, keep_prob: keep_prob_train, is_training:True})
    if train_g:
        sess.run(optimizer_g, feed_dict={noise: n, keep_prob: keep_prob_train, is_training:True})
        
    # Showing sample image
    if not i % 50:
        gen_sample = sess.run(g, feed_dict={noise: n, keep_prob: 1.0, is_training:False})
        print(gen_sample)

(64, 28, 28, 1)
[[[[ 0.49791801]
   [ 0.49748281]
   [ 0.49700615]
   ..., 
   [ 0.49885151]
   [ 0.49746954]
   [ 0.49890342]]

  [[ 0.49799049]
   [ 0.49680182]
   [ 0.49604493]
   ..., 
   [ 0.49735928]
   [ 0.49867314]
   [ 0.49852818]]

  [[ 0.49831033]
   [ 0.49669075]
   [ 0.49577042]
   ..., 
   [ 0.49795324]
   [ 0.49842829]
   [ 0.49782127]]

  ..., 
  [[ 0.49652216]
   [ 0.49823493]
   [ 0.49667582]
   ..., 
   [ 0.49679911]
   [ 0.4973315 ]
   [ 0.4979662 ]]

  [[ 0.49943677]
   [ 0.49992043]
   [ 0.49754795]
   ..., 
   [ 0.49680981]
   [ 0.49833816]
   [ 0.49838635]]

  [[ 0.49963781]
   [ 0.49893215]
   [ 0.49836332]
   ..., 
   [ 0.49892265]
   [ 0.49985045]
   [ 0.4989261 ]]]


 [[[ 0.49824616]
   [ 0.49758381]
   [ 0.49718449]
   ..., 
   [ 0.4991987 ]
   [ 0.49748242]
   [ 0.49930865]]

  [[ 0.49809363]
   [ 0.49667346]
   [ 0.49577123]
   ..., 
   [ 0.49757129]
   [ 0.49883398]
   [ 0.49897665]]

  [[ 0.49858898]
   [ 0.49684972]
   [ 0.49592471]
   ..., 
   [ 0.498

In [16]:
# Still pretty noisy
plt.imshow(samples[0].reshape(28, 28), cmap='gray')

IndexError: list index out of range