# MM-Reg: Complete Experiment

End-to-end experiment comparing **Baseline VAE** vs **MM-Reg VAE** for latent diffusion.

## Pipeline:
1. **Setup**: Install dependencies, load data
2. **Pre-compute PCA**: Reference embeddings for MM-Reg
3. **Train VAEs**: Baseline (no MM-Reg) and MM-Reg versions
4. **Evaluate VAEs**: Reconstruction quality, distance correlation
5. **Train Diffusion**: On latents from both VAEs
6. **Evaluate Diffusion**: Generate samples, compare quality

**Hypothesis**: MM-Reg VAE creates smoother latent space â†’ diffusion learns faster/better.

## 1. Setup

In [None]:
# Clone repository
!rm -rf MMReg_diffusion_generative 2>/dev/null
!git clone https://github.com/laurent-cheret/MMReg_diffusion_generative.git
%cd MMReg_diffusion_generative

In [None]:
# Install dependencies
!pip install -q torch torchvision diffusers transformers accelerate
!pip install -q pyyaml tqdm scipy scikit-learn matplotlib

In [None]:
import sys
sys.path.insert(0, '.')

import torch
import torch.nn as nn
import os
import json
import matplotlib.pyplot as plt
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Experiment configuration
CONFIG = {
    # Data
    'data_root': './data',
    'image_size': 256,
    'batch_size': 32,
    
    # PCA
    'pca_components': 256,
    
    # VAE Training
    'vae_epochs': 5,
    'vae_lr': 1e-5,
    'lambda_mm': 1.0,
    'beta': 1e-6,
    
    # Diffusion Training
    'diffusion_epochs': 20,
    'diffusion_lr': 1e-4,
    'diffusion_timesteps': 1000,
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 2. Load Data & Pre-compute PCA

In [None]:
from src.data.dataset import (
    get_imagenette_dataset,
    compute_pca_embeddings,
    get_dataset_and_loader
)

# Load datasets with fixed transforms for PCA
print("Loading Imagenette...")
train_dataset_fixed = get_imagenette_dataset(
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    fixed_transform=True
)

val_dataset_fixed = get_imagenette_dataset(
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    fixed_transform=True
)

print(f"Train: {len(train_dataset_fixed)}, Val: {len(val_dataset_fixed)}")

In [None]:
# Compute PCA embeddings
os.makedirs('./embeddings', exist_ok=True)

print("Computing PCA embeddings...")
train_pca = compute_pca_embeddings(
    train_dataset_fixed,
    n_components=CONFIG['pca_components'],
    batch_size=64
)
torch.save(train_pca, './embeddings/train_pca.pt')

val_pca = compute_pca_embeddings(
    val_dataset_fixed,
    n_components=CONFIG['pca_components'],
    batch_size=64
)
torch.save(val_pca, './embeddings/val_pca.pt')

print(f"Train PCA: {train_pca.shape}, Val PCA: {val_pca.shape}")

## 3. Train VAEs

### 3.1 Baseline VAE (No MM-Reg)

In [None]:
from src.models.vae_wrapper import load_vae
from src.models.losses import VAELoss
from src.trainer import MMRegTrainer

print("="*60)
print("TRAINING BASELINE VAE (no MM-Reg)")
print("="*60)

# Load fresh VAE
vae_baseline = load_vae(device=device)

# Loss without MM-Reg (lambda_mm=0)
loss_baseline = VAELoss(lambda_mm=0.0, beta=CONFIG['beta'])

# Data loaders (without PCA embeddings for baseline)
train_dataset_base, train_loader_base = get_dataset_and_loader(
    dataset_name='imagenette',
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/train_pca.pt'  # Still need for trainer compatibility
)

val_dataset_base, val_loader_base = get_dataset_and_loader(
    dataset_name='imagenette',
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/val_pca.pt'
)

# Optimizer
optimizer_baseline = torch.optim.AdamW(vae_baseline.parameters(), lr=CONFIG['vae_lr'])

# Trainer
trainer_baseline = MMRegTrainer(
    vae=vae_baseline,
    loss_fn=loss_baseline,
    optimizer=optimizer_baseline,
    train_loader=train_loader_base,
    val_loader=val_loader_base,
    device=device,
    save_dir='./checkpoints/baseline_vae'
)

# Train
trainer_baseline.train(num_epochs=CONFIG['vae_epochs'])

### 3.2 MM-Reg VAE

In [None]:
print("="*60)
print("TRAINING MM-REG VAE")
print("="*60)

# Load fresh VAE
vae_mmreg = load_vae(device=device)

# Loss with MM-Reg
loss_mmreg = VAELoss(
    lambda_mm=CONFIG['lambda_mm'],
    beta=CONFIG['beta'],
    mm_variant='correlation'
)

# Use same data loaders (they have PCA embeddings)
train_dataset_mm, train_loader_mm = get_dataset_and_loader(
    dataset_name='imagenette',
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/train_pca.pt'
)

val_dataset_mm, val_loader_mm = get_dataset_and_loader(
    dataset_name='imagenette',
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/val_pca.pt'
)

# Optimizer
optimizer_mmreg = torch.optim.AdamW(vae_mmreg.parameters(), lr=CONFIG['vae_lr'])

# Trainer
trainer_mmreg = MMRegTrainer(
    vae=vae_mmreg,
    loss_fn=loss_mmreg,
    optimizer=optimizer_mmreg,
    train_loader=train_loader_mm,
    val_loader=val_loader_mm,
    device=device,
    save_dir='./checkpoints/mmreg_vae'
)

# Train
trainer_mmreg.train(num_epochs=CONFIG['vae_epochs'])

## 4. Evaluate VAEs

In [None]:
from src.models.losses import pairwise_distances, get_upper_triangular
from scipy.stats import pearsonr, spearmanr

def evaluate_vae(vae, val_loader, name):
    """Evaluate VAE: reconstruction + distance correlation."""
    vae.eval()
    
    all_latents = []
    all_pca = []
    total_recon_error = 0
    num_samples = 0
    
    with torch.no_grad():
        for batch in val_loader:
            images, _, pca_emb = batch
            images = images.to(device)
            
            outputs = vae(images, sample=False)
            
            # Reconstruction error
            recon_error = ((outputs['x_recon'] - images) ** 2).mean().item()
            total_recon_error += recon_error * images.shape[0]
            num_samples += images.shape[0]
            
            all_latents.append(outputs['latent_flat'].cpu())
            all_pca.append(pca_emb)
    
    # Concatenate
    all_latents = torch.cat(all_latents, dim=0)
    all_pca = torch.cat(all_pca, dim=0)
    
    # Distance correlation (use subset for speed)
    n = min(500, len(all_latents))
    D_latent = pairwise_distances(all_latents[:n])
    D_pca = pairwise_distances(all_pca[:n])
    
    d_latent = get_upper_triangular(D_latent).numpy()
    d_pca = get_upper_triangular(D_pca).numpy()
    
    pearson, _ = pearsonr(d_latent, d_pca)
    spearman, _ = spearmanr(d_latent, d_pca)
    
    results = {
        'recon_mse': total_recon_error / num_samples,
        'pearson_corr': pearson,
        'spearman_corr': spearman
    }
    
    print(f"\n{name} Results:")
    print(f"  Reconstruction MSE: {results['recon_mse']:.6f}")
    print(f"  Distance Pearson:   {results['pearson_corr']:.4f}")
    print(f"  Distance Spearman:  {results['spearman_corr']:.4f}")
    
    return results

# Evaluate both VAEs
results_baseline = evaluate_vae(vae_baseline, val_loader_base, "Baseline VAE")
results_mmreg = evaluate_vae(vae_mmreg, val_loader_mm, "MM-Reg VAE")

In [None]:
# Visualize reconstructions
def plot_reconstructions(vae, val_loader, title, save_path):
    vae.eval()
    batch = next(iter(val_loader))
    images = batch[0][:8].to(device)
    
    with torch.no_grad():
        outputs = vae(images, sample=False)
        recon = outputs['x_recon']
    
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    fig.suptitle(title, fontsize=14)
    
    for i in range(8):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = ((img + 1) / 2).clip(0, 1)
        axes[0, i].imshow(img)
        axes[0, i].axis('off')
        
        rec = recon[i].cpu().permute(1, 2, 0).numpy()
        rec = ((rec + 1) / 2).clip(0, 1)
        axes[1, i].imshow(rec)
        axes[1, i].axis('off')
    
    axes[0, 0].set_ylabel('Original', fontsize=12)
    axes[1, 0].set_ylabel('Recon', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

plot_reconstructions(vae_baseline, val_loader_base, "Baseline VAE Reconstructions", 
                     "./checkpoints/baseline_vae/reconstructions.png")
plot_reconstructions(vae_mmreg, val_loader_mm, "MM-Reg VAE Reconstructions",
                     "./checkpoints/mmreg_vae/reconstructions.png")

## 5. Train Diffusion Models

### 5.1 Encode Datasets to Latents

In [None]:
from src.diffusion_trainer import encode_dataset

# Create simple dataloaders without PCA for encoding
from torch.utils.data import DataLoader

train_loader_simple = DataLoader(
    train_dataset_fixed,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)

val_loader_simple = DataLoader(
    val_dataset_fixed,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)

# Encode with baseline VAE
print("Encoding dataset with Baseline VAE...")
train_latents_baseline = encode_dataset(vae_baseline, train_loader_simple, device)
val_latents_baseline = encode_dataset(vae_baseline, val_loader_simple, device)
print(f"Baseline latents - Train: {train_latents_baseline.shape}, Val: {val_latents_baseline.shape}")

# Encode with MM-Reg VAE
print("\nEncoding dataset with MM-Reg VAE...")
train_latents_mmreg = encode_dataset(vae_mmreg, train_loader_simple, device)
val_latents_mmreg = encode_dataset(vae_mmreg, val_loader_simple, device)
print(f"MM-Reg latents - Train: {train_latents_mmreg.shape}, Val: {val_latents_mmreg.shape}")

### 5.2 Train Diffusion on Baseline Latents

In [None]:
from src.models.diffusion import SimpleUNet, GaussianDiffusion
from src.diffusion_trainer import DiffusionTrainer

print("="*60)
print("TRAINING DIFFUSION ON BASELINE LATENTS")
print("="*60)

# Create diffusion model
diffusion_baseline = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'],
    device=device
)

unet_baseline = SimpleUNet(
    in_channels=4,
    base_channels=128,
    channel_mult=(1, 2, 4),
    num_res_blocks=2
).to(device)

optimizer_diff_base = torch.optim.AdamW(unet_baseline.parameters(), lr=CONFIG['diffusion_lr'])

trainer_diff_baseline = DiffusionTrainer(
    model=unet_baseline,
    diffusion=diffusion_baseline,
    optimizer=optimizer_diff_base,
    train_latents=train_latents_baseline,
    val_latents=val_latents_baseline,
    batch_size=CONFIG['batch_size'],
    device=device,
    save_dir='./checkpoints/diffusion_baseline'
)

trainer_diff_baseline.train(num_epochs=CONFIG['diffusion_epochs'])

### 5.3 Train Diffusion on MM-Reg Latents

In [None]:
print("="*60)
print("TRAINING DIFFUSION ON MM-REG LATENTS")
print("="*60)

# Create diffusion model
diffusion_mmreg = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'],
    device=device
)

unet_mmreg = SimpleUNet(
    in_channels=4,
    base_channels=128,
    channel_mult=(1, 2, 4),
    num_res_blocks=2
).to(device)

optimizer_diff_mm = torch.optim.AdamW(unet_mmreg.parameters(), lr=CONFIG['diffusion_lr'])

trainer_diff_mmreg = DiffusionTrainer(
    model=unet_mmreg,
    diffusion=diffusion_mmreg,
    optimizer=optimizer_diff_mm,
    train_latents=train_latents_mmreg,
    val_latents=val_latents_mmreg,
    batch_size=CONFIG['batch_size'],
    device=device,
    save_dir='./checkpoints/diffusion_mmreg'
)

trainer_diff_mmreg.train(num_epochs=CONFIG['diffusion_epochs'])

## 6. Generate Samples & Compare

In [None]:
# Generate samples from both diffusion models
print("Generating samples from Baseline Diffusion...")
samples_baseline = trainer_diff_baseline.generate_samples(num_samples=16)

print("\nGenerating samples from MM-Reg Diffusion...")
samples_mmreg = trainer_diff_mmreg.generate_samples(num_samples=16)

In [None]:
# Decode latents to images
def decode_and_plot(vae, latents, title, save_path):
    vae.eval()
    with torch.no_grad():
        latents = latents.to(device)
        images = vae.decode(latents)
    
    # Plot
    n = min(16, images.shape[0])
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    fig.suptitle(title, fontsize=14)
    
    for i in range(n):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = ((img + 1) / 2).clip(0, 1)
        ax = axes[i // 8, i % 8]
        ax.imshow(img)
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

decode_and_plot(vae_baseline, samples_baseline, 
                "Generated Samples (Baseline VAE + Diffusion)",
                "./checkpoints/diffusion_baseline/generated_samples.png")

decode_and_plot(vae_mmreg, samples_mmreg,
                "Generated Samples (MM-Reg VAE + Diffusion)",
                "./checkpoints/diffusion_mmreg/generated_samples.png")

In [None]:
# Plot training curves comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# VAE losses
with open('./checkpoints/baseline_vae/history.json') as f:
    baseline_vae_hist = json.load(f)
with open('./checkpoints/mmreg_vae/history.json') as f:
    mmreg_vae_hist = json.load(f)

axes[0].plot([h['loss'] for h in baseline_vae_hist['train']], 'b-', label='Baseline Train')
axes[0].plot([h['loss'] for h in baseline_vae_hist['val']], 'b--', label='Baseline Val')
axes[0].plot([h['loss'] for h in mmreg_vae_hist['train']], 'r-', label='MM-Reg Train')
axes[0].plot([h['loss'] for h in mmreg_vae_hist['val']], 'r--', label='MM-Reg Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('VAE Training')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Diffusion losses
with open('./checkpoints/diffusion_baseline/history.json') as f:
    baseline_diff_hist = json.load(f)
with open('./checkpoints/diffusion_mmreg/history.json') as f:
    mmreg_diff_hist = json.load(f)

axes[1].plot([h['loss'] for h in baseline_diff_hist['train']], 'b-', label='Baseline Train')
axes[1].plot([h['loss'] for h in baseline_diff_hist['val']], 'b--', label='Baseline Val')
axes[1].plot([h['loss'] for h in mmreg_diff_hist['train']], 'r-', label='MM-Reg Train')
axes[1].plot([h['loss'] for h in mmreg_diff_hist['val']], 'r--', label='MM-Reg Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Diffusion Training')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('./checkpoints/training_comparison.png', dpi=150)
plt.show()

In [None]:
# Final summary
print("="*60)
print("EXPERIMENT SUMMARY")
print("="*60)

summary = {
    'config': CONFIG,
    'vae_results': {
        'baseline': results_baseline,
        'mmreg': results_mmreg
    },
    'diffusion_final_loss': {
        'baseline_train': baseline_diff_hist['train'][-1]['loss'],
        'baseline_val': baseline_diff_hist['val'][-1]['loss'],
        'mmreg_train': mmreg_diff_hist['train'][-1]['loss'],
        'mmreg_val': mmreg_diff_hist['val'][-1]['loss']
    }
}

print("\nVAE Comparison:")
print(f"  Baseline - Recon MSE: {results_baseline['recon_mse']:.6f}, Pearson: {results_baseline['pearson_corr']:.4f}")
print(f"  MM-Reg   - Recon MSE: {results_mmreg['recon_mse']:.6f}, Pearson: {results_mmreg['pearson_corr']:.4f}")

print("\nDiffusion Final Val Loss:")
print(f"  Baseline: {summary['diffusion_final_loss']['baseline_val']:.6f}")
print(f"  MM-Reg:   {summary['diffusion_final_loss']['mmreg_val']:.6f}")

# Save summary
with open('./checkpoints/experiment_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\nResults saved to ./checkpoints/experiment_summary.json")