# MM-Reg: Dataset Scaling Experiment on CelebA

Does MM-Reg's diffusion advantage grow, shrink, or stay constant as dataset size increases?

## Design:
1. Fine-tune TWO SD VAEs once on full CelebA (~160k): **Baseline** vs **MM-Reg**
2. Encode all images with both VAEs (cached)
3. For each dataset scale (10k, 20k, 50k, 100k, 160k):
   - Subsample cached latents
   - Train diffusion on baseline latents (50 epochs)
   - Train diffusion on MM-Reg latents (50 epochs)
4. Compare diffusion val loss across scales

**Validation set is fixed** across all scales for fair comparison.

## 1. Setup

In [None]:
# Clone repository
!rm -rf MMReg_diffusion_generative 2>/dev/null
!git clone https://github.com/laurent-cheret/MMReg_diffusion_generative.git
%cd MMReg_diffusion_generative

In [None]:
# Install dependencies
!pip install -q torch torchvision diffusers transformers accelerate
!pip install -q pyyaml tqdm scipy scikit-learn matplotlib
!pip install -q datasets gdown

In [None]:
import sys
sys.path.insert(0, '.')

import torch
import torch.nn as nn
import os
import json
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

In [None]:
# Experiment configuration
CONFIG = {
    # Data
    'data_root': './data',
    'image_size': 128,
    'batch_size': 64,

    # PCA
    'pca_components': 256,

    # VAE Training (on full dataset, once)
    'vae_epochs': 5,
    'vae_lr': 1e-5,
    'lambda_mm': 1.0,
    'beta': 1e-6,

    # Diffusion Training (per scale)
    'diffusion_epochs': 50,
    'diffusion_lr': 1e-4,
    'diffusion_timesteps': 1000,

    # Dataset scales to test
    'scales': [10000, 20000, 50000, 100000, 160000],
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 2. Load CelebA & Pre-compute PCA

In [None]:
from src.data.dataset import (
    get_celeba_dataset,
    compute_pca_embeddings_celeba,
    get_dataset_and_loader
)

# === CONFIGURATION: Choose your data source ===
# Option 1: HuggingFace (recommended)
USE_HUGGINGFACE = True

# Option 2: Google Drive
USE_DRIVE = False
DRIVE_PATHS = {
    'images_zip': '/content/drive/MyDrive/DATASETS/CelebA/img_align_celeba.zip',
    'attr_file': '/content/drive/MyDrive/DATASETS/CelebA/list_attr_celeba.txt',
    'partition_file': '/content/drive/MyDrive/DATASETS/CelebA/list_eval_partition.txt'
}

if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')

source = 'drive' if USE_DRIVE else 'huggingface'
drive_paths = DRIVE_PATHS if USE_DRIVE else None

# Load full CelebA with fixed transforms for PCA
print(f"Loading CelebA from {source}...")
train_dataset_fixed = get_celeba_dataset(
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    fixed_transform=True,
    source=source,
    drive_paths=drive_paths
)

val_dataset_fixed = get_celeba_dataset(
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    fixed_transform=True,
    source=source,
    drive_paths=drive_paths
)

total_train = len(train_dataset_fixed)
print(f"Train: {total_train}, Val: {len(val_dataset_fixed)}")

# Cap scales at actual dataset size
CONFIG['scales'] = [s for s in CONFIG['scales'] if s <= total_train]
if total_train not in CONFIG['scales']:
    CONFIG['scales'].append(total_train)
print(f"Scales to test: {CONFIG['scales']}")

In [None]:
# Compute PCA embeddings on full training set
os.makedirs('./embeddings', exist_ok=True)

print("Computing PCA embeddings for training set...")
train_pca = compute_pca_embeddings_celeba(
    train_dataset_fixed,
    n_components=CONFIG['pca_components'],
    batch_size=256
)
torch.save(train_pca, './embeddings/celeba_train_pca.pt')

print("Computing PCA embeddings for validation set...")
val_pca = compute_pca_embeddings_celeba(
    val_dataset_fixed,
    n_components=CONFIG['pca_components'],
    batch_size=256
)
torch.save(val_pca, './embeddings/celeba_val_pca.pt')

print(f"Train PCA: {train_pca.shape}, Val PCA: {val_pca.shape}")

## 3. Fine-tune VAEs on Full CelebA

Train two SD VAEs once on the full dataset. These are reused across all scales.

### 3.1 Baseline VAE (no MM-Reg)

In [None]:
from src.models.vae_wrapper import load_vae
from src.models.losses import VAELoss
from src.trainer import MMRegTrainer

print("="*60)
print("TRAINING BASELINE VAE (no MM-Reg) on full CelebA")
print("="*60)

vae_baseline = load_vae(device=device)

loss_baseline = VAELoss(lambda_mm=0.0, beta=CONFIG['beta'])

train_dataset_base, train_loader_base = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_train_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

val_dataset_base, val_loader_base = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_val_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

optimizer_baseline = torch.optim.AdamW(vae_baseline.parameters(), lr=CONFIG['vae_lr'])

trainer_baseline = MMRegTrainer(
    vae=vae_baseline,
    loss_fn=loss_baseline,
    optimizer=optimizer_baseline,
    train_loader=train_loader_base,
    val_loader=val_loader_base,
    device=device,
    save_dir='./checkpoints/scaling_baseline_vae'
)

trainer_baseline.train(num_epochs=CONFIG['vae_epochs'])

### 3.2 MM-Reg VAE

In [None]:
print("="*60)
print("TRAINING MM-REG VAE on full CelebA")
print("="*60)

vae_mmreg = load_vae(device=device)

loss_mmreg = VAELoss(
    lambda_mm=CONFIG['lambda_mm'],
    beta=CONFIG['beta'],
    mm_variant='correlation'
)

train_dataset_mm, train_loader_mm = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_train_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

val_dataset_mm, val_loader_mm = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_val_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

optimizer_mmreg = torch.optim.AdamW(vae_mmreg.parameters(), lr=CONFIG['vae_lr'])

trainer_mmreg = MMRegTrainer(
    vae=vae_mmreg,
    loss_fn=loss_mmreg,
    optimizer=optimizer_mmreg,
    train_loader=train_loader_mm,
    val_loader=val_loader_mm,
    device=device,
    save_dir='./checkpoints/scaling_mmreg_vae'
)

trainer_mmreg.train(num_epochs=CONFIG['vae_epochs'])

## 4. Evaluate VAEs & Encode Full Dataset

In [None]:
from src.models.losses import pairwise_distances, get_upper_triangular
from scipy.stats import pearsonr

def quick_evaluate(vae, val_loader, val_pca, name):
    """Quick VAE evaluation: recon MSE + Pearson correlation."""
    vae.eval()
    all_latents = []
    total_recon = 0
    num_samples = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Eval {name}"):
            images = batch[0].to(device)
            outputs = vae(images, sample=False)
            total_recon += ((outputs['x_recon'] - images) ** 2).mean().item() * images.shape[0]
            num_samples += images.shape[0]
            all_latents.append(outputs['latent_flat'].cpu())

    all_latents = torch.cat(all_latents, dim=0)

    n = min(500, len(all_latents))
    D_lat = pairwise_distances(all_latents[:n])
    D_pca = pairwise_distances(val_pca[:n])
    d_lat = get_upper_triangular(D_lat).numpy()
    d_pca = get_upper_triangular(D_pca).numpy()
    pearson, _ = pearsonr(d_lat, d_pca)

    recon_mse = total_recon / num_samples
    print(f"{name}: Recon MSE={recon_mse:.6f}, Pearson={pearson:.4f}")
    return {'recon_mse': recon_mse, 'pearson': pearson}

print("\nVAE Evaluation (on full val set):")
vae_results_baseline = quick_evaluate(vae_baseline, val_loader_base, val_pca, "Baseline")
vae_results_mmreg = quick_evaluate(vae_mmreg, val_loader_mm, val_pca, "MM-Reg")

In [None]:
from src.diffusion_trainer import encode_dataset

# Simple dataloaders for encoding (no PCA wrapper)
train_loader_simple = DataLoader(
    train_dataset_fixed,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)
val_loader_simple = DataLoader(
    val_dataset_fixed,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)

# Encode full dataset with both VAEs
print("Encoding full dataset with Baseline VAE...")
train_latents_baseline = encode_dataset(vae_baseline, train_loader_simple, device)
val_latents_baseline = encode_dataset(vae_baseline, val_loader_simple, device)
print(f"Baseline latents - Train: {train_latents_baseline.shape}, Val: {val_latents_baseline.shape}")

print("\nEncoding full dataset with MM-Reg VAE...")
train_latents_mmreg = encode_dataset(vae_mmreg, train_loader_simple, device)
val_latents_mmreg = encode_dataset(vae_mmreg, val_loader_simple, device)
print(f"MM-Reg latents - Train: {train_latents_mmreg.shape}, Val: {val_latents_mmreg.shape}")

# Save for potential reuse
torch.save(train_latents_baseline, './embeddings/scaling_train_latents_baseline.pt')
torch.save(val_latents_baseline, './embeddings/scaling_val_latents_baseline.pt')
torch.save(train_latents_mmreg, './embeddings/scaling_train_latents_mmreg.pt')
torch.save(val_latents_mmreg, './embeddings/scaling_val_latents_mmreg.pt')
print("\nLatents cached to ./embeddings/")

## 5. Scaling Experiment: Diffusion Training at Each Scale

For each dataset size, subsample training latents and train fresh diffusion models.
Validation set stays fixed across all scales.

In [None]:
from src.models.diffusion import SimpleUNet, GaussianDiffusion
from src.diffusion_trainer import DiffusionTrainer
import copy

# Latent spatial size for UNet
latent_size = train_latents_baseline.shape[2]  # 16 for 128x128 images
print(f"Latent size: {latent_size}x{latent_size}")

# === NORMALIZE LATENTS ===
# Diffusion assumes data is roughly N(0,1). Raw VAE latents are not!
print("\nNormalizing latents...")
baseline_mean = train_latents_baseline.mean()
baseline_std = train_latents_baseline.std()
mmreg_mean = train_latents_mmreg.mean()
mmreg_std = train_latents_mmreg.std()

print(f"Baseline latents - mean: {baseline_mean:.4f}, std: {baseline_std:.4f}")
print(f"MM-Reg latents   - mean: {mmreg_mean:.4f}, std: {mmreg_std:.4f}")

train_latents_baseline_norm = (train_latents_baseline - baseline_mean) / baseline_std
val_latents_baseline_norm = (val_latents_baseline - baseline_mean) / baseline_std
train_latents_mmreg_norm = (train_latents_mmreg - mmreg_mean) / mmreg_std
val_latents_mmreg_norm = (val_latents_mmreg - mmreg_mean) / mmreg_std

# Save normalization stats for generation
latent_stats = {
    'baseline': {'mean': baseline_mean.item(), 'std': baseline_std.item()},
    'mmreg': {'mean': mmreg_mean.item(), 'std': mmreg_std.item()}
}
torch.save(latent_stats, './embeddings/latent_stats.pt')
print("Latent stats saved for denormalization during generation")

# === EMA HELPER ===
def update_ema(ema_model, model, decay=0.9999):
    with torch.no_grad():
        for ema_p, p in zip(ema_model.parameters(), model.parameters()):
            ema_p.data.mul_(decay).add_(p.data, alpha=1 - decay)

# Store results
scaling_results = {}

for scale in CONFIG['scales']:
    # Cap at actual dataset size
    n = min(scale, len(train_latents_baseline_norm))
    print(f"\n{'='*60}")
    print(f"SCALE: {n:,} training samples")
    print(f"{'='*60}")

    # Subsample training latents (first N, deterministic)
    sub_train_baseline = train_latents_baseline_norm[:n]
    sub_train_mmreg = train_latents_mmreg_norm[:n]

    # --- Baseline Diffusion ---
    print(f"\n--- Baseline Diffusion ({n:,} samples) ---")

    diffusion_base = GaussianDiffusion(
        num_timesteps=CONFIG['diffusion_timesteps'],
        device=device
    )

    unet_base = SimpleUNet(
        in_channels=4,
        base_channels=128,
        channel_mult=(1, 2, 4),
        num_res_blocks=2
    ).to(device)
    
    # EMA model
    ema_base = copy.deepcopy(unet_base)
    for p in ema_base.parameters():
        p.requires_grad = False

    opt_base = torch.optim.AdamW(unet_base.parameters(), lr=CONFIG['diffusion_lr'])
    
    # Cosine LR schedule
    scheduler_base = torch.optim.lr_scheduler.CosineAnnealingLR(
        opt_base, T_max=CONFIG['diffusion_epochs']
    )

    trainer_base = DiffusionTrainer(
        model=unet_base,
        diffusion=diffusion_base,
        optimizer=opt_base,
        train_latents=sub_train_baseline,
        val_latents=val_latents_baseline_norm,
        batch_size=CONFIG['batch_size'],
        device=device,
        save_dir=f'./checkpoints/scaling_diff_baseline_{n}',
        scheduler=scheduler_base
    )

    # Custom training loop with gradient clipping and EMA
    print(f"Training diffusion for {CONFIG['diffusion_epochs']} epochs...")
    for epoch in range(CONFIG['diffusion_epochs']):
        trainer_base.model.train()
        total_loss = 0
        num_batches = 0
        
        for latents in trainer_base.train_loader:
            latents = latents.to(device)
            batch_size = latents.shape[0]
            t = torch.randint(0, diffusion_base.num_timesteps, (batch_size,), device=device)
            
            opt_base.zero_grad()
            loss = diffusion_base.p_losses(unet_base, latents, t)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(unet_base.parameters(), max_norm=1.0)
            
            opt_base.step()
            
            # Update EMA
            update_ema(ema_base, unet_base)
            
            total_loss += loss.item()
            num_batches += 1
        
        scheduler_base.step()
        
        # Validation
        val_loss = trainer_base.validate()['loss']
        trainer_base.train_history.append({'loss': total_loss / num_batches})
        trainer_base.val_history.append({'loss': val_loss})
        
        if epoch % 10 == 0 or epoch == CONFIG['diffusion_epochs'] - 1:
            print(f"  Epoch {epoch}: train={total_loss/num_batches:.6f}, val={val_loss:.6f}")
    
    # Save EMA model as best
    torch.save({'model_state_dict': ema_base.state_dict()}, 
               f'./checkpoints/scaling_diff_baseline_{n}/best.pt')

    # --- MM-Reg Diffusion ---
    print(f"\n--- MM-Reg Diffusion ({n:,} samples) ---")

    diffusion_mm = GaussianDiffusion(
        num_timesteps=CONFIG['diffusion_timesteps'],
        device=device
    )

    unet_mm = SimpleUNet(
        in_channels=4,
        base_channels=128,
        channel_mult=(1, 2, 4),
        num_res_blocks=2
    ).to(device)
    
    # EMA model
    ema_mm = copy.deepcopy(unet_mm)
    for p in ema_mm.parameters():
        p.requires_grad = False

    opt_mm = torch.optim.AdamW(unet_mm.parameters(), lr=CONFIG['diffusion_lr'])
    
    scheduler_mm = torch.optim.lr_scheduler.CosineAnnealingLR(
        opt_mm, T_max=CONFIG['diffusion_epochs']
    )

    trainer_mm = DiffusionTrainer(
        model=unet_mm,
        diffusion=diffusion_mm,
        optimizer=opt_mm,
        train_latents=sub_train_mmreg,
        val_latents=val_latents_mmreg_norm,
        batch_size=CONFIG['batch_size'],
        device=device,
        save_dir=f'./checkpoints/scaling_diff_mmreg_{n}',
        scheduler=scheduler_mm
    )

    print(f"Training diffusion for {CONFIG['diffusion_epochs']} epochs...")
    for epoch in range(CONFIG['diffusion_epochs']):
        trainer_mm.model.train()
        total_loss = 0
        num_batches = 0
        
        for latents in trainer_mm.train_loader:
            latents = latents.to(device)
            batch_size = latents.shape[0]
            t = torch.randint(0, diffusion_mm.num_timesteps, (batch_size,), device=device)
            
            opt_mm.zero_grad()
            loss = diffusion_mm.p_losses(unet_mm, latents, t)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(unet_mm.parameters(), max_norm=1.0)
            
            opt_mm.step()
            
            # Update EMA
            update_ema(ema_mm, unet_mm)
            
            total_loss += loss.item()
            num_batches += 1
        
        scheduler_mm.step()
        
        # Validation
        val_loss = trainer_mm.validate()['loss']
        trainer_mm.train_history.append({'loss': total_loss / num_batches})
        trainer_mm.val_history.append({'loss': val_loss})
        
        if epoch % 10 == 0 or epoch == CONFIG['diffusion_epochs'] - 1:
            print(f"  Epoch {epoch}: train={total_loss/num_batches:.6f}, val={val_loss:.6f}")
    
    # Save EMA model as best
    torch.save({'model_state_dict': ema_mm.state_dict()}, 
               f'./checkpoints/scaling_diff_mmreg_{n}/best.pt')

    # Record results
    base_val = trainer_base.val_history[-1]['loss']
    mm_val = trainer_mm.val_history[-1]['loss']
    improvement = (base_val - mm_val) / base_val * 100

    scaling_results[n] = {
        'baseline_train': trainer_base.train_history[-1]['loss'],
        'baseline_val': base_val,
        'mmreg_train': trainer_mm.train_history[-1]['loss'],
        'mmreg_val': mm_val,
        'improvement_pct': improvement,
        'baseline_history': trainer_base.val_history,
        'mmreg_history': trainer_mm.val_history,
    }

    print(f"\nScale {n:,}: Baseline val={base_val:.6f}, MM-Reg val={mm_val:.6f}, Improvement={improvement:.1f}%")

    # Free GPU memory
    del unet_base, unet_mm, ema_base, ema_mm, trainer_base, trainer_mm
    del diffusion_base, diffusion_mm, opt_base, opt_mm
    torch.cuda.empty_cache()

print("\n" + "="*60)
print("ALL SCALES COMPLETE")
print("="*60)

## 6. Results

In [None]:
# Results table
print(f"{'Scale':>10} | {'Baseline Val':>14} | {'MM-Reg Val':>14} | {'Improvement':>12}")
print("-" * 60)
for n in sorted(scaling_results.keys()):
    r = scaling_results[n]
    print(f"{n:>10,} | {r['baseline_val']:>14.6f} | {r['mmreg_val']:>14.6f} | {r['improvement_pct']:>11.1f}%")

In [None]:
# Plot 1: Val loss vs dataset size
scales = sorted(scaling_results.keys())
base_vals = [scaling_results[n]['baseline_val'] for n in scales]
mm_vals = [scaling_results[n]['mmreg_val'] for n in scales]
improvements = [scaling_results[n]['improvement_pct'] for n in scales]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: absolute val loss
axes[0].plot(scales, base_vals, 'bo-', label='Baseline', markersize=8)
axes[0].plot(scales, mm_vals, 'rs-', label='MM-Reg', markersize=8)
axes[0].set_xlabel('Training Set Size', fontsize=12)
axes[0].set_ylabel('Diffusion Val Loss', fontsize=12)
axes[0].set_title('Diffusion Val Loss vs Dataset Size', fontsize=13)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
axes[0].set_xscale('log')
axes[0].set_xticks(scales)
axes[0].set_xticklabels([f'{s//1000}k' for s in scales])

# Right: improvement %
axes[1].bar(range(len(scales)), improvements, color='green', alpha=0.7)
axes[1].set_xlabel('Training Set Size', fontsize=12)
axes[1].set_ylabel('MM-Reg Improvement (%)', fontsize=12)
axes[1].set_title('MM-Reg Advantage vs Dataset Size', fontsize=13)
axes[1].set_xticks(range(len(scales)))
axes[1].set_xticklabels([f'{s//1000}k' for s in scales])
axes[1].grid(True, alpha=0.3, axis='y')
for i, v in enumerate(improvements):
    axes[1].text(i, v + 0.5, f'{v:.1f}%', ha='center', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig('./checkpoints/scaling_results.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Plot 2: Training curves at each scale
n_scales = len(scales)
fig, axes = plt.subplots(1, n_scales, figsize=(4 * n_scales, 4), sharey=True)

if n_scales == 1:
    axes = [axes]

for i, n in enumerate(scales):
    r = scaling_results[n]
    base_curve = [h['loss'] for h in r['baseline_history']]
    mm_curve = [h['loss'] for h in r['mmreg_history']]

    axes[i].plot(base_curve, 'b-', label='Baseline', alpha=0.8)
    axes[i].plot(mm_curve, 'r-', label='MM-Reg', alpha=0.8)
    axes[i].set_title(f'{n//1000}k samples', fontsize=12)
    axes[i].set_xlabel('Epoch')
    axes[i].grid(True, alpha=0.3)
    if i == 0:
        axes[i].set_ylabel('Val Loss')
    axes[i].legend(fontsize=9)

plt.suptitle('Diffusion Val Loss Curves Across Scales', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('./checkpoints/scaling_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Generate sample faces at largest scale for visual comparison
largest_scale = max(scales)

# Load latent normalization stats
latent_stats = torch.load('./embeddings/latent_stats.pt')
print(f"Latent stats loaded:")
print(f"  Baseline: mean={latent_stats['baseline']['mean']:.4f}, std={latent_stats['baseline']['std']:.4f}")
print(f"  MM-Reg:   mean={latent_stats['mmreg']['mean']:.4f}, std={latent_stats['mmreg']['std']:.4f}")

# Reload best diffusion models from largest scale
print(f"\nLoading best models from {largest_scale//1000}k scale...")

diffusion_gen = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'],
    device=device
)

# Baseline
unet_gen_base = SimpleUNet(in_channels=4, base_channels=128, channel_mult=(1, 2, 4), num_res_blocks=2).to(device)
ckpt_base = torch.load(f'./checkpoints/scaling_diff_baseline_{largest_scale}/best.pt', map_location=device)
unet_gen_base.load_state_dict(ckpt_base['model_state_dict'])

# MM-Reg
unet_gen_mm = SimpleUNet(in_channels=4, base_channels=128, channel_mult=(1, 2, 4), num_res_blocks=2).to(device)
ckpt_mm = torch.load(f'./checkpoints/scaling_diff_mmreg_{largest_scale}/best.pt', map_location=device)
unet_gen_mm.load_state_dict(ckpt_mm['model_state_dict'])

# Generate (in normalized space)
print("Generating baseline samples...")
latent_shape = (16, 4, latent_size, latent_size)
gen_latents_base_norm = diffusion_gen.sample(unet_gen_base, latent_shape, progress=True)

print("Generating MM-Reg samples...")
gen_latents_mm_norm = diffusion_gen.sample(unet_gen_mm, latent_shape, progress=True)

# === DENORMALIZE before decoding ===
gen_latents_base = gen_latents_base_norm * latent_stats['baseline']['std'] + latent_stats['baseline']['mean']
gen_latents_mm = gen_latents_mm_norm * latent_stats['mmreg']['std'] + latent_stats['mmreg']['mean']

print(f"\nGenerated latent stats (after denorm):")
print(f"  Baseline: mean={gen_latents_base.mean():.4f}, std={gen_latents_base.std():.4f}")
print(f"  MM-Reg:   mean={gen_latents_mm.mean():.4f}, std={gen_latents_mm.std():.4f}")

# Decode to images
def decode_and_plot(vae, latents, title, save_path):
    vae.eval()
    with torch.no_grad():
        images = vae.decode(latents.to(device))

    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    fig.suptitle(title, fontsize=14)
    for i in range(16):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = ((img + 1) / 2).clip(0, 1)
        axes[i // 8, i % 8].imshow(img)
        axes[i // 8, i % 8].axis('off')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

decode_and_plot(vae_baseline, gen_latents_base,
               f"Generated Faces - Baseline ({largest_scale//1000}k)",
               './checkpoints/scaling_generated_baseline.png')

decode_and_plot(vae_mmreg, gen_latents_mm,
               f"Generated Faces - MM-Reg ({largest_scale//1000}k)",
               './checkpoints/scaling_generated_mmreg.png')

## 7. Summary

In [None]:
def to_python(val):
    if hasattr(val, 'item'):
        return val.item()
    return float(val)

summary = {
    'config': CONFIG,
    'vae_results': {
        'baseline': {k: to_python(v) for k, v in vae_results_baseline.items()},
        'mmreg': {k: to_python(v) for k, v in vae_results_mmreg.items()}
    },
    'scaling_results': {
        str(n): {
            'baseline_val': to_python(r['baseline_val']),
            'mmreg_val': to_python(r['mmreg_val']),
            'improvement_pct': to_python(r['improvement_pct'])
        }
        for n, r in scaling_results.items()
    }
}

with open('./checkpoints/scaling_experiment_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("="*60)
print("SCALING EXPERIMENT SUMMARY")
print("="*60)

print(f"\nVAE (trained on full {total_train:,}):")
print(f"  Baseline: Recon MSE={vae_results_baseline['recon_mse']:.6f}, Pearson={vae_results_baseline['pearson']:.4f}")
print(f"  MM-Reg:   Recon MSE={vae_results_mmreg['recon_mse']:.6f}, Pearson={vae_results_mmreg['pearson']:.4f}")

print(f"\nDiffusion scaling ({CONFIG['diffusion_epochs']} epochs each):")
print(f"{'Scale':>10} | {'Baseline':>10} | {'MM-Reg':>10} | {'Improvement':>12}")
print("-" * 50)
for n in sorted(scaling_results.keys()):
    r = scaling_results[n]
    print(f"{n:>10,} | {r['baseline_val']:>10.6f} | {r['mmreg_val']:>10.6f} | {r['improvement_pct']:>11.1f}%")

# Trend analysis
imps = [scaling_results[n]['improvement_pct'] for n in sorted(scaling_results.keys())]
if imps[-1] > imps[0]:
    trend = "WIDENING - MM-Reg advantage grows with data"
elif imps[-1] < imps[0] - 2:
    trend = "NARROWING - MM-Reg advantage decreases with data"
else:
    trend = "STABLE - MM-Reg advantage is consistent across scales"

print(f"\nTrend: {trend}")
print(f"\nResults saved to ./checkpoints/scaling_experiment_summary.json")