# VAEs & GANs

Introduction to image generation with VAEs and GANs

| Date | User | Change Type | Remarks |  
| ---- | ---- | ----------- | ------- |
| 24/02/2026   | Martin | Created   | VAE model with celeb faces data | 

# Content

* [Variational Autoencoders](#variational-autoencoder-vae)
* [Generative Adversarial Networks](#generative-adversarial-networks-gans)

# Variational Autoencoder (VAE)

Dataset: [Celeb Faces](https://www.kaggle.com/datasets/vishesh1412/celebrity-face-image-dataset)

<u>SOTA Architectures</u>

- __VQ-VAE/ VQ-VAE-2__ - Replaces the continuous latent space with discrete codebook. The latent space is quantised to the nearest embedding vector
- __Hierarchical VAEs__ - Stacks multiple layers of latent variables. Top latent captures global strucutre, while lower latents capture finer details
- __⭐ Latent Diffusion Models__ - VAE compresses image into latent space, then diffusion model operates in that latent space. (Stable Diffusion implementation). VAE is typically a CNN architecture with residual blocks and attention

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os
import math
import numpy as np
from pathlib import Path

In [3]:
# ==================== Components ====================
class ResBlock(nn.Module):
  """
  Residual block with GroupNorm. GroupNorm >> BatchNorm for VAEs because
  batch statistics are unstable with the stochastic latent sampling.
  num_groups=32 is standard; reduce if channels < 32
  """
  def __init__(self, in_channels: int, out_channels: int, dropout: float=0.0):
    super().__init__()
    self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
    self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6)
    self.dropout = nn.Dropout(dropout)
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
    # Residual connection
    self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

  def forward(self, x):
    h = self.conv1(F.silu(self.norm1(x)))
    h = self.conv2(self.dropout(F.silu(self.norm2(h))))
    return h + self.skip(x)


class SelfAttention(nn.Module):
  """
  Single-head self-attention at spatial positions. Applied at low-resolution
  feature maps (e.g., 8x8, 16x16) to keep compute tractable.
  Multi-head attention is better but much more expensive for ablations.
  """
  def __init__(self, channels: int):
    super().__init__()
    self.norm = nn.GroupNorm(32, channels, eps=1e-6)
    self.q = nn.Conv2d(channels, channels, 1)
    self.k = nn.Conv2d(channels, channels, 1)
    self.v = nn.Conv2d(channels, channels, 1)
    self.proj = nn.Conv2d(channels, channels, 1)
    self.scale = channels ** -0.5
  
  def forward(self, x):
    B, C, H, W = x.shape
    h = self.norm(x)
    q = self.q(h).reshape(B, C, -1) # (B, C, HW)
    k = self.k(h).reshape(B, C, -1)
    v = self.v(h).reshape(B, C, -1)
    attn = torch.softmax(torch.bmm(q.transpose(1, 2), k) * self.scale, dim=-1) # (B, HW, HW)
    out = torch.bmm(v, attn.transpose(1, 2)).reshape(B, C, H, W)

    return x + self.proj(out)


class DownSample(nn.Module):
  """Strided conv is better than pooling - learns how to downsample"""
  def __init__(self, channels):
    super().__init__()
    self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)
  
  def forward(self, x):
    return self.conv1(x)


class Upsample(nn.Module):
  """Nearest-neighbor + conv avoids checkerboard artifacts from transposed conv."""
  def __init__(self, channels):
    super().__init__()
    self.conv = nn.Conv2d(channels, channels, 3, padding=1)

  def forward(self, x):
    return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))

In [4]:
# ==================== Encoder ====================
class Encoder(nn.Module):
  """
  - channel_multipliers: controls width at each resolution level.
    (1, 2, 4, 8) means 4 downsampling stages. More stages = more compression.
  - base_channels: 64 for lightweight, 128 for quality, 256 for SOTA.
  - latent_dim: 4 is standard for image VAEs (LDM uses 4).
    Higher = more expressivity but harder KL regularization.
  - attn_resolutions: which spatial sizes get attention. Smaller = more global context.
  """
  def __init__(
    self,
    in_channels: int=3,
    base_channels: int=128,
    channel_multipliers: tuple=(1, 2, 4, 8),
    latent_dim: int=4,
    num_res_blocks: int=2,
    attn_resolutions: tuple=(16,),
    dropout: float=0.0,
    image_size: int=256
  ):
    super().__init__()
    self.conv_in = nn.Conv2d(in_channels, base_channels, 3, padding=1)

    channels = [base_channels * m for m in channel_multipliers]
    current_res = image_size
    in_ch = base_channels

    self.down_blocks = nn.ModuleList()
    for i, out_ch in enumerate(channels):
      block = nn.ModuleList()
      for _ in range(num_res_blocks):
        block.append(ResBlock(in_ch, out_ch, dropout))
        if current_res in attn_resolutions:
          block.append(SelfAttention(out_ch))
        in_ch = out_ch
      
      # Downsample except at last level
      if i < len(channels) - 1:
        block.append(DownSample(in_ch))
        current_res //= 2
      self.down_blocks.append(block)
    
    self.mid_block1 = ResBlock(in_ch, in_ch, dropout)
    self.mid_attn = SelfAttention(in_ch)
    self.mid_block2 = ResBlock(in_ch, in_ch, dropout)

    self.norm_out = nn.GroupNorm(32, in_ch, eps=1e-6)
    self.conv_out = nn.Conv2d(in_ch, 2 * latent_dim, 3, padding=1)

  def forward(self, x):
    h = self.conv_in(x)
    for block in self.down_blocks:
      for layer in block:
        h = layer(h)
    h = self.mid_block2(self.mid_attn(self.mid_block1(h)))
    h = self.conv_out(F.silu(self.norm_out(h)))
    # Split into mean and log variance
    mu, log_var = h.chunk(2, dim=1)
    return mu, log_var

In [5]:
# ==================== Decoder ====================
class Decoder(nn.Module):
  def __init__(
    self,
    out_channels: int=3,
    base_channels: int=128,
    channel_multipliers: tuple=(1, 2, 4, 8),
    latent_dim: int=4,
    num_res_blocks: int=2,
    attn_resolutions: tuple=(16,),
    dropout: float=0.0,
    image_size: int=256
  ):
    super().__init__()
    channels = [base_channels * m for m in reversed(channel_multipliers)]
    in_ch = channels[0]
    self.conv_in = nn.Conv2d(latent_dim, in_ch, 3, padding=1)

    self.mid_block1 = ResBlock(in_ch, in_ch, dropout)
    self.mid_attn   = SelfAttention(in_ch)
    self.mid_block2 = ResBlock(in_ch, in_ch, dropout)

    current_res = image_size // (2 ** (len(channel_multipliers) - 1))
    self.up_blocks = nn.ModuleList()
    for i, out_ch in enumerate(channels[1:] + [base_channels]):
      block = nn.ModuleList()
      for j in range(num_res_blocks + 1):   # +1 vs encoder is standard
        block.append(ResBlock(in_ch, out_ch, dropout))
        if current_res in attn_resolutions:
          block.append(SelfAttention(out_ch))
        in_ch = out_ch
      if i < len(channels) - 1:
        block.append(Upsample(in_ch))
        current_res *= 2
      self.up_blocks.append(block)

    self.norm_out = nn.GroupNorm(32, in_ch, eps=1e-6)
    self.conv_out = nn.Conv2d(in_ch, out_channels, 3, padding=1)
  
  def forward(self, z):
    h = self.conv_in(z)
    h = self.mid_block2(self.mid_attn(self.mid_block1(h)))
    for block in self.up_blocks:
      for layer in block:
        h=layer(h)
    return self.conv_out(F.silu(self.norm_out(h)))

In [6]:
# Utility compute latent spatial size
def compute_latent_size(image_size: int, channel_multipliers: tuple) -> int:
  """
  Each stage (except the last) halves the spatial resolution.
  e.g., image_size=256, 4 multipliers → 3 downsamples → latent is 32x32
  """
  num_downsamples = len(channel_multipliers) - 1
  latent_size = image_size // (2 ** num_downsamples)
  assert image_size == latent_size * (2 ** num_downsamples), (
    f"image_size {image_size} is not evenly divisible by 2^{num_downsamples}. "
    f"Use a power-of-2 image size."
  )
  return latent_size

In [7]:
# ==================== VAE ====================
class VAE(nn.Module):
  def __init__(
    self,
    in_channels: int = 3,
    latent_dim: int = 4,
    base_channels: int = 128,
    channel_multipliers: tuple = (1, 2, 4, 8),
    num_res_blocks: int = 2,
    attn_resolutions: tuple = (16,),
    dropout: float = 0.0,
    image_size: int = 256,
    # ── KEY TUNING PARAMETER ──
    # β > 1  → stronger disentanglement, worse reconstruction
    # β < 1  → better reconstruction, less structured latent space
    # β = 1  → standard VAE
    # For LDM-style (encoding for a downstream diffusion model), use β ≈ 1e-6
    # to almost ignore KL and maximize reconstruction fidelity.
    beta: float = 1.0,
  ):
    super().__init__()
    self.beta = beta
    self.latent_dim = latent_dim

    enc_dec_kwargs = dict(
      base_channels=base_channels,
      channel_multipliers=channel_multipliers,
      latent_dim=latent_dim,
      num_res_blocks=num_res_blocks,
      attn_resolutions=attn_resolutions,
      dropout=dropout,
      image_size=image_size,
    )
    self.encoder = Encoder(in_channels=in_channels, **enc_dec_kwargs)
    self.decoder = Decoder(out_channels=in_channels, **enc_dec_kwargs)
  
  def reparameterize(self, mu, log_var):
    """
    WATCH OUT: log_var should be clamped during training. Unclamped log_var
    can explode to very large or very small values, causing NaN losses.
    """
    log_var = log_var.clamp(-30.0, 20.0)
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std
  
  def encode(self, x):
    mu, log_var = self.encoder(x)
    return mu, log_var  
  
  def decode(self, z):
    return self.decoder(z)
  
  def forward(self, x):
    mu, log_var = self.encode(x)
    z = self.reparameterize(mu, log_var)
    x_hat = self.decode(z)
    return x_hat, mu, log_var
  
  def compute_loss(self, x, x_hat, mu, log_var):
    B = x.size(0)

    # Per pixel MSE, averaged across all dimensions
    recon_loss = F.mse_loss(x_hat, x, reduction='sum') / B

    # KL Divergence (closed form) - standard normal
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / B

    total_loss = recon_loss + self.beta * kl_loss
    return total_loss, recon_loss, kl_loss
  
  @torch.no_grad()
  def sample(self, num_samples: int, device: str='cuda'):
    """Sample from the prior p(z) = N(0, I)"""
    # Need to know the latent spatial size at inference time
    # 256 x 256 image gets downsampled to 16 x 16
    latent_size = compute_latent_size(self.image_size, self.channel_multipliers)
    z = torch.randn(num_samples, self.latent_dim, latent_size, latent_size, device=device)
    return self.decode(z).clamp(-1, 1)


<u>Key Training Parameters</u>

- `lr`: 1e-4 is a safe default. Higher -> Faster, but unstable KL. Use warmup for first ~1000 steps.
- `beta_warmup_steps`: Start $\beta$ at 0 and linearly anneal to target. Without it model collapses the KL to 0 early and never learns a useful latent space (posterior collapse) 
- `gradient_clipping`: Clip to 1.0. VAEs can have explosive gradients from the KL term

<u>WATCH OUT!</u>

- KL loss should increase gradually. If it stays at 0 → posterior collapse.
- Reconstruction loss should decrease steadily.
- If recon loss plateaus very early → $\beta$ is too high.
- If generated samples have no structure → $\beta$ is too low, latent space is not regularized.
- Monitor $\mu$ values: if ||$\mu$|| >> 1, the encoder is ignoring the prior.
- Monitor $\sigma$ values: if $\sigma$ → 0 everywhere, you have posterior collapse.

In [8]:
class VAETrainer:
  def __init__(
    self,
    model: VAE,
    lr: float=1e-4,
    beta_warmup_steps: int=10_000,
    grad_clip: float=1.0,
    device: str='cuda'
  ):
    self.model = model.to(device)
    self.device = device
    self.beta_warmup_steps = beta_warmup_steps
    self.target_beta = model.beta
    self.grad_clip = grad_clip
    self.step = 0

    self. optimizer = AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=1e-4)
  
  def _current_beta(self):
    """Linear KL annealing — one of the most impactful tricks."""
    if self.beta_warmup_steps <= 0:
      return self.target_beta
    return self.target_beta * min(1.0, self.step / self.beta_warmup_steps)

  def train_step(self, x: torch.Tensor):
    self.model.train()
    x = x.to(self.device)

    self.model.beta = self._current_beta()

    x_hat, mu, log_var = self.model(x)
    total_loss, recon_loss, kl_loss = self.model.compute_loss(x, x_hat, mu, log_var)

    self.optimizer.zero_grad()
    total_loss.backward()

    # Gradient clipping — essential for stable VAE training
    nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)

    self.optimizer.step()
    self.step += 1

    return {
      'total_loss': total_loss.item(),
      'recon_loss': recon_loss.item(),
      'kl_loss': kl_loss.item(),
      'beta': self.model.beta,
      'mu_mean': mu.mean().item(),
      'mu_std': mu.std().item(),
      'sigma_mean': (0.5 * log_var).exp().mean().item(),
    }

In [9]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'

# model = VAE(
#   in_channels=3,
#   latent_dim=4,           # try 8 or 16 for more expressivity
#   base_channels=64,       # 64 for fast experiments, 256 for SOTA
#   channel_multipliers=(1, 2, 4, 8),  # 4 stages of 2x downsampling
#   num_res_blocks=2,
#   attn_resolutions=(16,),  # attention only at 16x16 feature maps
#   dropout=0.0,             # dropout hurts VAE quality; only use if overfitting
#   image_size=256,
#   beta=1.0,                # tune this first
# )

# trainer = VAETrainer(
#   model,
#   lr=1e-4,
#   beta_warmup_steps=10_000,   # increase for more complex datasets
#   grad_clip=1.0,
#   device=device,
# )

# # Simulate a training step
# dummy_batch = torch.randn(4, 3, 256, 256)  # (B, C, H, W), normalized to [-1, 1]
# metrics = trainer.train_step(dummy_batch)
# print(metrics)

# # Generation
# samples = model.sample(4, device=device)
# print(f"Generated samples shape: {samples.shape}")

# # Count parameters
# total_params = sum(p.numel() for p in model.parameters())
# print(f"Total parameters: {total_params / 1e6:.1f}M")

In [10]:
# ==================== Data Loader ====================
class ImageFolderDataset(Dataset):
  """
  Loads all images from a folder (and optionally subfolders).
  Resizes images to image_size x image_size and normalised to [-1, 1]
  """
  def __init__(
    self,
    folder: str,
    image_size: int = 256,
    recursive: bool = True
  ):
    self.paths = []
    folder = Path(folder)
    glob = folder.rglob('*') if recursive else folder.glob('*')
    for p in glob:
      self.paths.append(p)
    
    if len(self.paths) == 0:
      raise ValueError(f"No images found in {folder}")
    
    print(f"Found {len(self.paths)} images in {folder}")

    self.transform = transforms.Compose([
      transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),
      transforms.CenterCrop(image_size),                      # Ensures square image
      transforms.ToTensor(),                                  # [0, 255] → [0.0, 1.0]
      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # → [-1.0, 1.0]
    ])
  
  def __len__(self):
    return len(self.paths)
  
  def __getitem__(self, idx):
    try:
      img = Image.open(self.paths[idx]).convert('RGB')
      return self.transform(img)
    except Exception as e:
      print(f"Warning: failed to load {self.paths[idx]}: {e}. Returning zeros.")
      return torch.zeros(3, self.transform.transforms[1].size, self.transform.transforms[1].size)


# ==================== Generator Helper ====================
def save_samples(model, num_samples: int, save_path: str, device: str):
  model.eval()
  with torch.no_grad():
    samples = model.sample(num_samples, device=device)     # [-1, 1]
    samples = (samples + 1) / 2                            # → [0, 1]
    samples = samples.clamp(0, 1)
  save_image(samples, save_path, nrow=int(math.sqrt(num_samples)))
  print(f"Saved {num_samples} samples → {save_path}")

def save_reconstructions(model, batch: torch.Tensor, save_path: str, device: str):
  """Save original vs reconstruction side-by-side for visual quality checks."""
  model.eval()
  with torch.no_grad():
    batch = batch.to(device)
    x_hat, _, _ = model(batch)
    comparison = torch.cat([batch, x_hat], dim=0)  # stack originals and recons
    comparison = (comparison + 1) / 2
    comparison = comparison.clamp(0, 1)
  save_image(comparison, save_path, nrow=len(batch))
  print(f"Saved reconstructions → {save_path}")

In [11]:
# ----------------------------------------
# Training Loop
# ----------------------------------------
def train(
  image_folder: str = "./data/Celebrity Faces Dataset",
  output_dir: str = "./data/outputs",
  # Image config
  image_size: int = 256,

  # Model config
  latent_dim: int = 4,
  base_channels: int = 64,
  channel_multipliers: tuple = (1, 2, 4, 8),
  num_res_blocks: int = 2,
  attn_resolutions: tuple = (16,),
  beta: float = 1.0,

  # Training config
  batch_size: int = 8,
  num_epochs: int = 10,
  lr: float = 1e-4,
  beta_warmup_steps: int = 5000,
  grad_clip: float = 1.0,

  # Logging config
  log_every: int = 50,          # print metrics every N steps
  sample_every: int = 500,      # generate samples every N steps
  save_every: int = 2000,       # save checkpoint every N steps
  num_samples: int = 4,         # how many images to generate

  device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
):
  # ---- Setup --------------------------------
  os.makedirs(output_dir, exist_ok=True)
  samples_dir = os.path.join(output_dir, 'samples')
  recon_dir   = os.path.join(output_dir, 'reconstructions')
  ckpt_dir    = os.path.join(output_dir, 'checkpoints')
  for d in [samples_dir, recon_dir, ckpt_dir]:
    os.makedirs(d, exist_ok=True)

  latent_size = compute_latent_size(image_size, channel_multipliers)
  print(f"Image size: {image_size}x{image_size} → Latent size: {latent_size}x{latent_size}")

  # ---- Data ----------------------------------------
  dataset = ImageFolderDataset(image_folder, image_size=image_size)
  loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    drop_last=True,
  )

  # ---- Model ----------------------------------------
  model = VAE(
    in_channels=3,
    latent_dim=latent_dim,
    base_channels=base_channels,
    channel_multipliers=channel_multipliers,
    num_res_blocks=num_res_blocks,
    attn_resolutions=attn_resolutions,
    image_size=image_size,
    beta=beta,
  ).to(device)

  model.image_size = image_size
  model.channel_multipliers = channel_multipliers

  total_params = sum(p.numel() for p in model.parameters()) / 1e6
  print(f"Model parameters: {total_params:.1f}M")

  trainer = VAETrainer(
    model,
    lr=lr,
    beta_warmup_steps=beta_warmup_steps,
    grad_clip=grad_clip,
    device=device,
  )

  val_batch = next(iter(loader))[:6] # Use the same 6 images as validation

  global_step = 0
  for epoch in range(num_epochs):
    epoch_recon, epoch_kl, epoch_total = 0.0, 0.0, 0.0

    for batch in loader:
      metrics = trainer.train_step(batch)
      global_step += 1

      epoch_recon += metrics['recon_loss']
      epoch_kl    += metrics['kl_loss']
      epoch_total += metrics['total_loss']

      # ---- Logging ------------------------------
      if global_step % log_every == 0:
        print(
          f"Epoch {epoch+1:3d} | Step {global_step:6d} | "
          f"Total {metrics['total_loss']:8.2f} | "
          f"Recon {metrics['recon_loss']:8.2f} | "
          f"KL {metrics['kl_loss']:7.3f} | "
          f"β {metrics['beta']:.4f} | "
          f"μ={metrics['mu_mean']:.3f} σ={metrics['sigma_mean']:.3f}"
        )
      
      # ---- Generate samples + reconstruction ------------------------------
      if global_step % sample_every == 0:
        save_samples(
          model, num_samples,
          save_path=os.path.join(samples_dir, f'step_{global_step:06d}.png'),
          device=device,
        )
        save_reconstructions(
          model, val_batch,
          save_path=os.path.join(recon_dir, f'step_{global_step:06d}.png'),
          device=device,
        )
    
    n_batches = len(loader)
    print(
      f"\n── Epoch {epoch+1} Summary ──\n"
      f"   Avg Total: {epoch_total/n_batches:.2f} | "
      f"   Avg Recon: {epoch_recon/n_batches:.2f} | "
      f"   Avg KL:    {epoch_kl/n_batches:.3f}\n"
    )

  print("Training complete")
 

In [12]:
train(
  image_folder="./data/Celebrity Faces Dataset/Brad Pitt",
  output_dir="./data/outputs",
  image_size=256,
  batch_size=8,
  num_epochs=100,
  beta=1.0,
  beta_warmup_steps=100,
  sample_every=200
)

Image size: 256x256 → Latent size: 32x32
Found 100 images in data/Celebrity Faces Dataset/Brad Pitt
Model parameters: 39.6M

── Epoch 1 Summary ──
   Avg Total: 73261.36 |    Avg Recon: 72665.52 |    Avg KL:    31473.871


── Epoch 2 Summary ──
   Avg Total: 35096.31 |    Avg Recon: 34204.82 |    Avg KL:    5074.362


── Epoch 3 Summary ──
   Avg Total: 23495.90 |    Avg Recon: 22312.51 |    Avg KL:    4049.126


── Epoch 4 Summary ──
   Avg Total: 17062.72 |    Avg Recon: 15758.25 |    Avg KL:    3140.560

Epoch   5 | Step     50 | Total 13889.32 | Recon 12415.05 | KL 3008.719 | β 0.4900 | μ=-0.043 σ=0.659

── Epoch 5 Summary ──
   Avg Total: 15359.81 |    Avg Recon: 13839.02 |    Avg KL:    2847.669


── Epoch 6 Summary ──
   Avg Total: 14337.84 |    Avg Recon: 12696.50 |    Avg KL:    2511.028


── Epoch 7 Summary ──
   Avg Total: 13512.02 |    Avg Recon: 11799.77 |    Avg KL:    2211.212


── Epoch 8 Summary ──
   Avg Total: 12980.11 |    Avg Recon: 11099.27 |    Avg KL:    2105.55

# Generative Adversarial Networks (GANs)

Dataset: [ Pokemon Images ](https://www.kaggle.com/datasets/kvpratama/pokemon-images-dataset)

In [None]:
%load_ext watermark
%watermark