In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')
input_size = 784

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
params = {
    'dropout': 0.8,
    'e_layers': [128],
    'd_layers': [128],
    'd_std_cold_start': 0.3,
    'cold_start_ends': 20,
    'z_dim': 100,
    'activation': 'sigmoid',
    'batch_size': 100,
    'epochs': 50,
    'learning_rate': 0.0002
}

In [4]:
def stochastic_net(x, layers, output_dim):
    for layer in layers:
        x = tf.layers.dense(x,
                            layer,
                            activation={'tanh': tf.nn.tanh, 'sigmoid': tf.nn.sigmoid}[params['activation']])
        x = tf.nn.dropout(x, params['dropout'])
    mu = tf.layers.dense(x, output_dim)
    var = 1e-5 + tf.exp(tf.layers.dense(x, output_dim))
    return mu, var

In [5]:
# Create our input placeholders
images = tf.placeholder(tf.float32, (None, input_size), name='image')

# Build the model
e_mu, e_var = stochastic_net(images, params['e_layers'], params['z_dim'])
eps = tf.random_normal([tf.shape(images)[0], params['z_dim']], mean=0.0, stddev=1.0)
z = e_mu + tf.sqrt(e_var) * eps
d_mu, d_var = stochastic_net(z, params['d_layers'], input_size)
d_mu = tf.nn.sigmoid(d_mu)

num_batches = mnist.train.num_examples//params['batch_size']
global_step = tf.train.get_or_create_global_step()
d_std = tf.cond(global_step < params['cold_start_ends'] * num_batches,
                lambda: params['d_std_cold_start'],
                lambda: tf.sqrt(d_var))

In [6]:
# Calculate losses
loss_encode = - 0.5 * tf.reduce_sum(1 + tf.log(e_var) - e_mu ** 2 - e_var, axis=1)
loss_decode = - tf.reduce_sum(tf.contrib.distributions.Normal(d_mu, d_std).log_prob(images), axis=1)
loss = tf.reduce_mean(loss_encode + loss_decode, axis=0)

In [None]:
# Optimizers
train_op = tf.train.AdamOptimizer(params['learning_rate']).minimize(loss, global_step=global_step)

In [None]:
samples = []
losses = []
# Only save generator variables
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(params['epochs']):
        for ii in range(num_batches):
            batch = mnist.train.next_batch(params['batch_size'])
            
            # Get images
            batch_images = batch[0].reshape((params['batch_size'], input_size))
            
            # Sample random noise for G
            batch_z = np.random.uniform(-1, 1, size=(params['batch_size'], params['z_dim']))
            
            # Run optimizer
            sess.run(train_op, feed_dict={images: batch_images})
        
        # At the end of each epoch, get the loss and print it out
        train_loss = sess.run(loss, {images: batch_images})
            
        print("Epoch {}/{}...".format(e+1, params['epochs']),
              "Loss: {:.4f}".format(train_loss))
        # Save losses to view after training
        losses.append(train_loss)
        
        # Sample an image as we're training for viewing afterwards
        sample_z = np.random.randn(16, params['z_dim'])
        gen_samples = sess.run(
                       d_mu,
                       feed_dict={z: sample_z})
        samples.append(gen_samples)
        saver.save(sess, './checkpoints/generator.ckpt')
        if np.isnan(train_loss):
            print 'loss is NaN!'
            break

('Epoch 1/50...', 'Loss: 67.0471')
('Epoch 2/50...', 'Loss: 65.0508')
('Epoch 3/50...', 'Loss: 61.8645')
('Epoch 4/50...', 'Loss: 43.5620')
('Epoch 5/50...', 'Loss: 24.5086')
('Epoch 6/50...', 'Loss: 0.0049')
('Epoch 7/50...', 'Loss: -13.7173')
('Epoch 8/50...', 'Loss: -17.1002')
('Epoch 9/50...', 'Loss: -0.8542')
('Epoch 10/50...', 'Loss: -23.8613')
('Epoch 11/50...', 'Loss: -20.4175')
('Epoch 12/50...', 'Loss: -26.0050')
('Epoch 13/50...', 'Loss: -25.8658')
('Epoch 14/50...', 'Loss: -29.8176')
('Epoch 15/50...', 'Loss: -30.6528')
('Epoch 16/50...', 'Loss: -34.4893')
('Epoch 17/50...', 'Loss: -43.9566')
('Epoch 18/50...', 'Loss: -46.2164')
('Epoch 19/50...', 'Loss: -38.4015')
('Epoch 20/50...', 'Loss: 755.1227')
('Epoch 21/50...', 'Loss: -1442.5038')
('Epoch 22/50...', 'Loss: -1548.2920')
('Epoch 23/50...', 'Loss: -1566.0380')
('Epoch 24/50...', 'Loss: -1642.9843')
('Epoch 25/50...', 'Loss: -1852.0916')
('Epoch 26/50...', 'Loss: -1838.0902')
('Epoch 27/50...', 'Loss: -1767.4000')
('Ep

In [None]:
fig, ax = plt.subplots()
plt.plot(losses)
plt.title('Training Loss')
plt.legend()

In [None]:
def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes

In [None]:
_ = view_samples(-1, samples)

In [None]:
rows, cols = min(len(samples), 20), 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)