In [None]:
"""
Class that implements a VAE
"""
from ipywidgets import *
import tensorflow as tf
import vae_tests.tfutils as ut
import vae_tests.config as config
import numpy as np
import vae_tests.augmentation as aug

from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_nn_ops

@ops.RegisterGradient("MaxPoolWithArgmax")
def _MaxPoolWithArgmaxGrad(op, grad, some_other_arg):
  return gen_nn_ops._max_pool_grad(op.inputs[0],
                                   op.outputs[0],
                                   grad,
                                   op.get_attr("ksize"),
                                   op.get_attr("strides"),
                                   padding=op.get_attr("padding"),
                                   data_format='NHWC')

class VAE():
    
    def __init__(self):
        # Training Parameters
        self.learning_rate   = 0.001 #0.001
        self.training_epochs = 1000
        self.batch_size      = 10
        self.display_step    = 1
        self.latent_variables = 2 #16

            
        self.scope="vaeengines"
        self.sess = tf.InteractiveSession()
    
    
    def buildDecoder(self):
        
        #-----------------------
        # decoder p(x|z)
        #-----------------------
        with tf.variable_scope(self.scope):
            #first, need to sample from unit Gaussian and scale the result(reparametrization trick)
            gaussian_sampled = sampleGaussian(mu=z_mean, log_sigma=z_log_sigma)

            print('gaussian_sampled')
            print(gaussian_sampled.get_shape())    

            deconv_dense = ut.full_layer_with_bn(gaussian_sampled, [latent_variables, 256], phase, name="deconv4")   
            #deconv_dense = ut.full_layer_with_bn(gaussian_sampled, tf.pack([in_batch,pool4.get_shape()[1]*pool4.get_shape()[2]*pool4.get_shape()[3]]), phase, name="deconv4")   


            print('deconv_dense')
            print(deconv_dense.get_shape())

            full_deconv = ut.full_layer_with_bn(deconv_dense, [256, 8*8*256], phase, name="full_deconv")   

            deconv4 = tf.reshape(full_deconv, tf.pack([in_batch,8,8,256]))

            print('deconv4')
            print(deconv4.get_shape())

            #in_batch = tf.shape(deconv4)[0]#deconv4.get_shape()[0]
            ksize=4
            factor=2
            number_of_classes = int(int(deconv4.get_shape()[3]) / factor)
            in_features = deconv4.get_shape()[3].value
            deconv3 = ut.deconv_layer_koo(deconv4, [ksize, ksize, 256, 256], 
                                          tf.pack([in_batch,16, 16, 256]) , 
                                          stride=2, name="deconv3")

            print('deconv3')
            print(deconv3.get_shape())

            deconv2 = ut.deconv_layer_koo(deconv3, [ksize, ksize, 128, 256], 
                                          tf.pack([in_batch,32, 32, 128]) , 
                                          stride=2, name="deconv2")

            print('deconv2')
            print(deconv2.get_shape())

            deconv1 = ut.deconv_layer_koo(deconv2, [ksize, ksize, 64, 128], 
                                          tf.pack([in_batch,64, 64, 64]) ,                                  
                                          stride=2, name="deconv1")


            print('deconv1')
            print(deconv1.get_shape())
            x_reconstructed = ut.deconv_layer_koo(deconv1, [ksize, ksize, 1, 64], 
                                                  tf.pack([in_batch,128, 128, 1]) , stride=2, name="x_reconstructed")

    
    def buildEncoder(self):
        with tf.variable_scope(self.scope):
            self.x_in_depth = tf.placeholder(tf.float32, [None, 128, 128, 1])

            self.phase = tf.placeholder(tf.bool, name='phase_train')



            #-----------------------
            # encoder p(z|x)
            #-----------------------
            self.conv1 = ut.conv_layer_with_bn(
                            self.x_in_depth, [3, 3, 1, 32], phase, name="conv1")
             # see https://groups.google.com/a/tensorflow.org/forum/#!topic/discuss/vf8eH9YMwVA
            dyn_input_shape = tf.shape(conv1)
            in_batch = dyn_input_shape[0]
            self.pool1, self.pool1_indices = tf.nn.max_pool_with_argmax(
                            self.conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],padding='SAME', name='pool1')        
            self.conv2 = ut.conv_layer_with_bn(
                            self.pool1, [3, 3, 32, 64], phase, name="conv2")
            self.pool2, self.pool2_indices = tf.nn.max_pool_with_argmax(
                            self.conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool2')       
            self.conv31 = ut.conv_layer_with_bn(
                            self.pool2, [3, 3, 64, 128], phase, name="conv31")
            self.conv32 = ut.conv_layer_with_bn(
                            self.conv31, [3, 3, 128, 128], phase, name="conv32")
            self.pool3, self.pool3_indices = tf.nn.max_pool_with_argmax(
                            self.conv32, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool3')        
            self.conv41 = ut.conv_layer_with_bn(
                            self.pool3, [3, 3, 128, 256], phase, name="conv41")
            self.conv42 = ut.conv_layer_with_bn(
                            self.conv41, [3, 3, 256, 256], phase, name="conv42")
            self.pool4, pool4_indices = tf.nn.max_pool_with_argmax(
                            self.conv42, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name='pool4')

            
            print('pool4')
            print(self.pool4.get_shape())
            self.dense = tf.reshape(self.pool4, [in_batch, self.pool4.get_shape().as_list()[1]*self.pool4.get_shape().as_list()[2]*256])
            print('dense')
            print(self.dense.get_shape())
            self.full = ut.full_layer_with_bn(self.dense, [self.dense.get_shape().as_list()[1], self.latent_variables], phase, name="full")
            print('full')
            print(self.full.get_shape())
            #-----------------------
            #the latent variables   
            self.z_mean = ut.full_layer_with_bn(self.full, [self.full.get_shape().as_list()[1], self.latent_variables], phase, name="z_mean")        
            self.z_log_sigma = ut.full_layer_with_bn(self.full, [self.full.get_shape().as_list()[1], self.latent_variables], phase, name="z_log_sigma")       
            print('z_mean')
            print(self.z_mean.get_shape())
            
            #---------------------------------
            # at this point define the KL loss
            #---------------------------------
            self.kl_loss = kullbackLeibler(self.z_mean, self.z_log_sigma)


    
    
    """
    Utility functions from https://github.com/fastforwardlabs/vae-tf/blob/master/vae.py
    """
    def sampleGaussian(mu, log_sigma):
            """(Differentiably!) draw sample from Gaussian with given shape, subject to random noise epsilon"""
            with tf.name_scope("sample_gaussian"):
                # reparameterization trick
                epsilon = tf.random_normal(tf.shape(log_sigma), name="epsilon")
                return mu + epsilon * tf.exp(log_sigma) # N(mu, I * sigma**2)


    def crossEntropy(obs, actual, offset=1e-7):
        """Binary cross-entropy, per training example"""
        # (tf.Tensor, tf.Tensor, float) -> tf.Tensor
        with tf.name_scope("cross_entropy"):
            # bound by clipping to avoid nan
            obs_ = tf.clip_by_value(obs, offset, 1 - offset)
            return -tf.reduce_sum(actual * tf.log(obs_) +
                                  (1 - actual) * tf.log(1 - obs_), 1)


    def l1_loss(obs, actual):
        """L1 loss (a.k.a. LAD), per training example"""
        # (tf.Tensor, tf.Tensor, float) -> tf.Tensor
        with tf.name_scope("l1_loss"):
            return tf.reduce_sum(tf.abs(obs - actual) , 1)


    def l2_loss(obs, actual):
        """L2 loss (a.k.a. Euclidean / LSE), per training example"""
        # (tf.Tensor, tf.Tensor, float) -> tf.Tensor
        with tf.name_scope("l2_loss"):
            return tf.reduce_sum(tf.square(obs - actual), 1)


    def kullbackLeibler(mu, log_sigma):
        """(Gaussian) Kullback-Leibler divergence KL(q||p), per training example"""
        # (tf.Tensor, tf.Tensor) -> tf.Tensor
        with tf.name_scope("KL_divergence"):
            # = -0.5 * (1 + log(sigma**2) - mu**2 - sigma**2)
            return -0.5 * tf.reduce_sum(1 + 2 * log_sigma - mu**2 - tf.exp(2 * log_sigma), 1)