In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.mlab as mlab
import math

In [None]:
for variance in [1, 0.5, 0.1, 0.01]:
    for mu, prob in [(1, 0.1), (4, 0.4), (5, 0.3), (7, 0.2)]:
        sigma = math.sqrt(variance)
        x = np.linspace(-1, 9, 100)
        plt.plot(x, prob*mlab.normpdf(x, mu, sigma))
    plt.show()
    plt.clf()

# Variation Deep Embeddings

Paper Link - https://arxiv.org/abs/1611.05148

## Importing Libraries and Config File

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import os
import numpy as np

from tqdm import tqdm
from time import sleep

from includes.config import Config
from includes.utils import get_data, Dataset

import tensorflow as tf

from sklearn import mixture
from sklearn.manifold import TSNE

ImportError: No module named tensorflow

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as grid

In [None]:
tf.logging.set_verbosity(tf.logging.ERROR)

In [None]:
ds = tf.contrib.distributions
xav_init = tf.contrib.layers.xavier_initializer()

## Loading Data and Setting Parameters

In [None]:
dataset = "spiral"

In [None]:
conf = Config(data=dataset)

In [None]:
train_data, test_data = get_data(dataset)

train_data = Dataset(train_data, batch_size=conf.batch_size)
test_data = Dataset(test_data, batch_size=conf.batch_size)

## Visualizing the Data

In [None]:
if dataset == "spiral":
    plt.scatter(train_data.data[:, 0], train_data.data[:, 1], s=0.5)
elif dataset == "mnist" or dataset == "mmnist":
    images = train_data.data[:100]
    if dataset == "mmnist":
        images = images[:, :-1]
    
    images = images.reshape((100, 28, 28))
    images = np.concatenate(images, axis=1)
    images = np.array([images[:, x:x+280] for x in range(0, 2800, 280)])
    images = np.concatenate(images, axis=0)
    images = np.concatenate(
        [np.zeros((280, 10)), images, np.zeros((280, 10))], axis=1
    )
    images = np.concatenate(
        [np.zeros((10, 300)), images, np.zeros((10, 300))], axis=0
    )

    plt.imshow(images, cmap='gray')
    plt.xticks([])
    plt.yticks([])
plt.title("Test")
plt.savefig("test.png")
plt.show()

## Defining the Encoder

In [None]:
def encoder():
    global X
    
    h_encoders = [
        tf.layers.dense(
            X,
            conf.encoder_hidden_size[0],
            activation=tf.nn.relu,
            kernel_initializer=xav_init,
            name="encoder_hidden_layer_0"
        )
    ]
    
    for index, size in enumerate(conf.encoder_hidden_size[1:]):
        h_encoders.append(
            tf.layers.dense(
                h_encoders[index],
                size,
                activation=tf.nn.relu,
                kernel_initializer=xav_init,
                name="encoder_hidden_layer_" + str(index + 1)
            )
        )
    
    encoder_mean = tf.layers.dense(
        h_encoders[-1],
        conf.latent_dim,
        kernel_initializer=xav_init,
        name="encoder_mean"
    )
    encoder_log_var = tf.layers.dense(
        h_encoders[-1],
        conf.latent_dim,
        kernel_initializer=xav_init,
        name="encoder_log_variance"
    )
    
    return encoder_mean, encoder_log_var

## Defining the Decoder

In [None]:
def decoder():
    global Z
    
    h_decoders = [
        tf.layers.dense(
            Z,
            conf.decoder_hidden_size[0],
            activation=tf.nn.relu,
            kernel_initializer=xav_init,
            name="decoder_hidden_layer_0"
        )
    ]
    
    for index, size in enumerate(conf.decoder_hidden_size[1:]):
        h_decoders.append(
            tf.layers.dense(
                h_decoders[index],
                size,
                activation=tf.nn.relu,
                kernel_initializer=xav_init,
                name="decoder_hidden_layer_" + str(index + 1)
            )
        )
    
    out_X = tf.layers.dense(
        h_decoders[-1],
        conf.input_dim,
        kernel_initializer=xav_init,
        name="decoder_X"
    )
    
    return out_X, tf.nn.sigmoid(out_X)

## Sampling Z using the reparametrization trick

In [None]:
def sample_Z():
    global epsilon
    global encoder_mean, encoder_log_var
    
    return encoder_mean + tf.exp(encoder_log_var / 2) * epsilon

## Initializing and Learning the GMM Priors (Pretraining)

In [None]:
def init_prior():
    prior_means = tf.Variable(
        tf.random_normal((conf.n_clusters, conf.latent_dim), stddev=5.0),
        dtype=tf.float32,
        name="prior_means"
    )
    prior_vars = tf.Variable(
        tf.ones((conf.n_clusters, conf.latent_dim)),
        dtype=tf.float32,
        name="prior_vars"
    )
    prior_weights = tf.Variable(
        tf.ones((conf.n_clusters)) / conf.n_clusters,
        dtype=tf.float32,
        name="prior_weights"
    )
    
    return prior_means, prior_vars, prior_weights

In [None]:
def init_gmm_priors(Z=None, train=True):
    global init_gmm_model
    
    if train == True:
        init_gmm_model.fit(Z)
        
    return init_gmm_model

## Defining the Posterior of Cluster Assignments

In [None]:
def q_c():
    global Z
    global prior_means, prior_vars, prior_weights
    
    def fn_cluster(_, k):
        q = prior_weights[k] * ds.MultivariateNormalDiag(loc=prior_means[k], scale_diag=prior_vars[k]).prob(Z) + 1e-10
        return tf.reshape(q, [conf.batch_size])

    clusters = tf.Variable(tf.range(conf.n_clusters))
    probs = tf.scan(fn_cluster, clusters, initializer=tf.ones([conf.batch_size]))
    probs = tf.transpose(probs)
    probs = probs / tf.reshape(tf.reduce_sum(probs, 1), (-1, 1))
    return probs

## Defining the Loss Function

In [None]:
def vae_loss(dataset="mnist"):
    global cluster_weights
    global X, decoded_exp_X_mean
    global encoder_mean, encoder_log_var
    global prior_means, prior_vars, prior_weights
    
    J = 0.0
    if dataset == "mnist":
        J += conf.regularizer * tf.reduce_sum(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=X, logits=decoded_exp_X_mean
            ),
            axis=1
        )
    elif dataset == "spiral":
        J += conf.regularizer * tf.reduce_sum(
            tf.square(decoded_exp_X_mean - X),
            axis=1
        )
    J -= tf.reduce_sum(cluster_weights * tf.log(prior_weights), axis=1)
    J += tf.reduce_sum(cluster_weights * tf.log(cluster_weights), axis=1)
    J -= 0.5 * tf.reduce_sum(1 + encoder_log_var, axis=1)

    def fn_cluster(previous_output, current_input):
        k = current_input
        l = previous_output + 0.5 * cluster_weights[:, k] * tf.reduce_sum(
            tf.log(prior_vars[k]) + 
            (
                tf.exp(encoder_log_var) + 
                tf.square(encoder_mean - prior_means[k])
            ) / prior_vars[k], axis=1
        )
        return l

    clusters = tf.Variable(tf.range(conf.n_clusters))
    y = tf.scan(fn_cluster, clusters, initializer=tf.zeros(conf.batch_size))
    
    J += y[-1, :]
    
    return tf.reduce_mean(J)

## Defining the TensorFlow Graph

In [None]:
X = tf.placeholder(tf.float32, [None, conf.input_dim])
epsilon = tf.placeholder(tf.float32, [None, conf.latent_dim])

prior_means, prior_vars, prior_weights = init_prior()

encoder_mean, encoder_log_var = encoder()

Z = sample_Z()

decoded_exp_X_mean, decoded_X_mean = decoder()

cluster_weights = q_c()

In [None]:
loss = vae_loss(dataset)
learning_rate = tf.train.exponential_decay(
    learning_rate=conf.learning_rate,
    global_step=0,
    decay_steps=train_data.epoch_len * conf.decay_steps,
    decay_rate=conf.decay_rate
)

train_step = tf.train.AdamOptimizer(
    learning_rate=learning_rate, 
    epsilon=conf.epsilon
).minimize(loss)

## Defining the Pretraining Loss

In [None]:
if dataset == "mnist":
    pretrain_loss = tf.reduce_mean(
        tf.reduce_sum(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=X, 
                logits=decoded_exp_X_mean
            ), axis=1
        ) + 0.5 * tf.reduce_sum(
            tf.exp(encoder_log_var) + encoder_mean ** 2 - 1. - encoder_log_var,
            axis=1
        )
    )
elif dataset == "spiral":
    pretrain_loss = tf.reduce_mean(
        tf.reduce_sum(
            tf.square(
                decoded_exp_X_mean - X
            ), axis=1
        ) + 0.5 * tf.reduce_sum(
            tf.exp(encoder_log_var) + encoder_mean ** 2 - 1. - encoder_log_var,
            axis=1
        )
    )

learning_rate = tf.train.exponential_decay(
    learning_rate=conf.learning_rate,
    global_step=0,
    decay_steps=train_data.epoch_len * conf.decay_steps,
    decay_rate=conf.decay_rate
)

pretrain_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(pretrain_loss)

## Defining functions for Plotting

In [None]:
def regeneration_plot(epoch, data="mnist"):
    if not os.path.exists("plots/regenerated/" + dataset):
        os.makedirs("plots/regenerated/" + dataset)

    gs = grid.GridSpec(1, 2) 

    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
        
    if data == "mnist":
        decoded_image = sess.run(
            decoded_X_mean,
            feed_dict={
                X: test_data.data[:100],
                epsilon: np.random.randn(100, conf.latent_dim)
            }
        )

        decoded_image = np.array(decoded_image).reshape((100, 784))
        figure = np.zeros((280, 280))

        for i in range(0, 10):
            for j in range(0, 10):
                figure[i * 28 : (i + 1) * 28, j * 28 : (j + 1) * 28] = (
                    decoded_image[10 * i + j].reshape((28, 28)) * 255
                )

        ax2.imshow(figure, cmap="Greys_r")

        decoded_image = np.array(test_data.data[:100])
        figure = np.zeros((280, 280))

        for i in range(0, 10):
            for j in range(0, 10):
                figure[i * 28 : (i + 1) * 28, j * 28 : (j + 1) * 28] = (
                    decoded_image[10 * i + j].reshape((28, 28)) * 255
                )

        ax1.imshow(figure, cmap="Greys_r")
        
    elif data == "spiral":
        decoded_X = sess.run(
            decoded_exp_X_mean,
            feed_dict={
                X: test_data.data,
                epsilon: np.random.randn(len(test_data.data), conf.latent_dim)
            }
        )
        
        ax1.scatter(test_data.data[:, 0], test_data.data[:, 1], s=0.5)
        ax2.scatter(decoded_X[:, 0], decoded_X[:, 1], s=0.5)
        
    ax1.spines['left'].set_visible(False)
    ax1.spines['bottom'].set_visible(False)

    ax2.spines['bottom'].set_visible(False)

    ax1.get_xaxis().set_visible(False)
    ax1.get_yaxis().set_visible(False)
    ax2.get_xaxis().set_visible(False)
    ax2.get_yaxis().set_visible(False)

    plt.tight_layout()
    plt.show()
    plt.savefig("plots/regenerated/" + dataset + "/" + str(epoch) + ".png", transparent=True)
#     plt.close()

In [None]:
def sample_plot(epoch, dataset="mnist"):
    if not os.path.exists("plots/sampled/" + dataset):
        os.makedirs("plots/sampled/" + dataset)
    
    gs = grid.GridSpec(1, 2)

    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    
    mus, sigmas = sess.run([prior_means, prior_vars], feed_dict={})
    
    sigmas = np.sqrt(sigmas)
    
    if dataset == "mnist":
        sample_Z = list()
        decoded_X = list()
        for k in range(0, conf.n_clusters):
            s_Z = mus[k] + sigmas[k] * np.random.randn(1000, conf.latent_dim)
            sample_Z.append(s_Z)

            decoded_X.append(sess.run(
                decoded_X_mean,
                feed_dict={
                    Z: s_Z
                }
            ))

        sample_Z = np.concatenate(sample_Z, axis=0)
        if conf.latent_dim > 2:
            sample_Z = TSNE(n_components=2).fit_transform(sample_Z)
        
        sample_Z = sample_Z.reshape((conf.n_clusters, sample_Z.shape[0] / conf.n_clusters, 2))

        image = (
            1 - np.concatenate(
                np.concatenate(
                    np.array(decoded_X)[:, :10].reshape((10, 10, 28, 28)),
                    axis=1
                ), 
                axis=1
            )
        )

        ax1.imshow(image, cmap="Greys")

        for k in range(0, conf.n_clusters):
            ax2.scatter(sample_Z[k][:, 0], sample_Z[k][:, 1], s=0.5)
        
    elif dataset == "spiral":
        eps = np.random.randn(conf.n_clusters, 1000, conf.latent_dim)
        
        sample_Z = np.concatenate([eps[k] * sigmas[k] + mus[k] for k in range(0, conf.n_clusters)])
        
        decoded_X = [
            sess.run(
                decoded_exp_X_mean,
                feed_dict={
                    Z: sample_Z[1000*k:1000*(k + 1)]
                }
            ) for k in range(0, conf.n_clusters)
        ]
        
        if conf.latent_dim > 2:
            sample_Z = TSNE(n_components=2).fit_transform(sample_Z)

        for k in range(0, conf.n_clusters):
            ax1.scatter(decoded_X[k][:, 0], decoded_X[k][:, 1], s=0.5)
            ax2.scatter(sample_Z[1000*k:1000*(k+1), 0], sample_Z[1000*k:1000*(k+1), 1], s=0.5)

    ax1.spines['left'].set_visible(False)
    ax1.spines['bottom'].set_visible(False)

    ax2.spines['bottom'].set_visible(False)

    ax1.get_xaxis().set_visible(False)
    ax1.get_yaxis().set_visible(False)
    ax2.get_xaxis().set_visible(False)
    ax2.get_yaxis().set_visible(False)

    plt.tight_layout()
    plt.show()
    plt.savefig("plots/sampled/" + dataset + "/" + str(epoch) + ".png", transparent=True)
#     plt.close()

## Starting the Session

In [None]:
sess = tf.Session()
tf.global_variables_initializer().run(session=sess)

## Pretraining for VAE parameters

In [None]:
with tqdm(range(conf.pretrain_vae_n_epochs), postfix={"loss": float("inf")}) as bar:
    for epoch in bar:
        J = 0.0
        for batch in train_data.get_batches():
            out = sess.run(
                [pretrain_loss, pretrain_step],
                feed_dict={
                    X: batch,
                    epsilon: np.random.randn(conf.batch_size, conf.latent_dim)
                }
            )
            J += out[0] / train_data.epoch_len

        if epoch % 100 == 0:
            regeneration_plot(epoch, dataset)

        bar.set_postfix({"loss": J})

## Pretraining for GMM parameters

In [None]:
lv = sess.run(Z, feed_dict={
    X: train_data.data,
    epsilon: np.random.randn(len(train_data.data), conf.latent_dim)
})
init_gmm_model = mixture.GaussianMixture(
    n_components=conf.n_clusters,
    covariance_type="diag",
    max_iter=conf.pretrain_gmm_n_iters,
        n_init=conf.pretrain_gmm_n_inits,
    weights_init=np.ones(conf.n_clusters) / conf.n_clusters,
)

init_gmm_means = tf.assign(prior_means, init_gmm_priors(Z=lv).means_)
init_gmm_vars = tf.assign(prior_vars, init_gmm_priors(train=False).covariances_)
init_gmm_weights = tf.assign(prior_weights, init_gmm_priors(train=False).weights_)

_ = sess.run([init_gmm_means, init_gmm_vars, init_gmm_weights], feed_dict={})

sample_plot(0, dataset)

## Training the VaDE Model

In [None]:
with tqdm(range(conf.n_epochs), postfix={"loss": float("inf")}) as bar:
    for epoch in bar:
        J = 0.0
        for batch in train_data.get_batches():
            out = sess.run(
                [loss, train_step],
                feed_dict={
                    X: batch,
                    epsilon: np.random.randn(conf.batch_size, conf.latent_dim)
                }
            )
            J += out[0] / train_data.epoch_len

        bar.set_postfix({"loss": J})
    
        if (epoch + 1) % 200 == 0:
            sample_plot(epoch + 1, dataset)
            regeneration_plot(epoch, dataset)