# MM-Reg: Manifold-Matching Regularization for VAE

This notebook tests and trains MM-Reg VAE on Colab.

**Steps:**
1. Setup & Install dependencies
2. Test components (loss, reference model, VAE)
3. Train baseline VAE
4. Train MM-Reg VAE
5. Evaluate and compare

## 1. Setup

In [None]:
# Clone the repository (replace with your GitHub URL)
!git clone https://github.com/YOUR_USERNAME/MMReg_diffusion_generative.git
%cd MMReg_diffusion_generative

In [None]:
# Install dependencies
!pip install -q torch torchvision diffusers transformers accelerate
!pip install -q pyyaml tqdm scipy

In [None]:
import sys
sys.path.insert(0, '.')

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Test Components

In [None]:
# Test MM-Reg Loss
from src.models.losses import MMRegLoss, pairwise_distances

# Create random test tensors
batch_size = 16
z = torch.randn(batch_size, 4096)  # Simulated latents
r = torch.randn(batch_size, 768)   # Simulated DINOv2 features

# Test correlation loss
loss_fn = MMRegLoss(variant='correlation')
loss = loss_fn(z, r)
print(f"Correlation loss: {loss.item():.4f}")

# Test SI-MSE loss
loss_fn_mse = MMRegLoss(variant='si_mse')
loss_mse = loss_fn_mse(z, r)
print(f"SI-MSE loss: {loss_mse.item():.4f}")

print("\n✓ Loss functions working!")

In [None]:
# Test Reference Model (DINOv2)
from src.models.reference import DINOv2Reference

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Load DINOv2
print("Loading DINOv2 (this may take a moment)...")
dino = DINOv2Reference(device=device)

# Test with random images
test_images = torch.randn(4, 3, 256, 256).to(device)
test_images = (test_images - test_images.min()) / (test_images.max() - test_images.min())  # Normalize to [0,1]

with torch.no_grad():
    features = dino(test_images)

print(f"Input shape: {test_images.shape}")
print(f"Output shape: {features.shape}")
print("\n✓ DINOv2 reference model working!")

In [None]:
# Test VAE Wrapper
from src.models.vae_wrapper import load_vae

print("Loading VAE (this may take a moment)...")
vae = load_vae(device=device, use_gradient_checkpointing=True)

# Test forward pass
test_images_vae = torch.randn(4, 3, 256, 256).to(device) * 2 - 1  # Range [-1, 1]

with torch.no_grad():
    outputs = vae(test_images_vae)

print(f"Input shape: {test_images_vae.shape}")
print(f"Latent shape: {outputs['latent'].shape}")
print(f"Latent flat shape: {outputs['latent_flat'].shape}")
print(f"Reconstruction shape: {outputs['x_recon'].shape}")
print("\n✓ VAE wrapper working!")

In [None]:
# Test Full Loss Computation
from src.models.losses import VAELoss

# Create full loss function
full_loss = VAELoss(lambda_mm=0.1, beta=1.0, mm_variant='correlation')

# Forward pass
with torch.no_grad():
    outputs = vae(test_images_vae)
    ref_features = dino(test_images_vae)

# Compute loss (need gradients for this)
vae.train()
outputs = vae(test_images_vae)
ref_features = dino(test_images_vae).detach()

losses = full_loss(
    x=test_images_vae,
    x_recon=outputs['x_recon'],
    z=outputs['latent_flat'],
    r=ref_features,
    posterior=outputs['posterior']
)

print(f"Total loss: {losses['loss'].item():.4f}")
print(f"Recon loss: {losses['recon_loss'].item():.4f}")
print(f"KL loss: {losses['kl_loss'].item():.4f}")
print(f"MM loss: {losses['mm_loss'].item():.4f}")

# Test backward
losses['loss'].backward()
print("\n✓ Full loss and backward pass working!")

In [None]:
# Test Data Loading (Imagenette)
from src.data.dataset import get_dataset_and_loader

print("Loading Imagenette dataset...")
train_dataset, train_loader = get_dataset_and_loader(
    dataset_name='imagenette',
    root='./data',
    split='train',
    image_size=256,
    batch_size=8,
    num_workers=2
)

print(f"Dataset size: {len(train_dataset)}")

# Get a batch
images, labels = next(iter(train_loader))
print(f"Batch shape: {images.shape}")
print(f"Labels: {labels}")
print(f"Image range: [{images.min():.2f}, {images.max():.2f}]")
print("\n✓ Data loading working!")

In [None]:
# Visualize some samples
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 4, figsize=(12, 6))

for i, ax in enumerate(axes.flat):
    img = images[i].permute(1, 2, 0).numpy()
    img = (img + 1) / 2  # Convert from [-1,1] to [0,1]
    img = img.clip(0, 1)
    ax.imshow(img)
    ax.axis('off')
    ax.set_title(f'Label: {labels[i].item()}')

plt.tight_layout()
plt.show()

## 3. Quick Training Test

Run a few iterations to make sure training works.

In [None]:
# Quick training test (few batches)
from src.models.vae_wrapper import load_vae
from src.models.reference import get_reference_model
from src.models.losses import VAELoss
from torch.cuda.amp import GradScaler, autocast

# Reload models fresh
vae = load_vae(device=device)
ref_model = get_reference_model('dinov2', device=device)
loss_fn = VAELoss(lambda_mm=0.1)

optimizer = torch.optim.AdamW(vae.parameters(), lr=1e-5)
scaler = GradScaler()

# Train for a few batches
vae.train()
num_test_batches = 5

print("Running quick training test...")
for batch_idx, (images, _) in enumerate(train_loader):
    if batch_idx >= num_test_batches:
        break
    
    images = images.to(device)
    
    # Get reference
    with torch.no_grad():
        ref_features = ref_model(images)
    
    # Forward
    optimizer.zero_grad()
    with autocast():
        outputs = vae(images)
        losses = loss_fn(
            x=images,
            x_recon=outputs['x_recon'],
            z=outputs['latent_flat'],
            r=ref_features,
            posterior=outputs['posterior']
        )
    
    # Backward
    scaler.scale(losses['loss']).backward()
    scaler.step(optimizer)
    scaler.update()
    
    print(f"Batch {batch_idx+1}: loss={losses['loss'].item():.4f}, "
          f"recon={losses['recon_loss'].item():.4f}, mm={losses['mm_loss'].item():.4f}")

print("\n✓ Training loop working!")

## 4. Full Training

Now run the actual training. Choose baseline or MM-Reg.

In [None]:
# Configuration
CONFIG = {
    'experiment': 'mmreg',  # 'baseline' or 'mmreg'
    'lambda_mm': 0.1,       # Only used for mmreg
    'epochs': 5,
    'batch_size': 32,       # Reduce if OOM
    'learning_rate': 1e-5,
    'image_size': 256,
    'num_workers': 2,
}

print(f"Configuration: {CONFIG}")

In [None]:
# Setup training
from src.models.vae_wrapper import load_vae
from src.models.reference import get_reference_model
from src.models.losses import VAELoss
from src.data.dataset import get_dataset_and_loader
from src.trainer import MMRegTrainer

# Load models
print("Loading models...")
vae = load_vae(device=device)
ref_model = get_reference_model('dinov2', device=device)

# Loss function
lambda_mm = CONFIG['lambda_mm'] if CONFIG['experiment'] == 'mmreg' else 0.0
loss_fn = VAELoss(lambda_mm=lambda_mm)

# Data
print("Loading data...")
train_dataset, train_loader = get_dataset_and_loader(
    dataset_name='imagenette',
    root='./data',
    split='train',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers']
)

val_dataset, val_loader = get_dataset_and_loader(
    dataset_name='imagenette',
    root='./data',
    split='val',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers']
)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

# Optimizer
optimizer = torch.optim.AdamW(vae.parameters(), lr=CONFIG['learning_rate'])

# Trainer
save_dir = f"./checkpoints/{CONFIG['experiment']}"
trainer = MMRegTrainer(
    vae=vae,
    reference_model=ref_model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    save_dir=save_dir
)

print(f"\nReady to train {CONFIG['experiment']} for {CONFIG['epochs']} epochs")

In [None]:
# Run training
trainer.train(num_epochs=CONFIG['epochs'])

## 5. Evaluation

In [None]:
# Evaluate geometric metrics
from src.analysis.evaluate_geometry import full_evaluation

# Load best checkpoint
checkpoint = torch.load(f"{save_dir}/best.pt")
vae.load_state_dict(checkpoint['vae_state_dict'])

# Run evaluation
results = full_evaluation(
    vae=vae,
    reference_model=ref_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_classes=10,  # Imagenette has 10 classes
    device=device
)

print("\n" + "="*50)
print("FINAL RESULTS")
print("="*50)
for key, value in results.items():
    print(f"{key}: {value:.4f}")

In [None]:
# Visualize reconstructions
import matplotlib.pyplot as plt

vae.eval()
images, _ = next(iter(val_loader))
images = images[:8].to(device)

with torch.no_grad():
    outputs = vae(images, sample=False)
    recon = outputs['x_recon']

# Plot
fig, axes = plt.subplots(2, 8, figsize=(16, 4))

for i in range(8):
    # Original
    img = images[i].cpu().permute(1, 2, 0).numpy()
    img = ((img + 1) / 2).clip(0, 1)
    axes[0, i].imshow(img)
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Original')
    
    # Reconstruction
    rec = recon[i].cpu().permute(1, 2, 0).numpy()
    rec = ((rec + 1) / 2).clip(0, 1)
    axes[1, i].imshow(rec)
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Reconstruction')

plt.tight_layout()
plt.savefig(f"{save_dir}/reconstructions.png", dpi=150)
plt.show()

In [None]:
# Plot training curves
import json

with open(f"{save_dir}/history.json", 'r') as f:
    history = json.load(f)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Total loss
axes[0].plot([h['loss'] for h in history['train']], label='Train')
axes[0].plot([h['loss'] for h in history['val']], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Total Loss')
axes[0].legend()
axes[0].set_title('Total Loss')

# Reconstruction loss
axes[1].plot([h['recon_loss'] for h in history['train']], label='Train')
axes[1].plot([h['recon_loss'] for h in history['val']], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Recon Loss')
axes[1].legend()
axes[1].set_title('Reconstruction Loss')

# MM loss
axes[2].plot([h['mm_loss'] for h in history['train']], label='Train')
axes[2].plot([h['mm_loss'] for h in history['val']], label='Val')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('MM Loss')
axes[2].legend()
axes[2].set_title('MM-Reg Loss')

plt.tight_layout()
plt.savefig(f"{save_dir}/training_curves.png", dpi=150)
plt.show()

## 6. Save Results

Save the results for comparison.

In [None]:
# Save final results
import json

final_results = {
    'config': CONFIG,
    'metrics': results
}

with open(f"{save_dir}/final_results.json", 'w') as f:
    json.dump(final_results, f, indent=2)

print(f"Results saved to {save_dir}/final_results.json")