# MM-Reg: Manifold-Matching Regularization for VAE

This notebook implements MM-Reg with **pre-computed PCA embeddings**.

**Key insight**: Pre-compute PCA projections for ALL training samples. During training, each batch looks up its corresponding PCA embeddings and compares pairwise distances.

**Steps:**
1. Setup & Install
2. Download data & Pre-compute PCA embeddings
3. Train MM-Reg VAE
4. Evaluate

## 1. Setup

In [None]:
# Clone repository
!git clone https://github.com/laurent-cheret/MMReg_diffusion_generative.git
%cd MMReg_diffusion_generative

In [None]:
# Install dependencies
!pip install -q torch torchvision diffusers transformers accelerate
!pip install -q pyyaml tqdm scipy scikit-learn

In [None]:
import sys
sys.path.insert(0, '.')

import torch
import os

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Pre-compute PCA Embeddings

This is the key step: compute PCA projections for ALL training samples using **fixed transforms** (no augmentation). These embeddings define the reference manifold structure.

In [None]:
from src.data.dataset import (
    get_imagenette_dataset, 
    compute_pca_embeddings,
    get_dataset_and_loader
)

# Configuration
DATA_ROOT = './data'
IMAGE_SIZE = 256
PCA_COMPONENTS = 256  # Dimensionality of reference space

# Load dataset with FIXED transforms (deterministic, no augmentation)
print("Loading Imagenette with fixed transforms...")
train_dataset_fixed = get_imagenette_dataset(
    root=DATA_ROOT,
    split='train',
    image_size=IMAGE_SIZE,
    fixed_transform=True  # Important: no random augmentation
)

val_dataset_fixed = get_imagenette_dataset(
    root=DATA_ROOT,
    split='val',
    image_size=IMAGE_SIZE,
    fixed_transform=True
)

print(f"Train samples: {len(train_dataset_fixed)}")
print(f"Val samples: {len(val_dataset_fixed)}")

In [None]:
# Compute PCA embeddings (this takes a few minutes)
os.makedirs('./embeddings', exist_ok=True)

print("\nComputing PCA embeddings for training set...")
train_pca = compute_pca_embeddings(
    train_dataset_fixed,
    n_components=PCA_COMPONENTS,
    batch_size=64
)
torch.save(train_pca, './embeddings/train_pca.pt')

print("\nComputing PCA embeddings for validation set...")
val_pca = compute_pca_embeddings(
    val_dataset_fixed,
    n_components=PCA_COMPONENTS,
    batch_size=64
)
torch.save(val_pca, './embeddings/val_pca.pt')

print(f"\nâœ“ PCA embeddings saved!")
print(f"  Train: {train_pca.shape}")
print(f"  Val: {val_pca.shape}")

## 3. Train MM-Reg VAE

Now train with pre-computed PCA embeddings. The dataloader returns `(image, label, pca_embedding)` tuples.

In [None]:
# Training configuration
CONFIG = {
    'experiment': 'mmreg_pca',
    'lambda_mm': 1.0,      # MM-Reg weight (losses are now properly scaled)
    'beta': 1e-6,          # KL weight (very small, KL is ~70k unscaled)
    'epochs': 5,
    'batch_size': 32,
    'learning_rate': 1e-5,
    'image_size': 256,
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

In [None]:
from src.models.vae_wrapper import load_vae
from src.models.losses import VAELoss
from src.trainer import MMRegTrainer

# Load VAE
print("Loading VAE...")
vae = load_vae(device=device)

# Loss function with proper scaling
loss_fn = VAELoss(
    lambda_mm=CONFIG['lambda_mm'],
    beta=CONFIG['beta'],
    mm_variant='correlation'
)

print(f"Loss weights: lambda_mm={loss_fn.lambda_mm}, beta={loss_fn.beta}")

In [None]:
# Load data WITH pre-computed PCA embeddings
print("Loading data with PCA embeddings...")

train_dataset, train_loader = get_dataset_and_loader(
    dataset_name='imagenette',
    root=DATA_ROOT,
    split='train',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/train_pca.pt'  # Pre-computed PCA!
)

val_dataset, val_loader = get_dataset_and_loader(
    dataset_name='imagenette',
    root=DATA_ROOT,
    split='val',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/val_pca.pt'
)

# Verify the dataloader returns 3 items
batch = next(iter(train_loader))
print(f"Batch contents: {len(batch)} items")
print(f"  Images: {batch[0].shape}")
print(f"  Labels: {batch[1].shape}")
print(f"  PCA embeddings: {batch[2].shape}")

In [None]:
# Setup trainer
optimizer = torch.optim.AdamW(vae.parameters(), lr=CONFIG['learning_rate'])

save_dir = f"./checkpoints/{CONFIG['experiment']}"
trainer = MMRegTrainer(
    vae=vae,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    reference_model=None,  # Not needed - using pre-computed PCA!
    device=device,
    save_dir=save_dir
)

print(f"\nReady to train for {CONFIG['epochs']} epochs")
print(f"Checkpoints will be saved to: {save_dir}")

In [None]:
# Train!
trainer.train(num_epochs=CONFIG['epochs'])

## 4. Evaluation

In [None]:
# Plot training curves
import matplotlib.pyplot as plt
import json

with open(f"{save_dir}/history.json", 'r') as f:
    history = json.load(f)

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

metrics = ['loss', 'recon_loss', 'kl_loss', 'mm_loss']
titles = ['Total Loss', 'Reconstruction', 'KL Divergence', 'MM-Reg Loss']

for ax, metric, title in zip(axes, metrics, titles):
    ax.plot([h[metric] for h in history['train']], label='Train', marker='o')
    ax.plot([h[metric] for h in history['val']], label='Val', marker='s')
    ax.set_xlabel('Epoch')
    ax.set_ylabel(metric)
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{save_dir}/training_curves.png", dpi=150)
plt.show()

# Print final metrics
print("\nFinal metrics:")
print(f"  Train - loss: {history['train'][-1]['loss']:.4f}, mm_loss: {history['train'][-1]['mm_loss']:.4f}")
print(f"  Val   - loss: {history['val'][-1]['loss']:.4f}, mm_loss: {history['val'][-1]['mm_loss']:.4f}")

In [None]:
# Visualize reconstructions
vae.eval()
batch = next(iter(val_loader))
images = batch[0][:8].to(device)

with torch.no_grad():
    outputs = vae(images, sample=False)
    recon = outputs['x_recon']

fig, axes = plt.subplots(2, 8, figsize=(16, 4))

for i in range(8):
    # Original
    img = images[i].cpu().permute(1, 2, 0).numpy()
    img = ((img + 1) / 2).clip(0, 1)
    axes[0, i].imshow(img)
    axes[0, i].axis('off')
    
    # Reconstruction
    rec = recon[i].cpu().permute(1, 2, 0).numpy()
    rec = ((rec + 1) / 2).clip(0, 1)
    axes[1, i].imshow(rec)
    axes[1, i].axis('off')

axes[0, 0].set_title('Original', fontsize=12)
axes[1, 0].set_title('Reconstruction', fontsize=12)

plt.tight_layout()
plt.savefig(f"{save_dir}/reconstructions.png", dpi=150)
plt.show()

In [None]:
# Evaluate distance correlation
from src.analysis.evaluate_geometry import compute_distance_correlation
from src.models.reference import PCAReference

# Create a simple PCA reference that uses the pre-computed embeddings
# For evaluation, we just need the correlation between latent and PCA distances

print("Computing distance correlation...")

# Get a subset of samples
all_latents = []
all_pca = []

vae.eval()
with torch.no_grad():
    for batch in val_loader:
        images, _, pca_emb = batch
        images = images.to(device)
        
        outputs = vae(images, sample=False)
        all_latents.append(outputs['latent_flat'].cpu())
        all_pca.append(pca_emb)

all_latents = torch.cat(all_latents, dim=0)
all_pca = torch.cat(all_pca, dim=0)

print(f"Latents shape: {all_latents.shape}")
print(f"PCA shape: {all_pca.shape}")

# Compute pairwise distances and correlation
from src.models.losses import pairwise_distances, get_upper_triangular
from scipy.stats import pearsonr, spearmanr

# Use subset for speed
n_samples = min(500, len(all_latents))
D_latent = pairwise_distances(all_latents[:n_samples])
D_pca = pairwise_distances(all_pca[:n_samples])

d_latent = get_upper_triangular(D_latent).numpy()
d_pca = get_upper_triangular(D_pca).numpy()

pearson, _ = pearsonr(d_latent, d_pca)
spearman, _ = spearmanr(d_latent, d_pca)

print(f"\n=== Distance Correlation ===")
print(f"Pearson:  {pearson:.4f}")
print(f"Spearman: {spearman:.4f}")
print(f"\nTarget: > 0.5 indicates geometry is being preserved")

In [None]:
# Save results
results = {
    'config': CONFIG,
    'final_train_loss': history['train'][-1]['loss'],
    'final_val_loss': history['val'][-1]['loss'],
    'final_mm_loss': history['val'][-1]['mm_loss'],
    'pearson_correlation': float(pearson),
    'spearman_correlation': float(spearman)
}

with open(f"{save_dir}/results.json", 'w') as f:
    json.dump(results, f, indent=2)

print(f"Results saved to {save_dir}/results.json")
print("\n" + "="*50)
print("SUMMARY")
print("="*50)
for k, v in results.items():
    if isinstance(v, float):
        print(f"{k}: {v:.4f}")
    else:
        print(f"{k}: {v}")