## Train CAE on people Dataset
#### References
* https://jmetzen.github.io/2015-11-27/vae.html
* http://int8.io/variational-autoencoder-in-tensorflow/
* https://github.com/int8/VAE_tensorflow
* https://github.com/dagcilibili/variational-autoencoder/blob/master/vae.py
* https://arxiv.org/pdf/1312.6114.pdf
* https://www.tensorflow.org/api_docs/python/tf/clip_by_value
* https://arxiv.org/pdf/1609.04468.pdf
* http://int8.io/variational-autoencoder-in-tensorflow/
* http://blog.fastforwardlabs.com/2016/08/12/introducing-variational-autoencoders-in-prose-and.html
* http://blog.fastforwardlabs.com/2016/08/22/under-the-hood-of-the-variational-autoencoder-in.html
* http://kvfrans.com/variational-autoencoders-explained/
* http://torch.ch/blog/2015/11/13/gan.html
* https://www.slideshare.net/ShaiHarel/variational-autoencoder-talk

In [1]:
import tensorflow as tf
import models
import model_util as util
import os
from handle_data import HandleData

# Regularization value
start_lr = 0.0000001
batch_size=500
epochs = 600
input_train_lmdb = '/home/leoara01/work/PPSS_LMDB_Train'
logs_path = './logs'
save_dir = './save'
gpu_fraction = 0.4
LATENT_SIZE=100
DECAY_CONTROL = 10000
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)

# Delete logs directory if exist
if os.path.exists(logs_path):    
    os.system("rm -rf " + logs_path)

# Delete save directory if exist
if os.path.exists(save_dir):    
    os.system("rm -rf " + save_dir)
    

### Define a Model

In [2]:
# Build model
vae_model = models.CAE_AutoEncoderFE()
#vae_model = models.CAE_AutoEncoderSegnet()

# Get Placeholders
model_in = vae_model.input
model_out = vae_model.output
model_out_flat = vae_model.output_flat
model_in_flat = vae_model.input_flat

# Get number of parameters
print('Number of parameters:', util.get_paremeter_size(tf.trainable_variables()))

# Get all model "parameters" that are trainable
train_vars = tf.trainable_variables()

Number of parameters: 126249


### Define a Variational AutoEncoder Loss

In [3]:
# Add loss (Should be a generative model here....)
with tf.name_scope("CAE_LOSS"):        
    #L1 loss
    generation_loss = tf.norm(model_in-model_out, ord=2)
        
    # Merge the losses
    loss = tf.reduce_mean(generation_loss)
    #loss = tf.reduce_mean(generation_loss)

### Define a Solver

In [4]:
# Get ops to update moving_mean and moving_variance from batch_norm
# Reference: https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.name_scope("Solver"):
    global_step = tf.Variable(0, trainable=False)
    starter_learning_rate = start_lr
    # decay every 10000 steps with a base of 0.96
    learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
                                               DECAY_CONTROL, 0.1, staircase=True)

    # Basically update the batch_norm moving averages before the training step
    # http://ruishu.io/2016/12/27/batchnorm/
    with tf.control_dependencies(update_ops):
        train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

### Build the Graph

In [5]:
# Avoid allocating the whole memory
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
sess = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))
        
# Initialize all random variables (Weights/Bias)
sess.run(tf.global_variables_initializer())

# Just create saver for saving checkpoints
saver = tf.train.Saver(max_to_keep=None)

### Add some variables to watch on tensorboard

In [6]:
# Monitor loss, learning_rate, global_step, etc...
tf.summary.scalar("loss_train", loss)
tf.summary.scalar("learning_rate", learning_rate)
tf.summary.scalar("global_step", global_step)

# Add input image on summary
tf.summary.image("input_image", model_in, 4)
tf.summary.image("output_image", model_out, 4)

# merge all summaries into a single op
merged_summary_op = tf.summary.merge_all()

# Configure where to save the logs for tensorboard
summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())

### Load LMDB Data

In [None]:
data = HandleData(path=input_train_lmdb, path_val='', val_perc=0.01)
num_images_epoch = int(data.get_num_images() / batch_size)
print('Num samples', data.get_num_images(), 'Iterations per epoch:', num_images_epoch, 'batch size:',batch_size)

Loading training data
LMDB file
Spliting training and validation
Number training images: 3168
Number validation images: 39
Num samples 3961 Iterations per epoch: 7 batch size: 500


### Train Loop

In [None]:
# For each epoch
for epoch in range(epochs):
    for i in range(int(data.get_num_images() / batch_size)):
        # Get training batch
        xs_train, ys_train = data.LoadTrainBatch(batch_size, should_augment=False, do_resize=True)

        # Send training batch to tensorflow graph (Dropout enabled)
        train_step.run(feed_dict={model_in: xs_train})

        # write logs at every iteration
        summary = merged_summary_op.eval(feed_dict={model_in: xs_train})
        summary_writer.add_summary(summary, epoch * batch_size + i)

    # Save checkpoint after each epoch
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    checkpoint_path = os.path.join(save_dir, "model")
    filename = saver.save(sess, checkpoint_path, global_step=epoch)
    print("Model saved in file: %s" % filename)

    # Shuffle data at each epoch end
    print("Shuffle data")
    data.shuffleData()

INFO:tensorflow:./save/model-0 is not in all_model_checkpoint_paths. Manually adding it.
Model saved in file: ./save/model-0
Shuffle data
INFO:tensorflow:./save/model-1 is not in all_model_checkpoint_paths. Manually adding it.
Model saved in file: ./save/model-1
Shuffle data
INFO:tensorflow:./save/model-2 is not in all_model_checkpoint_paths. Manually adding it.
Model saved in file: ./save/model-2
Shuffle data
INFO:tensorflow:./save/model-3 is not in all_model_checkpoint_paths. Manually adding it.
Model saved in file: ./save/model-3
Shuffle data
INFO:tensorflow:./save/model-4 is not in all_model_checkpoint_paths. Manually adding it.
Model saved in file: ./save/model-4
Shuffle data
INFO:tensorflow:./save/model-5 is not in all_model_checkpoint_paths. Manually adding it.
Model saved in file: ./save/model-5
Shuffle data
INFO:tensorflow:./save/model-6 is not in all_model_checkpoint_paths. Manually adding it.
Model saved in file: ./save/model-6
Shuffle data
INFO:tensorflow:./save/model-7 is 