# MM-Reg: CelebA Experiment with Attribute Interpolation

End-to-end experiment on CelebA (200k faces, 40 attributes) comparing **Baseline VAE** vs **MM-Reg VAE**.

## Pipeline:
1. **Setup**: Install dependencies, load CelebA
2. **Pre-compute PCA**: Reference embeddings for MM-Reg
3. **Train VAEs**: Baseline and MM-Reg versions
4. **Evaluate VAEs**: Reconstruction, distance correlation
5. **Attribute Directions**: Find semantic directions in latent space
6. **VAE Interpolation**: Test attribute manipulation on reconstructions
7. **Train Diffusion**: On latents from both VAEs
8. **Diffusion + Interpolation**: Apply attributes to generated samples

**Hypothesis**: MM-Reg preserves manifold structure → semantic directions are more meaningful → smoother interpolations.

## 1. Setup

In [None]:
# Clone repository
!rm -rf MMReg_diffusion_generative 2>/dev/null
!git clone https://github.com/laurent-cheret/MMReg_diffusion_generative.git
%cd MMReg_diffusion_generative

In [None]:
# Install dependencies
!pip install -q torch torchvision diffusers transformers accelerate
!pip install -q pyyaml tqdm scipy scikit-learn matplotlib
!pip install -q datasets  # For HuggingFace CelebA

In [None]:
import sys
sys.path.insert(0, '.')

import torch
import torch.nn as nn
import os
import json
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Experiment configuration
CONFIG = {
    # Data
    'data_root': './data',
    'image_size': 128,  # CelebA at 128x128
    'batch_size': 64,
    
    # PCA
    'pca_components': 256,
    
    # VAE Training
    'vae_epochs': 5,
    'vae_lr': 1e-5,
    'lambda_mm': 1.0,
    'beta': 1e-6,
    
    # Diffusion Training
    'diffusion_epochs': 200,
    'diffusion_lr': 1e-4,
    'diffusion_timesteps': 1000,
    
    # Attributes to test for interpolation
    'test_attributes': ['Smiling', 'Eyeglasses', 'Male', 'Young', 'Blond_Hair'],
}

# CelebA attribute names
CELEBA_ATTRIBUTES = [
    '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
    'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
    'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
    'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
    'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
    'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks',
    'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
    'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
]

def get_attr_idx(name):
    """Get index of attribute by name."""
    return CELEBA_ATTRIBUTES.index(name)

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 2. Load CelebA & Pre-compute PCA

In [None]:
from src.data.dataset import (
    get_celeba_dataset,
    compute_pca_embeddings_celeba,
    get_dataset_and_loader,
    CELEBA_ATTRIBUTES
)

# === CONFIGURATION: Choose your data source ===
# Option 1: HuggingFace (recommended, no rate limits)
USE_HUGGINGFACE = True

# Option 2: Google Drive (if you have CelebA saved there)
# Set these paths if you have CelebA in your Drive
USE_DRIVE = False
DRIVE_PATHS = {
    'images_zip': '/content/drive/MyDrive/DATASETS/CelebA/img_align_celeba.zip',
    'attr_file': '/content/drive/MyDrive/DATASETS/CelebA/list_attr_celeba.txt',
    'partition_file': '/content/drive/MyDrive/DATASETS/CelebA/list_eval_partition.txt'
}

# Mount Google Drive if using Drive source
if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')

# Determine source
if USE_DRIVE:
    source = 'drive'
    drive_paths = DRIVE_PATHS
else:
    source = 'huggingface'
    drive_paths = None

# Load CelebA with fixed transforms for PCA
print(f"Loading CelebA from {source}...")
train_dataset_fixed = get_celeba_dataset(
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    fixed_transform=True,
    source=source,
    drive_paths=drive_paths
)

val_dataset_fixed = get_celeba_dataset(
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    fixed_transform=True,
    source=source,
    drive_paths=drive_paths
)

print(f"Train: {len(train_dataset_fixed)}, Val: {len(val_dataset_fixed)}")

In [None]:
# Compute PCA embeddings
os.makedirs('./embeddings', exist_ok=True)

print("Computing PCA embeddings for training set...")
train_pca = compute_pca_embeddings_celeba(
    train_dataset_fixed,
    n_components=CONFIG['pca_components'],
    batch_size=64
)
torch.save(train_pca, './embeddings/celeba_train_pca.pt')

print("\nComputing PCA embeddings for validation set...")
val_pca = compute_pca_embeddings_celeba(
    val_dataset_fixed,
    n_components=CONFIG['pca_components'],
    batch_size=64
)
torch.save(val_pca, './embeddings/celeba_val_pca.pt')

print(f"\nTrain PCA: {train_pca.shape}, Val PCA: {val_pca.shape}")

## 3. Train VAEs

### 3.1 Baseline VAE (No MM-Reg)

In [None]:
from src.models.vae_wrapper import load_vae
from src.models.losses import VAELoss
from src.trainer import MMRegTrainer

print("="*60)
print("TRAINING BASELINE VAE (no MM-Reg)")
print("="*60)

# Load fresh VAE
vae_baseline = load_vae(device=device)

# Loss without MM-Reg
loss_baseline = VAELoss(lambda_mm=0.0, beta=CONFIG['beta'])

# Data loaders (use same source as earlier)
train_dataset_base, train_loader_base = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_train_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

val_dataset_base, val_loader_base = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_val_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

# Optimizer
optimizer_baseline = torch.optim.AdamW(vae_baseline.parameters(), lr=CONFIG['vae_lr'])

# Trainer
trainer_baseline = MMRegTrainer(
    vae=vae_baseline,
    loss_fn=loss_baseline,
    optimizer=optimizer_baseline,
    train_loader=train_loader_base,
    val_loader=val_loader_base,
    device=device,
    save_dir='./checkpoints/celeba_baseline_vae'
)

# Train
trainer_baseline.train(num_epochs=CONFIG['vae_epochs'])

### 3.2 MM-Reg VAE

In [None]:
print("="*60)
print("TRAINING MM-REG VAE")
print("="*60)

# Load fresh VAE
vae_mmreg = load_vae(device=device)

# Loss with MM-Reg
loss_mmreg = VAELoss(
    lambda_mm=CONFIG['lambda_mm'],
    beta=CONFIG['beta'],
    mm_variant='correlation'
)

# Data loaders (use same source as earlier)
train_dataset_mm, train_loader_mm = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_train_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

val_dataset_mm, val_loader_mm = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_val_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

# Optimizer
optimizer_mmreg = torch.optim.AdamW(vae_mmreg.parameters(), lr=CONFIG['vae_lr'])

# Trainer
trainer_mmreg = MMRegTrainer(
    vae=vae_mmreg,
    loss_fn=loss_mmreg,
    optimizer=optimizer_mmreg,
    train_loader=train_loader_mm,
    val_loader=val_loader_mm,
    device=device,
    save_dir='./checkpoints/celeba_mmreg_vae'
)

# Train
trainer_mmreg.train(num_epochs=CONFIG['vae_epochs'])

## 4. Evaluate VAEs

In [None]:
from src.models.losses import pairwise_distances, get_upper_triangular
from scipy.stats import pearsonr, spearmanr

def evaluate_vae_celeba(vae, val_loader, name):
    """Evaluate VAE on CelebA: reconstruction + distance correlation."""
    vae.eval()
    
    all_latents = []
    all_pca = []
    all_attrs = []
    total_recon_error = 0
    num_samples = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Evaluating {name}"):
            images, attrs, pca_emb = batch
            images = images.to(device)
            
            outputs = vae(images, sample=False)
            
            recon_error = ((outputs['x_recon'] - images) ** 2).mean().item()
            total_recon_error += recon_error * images.shape[0]
            num_samples += images.shape[0]
            
            all_latents.append(outputs['latent_flat'].cpu())
            all_pca.append(pca_emb)
            all_attrs.append(attrs)
    
    all_latents = torch.cat(all_latents, dim=0)
    all_pca = torch.cat(all_pca, dim=0)
    all_attrs = torch.cat(all_attrs, dim=0)
    
    # Distance correlation (subset for speed)
    n = min(500, len(all_latents))
    D_latent = pairwise_distances(all_latents[:n])
    D_pca = pairwise_distances(all_pca[:n])
    
    d_latent = get_upper_triangular(D_latent).numpy()
    d_pca = get_upper_triangular(D_pca).numpy()
    
    pearson, _ = pearsonr(d_latent, d_pca)
    spearman, _ = spearmanr(d_latent, d_pca)
    
    results = {
        'recon_mse': total_recon_error / num_samples,
        'pearson_corr': pearson,
        'spearman_corr': spearman,
        'latents': all_latents,
        'attrs': all_attrs
    }
    
    print(f"\n{name} Results:")
    print(f"  Reconstruction MSE: {results['recon_mse']:.6f}")
    print(f"  Distance Pearson:   {results['pearson_corr']:.4f}")
    print(f"  Distance Spearman:  {results['spearman_corr']:.4f}")
    
    return results

# Evaluate both VAEs
results_baseline = evaluate_vae_celeba(vae_baseline, val_loader_base, "Baseline VAE")
results_mmreg = evaluate_vae_celeba(vae_mmreg, val_loader_mm, "MM-Reg VAE")

In [None]:
# Visualize reconstructions
def plot_reconstructions_celeba(vae, val_loader, title, save_path):
    vae.eval()
    batch = next(iter(val_loader))
    images = batch[0][:8].to(device)
    
    with torch.no_grad():
        outputs = vae(images, sample=False)
        recon = outputs['x_recon']
    
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    fig.suptitle(title, fontsize=14)
    
    for i in range(8):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = ((img + 1) / 2).clip(0, 1)
        axes[0, i].imshow(img)
        axes[0, i].axis('off')
        
        rec = recon[i].cpu().permute(1, 2, 0).numpy()
        rec = ((rec + 1) / 2).clip(0, 1)
        axes[1, i].imshow(rec)
        axes[1, i].axis('off')
    
    axes[0, 0].set_ylabel('Original', fontsize=12)
    axes[1, 0].set_ylabel('Recon', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

plot_reconstructions_celeba(vae_baseline, val_loader_base, "Baseline VAE Reconstructions", 
                            "./checkpoints/celeba_baseline_vae/reconstructions.png")
plot_reconstructions_celeba(vae_mmreg, val_loader_mm, "MM-Reg VAE Reconstructions",
                            "./checkpoints/celeba_mmreg_vae/reconstructions.png")

## 5. Compute Attribute Directions

For each attribute, compute the direction in latent space:
```
d_attr = mean(z[attr=1]) - mean(z[attr=0])
```

In [None]:
def compute_attribute_directions(latents, attrs, attr_names):
    """
    Compute attribute directions in latent space.
    
    Args:
        latents: (N, D) tensor of latent vectors
        attrs: (N, 40) tensor of binary attributes
        attr_names: List of attribute names to compute directions for
    
    Returns:
        Dictionary mapping attribute name to direction vector
    """
    directions = {}
    
    for attr_name in attr_names:
        attr_idx = get_attr_idx(attr_name)
        
        # Get samples with and without attribute
        has_attr = attrs[:, attr_idx] > 0.5
        no_attr = attrs[:, attr_idx] <= 0.5
        
        # Compute means
        mean_with = latents[has_attr].mean(dim=0)
        mean_without = latents[no_attr].mean(dim=0)
        
        # Direction: adding this to a latent should add the attribute
        direction = mean_with - mean_without
        directions[attr_name] = direction
        
        print(f"{attr_name}: {has_attr.sum().item()} with, {no_attr.sum().item()} without, ||d||={direction.norm():.2f}")
    
    return directions

print("Computing attribute directions for Baseline VAE...")
directions_baseline = compute_attribute_directions(
    results_baseline['latents'],
    results_baseline['attrs'],
    CONFIG['test_attributes']
)

print("\nComputing attribute directions for MM-Reg VAE...")
directions_mmreg = compute_attribute_directions(
    results_mmreg['latents'],
    results_mmreg['attrs'],
    CONFIG['test_attributes']
)

## 6. VAE Attribute Interpolation

Test if adding attribute direction to a latent produces the expected change.

In [None]:
def interpolate_attribute(vae, image, direction, alphas, device):
    """
    Interpolate an attribute by adding scaled direction to latent.
    
    Args:
        vae: VAE model
        image: (1, C, H, W) input image
        direction: (D,) attribute direction vector
        alphas: List of interpolation strengths
        device: Device
    
    Returns:
        List of interpolated images
    """
    vae.eval()
    image = image.to(device)
    direction = direction.to(device)
    
    with torch.no_grad():
        # Encode
        outputs = vae(image, sample=False)
        z = outputs['latent_flat']  # (1, D)
        
        results = []
        for alpha in alphas:
            # Add scaled direction
            z_new = z + alpha * direction.unsqueeze(0)
            
            # Reshape back to latent shape and decode
            # For 128x128 images: latent is 16x16x4
            latent_shape = outputs['latent'].shape
            z_reshaped = z_new.view(latent_shape)
            
            # Decode
            x_recon = vae.decode(z_reshaped)
            results.append(x_recon.cpu())
    
    return results


def plot_interpolation(vae, val_loader, directions, attr_name, title_prefix, save_path):
    """
    Plot attribute interpolation for several test images.
    """
    # Get test images that DON'T have the attribute
    attr_idx = get_attr_idx(attr_name)
    
    batch = next(iter(val_loader))
    images, attrs, _ = batch
    
    # Find images without the attribute
    no_attr_mask = attrs[:, attr_idx] <= 0.5
    test_images = images[no_attr_mask][:4]
    
    if len(test_images) == 0:
        print(f"No test images without {attr_name}")
        return
    
    alphas = [-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]
    direction = directions[attr_name]
    
    fig, axes = plt.subplots(len(test_images), len(alphas), figsize=(2*len(alphas), 2*len(test_images)))
    fig.suptitle(f"{title_prefix}: {attr_name} Interpolation", fontsize=14)
    
    for i, img in enumerate(test_images):
        results = interpolate_attribute(vae, img.unsqueeze(0), direction, alphas, device)
        
        for j, (alpha, result) in enumerate(zip(alphas, results)):
            img_np = result[0].permute(1, 2, 0).numpy()
            img_np = ((img_np + 1) / 2).clip(0, 1)
            
            ax = axes[i, j] if len(test_images) > 1 else axes[j]
            ax.imshow(img_np)
            ax.axis('off')
            if i == 0:
                ax.set_title(f"α={alpha}")
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Test interpolation for each attribute
os.makedirs('./checkpoints/interpolations', exist_ok=True)

for attr_name in CONFIG['test_attributes']:
    print(f"\nTesting {attr_name} interpolation...")
    
    plot_interpolation(
        vae_baseline, val_loader_base, directions_baseline, attr_name,
        "Baseline VAE", f"./checkpoints/interpolations/baseline_{attr_name}.png"
    )
    
    plot_interpolation(
        vae_mmreg, val_loader_mm, directions_mmreg, attr_name,
        "MM-Reg VAE", f"./checkpoints/interpolations/mmreg_{attr_name}.png"
    )

## 7. Train Diffusion Models

### 7.1 Encode Datasets to Latents

In [None]:
from src.diffusion_trainer import encode_dataset
from torch.utils.data import DataLoader

# Create simple dataloader for encoding (without PCA wrapper)
train_loader_simple = DataLoader(
    train_dataset_fixed,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)

val_loader_simple = DataLoader(
    val_dataset_fixed,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)

# Encode with baseline VAE
print("Encoding dataset with Baseline VAE...")
train_latents_baseline = encode_dataset(vae_baseline, train_loader_simple, device)
val_latents_baseline = encode_dataset(vae_baseline, val_loader_simple, device)
print(f"Baseline latents - Train: {train_latents_baseline.shape}, Val: {val_latents_baseline.shape}")

# Encode with MM-Reg VAE
print("\nEncoding dataset with MM-Reg VAE...")
train_latents_mmreg = encode_dataset(vae_mmreg, train_loader_simple, device)
val_latents_mmreg = encode_dataset(vae_mmreg, val_loader_simple, device)
print(f"MM-Reg latents - Train: {train_latents_mmreg.shape}, Val: {val_latents_mmreg.shape}")

### 7.2 Train Diffusion on Baseline Latents

In [None]:
from src.models.diffusion import SimpleUNet, GaussianDiffusion
from src.diffusion_trainer import DiffusionTrainer

# Get latent shape for UNet configuration
latent_size = train_latents_baseline.shape[2]  # Should be 16 for 128x128 images
print(f"Latent size: {latent_size}x{latent_size}")

print("="*60)
print("TRAINING DIFFUSION ON BASELINE LATENTS")
print("="*60)

diffusion_baseline = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'],
    device=device
)

# Adjust UNet for smaller latent (16x16 instead of 32x32)
# Use fewer resolution levels: (1, 2) gives 16->8 instead of (1,2,4) giving 32->16->8
unet_baseline = SimpleUNet(
    in_channels=4,
    base_channels=128,
    channel_mult=(1, 2, 4),  # Works for 16x16: 16->8->4
    num_res_blocks=2
).to(device)

optimizer_diff_base = torch.optim.AdamW(unet_baseline.parameters(), lr=CONFIG['diffusion_lr'])

trainer_diff_baseline = DiffusionTrainer(
    model=unet_baseline,
    diffusion=diffusion_baseline,
    optimizer=optimizer_diff_base,
    train_latents=train_latents_baseline,
    val_latents=val_latents_baseline,
    batch_size=CONFIG['batch_size'],
    device=device,
    save_dir='./checkpoints/celeba_diffusion_baseline'
)

trainer_diff_baseline.train(num_epochs=CONFIG['diffusion_epochs'])

### 7.3 Train Diffusion on MM-Reg Latents

In [None]:
print("="*60)
print("TRAINING DIFFUSION ON MM-REG LATENTS")
print("="*60)

diffusion_mmreg = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'],
    device=device
)

unet_mmreg = SimpleUNet(
    in_channels=4,
    base_channels=128,
    channel_mult=(1, 2, 4),
    num_res_blocks=2
).to(device)

optimizer_diff_mm = torch.optim.AdamW(unet_mmreg.parameters(), lr=CONFIG['diffusion_lr'])

trainer_diff_mmreg = DiffusionTrainer(
    model=unet_mmreg,
    diffusion=diffusion_mmreg,
    optimizer=optimizer_diff_mm,
    train_latents=train_latents_mmreg,
    val_latents=val_latents_mmreg,
    batch_size=CONFIG['batch_size'],
    device=device,
    save_dir='./checkpoints/celeba_diffusion_mmreg'
)

trainer_diff_mmreg.train(num_epochs=CONFIG['diffusion_epochs'])

## 8. Generate Samples & Apply Attribute Directions

In [None]:
# Generate latent samples from both diffusion models
print("Generating samples from Baseline Diffusion...")
# Shape for 128x128 images: (N, 4, 16, 16)
latent_shape = (16, 4, latent_size, latent_size)
samples_baseline = trainer_diff_baseline.generate_samples(num_samples=16)

print("\nGenerating samples from MM-Reg Diffusion...")
samples_mmreg = trainer_diff_mmreg.generate_samples(num_samples=16)

In [None]:
# Decode and visualize generated samples
def decode_and_plot_celeba(vae, latents, title, save_path):
    vae.eval()
    with torch.no_grad():
        latents = latents.to(device)
        images = vae.decode(latents)
    
    n = min(16, images.shape[0])
    rows = 2
    cols = 8
    fig, axes = plt.subplots(rows, cols, figsize=(16, 4))
    fig.suptitle(title, fontsize=14)
    
    for i in range(n):
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = ((img + 1) / 2).clip(0, 1)
        ax = axes[i // cols, i % cols]
        ax.imshow(img)
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

decode_and_plot_celeba(vae_baseline, samples_baseline, 
                       "Generated Faces (Baseline VAE + Diffusion)",
                       "./checkpoints/celeba_diffusion_baseline/generated_samples.png")

decode_and_plot_celeba(vae_mmreg, samples_mmreg,
                       "Generated Faces (MM-Reg VAE + Diffusion)",
                       "./checkpoints/celeba_diffusion_mmreg/generated_samples.png")

In [None]:
# Apply attribute directions to diffusion-generated samples
def apply_direction_to_generated(vae, latent_samples, direction, alphas, device):
    """
    Apply attribute direction to diffusion-generated latents.
    """
    vae.eval()
    latent_samples = latent_samples.to(device)
    direction = direction.to(device)
    
    results = []
    with torch.no_grad():
        for alpha in alphas:
            # Flatten latent, add direction, reshape back
            z_flat = latent_samples.view(latent_samples.shape[0], -1)
            z_new = z_flat + alpha * direction.unsqueeze(0)
            z_reshaped = z_new.view(latent_samples.shape)
            
            # Decode
            images = vae.decode(z_reshaped)
            results.append(images.cpu())
    
    return results


def plot_generated_interpolation(vae, latent_samples, directions, attr_name, title_prefix, save_path):
    """
    Plot attribute interpolation on diffusion-generated samples.
    """
    alphas = [-1.0, 0.0, 1.0, 2.0]
    direction = directions[attr_name]
    
    # Use first 4 samples
    test_latents = latent_samples[:4]
    
    results = apply_direction_to_generated(vae, test_latents, direction, alphas, device)
    
    fig, axes = plt.subplots(4, len(alphas), figsize=(2*len(alphas), 8))
    fig.suptitle(f"{title_prefix}: {attr_name} on Generated Samples", fontsize=14)
    
    for i in range(4):
        for j, (alpha, imgs) in enumerate(zip(alphas, results)):
            img_np = imgs[i].permute(1, 2, 0).numpy()
            img_np = ((img_np + 1) / 2).clip(0, 1)
            
            axes[i, j].imshow(img_np)
            axes[i, j].axis('off')
            if i == 0:
                axes[i, j].set_title(f"α={alpha}")
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Test attribute manipulation on generated samples
for attr_name in CONFIG['test_attributes'][:3]:  # Test first 3 attributes
    print(f"\nApplying {attr_name} to generated samples...")
    
    plot_generated_interpolation(
        vae_baseline, samples_baseline, directions_baseline, attr_name,
        "Baseline", f"./checkpoints/interpolations/gen_baseline_{attr_name}.png"
    )
    
    plot_generated_interpolation(
        vae_mmreg, samples_mmreg, directions_mmreg, attr_name,
        "MM-Reg", f"./checkpoints/interpolations/gen_mmreg_{attr_name}.png"
    )

## 9. Training Curves & Summary

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# VAE losses
with open('./checkpoints/celeba_baseline_vae/history.json') as f:
    baseline_vae_hist = json.load(f)
with open('./checkpoints/celeba_mmreg_vae/history.json') as f:
    mmreg_vae_hist = json.load(f)

axes[0].plot([h['loss'] for h in baseline_vae_hist['train']], 'b-', label='Baseline Train')
axes[0].plot([h['loss'] for h in baseline_vae_hist['val']], 'b--', label='Baseline Val')
axes[0].plot([h['loss'] for h in mmreg_vae_hist['train']], 'r-', label='MM-Reg Train')
axes[0].plot([h['loss'] for h in mmreg_vae_hist['val']], 'r--', label='MM-Reg Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('VAE Training')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Diffusion losses
with open('./checkpoints/celeba_diffusion_baseline/history.json') as f:
    baseline_diff_hist = json.load(f)
with open('./checkpoints/celeba_diffusion_mmreg/history.json') as f:
    mmreg_diff_hist = json.load(f)

axes[1].plot([h['loss'] for h in baseline_diff_hist['train']], 'b-', label='Baseline Train')
axes[1].plot([h['loss'] for h in baseline_diff_hist['val']], 'b--', label='Baseline Val')
axes[1].plot([h['loss'] for h in mmreg_diff_hist['train']], 'r-', label='MM-Reg Train')
axes[1].plot([h['loss'] for h in mmreg_diff_hist['val']], 'r--', label='MM-Reg Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Diffusion Training')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('./checkpoints/celeba_training_comparison.png', dpi=150)
plt.show()

In [None]:
# Final summary
print("="*60)
print("CELEBA EXPERIMENT SUMMARY")
print("="*60)

def to_python(val):
    if hasattr(val, 'item'):
        return val.item()
    return float(val)

summary = {
    'config': CONFIG,
    'vae_results': {
        'baseline': {
            'recon_mse': to_python(results_baseline['recon_mse']),
            'pearson_corr': to_python(results_baseline['pearson_corr']),
            'spearman_corr': to_python(results_baseline['spearman_corr'])
        },
        'mmreg': {
            'recon_mse': to_python(results_mmreg['recon_mse']),
            'pearson_corr': to_python(results_mmreg['pearson_corr']),
            'spearman_corr': to_python(results_mmreg['spearman_corr'])
        }
    },
    'diffusion_final_loss': {
        'baseline_train': to_python(baseline_diff_hist['train'][-1]['loss']),
        'baseline_val': to_python(baseline_diff_hist['val'][-1]['loss']),
        'mmreg_train': to_python(mmreg_diff_hist['train'][-1]['loss']),
        'mmreg_val': to_python(mmreg_diff_hist['val'][-1]['loss'])
    }
}

print("\nVAE Comparison:")
print(f"  Baseline - Recon MSE: {results_baseline['recon_mse']:.6f}, Pearson: {results_baseline['pearson_corr']:.4f}")
print(f"  MM-Reg   - Recon MSE: {results_mmreg['recon_mse']:.6f}, Pearson: {results_mmreg['pearson_corr']:.4f}")

print("\nDiffusion Final Val Loss:")
print(f"  Baseline: {summary['diffusion_final_loss']['baseline_val']:.6f}")
print(f"  MM-Reg:   {summary['diffusion_final_loss']['mmreg_val']:.6f}")

improvement = (summary['diffusion_final_loss']['baseline_val'] - summary['diffusion_final_loss']['mmreg_val']) / summary['diffusion_final_loss']['baseline_val'] * 100
print(f"\nMM-Reg diffusion improvement: {improvement:.1f}%")

# Save summary
with open('./checkpoints/celeba_experiment_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\nResults saved to ./checkpoints/celeba_experiment_summary.json")