# Diffusion Model Capacity Investigation

Does a larger UNet improve diffusion quality? Test 3 model sizes on 20k CelebA latents.

**Assumes:** VAEs already trained, latents encoded (from scaling experiment notebook).

**Improvements over previous training:**
- Gradient clipping (max_norm=1.0)
- EMA weights (decay=0.9999)
- Cosine LR schedule

## UNet Variants
| Variant | base_channels | channel_mult | num_res_blocks |
|---------|--------------|--------------|----------------|
| Small   | 128          | (1, 2, 4)    | 2              |
| Medium  | 192          | (1, 2, 4)    | 2              |
| Large   | 256          | (1, 2, 4, 4) | 2              |

## 1. Setup

Run this cell if starting fresh. Skip if continuing from scaling experiment notebook.

In [None]:
# === ONLY run this cell if starting fresh (not continuing from scaling notebook) ===
# If you already have latents in memory, skip to cell 2.

import sys, os
sys.path.insert(0, '.')

import torch
import torch.nn as nn
import copy
import json
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from pathlib import Path

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

# Load cached latents from scaling experiment
print("\nLoading cached latents...")
train_latents_baseline = torch.load('./embeddings/scaling_train_latents_baseline.pt')
val_latents_baseline = torch.load('./embeddings/scaling_val_latents_baseline.pt')
train_latents_mmreg = torch.load('./embeddings/scaling_train_latents_mmreg.pt')
val_latents_mmreg = torch.load('./embeddings/scaling_val_latents_mmreg.pt')
print(f"Train baseline: {train_latents_baseline.shape}")
print(f"Train MM-Reg: {train_latents_mmreg.shape}")

# Load VAEs for decoding generated samples
from src.models.vae_wrapper import load_vae
vae_baseline = load_vae(device=device)
vae_baseline.load_state_dict(torch.load('./checkpoints/scaling_baseline_vae/best.pt', map_location=device)['model_state_dict'])
vae_baseline.eval()

vae_mmreg = load_vae(device=device)
vae_mmreg.load_state_dict(torch.load('./checkpoints/scaling_mmreg_vae/best.pt', map_location=device)['model_state_dict'])
vae_mmreg.eval()
print("VAEs loaded.")

## 2. Configuration

In [None]:
import copy
from src.models.diffusion import SimpleUNet, GaussianDiffusion
from src.diffusion_trainer import DiffusionTrainer, LatentDataset
from torch.utils.data import DataLoader

# Experiment config
N_TRAIN = 20000
DIFFUSION_EPOCHS = 50
DIFFUSION_LR = 1e-4
DIFFUSION_TIMESTEPS = 1000
BATCH_SIZE = 64
GRAD_CLIP = 1.0
EMA_DECAY = 0.9999

# UNet variants
UNET_CONFIGS = {
    'small': {
        'base_channels': 128,
        'channel_mult': (1, 2, 4),
        'num_res_blocks': 2,
    },
    'medium': {
        'base_channels': 192,
        'channel_mult': (1, 2, 4),
        'num_res_blocks': 2,
    },
    'large': {
        'base_channels': 256,
        'channel_mult': (1, 2, 4, 4),
        'num_res_blocks': 2,
    },
}

# Subsample 20k training latents
sub_train_baseline = train_latents_baseline[:N_TRAIN]
sub_train_mmreg = train_latents_mmreg[:N_TRAIN]
latent_size = sub_train_baseline.shape[2]

# Count params for each config
print(f"Training on {N_TRAIN:,} samples, {DIFFUSION_EPOCHS} epochs")
print(f"Latent size: {latent_size}x{latent_size}")
print(f"\nUNet variants:")
for name, cfg in UNET_CONFIGS.items():
    tmp = SimpleUNet(in_channels=4, **cfg)
    n_params = sum(p.numel() for p in tmp.parameters())
    print(f"  {name:>6}: {n_params/1e6:.1f}M params  (base={cfg['base_channels']}, mult={cfg['channel_mult']})")
    del tmp

## 3. Training Loop with EMA + Gradient Clipping + Cosine LR

In [None]:
def update_ema(ema_model, model, decay=0.9999):
    """Update EMA model weights."""
    with torch.no_grad():
        for ema_p, p in zip(ema_model.parameters(), model.parameters()):
            ema_p.data.mul_(decay).add_(p.data, alpha=1 - decay)


def train_diffusion(
    train_latents, val_latents, unet_config, name,
    epochs=DIFFUSION_EPOCHS, lr=DIFFUSION_LR,
    batch_size=BATCH_SIZE, grad_clip=GRAD_CLIP, ema_decay=EMA_DECAY
):
    """
    Train a diffusion model with improved procedure.
    Returns: dict with model, ema_model, train_history, val_history.
    """
    save_dir = Path(f'./checkpoints/capacity_{name}')
    save_dir.mkdir(parents=True, exist_ok=True)

    # Create model
    diffusion = GaussianDiffusion(num_timesteps=DIFFUSION_TIMESTEPS, device=device)
    unet = SimpleUNet(in_channels=4, **unet_config).to(device)
    ema = copy.deepcopy(unet)
    for p in ema.parameters():
        p.requires_grad = False

    n_params = sum(p.numel() for p in unet.parameters())
    print(f"\n{'='*60}")
    print(f"Training: {name} ({n_params/1e6:.1f}M params)")
    print(f"{'='*60}")

    # Optimizer + scheduler
    optimizer = torch.optim.AdamW(unet.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    # Dataloaders
    train_loader = DataLoader(
        LatentDataset(train_latents), batch_size=batch_size,
        shuffle=True, num_workers=0, pin_memory=True
    )
    val_loader = DataLoader(
        LatentDataset(val_latents), batch_size=batch_size,
        shuffle=False, num_workers=0
    )

    train_history = []
    val_history = []
    best_val_loss = float('inf')

    for epoch in range(epochs):
        # --- Train ---
        unet.train()
        total_loss = 0
        num_batches = 0

        for latents in train_loader:
            latents = latents.to(device)
            bs = latents.shape[0]
            t = torch.randint(0, DIFFUSION_TIMESTEPS, (bs,), device=device)

            optimizer.zero_grad()
            loss = diffusion.p_losses(unet, latents, t)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(unet.parameters(), max_norm=grad_clip)
            optimizer.step()
            update_ema(ema, unet, ema_decay)

            total_loss += loss.item()
            num_batches += 1

        scheduler.step()
        train_loss = total_loss / num_batches
        train_history.append(train_loss)

        # --- Validate (using EMA model) ---
        ema.eval()
        val_total = 0
        val_batches = 0
        with torch.no_grad():
            for latents in val_loader:
                latents = latents.to(device)
                bs = latents.shape[0]
                t = torch.randint(0, DIFFUSION_TIMESTEPS, (bs,), device=device)
                loss = diffusion.p_losses(ema, latents, t)
                val_total += loss.item()
                val_batches += 1

        val_loss = val_total / val_batches
        val_history.append(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({'model_state_dict': ema.state_dict()}, save_dir / 'best.pt')

        if epoch % 10 == 0 or epoch == epochs - 1:
            lr_now = scheduler.get_last_lr()[0]
            print(f"  Epoch {epoch:3d}: train={train_loss:.6f}  val={val_loss:.6f}  lr={lr_now:.2e}")

    print(f"  Best val loss: {best_val_loss:.6f}")

    return {
        'diffusion': diffusion,
        'ema_model': ema,
        'train_history': train_history,
        'val_history': val_history,
        'best_val_loss': best_val_loss,
        'n_params': n_params,
        'config': unet_config,
        'save_dir': save_dir,
    }

## 4. Run All 6 Experiments

In [None]:
results = {}

for size_name, unet_cfg in UNET_CONFIGS.items():
    # Baseline
    key_base = f'{size_name}_baseline'
    results[key_base] = train_diffusion(
        sub_train_baseline, val_latents_baseline,
        unet_cfg, key_base
    )

    # MM-Reg
    key_mm = f'{size_name}_mmreg'
    results[key_mm] = train_diffusion(
        sub_train_mmreg, val_latents_mmreg,
        unet_cfg, key_mm
    )

    # Free training model from GPU (keep EMA for generation)
    torch.cuda.empty_cache()

print("\n" + "="*60)
print("ALL TRAINING COMPLETE")
print("="*60)

## 5. Results

In [None]:
# Results table
print(f"{'Model':>16} | {'Params':>8} | {'Base Val':>10} | {'MMReg Val':>10} | {'Improv.':>8}")
print("-" * 65)

for size_name in UNET_CONFIGS:
    base_val = results[f'{size_name}_baseline']['best_val_loss']
    mm_val = results[f'{size_name}_mmreg']['best_val_loss']
    n_params = results[f'{size_name}_baseline']['n_params']
    improvement = (base_val - mm_val) / base_val * 100
    print(f"{size_name:>16} | {n_params/1e6:>7.1f}M | {base_val:>10.6f} | {mm_val:>10.6f} | {improvement:>7.1f}%")

# Also show absolute improvement from small to large
print(f"\nBaseline improvement (small -> large): "
      f"{results['small_baseline']['best_val_loss']:.6f} -> {results['large_baseline']['best_val_loss']:.6f} "
      f"({(results['small_baseline']['best_val_loss'] - results['large_baseline']['best_val_loss']) / results['small_baseline']['best_val_loss'] * 100:.1f}%)")
print(f"MM-Reg improvement (small -> large): "
      f"{results['small_mmreg']['best_val_loss']:.6f} -> {results['large_mmreg']['best_val_loss']:.6f} "
      f"({(results['small_mmreg']['best_val_loss'] - results['large_mmreg']['best_val_loss']) / results['small_mmreg']['best_val_loss'] * 100:.1f}%)")

In [None]:
# Plot 1: Training curves - all 6 runs
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True)

for i, size_name in enumerate(UNET_CONFIGS):
    ax = axes[i]
    n_params = results[f'{size_name}_baseline']['n_params']

    ax.plot(results[f'{size_name}_baseline']['val_history'], 'b-', label='Baseline', alpha=0.8, linewidth=1.5)
    ax.plot(results[f'{size_name}_mmreg']['val_history'], 'r-', label='MM-Reg', alpha=0.8, linewidth=1.5)

    ax.set_title(f'{size_name.capitalize()} ({n_params/1e6:.1f}M)', fontsize=13)
    ax.set_xlabel('Epoch')
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=10)
    if i == 0:
        ax.set_ylabel('Val Loss (EMA)')

plt.suptitle(f'Diffusion Val Loss by Model Size ({N_TRAIN//1000}k samples)', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('./checkpoints/capacity_val_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Plot 2: Best val loss vs model size
sizes = list(UNET_CONFIGS.keys())
params = [results[f'{s}_baseline']['n_params'] / 1e6 for s in sizes]
base_vals = [results[f'{s}_baseline']['best_val_loss'] for s in sizes]
mm_vals = [results[f'{s}_mmreg']['best_val_loss'] for s in sizes]
improvements = [(b - m) / b * 100 for b, m in zip(base_vals, mm_vals)]

fig, axes = plt.subplots(1, 2, figsize=(13, 5))

# Left: val loss vs params
axes[0].plot(params, base_vals, 'bo-', label='Baseline', markersize=10, linewidth=2)
axes[0].plot(params, mm_vals, 'rs-', label='MM-Reg', markersize=10, linewidth=2)
axes[0].set_xlabel('Model Parameters (M)', fontsize=12)
axes[0].set_ylabel('Best Val Loss', fontsize=12)
axes[0].set_title('Diffusion Quality vs Model Capacity', fontsize=13)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
for j, s in enumerate(sizes):
    axes[0].annotate(s, (params[j], base_vals[j]), textcoords='offset points',
                     xytext=(0, 10), ha='center', fontsize=9)

# Right: MM-Reg improvement % by size
x = range(len(sizes))
axes[1].bar(x, improvements, color=['#2ecc71', '#27ae60', '#1e8449'], alpha=0.8)
axes[1].set_xlabel('Model Size', fontsize=12)
axes[1].set_ylabel('MM-Reg Improvement (%)', fontsize=12)
axes[1].set_title('MM-Reg Advantage by Model Capacity', fontsize=13)
axes[1].set_xticks(x)
axes[1].set_xticklabels([f'{s}\n({p:.0f}M)' for s, p in zip(sizes, params)])
axes[1].grid(True, alpha=0.3, axis='y')
for j, v in enumerate(improvements):
    axes[1].text(j, v + 0.5, f'{v:.1f}%', ha='center', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig('./checkpoints/capacity_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Generate Samples (Best Model)

In [None]:
# Find best model size based on lowest val loss across both variants
best_size = min(UNET_CONFIGS.keys(), key=lambda s: results[f'{s}_mmreg']['best_val_loss'])
print(f"Best model size: {best_size}")

def decode_and_plot(vae, latents, title, save_path):
    vae.eval()
    with torch.no_grad():
        images = vae.decode(latents.to(device))
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    fig.suptitle(title, fontsize=14)
    for i in range(16):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = ((img + 1) / 2).clip(0, 1)
        axes[i // 8, i % 8].imshow(img)
        axes[i // 8, i % 8].axis('off')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

# Generate with each model size for comparison
latent_shape = (16, 4, latent_size, latent_size)

for size_name in UNET_CONFIGS:
    print(f"\n--- Generating with {size_name} model ---")
    
    # Load best EMA checkpoint
    diff = results[f'{size_name}_mmreg']['diffusion']
    
    ema_mm = SimpleUNet(in_channels=4, **UNET_CONFIGS[size_name]).to(device)
    ckpt = torch.load(f'./checkpoints/capacity_{size_name}_mmreg/best.pt', map_location=device)
    ema_mm.load_state_dict(ckpt['model_state_dict'])
    
    gen_latents = diff.sample(ema_mm, latent_shape, progress=True)
    
    n_params = results[f'{size_name}_mmreg']['n_params']
    decode_and_plot(
        vae_mmreg, gen_latents,
        f'MM-Reg {size_name.capitalize()} ({n_params/1e6:.0f}M params)',
        f'./checkpoints/capacity_samples_{size_name}_mmreg.png'
    )
    
    del ema_mm, gen_latents
    torch.cuda.empty_cache()

## 7. Summary

In [None]:
def to_python(val):
    if hasattr(val, 'item'):
        return val.item()
    return float(val)

summary = {
    'experiment': 'diffusion_capacity_investigation',
    'n_train': N_TRAIN,
    'epochs': DIFFUSION_EPOCHS,
    'improvements': ['gradient_clipping', 'ema', 'cosine_lr'],
    'results': {}
}

for size_name in UNET_CONFIGS:
    base_val = results[f'{size_name}_baseline']['best_val_loss']
    mm_val = results[f'{size_name}_mmreg']['best_val_loss']
    summary['results'][size_name] = {
        'n_params': results[f'{size_name}_baseline']['n_params'],
        'config': {k: list(v) if isinstance(v, tuple) else v for k, v in UNET_CONFIGS[size_name].items()},
        'baseline_best_val': to_python(base_val),
        'mmreg_best_val': to_python(mm_val),
        'improvement_pct': to_python((base_val - mm_val) / base_val * 100),
    }

with open('./checkpoints/capacity_investigation_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("="*60)
print("CAPACITY INVESTIGATION SUMMARY")
print("="*60)
print(f"\nTraining: {N_TRAIN:,} samples, {DIFFUSION_EPOCHS} epochs")
print(f"Improvements: grad clip={GRAD_CLIP}, EMA={EMA_DECAY}, cosine LR")
print(f"\n{'Model':>8} | {'Params':>8} | {'Base Val':>10} | {'MMReg Val':>10} | {'Improv':>7} | {'Base Drop':>10}")
print("-" * 72)

small_base = results['small_baseline']['best_val_loss']
small_mm = results['small_mmreg']['best_val_loss']

for size_name in UNET_CONFIGS:
    base_val = results[f'{size_name}_baseline']['best_val_loss']
    mm_val = results[f'{size_name}_mmreg']['best_val_loss']
    n_params = results[f'{size_name}_baseline']['n_params']
    improv = (base_val - mm_val) / base_val * 100
    base_drop = (small_base - base_val) / small_base * 100
    print(f"{size_name:>8} | {n_params/1e6:>7.1f}M | {base_val:>10.6f} | {mm_val:>10.6f} | {improv:>6.1f}% | {base_drop:>+9.1f}%")

print(f"\nResults saved to ./checkpoints/capacity_investigation_summary.json")