# Variational AutoEncoder (MLP) Tensorflow
Train a variational autoencoder with MNIST dataset

#### References:
* http://kvfrans.com/variational-autoencoders-explained/
* https://github.com/kvfrans/variational-autoencoder
* https://github.com/int8/VAE_tensorflow
* http://int8.io/variational-autoencoder-in-tensorflow/
* http://blog.fastforwardlabs.com/2016/08/22/under-the-hood-of-the-variational-autoencoder-in.html
* http://blog.fastforwardlabs.com/2016/08/12/introducing-variational-autoencoders-in-prose-and.html
* https://blog.keras.io/building-autoencoders-in-keras.html
* https://jaan.io/what-is-variational-autoencoder-vae-tutorial/
* https://arxiv.org/pdf/1606.05908.pdf
* https://arxiv.org/pdf/1312.6114.pdf
* http://wiseodd.github.io/techblog/2016/12/10/variational-autoencoder/

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

In [None]:
from models import VAE_CNN
model = VAE_CNN()
model_in = model.input
model_out = model.output
model_out_flat = model.output_flat
z_mean = model.z_mean
z_stddev = model.z_stddev

### Define loss

In [None]:
with tf.name_scope("VAE_LOSS"):
    generation_loss = -tf.reduce_sum(model_in * tf.log(1e-8 + model_out_flat) + (1-model_in) * tf.log(1e-8 + 1 - model_out_flat),1)
    latent_loss = 0.5 * tf.reduce_sum(tf.square(z_mean) + tf.square(z_stddev) - tf.log(tf.square(z_stddev)) - 1,1)
    loss = tf.reduce_mean(generation_loss + latent_loss)

# Solver configuration
with tf.name_scope("Solver"):
    train_step = tf.train.AdamOptimizer(0.001).minimize(loss)

### Build Graph

In [None]:
init = tf.global_variables_initializer()

# Avoid allocating the whole memory
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

#sess = tf.Session()
sess.run(init)

merged_summary = tf.summary.merge_all()
writer = tf.summary.FileWriter("/tmp/vae_cnn/1")
writer.add_graph(sess.graph)

### Train

In [None]:
for i in range(2000):
    # Get batch of 50 images
    batch = mnist.train.next_batch(50)
    
    # Dump summary
    if i % 5 == 0:
        s = sess.run(merged_summary, feed_dict={model_in:batch[0]})
        writer.add_summary(s,i)    
    
    
    # Train actually here
    #train_step.run(session=sess, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
    sess.run(train_step, feed_dict={model_in:batch[0]})