<a href="https://colab.research.google.com/github/jngadiub/ML_course_Pavia_23/blob/main/neural-networks/12.VAE_FMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generate FashionMNIST data with VAE with Keras

In this tutorial, we will train a generator of images using a VAE model in Keras. For illustration we will use the FashionMNIST dataset.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
import keras.backend as K
from tensorflow.keras.layers import Input, Flatten, Dense, Lambda, Reshape, Conv2D, Conv2DTranspose

## Load and Process the Dataset

In [None]:
(X_train, y_train), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()

print(X_train.shape, '\t', y_train.shape)
print(X_test.shape, '\t', y_test.shape)

The pixel values in the data lie between 0 and 255. So, we need to normalise them

In [None]:
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255

Now, we use the .`reshape()` fucntion to reshape our data in the format expected by TensorFlow layer i.e., (no of samples, width, height, no of channels)

In [None]:
X_train = X_train.reshape((-1,28,28,1))
X_test = X_test.reshape((-1,28,28,1))

## Visualization of Samples
We plot a few random observations

In [None]:
plt.figure(1)
plt.subplot(221)
plt.imshow(X_train[20][:,:,0])

plt.subplot(222)
plt.imshow(X_train[300][:,:,0])

plt.subplot(223)
plt.imshow(X_train[4000][:,:,0])

plt.subplot(224)
plt.imshow(X_train[5000][:,:,0])
plt.show()

## Model Architecture
We now design our VAE model; which involves of an encoder, the latent space and a decoder. Model implementation wise, the latent space can be considered to be a part of the encoder

### Encoder

In [None]:
enc_input = Input(shape=(28,28,1), name='encoder input')
x = Conv2D(128, 5, padding='same', activation='relu')(enc_input)
x = Conv2D(64, 3, padding='same', strides=2, activation='relu')(x)
x = Conv2D(64, 3, padding='same', activation='relu')(x)
x = Conv2D(64, 3, padding='same', activation='relu')(x)

enc_shape = K.int_shape(x)

In [None]:
x = Flatten()(x)
x = Dense(32)(x)

#### Latent Space

In [None]:
latent_dim = 2 #2D space

z_mean = Dense(latent_dim, name='Z-mean')(x)
z_logvar = Dense(latent_dim, name='Z-logvariance')(x)

We need to define a function that takes in the mean and log variance parameters and return a random sample from the resulting distribution.

In [None]:
def sampling(args):
  mean, logvar = args
  eps = K.random_normal([latent_dim])
  rnd_sam = mean + K.exp(logvar/2) * eps
  return rnd_sam

By using a Lambda layer, we can thus define our latent space as shown below

In [None]:
z = Lambda(sampling, output_shape=latent_dim, name='latent-space')([z_mean, z_logvar])

In [None]:
encoder = keras.Model(enc_input, z, name='encoder')
encoder.summary()

### Decoder
Here, we need to take the randomly sampled 2D latent space vector and convert it back to the original format of the image i.e., 28x28 with a single channel

In [None]:
dec_input = Input(shape=(latent_dim,), name='decoder-input')

true_shape = enc_shape[1:]

y = Dense(np.prod(true_shape))(dec_input)
y = Reshape(target_shape=true_shape)(y)
y = Conv2DTranspose(64, 3, padding='same', activation='relu')(y)
y = Conv2DTranspose(64, 3, padding='same', activation='relu')(y)
y = Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu')(y)
y = Conv2DTranspose(128, 5, padding='same', activation='relu')(y)
y = Conv2DTranspose(1, 5, padding='same', activation='relu')(y)

In [None]:
decoder = keras.Model(dec_input, y, name='decoder')
decoder.summary()

### Connecting all components,

In [None]:
enc_output = encoder(enc_input)
dec_output = decoder(enc_output)


vae = keras.Model(enc_input, dec_output, name='VAE')
vae.summary()

## Training
First, we need to define a custom loss function which trains our model based to improve an error defined as the sum of reconstruction loss and KL-Divergence loss.

In [None]:
def loss_func(z_mean, z_logvar):

    def vae_reconstruction_loss(y_true, y_predict):
        reconstruction_loss_factor = 100
        reconstruction_loss = K.mean(K.square(y_true-y_predict), axis=[1, 2, 3])
        return reconstruction_loss_factor * reconstruction_loss

    def vae_kl_loss(z_mean, z_logvar):
        kl_loss = -0.5 * K.sum(1.0 + z_logvar - K.square(z_mean) - K.exp(z_logvar), axis=1)
        return kl_loss

    def vae_kl_loss_metric(y_true, y_predict):
        kl_loss = -0.5 * K.sum(1.0 + z_logvar - K.square(z_mean) - K.exp(z_logvar), axis=1)
        return kl_loss

    def vae_loss(y_true, y_predict):
        reconstruction_loss = vae_reconstruction_loss(y_true, y_predict)
        kl_loss = vae_kl_loss(y_true, y_predict)

        loss = reconstruction_loss + kl_loss
        return loss

    return vae_loss

We can now compile and train

In [None]:
opt = keras.optimizers.Adam(learning_rate=0.0001)
vae.compile(optimizer=opt, loss=loss_func(z_mean, z_logvar))

In [None]:
history = vae.fit(X_train, X_train, epochs=20, batch_size=32, validation_data=(X_test, X_test))

Here, we performed naive hyperparameter tuning and achieved the above results. Whether the above loss is satisfactory or not depends on how well the model can reconstruct a given sample. This can only be gauged by visualising a few test observations.

## Visualization of Test samples

In [None]:
index = int(input())

y_pred = vae.predict(X_test[:10,:])

plt.figure(1)
plt.subplot(221)
plt.imshow(X_test[index].reshape(28,28))

plt.subplot(222)
plt.imshow(y_pred[index].reshape(28,28))

plt.subplot(223)
plt.imshow(X_test[index*5].reshape(28,28))

plt.subplot(224)
plt.imshow(y_pred[index*5].reshape(28,28))
plt.show()

As seen above, the model is successful in reconstructing the general shape of the clothing item but finer details like text or patters are lost. For our case, this is satisfactory