# üß† VAE vs. VQ-VAE on CIFAR-10  
### A Comparative Study in Latent Representation Learning

---

## üéØ Objective

The goal of this project is to **train and compare** two generative autoencoding models ‚Äî  
a **standard Variational Autoencoder (VAE)** and a **Vector Quantized Variational Autoencoder (VQ-VAE)** ‚Äî  
using the **CIFAR-10** image dataset.  

Both models aim to **learn compact latent representations** of natural images,  
but they differ fundamentally in how the latent space is structured and optimized:

- **VAE** uses a *continuous* latent space, optimized via the reparameterization trick and a KL-divergence regularization term.  
- **VQ-VAE** introduces a *discrete* latent space using **vector quantization**, enabling the model to represent images through a finite learned codebook.

By the end of this notebook, we will:
1. Train both models on CIFAR-10 images.  
2. Visualize reconstructions and generated samples.  
3. Analyze qualitative differences between continuous and discrete latent representations.

---

## üß© Why Compare VAE and VQ-VAE?

- **VAE** is a classical approach that provides smooth latent interpolation but often suffers from blurry reconstructions.  
- **VQ-VAE**, on the other hand, replaces the continuous latent space with a learned dictionary of embeddings, which can lead to sharper image reconstructions and more interpretable discrete codes.

This comparison helps demonstrate how quantization can influence both *image quality* and *latent structure*,  
providing insights that are directly relevant to downstream tasks like **compression**, **generation**, and **representation learning**.

---

## üì¶ Dataset: CIFAR-10

- **Dataset**: CIFAR-10 (60,000 color images, 10 classes, 32√ó32 pixels)  
- **Training set**: 50,000 images  
- **Test set**: 10,000 images  
- **Normalization**: Pixel values scaled to [0, 1]  

CIFAR-10 is a balanced, compact dataset commonly used to benchmark generative models, making it ideal for comparing different VAE variants under similar training conditions.

---

## üß∞ Frameworks Used

This notebook is implemented with **PyTorch** and **TorchVision**, leveraging:
- `torch.nn` for defining deep neural networks  
- `torchvision.datasets` for loading CIFAR-10  
- `torch.utils.data.DataLoader` for efficient batching  
- `torchvision.utils` and `matplotlib` for visualization  
- `tqdm` for progress tracking during training  

---

## üìà Expected Outcomes

By the end of this notebook, you will obtain:
- A trained **VAE** and **VQ-VAE** model on CIFAR-10  
- Visual comparisons of **reconstructed** and **sampled** images  
- A clear understanding of how quantization affects generative performance  

---



## üß© Environment Setup and Imports

Before building and training our models, we begin by importing all the necessary Python libraries and modules.

- **Core Libraries:**  
  - `os`, `math`, `random`, and `pathlib` provide essential utilities for file management, mathematical operations, and reproducibility.  
  - `tqdm` adds progress bars for cleaner, real-time feedback during training.  
  - `numpy` and `matplotlib` are used for numerical computations and data visualization.  

- **PyTorch Framework:**  
  - `torch` and `torch.nn` form the foundation for defining and training deep learning models.  
  - `torch.nn.functional` provides functional interfaces for activation functions and loss computations.  
  - `torch.utils.data.DataLoader` handles batching and shuffling of training data efficiently.  

- **TorchVision Utilities:**  
  - `torchvision` and `torchvision.transforms` are used to load and preprocess the **CIFAR-10 dataset**.  
  - `torchvision.utils` provides convenient functions such as `make_grid` and `save_image` to visualize model outputs during and after training.

The `%matplotlib inline` magic command ensures that all plots generated using `matplotlib` are displayed directly within the notebook cells.

This setup provides a clean, modular foundation for developing and comparing our **VAE** and **VQ-VAE** models.


In [None]:
import os
import math
import random
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.utils import make_grid, save_image

In [None]:
## ‚öôÔ∏è Hardware Setup, Reproducibility, and Precision Control

Before training our models, we configure key runtime settings to ensure consistent results and optimal GPU utilization.

- **Mixed Precision Training:**  
  We import `GradScaler` and `autocast` from `torch.cuda.amp`, which enable **automatic mixed precision (AMP)**.  
  Mixed precision allows certain parts of the computation to run in lower precision (float16) while keeping model stability in higher precision (float32), leading to faster training and reduced GPU memory usage ‚Äî especially useful on GPUs like the **NVIDIA P100**.

- **Reproducibility:**  
  To make results consistent across runs, we fix random seeds for Python, NumPy, and PyTorch using the value `42`.  
  This ensures that data shuffling, weight initialization, and other stochastic processes behave deterministically.

- **Device Configuration:**  
  The training automatically uses **GPU (CUDA)** if available; otherwise, it falls back to CPU.  
  Using GPU acceleration is crucial for deep generative models like VAE and VQ-VAE, as it significantly speeds up both forward and backward passes.

- **Output Directory:**  
  Finally, an output directory named `./outputs` is created to store generated images, checkpoints, and training logs.  
  This helps maintain a clean and organized project structure throughout the experiment.


In [None]:
# Mixed precision scaler
from torch.cuda.amp import GradScaler, autocast
from torch import amp

autocast = amp.autocast

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Output directory
OUTDIR = Path("./outputs")
OUTDIR.mkdir(parents=True, exist_ok=True)

## üéõÔ∏è Hyperparameter Configuration

This section defines the core hyperparameters that govern model training, optimization, and experiment behavior.  
Each setting is tuned for training **VAE** and **VQ-VAE** on **CIFAR-10** with efficient GPU utilization (e.g., NVIDIA P100).

- **EPOCHS:**  
  Controls the total number of full passes through the dataset.  
  Higher values (150‚Äì300) yield better convergence and reconstruction quality, though training time increases accordingly.

- **BATCH_SIZE:**  
  Determines how many samples are processed per GPU iteration.  
  A batch size of `128` provides a good balance between GPU memory usage and stable gradient estimation on a 16GB P100.

- **LR (Learning Rate):**  
  The step size used by the optimizer.  
  A rate of `2e-4` is a common, stable choice for training VAEs and VQ-VAEs with Adam or AdamW optimizers.

- **VAE_ZDIM:**  
  Dimensionality of the latent space in the **standard VAE**.  
  This controls the capacity of the model to encode image features.

- **VQ_EMBED_DIM** and **VQ_N_EMBED:**  
  Define the embedding dimensionality and vocabulary size of the **VQ-VAE codebook**.  
  These parameters directly affect the compression level and the representational diversity of the learned discrete latent space.

- **BETA:**  
  The KL-divergence weight in the VAE loss function.  
  Adjusting `BETA` changes the trade-off between reconstruction fidelity and latent regularization.

- **NUM_SAMPLES:**  
  Number of generated images to sample at the end of training for qualitative comparison between models.

- **SAVE_EVERY:**  
  Specifies how frequently (in epochs) model checkpoints and image grids are saved to disk.

- **USE_AMP:**  
  Enables **Automatic Mixed Precision (AMP)** training for faster computation and reduced GPU memory load.

> üí° *Tuning these parameters allows you to trade off between training speed, memory efficiency, and model quality.*


In [None]:
# For "full quality" on a P100: epochs ~ 150-300 recommended; set below and adjust to your run-time budget.
EPOCHS = 200                  # set to 200 for full-quality; reduce for quicker runs
BATCH_SIZE = 128              # P100 16GB: 128 should be fine; reduce if OOM
LR = 2e-4
VAE_ZDIM = 128
VQ_EMBED_DIM = 64
VQ_N_EMBED = 512
BETA = 1.0                    # VAE KL weight
NUM_SAMPLES = 40              # number of samples to generate at the end (30-40 as you requested)
SAVE_EVERY = 10               # epochs between saving checkpoints and sample grids
USE_AMP = True                # mixed precision training (recommended for speed on GPU)

print("Hyperparameters:")
print(f"EPOCHS={EPOCHS}, BATCH_SIZE={BATCH_SIZE}, LR={LR}, VAE_ZDIM={VAE_ZDIM}, VQ_EMBED_DIM={VQ_EMBED_DIM}, VQ_N_EMBED={VQ_N_EMBED}")


## üß† Dataset Preparation ‚Äî CIFAR-10

This section sets up the **CIFAR-10 dataset**, which serves as the benchmark for both the **VAE** and **VQ-VAE** models.  
CIFAR-10 is a widely used computer vision dataset containing **60,000 color images** of size **32√ó32** across **10 object categories** (e.g., airplane, car, dog, etc.).  
It is divided into **50,000 training** and **10,000 test** images.

### üîß Transformations
The preprocessing pipeline is intentionally minimal:
- `transforms.ToTensor()` converts each image from a PIL format into a PyTorch tensor and rescales pixel values from `[0, 255]` to `[0, 1]`.

> üß© *No normalization or augmentation is applied here* ‚Äî since the goal is to analyze and compare latent representations, not to maximize classification accuracy.

### ‚öôÔ∏è DataLoader Configuration
- **`train_loader`** and **`test_loader`** efficiently stream data batches to the GPU.  
- `batch_size` matches the previously defined `BATCH_SIZE` hyperparameter.  
- `shuffle=True` ensures randomness in the training order for better generalization.  
- `num_workers=4` enables multi-threaded data loading for speed.  
- `pin_memory=True` improves GPU data transfer efficiency.

Overall, this setup ensures a **lightweight, stable, and reproducible data pipeline** suitable for unsupervised generative model training.


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),  # [0,1]
])

trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
testset  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
test_loader  = DataLoader(testset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
# helper: show grid inline
def show_grid(tensor_imgs, nrow=8, title=None, figsize=(10,4), savepath=None):
    grid = make_grid(tensor_imgs.cpu(), nrow=nrow, padding=2)
    npimg = grid.numpy().transpose(1,2,0)
    plt.figure(figsize=figsize)
    plt.axis('off')
    if title: plt.title(title)
    plt.imshow(np.clip(npimg, 0, 1))
    if savepath:
        plt.savefig(savepath, bbox_inches='tight', dpi=150)
    plt.show()

# show some real images
batch_vis = next(iter(train_loader))[0][:16]
show_grid(batch_vis, nrow=8, title="Sample real CIFAR-10 images", figsize=(8,2))

# Convolutional Variational Autoencoder (VAE) for Image Generation

Variational Autoencoders (VAEs) are a class of **probabilistic generative models** that learn to represent complex data distributions in a structured latent space. Unlike standard autoencoders, VAEs impose a **probabilistic structure** on the latent space, allowing for smooth interpolation, sample generation, and principled uncertainty estimation.

This notebook presents a **convolutional VAE (ConvVAE)** architecture specifically designed for small RGB images, such as those in the CIFAR-10 dataset. By leveraging convolutional layers in both the encoder and decoder, the model can efficiently capture hierarchical spatial features, producing high-quality reconstructions while maintaining a well-regularized latent space.

The key objectives of this notebook are:

1. **Implementation:** Build a convolutional VAE from scratch using PyTorch.
2. **Latent Representation:** Learn a compact, structured latent space for image data.
3. **Image Reconstruction:** Demonstrate the model's ability to reconstruct input images with fidelity.
4. **Generative Capability:** Sample from the latent space to generate novel images.
5. **Analysis:** Examine reconstruction quality and latent space properties for interpretability.

This work serves as both a **teaching resource** for understanding VAEs and a **foundation for research applications** in generative modeling, representation learning, and image synthesis.


In [None]:
# Cell 4: Standard Conv VAE implementation
class ConvEncoderVAE(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),   # 64 x 16 x 16
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1), # 128 x 8 x 8
            nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1),# 256 x 4 x 4
            nn.ReLU(True),
        )
        self.fc_mu = nn.Linear(256*4*4, z_dim)
        self.fc_logvar = nn.Linear(256*4*4, z_dim)
    def forward(self, x):
        h = self.conv(x)
        h = h.view(h.size(0), -1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

class ConvDecoderVAE(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.fc = nn.Linear(z_dim, 256*4*4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), # 128 x 8 x 8
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 64 x 16 x 16
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),    # 3 x 32 x 32
            nn.Sigmoid(),
        )
    def forward(self, z):
        h = self.fc(z).view(-1, 256, 4, 4)
        xrec = self.deconv(h)
        return xrec

class VAE(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.enc = ConvEncoderVAE(z_dim=z_dim)
        self.dec = ConvDecoderVAE(z_dim=z_dim)
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    def forward(self, x):
        mu, logvar = self.enc(x)
        z = self.reparameterize(mu, logvar)
        xr = self.dec(z)
        return xr, mu, logvar

def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    # Use MSE reconstruction (works fine for CIFAR)
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kld, recon_loss, kld

# Vector-Quantized Variational Autoencoder (VQ-VAE)

This notebook implements a **Vector-Quantized Variational Autoencoder (VQ-VAE)**, a discrete latent generative model that combines convolutional neural networks with vector quantization. Unlike standard VAEs, VQ-VAEs learn a **discrete latent codebook**, enabling efficient compression, high-quality reconstructions, and stable training. 

The model architecture consists of three main components:

1. **Encoder:** Maps the input image to a continuous latent representation.
2. **Vector Quantizer (VQ) with EMA):** Discretizes the latent features using a learned codebook and maintains the embeddings with **Exponential Moving Average (EMA)** updates for numerical stability.
3. **Decoder:** Reconstructs images from quantized latent codes.

---

In [None]:
class VQEmbeddingEMA(nn.Module):
    """
    Vector Quantizer with Exponential Moving Average (EMA) updates.
    Ensures numerical stability and prevents NaN loss.
    """
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25, decay=0.99, eps=1e-5):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.decay = decay
        self.eps = eps

        # Codebook initialization
        embed = torch.randn(num_embeddings, embedding_dim)
        self.register_buffer('embedding', embed)
        self.register_buffer('ema_cluster_size', torch.zeros(num_embeddings))
        self.register_buffer('ema_w', embed.clone())
        self.initialized = False

    def _initialize_embeddings(self, flat):
        """Initialize embeddings from first batch."""
        with torch.no_grad():
            N = flat.size(0)
            n = min(self.num_embeddings, N)
            perm = torch.randperm(N, device=flat.device)
            init = flat[perm[:n]]
            if n < self.num_embeddings:
                pad = torch.randn(self.num_embeddings - n, flat.size(1), device=flat.device)
                init = torch.cat([init, pad], dim=0)
            self.embedding.copy_(init)
            self.ema_w.copy_(init)
            self.ema_cluster_size.zero_()
            self.initialized = True

    def forward(self, z_e):
        """
        Args:
            z_e: Encoder output (B, C, H, W)
        Returns:
            quantized: Discretized latent
            vq_loss: Quantization + commitment loss
            indices: Code indices
        """
        b, c, h, w = z_e.shape
        device = z_e.device
        flat = z_e.permute(0, 2, 3, 1).contiguous().view(-1, c)

        # Lazy initialization
        if not self.initialized:
            self._initialize_embeddings(flat)

        # Compute L2 distances
        distances = (
            torch.sum(flat ** 2, dim=1, keepdim=True)
            - 2 * flat @ self.embedding.t()
            + torch.sum(self.embedding ** 2, dim=1).unsqueeze(0)
        )

        encoding_indices = torch.argmin(distances, dim=1)
        encodings = F.one_hot(encoding_indices, self.num_embeddings).type(flat.dtype)

        # Quantized output
        quantized = (encodings @ self.embedding).view(b, h, w, c).permute(0, 3, 1, 2)

        # EMA updates (only during training)
        if self.training:
            with torch.no_grad():
                enc_sum = encodings.sum(0)
                dw = encodings.t() @ flat

                self.ema_cluster_size.mul_(self.decay).add_(enc_sum, alpha=1 - self.decay)
                self.ema_w.mul_(self.decay).add_(dw, alpha=1 - self.decay)

                # Normalize embeddings
                n = self.ema_cluster_size.sum()
                cluster_size = ((self.ema_cluster_size + self.eps) /
                                (n + self.num_embeddings * self.eps)) * n
                self.embedding.copy_(self.ema_w / cluster_size.unsqueeze(1))

        # Compute VQ loss
        e_latent_loss = F.mse_loss(quantized.detach(), z_e)
        q_latent_loss = F.mse_loss(quantized, z_e.detach())
        vq_loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # Straight-through estimator
        quantized = z_e + (quantized - z_e).detach()
        indices = encoding_indices.view(b, h, w)

        return quantized, vq_loss, indices

# -----------------------------
# VQ-VAE model
# -----------------------------
class VQVAE(nn.Module):
    def __init__(self, embedding_dim=64, n_embeddings=512, encoder_base_channels=(128,256)):
        super().__init__()
        ch1, ch2 = encoder_base_channels

        # Encoder network
        self.enc_body = nn.Sequential(
            nn.Conv2d(3, ch1, 4, 2, 1),  # 32 ‚Üí 16
            nn.ReLU(True),
            nn.Conv2d(ch1, ch2, 4, 2, 1),  # 16 ‚Üí 8
            nn.ReLU(True),
        )

        # Project to embedding dimension
        self.enc_to_vq = nn.Conv2d(ch2, embedding_dim, 1)

        # Vector Quantizer
        self.quantizer = VQEmbeddingEMA(n_embeddings, embedding_dim, commitment_cost=0.25)

        # Decoder network
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, 256, 4, 2, 1),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.ReLU(True),
            nn.Conv2d(128, 3, 1),
            nn.Sigmoid()
        )
    def enc(self, x):
        """Encodes input to latent feature map before quantization."""
        return self.enc_to_vq(self.enc_body(x))

    def forward(self, x):
        z = self.enc_body(x)
        z_e = self.enc_to_vq(z)
        z_q, vq_loss, indices = self.quantizer(z_e)
        x_recon = self.dec(z_q)
        return x_recon, vq_loss, indices



# Generative Modeling with VAE and VQ-VAE

This notebook demonstrates the implementation and training of two powerful generative models for image reconstruction and synthesis: a **Convolutional Variational Autoencoder (VAE)** and a **Vector-Quantized VAE (VQ-VAE)**. Both models leverage convolutional neural networks to capture hierarchical spatial features in images, while employing distinct latent representations.

---

## 1. Convolutional Variational Autoencoder (VAE)

The VAE is a probabilistic generative model that maps input images to a **continuous latent space** and reconstructs them using a decoder network.

- **Encoder:** Convolutional layers reduce spatial resolution and output the **mean (`Œº`)** and **log-variance (`logœÉ¬≤`)** of a latent Gaussian distribution.
- **Reparameterization Trick:** Samples latent vector \( z = \mu + \sigma \cdot \epsilon \) to allow gradient backpropagation through stochastic nodes.
- **Decoder:** Transpose convolutions reconstruct the image from latent samples.
- **Loss Function:** Combines reconstruction loss (Mean Squared Error) and Kullback-Leibler divergence:
\[
\mathcal{L}_{VAE} = \text{MSE}(\hat{x}, x) + \beta \cdot \text{KL}(q(z|x)\|p(z))
\]

---

## 2. Vector-Quantized VAE (VQ-VAE)

The VQ-VAE uses a **discrete latent space** with a learned codebook of embeddings, enabling high-quality image reconstruction and efficient compression.

- **Encoder:** Maps images to a latent feature map.
- **Vector Quantizer with EMA:** Discretizes latent features by assigning each feature to its nearest codebook vector. Exponential Moving Average (EMA) updates maintain numerical stability and prevent NaN losses.
- **Decoder:** Reconstructs images from quantized latent codes using transpose convolutions.
- **Loss Function:** Combines reconstruction loss and commitment cost to enforce latent alignment:
\[
\mathcal{L}_{VQ} = \text{MSE}(z_q, z_e.\text{detach}()) + \beta \cdot \text{MSE}(z_q.\text{detach}(), z_e)
\]
- **Straight-Through Estimator:** Ensures gradients flow through the non-differentiable quantization operation.

---

## 3. Training Setup

- **Optimizers:** `Adam` is used separately for VAE and VQ-VAE.
- **Mixed-Precision Training:** Optional AMP via `GradScaler` for faster computation and lower memory usage.
- **Metric Tracking:** History dictionaries record reconstruction loss, KL divergence, and quantization losses per epoch.
- **Visualization Utility:** `save_grid()` function saves reconstructed or generated images in a grid format for inspection.

---

## 4. Summary

This notebook establishes a **complete generative modeling pipeline**:

1. **VAE**: Continuous latent space for smooth interpolation and probabilistic reconstruction.
2. **VQ-VAE**: Discrete latent space with EMA-updated codebook for high-fidelity image generation.
3. **Optimized Training**: Separate optimizers and optional AMP scalers for efficiency.
4. **Evaluation**: Metric tracking and image grid visualization enable monitoring of reconstruction quality and training progress.

Together, these components provide a robust framework for **image representation learning, generative modeling, and latent-space exploration**.


In [None]:
# Cell 6: Instantiate models, optimizers, scalers
vae = VAE(z_dim=VAE_ZDIM).to(device)
vqvae = VQVAE(embedding_dim=VQ_EMBED_DIM, n_embeddings=VQ_N_EMBED).to(device)

opt_vae = torch.optim.Adam(vae.parameters(), lr=LR)
opt_vq  = torch.optim.Adam(vqvae.parameters(), lr=LR)

scaler_vae = GradScaler('cuda', enabled=USE_AMP)
scaler_vq  = GradScaler('cuda', enabled=USE_AMP)

# trackers
history = {
    'vae_loss': [], 'vae_recon': [], 'vae_kld': [],
    'vq_loss': [], 'vq_recon': [], 'vq_vqterm': []
}

# helper: save samples grid
def save_grid(tensor, path, nrow=8):
    save_image(tensor.cpu(), path, nrow=nrow)


In [None]:
# Cell 7: Training loops with visualization
def train_epoch_vae(model, dataloader, optimizer, scaler, epoch):
    model.train()
    running_loss = 0.0
    running_recon = 0.0
    running_kld = 0.0
    pbar = tqdm(dataloader, desc=f"VAE Train E{epoch}")
    for x, _ in pbar:
        x = x.to(device)
        optimizer.zero_grad()
        with autocast('cuda', enabled=USE_AMP):
            xr, mu, logvar = model(x)
            loss, rec, kld = vae_loss(xr, x, mu, logvar, beta=BETA)
            loss_val = loss
        scaler.scale(loss_val).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss_val.item()
        running_recon += rec.item()
        running_kld += kld.item()
        pbar.set_postfix({'loss': f"{running_loss/((pbar.n+1)*x.size(0)):.4f}"})
    n = len(dataloader.dataset)
    return running_loss/n, running_recon/n, running_kld/n

def train_epoch_vqvae(model, dataloader, optimizer, scaler, epoch, device):
    model.train()
    total_loss, recon_loss, vq_loss = 0, 0, 0

    for x, _ in tqdm(dataloader, desc=f"VQ-VAE Train E{epoch}"):
        x = x.to(device)

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            x_recon, vq_loss_item, _ = model(x)
            recon = F.mse_loss(x_recon, x)
            loss = recon + vq_loss_item

        # Skip invalid loss
        if torch.isnan(loss):
            continue

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        recon_loss += recon.item()
        vq_loss += vq_loss_item.item()

    n = len(dataloader)
    return total_loss/n, recon_loss/n, vq_loss/n


In [None]:
# small helper to show reconstructions
@torch.no_grad()
def visualize_reconstructions(model, mode='vae', n=8, savepath=None, step_label=""):
    model.eval()
    x = next(iter(test_loader))[0][:n].to(device)
    if mode == 'vae':
        xr, mu, logvar = model(x)
    else:  # vqvae
        xr, vq_loss, _ = model(x)
    combined = torch.cat([x.cpu(), xr.cpu()], dim=0)
    title = f"{mode.upper()} Reconstructions {step_label}"
    if savepath:
        save_grid(combined, savepath, nrow=n)
    show_grid(combined, nrow=n, title=title)

# collect code-index histogram (for sampling later)
@torch.no_grad()
def collect_vq_code_histogram(vq_model, loader):
    vq_model.eval()
    counts = torch.zeros(vq_model.quantizer.num_embeddings, device='cpu')
    for x, _ in tqdm(loader, desc="Collect VQ code histogram"):
        x = x.to(device)
        _, _, indices = vq_model(x)
        indices_cpu = indices.cpu().view(-1)
        for idx in indices_cpu:
            counts[idx] += 1
    probs = (counts + 1e-6) / (counts.sum() + 1e-6 * counts.size(0))  # smooth
    return probs.numpy()

# sampling for VAE
@torch.no_grad()
def sample_vae(model, n, z_dim):
    model.eval()
    z = torch.randn(n, z_dim, device=device)
    xr = model.dec(z)
    return xr.clamp(0,1)

# sampling for VQ-VAE using empirical code distribution
@torch.no_grad()
def sample_vqvae_empirical(vq_model, n, code_probs=None):
    vq_model.eval()
    dummy = torch.zeros(1, 3, 32, 32, device=device)

    # use new encoder
    ze = vq_model.enc_to_vq(vq_model.enc_body(dummy))
    _, c, h, w = ze.shape

    if code_probs is None:
        code_probs = np.ones(vq_model.quantizer.num_embeddings) / vq_model.quantizer.num_embeddings

    # sample index grid
    idx = np.random.choice(vq_model.quantizer.num_embeddings, size=(n,h,w), p=code_probs)

    # get embeddings
    emb = vq_model.quantizer.embedding.to(device)  # [E, C]
    emb_t = emb[idx]  # n,h,w,C
    # convert to n,C,H,W
    emb_t = torch.tensor(emb[idx], dtype=torch.float32, device=device).permute(0,3,1,2)

    xr = vq_model.dec(emb_t)
    return xr.clamp(0,1)


In [None]:
import logging
from pathlib import Path

# Setup output directory
OUTDIR.mkdir(parents=True, exist_ok=True)

# Configure logging
log_file = OUTDIR / "training.log"
logging.basicConfig(
    filename=str(log_file),
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    filemode='w'
)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
console.setFormatter(formatter)
logging.getLogger().addHandler(console)

logging.info("üöÄ Starting training")


# VAE Training Loop

This cell implements the **training procedure for the convolutional VAE**, including mixed-precision support, metric tracking, checkpointing, and sample visualization.

---

## 1. GradScaler for Mixed-Precision Training

```python
scaler_vae = torch.amp.GradScaler("cuda", enabled=USE_AMP)


In [None]:
# --- GradScaler
scaler_vae = torch.amp.GradScaler("cuda", enabled=USE_AMP)

history_vae = {"loss": [], "recon": [], "kld": []}
best_vaeloss = float("inf")

vae.train()
for epoch in range(1, EPOCHS + 1):
    loss_epoch, rec_epoch, kld_epoch = train_epoch_vae(
        vae, train_loader, opt_vae, scaler_vae, epoch
    )

    history_vae["loss"].append(loss_epoch)
    history_vae["recon"].append(rec_epoch)
    history_vae["kld"].append(kld_epoch)

    logging.info(f"VAE Epoch {epoch}: total={loss_epoch:.4f}, recon={rec_epoch:.4f}, kld={kld_epoch:.4f}")

    # Periodic reconstructions
    if epoch % SAVE_EVERY == 0 or epoch == 1 or epoch == EPOCHS:
        vpath = OUTDIR / f"vae_recon_e{epoch}.png"
        visualize_reconstructions(vae, mode='vae', n=8, savepath=str(vpath), step_label=f"epoch{epoch}")

        vae_samples = sample_vae(vae, n=min(36, NUM_SAMPLES), z_dim=VAE_ZDIM)
        save_grid(vae_samples, OUTDIR / f"vae_samples_e{epoch}.png", nrow=6)

    # Save checkpoint
    if epoch % SAVE_EVERY == 0 or epoch == EPOCHS:
        torch.save(vae.state_dict(), OUTDIR / f"vae_epoch{epoch}.pt")
        logging.info(f"üíæ VAE checkpoint saved at epoch {epoch}")

# Final save
torch.save(vae.state_dict(), OUTDIR / "vae_final.pt")
logging.info("‚úÖ VAE training complete")


# VQ-VAE Training Loop

This cell implements the **training procedure for the Vector-Quantized Variational Autoencoder (VQ-VAE)**, including mixed-precision support, metric tracking, checkpointing, and sample visualization.

---

## 1. GradScaler for Mixed-Precision Training

```python
scaler_vq = torch.amp.GradScaler(enabled=USE_AMP)


In [None]:
# ------------------------
# GradScaler
# ------------------------
scaler_vq = torch.amp.GradScaler(enabled=USE_AMP)

# History tracking
history_vq = {"loss": [], "recon": [], "vqterm": []}
best_vqloss = float("inf")

# Ensure output folder exists
(VQDIR := OUTDIR / "vq_vae").mkdir(parents=True, exist_ok=True)

# ------------------------
# Training loop
# ------------------------
vqvae.train()
for epoch in range(1, EPOCHS + 1):
    loss_epoch, rec_epoch, vq_epoch = train_epoch_vqvae(
        vqvae, train_loader, opt_vq, scaler_vq, epoch, device
    )

    history_vq["loss"].append(loss_epoch)
    history_vq["recon"].append(rec_epoch)
    history_vq["vqterm"].append(vq_epoch)

    logging.info(
        f"VQ-VAE Epoch {epoch}: total={loss_epoch:.4f}, recon={rec_epoch:.4f}, vq={vq_epoch:.4f}"
    )

    # ------------------------
    # Periodic reconstructions
    # ------------------------
    if epoch % SAVE_EVERY == 0 or epoch == 1 or epoch == EPOCHS:
        vqpath = VQDIR / f"vq_recon_e{epoch}.png"
        visualize_reconstructions(
            vqvae, mode='vq', n=8, savepath=str(vqpath), step_label=f"epoch{epoch}"
        )

        # Fixed empirical sampling
        with torch.no_grad():
            vqvae.eval()
            dummy = torch.zeros(1, 3, 32, 32, device=device)
            ze = vqvae.enc_to_vq(vqvae.enc_body(dummy))
            _, C, H, W = ze.shape

            # sample indices uniformly or with given probabilities
            idx = torch.randint(
                0, vqvae.quantizer.num_embeddings, (min(36, NUM_SAMPLES), H, W), device=device
            )
            emb = vqvae.quantizer.embedding.to(device)  # [E, C]
            emb_t = emb[idx].permute(0,3,1,2)           # n, C, H, W
            xr = vqvae.dec(emb_t).clamp(0,1)

        save_image(xr, VQDIR / f"vq_samples_e{epoch}.png", nrow=6)

    # ------------------------
    # Save checkpoint
    # ------------------------
    if epoch % SAVE_EVERY == 0 or epoch == EPOCHS:
        torch.save(vqvae.state_dict(), VQDIR / f"vqvae_epoch{epoch}.pt")
        logging.info(f"üíæ VQ-VAE checkpoint saved at epoch {epoch}")

# ------------------------
# Final save
# ------------------------
torch.save(vqvae.state_dict(), VQDIR / "vqvae_final.pt")
logging.info("‚úÖ VQ-VAE training complete")

# Collecting VQ-VAE Code Histogram

This cell computes and visualizes the **empirical distribution of VQ-VAE codebook assignments** across the training dataset. Understanding this distribution is critical for **improving sample quality** when generating images from the discrete latent space.

---

## 1. Motivation

- The **VQ-VAE latent space** consists of a fixed set of discrete embeddings (codebook vectors).  
- During training, each latent feature is assigned to its **nearest codebook vector**.
- Some embeddings may be underrepresented or never used, leading to **poor sampling quality** if codebook indices are sampled uniformly.  
- By collecting a **histogram of code assignments**, we can:
  1. Identify rarely-used codes.
  2. Compute **empirical probabilities** for more realistic latent sampling.
  3. Improve the fidelity of generated images.

---

## 2. Code Histogram Collection

```python
print("Collecting VQ code histogram from training set (this helps VQ-VAE sampling quality)...")
code_probs = collect_vq_code_histogram(vqvae, train_loader)  # numpy array


In [None]:
# Cell 10: Collect VQ code histogram (for better sampling)
print("Collecting VQ code histogram from training set (this helps VQ-VAE sampling quality)...")
code_probs = collect_vq_code_histogram(vqvae, train_loader)  # numpy array
# save histogram plot
plt.figure(figsize=(6,3))
plt.bar(np.arange(len(code_probs)), code_probs, width=1.0)
plt.title("VQ code empirical distribution (first 200 shown)")
plt.xlim(0, min(200, len(code_probs)))
plt.savefig(OUTDIR / "vq_code_histogram.png", dpi=150)
plt.show()


In [None]:
# Final Sampling and Comparison: VAE vs VQ-VAE

This cell generates **final samples** from both the trained **VAE** and **VQ-VAE** models and visualizes them side-by-side for qualitative comparison.

---

## 1. Sample Generation

```python
n = NUM_SAMPLES
vae_samples = sample_vae(vae, n=n, z_dim=VAE_ZDIM)
vq_samples  = sample_vqvae_empirical(vqvae, n=n, code_probs=code_probs)


In [None]:
n = NUM_SAMPLES
print(f"Generating {n} samples from each model...")

vae_samples = sample_vae(vae, n=n, z_dim=VAE_ZDIM)
vq_samples  = sample_vqvae_empirical(vqvae, n=n, code_probs=code_probs)

# Save grids separately
save_grid(vae_samples, OUTDIR / f"vae_samples_final.png", nrow=min(10,n))
save_grid(vq_samples,  OUTDIR / f"vq_samples_final.png",  nrow=min(10,n))

# Combine both in a single figure: first row VAE, second row VQ-VAE
combined = torch.cat([vae_samples, vq_samples], dim=0)
show_grid(combined, nrow=min(10,n), title=f"Top: VAE ({n} samples). Bottom: VQ-VAE ({n} samples)", figsize=(16,6),
          savepath=OUTDIR / "comparison_final.png")

print("Saved final sample grids to", OUTDIR)

In [None]:
# Cell 14: Summary printout and where outputs are saved
print("All done. Outputs saved in:", OUTDIR)
print("- Model checkpoints: vae_final.pt, vqvae_final.pt (and epoch checkpoints)")
print("- Reconstructions: vae_recon_*.png, vq_recon_*.png")
print("- Sample grids: vae_samples_final.png, vq_samples_final.png, comparison_final.png")
print("- Training loss plot: train_loss_curves.png")
print("- VQ code histogram: vq_code_histogram.png")
print("- Interpolation/morph visuals: vae_interpolation.png, vq_morph.png")