# Variational Autoencoders (VAEs)

## Overview

**Variational Autoencoders (VAEs)** are a type of generative model that combine the ideas from autoencoders and variational inference. They learn a **probabilistic latent space** that can be sampled to generate new data similar to the input data. VAEs are commonly used for generating images, learning latent representations, and unsupervised tasks.

VAEs provide a principled way of performing both dimensionality reduction and data generation by learning a distribution of the latent variables, unlike standard autoencoders, which directly map inputs to latent codes.

---

## Architecture of VAEs

The basic architecture of a VAE consists of:

1. **Encoder**: Maps input data $ x $ to a probability distribution over the latent space $ z $. This distribution is typically Gaussian.
2. **Latent Space**: Encodes the data in a low-dimensional, continuous latent variable $ z $, sampled from the distribution learned by the encoder.
3. **Decoder**: Maps latent variable $ z $ back to a distribution over the original data space $ x $.

In contrast to standard autoencoders, VAEs encode the input as a distribution, rather than a single point. They aim to minimize the reconstruction error **and** regularize the latent space using a term based on **Kullback-Leibler (KL) divergence**.

---

## Mathematical Foundations of VAEs

### 1. The Encoder Network

The encoder maps the input $ x $ to a distribution over the latent variable $ z $. The distribution is often assumed to be Gaussian, so the encoder outputs the mean $ \mu(x) $ and the standard deviation $ \sigma(x) $ of the Gaussian distribution for each input:

$$
q(z|x) = \mathcal{N}(z; \mu(x), \sigma(x)^2)
$$

Here:

- $ \mu(x) $ is the mean of the Gaussian.
- $ \sigma(x) $ is the standard deviation (often represented as $ \log(\sigma(x)) $ to avoid negative values).

### 2. Sampling Latent Variable $ z $

To make backpropagation work, we use the **reparameterization trick**. Instead of directly sampling $ z $ from the Gaussian, we reparameterize the sampling process as:

$$
z = \mu(x) + \sigma(x) \odot \epsilon
$$

where $ \epsilon $ is sampled from a standard normal distribution $ \mathcal{N}(0, 1) $. This allows the model to learn $ \mu(x) $ and $ \sigma(x) $ through gradient descent.

### 3. The Decoder Network

The decoder reconstructs the data from the latent variable $ z $. It aims to maximize the likelihood of the data $ p(x|z) $, which can also be modeled as a Gaussian distribution:

$$
p(x|z) = \mathcal{N}(x; \hat{x}(z), \sigma^2)
$$

where $ \hat{x}(z) $ is the reconstructed output from the decoder network.

### 4. The Loss Function

The loss function for a VAE consists of two terms:

1. **Reconstruction Loss**: Measures how well the decoder reconstructs the input data $ x $. This can be computed using binary cross-entropy or mean squared error depending on the data type.
   
   $$
   \mathcal{L}_{\text{reconstruction}} = - \mathbb{E}_{q(z|x)}[\log p(x|z)]
   $$

2. **KL Divergence (Regularization Term)**: Encourages the distribution $ q(z|x) $ to be close to the prior distribution $ p(z) $, which is typically a standard Gaussian $ \mathcal{N}(0, 1) $. This regularizes the latent space to ensure smoothness and allows for meaningful sampling from it.

   $$
   \mathcal{L}_{\text{KL}} = D_{\text{KL}}(q(z|x) \parallel p(z))
   $$

   The KL divergence between the approximate posterior $ q(z|x) = \mathcal{N}(z; \mu(x), \sigma(x)^2) $ and the prior $ p(z) = \mathcal{N}(0, 1) $ is given by:

   $$
   D_{\text{KL}}(q(z|x) \parallel p(z)) = \frac{1}{2} \sum_{i=1}^{d} \left( \mu_i^2 + \sigma_i^2 - \log(\sigma_i^2) - 1 \right)
   $$

Thus, the total loss for the VAE is:

$$
\mathcal{L}_{\text{VAE}} = \mathcal{L}_{\text{reconstruction}} + \mathcal{L}_{\text{KL}}
$$

---

## How VAEs Work in Practice

1. **Training**: The VAE is trained to optimize the total loss function that balances the reconstruction accuracy and the regularization (KL divergence). The encoder learns to map the input data to a Gaussian distribution in the latent space, while the decoder learns to generate realistic outputs from sampled latent variables.

2. **Generation**: After training, we can generate new data by sampling from the prior distribution $ p(z) = \mathcal{N}(0, 1) $ in the latent space and feeding these samples into the decoder.

---

## Common Use Cases of VAEs

1. **Data Generation**: VAEs can generate new data that resembles the training data by sampling from the latent space. For example, they are used for generating images, speech, and other forms of data.

2. **Dimensionality Reduction**: VAEs can be used for unsupervised learning tasks where the goal is to learn a low-dimensional representation of the data, like PCA but with a probabilistic interpretation.

3. **Anomaly Detection**: Since VAEs model the distribution of the data, they can detect anomalies by observing how well the model reconstructs new data points. Poor reconstruction suggests that the data point is unusual or anomalous.

4. **Semi-supervised Learning**: VAEs can be used in semi-supervised learning settings, where only a small portion of the data is labeled. The VAE can help by learning meaningful latent representations from the unlabeled data.

---

## Summary

Variational Autoencoders (VAEs) are a powerful class of generative models that learn to map data into a probabilistic latent space. They are trained by optimizing a loss function that balances the reconstruction error and the KL divergence between the learned latent distribution and a prior. VAEs are used for tasks like data generation, dimensionality reduction, and anomaly detection.


## Details on innerworkings of VAEs

### How the Encoder Network Maps $ x $ to a Distribution over $ z $ in a Variational Autoencoder (VAE)

In a **Variational Autoencoder (VAE)**, the encoder network maps each input data point $ x $ to a **probability distribution** over the latent variable $ z $, rather than to a single point as in traditional autoencoders. This probabilistic mapping is fundamental to the VAE’s ability to generate new data and to learn meaningful latent representations.

### 1. Probabilistic Modeling of Latent Variables
- **Goal**: We want to model the underlying data distribution and generate new data points. To achieve this, we consider the latent variables $ z $ to be random variables drawn from a probability distribution conditioned on the input $ x $.
- **Approximate Posterior**: The true posterior distribution $ p(z|x) $ is typically intractable. Therefore, VAEs introduce an **approximate posterior** $ q_{\phi}(z|x) $, parameterized by $ \phi $, to approximate $ p(z|x) $.

### 2. Encoder Network Outputs Distribution Parameters
- **Mean and Variance**: The encoder network processes the input $ x $ and outputs two vectors:
  - The **mean vector** $ \boldsymbol{\mu}_{\phi}(x) $.
  - The **log-variance vector** $ \boldsymbol{\log\sigma^2}_{\phi}(x) $.
  
- **Latent Distribution**: These outputs parameterize a multivariate Gaussian distribution over $ z $:
  $$
  q_{\phi}(z|x) = \mathcal{N}\left(z;\, \boldsymbol{\mu}_{\phi}(x),\, \text{diag}(\boldsymbol{\sigma}^2_{\phi}(x))\right)
  $$
  where $ \boldsymbol{\sigma}^2_{\phi}(x) = \exp(\boldsymbol{\log\sigma^2}_{\phi}(x)) $ to ensure positivity.

### 3. Reparameterization Trick
- **Challenge**: Sampling $ z $ directly from $ q_{\phi}(z|x) $ would introduce randomness that is not differentiable, hindering gradient-based optimization.
- **Solution**: The **reparameterization trick** rewrites the sampling process in a differentiable way:
  $$
  z = \boldsymbol{\mu}_{\phi}(x) + \boldsymbol{\sigma}_{\phi}(x) \odot \epsilon
  $$
  where:
  - $ \epsilon $ is sampled from $ \mathcal{N}(0, I) $.
  - $ \odot $ denotes element-wise multiplication.
  
- **Benefit**: This separates the stochastic part ($ \epsilon $) from the deterministic parameters ($ \boldsymbol{\mu}_{\phi}(x) $, $ \boldsymbol{\sigma}_{\phi}(x) $), allowing gradients to flow through the encoder during training.

## 4. Loss Function Components
The VAE’s objective is to maximize the **evidence lower bound (ELBO)**, which consists of two parts:

### a. Reconstruction Loss
- **Purpose**: Measures how well the decoder $ p_{\theta}(x|z) $ can reconstruct the input $ x $ from $ z $.
- **Computation**: Often uses negative log-likelihood, such as:
  $$
  \mathcal{L}_{\text{recon}} = -\mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)]
  $$
- **Interpretation**: Encourages the decoder to produce outputs close to the original input.

### b. KL Divergence Loss
- **Purpose**: Regularizes the latent space by measuring how close $ q_{\phi}(z|x) $ is to the prior $ p(z) $, typically $ \mathcal{N}(0, I) $.
- **Computation**: For Gaussian distributions, the KL divergence has a closed-form expression:
  $$
  \mathcal{L}_{\text{KL}} = D_{\text{KL}}(q_{\phi}(z|x) \parallel p(z)) = \frac{1}{2} \sum_{i=1}^{d} \left( \sigma^2_{\phi,i}(x) + \mu_{\phi,i}(x)^2 - 1 - \log \sigma^2_{\phi,i}(x) \right)
  $$
- **Interpretation**: Encourages the approximate posterior to be close to the prior, promoting a smooth and continuous latent space.

## 5. Total Loss and Training Objective
- **Total Loss**: The sum of the reconstruction and KL divergence losses:
  $$
  \mathcal{L} = \mathcal{L}_{\text{recon}} + \mathcal{L}_{\text{KL}}
  $$
- **Optimization**: The model is trained by minimizing $ \mathcal{L} $ with respect to $ \phi $ and $ \theta $, the parameters of the encoder and decoder networks, respectively.

## 6. Encoder Network Architecture
- **Input Layer**: Receives the data point $ x $.
- **Hidden Layers**: Processes $ x $ through layers (e.g., convolutional or fully connected), extracting features.
- **Output Layers**: Splits into two heads:
  - One outputs $ \boldsymbol{\mu}_{\phi}(x) $.
  - The other outputs $ \boldsymbol{\log\sigma^2}_{\phi}(x) $.
  
- **Activation Functions**: Commonly use linear activations for the mean and log-variance outputs.

## 7. Why Map to a Distribution?
- **Generative Modeling**: Allows the VAE to generate new, diverse samples by sampling $ z $ from the prior distribution.
- **Latent Space Structure**: Encourages similar inputs to have similar latent representations, creating a continuous and smooth latent space.
- **Handling Uncertainty**: Represents the model’s uncertainty about latent variables, which can be useful in downstream tasks.

## 8. Mathematical Summary
- **Encoder Outputs**:
  $$
  \boldsymbol{\mu}_{\phi}(x),\ \boldsymbol{\log\sigma^2}_{\phi}(x)
  $$
- **Sampling Latent Variable**:
  $$
  z = \boldsymbol{\mu}_{\phi}(x) + \boldsymbol{\sigma}_{\phi}(x) \odot \epsilon,\ \epsilon \sim \mathcal{N}(0, I)
  $$
- **Decoder Outputs**:
  $$
  \hat{x} = p_{\theta}(x|z)
  $$
- **Loss Function**:
  $$
  \mathcal{L} = -\mathbb{E}_{q_{\phi}(z|x)}[\log p_{\theta}(x|z)] + D_{\text{KL}}(q_{\phi}(z|x) \parallel p(z))
  $$

## 9. Training Procedure
1. **Forward Pass**:
   - Input $ x $ is passed through the encoder to obtain $ \boldsymbol{\mu}_{\phi}(x) $ and $ \boldsymbol{\log\sigma^2}_{\phi}(x) $.
   - Sample $ \epsilon $ from $ \mathcal{N}(0, I) $.
   - Compute $ z $ using the reparameterization trick.
   - Pass $ z $ through the decoder to obtain $ \hat{x} $.

2. **Compute Loss**:
   - Calculate the reconstruction loss between $ x $ and $ \hat{x} $.
   - Compute the KL divergence between $ q_{\phi}(z|x) $ and $ p(z) $.

3. **Backward Pass**:
   - Backpropagate the gradients through the decoder and encoder networks.
   - Update $ \phi $ and $ \theta $ using an optimizer like Adam.

## 10. Practical Considerations
- **Stability**: Outputting $ \boldsymbol{\log\sigma^2}_{\phi}(x) $ instead of $ \boldsymbol{\sigma}_{\phi}(x) $ helps with numerical stability and ensures that $ \boldsymbol{\sigma}^2_{\phi}(x) $ is positive after exponentiation.
- **Batch Training**: During training, averages over minibatches are used to compute the expected values in the loss function.
- **Exploration of Latent Space**: After training, varying $ z $ smoothly can produce meaningful variations in $ \hat{x} $.

## Summary
In a Variational Autoencoder, the encoder network maps input data $ x $ to the parameters of a probability distribution over the latent variable $ z $ by outputting the mean and log-variance of a Gaussian distribution. The reparameterization trick allows for differentiable sampling of $ z $, enabling backpropagation during training. The encoder learns to produce distributions that capture the essential features of $ x $ while regularizing $ z $ to follow a prior distribution. This probabilistic mapping facilitates the generation of new data, interpolation between data points, and the learning of meaningful latent representations.


In [None]:
import torch
import torch.nn as nn

class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)

class ConvVAE(nn.Module):
    def __init__(self, latent_dim=128):
        super(ConvVAE, self).__init__()
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),  # (64, 125, 125)
            nn.ReLU(True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # (128, 63, 63)
            nn.ReLU(True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),# (256, 32, 32)
            nn.ReLU(True),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),# (512, 16, 16)
            nn.ReLU(True),
            nn.Flatten()
        )
        
        # mu is computed by passing the output of the encoder through a fully connected layer (fc_mu). 
        # This layer produces the mean of the latent space distribution for each latent dimension. 
        # It is a vector representing the center of the Gaussian distribution from which the latent vector z will be sampled.
        # Latent space
        self.fc_mu = nn.Linear(512*16*16, latent_dim)
        # Similarly, logvar is computed by passing the encoded input through another fully connected layer (fc_logvar). 
        # This represents the log variance (logarithm of the variance) of the Gaussian distribution. We use the log variance 
        # instead of the variance to ensure numerical stability and avoid potential problems with negative variance values.
        self.fc_logvar = nn.Linear(512*16*16, latent_dim)
        
        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 512*16*16)
        
        self.decoder = nn.Sequential(
            Reshape((-1, 512, 16, 16)),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), # (256, 32, 32)
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=0), # (128, 63, 63)
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=0),  # (64, 125, 125)
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1),    # (3, 250, 250)
            nn.Sigmoid()  # Use sigmoid to bring output values between [0, 1]
        )

    # The function returns these two vectors, mu and logvar. These vectors represent the parameters of the distribution 
    # (usually a Gaussian distribution) from which the latent variable z will be sampled.
    def encode(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        # var = exp(logvar)
        # std = sqrt(var) = sqrt(exp(logvar))
        # sqrt(a) = a**(1/2)
        # log(sqrt(a)) = log(a**(1/2)) = 1/2 * log(a)
        # exp(log(a**(1/2))) = a**(1/2) = exp(1/2 * log(a)). Replace a by variance
        std = torch.exp(0.5 * logvar)
        # eps is a random noise sampled from a standard normal distribution (mean = 0, variance = 1). 
        # torch.randn_like(std) generates random values with the same shape as std.
        # This noise eps will later be used to inject randomness into the sampling process while maintaining differentiability.
        eps = torch.randn_like(std)
        # The latent variable z is calculated by adding the mean mu to the scaled random noise (eps * std). 
        # This operation represents the reparameterization trick, which ensures that the sampling process is differentiable.
        return mu + eps * std

    def decode(self, z):
        x = self.decoder_input(z)
        x = self.decoder(x)
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD