In [None]:
import os

import config

import numpy as np
from tqdm import tqdm

import tensorflow as tf

import matplotlib.pyplot as plt
import matplotlib.gridspec as grid

tf.logging.set_verbosity(tf.logging.INFO)

In [None]:
mnist = tf.contrib.learn.datasets.mnist.load_mnist(train_dir="mnist_data")

test_data = mnist.test.images

In [None]:
mnist.train.next_batch(100)[1]

In [None]:
np.sum(mnist.train.next_batch(10)[0], axis=1)

In [None]:
mnist.train.images[mnist.train.images < 0.5] = 0
mnist.train.images[mnist.train.images > 0] = 1

In [None]:
ds = tf.contrib.distributions
xav_init = tf.contrib.layers.xavier_initializer()

In [None]:
def encoder(X, hidden_size):
    h_encoders = [
        tf.layers.dense(
            X,
            hidden_size[0],
            activation=tf.nn.relu,
            kernel_initializer=xav_init,
            name="encoder_hidden_layer_0"
        )
    ]
    
    for index, size in enumerate(hidden_size[1:]):
        h_encoders.append(
            tf.layers.dense(
                h_encoders[index],
                size,
                activation=tf.nn.relu,
                kernel_initializer=xav_init,
                name="encoder_hidden_layer_" + str(index + 1)
            )
        )
    
    encoder_mean = tf.layers.dense(
        h_encoders[-1],
        config.latent_dim,
        kernel_initializer=xav_init,
        name="encoder_mean"
    )
    encoder_log_var = tf.layers.dense(
        h_encoders[-1],
        config.latent_dim,
        kernel_initializer=xav_init,
        name="encoder_log_variance"
    )
    
    return encoder_mean, encoder_log_var

In [None]:
def decoder(Z, hidden_size):
    h_decoders = [
        tf.layers.dense(
            Z,
            hidden_size[0],
            activation=tf.nn.relu,
            kernel_initializer=xav_init,
            name="decoder_hidden_layer_0"
        )
    ]
    
    for index, size in enumerate(hidden_size[1:]):
        h_decoders.append(
            tf.layers.dense(
                h_decoders[index],
                size,
                activation=tf.nn.relu,
                kernel_initializer=xav_init,
                name="decoder_hidden_layer_" + str(index + 1)
            )
        )
    
    out_X = tf.layers.dense(
        h_decoders[-1],
        config.input_dim,
        kernel_initializer=xav_init,
        name="decoder_X"
    )
    
    return out_X, tf.nn.sigmoid(out_X)

In [None]:
def sample_Z(encoder_mean, encoder_log_var, epsilon):
    return encoder_mean + tf.exp(encoder_log_var / 2) * epsilon

In [None]:
def init_prior():
    prior_means = tf.Variable(
        tf.random_normal((config.n_clusters, config.latent_dim), stddev=5.0),
        dtype=tf.float32,
#         trainable=False,
        name="prior_means"
    )
    prior_vars = tf.Variable(
        tf.ones((config.n_clusters, config.latent_dim)),
        dtype=tf.float32,
#         trainable=False,
        name="prior_vars"
    )
    prior_weights = tf.Variable(
        tf.ones((config.n_clusters)) / config.n_clusters,
        dtype=tf.float32,
        trainable=False,
        name="prior_weights"
    )
    
    return prior_means, prior_vars, prior_weights

In [None]:
def q_c(Z):
    def fn_cluster(_, k):
        q = prior_weights[k] * ds.MultivariateNormalDiag(loc=prior_means[k], scale_diag=prior_vars[k]).prob(Z) + 1e-10
        return tf.reshape(q, [config.batch_size])

    clusters = tf.Variable(tf.range(config.n_clusters))
    probs = tf.scan(fn_cluster, clusters, initializer=tf.ones([config.batch_size]))
    probs = tf.transpose(probs)
    probs = probs / tf.reshape(tf.reduce_sum(probs, 1), (-1, 1))
    return probs

In [None]:
def vae_loss(cluster_weights, decoded_X_mean, encoder_mean, encoder_log_var):
    J = 0.0
    J += config.regularizer * tf.reduce_sum(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=X, logits=decoded_exp_X_mean),
        axis=1
    )
    J -= tf.reduce_sum(cluster_weights * tf.log(prior_weights), axis=1)
    J += tf.reduce_sum(cluster_weights * tf.log(cluster_weights), axis=1)
    J -= 0.5 * tf.reduce_sum(1 + encoder_log_var, axis=1)

    def loop_cluster(previous_output, current_input):
        k = current_input
        l = previous_output + 0.5 * cluster_weights[:, k] * tf.reduce_sum(
            tf.log(prior_vars[k]) + (tf.exp(encoder_log_var) + tf.square(encoder_mean - prior_means[k])) / prior_vars[k], axis=1
        )
        return l

    clusters = tf.Variable(tf.range(config.n_clusters))
    y = tf.scan(loop_cluster, clusters, initializer=tf.zeros(config.batch_size))
    
    J += y[-1, :]
    
    return tf.reduce_mean(J)

In [None]:
def update_prior_weights_f(prior_weights, cluster_weights):
    return tf.reduce_sum(cluster_weights, axis=0) / float(config.batch_size)

def update_prior_means_f(prior_means, encoder_mean, cluster_weights):
    def loop_cluster(_, k):
        t = tf.matmul(tf.reshape(cluster_weights[:, k], [1, config.batch_size]), encoder_mean)
        t = tf.reshape(t, [config.latent_dim]) / tf.reduce_sum(cluster_weights[:, k], axis=0)
        return t

    clusters = tf.Variable(tf.range(config.n_clusters))
    return tf.scan(loop_cluster, clusters, initializer=tf.ones([config.latent_dim]))

def update_prior_vars_f(prior_weights, encoder_mean, encoder_log_var, cluster_weights):
    global prior_means
    
    s = tf.matmul(
        tf.reshape(cluster_weights[:, 0], [1, config.batch_size]),
        (tf.exp(encoder_log_var) + tf.square(encoder_mean - prior_means[0]))
    ) / tf.reduce_sum(cluster_weights[:, 0], axis=0)
    
    for i in range(1, config.n_clusters):
        t = tf.matmul(
            tf.reshape(cluster_weights[:, i], [1, config.batch_size]),
            tf.exp(encoder_log_var) + tf.square(encoder_mean - prior_means[i])
        ) / tf.reduce_sum(cluster_weights[:, i], axis=0)
        s = tf.concat([s, t], 0)
        
    return s

In [None]:
epoch_len = int(len(mnist.train.images) / config.batch_size)
decay_steps = epoch_len * 10

In [None]:
X = tf.placeholder(tf.float32, [None, config.input_dim])
epsilon = tf.placeholder(tf.float32, [None, config.latent_dim])

prior_means, prior_vars, prior_weights = init_prior()

encoder_mean, encoder_log_var = encoder(X, config.encoder_hidden_size)

Z = sample_Z(encoder_mean, encoder_log_var, epsilon)

decoded_exp_X_mean, decoded_X_mean = decoder(Z, config.decoder_hidden_size)

cluster_weights = q_c(Z)

In [None]:
loss = vae_loss(cluster_weights, decoded_X_mean, encoder_mean, encoder_log_var)

In [None]:
# update_prior_weights = prior_weights.assign(update_prior_weights_f(prior_weights, cluster_weights))
# update_prior_means = prior_means.assign(update_prior_means_f(prior_means, encoder_mean, cluster_weights))
# update_prior_vars = prior_vars.assign(update_prior_vars_f(prior_vars, encoder_mean, encoder_log_var, cluster_weights))

In [None]:
# learning_rate = tf.train.exponential_decay(.002, 0, decay_steps, 0.9, staircase=True)
train_step = tf.train.AdamOptimizer(config.adam_nn_learning_rate, epsilon=config.adam_nn_epsilon).minimize(loss)

In [None]:
sess = tf.Session()
tf.global_variables_initializer().run(session=sess)

In [None]:
def regeneration_plot(epoch):
    if not os.path.exists("plots/regenerated"):
        os.makedirs("plots/regenerated")
    
    np.random.shuffle(test_data)

    decoded_image = sess.run(
        [decoded_X_mean],
        feed_dict={
            X: test_data[:100],
            epsilon: np.random.randn(100, config.latent_dim)
        }
    )

    fig = plt.figure(figsize=(12, 6))
    gs = grid.GridSpec(1, 2) 

    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])

    decoded_image = np.array(decoded_image).reshape((100, 784))
    figure = np.zeros((280, 280))

    for i in range(0, 10):
        for j in range(0, 10):
            figure[i * 28 : (i + 1) * 28, j * 28 : (j + 1) * 28] = decoded_image[10 * i + j].reshape((28, 28)) * 255

    ax1.imshow(figure, cmap="Greys_r")

    decoded_image = np.array(test_data[:100])
    figure = np.zeros((280, 280))

    for i in range(0, 10):
        for j in range(0, 10):
            figure[i * 28 : (i + 1) * 28, j * 28 : (j + 1) * 28] = decoded_image[10 * i + j].reshape((28, 28)) * 255

    ax2.imshow(figure, cmap="Greys_r")

    plt.tight_layout()
    plt.savefig("plots/regenerated/" + str(epoch) + ".png")
    plt.close()

In [None]:
def sample_plot(epoch):
    if not os.path.exists("plots/sampled"):
        os.makedirs("plots/sampled")
    
    mus, sigmas = sess.run([prior_means, prior_vars], feed_dict={})
    
    figure = np.zeros((280, 280))
    fig = plt.figure(figsize=(6, 6))
    
    for k in range(0, 10):
        for i in range(0, 10):
            eps = np.random.randn(1, config.latent_dim)
            decoded_image = sess.run(
                decoded_X_mean,
                feed_dict={
                    Z: eps * sigmas[k] + mus[k]
                }
            ).reshape((28, 28)) * 255

            figure[k * 28 : (k + 1) * 28, i * 28 : (i + 1) * 28] = decoded_image
 
    plt.imshow(figure, cmap="Greys_r")

    plt.tight_layout()
    plt.savefig("plots/sampled/" + str(epoch) + ".png")
    plt.close()

In [None]:
for epoch in range(config.n_epochs):
    J = 0.0
    for i in tqdm(range(epoch_len)):
        X_batch = mnist.train.next_batch(config.batch_size)[0]
        out = sess.run(
            [loss, train_step],
            feed_dict={
                X: X_batch,
                epsilon: np.random.randn(config.batch_size, config.latent_dim)
            }
        )
        J += out[0] / epoch_len
    
    print(J)
    
    sample_plot(epoch)
    regeneration_plot(epoch)