# Diffusion Model Capacity Investigation

Does a larger UNet improve diffusion quality? Test 3 model sizes on 20k CelebA latents.

## Pipeline
1. Setup + install dependencies
2. Load CelebA + compute PCA
3. Train Baseline + MM-Reg VAEs on full CelebA
4. Encode all images to latents (cached)
5. Train 6 diffusion models (3 sizes x 2 VAE variants) on 20k subset
6. Compare results + generate samples

## Training improvements
- Gradient clipping (max_norm=1.0)
- EMA weights (decay=0.9999)
- Cosine LR schedule

## UNet Variants
| Variant | base_channels | channel_mult | num_res_blocks |
|---------|--------------|--------------|----------------|
| Small   | 128          | (1, 2, 4)    | 2              |
| Medium  | 192          | (1, 2, 4)    | 2              |
| Large   | 256          | (1, 2, 4, 4) | 2              |

## 1. Setup

In [None]:
# Clone repository
!rm -rf MMReg_diffusion_generative 2>/dev/null
!git clone https://github.com/laurent-cheret/MMReg_diffusion_generative.git
%cd MMReg_diffusion_generative

In [None]:
# Install dependencies
!pip install -q torch torchvision diffusers transformers accelerate
!pip install -q pyyaml tqdm scipy scikit-learn matplotlib
!pip install -q datasets gdown

In [None]:
import sys
sys.path.insert(0, '.')

import torch
import torch.nn as nn
import copy
import os
import json
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

In [None]:
# Experiment configuration
CONFIG = {
    # Data
    'data_root': './data',
    'image_size': 128,
    'batch_size': 64,

    # PCA
    'pca_components': 256,

    # VAE Training (on full dataset, once)
    'vae_epochs': 5,
    'vae_lr': 1e-5,
    'lambda_mm': 1.0,
    'beta': 1e-6,

    # Diffusion Training
    'n_train': 20000,
    'diffusion_epochs': 50,
    'diffusion_lr': 1e-4,
    'diffusion_timesteps': 1000,
    'grad_clip': 1.0,
    'ema_decay': 0.9999,
}

# UNet variants
UNET_CONFIGS = {
    'small': {
        'base_channels': 128,
        'channel_mult': (1, 2, 4),
        'num_res_blocks': 2,
    },
    'medium': {
        'base_channels': 192,
        'channel_mult': (1, 2, 4),
        'num_res_blocks': 2,
    },
    'large': {
        'base_channels': 256,
        'channel_mult': (1, 2, 4, 4),
        'num_res_blocks': 2,
    },
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")
print(f"\nUNet variants: {list(UNET_CONFIGS.keys())}")

## 2. Load CelebA & Compute PCA

In [None]:
from src.data.dataset import (
    get_celeba_dataset,
    compute_pca_embeddings_celeba,
    get_dataset_and_loader
)

# === CONFIGURATION: Choose your data source ===
USE_HUGGINGFACE = True
USE_DRIVE = False
DRIVE_PATHS = {
    'images_zip': '/content/drive/MyDrive/DATASETS/CelebA/img_align_celeba.zip',
    'attr_file': '/content/drive/MyDrive/DATASETS/CelebA/list_attr_celeba.txt',
    'partition_file': '/content/drive/MyDrive/DATASETS/CelebA/list_eval_partition.txt'
}

if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')

source = 'drive' if USE_DRIVE else 'huggingface'
drive_paths = DRIVE_PATHS if USE_DRIVE else None

# Load CelebA
print(f"Loading CelebA from {source}...")
train_dataset_fixed = get_celeba_dataset(
    root=CONFIG['data_root'], split='train',
    image_size=CONFIG['image_size'], fixed_transform=True,
    source=source, drive_paths=drive_paths
)
val_dataset_fixed = get_celeba_dataset(
    root=CONFIG['data_root'], split='val',
    image_size=CONFIG['image_size'], fixed_transform=True,
    source=source, drive_paths=drive_paths
)

total_train = len(train_dataset_fixed)
print(f"Train: {total_train}, Val: {len(val_dataset_fixed)}")

In [None]:
# Compute PCA embeddings
os.makedirs('./embeddings', exist_ok=True)

print("Computing PCA embeddings for training set...")
train_pca = compute_pca_embeddings_celeba(
    train_dataset_fixed, n_components=CONFIG['pca_components'], batch_size=256
)
torch.save(train_pca, './embeddings/celeba_train_pca.pt')

print("Computing PCA embeddings for validation set...")
val_pca = compute_pca_embeddings_celeba(
    val_dataset_fixed, n_components=CONFIG['pca_components'], batch_size=256
)
torch.save(val_pca, './embeddings/celeba_val_pca.pt')

print(f"Train PCA: {train_pca.shape}, Val PCA: {val_pca.shape}")

## 3. Train VAEs on Full CelebA

### 3.1 Baseline VAE (no MM-Reg)

In [None]:
from src.models.vae_wrapper import load_vae
from src.models.losses import VAELoss
from src.trainer import MMRegTrainer

print("="*60)
print("TRAINING BASELINE VAE (no MM-Reg) on full CelebA")
print("="*60)

vae_baseline = load_vae(device=device)
loss_baseline = VAELoss(lambda_mm=0.0, beta=CONFIG['beta'])

train_dataset_base, train_loader_base = get_dataset_and_loader(
    dataset_name='celeba', root=CONFIG['data_root'], split='train',
    image_size=CONFIG['image_size'], batch_size=CONFIG['batch_size'],
    num_workers=2, pca_embeddings_path='./embeddings/celeba_train_pca.pt',
    celeba_source=source, celeba_drive_paths=drive_paths
)
val_dataset_base, val_loader_base = get_dataset_and_loader(
    dataset_name='celeba', root=CONFIG['data_root'], split='val',
    image_size=CONFIG['image_size'], batch_size=CONFIG['batch_size'],
    num_workers=2, pca_embeddings_path='./embeddings/celeba_val_pca.pt',
    celeba_source=source, celeba_drive_paths=drive_paths
)

optimizer_baseline = torch.optim.AdamW(vae_baseline.parameters(), lr=CONFIG['vae_lr'])

trainer_baseline = MMRegTrainer(
    vae=vae_baseline, loss_fn=loss_baseline, optimizer=optimizer_baseline,
    train_loader=train_loader_base, val_loader=val_loader_base,
    device=device, save_dir='./checkpoints/capacity_baseline_vae'
)
trainer_baseline.train(num_epochs=CONFIG['vae_epochs'])

### 3.2 MM-Reg VAE

In [None]:
print("="*60)
print("TRAINING MM-REG VAE on full CelebA")
print("="*60)

vae_mmreg = load_vae(device=device)
loss_mmreg = VAELoss(
    lambda_mm=CONFIG['lambda_mm'], beta=CONFIG['beta'], mm_variant='correlation'
)

train_dataset_mm, train_loader_mm = get_dataset_and_loader(
    dataset_name='celeba', root=CONFIG['data_root'], split='train',
    image_size=CONFIG['image_size'], batch_size=CONFIG['batch_size'],
    num_workers=2, pca_embeddings_path='./embeddings/celeba_train_pca.pt',
    celeba_source=source, celeba_drive_paths=drive_paths
)
val_dataset_mm, val_loader_mm = get_dataset_and_loader(
    dataset_name='celeba', root=CONFIG['data_root'], split='val',
    image_size=CONFIG['image_size'], batch_size=CONFIG['batch_size'],
    num_workers=2, pca_embeddings_path='./embeddings/celeba_val_pca.pt',
    celeba_source=source, celeba_drive_paths=drive_paths
)

optimizer_mmreg = torch.optim.AdamW(vae_mmreg.parameters(), lr=CONFIG['vae_lr'])

trainer_mmreg = MMRegTrainer(
    vae=vae_mmreg, loss_fn=loss_mmreg, optimizer=optimizer_mmreg,
    train_loader=train_loader_mm, val_loader=val_loader_mm,
    device=device, save_dir='./checkpoints/capacity_mmreg_vae'
)
trainer_mmreg.train(num_epochs=CONFIG['vae_epochs'])

## 4. Evaluate VAEs & Encode Dataset

In [None]:
from src.models.losses import pairwise_distances, get_upper_triangular
from scipy.stats import pearsonr

def quick_evaluate(vae, val_loader, val_pca, name):
    """Quick VAE evaluation: recon MSE + Pearson correlation."""
    vae.eval()
    all_latents = []
    total_recon = 0
    num_samples = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Eval {name}"):
            images = batch[0].to(device)
            outputs = vae(images, sample=False)
            total_recon += ((outputs['x_recon'] - images) ** 2).mean().item() * images.shape[0]
            num_samples += images.shape[0]
            all_latents.append(outputs['latent_flat'].cpu())

    all_latents = torch.cat(all_latents, dim=0)
    n = min(500, len(all_latents))
    D_lat = pairwise_distances(all_latents[:n])
    D_pca = pairwise_distances(val_pca[:n])
    d_lat = get_upper_triangular(D_lat).numpy()
    d_pca = get_upper_triangular(D_pca).numpy()
    pearson_val, _ = pearsonr(d_lat, d_pca)

    recon_mse = total_recon / num_samples
    print(f"{name}: Recon MSE={recon_mse:.6f}, Pearson={pearson_val:.4f}")
    return {'recon_mse': recon_mse, 'pearson': pearson_val}

print("\nVAE Evaluation (on full val set):")
vae_results_baseline = quick_evaluate(vae_baseline, val_loader_base, val_pca, "Baseline")
vae_results_mmreg = quick_evaluate(vae_mmreg, val_loader_mm, val_pca, "MM-Reg")

In [None]:
from src.diffusion_trainer import encode_dataset, LatentDataset

# Simple dataloaders for encoding (no PCA wrapper)
train_loader_simple = DataLoader(
    train_dataset_fixed, batch_size=CONFIG['batch_size'],
    shuffle=False, num_workers=2
)
val_loader_simple = DataLoader(
    val_dataset_fixed, batch_size=CONFIG['batch_size'],
    shuffle=False, num_workers=2
)

print("Encoding full dataset with Baseline VAE...")
train_latents_baseline = encode_dataset(vae_baseline, train_loader_simple, device)
val_latents_baseline = encode_dataset(vae_baseline, val_loader_simple, device)
print(f"Baseline latents - Train: {train_latents_baseline.shape}, Val: {val_latents_baseline.shape}")

print("\nEncoding full dataset with MM-Reg VAE...")
train_latents_mmreg = encode_dataset(vae_mmreg, train_loader_simple, device)
val_latents_mmreg = encode_dataset(vae_mmreg, val_loader_simple, device)
print(f"MM-Reg latents - Train: {train_latents_mmreg.shape}, Val: {val_latents_mmreg.shape}")

# Save for potential reuse
torch.save(train_latents_baseline, './embeddings/train_latents_baseline.pt')
torch.save(val_latents_baseline, './embeddings/val_latents_baseline.pt')
torch.save(train_latents_mmreg, './embeddings/train_latents_mmreg.pt')
torch.save(val_latents_mmreg, './embeddings/val_latents_mmreg.pt')
print("\nLatents cached to ./embeddings/")

## 5. Diffusion Capacity Experiment

Train 3 UNet sizes x 2 VAE variants = 6 runs on 20k samples.

In [None]:
from src.models.diffusion import SimpleUNet, GaussianDiffusion

# Subsample 20k training latents
N_TRAIN = CONFIG['n_train']
sub_train_baseline = train_latents_baseline[:N_TRAIN]
sub_train_mmreg = train_latents_mmreg[:N_TRAIN]
latent_size = sub_train_baseline.shape[2]

print(f"Training on {N_TRAIN:,} samples, {CONFIG['diffusion_epochs']} epochs")
print(f"Latent size: {latent_size}x{latent_size}")
print(f"Training improvements: grad_clip={CONFIG['grad_clip']}, EMA={CONFIG['ema_decay']}, cosine LR")

# Count params for each config
print(f"\nUNet variants:")
for name, cfg in UNET_CONFIGS.items():
    tmp = SimpleUNet(in_channels=4, **cfg)
    n_params = sum(p.numel() for p in tmp.parameters())
    print(f"  {name:>6}: {n_params/1e6:.1f}M params  (base={cfg['base_channels']}, mult={cfg['channel_mult']})")
    del tmp

In [None]:
def update_ema(ema_model, model, decay=0.9999):
    """Update EMA model weights."""
    with torch.no_grad():
        for ema_p, p in zip(ema_model.parameters(), model.parameters()):
            ema_p.data.mul_(decay).add_(p.data, alpha=1 - decay)


def train_diffusion(train_latents, val_latents, unet_config, name):
    """
    Train a diffusion model with EMA, gradient clipping, cosine LR.
    Returns: dict with results.
    """
    save_dir = Path(f'./checkpoints/capacity_{name}')
    save_dir.mkdir(parents=True, exist_ok=True)

    epochs = CONFIG['diffusion_epochs']
    lr = CONFIG['diffusion_lr']
    batch_size = CONFIG['batch_size']
    grad_clip = CONFIG['grad_clip']
    ema_decay = CONFIG['ema_decay']
    timesteps = CONFIG['diffusion_timesteps']

    # Create model
    diffusion = GaussianDiffusion(num_timesteps=timesteps, device=device)
    unet = SimpleUNet(in_channels=4, **unet_config).to(device)
    ema = copy.deepcopy(unet)
    for p in ema.parameters():
        p.requires_grad = False

    n_params = sum(p.numel() for p in unet.parameters())
    print(f"\n{'='*60}")
    print(f"Training: {name} ({n_params/1e6:.1f}M params)")
    print(f"{'='*60}")

    # Optimizer + scheduler
    optimizer = torch.optim.AdamW(unet.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    # Dataloaders
    train_loader = DataLoader(
        LatentDataset(train_latents), batch_size=batch_size,
        shuffle=True, num_workers=0, pin_memory=True
    )
    val_loader = DataLoader(
        LatentDataset(val_latents), batch_size=batch_size,
        shuffle=False, num_workers=0
    )

    train_history = []
    val_history = []
    best_val_loss = float('inf')

    for epoch in range(epochs):
        # --- Train ---
        unet.train()
        total_loss = 0
        num_batches = 0

        for latents in train_loader:
            latents = latents.to(device)
            bs = latents.shape[0]
            t = torch.randint(0, timesteps, (bs,), device=device)

            optimizer.zero_grad()
            loss = diffusion.p_losses(unet, latents, t)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(unet.parameters(), max_norm=grad_clip)
            optimizer.step()
            update_ema(ema, unet, ema_decay)

            total_loss += loss.item()
            num_batches += 1

        scheduler.step()
        train_loss = total_loss / num_batches
        train_history.append(train_loss)

        # --- Validate (using EMA model) ---
        ema.eval()
        val_total = 0
        val_batches = 0
        with torch.no_grad():
            for latents in val_loader:
                latents = latents.to(device)
                bs = latents.shape[0]
                t = torch.randint(0, timesteps, (bs,), device=device)
                loss = diffusion.p_losses(ema, latents, t)
                val_total += loss.item()
                val_batches += 1

        val_loss = val_total / val_batches
        val_history.append(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({'model_state_dict': ema.state_dict()}, save_dir / 'best.pt')

        if epoch % 10 == 0 or epoch == epochs - 1:
            lr_now = scheduler.get_last_lr()[0]
            print(f"  Epoch {epoch:3d}: train={train_loss:.6f}  val={val_loss:.6f}  lr={lr_now:.2e}")

    print(f"  Best val loss: {best_val_loss:.6f}")

    # Free training model, keep EMA
    del unet, optimizer, scheduler
    torch.cuda.empty_cache()

    return {
        'diffusion': diffusion,
        'ema_model': ema,
        'train_history': train_history,
        'val_history': val_history,
        'best_val_loss': best_val_loss,
        'n_params': n_params,
        'config': unet_config,
        'save_dir': save_dir,
    }

In [None]:
# Run all 6 experiments
results = {}

for size_name, unet_cfg in UNET_CONFIGS.items():
    # Baseline
    key_base = f'{size_name}_baseline'
    results[key_base] = train_diffusion(
        sub_train_baseline, val_latents_baseline,
        unet_cfg, key_base
    )

    # MM-Reg
    key_mm = f'{size_name}_mmreg'
    results[key_mm] = train_diffusion(
        sub_train_mmreg, val_latents_mmreg,
        unet_cfg, key_mm
    )

    torch.cuda.empty_cache()

print("\n" + "="*60)
print("ALL TRAINING COMPLETE")
print("="*60)

## 6. Results

In [None]:
# Results table
print(f"{'Model':>16} | {'Params':>8} | {'Base Val':>10} | {'MMReg Val':>10} | {'Improv.':>8}")
print("-" * 65)

for size_name in UNET_CONFIGS:
    base_val = results[f'{size_name}_baseline']['best_val_loss']
    mm_val = results[f'{size_name}_mmreg']['best_val_loss']
    n_params = results[f'{size_name}_baseline']['n_params']
    improvement = (base_val - mm_val) / base_val * 100
    print(f"{size_name:>16} | {n_params/1e6:>7.1f}M | {base_val:>10.6f} | {mm_val:>10.6f} | {improvement:>7.1f}%")

print(f"\nBaseline improvement (small -> large): "
      f"{results['small_baseline']['best_val_loss']:.6f} -> {results['large_baseline']['best_val_loss']:.6f} "
      f"({(results['small_baseline']['best_val_loss'] - results['large_baseline']['best_val_loss']) / results['small_baseline']['best_val_loss'] * 100:.1f}%)")
print(f"MM-Reg improvement (small -> large): "
      f"{results['small_mmreg']['best_val_loss']:.6f} -> {results['large_mmreg']['best_val_loss']:.6f} "
      f"({(results['small_mmreg']['best_val_loss'] - results['large_mmreg']['best_val_loss']) / results['small_mmreg']['best_val_loss'] * 100:.1f}%)")

In [None]:
# Plot 1: Training curves - all 6 runs
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True)

for i, size_name in enumerate(UNET_CONFIGS):
    ax = axes[i]
    n_params = results[f'{size_name}_baseline']['n_params']

    ax.plot(results[f'{size_name}_baseline']['val_history'], 'b-', label='Baseline', alpha=0.8, linewidth=1.5)
    ax.plot(results[f'{size_name}_mmreg']['val_history'], 'r-', label='MM-Reg', alpha=0.8, linewidth=1.5)

    ax.set_title(f'{size_name.capitalize()} ({n_params/1e6:.1f}M)', fontsize=13)
    ax.set_xlabel('Epoch')
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=10)
    if i == 0:
        ax.set_ylabel('Val Loss (EMA)')

plt.suptitle(f'Diffusion Val Loss by Model Size ({N_TRAIN//1000}k samples)', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('./checkpoints/capacity_val_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Plot 2: Best val loss vs model size
sizes = list(UNET_CONFIGS.keys())
params = [results[f'{s}_baseline']['n_params'] / 1e6 for s in sizes]
base_vals = [results[f'{s}_baseline']['best_val_loss'] for s in sizes]
mm_vals = [results[f'{s}_mmreg']['best_val_loss'] for s in sizes]
improvements = [(b - m) / b * 100 for b, m in zip(base_vals, mm_vals)]

fig, axes = plt.subplots(1, 2, figsize=(13, 5))

# Left: val loss vs params
axes[0].plot(params, base_vals, 'bo-', label='Baseline', markersize=10, linewidth=2)
axes[0].plot(params, mm_vals, 'rs-', label='MM-Reg', markersize=10, linewidth=2)
axes[0].set_xlabel('Model Parameters (M)', fontsize=12)
axes[0].set_ylabel('Best Val Loss', fontsize=12)
axes[0].set_title('Diffusion Quality vs Model Capacity', fontsize=13)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
for j, s in enumerate(sizes):
    axes[0].annotate(s, (params[j], base_vals[j]), textcoords='offset points',
                     xytext=(0, 10), ha='center', fontsize=9)

# Right: MM-Reg improvement % by size
x = range(len(sizes))
axes[1].bar(x, improvements, color=['#2ecc71', '#27ae60', '#1e8449'], alpha=0.8)
axes[1].set_xlabel('Model Size', fontsize=12)
axes[1].set_ylabel('MM-Reg Improvement (%)', fontsize=12)
axes[1].set_title('MM-Reg Advantage by Model Capacity', fontsize=13)
axes[1].set_xticks(x)
axes[1].set_xticklabels([f'{s}\n({p:.0f}M)' for s, p in zip(sizes, params)])
axes[1].grid(True, alpha=0.3, axis='y')
for j, v in enumerate(improvements):
    axes[1].text(j, v + 0.5, f'{v:.1f}%', ha='center', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig('./checkpoints/capacity_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Generate Samples

In [None]:
def decode_and_plot(vae, latents, title, save_path):
    vae.eval()
    with torch.no_grad():
        images = vae.decode(latents.to(device))
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    fig.suptitle(title, fontsize=14)
    for i in range(16):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = ((img + 1) / 2).clip(0, 1)
        axes[i // 8, i % 8].imshow(img)
        axes[i // 8, i % 8].axis('off')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

# Generate with each model size for comparison
latent_shape = (16, 4, latent_size, latent_size)

for size_name in UNET_CONFIGS:
    print(f"\n--- Generating with {size_name} model ---")

    diff = results[f'{size_name}_mmreg']['diffusion']

    ema_mm = SimpleUNet(in_channels=4, **UNET_CONFIGS[size_name]).to(device)
    ckpt = torch.load(f'./checkpoints/capacity_{size_name}_mmreg/best.pt', map_location=device)
    ema_mm.load_state_dict(ckpt['model_state_dict'])

    gen_latents = diff.sample(ema_mm, latent_shape, progress=True)

    n_params = results[f'{size_name}_mmreg']['n_params']
    decode_and_plot(
        vae_mmreg, gen_latents,
        f'MM-Reg {size_name.capitalize()} ({n_params/1e6:.0f}M params)',
        f'./checkpoints/capacity_samples_{size_name}_mmreg.png'
    )

    del ema_mm, gen_latents
    torch.cuda.empty_cache()

## 8. Summary

In [None]:
def to_python(val):
    if hasattr(val, 'item'):
        return val.item()
    return float(val)

summary = {
    'experiment': 'diffusion_capacity_investigation',
    'n_train': CONFIG['n_train'],
    'epochs': CONFIG['diffusion_epochs'],
    'improvements': ['gradient_clipping', 'ema', 'cosine_lr'],
    'vae_results': {
        'baseline': {k: to_python(v) for k, v in vae_results_baseline.items()},
        'mmreg': {k: to_python(v) for k, v in vae_results_mmreg.items()}
    },
    'diffusion_results': {}
}

for size_name in UNET_CONFIGS:
    base_val = results[f'{size_name}_baseline']['best_val_loss']
    mm_val = results[f'{size_name}_mmreg']['best_val_loss']
    summary['diffusion_results'][size_name] = {
        'n_params': results[f'{size_name}_baseline']['n_params'],
        'config': {k: list(v) if isinstance(v, tuple) else v for k, v in UNET_CONFIGS[size_name].items()},
        'baseline_best_val': to_python(base_val),
        'mmreg_best_val': to_python(mm_val),
        'improvement_pct': to_python((base_val - mm_val) / base_val * 100),
    }

with open('./checkpoints/capacity_investigation_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("="*60)
print("CAPACITY INVESTIGATION SUMMARY")
print("="*60)

print(f"\nVAE (trained on full dataset):")
print(f"  Baseline: Recon MSE={vae_results_baseline['recon_mse']:.6f}, Pearson={vae_results_baseline['pearson']:.4f}")
print(f"  MM-Reg:   Recon MSE={vae_results_mmreg['recon_mse']:.6f}, Pearson={vae_results_mmreg['pearson']:.4f}")

print(f"\nDiffusion ({CONFIG['n_train']//1000}k samples, {CONFIG['diffusion_epochs']} epochs):")
print(f"Training: grad_clip={CONFIG['grad_clip']}, EMA={CONFIG['ema_decay']}, cosine LR")
print(f"\n{'Model':>8} | {'Params':>8} | {'Base Val':>10} | {'MMReg Val':>10} | {'Improv':>7} | {'Base Drop':>10}")
print("-" * 72)

small_base = results['small_baseline']['best_val_loss']

for size_name in UNET_CONFIGS:
    base_val = results[f'{size_name}_baseline']['best_val_loss']
    mm_val = results[f'{size_name}_mmreg']['best_val_loss']
    n_params = results[f'{size_name}_baseline']['n_params']
    improv = (base_val - mm_val) / base_val * 100
    base_drop = (small_base - base_val) / small_base * 100
    print(f"{size_name:>8} | {n_params/1e6:>7.1f}M | {base_val:>10.6f} | {mm_val:>10.6f} | {improv:>6.1f}% | {base_drop:>+9.1f}%")

print(f"\nResults saved to ./checkpoints/capacity_investigation_summary.json")