# Lab 02 : Variational AutoEncoders (VAE) for MNIST Images -- exercise


In [None]:
# For Google Colaboratory
import sys, os
if 'google.colab' in sys.modules:
    # mount google drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    path_to_file = '/content/gdrive/My Drive/CS5242_2025_codes/labs_lecture08/lab02_vae_image'
    print(path_to_file)
    # move to Google Drive directory
    os.chdir(path_to_file)
    !pwd

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import utils
import time
import random

# Libraries
import matplotlib.pyplot as plt
import logging
logging.getLogger().setLevel(logging.CRITICAL) # remove warnings

# PyTorch version and GPU
print(torch.__version__)
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    device= torch.device("cuda") # use GPU
else:
    device= torch.device("cpu")
print(device)


### MNIST dataset 

In [None]:
from utils import check_mnist_dataset_exists
data_path=check_mnist_dataset_exists()

train_data=torch.load(data_path+'mnist/train_data.pt')
train_label=torch.load(data_path+'mnist/train_label.pt')
print(train_data.size())

### VAE with Transformers

The VAE encoder is designed as follows:
* It begins with a convolutional layer that reduces the input image size from 1 x n x n (grayscale) to d x n/2 x n/2, where d is the hidden dimension. 
* A second convolutional layer further downsamples the feature map from d x n/2 x n/2 to d x n/4 x n/4.
* This is followed by a linear layer that flattens the output to a vector of size d.
* Two additional linear layers are used to produce dz-dimensional vectors for the mean and variance of the Gaussian distribution. 
* 2D batch normalization layer and layer normalization are applied before ReLU activation function. 
* The final output of the encoder is a latent vector z, which is sampled from the learned Gaussian distribution.

The VAE decoder reconstructs the image from the latent representation using a symmetric process:
* It starts with two linear layers that map the latent vector z to a feature map of size d x n/4 x n/4. 
* The decoder then applies two transposed convolutional layers to upsample the feature map from d x n/4 x n/4 to d x n/2 x n/2 and finally to the original image size of 1 x n x n. 
* The reconstructed image, x_hat, is obtained by applying a sigmoid activation function, ensuring the pixel values are constrained to the range [0,1].

Hints: You may use PyTorch modules `nn.Conv2d`, `nn.ConvTranspose2d`, `nn.BatchNorm2d`, `nn.LayerNorm`, `torch.randn`, and `torch.sigmoid`.


In [None]:
# Global constants
n = train_data.size(1) # n : nb of pixels along each spatial dimension
dz = 36 # dz : latent dimension
d = 256 # d : hidden dimension
b = 250 # b : batch size


In [None]:
# Define  VAE architecture
class VAE(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        # Encoder x => z
        self.conv1_enc = # COMPLETE HERE  #  1 x 28 x 28 --> d x 14 x 14
        self.bn1_enc = # COMPLETE HERE 
        self.conv2_enc = # COMPLETE HERE  #  d x 14 x 14 --> d x 7 x 7
        self.bn2_enc = # COMPLETE HERE 
        self.linear_q = # COMPLETE HERE 
        self.ln = # COMPLETE HERE 
        self.linear_q_mu     = # COMPLETE HERE 
        self.linear_q_logvar = # COMPLETE HERE 

        # Decoder z => x
        self.linear1_dec = # COMPLETE HERE 
        self.ln_dec = # COMPLETE HERE 
        self.linear2_dec = # COMPLETE HERE 
        self.conv1_dec = # COMPLETE HERE  #  d x 7 x 7 --> d x 14 x 14
        self.bn1_dec = # COMPLETE HERE 
        self.conv2_dec = # COMPLETE HERE  #  d x 14 x 14 --> 1 x 28 x 28
        
    def forward(self, x, train=True): 
        
        if train:
            # Encoder x => z
            h = # COMPLETE HERE 
            q_mu = # COMPLETE HERE 
            q_logvar = # COMPLETE HERE 
            q_std = # COMPLETE HERE 
            eps = # COMPLETE HERE 
            z = # COMPLETE HERE 
        else:
            # Sample unit Normal distribution
            z = torch.Tensor(x.size(0), dz).normal_(mean=0.0, std=1.0).to(device) # [b, dz]
            q_mu, q_logvar = _, _
            
        # Decoder z => x
        h = # COMPLETE HERE 
        x_hat = # COMPLETE HERE 
        
        return x_hat, q_mu, q_logvar
    
    
# Instantiate the network
net = VAE()
net = net.to(device)
print(net)
utils.display_num_param(net) 

# Test the forward pass, backward pass and gradient update with a single batch
init_lr = 0.001
optimizer = torch.optim.Adam(net.parameters(), lr=init_lr)
idx = torch.LongTensor(b).random_(0,60000)
batch_images = train_data[idx,:,:].to(device) # [b, n, n]
print(batch_images.size())
optimizer.zero_grad()
x_hat, q_mu, q_logvar = net(batch_images) # [b, n, n], [b, dz], [b, dz]
print(x_hat.size())
# loss
p_x = batch_images # we assume that images are Bernoulli distribution
p_xz = x_hat       # we do not perform Bernoulli sampling
loss_data =  nn.BCELoss()(p_xz, p_x)
loss_KL = -0.5* torch.mean( 1.0 + q_logvar - q_mu.pow(2.0) - q_logvar.exp() )
loss = 10* loss_data + loss_KL
loss.backward()
optimizer.step()


In [None]:
# Training loop
net = VAE()
net = net.to(device)
utils.display_num_param(net) 

# Optimizer
init_lr = 0.0003
optimizer = torch.optim.AdamW(net.parameters(), lr=init_lr)

nb_batch = 100 
b = 200  # Batch size

start = time.time()
for epoch in range(nb_batch):

    running_loss = 0.0
    num_batches = 0
    
    shuffled_indices = torch.randperm(60000)
    
    for count in range(0,60000,b):
        
        # FORWARD AND BACKWARD PASS
        idx = shuffled_indices[count : count+b]
        batch_images = train_data[idx,:,:].to(device)
        optimizer.zero_grad()
        x_hat, q_mu, q_logvar = net(batch_images) 
        # loss
        p_x = batch_images # we assume that images are Bernoulli distribution
        p_xz = x_hat       # we do not perform Bernoulli sampling
        loss_data =  nn.BCELoss()(p_xz, p_x)
        loss_KL = -0.5* torch.mean( 1.0 + q_logvar - q_mu.pow(2.0) - q_logvar.exp() )
        loss = loss_data + 1/4*loss_KL 
        loss.backward()
        optimizer.step()

        # COMPUTE STATS
        running_loss += loss.detach().item()
        num_batches += 1        
    
    # AVERAGE STATS THEN DISPLAY
    total_loss = running_loss/num_batches
    elapsed = (time.time()-start)/60
    print('epoch=',epoch, '\t time=', elapsed,'min', '\t lr=', init_lr  ,'\t loss=', total_loss )
    
    # PLOT
    if epoch>0 and not epoch%20:
        with torch.no_grad():
            num_generated_images = 16
            x = torch.zeros(num_generated_images, n**2).to(device)
            x_hat = net(x, False)[0]
            x_hat = x_hat.squeeze().detach().to('cpu')
        figure, axis = plt.subplots(4, 4)
        figure.set_size_inches(10,10)
        i,j,cpt=0,0,0; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=1,0,1; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=2,0,2; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=3,0,3; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=0,1+0,4; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=1,1+0,5; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=2,1+0,6; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=3,1+0,7; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=0,2+0,8; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=1,2+0,9; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=2,2+0,10; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=3,2+0,11; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=0,3+0,12; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=1,3+0,13; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=2,3+0,14; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        i,j,cpt=3,3+0,15; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
        plt.show()

    

In [None]:
# Generated images with VAE

net.eval()
with torch.no_grad():
    num_generated_images = 16
    x = torch.zeros(num_generated_images, n**2).to(device)
    x_hat = net(x, False)[0]
    print('x_hat',x_hat.size())
    x_hat = x_hat.squeeze().detach().to('cpu')

figure, axis = plt.subplots(4, 4)
figure.set_size_inches(10,10)

i,j,cpt=0,0,0; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=1,0,1; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=2,0,2; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=3,0,3; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=0,1+0,4; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=1,1+0,5; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=2,1+0,6; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=3,1+0,7; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=0,2+0,8; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=1,2+0,9; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=2,2+0,10; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=3,2+0,11; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=0,3+0,12; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=1,3+0,13; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=2,3+0,14; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')
i,j,cpt=3,3+0,15; axis[i,j].imshow(x_hat[cpt,:,:], cmap='gray'); axis[i,j].set_title("Generated w/ VAE"); axis[i,j].axis('off')

plt.show()
