# Variational Autoencoders 


### References

[1] Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling, https://arxiv.org/pdf/1406.5298.pdf

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

from tensorflow.keras import backend as K

from IPython.display import clear_output

In [2]:
img_dim = 28
latent_dim = 3
dropout_rate = 0.3
learn_rate_init = 0.0001

def apply_bn_and_dropout(x):
    return layers.Dropout(dropout_rate)(layers.BatchNormalization()(x))

def make_encoder(input_dim: int, latent_dim: int):
    encoder_inputs = keras.Input(shape=(input_dim, input_dim, 1))
    x = layers.Flatten()(encoder_inputs)
    x = layers.Dense(256, activation='relu')(x)
    x = apply_bn_and_dropout(x)
    x = layers.Dense(128, activation='relu')(x)
    x = apply_bn_and_dropout(x)
    # x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
    # x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
    # x = layers.Flatten()(x)
    # x = layers.Dense(16, activation="relu")(x)
    z_mean = layers.Dense(latent_dim, name="z_mean")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
    z = Sampling()([z_mean, z_log_var])
    encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
    encoder.summary()
    return encoder

def make_decoder(output_dim: int, latent_dim: int):
    latent_inputs = keras.Input(shape=(latent_dim,))
    x = layers.Dense(128, activation='relu')(latent_inputs)
    x = apply_bn_and_dropout(x)
    x = layers.Dense(256, activation='relu')(x)
    x = apply_bn_and_dropout(x)
    x = layers.Dense(28 * 28, activation='sigmoid')(x)
    decoder_outputs = layers.Reshape((output_dim, output_dim, 1))(x)
    # x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
    # x = layers.Reshape((7, 7, 64))(x)
    # x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
    # x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
    # decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
    decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
    decoder.summary()
    return decoder

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }



In [3]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

x_train = np.expand_dims(x_train, -1).astype("float32") / 255
x_test  = np.expand_dims(x_test, -1).astype("float32") / 255

encoder = make_encoder(img_dim, latent_dim)
decoder = make_decoder(img_dim, latent_dim)
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam(learn_rate_init))

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 flatten (Flatten)              (None, 784)          0           ['input_1[0][0]']                
                                                                                                  
 dense (Dense)                  (None, 256)          200960      ['flatten[0][0]']                
                                                                                                  
 batch_normalization (BatchNorm  (None, 256)         1024        ['dense[0][0]']                  
 alization)                                                                                 

In [4]:
def plot_label_clusters_2d(vae, data, labels, epoch: int = 1):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = vae.encoder.predict(data)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z0")
    plt.ylabel("z1")
    plt.title(f'Epoch: {epoch}')
    plt.xlim(-5.5, 5.5)
    plt.ylim(-5.5, 5.5)
    plt.savefig(f'figs/vae/2d/{epoch}.png')
    plt.close()

def plot_label_clusters_3d(vae, data, labels, epoch):
    # display a 3D plot of the digit classes in the latent space
    z_mean, _, _ = vae.encoder.predict(data)
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(projection='3d')
    
    scatter = ax.scatter(z_mean[:, 0], z_mean[:, 1], z_mean[:, 2], c=labels)
    cbar = plt.colorbar(scatter)
    
    ax.set_xlabel("z0")
    ax.set_ylabel("z1")
    ax.set_zlabel("z2")
    ax.set_title(f'Epoch: {epoch}')
    ax.set_xlim(-5.5, 5.5)
    ax.set_ylim(-5.5, 5.5)
    ax.set_zlim(-5.5, 5.5)
    plt.savefig(f'figs/vae/3d/{epoch}.png')
    plt.close()


def on_epoch_end(epoch, logs):
   
    plot_label_clusters_2d(vae, x_train, y_train, epoch)
    plot_label_clusters_3d(vae, x_train, y_train, epoch)
    clear_output()


save_fig = keras.callbacks.LambdaCallback(on_epoch_end=on_epoch_end)
tb       = keras.callbacks.TensorBoard(log_dir='logs_vae')

In [5]:

vae.fit(
    mnist_digits, 
    epochs=50, 
    batch_size=500, 
    callbacks=[save_fig, tb]
)




<keras.callbacks.History at 0x7f5accdffa50>

In [6]:
import glob
import imageio

figs2d = sorted(glob.glob('figs/vae/2d/*.png'), key=lambda s: int(s.rstrip('.png').split('/')[-1]))
figs3d = sorted(glob.glob('figs/vae/3d/*.png'), key=lambda s: int(s.rstrip('.png').split('/')[-1]))

with imageio.get_writer('figs/vae/2d/movie.gif', mode='I', fps=1.2) as writer:
    for fig in figs2d:
        image = imageio.imread(fig)
        writer.append_data(image)

with imageio.get_writer('figs/vae/3d/movie.gif', mode='I', fps=1.2) as writer:
    for fig in figs3d:
        image = imageio.imread(fig)
        writer.append_data(image)

  if __name__ == '__main__':
  
