# Variational AutoEncoder (MLP) Tensorflow
Train a variational autoencoder with MNIST dataset

#### References:
* http://kvfrans.com/variational-autoencoders-explained/
* https://github.com/kvfrans/variational-autoencoder
* https://github.com/int8/VAE_tensorflow
* http://int8.io/variational-autoencoder-in-tensorflow/
* http://blog.fastforwardlabs.com/2016/08/22/under-the-hood-of-the-variational-autoencoder-in.html
* http://blog.fastforwardlabs.com/2016/08/12/introducing-variational-autoencoders-in-prose-and.html
* https://blog.keras.io/building-autoencoders-in-keras.html
* https://jaan.io/what-is-variational-autoencoder-vae-tutorial/
* https://arxiv.org/pdf/1606.05908.pdf
* https://arxiv.org/pdf/1312.6114.pdf
* http://wiseodd.github.io/techblog/2016/12/10/variational-autoencoder/
* https://www.tensorflow.org/get_started/embedding_viz
* https://www.youtube.com/watch?v=eBbEDRsCmv4
* https://www.youtube.com/watch?v=bbOFvxbMIV0
* https://www.youtube.com/watch?v=P78QYjWh5sM
* https://github.com/normanheckscher/mnist-tensorboard-embeddings
* http://projector.tensorflow.org/

In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)

SAVE_FOLDER='/tmp/vae_cnn'

# Delete directory if exist
if os.path.exists(SAVE_FOLDER):    
    os.system("rm -rf " + SAVE_FOLDER)

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Learning parameters
start_lr = 0.001
num_epoch = 200
batch_size = 100

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
from models import VAE_CNN
model = VAE_CNN(latent_size = 20)
model_in = model.input
model_out = model.output
model_out_flat = model.output_flat
z_mean = model.z_mean
z_stddev = model.z_stddev

### Define loss

In [3]:
with tf.name_scope("VAE_LOSS"):
    # Binary cross entropy
    generation_loss = -tf.reduce_sum(
        model_in * tf.log(1e-8 + model_out_flat) + (1-model_in) * tf.log(1e-8 + 1 - model_out_flat),1)
    
    # L2/L1 Loss (Works, but sometimes you my get a NaN)
    #generation_loss = tf.norm(model_in-model_out_flat, ord=1)
    #generation_loss = tf.nn.l2_loss(model_in-model_out_flat)
    
    # KL Loss
    latent_loss = 0.5 * tf.reduce_sum(
        tf.square(z_mean) + tf.square(z_stddev) - tf.log(tf.square(z_stddev)) - 1,1)
    
    # Merge the losses
    loss = tf.reduce_mean(generation_loss + latent_loss)

### Define the Solver

In [4]:
# Solver configuration
# Solver configuration
# Get ops to update moving_mean and moving_variance from batch_norm
# Reference: https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.name_scope("Solver"):
    # Stuff for learning rate decay
    global_step = tf.Variable(0, trainable=False)
    starter_learning_rate = start_lr
    # decay every 10000 steps with a base of 0.96
    learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
                                               1000, 0.9, staircase=True)
    # Optimizer
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

### Build Graph

In [5]:
init = tf.global_variables_initializer()

# Avoid allocating the whole memory
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

#sess = tf.Session()
sess.run(init)

### Add some tensors to observe on Tensorboard

In [6]:
tf.summary.image("input_image", model.image_in, 4)
tf.summary.image("output_image", model_out, 4)
tf.summary.scalar("global_step", global_step)
tf.summary.scalar("learning_rate", learning_rate)
tf.summary.scalar("loss", loss)

merged_summary = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(SAVE_FOLDER, graph=tf.get_default_graph())

# Create saver object
saver = tf.train.Saver()

### Train

In [7]:
# For each epoch
for epoch in range(num_epoch):
    for i in range(int(mnist.train.num_examples / batch_size)):
        # Get batch of 50 images
        batch = mnist.train.next_batch(50)

        # Dump summary
        if i % 5000 == 0:                                    
            # Save embedding (for PCA, TSNE)
            sess.run(model.assignment, feed_dict={model_in: batch[0]})            


        # Train actually here (Also get loss value)    
        _, val_loss = sess.run((train_step, loss), feed_dict={model_in:batch[0]})
        
        # write logs at every iteration
        summary = sess.run(merged_summary, feed_dict={model_in:batch[0]})
        summary_writer.add_summary(summary, epoch * batch_size + i)
        
    print('Epoch: %d/%d loss:%d' % (epoch, num_epoch, val_loss))
    
    # Save checkpoint after each epoch
    if not os.path.exists(SAVE_FOLDER):
        os.makedirs(SAVE_FOLDER)
    checkpoint_path = os.path.join(SAVE_FOLDER, "model")
    filename = saver.save(sess, checkpoint_path, global_step=epoch)
    print("Model saved in file: %s" % filename)

Epoch: 0/200 loss:149
Model saved in file: /tmp/vae_cnn/model-0
Epoch: 1/200 loss:119
Model saved in file: /tmp/vae_cnn/model-1
Epoch: 2/200 loss:115
Model saved in file: /tmp/vae_cnn/model-2
Epoch: 3/200 loss:110
Model saved in file: /tmp/vae_cnn/model-3
Epoch: 4/200 loss:102
Model saved in file: /tmp/vae_cnn/model-4
Epoch: 5/200 loss:111
Model saved in file: /tmp/vae_cnn/model-5
Epoch: 6/200 loss:107
Model saved in file: /tmp/vae_cnn/model-6
Epoch: 7/200 loss:113
Model saved in file: /tmp/vae_cnn/model-7
Epoch: 8/200 loss:105
Model saved in file: /tmp/vae_cnn/model-8
Epoch: 9/200 loss:105
Model saved in file: /tmp/vae_cnn/model-9
Epoch: 10/200 loss:109
Model saved in file: /tmp/vae_cnn/model-10
Epoch: 11/200 loss:103
Model saved in file: /tmp/vae_cnn/model-11
Epoch: 12/200 loss:106
Model saved in file: /tmp/vae_cnn/model-12
Epoch: 13/200 loss:102
Model saved in file: /tmp/vae_cnn/model-13
Epoch: 14/200 loss:100
Model saved in file: /tmp/vae_cnn/model-14
Epoch: 15/200 loss:103
Model s