<a href="https://colab.research.google.com/github/martinpius/keras_Functional_API_architecture/blob/main/Variational_AutoEncoder's_From_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
try:
  drive.mount('/content/drive/', force_remount = True)
  COLAB = True
  import tensorflow as tf
  print(f"You are using google colab with tensoflow version: {tf.__version__}")
except Exception as e:
  COLAB = False
  print(f"{type(e)}: {e}\n...load your drive...")


Mounted at /content/drive/
You are using google colab with tensoflow version: 2.3.0


In [2]:
#We are going to train the mnist dataset on a Variational autoencoder model
#This model simply try to represent the data in a more compact form such that
#through sampling we can reconstruct back the original data
#Example we can take input as an image and convolves to obtain the pixels (decoding)
#Then from the pixels we can sample (assuming Gaussian distribution to reconstruct the original image)
#The autoencoder model is trainee in three few steps
#First we build an encoder model (to encode the original information)
#Second we build a decoder network to reconstruct the original information
#Finally we combine the two netwoks to train it frromnehnd to end


In [3]:
import tensorflow as tf
import numpy as np

In [4]:
#The sampling mechanism (layer's subclassing)
class Sampling(tf.keras.layers.Layer):
  def call(self, inputs):
    '''This method will return the inputs to a decoder network'''
    z_mean, z_log_var = inputs
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    eps = tf.keras.backend.random_normal(shape = (batch, dim))
    return z_mean + tf.exp(0.5 * z_log_var) * eps



In [5]:
#The encoder's network (layer's subclassing)

In [6]:
class Encoder(tf.keras.layers.Layer):
  def __init__(self, latent = 32, intermediate = 64, name = 'encoder', **kwargs):
    super(Encoder, self).__init__(name = name, **kwargs)
    self.dense_1 = tf.keras.layers.Dense(units = intermediate, activation = 'relu', kernel_initializer = 'random_normal')
    self.z_mean = tf.keras.layers.Dense(units = latent, activation = 'relu', kernel_initializer = 'random_normal')
    self.z_log_var = tf.keras.layers.Dense(units = latent, activation = 'relu', kernel_initializer = 'random_normal')
    self.sampling = Sampling()
  
  def call(self, inputs):
    x = self.dense_1(inputs)
    z_mean = self.z_mean(x)
    z_log_var = self.z_log_var(x)
    z = self.sampling((z_mean, z_log_var))
    return z_mean, z_log_var, z
    

In [7]:
#The decoder's network (layer's subclassing)
class Decoder(tf.keras.layers.Layer):
  def __init__(self, original, intermediate = 64, name = 'decoder', **kwargs):
    super(Decoder, self).__init__(name = name, **kwargs)
    self.dense_1 = tf.keras.layers.Dense(units = intermediate, activation = 'relu', kernel_initializer = 'random_normal')
    self.out_decoder = tf.keras.layers.Dense(units = original, activation = 'sigmoid')
  
  def call(self, inputs):
    x = self.dense_1(inputs)
    return self.out_decoder(x)
  

In [8]:
#Combining both of the above
class VariationalAutoEncoder(tf.keras.Model):
  def __init__(self, original, latent = 32, intermediate = 64, name = 'vae',**kwargs):
    super(VariationalAutoEncoder, self).__init__(name = name, **kwargs)
    self.original = original
    self.encoder = Encoder(latent = latent, intermediate = intermediate)
    self.decoder = Decoder(original = original, intermediate = intermediate)
  
  def call(self, inputs):
    z_mean, z_log_var, z = self.encoder(inputs)
    reconstructed = self.decoder(z)
    kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
    self.add_loss(kl_loss)
    return reconstructed

In [9]:
original = 784
vae = VariationalAutoEncoder(original, 32, 64)

In [10]:
#Get the train data from keras-mnist digits
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [11]:
display(x_train.shape)

(60000, 28, 28)

In [12]:
#reshaping,preprocessing, and converting into a tensorflow dataset
x_train = x_train.reshape(60000, 784).astype('float32')/255.0

In [13]:
train_dfm = tf.data.Dataset.from_tensor_slices(x_train)
train_dfm = train_dfm.shuffle(buffer_size = 1024).batch(64)

In [14]:
#We can now preparing the training loop for our autoencoder
epochs = 10
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.MeanSquaredError()
metric_fn = tf.keras.metrics.Mean()

In [15]:
#Iterate over the epochs
for epoch in range(epochs):
  print(f"The start of epoch {epoch}.....please wait....")
  #iterate over the training batches
  for step, train_batch in enumerate(train_dfm):
    with tf.GradientTape() as tape:
      reconstructed = vae(train_batch)
      loss = loss_fn(train_batch, reconstructed)
      loss+=sum(vae.losses) #Combining the kl-loss to the main loss.(Note that we can train the model with either of the loss)
    grads = tape.gradient(loss, vae.trainable_weights)
    optimizer.apply_gradients(zip(grads, vae.trainable_weights))
    metric_fn(loss)
    if step % 50 ==0:
      print(f"Step: {step} mean-loss: {metric_fn.result():.4f}")
  

The start of epoch 0.....please wait....
Step: 0 mean-loss: 0.2374
Step: 50 mean-loss: 0.1872
Step: 100 mean-loss: 0.1450
Step: 150 mean-loss: 0.1221
Step: 200 mean-loss: 0.1096
Step: 250 mean-loss: 0.1017
Step: 300 mean-loss: 0.0961
Step: 350 mean-loss: 0.0924
Step: 400 mean-loss: 0.0894
Step: 450 mean-loss: 0.0871
Step: 500 mean-loss: 0.0850
Step: 550 mean-loss: 0.0834
Step: 600 mean-loss: 0.0822
Step: 650 mean-loss: 0.0811
Step: 700 mean-loss: 0.0800
Step: 750 mean-loss: 0.0792
Step: 800 mean-loss: 0.0785
Step: 850 mean-loss: 0.0778
Step: 900 mean-loss: 0.0772
The start of epoch 1.....please wait....
Step: 0 mean-loss: 0.0768
Step: 50 mean-loss: 0.0763
Step: 100 mean-loss: 0.0759
Step: 150 mean-loss: 0.0756
Step: 200 mean-loss: 0.0753
Step: 250 mean-loss: 0.0750
Step: 300 mean-loss: 0.0746
Step: 350 mean-loss: 0.0744
Step: 400 mean-loss: 0.0742
Step: 450 mean-loss: 0.0740
Step: 500 mean-loss: 0.0737
Step: 550 mean-loss: 0.0734
Step: 600 mean-loss: 0.0733
Step: 650 mean-loss: 0.0731
