# MM-Reg: CelebA Experiment (SD VAE + Diffusion)

Does MM-Reg regularization improve diffusion generation quality on CelebA faces?

## Design:
1. Fine-tune TWO SD VAEs on full CelebA (~160k): **Baseline** vs **MM-Reg**
2. Encode all images with both VAEs (cached as 4x16x16 latents)
3. Train diffusion models on each set of latents
4. Compare diffusion val loss, generate samples, and evaluate

**Key**: Same architecture, same data, only difference is MM-Reg during VAE training.

## 1. Setup

In [None]:
# Clone repository
!rm -rf MMReg_diffusion_generative 2>/dev/null
!git clone https://github.com/laurent-cheret/MMReg_diffusion_generative.git
%cd MMReg_diffusion_generative

In [None]:
# Install dependencies
!pip install -q torch torchvision diffusers transformers accelerate
!pip install -q pyyaml tqdm scipy scikit-learn matplotlib
!pip install -q datasets gdown

In [None]:
import sys
sys.path.insert(0, '.')

import torch
import torch.nn as nn
import os
import json
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

In [None]:
# Experiment configuration
CONFIG = {
    # Data
    'data_root': './data',
    'image_size': 128,
    'batch_size': 64,

    # PCA
    'pca_components': 256,

    # VAE Training (on full dataset)
    'vae_epochs': 5,
    'vae_lr': 1e-5,
    'lambda_mm': 1.0,
    'beta': 1e-6,

    # Diffusion Training
    'diffusion_epochs': 50,
    'diffusion_lr': 1e-4,
    'diffusion_timesteps': 1000,
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 2. Load CelebA & Pre-compute PCA

In [None]:
from src.data.dataset import (
    get_celeba_dataset,
    compute_pca_embeddings_celeba,
    get_dataset_and_loader
)

# === CONFIGURATION: Choose your data source ===
# Option 1: HuggingFace (recommended)
USE_HUGGINGFACE = True

# Option 2: Google Drive
USE_DRIVE = False
DRIVE_PATHS = {
    'images_zip': '/content/drive/MyDrive/DATASETS/CelebA/img_align_celeba.zip',
    'attr_file': '/content/drive/MyDrive/DATASETS/CelebA/list_attr_celeba.txt',
    'partition_file': '/content/drive/MyDrive/DATASETS/CelebA/list_eval_partition.txt'
}

if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')

source = 'drive' if USE_DRIVE else 'huggingface'
drive_paths = DRIVE_PATHS if USE_DRIVE else None

# Load full CelebA with fixed transforms for PCA
print(f"Loading CelebA from {source}...")
train_dataset_fixed = get_celeba_dataset(
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    fixed_transform=True,
    source=source,
    drive_paths=drive_paths
)

val_dataset_fixed = get_celeba_dataset(
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    fixed_transform=True,
    source=source,
    drive_paths=drive_paths
)

total_train = len(train_dataset_fixed)
print(f"Train: {total_train}, Val: {len(val_dataset_fixed)}")

In [None]:
# Compute PCA embeddings on full training set
os.makedirs('./embeddings', exist_ok=True)

print("Computing PCA embeddings for training set...")
train_pca = compute_pca_embeddings_celeba(
    train_dataset_fixed,
    n_components=CONFIG['pca_components'],
    batch_size=256
)
torch.save(train_pca, './embeddings/celeba_train_pca.pt')

print("Computing PCA embeddings for validation set...")
val_pca = compute_pca_embeddings_celeba(
    val_dataset_fixed,
    n_components=CONFIG['pca_components'],
    batch_size=256
)
torch.save(val_pca, './embeddings/celeba_val_pca.pt')

print(f"Train PCA: {train_pca.shape}, Val PCA: {val_pca.shape}")

## 3. Fine-tune VAEs on Full CelebA

Train two SD VAEs on the full dataset:
- **Baseline**: reconstruction + KL only (no MM-Reg)
- **MM-Reg**: reconstruction + KL + MM-Reg regularization

### 3.1 Baseline VAE (no MM-Reg)

In [None]:
from src.models.vae_wrapper import load_vae
from src.models.losses import VAELoss
from src.trainer import MMRegTrainer

print("="*60)
print("TRAINING BASELINE VAE (no MM-Reg) on full CelebA")
print("="*60)

vae_baseline = load_vae(device=device)

loss_baseline = VAELoss(lambda_mm=0.0, beta=CONFIG['beta'])

train_dataset_base, train_loader_base = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_train_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

val_dataset_base, val_loader_base = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_val_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

optimizer_baseline = torch.optim.AdamW(vae_baseline.parameters(), lr=CONFIG['vae_lr'])

trainer_baseline = MMRegTrainer(
    vae=vae_baseline,
    loss_fn=loss_baseline,
    optimizer=optimizer_baseline,
    train_loader=train_loader_base,
    val_loader=val_loader_base,
    device=device,
    save_dir='./checkpoints/scaling_baseline_vae'
)

trainer_baseline.train(num_epochs=CONFIG['vae_epochs'])

### 3.2 MM-Reg VAE

In [None]:
print("="*60)
print("TRAINING MM-REG VAE on full CelebA")
print("="*60)

vae_mmreg = load_vae(device=device)

loss_mmreg = VAELoss(
    lambda_mm=CONFIG['lambda_mm'],
    beta=CONFIG['beta'],
    mm_variant='correlation'
)

train_dataset_mm, train_loader_mm = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_train_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

val_dataset_mm, val_loader_mm = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_val_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

optimizer_mmreg = torch.optim.AdamW(vae_mmreg.parameters(), lr=CONFIG['vae_lr'])

trainer_mmreg = MMRegTrainer(
    vae=vae_mmreg,
    loss_fn=loss_mmreg,
    optimizer=optimizer_mmreg,
    train_loader=train_loader_mm,
    val_loader=val_loader_mm,
    device=device,
    save_dir='./checkpoints/scaling_mmreg_vae'
)

trainer_mmreg.train(num_epochs=CONFIG['vae_epochs'])

## 4. Evaluate VAEs & Encode Full Dataset

In [None]:
from src.models.losses import pairwise_distances, get_upper_triangular
from scipy.stats import pearsonr

def quick_evaluate(vae, val_loader, val_pca, name):
    """Quick VAE evaluation: recon MSE + Pearson correlation."""
    vae.eval()
    all_latents = []
    total_recon = 0
    num_samples = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Eval {name}"):
            images = batch[0].to(device)
            outputs = vae(images, sample=False)
            total_recon += ((outputs['x_recon'] - images) ** 2).mean().item() * images.shape[0]
            num_samples += images.shape[0]
            all_latents.append(outputs['latent_flat'].cpu())

    all_latents = torch.cat(all_latents, dim=0)

    n = min(500, len(all_latents))
    D_lat = pairwise_distances(all_latents[:n])
    D_pca = pairwise_distances(val_pca[:n])
    d_lat = get_upper_triangular(D_lat).numpy()
    d_pca = get_upper_triangular(D_pca).numpy()
    pearson, _ = pearsonr(d_lat, d_pca)

    recon_mse = total_recon / num_samples
    print(f"{name}: Recon MSE={recon_mse:.6f}, Pearson={pearson:.4f}")
    return {'recon_mse': recon_mse, 'pearson': pearson}

print("\nVAE Evaluation (on full val set):")
vae_results_baseline = quick_evaluate(vae_baseline, val_loader_base, val_pca, "Baseline")
vae_results_mmreg = quick_evaluate(vae_mmreg, val_loader_mm, val_pca, "MM-Reg")

In [None]:
from src.diffusion_trainer import encode_dataset

# Simple dataloaders for encoding (no PCA wrapper)
train_loader_simple = DataLoader(
    train_dataset_fixed,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)
val_loader_simple = DataLoader(
    val_dataset_fixed,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)

# Encode full dataset with both VAEs
print("Encoding full dataset with Baseline VAE...")
train_latents_baseline = encode_dataset(vae_baseline, train_loader_simple, device)
val_latents_baseline = encode_dataset(vae_baseline, val_loader_simple, device)
print(f"Baseline latents - Train: {train_latents_baseline.shape}, Val: {val_latents_baseline.shape}")

print("\nEncoding full dataset with MM-Reg VAE...")
train_latents_mmreg = encode_dataset(vae_mmreg, train_loader_simple, device)
val_latents_mmreg = encode_dataset(vae_mmreg, val_loader_simple, device)
print(f"MM-Reg latents - Train: {train_latents_mmreg.shape}, Val: {val_latents_mmreg.shape}")

# Save for potential reuse
torch.save(train_latents_baseline, './embeddings/celeba_train_latents_baseline.pt')
torch.save(val_latents_baseline, './embeddings/celeba_val_latents_baseline.pt')
torch.save(train_latents_mmreg, './embeddings/celeba_train_latents_mmreg.pt')
torch.save(val_latents_mmreg, './embeddings/celeba_val_latents_mmreg.pt')
print("\nLatents cached to ./embeddings/")

## 5. Train Diffusion Models

Train two diffusion models on the full set of cached latents:
- **Baseline**: trained on latents from baseline VAE
- **MM-Reg**: trained on latents from MM-Reg VAE

In [None]:
from src.models.diffusion import DiT, GaussianDiffusion
from src.diffusion_trainer import DiffusionTrainer

latent_size = train_latents_baseline.shape[2]  # 16 for 128x128 images
print(f"Latent size: {latent_size}x{latent_size}")
print(f"Training samples: {len(train_latents_baseline):,}")

# --- Baseline Diffusion (DiT-S) ---
print(f"\n{'='*60}")
print("TRAINING BASELINE DIFFUSION (DiT-S)")
print(f"{'='*60}")

diffusion_base = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'],
    device=device
)

dit_base = DiT(
    in_channels=4,
    patch_size=2,
    hidden_size=384,
    depth=12,
    num_heads=6,
    mlp_ratio=4.0,
    input_size=latent_size,
).to(device)

print(f"DiT-S params: {sum(p.numel() for p in dit_base.parameters()):,}")

opt_base = torch.optim.AdamW(dit_base.parameters(), lr=CONFIG['diffusion_lr'])

trainer_diff_baseline = DiffusionTrainer(
    model=dit_base,
    diffusion=diffusion_base,
    optimizer=opt_base,
    train_latents=train_latents_baseline,
    val_latents=val_latents_baseline,
    batch_size=CONFIG['batch_size'],
    device=device,
    save_dir='./checkpoints/celeba_diff_baseline'
)

trainer_diff_baseline.train(num_epochs=CONFIG['diffusion_epochs'])

# --- MM-Reg Diffusion (DiT-S) ---
print(f"\n{'='*60}")
print("TRAINING MM-REG DIFFUSION (DiT-S)")
print(f"{'='*60}")

diffusion_mm = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'],
    device=device
)

dit_mm = DiT(
    in_channels=4,
    patch_size=2,
    hidden_size=384,
    depth=12,
    num_heads=6,
    mlp_ratio=4.0,
    input_size=latent_size,
).to(device)

opt_mm = torch.optim.AdamW(dit_mm.parameters(), lr=CONFIG['diffusion_lr'])

trainer_diff_mmreg = DiffusionTrainer(
    model=dit_mm,
    diffusion=diffusion_mm,
    optimizer=opt_mm,
    train_latents=train_latents_mmreg,
    val_latents=val_latents_mmreg,
    batch_size=CONFIG['batch_size'],
    device=device,
    save_dir='./checkpoints/celeba_diff_mmreg'
)

trainer_diff_mmreg.train(num_epochs=CONFIG['diffusion_epochs'])

# Record results
base_val = trainer_diff_baseline.val_history[-1]['loss']
mm_val = trainer_diff_mmreg.val_history[-1]['loss']
improvement = (base_val - mm_val) / base_val * 100

print(f"\n{'='*60}")
print(f"Baseline val loss: {base_val:.6f}")
print(f"MM-Reg val loss:   {mm_val:.6f}")
print(f"Improvement:       {improvement:.1f}%")
print(f"{'='*60}")

## 6. Results & Visualization

In [None]:
# Training curves comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# VAE training curves
with open('./checkpoints/scaling_baseline_vae/history.json') as f:
    base_vae_hist = json.load(f)
with open('./checkpoints/scaling_mmreg_vae/history.json') as f:
    mm_vae_hist = json.load(f)

axes[0].plot([h['loss'] for h in base_vae_hist['train']], 'b-', label='Baseline Train')
axes[0].plot([h['loss'] for h in base_vae_hist['val']], 'b--', label='Baseline Val')
axes[0].plot([h['loss'] for h in mm_vae_hist['train']], 'r-', label='MM-Reg Train')
axes[0].plot([h['loss'] for h in mm_vae_hist['val']], 'r--', label='MM-Reg Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('VAE Training')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Diffusion training curves
base_diff_curve = [h['loss'] for h in trainer_diff_baseline.val_history]
mm_diff_curve = [h['loss'] for h in trainer_diff_mmreg.val_history]

axes[1].plot(base_diff_curve, 'b-', label='Baseline Val')
axes[1].plot(mm_diff_curve, 'r-', label='MM-Reg Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Diffusion Training (Val Loss)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('./checkpoints/celeba_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Generate sample faces for visual comparison
print("Loading best diffusion models...")

diffusion_gen = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'],
    device=device
)

# Baseline DiT-S
dit_gen_base = DiT(in_channels=4, patch_size=2, hidden_size=384, depth=12, num_heads=6, input_size=latent_size).to(device)
ckpt_base = torch.load('./checkpoints/celeba_diff_baseline/best.pt', map_location=device)
dit_gen_base.load_state_dict(ckpt_base['model_state_dict'])

# MM-Reg DiT-S
dit_gen_mm = DiT(in_channels=4, patch_size=2, hidden_size=384, depth=12, num_heads=6, input_size=latent_size).to(device)
ckpt_mm = torch.load('./checkpoints/celeba_diff_mmreg/best.pt', map_location=device)
dit_gen_mm.load_state_dict(ckpt_mm['model_state_dict'])

# Generate
print("Generating baseline samples...")
latent_shape = (16, 4, latent_size, latent_size)
gen_latents_base = diffusion_gen.sample(dit_gen_base, latent_shape, progress=True)

print("Generating MM-Reg samples...")
gen_latents_mm = diffusion_gen.sample(dit_gen_mm, latent_shape, progress=True)

# Decode to images
def decode_and_plot(vae, latents, title, save_path):
    vae.eval()
    with torch.no_grad():
        images = vae.decode(latents.to(device))

    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    fig.suptitle(title, fontsize=14)
    for i in range(16):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = ((img + 1) / 2).clip(0, 1)
        axes[i // 8, i % 8].imshow(img)
        axes[i // 8, i % 8].axis('off')
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

decode_and_plot(vae_baseline, gen_latents_base,
               "Generated Faces - Baseline (DiT-S)",
               './checkpoints/celeba_generated_baseline.png')

decode_and_plot(vae_mmreg, gen_latents_mm,
               "Generated Faces - MM-Reg (DiT-S)",
               './checkpoints/celeba_generated_mmreg.png')

## 7. Summary

In [None]:
def to_python(val):
    if hasattr(val, 'item'):
        return val.item()
    return float(val)

summary = {
    'config': CONFIG,
    'vae_results': {
        'baseline': {k: to_python(v) for k, v in vae_results_baseline.items()},
        'mmreg': {k: to_python(v) for k, v in vae_results_mmreg.items()}
    },
    'diffusion_results': {
        'baseline_train': to_python(trainer_diff_baseline.train_history[-1]['loss']),
        'baseline_val': to_python(base_val),
        'mmreg_train': to_python(trainer_diff_mmreg.train_history[-1]['loss']),
        'mmreg_val': to_python(mm_val),
        'improvement_pct': to_python(improvement)
    }
}

with open('./checkpoints/celeba_experiment_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("="*60)
print("CELEBA EXPERIMENT SUMMARY")
print("="*60)

print(f"\nVAE (trained on full {total_train:,} images):")
print(f"  Baseline: Recon MSE={vae_results_baseline['recon_mse']:.6f}, Pearson={vae_results_baseline['pearson']:.4f}")
print(f"  MM-Reg:   Recon MSE={vae_results_mmreg['recon_mse']:.6f}, Pearson={vae_results_mmreg['pearson']:.4f}")

print(f"\nDiffusion ({CONFIG['diffusion_epochs']} epochs):")
print(f"  Baseline val loss: {base_val:.6f}")
print(f"  MM-Reg val loss:   {mm_val:.6f}")
print(f"  MM-Reg improvement: {improvement:.1f}%")

print(f"\nResults saved to ./checkpoints/celeba_experiment_summary.json")