In [None]:
import torch
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
# TODO:
# 1- try directly giving mu and sigma
# 2- try cnn vae
# after convincing on the best approach
# 3- discover label coded latent space (its evaluation for different beta)
# 4- 2D manifold of digits (latent space=2)

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# [0,255] -> [0,1]
x_train = x_train/255.
x_test = x_test/255.
dataset_mean,dataset_std = np.mean(x_train),np.std(x_train)
# standardization
x_train = (x_train - dataset_mean) / (dataset_std)
x_train = tf.keras.layers.Flatten()(x_train)
x_test = (x_test - dataset_mean) / (dataset_std)
x_test = tf.keras.layers.Flatten()(x_test)

train_ds = tf.data.Dataset.from_tensor_slices((x_train))
train_ds = train_ds.shuffle(1000).batch(64)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

In [None]:
preds = tf.reshape(x_test[:16],[-1,28,28])
# destandardization
preds = (preds * dataset_std) + dataset_mean
preds = preds * 255.

fig = plt.figure(figsize=(5, 5))
for i in range(preds.shape[0]):
    plt.subplot(4, 4, i + 1)
    plt.imshow(preds[i, :, :], cmap='gray')
    plt.axis('off')

In [None]:
latent_dim = 2

# define encoder model
inputs = tf.keras.Input(shape = (784,))
x = tf.keras.layers.Dense(units=500, activation='relu')(inputs)
x = tf.keras.layers.Dense(units=120, activation='relu')(x)
mu = tf.keras.layers.Dense(units=latent_dim)(x)
rho = tf.keras.layers.Dense(units=latent_dim)(x)
Encoder = tf.keras.Model(inputs=inputs,outputs=[mu,rho])

# define decoder model
z = tf.keras.Input(shape = (latent_dim,))
x = tf.keras.layers.Dense(units=120, activation='relu')(z)
x = tf.keras.layers.Dense(units=500, activation='relu')(x)
decoded_img = tf.keras.layers.Dense(units=784)(x)
Decoder = tf.keras.Model(inputs=z,outputs=[decoded_img])

class VAE(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.encoder_block = Encoder
        self.decoder_block = Decoder

    def call(self,img):
        z_mu,z_rho = self.encoder_block(img)

        epsilon = tf.random.normal(shape=z_mu.shape,mean=0.0,stddev=1.0)
        z = z_mu + tf.math.softplus(z_rho) * epsilon

        decoded_img = self.decoder_block(z)

        return z_mu,z_rho,decoded_img

In [None]:
# closed form kl loss computation between variational posterior q(z|x) and unit Gaussian prior p(z) 
def kl_loss(z_mu,z_rho):
    sigma_squared = tf.math.softplus(z_rho) ** 2
    kl_1d = -0.5 * (1 + tf.math.log(sigma_squared) - z_mu ** 2 - sigma_squared)

    # sum over sample dim, average over batch dim
    kl_batch = tf.reduce_mean(tf.reduce_sum(kl_1d,axis=1))

    return kl_batch

def elbo(z_mu,z_rho,decoded_img,original_img):
    # reconstruction loss
    mse = tf.reduce_mean(tf.reduce_sum(tf.square(original_img - decoded_img),axis=1))
    # kl loss
    kl = kl_loss(z_mu,z_rho)

    return mse,kl

In [None]:
def generate_images(model, epoch, step):
  
  # during prediction, sample from prior directly. 16 is batch size
  z = tf.random.normal(shape=(16,model.encoder_block.output[0].shape[1]),mean=0.0,stddev=1.0)
  preds = model.decoder_block(z)
  preds = tf.reshape(preds,[-1,28,28])
  # destandardization
  preds = (preds * dataset_std) + dataset_mean
  preds = preds * 255.

  fig = plt.figure(figsize=(5, 5))
  for i in range(preds.shape[0]):
    plt.subplot(4, 4, i + 1)
    plt.imshow(preds[i, :, :], cmap='gray')
    plt.axis('off')

  #plt.savefig(f'image_at_epoch_{epoch:04d}_step_{step:04d}.png')
  plt.savefig(f'generated_samples.png')
  #plt.show()

In [None]:
def generate_training_images(model, temp_x_train):

  _,_,preds = model(temp_x_train)
  preds = tf.reshape(preds,[-1,28,28])
  # destandardization
  preds = (preds * dataset_std) + dataset_mean
  preds = preds * 255.

  fig = plt.figure(figsize=(5, 5))
  for i in range(preds.shape[0]):
    plt.subplot(4, 4, i + 1)
    plt.imshow(preds[i, :, :], cmap='gray')
    plt.axis('off')

  #plt.savefig(f'image_at_epoch_{epoch:04d}_step_{step:04d}.png')
  plt.savefig(f'generated_test_samples.png')
  #plt.show()

In [None]:
model = VAE()

optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001)

kl_loss_tracker = tf.keras.metrics.Mean(name='kl_loss')
mse_loss_tracker = tf.keras.metrics.Mean(name='mse_loss')


for epoch in range(100):

    for step,imgs in train_ds.enumerate():
        
        # training loop
        with tf.GradientTape() as tape:
            # forward pass
            z_mu,z_rho,decoded_imgs = model(imgs)

            # compute loss
            mse,kl = elbo(z_mu,z_rho,decoded_imgs,imgs)
            loss = mse + 15. * kl

        # compute gradients
        gradients = tape.gradient(loss,model.variables)

        # update weights
        optimizer.apply_gradients(zip(gradients, model.variables))

        # update metrics
        kl_loss_tracker.update_state(kl)
        mse_loss_tracker.update_state(mse)


    # generate 16 samples every epoch.
    generate_images(model,epoch,0)
    generate_training_images(model, x_test[:16])

    # display metrics at the end of each epoch.
    epoch_kl,epoch_mse = kl_loss_tracker.result(),mse_loss_tracker.result()
    print(f'epoch: {epoch}, mse: {epoch_mse:.4f}, kl_div: {epoch_kl:.4f}')

    # reset metric states
    kl_loss_tracker.reset_state()
    mse_loss_tracker.reset_state()