In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
num_sample = mnist.train.num_examples
input_dim = mnist.train.images[0].shape[0]
w = h = int(np.sqrt(input_dim))

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


In [None]:
tf.reset_default_graph()

num_epochs = 50
batch_size = 64
learning_rate = 3e-4
z_dim = 10
beta = 4

input_images = tf.placeholder(tf.float32, (None, input_dim))

# Dense VAE model
activation = tf.nn.leaky_relu

# Encoder
def encoder(images):
    with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
        x = tf.layers.dense(images, 512, activation=activation, name="dense1")
        x = tf.layers.dense(x, 256, activation=activation, name="dense2")
        vae_mean      = tf.layers.dense(x, z_dim, activation=None, name="vae_mean")
        vae_logstd_sq = tf.layers.dense(x, z_dim, activation=None, name="vae_logstd_sqare")
        return x, vae_mean, vae_logstd_sq

# Decoder
def decoder(z):
    with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
        x = tf.layers.dense(z, 256, activation=activation, name="dense1")
        x = tf.layers.dense(x, 512, activation=activation, name="dense2")
        x = tf.layers.dense(x, input_dim, activation=None, name="dense3")
        return x


_, vae_mean, vae_logstd_sq = encoder(input_images)

vae_normal = tf.distributions.Normal(vae_mean, tf.exp(0.5 * vae_logstd_sq), validate_args=True)
vae_sample = tf.squeeze(vae_normal.sample(1), axis=0)

reconstructed_images = decoder(vae_sample)
reconstructed_images_mean = tf.nn.sigmoid(decoder(vae_mean))

generative_z = tf.placeholder(tf.float32, (None, z_dim))
generated_images = decoder(generative_z)

def bce(t, y):
    epsilon = 1e-10
    return -tf.reduce_sum(t * tf.log(epsilon + y) + (1 - t) * tf.log(epsilon + 1 - y), axis=1)

def kl_divergence(mean, logstd_sq):
    return -0.5 * tf.reduce_sum(1 + logstd_sq - tf.square(mean) - tf.exp(logstd_sq), axis=1)

# Binary cross-entropy reconstruction loss
reconstruction_loss = tf.reduce_mean(tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=input_images, logits=reconstructed_images), axis=1))#tf.reduce_mean(bce(input_images, reconstructed_images))
kl_loss = tf.reduce_mean(kl_divergence(vae_mean, vae_logstd_sq))

# Total loss
loss = reconstruction_loss + beta * kl_loss

# Summary
tf.summary.scalar("kl_loss", kl_loss)
tf.summary.scalar("reconstruction_loss", reconstruction_loss)
merge_op = tf.summary.merge_all()

# Minimize loss
#optimizer     = tf.train.AdamOptimizer(learning_rate=3e-4, epsilon=1e-5)
optimizer     = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5)
train_step    = optimizer.minimize(loss)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

train_writer = tf.summary.FileWriter("./vae_logs/run_sigmoid_cross_entropy_with_logits_run4", sess.graph)

step_idx = 0
print("Training")
for epoch in range(num_epochs):
    if (epoch+1) % 10 == 0: print(f"Epoch {epoch+1}/{num_epochs}")
    for i in range(num_sample // batch_size):
        r = sess.run([train_step, merge_op], feed_dict={
            input_images: mnist.train.next_batch(batch_size)[0]
        })
        train_writer.add_summary(r[1], step_idx)
        step_idx += 1

Training
Epoch 10/50
Epoch 20/50
Epoch 30/50
Epoch 40/50


In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
import ipywidgets as widgets

sample_idx = np.random.choice(num_sample, 16, replace=False)

orig_img = mnist.train.images[sample_idx]
recon_img = sess.run(reconstructed_images_mean, feed_dict={
    input_images: mnist.train.images[sample_idx]
})

fig, ax = plt.subplots(4, 4*2, figsize=(5, 5))
fig.subplots_adjust(hspace=0, wspace=0)

for i in range(4):
    for j in range(4):
        ax[i, j*2].xaxis.set_major_locator(plt.NullLocator())
        ax[i, j*2].yaxis.set_major_locator(plt.NullLocator())
        ax[i, j*2].imshow(orig_img[4 * i + j].reshape(w, h))
        ax[i, j*2+1].xaxis.set_major_locator(plt.NullLocator())
        ax[i, j*2+1].yaxis.set_major_locator(plt.NullLocator())
        ax[i, j*2+1].imshow(recon_img[4 * i + j].reshape(w, h))
plt.show()

In [None]:
curr_z = np.zeros((1, z_dim))
img = sess.run(generated_images, feed_dict={
    generative_z: curr_z
})
fig = plt.figure()
fig.suptitle("Generated Image")
ax = fig.add_subplot(1, 1, 1)
plot = ax.imshow(img.reshape(w, h))

for dim_idx in range(z_dim):
    slider = widgets.FloatSlider(
        value=0.0,
        min=-3,
        max= 3,
        step=0.1,
        description=f"z_dim[{dim_idx}]",
        disabled=False,
        continuous_update=True,
        orientation="horizontal",
        readout=True,
        readout_format=".1f",
    )

    def create_slider_event(dim_idx):
        def func(change):
            curr_z[0, dim_idx] = change["new"]
            img = sess.run(generated_images, feed_dict={
                generative_z: curr_z
            })
            plot.set_data(img.reshape(w, h))
            fig.canvas.draw()
        return func

    slider.observe(create_slider_event(dim_idx), names="value")
    display(slider)