# Vector Quantized VAE (VQ-VAE)

## Core Idea

VQ-VAE replaces the continuous latent space of VAE with a discrete codebook of learned embeddings.
The encoder output is quantized to the nearest codebook vector, enabling discrete latent representations
that are more suitable for autoregressive priors and avoid posterior collapse.

## Mathematical Foundation

### Architecture

$$x \xrightarrow{\text{Encoder}} z_e \xrightarrow{\text{Quantize}} z_q \xrightarrow{\text{Decoder}} \hat{x}$$

### Codebook and Quantization

Codebook: $\mathcal{E} = \{e_k\}_{k=1}^K$ where $e_k \in \mathbb{R}^D$

Quantization (nearest neighbor lookup):
$$z_q = e_k \quad \text{where} \quad k = \arg\min_j \|z_e - e_j\|_2$$

### Loss Function

$$\mathcal{L} = \underbrace{\|x - \hat{x}\|_2^2}_{\text{Reconstruction}} + \underbrace{\|\text{sg}[z_e] - e\|_2^2}_{\text{Codebook}} + \underbrace{\beta\|z_e - \text{sg}[e]\|_2^2}_{\text{Commitment}}$$

where $\text{sg}[\cdot]$ is stop-gradient operator.

**Straight-Through Estimator:** Gradients flow through quantization via:
$$\frac{\partial \mathcal{L}}{\partial z_e} \approx \frac{\partial \mathcal{L}}{\partial z_q}$$

## Problem Statement

VAE limitations:
- Posterior collapse with powerful decoders
- Blurry reconstructions from continuous sampling
- Difficult to model with autoregressive priors

VQ-VAE addresses these with discrete latents and deterministic encoding.

## Algorithm Comparison

| Aspect | VAE | VQ-VAE |
|--------|-----|--------|
| Latent | Continuous | Discrete |
| Sampling | Reparameterization | Nearest neighbor |
| Prior | Gaussian | Learned (PixelCNN) |
| Posterior collapse | Common | Avoided |

## Complexity Analysis

- **Quantization:** $O(H \times W \times K \times D)$ for codebook lookup
- **Codebook:** $O(K \times D)$ parameters

In [None]:
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid

In [None]:
@dataclass
class VQVAEConfig:
    """Configuration for VQ-VAE.
    
    Core Idea:
        num_embeddings (K) and embedding_dim (D) define codebook capacity.
        Larger K = more expressiveness, but harder to train.
    """
    in_channels: int = 1
    hidden_dim: int = 128
    num_embeddings: int = 512
    embedding_dim: int = 64
    commitment_cost: float = 0.25
    
    lr: float = 1e-3
    batch_size: int = 128
    num_epochs: int = 20
    
    device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
    seed: int = 42

In [None]:
class VectorQuantizer(nn.Module):
    """Vector Quantization layer with EMA codebook update.
    
    Core Idea:
        Maps continuous encoder output to nearest discrete codebook vector.
        Uses straight-through estimator for gradient flow.
    
    Mathematical Theory:
        Quantization: $z_q = e_{\arg\min_k \|z_e - e_k\|}$
        Straight-through: $\nabla_{z_e} \mathcal{L} = \nabla_{z_q} \mathcal{L}$
    """
    
    def __init__(self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25) -> None:
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
    
    def forward(self, z_e: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        # z_e: [B, D, H, W] -> [B, H, W, D]
        z_e = z_e.permute(0, 2, 3, 1).contiguous()
        z_e_flat = z_e.view(-1, self.embedding_dim)
        
        # Compute distances to codebook
        distances = (z_e_flat.pow(2).sum(1, keepdim=True)
                    - 2 * z_e_flat @ self.embedding.weight.t()
                    + self.embedding.weight.pow(2).sum(1))
        
        # Nearest neighbor lookup
        indices = distances.argmin(dim=1)
        z_q = self.embedding(indices).view(z_e.shape)
        
        # Losses
        codebook_loss = F.mse_loss(z_q.detach(), z_e)
        commitment_loss = F.mse_loss(z_q, z_e.detach())
        vq_loss = codebook_loss + self.commitment_cost * commitment_loss
        
        # Straight-through estimator
        z_q = z_e + (z_q - z_e).detach()
        
        # [B, H, W, D] -> [B, D, H, W]
        z_q = z_q.permute(0, 3, 1, 2).contiguous()
        
        return z_q, vq_loss, indices.view(z_e.shape[:-1])

In [None]:
class Encoder(nn.Module):
    """Convolutional encoder for VQ-VAE."""
    
    def __init__(self, config: VQVAEConfig) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(config.in_channels, config.hidden_dim // 2, 4, 2, 1),
            nn.ReLU(True),
            nn.Conv2d(config.hidden_dim // 2, config.hidden_dim, 4, 2, 1),
            nn.ReLU(True),
            nn.Conv2d(config.hidden_dim, config.embedding_dim, 3, 1, 1),
        )
    
    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)


class Decoder(nn.Module):
    """Convolutional decoder for VQ-VAE."""
    
    def __init__(self, config: VQVAEConfig) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(config.embedding_dim, config.hidden_dim, 3, 1, 1),
            nn.ReLU(True),
            nn.ConvTranspose2d(config.hidden_dim, config.hidden_dim // 2, 4, 2, 1),
            nn.ReLU(True),
            nn.ConvTranspose2d(config.hidden_dim // 2, config.in_channels, 4, 2, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, z_q: Tensor) -> Tensor:
        return self.net(z_q)

In [None]:
class VQVAE(nn.Module):
    """Complete VQ-VAE model."""
    
    def __init__(self, config: VQVAEConfig) -> None:
        super().__init__()
        self.encoder = Encoder(config)
        self.vq = VectorQuantizer(config.num_embeddings, config.embedding_dim, config.commitment_cost)
        self.decoder = Decoder(config)
    
    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        z_e = self.encoder(x)
        z_q, vq_loss, indices = self.vq(z_e)
        x_recon = self.decoder(z_q)
        return x_recon, vq_loss, indices

In [None]:
class VQVAETrainer:
    """Training orchestrator for VQ-VAE."""
    
    def __init__(self, config: VQVAEConfig) -> None:
        self.config = config
        self.device = torch.device(config.device)
        torch.manual_seed(config.seed)
        
        self.model = VQVAE(config).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config.lr)
        self.history: Dict[str, List[float]] = {"recon_loss": [], "vq_loss": [], "total_loss": []}
    
    def train_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
        self.model.train()
        metrics = {k: 0.0 for k in self.history.keys()}
        
        for x, _ in dataloader:
            x = x.to(self.device)
            x_recon, vq_loss, _ = self.model(x)
            
            recon_loss = F.mse_loss(x_recon, x)
            loss = recon_loss + vq_loss
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            metrics["recon_loss"] += recon_loss.item()
            metrics["vq_loss"] += vq_loss.item()
            metrics["total_loss"] += loss.item()
        
        for k in metrics:
            metrics[k] /= len(dataloader)
            self.history[k].append(metrics[k])
        return metrics
    
    @torch.no_grad()
    def reconstruct(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        self.model.eval()
        x_recon, _, indices = self.model(x.to(self.device))
        return x_recon.cpu(), indices.cpu()

In [None]:
def create_dataloader(config: VQVAEConfig) -> DataLoader:
    transform = transforms.ToTensor()
    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    return DataLoader(dataset, batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=2)

In [None]:
def visualize_reconstruction(trainer: VQVAETrainer, dataset, n: int = 8) -> None:
    x = torch.stack([dataset[i][0] for i in range(n)])
    x_recon, indices = trainer.reconstruct(x)
    
    fig, axes = plt.subplots(2, n, figsize=(n * 1.5, 3))
    for i in range(n):
        axes[0, i].imshow(x[i].squeeze(), cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].imshow(x_recon[i].squeeze(), cmap="gray")
        axes[1, i].axis("off")
    axes[0, 0].set_ylabel("Original")
    axes[1, 0].set_ylabel("Recon")
    plt.tight_layout()
    plt.show()

In [None]:
def visualize_codebook_usage(trainer: VQVAETrainer, dataloader: DataLoader) -> None:
    trainer.model.eval()
    all_indices = []
    
    with torch.no_grad():
        for x, _ in dataloader:
            _, _, indices = trainer.model(x.to(trainer.device))
            all_indices.append(indices.cpu().flatten())
            if len(all_indices) > 10:
                break
    
    all_indices = torch.cat(all_indices).numpy()
    
    plt.figure(figsize=(12, 4))
    plt.hist(all_indices, bins=trainer.config.num_embeddings, edgecolor="black", alpha=0.7)
    plt.xlabel("Codebook Index")
    plt.ylabel("Frequency")
    plt.title(f"Codebook Usage (K={trainer.config.num_embeddings})")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    unique = len(np.unique(all_indices))
    print(f"Active codes: {unique}/{trainer.config.num_embeddings} ({100*unique/trainer.config.num_embeddings:.1f}%)")

In [None]:
if __name__ == "__main__":
    config = VQVAEConfig(num_epochs=20)
    dataloader = create_dataloader(config)
    trainer = VQVAETrainer(config)
    
    print(f"Parameters: {sum(p.numel() for p in trainer.model.parameters()):,}")
    print(f"Codebook: {config.num_embeddings} x {config.embedding_dim}")
    
    for epoch in range(config.num_epochs):
        metrics = trainer.train_epoch(dataloader)
        if (epoch + 1) % 5 == 0:
            print(f"Epoch [{epoch+1}/{config.num_epochs}] Recon: {metrics['recon_loss']:.4f} VQ: {metrics['vq_loss']:.4f}")
    
    visualize_reconstruction(trainer, dataloader.dataset)
    visualize_codebook_usage(trainer, dataloader)

## Summary

VQ-VAE uses discrete latent codes via vector quantization:

1. **Codebook:** Learned dictionary of embedding vectors
2. **Quantization:** Nearest neighbor lookup (non-differentiable)
3. **Straight-through:** Copy gradients from decoder to encoder
4. **Losses:** Reconstruction + codebook + commitment

**Advantages over VAE:**
- No posterior collapse
- Discrete latents enable autoregressive priors (PixelCNN)
- Sharper reconstructions