In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

In [2]:
# Import data
from tensorflow.examples.tutorials.mnist import input_data

In [3]:
import tensorflow as tf
import os

In [4]:
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('data_dir', '/tmp/data/', 'Directory for storing data')

In [5]:
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [6]:
IMAGE_WH_SIZE = 28
INPUT_SIZE = IMAGE_WH_SIZE * IMAGE_WH_SIZE
HIDDEN_ENCODER_SIZE = 400
HIDDEN_DECODER_SIZE = 400
LATENT_SPACE_SIZE = 20

ADAGRAD_LR = 0.01 # Try with {0.01, 0.02, 0.1}
MINIBATCH_SIZE = 100
NUMBER_ITERATIONS = 1000

INIT_STD_DEV = 0.01

In [7]:
# Helpers

def create_W(shape):
    return tf.Variable(tf.truncated_normal(shape, stddev=INIT_STD_DEV))

def create_b(shape):
    return tf.Variable(tf.zeros(shape))


In [8]:
# Define Layers

# Input
x = tf.placeholder(tf.float32, [None, INPUT_SIZE])

# Encoder
W_x_h_enc = create_W([INPUT_SIZE, HIDDEN_ENCODER_SIZE])
b_x_h_enc = create_b([HIDDEN_ENCODER_SIZE])
h_enc = tf.sigmoid(tf.add(tf.matmul(x, W_x_h_enc), b_x_h_enc))

W_h_mu_enc = create_W([HIDDEN_ENCODER_SIZE, LATENT_SPACE_SIZE])
b_h_mu_enc = create_b([LATENT_SPACE_SIZE])
mu_enc = tf.sigmoid(tf.add(tf.matmul(h_enc, W_h_mu_enc), b_h_mu_enc))

W_h_logsigma2_enc = create_W([HIDDEN_ENCODER_SIZE, LATENT_SPACE_SIZE])
b_h_logsigma2_enc = create_b([LATENT_SPACE_SIZE])
logsigma2_enc = tf.sigmoid(tf.add(tf.matmul(h_enc, W_h_logsigma2_enc), b_h_logsigma2_enc))

# Sampler
eps_enc = tf.random_normal(shape=tf.shape(mu_enc))
sigma_enc = tf.exp(0.5 * logsigma2_enc)
z = tf.mul(sigma_enc, eps_enc)

# Decoder
W_z_h_dec = create_W([LATENT_SPACE_SIZE, HIDDEN_DECODER_SIZE])
b_z_h_dec = create_b([HIDDEN_DECODER_SIZE])
h_dec = tf.sigmoid(tf.add(tf.matmul(z, W_z_h_dec), b_z_h_dec))

W_h_x_dec = create_W([HIDDEN_DECODER_SIZE, INPUT_SIZE])
b_h_x_dec = create_b([INPUT_SIZE])
x_dec = tf.add(tf.matmul(h_dec, W_h_x_dec), b_h_x_dec)

log_p_x_z = tf.reduce_sum(-tf.nn.sigmoid_cross_entropy_with_logits(x_dec, x), reduction_indices=1)
KL_q_z_x_vs_p_z = - 0.5 * tf.reduce_sum(1 + logsigma2_enc - tf.square(mu_enc) - tf.square(sigma_enc) , reduction_indices=1)


In [9]:
lower_bound = - KL_q_z_x_vs_p_z + log_p_x_z
loss = - tf.reduce_mean(lower_bound)

In [10]:
train_it = tf.train.AdagradOptimizer(learning_rate=ADAGRAD_LR).minimize(loss)

In [11]:
# Summaries
loss_summ = tf.scalar_summary("loss", loss)
summary = tf.merge_all_summaries()


In [12]:
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

In [13]:
# Training (this code should be updated to follow the use of FLAGS from here: http://stackoverflow.com/questions/33759623/tensorflow-how-to-restore-a-previously-saved-model-python )

with tf.Session() as sess:
    summary_writer = tf.train.SummaryWriter('logs', graph=sess.graph)

    if os.path.isfile("models/model.ckpt"):
        saver.restore(sess, "models/model.ckpt")
        print("Model restored.")
    else:
        sess.run(tf.initialize_all_variables())
        print("Initialize parameters.")
    
    for it in xrange(NUMBER_ITERATIONS):
        minibatch = mnist.train.next_batch(MINIBATCH_SIZE)
        cur_train_it, cur_summary, cur_loss = sess.run([train_it, summary, loss], feed_dict={x: minibatch[0]})
        summary_writer.add_summary(cur_summary, it)
        
        if it % 50 == 0:
            save_path = saver.save(sess, "models/model.ckpt")
            print("Iteration {0} | Loss: {1}".format(it, cur_loss))

Initialize parameters.
Iteration 0 | Loss: 548.909545898
Iteration 50 | Loss: 211.465820312
Iteration 100 | Loss: 201.980422974
Iteration 150 | Loss: 218.528625488
Iteration 200 | Loss: 203.234771729
Iteration 250 | Loss: 207.564849854
Iteration 300 | Loss: 206.415664673
Iteration 350 | Loss: 203.245681763
Iteration 400 | Loss: 203.622299194
Iteration 450 | Loss: 192.473358154
Iteration 500 | Loss: 211.761322021
Iteration 550 | Loss: 205.576217651
Iteration 600 | Loss: 211.391799927
Iteration 650 | Loss: 210.300827026
Iteration 700 | Loss: 209.532806396
Iteration 750 | Loss: 206.712600708
Iteration 800 | Loss: 203.39730835
Iteration 850 | Loss: 207.924591064
Iteration 900 | Loss: 201.052230835
Iteration 950 | Loss: 203.578552246
