<a href="https://colab.research.google.com/github/luiscunhacsc/udemy-ai-en/blob/main/part4_generative_ai/VAE_Gradio/VAE_with_Gradio_and_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install -q tensorflow gradio numpy matplotlib


In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Lambda, Conv2D, Flatten, Conv2DTranspose, Reshape, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.datasets import mnist
import gradio as gr
import matplotlib.pyplot as plt

# Custom Sampling Layer
class Sampling(Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

# Custom VAE Loss Layer
class VAELoss(Layer):
    def call(self, inputs):
        x, x_decoded, z_mean, z_log_var = inputs
        reconstruction_loss = BinaryCrossentropy()(tf.keras.backend.flatten(x), tf.keras.backend.flatten(x_decoded))
        reconstruction_loss *= 28 * 28
        kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
        kl_loss = tf.reduce_sum(kl_loss, axis=-1)
        kl_loss *= -0.5
        total_loss = tf.reduce_mean(reconstruction_loss + kl_loss)
        self.add_loss(total_loss)
        return x_decoded

# Load and preprocess data
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

latent_dim = 2

# Encoder
inputs = Input(shape=(28, 28, 1))
x = Conv2D(32, 3, activation='relu', strides=2, padding='same')(inputs)
x = Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim)(x)
z_log_var = Dense(latent_dim)(x)
z = Sampling()([z_mean, z_log_var])

encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')

# Decoder
latent_inputs = Input(shape=(latent_dim,))
x = Dense(7 * 7 * 64, activation='relu')(latent_inputs)
x = Reshape((7, 7, 64))(x)
x = Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same')(x)
x = Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)
outputs = Conv2DTranspose(1, 3, activation='sigmoid', padding='same')(x)

decoder = Model(latent_inputs, outputs, name='decoder')

# VAE
outputs = decoder(encoder(inputs)[2])
vae_outputs = VAELoss()([inputs, outputs, encoder(inputs)[0], encoder(inputs)[1]])
vae = Model(inputs, vae_outputs, name='vae')

vae.compile(optimizer='adam')
vae.summary()

# Train the VAE
vae.fit(x_train, epochs=30, batch_size=128, validation_data=(x_test, None))

# Function to visualize the latent space (only possible for 2D latent spaces, not applicable here)
def plot_latent_space():
    pass

# Function to generate new samples
def generate_sample(z1, z2):
    z_sample = np.array([[z1, z2]])
    x_decoded = decoder.predict(z_sample)
    digit = x_decoded[0].reshape(28, 28)
    return digit

# Gradio interface
iface = gr.Interface(
    fn=generate_sample,
    inputs=[
        gr.Slider(-4, 4, value=0, label="Latent Variable 1"),
        gr.Slider(-4, 4, value=0, label="Latent Variable 2")
    ],
    outputs=gr.Image(image_mode='L'),
    live=True,
    description="Move the sliders to change the latent variables and generate different digit images."
)

iface.launch()


Model: "vae"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_5 (InputLayer)        [(None, 28, 28, 1)]          0         []                            
                                                                                                  
 encoder (Functional)        [(None, 2),                  69076     ['input_5[0][0]',             
                              (None, 2),                             'input_5[0][0]',             
                              (None, 2)]                             'input_5[0][0]']             
                                                                                                  
 decoder (Functional)        (None, 28, 28, 1)            65089     ['encoder[0][2]']             
                                                                                                

