In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import datetime, os
import matplotlib.pyplot as plt

In [None]:
np.random.seed(1)
tf.random.set_seed(1)

In [None]:
# define parameters
latent_dim = 50
EPOCHS=30
BATCH_SIZE=128

In [None]:
#load dataset
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
print(x_train.shape,x_test.shape)
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = mnist_digits.astype("float32") / 255
print(mnist_digits.shape)

In [None]:
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
def create_encoder(summary=False):
  encoder_inputs = keras.Input(shape=(28, 28))
  x = layers.LSTM(128)(encoder_inputs)
  z_mean = layers.Dense(latent_dim, name="z_mean")(x)
  z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
  z = Sampling()([z_mean, z_log_var])
  encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
  if summary:
    encoder.summary()
  return encoder

In [None]:
def create_decoder(summary=False):
  latent_inputs = keras.Input(shape=(latent_dim,))
  x = layers.Dense(7 * 7 * 32, activation="relu")(latent_inputs)
  x = layers.Reshape((7, 7, 32))(x)
  x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
  x = layers.Conv2DTranspose(16, 3, activation="relu", strides=2, padding="same")(x)
  x = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
  decoder_outputs = layers.Reshape((28, 28))(x)
  decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
  if summary:
    decoder.summary()
  return decoder


In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(tf.expand_dims(data, -1), tf.expand_dims(reconstruction, -1)), axis=(1, 2)
                )   # here, binary cross entropy takes mean along axis -1. So, another dimension is added to end.
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            kl_loss_weight = 1
            total_loss = reconstruction_loss + kl_loss_weight*kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

In [None]:
%load_ext tensorboard

In [None]:
# tensorboard to view loss graphs
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
print(logdir)

In [None]:
# run tensorboard
%tensorboard --logdir logs

In [None]:
def train_model():
  encoder=create_encoder(True)
  decoder=create_decoder(True)
  vae = VAE(encoder, decoder)
  vae.compile(optimizer=keras.optimizers.Adam())
  vae.fit(mnist_digits, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[tensorboard_callback])
  return vae


In [None]:
def generate_images(vae):
  # generate images from random vectors
  num = 100 #number of images to generate
  num_row = 10
  num_col = 10
  digit_size = 28

  # draw std. normal variables from latent space
  np.random.seed(0)
  rand_vars = np.random.normal(size=(num,latent_dim))

  # plot images
  fig, axes = plt.subplots(num_row, num_col, figsize=(1.5*num_col,2*num_row))
  for i in range(num):
      z_sample = np.array([rand_vars[i]])
      x_decoded = vae.decoder.predict(z_sample)
      digit = x_decoded[0].reshape(digit_size, digit_size)

      ax = axes[i//num_col, i%num_col]
      ax.imshow(digit, cmap='gray')
  plt.tight_layout()
  plt.show()

In [None]:
vae = train_model()

In [None]:
generate_images(vae)

In [None]:
# save models
vae.encoder.save("VAE_encoder") 
vae.decoder.save("VAE_decoder") 

In [None]:
# zip saved model folders to download
!zip -r /content/VAE_encoder.zip /content/VAE_encoder
!zip -r /content/VAE_decoder.zip /content/VAE_decoder