### Vanilla VAE

In [None]:
import random

import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
np.random.seed(0)
tf.set_random_seed(0)

In [None]:
n_z = 100

(X, y), _ = tf.keras.datasets.fashion_mnist.load_data()
X = X / 128.0

fmnist = X.reshape([-1, 28, 28, 1])

In [None]:
def encoder(x, reuse=None):
    with tf.variable_scope("encoder", reuse=reuse):
        print(x.shape, "input x")
        
        x = tf.layers.conv2d(x, filters=28, kernel_size=2, 
                             strides=2, activation=tf.nn.relu)
        print(x.shape, "conv 1")
        
        x = tf.layers.conv2d(x, filters=14, kernel_size=2, 
                             strides=2, activation=tf.nn.relu)
        print(x.shape, "conv 2")
        
        x = tf.contrib.layers.flatten(x)
        print(x.shape, "flatten")
        
        mn = tf.layers.dense(x, units=n_z, 
                             activation=tf.nn.sigmoid)
        
        sd = tf.layers.dense(x, units=n_z,
                             activation=tf.nn.sigmoid)
        
        epsilon = tf.random_normal(
            tf.stack([tf.shape(x)[0], n_z]))
        
        z = mn + tf.multiply(epsilon, tf.exp(sd))
        print(z.shape, "output z")
        
        return z, mn, sd

In [None]:
def decoder(z, reuse=None):
    with tf.variable_scope("decoder", reuse=reuse):
        print(z.shape, "input z")
        
        x = tf.layers.dense(z, units=686,
                            activation=tf.nn.relu)
        print(x.shape, "fully")
        
        x = tf.reshape(x, (-1, 7, 7, 14))
        print(x.shape, "reshape")
        
        x = tf.layers.conv2d_transpose(x, filters=28, kernel_size=2,
                                       strides=2, activation=tf.nn.relu)
        print(x.shape, "deconv 2")
        
        x = tf.layers.conv2d_transpose(x, filters=1, kernel_size=2,
                                       strides=2, activation=tf.nn.relu)
        print(x.shape, "deconv 1")
        
        return tf.reshape(x, shape=(-1, 28, 28, 1))

In [None]:
tf.reset_default_graph()

input_batch = tf.placeholder(dtype=tf.float32,
                             shape=(None, 28, 28, 1), 
                             name="input_batch")

z, mn, sd = encoder(input_batch)
out = decoder(z)

In [None]:
input_flat = tf.reshape(input_batch, (-1, 28*28*1))
out_flat = tf.reshape(out, (-1, 28*28*1))

img_loss = tf.reduce_sum(tf.squared_difference(out_flat, input_flat), 1)
kl_loss = -0.5 * tf.reduce_sum(1 + sd - tf.square(mn) - tf.exp(sd), 1)
cost = tf.reduce_mean(img_loss + kl_loss)
opt = tf.train.AdamOptimizer(0.0005).minimize(cost)

In [None]:
sess = tf.Session()
epochs = 10
batch_size = 200
sess.run(tf.global_variables_initializer())

try:
    for e in range(epochs):
        print(f"Epoch: {e+1}/{epochs}")
        for i in range(len(fmnist) // batch_size):
            batch = fmnist[random.sample(range(len(fmnist)), batch_size)]
            loss, _ = sess.run([cost, opt], feed_dict={input_batch: batch})            
            if i % 20 == 0:
                print(f"iter: {i}, loss: {str(round(loss, 2))}")
except KeyboardInterrupt:
    print("Stopped")

### See the results

In [None]:
imgs = fmnist[:10]

fig, axes = plt.subplots(
    nrows=2, ncols=10, sharex=True, 
    sharey=True, figsize=(20, 4))
fig.tight_layout(pad=0.1)

rec, _ = sess.run([out, opt], feed_dict={input_batch: imgs})

for images, row in zip([imgs, rec], axes):
    for img, axis in zip(images, row):
        axis.get_xaxis().set_visible(False)
        axis.get_yaxis().set_visible(False)
        axis.imshow(img.reshape((28, 28)), cmap='gray')

### Sample latent space

In [None]:
latent = np.random.normal(0, 1, n_z).reshape((1, n_z))
decoded = sess.run(out, feed_dict={z: latent})
plt.imshow(decoded.reshape((28, 28)), cmap='gray')

In [None]:
sess.close()