<a href="https://colab.research.google.com/github/RMoulla/SSL/blob/main/TP_VAE_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Variational Autoencoder (VAE) with Latent Space Interpolation on MNIST**

## Project Description

In this practical assignment, we explore the concepts of **Variational Autoencoders (VAEs)** by implementing and training a convolutional VAE on the MNIST dataset. The primary objectives are to learn how VAEs encode data into a structured latent space and to investigate how this latent space can be leveraged for generating new data and understanding data features.

### Key Learning Goals:
1. **Data Loading and Preprocessing**: Load the MNIST dataset and prepare it for training a convolutional neural network.
2. **VAE Architecture**: Define a VAE model with convolutional layers for encoding and decoding images, including a reparameterization trick to ensure smooth sampling from the latent space.
3. **VAE Loss Function**: Understand and implement the VAE loss, which combines reconstruction loss and KL-divergence to balance accurate reconstructions with a regularized latent space.
4. **Model Training**: Train the VAE model on MNIST, observe loss trends, and understand the impact of balancing reconstruction quality and latent space regularization.
5. **Latent Space Visualization**: Perform latent space interpolation by linearly blending between two points in the latent space, generating smooth transitions between two digit classes.

## Step 1: Loading the MNIST Dataset

In this first step, we load the MNIST dataset, which contains 28x28 grayscale images of handwritten digits (0-9). Each image represents a single digit.

We use PyTorch's `torchvision.datasets` and `DataLoader` to load and preprocess the data efficiently.

### Code Breakdown:
- **Transforms**: `transforms.Compose([transforms.ToTensor()])` converts each image into a tensor format. This allows the data to be used in PyTorch models.
- **Dataset**: `datasets.MNIST` loads the MNIST dataset and applies the specified transform. We set `download=True` to ensure the data is downloaded if it hasn't been already.
- **DataLoader**: The `DataLoader` wraps the dataset and enables batching, shuffling, and parallel loading.

This setup provides a `train_loader`, which we will use to feed batches of images into our model during training.

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Loading MNIST data
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

## Step 2: Defining the Convolutional VAE Model

In this section, we define the architecture of a **Variational Autoencoder (VAE)** with convolutional layers, designed to process the MNIST dataset.

### Code Breakdown:

1. **Initialization and Latent Dimension**:
   - `latent_dim=2` is defined as the dimensionality of the latent space. Setting it to a low number (e.g., 2) makes the latent space easy to visualize.

2. **Encoder**:
   - The encoder consists of three convolutional layers:
     - The first layer transforms the input image of size 1x28x28 to a feature map of size 32x14x14.
     - The second layer further reduces it to a size of 64x7x7.
     - The third layer compresses this to a size of 128x1x1.
   - This results in a flattened feature vector, which is then fed into two fully connected layers (`fc_mu` and `fc_logvar`) to produce the **mean** (`mu`) and **log-variance** (`logvar`) of the latent space.

3. **Reparameterization Trick**:
   - In `reparameterize`, we use the mean and log-variance vectors to sample from a Gaussian distribution. This is done by computing:

     z = \mu + \sigma \cdot \epsilon

   - Here, `epsilon` is a random noise sampled from a standard normal distribution. This allows gradients to backpropagate through the sampling process.

4. **Decoder**:
   - The decoder starts with a fully connected layer to expand the latent vector (`z`) back to a shape compatible with the convolutional layers.
   - It then passes through three transposed convolutional layers to reconstruct the original image shape:
     - The first layer reshapes it to 64x7x7.
     - The second layer outputs 32x14x14.
     - The final layer produces the original shape, 1x28x28, with pixel values normalized between 0 and 1 using a Sigmoid activation.

5. **Forward Pass**:
   - `encode`: Passes the input through the encoder to obtain `mu` and `logvar`.
   - `reparameterize`: Samples from the latent distribution using `mu` and `logvar`.
   - `decode`: Reconstructs the image from the latent vector `z`.
   - Returns the reconstructed image along with `mu` and `logvar` for further use in the loss calculation.

This architecture allows the VAE to encode images to a low-dimensional latent space and then decode them back to their original shape.

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class ConvVAE(nn.Module):
    def __init__(self, latent_dim=2):
        super(ConvVAE, self).__init__()

        # Encoder with convolutional layers
        self.encoder = nn.Sequential(
            # Input 1x28x28 -> Output 32x14x14

            # Input 32x14x14 -> Output 64x7x7

            # Input 64x7x7 -> Output 128x1x1
        )

        # Layers to produce mean and variance vectors

        # Mean vector of the latent space

        # Log-variance vector of the latent space

        # Decoder starting with a fully connected layer
        # Transform latent vector back to decoder shape

        # Decoder with transposed convolutional layers to reconstruct the image
        # Input 128x1x1 -> Output 64x7x7

        # Input 64x7x7 -> Output 32x14x14

        # Input 32x14x14 -> Output 1x28x28

        # Normalize output pixels between 0 and 1


    def encode(self, x):
        # Apply the encoder to extract features and flatten the result

        # Compute mean and log-variance for the latent space distribution

        return mu, logvar

    def reparameterize(self, mu, logvar):
        # Compute standard deviation from log-variance

        # Sample epsilon from a normal distribution

        # Reparameterization trick to sample z from N(mu, sigma^2)
        return mu + eps * std

    def decode(self, z):
        # Apply fully connected layer to expand latent vector to decoder's initial shape

        # Apply the decoder to reconstruct the image
        return self.decoder(h)

    def forward(self, x):
        # Forward pass through encoder

        # Sample from latent distribution
        z
        # Forward pass through decoder
        return self.decode(z), mu, logvar

## Step 3: Defining the VAE Loss Function

The Variational Autoencoder (VAE) loss function combines two key components: **Reconstruction Loss** and **KL-Divergence Loss**. Together, these terms encourage the VAE to produce high-quality reconstructions while also regularizing the latent space.

### Code Breakdown:

1. **Flattening**:
   - Both `recon_x` (the reconstructed image) and `x` (the original image) are flattened to a shape of `[batch_size, 784]` to match the expected input shape for the binary cross-entropy function.

2. **Reconstruction Loss (Binary Cross-Entropy)**:
   - We use Binary Cross-Entropy (BCE) as the reconstruction loss. This term measures the pixel-wise difference between the original and reconstructed images.
   - `BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')` sums the error over all pixels in the batch, promoting accurate reconstructions.

3. **KL-Divergence Loss**:
   - The KL-Divergence term measures the difference between the learned latent distribution `q(z|x)` and a standard normal distribution `p(z) = N(0, 1)`.
   - The calculation:
\[
     KLD = -0.5 \sum (1 + \log(\sigma^2) - \mu^2 - \sigma^2)
\]
   - This term penalizes deviations from the standard normal distribution, helping to structure the latent space and ensure continuity.

4. **Total VAE Loss**:
   - The final loss is the sum of the reconstruction loss (BCE) and the KL-Divergence loss (KLD). Minimizing this total loss encourages the model to reconstruct images accurately while keeping the latent space organized.

By using this combined loss, the VAE learns to generate images that resemble the input data and maintain a well-structured latent space that facilitates tasks like image generation and interpolation.

In [None]:

def loss_function(recon_x, x, mu, logvar):
    # Flatten the output to match the shape of x

    # Also flatten x for consistency

    # Reconstruction loss: Binary Cross Entropy

    # KL-Divergence loss

    # Total VAE loss
    return BCE + KLD

## Step 4: Training the VAE Model

In this step, we set up and execute the training loop for the Variational Autoencoder (VAE) model. The training process optimizes the model’s parameters to minimize the combined loss function over multiple epochs.

### Code Breakdown:

1. **Setting up the Device**:
   - The code detects if a GPU is available using `torch.cuda.is_available()`. If so, it uses the GPU for faster training; otherwise, it defaults to the CPU.

2. **Model and Optimizer Initialization**:
   - We create an instance of the `ConvVAE` model with a specified `latent_dim`. A small latent dimension (like 2) makes it easier to visualize the latent space later.
   - The model is sent to the chosen device (CPU or GPU).
   - The optimizer is set up using Adam with a learning rate of 0.001, which is well-suited for training VAEs.

3. **Training Loop**:
   - `num_epochs` defines the number of times the entire dataset is processed during training.
   - For each epoch:
     - The model is set to training mode (`model.train()`), which activates features like dropout (if used).
     - `train_loss` is initialized to accumulate the total loss over the epoch.
     - For each batch in `train_loader`:
       - **Data Transfer**: The batch of images is transferred to the chosen device.
       - **Gradient Reset**: `optimizer.zero_grad()` resets gradients from the previous batch to prevent accumulation.
       - **Forward Pass**: The data is passed through the model, which outputs `recon_batch` (reconstructed images), `mu` (mean), and `logvar` (log-variance) of the latent distribution.
       - **Loss Calculation**: `loss_function` computes the VAE loss by combining reconstruction and KL-divergence losses.
       - **Backpropagation**: `loss.backward()` computes the gradients for all model parameters.
       - **Optimization Step**: `optimizer.step()` updates the parameters using the calculated gradients.
       - The batch loss is added to `train_loss` to track the total loss for the epoch.
   
   - After each epoch, the average loss for that epoch is printed for monitoring progress. This average loss helps to assess if the model is converging.

The training loop fine-tunes the VAE’s parameters to generate accurate reconstructions and a well-structured latent space, which can be visualized or used for generative tasks after training.

In [None]:
from torch.optim import Adam


# Set device (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model, optimizer, and send model to device
latent_dim = 2
model = ConvVAE(latent_dim=latent_dim).to(device)
optimizer = Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        # Move data to device


        # Zero the gradients


        # Forward pass through the VAE


        # Calculate the VAE loss


        # Backpropagation and optimization

        train_loss += loss.item()
        optimizer.step()

    # Print the average loss for each epoch
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss / len(train_loader.dataset):.4f}')

## Step 5: Interpolation in the Latent Space

In this final step, we perform **latent space interpolation** to explore the continuity of the VAE's learned latent space. By interpolating between two points in the latent space (representing different digits), we can generate a smooth transition between them.

### Code Breakdown:

1. **Defining the Interpolation Function**:
   - `interpolate_and_generate`: This function takes two latent vectors (`z_start` and `z_end`) and generates intermediate points between them.
   - **Interpolation**: We create a series of interpolated points by linearly blending `z_start` and `z_end` with interpolation weights `t` ranging from 0 to 1.
   - **Decoding**: Each interpolated vector is decoded by the VAE’s decoder to produce an image, which is stored in the `images` list.

2. **Selecting Points to Interpolate**:
   - We randomly select two samples from the dataset (e.g., images of "1" and "7").
   - These images are passed through the encoder to obtain their latent representations, `z_start` and `z_end`.

3. **Generating and Visualizing the Interpolation**:
   - The interpolated latent points are decoded back into images.
   - We then plot each decoded image side-by-side to visualize the transformation from the starting digit to the ending digit.
   
   Each intermediate image represents a gradual change in the latent space between `z_start` and `z_end`, showing how the model "morphs" one digit into another.

### Visualization Explanation:

The visualization showcases the VAE’s ability to generate new images by sampling from the latent space. The smooth transition between digits demonstrates that the VAE has learned a well-structured, continuous latent space where similar concepts (like digits) are located close to each other. This capability can be useful for generating synthetic data or exploring variations in features.

By interpolating in this manner, we can observe how the VAE understands and generates the fundamental features of the digits in the MNIST dataset.

In [None]:

import matplotlib.pyplot as plt


# Define the interpolation function
def interpolate_and_generate(model, z_start, z_end, steps=10):
    # Create interpolation steps between two latent vectors

    # Decode each interpolated point and store the images

    return images

# Select two random digits to interpolate between
model.eval()
with torch.no_grad():
    # Pass the data directly without reshaping, as the model expects [1, 1, 28, 28]


# Plot interpolated images
fig, axes = plt.subplots(1, len(images), figsize=(15, 3))
for i, img in enumerate(images):
    axes[i].imshow(img, cmap="gray")
    axes[i].axis("off")
plt.show()