# Variation Deep Embeddings

Paper Link - https://arxiv.org/abs/1611.05148

## Importing Libraries and Config File

In [None]:
import os

import config

import numpy as np
from tqdm import tqdm

import tensorflow as tf

from sklearn import mixture

import matplotlib.pyplot as plt
import matplotlib.gridspec as grid

tf.logging.set_verbosity(tf.logging.ERROR)

In [None]:
ds = tf.contrib.distributions
xav_init = tf.contrib.layers.xavier_initializer()

## Loading and Binarizing MNIST Data

In [None]:
mnist = tf.contrib.learn.datasets.mnist.load_mnist(train_dir="mnist_data")

test_data = mnist.test.images

In [None]:
mnist.train.images[mnist.train.images < 0.5] = 0
mnist.train.images[mnist.train.images > 0] = 1

## Defining the Encoder

In [None]:
def encoder():
    global X
    
    h_encoders = [
        tf.layers.dense(
            X,
            config.encoder_hidden_size[0],
            activation=tf.nn.relu,
            kernel_initializer=xav_init,
            name="encoder_hidden_layer_0"
        )
    ]
    
    for index, size in enumerate(config.encoder_hidden_size[1:]):
        h_encoders.append(
            tf.layers.dense(
                h_encoders[index],
                size,
                activation=tf.nn.relu,
                kernel_initializer=xav_init,
                name="encoder_hidden_layer_" + str(index + 1)
            )
        )
    
    encoder_mean = tf.layers.dense(
        h_encoders[-1],
        config.latent_dim,
        kernel_initializer=xav_init,
        name="encoder_mean"
    )
    encoder_log_var = tf.layers.dense(
        h_encoders[-1],
        config.latent_dim,
        kernel_initializer=xav_init,
        name="encoder_log_variance"
    )
    
    return encoder_mean, encoder_log_var

## Defining the Decoder

In [None]:
def decoder():
    global Z
    
    h_decoders = [
        tf.layers.dense(
            Z,
            config.decoder_hidden_size[0],
            activation=tf.nn.relu,
            kernel_initializer=xav_init,
            name="decoder_hidden_layer_0"
        )
    ]
    
    for index, size in enumerate(config.decoder_hidden_size[1:]):
        h_decoders.append(
            tf.layers.dense(
                h_decoders[index],
                size,
                activation=tf.nn.relu,
                kernel_initializer=xav_init,
                name="decoder_hidden_layer_" + str(index + 1)
            )
        )
    
    out_X = tf.layers.dense(
        h_decoders[-1],
        config.input_dim,
        kernel_initializer=xav_init,
        name="decoder_X"
    )
    
    return out_X, tf.nn.sigmoid(out_X)

## Sampling Z using the reparametrization trick

In [None]:
def sample_Z():
    global epsilon
    global encoder_mean, encoder_log_var
    
    return encoder_mean + tf.exp(encoder_log_var / 2) * epsilon

## Initializing and Learning the GMM Priors (Pretraining)

In [None]:
def init_prior():
    prior_means = tf.Variable(
        tf.random_normal((config.n_clusters, config.latent_dim), stddev=5.0),
        dtype=tf.float32,
        name="prior_means"
    )
    prior_vars = tf.Variable(
        tf.ones((config.n_clusters, config.latent_dim)),
        dtype=tf.float32,
        name="prior_vars"
    )
    prior_weights = tf.Variable(
        tf.ones((config.n_clusters)) / config.n_clusters,
        dtype=tf.float32,
        name="prior_weights"
    )
    
    return prior_means, prior_vars, prior_weights

In [None]:
def init_gmm_priors(Z=None, train=True):
    global init_gmm_model
    
    if train == True:
        init_gmm_model.fit(Z)
        
    return init_gmm_model

## Defining the Posterior of Cluster Assignments

In [None]:
def q_c():
    global Z
    global prior_means, prior_vars, prior_weights
    
    def fn_cluster(_, k):
        q = prior_weights[k] * ds.MultivariateNormalDiag(loc=prior_means[k], scale_diag=prior_vars[k]).prob(Z) + 1e-10
        return tf.reshape(q, [config.batch_size])

    clusters = tf.Variable(tf.range(config.n_clusters))
    probs = tf.scan(fn_cluster, clusters, initializer=tf.ones([config.batch_size]))
    probs = tf.transpose(probs)
    probs = probs / tf.reshape(tf.reduce_sum(probs, 1), (-1, 1))
    return probs

## Defining the Loss Function

In [None]:
def vae_loss():
    global cluster_weights
    global X, decoded_exp_X_mean
    global encoder_mean, encoder_log_var
    global prior_means, prior_vars, prior_weights
    
    J = 0.0
    J += config.regularizer * tf.reduce_sum(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=X, logits=decoded_exp_X_mean),
        axis=1
    )
    J -= tf.reduce_sum(cluster_weights * tf.log(prior_weights), axis=1)
    J += tf.reduce_sum(cluster_weights * tf.log(cluster_weights), axis=1)
    J -= 0.5 * tf.reduce_sum(1 + encoder_log_var, axis=1)

    def fn_cluster(previous_output, current_input):
        k = current_input
        l = previous_output + 0.5 * cluster_weights[:, k] * tf.reduce_sum(
            tf.log(prior_vars[k]) + (tf.exp(encoder_log_var) + tf.square(encoder_mean - prior_means[k])) / prior_vars[k], axis=1
        )
        return l

    clusters = tf.Variable(tf.range(config.n_clusters))
    y = tf.scan(fn_cluster, clusters, initializer=tf.zeros(config.batch_size))
    
    J += y[-1, :]
    
    return tf.reduce_mean(J)

## Defining the TensorFlow Graph

In [None]:
epoch_len = int(len(mnist.train.images) / config.batch_size)

In [None]:
X = tf.placeholder(tf.float32, [None, config.input_dim])
epsilon = tf.placeholder(tf.float32, [None, config.latent_dim])

prior_means, prior_vars, prior_weights = init_prior()

encoder_mean, encoder_log_var = encoder()

Z = sample_Z()

decoded_exp_X_mean, decoded_X_mean = decoder()

cluster_weights = q_c()

In [None]:
loss = vae_loss()
learning_rate = tf.train.exponential_decay(
    learning_rate=config.adam_learning_rate,
    global_step=0,
    decay_steps=epoch_len * config.adam_decay_steps,
    decay_rate=config.adam_decay_rate
)
train_step = tf.train.AdamOptimizer(learning_rate=learning_rate, epsilon=config.adam_epsilon).minimize(loss)

## Defining the Pretraining Loss

In [None]:
pretrain_loss = tf.reduce_mean(
    tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=X, logits=decoded_exp_X_mean), axis=1) \
    + 0.5 * tf.reduce_sum(tf.exp(encoder_log_var) + encoder_mean ** 2 - 1. - encoder_log_var, axis=1)
)
pretrain_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(pretrain_loss)

## Defining functions for Plotting

In [None]:
def regeneration_plot(epoch):
    if not os.path.exists("plots/regenerated"):
        os.makedirs("plots/regenerated")
    
    np.random.shuffle(test_data)

    decoded_image = sess.run(
        [decoded_X_mean],
        feed_dict={
            X: test_data[:100],
            epsilon: np.random.randn(100, config.latent_dim)
        }
    )

    fig = plt.figure(figsize=(12, 6))
    gs = grid.GridSpec(1, 2) 

    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])

    decoded_image = np.array(decoded_image).reshape((100, 784))
    figure = np.zeros((280, 280))

    for i in range(0, 10):
        for j in range(0, 10):
            figure[i * 28 : (i + 1) * 28, j * 28 : (j + 1) * 28] = decoded_image[10 * i + j].reshape((28, 28)) * 255

    ax1.imshow(figure, cmap="Greys_r")

    decoded_image = np.array(test_data[:100])
    figure = np.zeros((280, 280))

    for i in range(0, 10):
        for j in range(0, 10):
            figure[i * 28 : (i + 1) * 28, j * 28 : (j + 1) * 28] = decoded_image[10 * i + j].reshape((28, 28)) * 255

    ax2.imshow(figure, cmap="Greys_r")

    plt.tight_layout()
    plt.savefig("plots/regenerated/" + str(epoch) + ".png")
    plt.close()

In [None]:
def sample_plot(epoch):
    if not os.path.exists("plots/sampled"):
        os.makedirs("plots/sampled")
    
    mus, sigmas = sess.run([prior_means, prior_vars], feed_dict={})
    
    sigmas = np.sqrt(sigmas)
    
    figure = np.zeros((280, 280))
    fig = plt.figure(figsize=(6, 6))
    
    for k in range(0, 10):
        for i in range(0, 10):
            eps = np.random.randn(1, config.latent_dim)
            sample = eps * sigmas[k] + mus[k]
            
            decoded_image = sess.run(
                decoded_X_mean,
                feed_dict={
                    Z: sample
                }
            ).reshape((28, 28)) * 255

            figure[k * 28 : (k + 1) * 28, i * 28 : (i + 1) * 28] = decoded_image
 
    plt.imshow(figure, cmap="Greys_r")

    plt.tight_layout()
    plt.savefig("plots/sampled/" + str(epoch) + ".png")
    plt.close()

## Starting the Session

In [None]:
sess = tf.Session()
tf.global_variables_initializer().run(session=sess)

## Pretraining for VAE parameters

In [None]:
for epoch in range(config.pretrain_vae_n_epochs):
    J = 0.0
    for i in tqdm(range(epoch_len)):
        X_batch = mnist.train.next_batch(config.batch_size)[0]
        out = sess.run(
            [pretrain_loss, pretrain_step],
            feed_dict={
                X: X_batch,
                epsilon: np.random.randn(config.batch_size, config.latent_dim)
            }
        )
        J += out[0] / epoch_len
        
    regeneration_plot(epoch)
    
    print("Pretrain Epoch %d: %.3f" % (epoch + 1, J))

## Pretraining for GMM parameters

In [None]:
lv = sess.run(Z, feed_dict={
    X: mnist.train.next_batch(config.pretrain_gmm_train_size)[0],
    epsilon: np.random.randn(config.pretrain_gmm_train_size, config.latent_dim)
})

init_gmm_model = mixture.GaussianMixture(
    n_components=config.n_clusters,
    covariance_type="diag",
    max_iter=config.pretrain_gmm_n_iters,
    n_init=config.pretrain_gmm_n_inits,
    weights_init=np.ones(config.n_clusters) / config.n_clusters,
)

init_gmm_means = tf.assign(prior_means, init_gmm_priors(Z=lv).means_)
init_gmm_vars = tf.assign(prior_vars, init_gmm_priors(train=False).covariances_)
init_gmm_weights = tf.assign(prior_weights, init_gmm_priors(train=False).weights_)

_ = sess.run([init_gmm_means, init_gmm_vars, init_gmm_weights], feed_dict={})

## Training the VaDE Model

In [None]:
for epoch in range(config.n_epochs):
    J = 0.0
    for i in tqdm(range(epoch_len)):
        X_batch = mnist.train.next_batch(config.batch_size)[0]
        out = sess.run(
            [loss, train_step],
            feed_dict={
                X: X_batch,
                epsilon: np.random.randn(config.batch_size, config.latent_dim)
            }
        )
        J += out[0] / epoch_len
    
    print("Epoch %d: %.3f" % (epoch + 1, J))
    
    sample_plot(epoch)
    regeneration_plot(epoch)