# MM-Reg: CelebA Experiment with Bottleneck Architecture

Two-stage latent architecture for sharp reconstruction AND stable interpolation.

## Architecture:
```
Image → SD VAE encoder → 4x16x16 (spatial) → BottleneckVAE → 256-d (semantic) → BottleneckVAE decoder → 4x16x16 → SD VAE decoder → Image
```

## Pipeline:
1. **Setup**: Install dependencies, load CelebA
2. **Pre-compute PCA**: Reference embeddings for MM-Reg
3. **Fine-tune SD VAE**: One shared VAE on CelebA (reconstruction only)
4. **Encode to SD latents**: Cache all images as 4x16x16 latents
5. **Train Bottleneck VAEs**: Baseline (no MM-Reg) vs MM-Reg on 4x16x16 → 256-d
6. **Attribute Directions**: Find directions in 256-d bottleneck space
7. **Interpolation**: Test attribute manipulation in bottleneck space
8. **Train Diffusion**: On 256-d bottleneck vectors
9. **Generate + Interpolate**: Apply attributes to diffusion-generated faces

**Key insight**: The 256-d bottleneck forces semantic compression (like the old flat VAE),
while the frozen SD decoder provides sharp reconstruction quality.

## 1. Setup

In [None]:
# Clone repository
!rm -rf MMReg_diffusion_generative 2>/dev/null
!git clone https://github.com/laurent-cheret/MMReg_diffusion_generative.git
%cd MMReg_diffusion_generative

In [None]:
# Install dependencies
!pip install -q torch torchvision diffusers transformers accelerate
!pip install -q pyyaml tqdm scipy scikit-learn matplotlib
!pip install -q datasets  # For HuggingFace CelebA
!pip install -q gdown  # For Google Drive fallback

In [None]:
import sys
sys.path.insert(0, '.')

import torch
import torch.nn as nn
import os
import json
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Experiment configuration
CONFIG = {
    # Data
    'data_root': './data',
    'image_size': 128,  # CelebA at 128x128
    'batch_size': 64,
    
    # PCA
    'pca_components': 256,
    
    # SD VAE fine-tuning (shared, reconstruction only)
    'sd_vae_epochs': 5,
    'sd_vae_lr': 1e-5,
    
    # Bottleneck VAE
    'bottleneck_dim': 256,
    'bottleneck_hidden': 512,
    'bottleneck_epochs': 100,
    'bottleneck_lr': 1e-3,
    'bottleneck_batch_size': 256,
    'lambda_mm': 1.0,
    'beta': 0.001,
    
    # Diffusion Training (on 256-d bottleneck vectors)
    'diffusion_epochs': 200,
    'diffusion_lr': 1e-4,
    'diffusion_timesteps': 1000,
    
    # Attributes to test for interpolation
    'test_attributes': ['Smiling', 'Eyeglasses', 'Male', 'Young', 'Blond_Hair'],
}

# CelebA attribute names
CELEBA_ATTRIBUTES = [
    '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
    'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
    'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
    'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
    'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
    'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks',
    'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings',
    'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young'
]

def get_attr_idx(name):
    """Get index of attribute by name."""
    return CELEBA_ATTRIBUTES.index(name)

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 2. Load CelebA & Pre-compute PCA

In [None]:
from src.data.dataset import (
    get_celeba_dataset,
    compute_pca_embeddings_celeba,
    get_dataset_and_loader,
    CELEBA_ATTRIBUTES
)

# === CONFIGURATION: Choose your data source ===
# Option 1: HuggingFace (recommended, no rate limits)
USE_HUGGINGFACE = True

# Option 2: Google Drive (if you have CelebA saved there)
# Set these paths if you have CelebA in your Drive
USE_DRIVE = False
DRIVE_PATHS = {
    'images_zip': '/content/drive/MyDrive/DATASETS/CelebA/img_align_celeba.zip',
    'attr_file': '/content/drive/MyDrive/DATASETS/CelebA/list_attr_celeba.txt',
    'partition_file': '/content/drive/MyDrive/DATASETS/CelebA/list_eval_partition.txt'
}

# Mount Google Drive if using Drive source
if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')

# Determine source
if USE_DRIVE:
    source = 'drive'
    drive_paths = DRIVE_PATHS
else:
    source = 'huggingface'
    drive_paths = None

# Load CelebA with fixed transforms for PCA
print(f"Loading CelebA from {source}...")
train_dataset_fixed = get_celeba_dataset(
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    fixed_transform=True,
    source=source,
    drive_paths=drive_paths
)

val_dataset_fixed = get_celeba_dataset(
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    fixed_transform=True,
    source=source,
    drive_paths=drive_paths
)

print(f"Train: {len(train_dataset_fixed)}, Val: {len(val_dataset_fixed)}")

In [None]:
# Compute PCA embeddings
os.makedirs('./embeddings', exist_ok=True)

print("Computing PCA embeddings for training set...")
train_pca = compute_pca_embeddings_celeba(
    train_dataset_fixed,
    n_components=CONFIG['pca_components'],
    batch_size=64
)
torch.save(train_pca, './embeddings/celeba_train_pca.pt')

print("\nComputing PCA embeddings for validation set...")
val_pca = compute_pca_embeddings_celeba(
    val_dataset_fixed,
    n_components=CONFIG['pca_components'],
    batch_size=64
)
torch.save(val_pca, './embeddings/celeba_val_pca.pt')

print(f"\nTrain PCA: {train_pca.shape}, Val PCA: {val_pca.shape}")

## 3. Fine-tune Shared SD VAE

Train ONE SD VAE on CelebA for reconstruction only (no MM-Reg at this stage).
This gives us a CelebA-adapted encoder/decoder. The bottleneck VAEs will add the MM-Reg later.

In [None]:
from src.models.vae_wrapper import load_vae
from src.models.losses import VAELoss
from src.trainer import MMRegTrainer

print("="*60)
print("FINE-TUNING SHARED SD VAE (reconstruction only)")
print("="*60)

# Load pretrained SD VAE
sd_vae = load_vae(device=device)

# Loss: reconstruction + small KL, NO MM-Reg
loss_sdvae = VAELoss(lambda_mm=0.0, beta=CONFIG.get('beta', 1e-6))

# Data loaders with PCA (needed by trainer interface, but lambda_mm=0 ignores it)
train_dataset_pca, train_loader_pca = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='train',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_train_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

val_dataset_pca, val_loader_pca = get_dataset_and_loader(
    dataset_name='celeba',
    root=CONFIG['data_root'],
    split='val',
    image_size=CONFIG['image_size'],
    batch_size=CONFIG['batch_size'],
    num_workers=2,
    pca_embeddings_path='./embeddings/celeba_val_pca.pt',
    celeba_source=source,
    celeba_drive_paths=drive_paths
)

# Optimizer
optimizer_sdvae = torch.optim.AdamW(sd_vae.parameters(), lr=CONFIG['sd_vae_lr'])

# Trainer
trainer_sdvae = MMRegTrainer(
    vae=sd_vae,
    loss_fn=loss_sdvae,
    optimizer=optimizer_sdvae,
    train_loader=train_loader_pca,
    val_loader=val_loader_pca,
    device=device,
    save_dir='./checkpoints/celeba_sd_vae'
)

# Train
trainer_sdvae.train(num_epochs=CONFIG['sd_vae_epochs'])

## 4. Encode Dataset to SD Latents

Encode all CelebA images to 4x16x16 SD latents (cached) and collect attributes.

In [None]:
from src.bottleneck_trainer import encode_dataset_to_sd_latents
from torch.utils.data import DataLoader

os.makedirs('./embeddings', exist_ok=True)

# Create simple dataloaders for encoding (without PCA wrapper)
train_loader_simple = DataLoader(
    train_dataset_fixed,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)
val_loader_simple = DataLoader(
    val_dataset_fixed,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=2
)

# Encode all images to SD latents
print("Encoding training set to SD latents...")
train_sd_latents = encode_dataset_to_sd_latents(sd_vae, train_loader_simple, device)
torch.save(train_sd_latents, './embeddings/celeba_train_sd_latents.pt')

print("Encoding validation set to SD latents...")
val_sd_latents = encode_dataset_to_sd_latents(sd_vae, val_loader_simple, device)
torch.save(val_sd_latents, './embeddings/celeba_val_sd_latents.pt')

print(f"\nTrain SD latents: {train_sd_latents.shape}")  # (N, 4, 16, 16)
print(f"Val SD latents:   {val_sd_latents.shape}")

# Collect attributes in the same order
print("\nCollecting attributes...")
train_attrs_list = []
for batch in tqdm(train_loader_simple, desc="Train attrs"):
    _, attrs = batch
    train_attrs_list.append(attrs)
train_attrs = torch.cat(train_attrs_list, dim=0)

val_attrs_list = []
for batch in tqdm(val_loader_simple, desc="Val attrs"):
    _, attrs = batch
    val_attrs_list.append(attrs)
val_attrs = torch.cat(val_attrs_list, dim=0)

torch.save(train_attrs, './embeddings/celeba_train_attrs.pt')
torch.save(val_attrs, './embeddings/celeba_val_attrs.pt')

print(f"Train attrs: {train_attrs.shape}")  # (N, 40)
print(f"Val attrs:   {val_attrs.shape}")

## 5. Train Bottleneck VAEs

Train two bottleneck VAEs on the cached SD latents:
- **Baseline**: reconstruction + KL only (lambda_mm = 0)
- **MM-Reg**: reconstruction + KL + MM-Reg on 256-d bottleneck

### 5.1 Baseline Bottleneck (no MM-Reg)

In [None]:
from src.models.bottleneck import BottleneckVAE, BottleneckLoss
from src.bottleneck_trainer import BottleneckTrainer

# Load cached data
train_sd_latents = torch.load('./embeddings/celeba_train_sd_latents.pt')
val_sd_latents = torch.load('./embeddings/celeba_val_sd_latents.pt')
train_pca = torch.load('./embeddings/celeba_train_pca.pt')
val_pca = torch.load('./embeddings/celeba_val_pca.pt')

print(f"Train SD latents: {train_sd_latents.shape}")
print(f"Train PCA: {train_pca.shape}")

# Spatial shape from actual latents
spatial_shape = tuple(train_sd_latents.shape[1:])  # (4, 16, 16)
print(f"Spatial shape: {spatial_shape}")

print("\n" + "="*60)
print("TRAINING BASELINE BOTTLENECK (no MM-Reg)")
print("="*60)

# Baseline bottleneck
bottleneck_baseline = BottleneckVAE(
    spatial_shape=spatial_shape,
    latent_dim=CONFIG['bottleneck_dim'],
    hidden_dim=CONFIG['bottleneck_hidden']
).to(device)

# Loss without MM-Reg
loss_baseline = BottleneckLoss(
    lambda_mm=0.0,
    beta=CONFIG['beta']
)

optimizer_baseline = torch.optim.Adam(bottleneck_baseline.parameters(), lr=CONFIG['bottleneck_lr'])

trainer_baseline = BottleneckTrainer(
    bottleneck=bottleneck_baseline,
    loss_fn=loss_baseline,
    optimizer=optimizer_baseline,
    train_latents=train_sd_latents,
    train_pca=train_pca,
    val_latents=val_sd_latents,
    val_pca=val_pca,
    batch_size=CONFIG['bottleneck_batch_size'],
    device=device,
    save_dir='./checkpoints/celeba_bottleneck_baseline'
)

trainer_baseline.train(num_epochs=CONFIG['bottleneck_epochs'])

### 5.2 MM-Reg Bottleneck

In [ ]:
print("="*60)
print("TRAINING MM-REG BOTTLENECK")
print("="*60)

# MM-Reg bottleneck
bottleneck_mmreg = BottleneckVAE(
    spatial_shape=spatial_shape,
    latent_dim=CONFIG['bottleneck_dim'],
    hidden_dim=CONFIG['bottleneck_hidden']
).to(device)

# Loss with MM-Reg
loss_mmreg = BottleneckLoss(
    lambda_mm=CONFIG['lambda_mm'],
    beta=CONFIG['beta'],
    mm_variant='correlation'
)

optimizer_mmreg = torch.optim.Adam(bottleneck_mmreg.parameters(), lr=CONFIG['bottleneck_lr'])

trainer_mmreg = BottleneckTrainer(
    bottleneck=bottleneck_mmreg,
    loss_fn=loss_mmreg,
    optimizer=optimizer_mmreg,
    train_latents=train_sd_latents,
    train_pca=train_pca,
    val_latents=val_sd_latents,
    val_pca=val_pca,
    batch_size=CONFIG['bottleneck_batch_size'],
    device=device,
    save_dir='./checkpoints/celeba_bottleneck_mmreg'
)

trainer_mmreg.train(num_epochs=CONFIG['bottleneck_epochs'])

## 6. Evaluate Bottleneck VAEs

Evaluate end-to-end reconstruction (bottleneck decode → SD decode) and distance correlation in 256-d.

In [ ]:
from src.models.losses import pairwise_distances, get_upper_triangular
from scipy.stats import pearsonr, spearmanr

val_sd_latents = torch.load('./embeddings/celeba_val_sd_latents.pt')
val_pca = torch.load('./embeddings/celeba_val_pca.pt')
val_attrs = torch.load('./embeddings/celeba_val_attrs.pt')

def evaluate_bottleneck(bottleneck, sd_vae, val_latents, val_pca, val_attrs, name, device):
    """Evaluate bottleneck: latent recon, end-to-end recon, distance correlation."""
    bottleneck.eval()
    sd_vae.eval()
    
    all_z = []
    total_latent_mse = 0
    total_pixel_mse = 0
    num_samples = 0
    
    with torch.no_grad():
        for i in tqdm(range(0, len(val_latents), 64), desc=f"Evaluating {name}"):
            batch_latents = val_latents[i:i+64].to(device)
            
            # Bottleneck forward
            outputs = bottleneck(batch_latents, sample=False)
            
            # Latent-space reconstruction MSE
            latent_mse = ((outputs['x_recon'] - batch_latents) ** 2).mean().item()
            total_latent_mse += latent_mse * batch_latents.shape[0]
            
            # End-to-end pixel reconstruction: bottleneck recon → SD decode
            pixel_recon = sd_vae.decode(outputs['x_recon'])
            pixel_orig = sd_vae.decode(batch_latents)
            pixel_mse = ((pixel_recon - pixel_orig) ** 2).mean().item()
            total_pixel_mse += pixel_mse * batch_latents.shape[0]
            
            all_z.append(outputs['mu'].cpu())
            num_samples += batch_latents.shape[0]
    
    all_z = torch.cat(all_z, dim=0)
    
    # Distance correlation in 256-d bottleneck space
    n = min(500, len(all_z))
    D_z = pairwise_distances(all_z[:n])
    D_pca = pairwise_distances(val_pca[:n])
    
    d_z = get_upper_triangular(D_z).numpy()
    d_pca = get_upper_triangular(D_pca).numpy()
    
    pearson, _ = pearsonr(d_z, d_pca)
    spearman, _ = spearmanr(d_z, d_pca)
    
    results = {
        'latent_mse': total_latent_mse / num_samples,
        'pixel_mse': total_pixel_mse / num_samples,
        'pearson_corr': pearson,
        'spearman_corr': spearman,
        'bottleneck_z': all_z,
        'attrs': val_attrs
    }
    
    print(f"\n{name} Results:")
    print(f"  Latent Recon MSE:   {results['latent_mse']:.6f}")
    print(f"  Pixel Recon MSE:    {results['pixel_mse']:.6f}")
    print(f"  Pearson (256-d):    {results['pearson_corr']:.4f}")
    print(f"  Spearman (256-d):   {results['spearman_corr']:.4f}")
    
    return results

results_baseline = evaluate_bottleneck(
    bottleneck_baseline, sd_vae, val_sd_latents, val_pca, val_attrs, "Baseline Bottleneck", device
)
results_mmreg = evaluate_bottleneck(
    bottleneck_mmreg, sd_vae, val_sd_latents, val_pca, val_attrs, "MM-Reg Bottleneck", device
)

In [None]:
# Visualize end-to-end reconstructions: original → SD encode → bottleneck → SD decode
def plot_bottleneck_reconstructions(bottleneck, sd_vae, val_latents, val_loader_simple, title, save_path):
    """Show original images and their bottleneck reconstructions."""
    bottleneck.eval()
    sd_vae.eval()
    
    # Get original images and their SD latents
    batch = next(iter(val_loader_simple))
    images = batch[0][:8].to(device)
    
    with torch.no_grad():
        # Encode to SD latents
        enc = sd_vae.encode(images, sample=False)
        sd_latents = enc['latent']
        
        # Through bottleneck
        bn_out = bottleneck(sd_latents, sample=False)
        sd_recon = bn_out['x_recon']
        
        # Decode both through SD decoder
        orig_decoded = sd_vae.decode(sd_latents)
        bn_decoded = sd_vae.decode(sd_recon)
    
    fig, axes = plt.subplots(3, 8, figsize=(16, 6))
    fig.suptitle(title, fontsize=14)
    
    for i in range(8):
        # Original
        img = images[i].cpu().permute(1, 2, 0).numpy()
        img = ((img + 1) / 2).clip(0, 1)
        axes[0, i].imshow(img)
        axes[0, i].axis('off')
        
        # SD VAE reconstruction (no bottleneck)
        sd_rec = orig_decoded[i].cpu().permute(1, 2, 0).numpy()
        sd_rec = ((sd_rec + 1) / 2).clip(0, 1)
        axes[1, i].imshow(sd_rec)
        axes[1, i].axis('off')
        
        # Bottleneck reconstruction (SD → bottleneck → SD)
        bn_rec = bn_decoded[i].cpu().permute(1, 2, 0).numpy()
        bn_rec = ((bn_rec + 1) / 2).clip(0, 1)
        axes[2, i].imshow(bn_rec)
        axes[2, i].axis('off')
    
    axes[0, 0].set_ylabel('Original', fontsize=11)
    axes[1, 0].set_ylabel('SD VAE only', fontsize=11)
    axes[2, 0].set_ylabel('+ Bottleneck', fontsize=11)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

plot_bottleneck_reconstructions(
    bottleneck_baseline, sd_vae, val_sd_latents, val_loader_simple,
    "Baseline Bottleneck Reconstructions",
    "./checkpoints/celeba_bottleneck_baseline/reconstructions.png"
)
plot_bottleneck_reconstructions(
    bottleneck_mmreg, sd_vae, val_sd_latents, val_loader_simple,
    "MM-Reg Bottleneck Reconstructions",
    "./checkpoints/celeba_bottleneck_mmreg/reconstructions.png"
)

## 7. Compute Attribute Directions in Bottleneck Space

For each attribute, compute: `d = mean(z_256d[attr=1]) - mean(z_256d[attr=0])`

These directions should work much better in the 256-d bottleneck than in the 4x16x16 spatial space.

In [ ]:
# Encode training set through bottlenecks to get 256-d vectors + attrs
train_attrs = torch.load('./embeddings/celeba_train_attrs.pt')

print("Encoding training set through baseline bottleneck...")
train_z_baseline = trainer_baseline.encode_all(train_sd_latents)
print(f"Baseline z: {train_z_baseline.shape}")  # (N, 256)

print("Encoding training set through MM-Reg bottleneck...")
train_z_mmreg = trainer_mmreg.encode_all(train_sd_latents)
print(f"MM-Reg z: {train_z_mmreg.shape}")

def compute_attribute_directions(z, attrs, attr_names):
    """Compute attribute directions in bottleneck space."""
    directions = {}
    for attr_name in attr_names:
        attr_idx = get_attr_idx(attr_name)
        has_attr = attrs[:, attr_idx] > 0.5
        no_attr = attrs[:, attr_idx] <= 0.5
        
        mean_with = z[has_attr].mean(dim=0)
        mean_without = z[no_attr].mean(dim=0)
        direction = mean_with - mean_without
        directions[attr_name] = direction
        
        print(f"  {attr_name}: {has_attr.sum().item()} with, {no_attr.sum().item()} without, ||d||={direction.norm():.2f}")
    return directions

print("\nBaseline attribute directions:")
directions_baseline = compute_attribute_directions(train_z_baseline, train_attrs, CONFIG['test_attributes'])

print("\nMM-Reg attribute directions:")
directions_mmreg = compute_attribute_directions(train_z_mmreg, train_attrs, CONFIG['test_attributes'])

## 8. Attribute Interpolation in Bottleneck Space

Manipulate attributes in the 256-d bottleneck space, then decode through:
`z_256d + alpha*d → bottleneck decode → 4x16x16 → SD decode → image`

In [ ]:
def interpolate_bottleneck(bottleneck, sd_vae, sd_latent, direction, alphas, device):
    """
    Interpolate attribute in bottleneck space, decode through full pipeline.
    
    Args:
        bottleneck: BottleneckVAE
        sd_vae: SD VAE (frozen decoder)
        sd_latent: (1, 4, 16, 16) SD VAE latent
        direction: (256,) attribute direction in bottleneck space
        alphas: interpolation strengths
        device: device
    
    Returns:
        List of decoded images
    """
    bottleneck.eval()
    sd_vae.eval()
    sd_latent = sd_latent.to(device)
    direction = direction.to(device)
    
    with torch.no_grad():
        # Encode to bottleneck
        enc = bottleneck.encode(sd_latent, sample=False)
        z = enc['mu']  # (1, 256)
        
        results = []
        for alpha in alphas:
            # Add direction in 256-d space
            z_new = z + alpha * direction.unsqueeze(0)
            
            # Decode through bottleneck → 4x16x16
            sd_recon = bottleneck.decode(z_new)
            
            # Decode through SD VAE → image
            image = sd_vae.decode(sd_recon)
            results.append(image.cpu())
    
    return results


def plot_bottleneck_interpolation(bottleneck, sd_vae, val_sd_latents, val_attrs,
                                   directions, attr_name, title_prefix, save_path):
    """Plot attribute interpolation in bottleneck space."""
    attr_idx = get_attr_idx(attr_name)
    
    # Find samples without the attribute
    no_attr_mask = val_attrs[:, attr_idx] <= 0.5
    no_attr_indices = torch.where(no_attr_mask)[0][:4]
    
    if len(no_attr_indices) == 0:
        print(f"No test images without {attr_name}")
        return
    
    alphas = [-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]
    direction = directions[attr_name]
    n_images = len(no_attr_indices)
    
    fig, axes = plt.subplots(n_images, len(alphas), figsize=(2*len(alphas), 2*n_images))
    fig.suptitle(f"{title_prefix}: {attr_name} Interpolation (bottleneck space)", fontsize=14)
    
    for i, idx in enumerate(no_attr_indices):
        sd_latent = val_sd_latents[idx:idx+1]
        results = interpolate_bottleneck(bottleneck, sd_vae, sd_latent, direction, alphas, device)
        
        for j, (alpha, result) in enumerate(zip(alphas, results)):
            img_np = result[0].permute(1, 2, 0).numpy()
            img_np = ((img_np + 1) / 2).clip(0, 1)
            
            ax = axes[i, j] if n_images > 1 else axes[j]
            ax.imshow(img_np)
            ax.axis('off')
            if i == 0:
                ax.set_title(f"a={alpha}")
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Test interpolation for each attribute
os.makedirs('./checkpoints/interpolations', exist_ok=True)

for attr_name in CONFIG['test_attributes']:
    print(f"\nTesting {attr_name} interpolation...")
    
    plot_bottleneck_interpolation(
        bottleneck_baseline, sd_vae, val_sd_latents, val_attrs,
        directions_baseline, attr_name,
        "Baseline", f"./checkpoints/interpolations/baseline_{attr_name}.png"
    )
    
    plot_bottleneck_interpolation(
        bottleneck_mmreg, sd_vae, val_sd_latents, val_attrs,
        directions_mmreg, attr_name,
        "MM-Reg", f"./checkpoints/interpolations/mmreg_{attr_name}.png"
    )

## 9. Train Diffusion on 256-d Bottleneck Vectors

The diffusion model now learns to generate 256-d vectors (not 4x16x16).
This is a 1D diffusion problem - simpler and faster.

### 9.1 Encode to Bottleneck Vectors

In [None]:
# Encode all SD latents to 256-d bottleneck vectors
print("Encoding to baseline bottleneck vectors...")
train_z_baseline = trainer_baseline.encode_all(train_sd_latents)
val_z_baseline = trainer_baseline.encode_all(val_sd_latents)
print(f"Baseline z - Train: {train_z_baseline.shape}, Val: {val_z_baseline.shape}")

print("\nEncoding to MM-Reg bottleneck vectors...")
train_z_mmreg = trainer_mmreg.encode_all(train_sd_latents)
val_z_mmreg = trainer_mmreg.encode_all(val_sd_latents)
print(f"MM-Reg z - Train: {train_z_mmreg.shape}, Val: {val_z_mmreg.shape}")

### 9.2 Train Diffusion on Baseline Bottleneck Vectors

In [None]:
from src.models.diffusion import MLPDenoiser, GaussianDiffusion
from src.diffusion_trainer import DiffusionTrainer

print("="*60)
print("TRAINING DIFFUSION ON BASELINE BOTTLENECK (256-d)")
print("="*60)

diffusion_baseline = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'],
    device=device
)

# MLP denoiser for flat 256-d vectors
mlp_baseline = MLPDenoiser(
    input_dim=CONFIG['bottleneck_dim'],
    hidden_dim=1024,
    num_layers=6,
    time_emb_dim=256
).to(device)

print(f"MLP params: {sum(p.numel() for p in mlp_baseline.parameters()):,}")

optimizer_diff_base = torch.optim.AdamW(mlp_baseline.parameters(), lr=CONFIG['diffusion_lr'])

trainer_diff_baseline = DiffusionTrainer(
    model=mlp_baseline,
    diffusion=diffusion_baseline,
    optimizer=optimizer_diff_base,
    train_latents=train_z_baseline,
    val_latents=val_z_baseline,
    batch_size=CONFIG['batch_size'],
    device=device,
    save_dir='./checkpoints/celeba_diffusion_baseline'
)

trainer_diff_baseline.train(num_epochs=CONFIG['diffusion_epochs'])

### 9.3 Train Diffusion on MM-Reg Bottleneck Vectors

In [None]:
print("="*60)
print("TRAINING DIFFUSION ON MM-REG BOTTLENECK (256-d)")
print("="*60)

diffusion_mmreg = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'],
    device=device
)

mlp_mmreg = MLPDenoiser(
    input_dim=CONFIG['bottleneck_dim'],
    hidden_dim=1024,
    num_layers=6,
    time_emb_dim=256
).to(device)

optimizer_diff_mm = torch.optim.AdamW(mlp_mmreg.parameters(), lr=CONFIG['diffusion_lr'])

trainer_diff_mmreg = DiffusionTrainer(
    model=mlp_mmreg,
    diffusion=diffusion_mmreg,
    optimizer=optimizer_diff_mm,
    train_latents=train_z_mmreg,
    val_latents=val_z_mmreg,
    batch_size=CONFIG['batch_size'],
    device=device,
    save_dir='./checkpoints/celeba_diffusion_mmreg'
)

trainer_diff_mmreg.train(num_epochs=CONFIG['diffusion_epochs'])

## 10. Generate Samples & Apply Attribute Directions

Generate 256-d vectors with diffusion, then decode through:
`z_256d → bottleneck decode → 4x16x16 → SD decode → image`

In [ ]:
# Generate 256-d samples from diffusion
print("Generating samples from Baseline Diffusion...")
samples_z_baseline = trainer_diff_baseline.generate_samples(num_samples=16)
print(f"Generated baseline z: {samples_z_baseline.shape}")  # (16, 256)

print("\nGenerating samples from MM-Reg Diffusion...")
samples_z_mmreg = trainer_diff_mmreg.generate_samples(num_samples=16)
print(f"Generated MM-Reg z: {samples_z_mmreg.shape}")


def decode_bottleneck_to_images(bottleneck, sd_vae, z_vectors, device):
    """Decode 256-d vectors through bottleneck → SD VAE → images."""
    bottleneck.eval()
    sd_vae.eval()
    with torch.no_grad():
        z = z_vectors.to(device)
        # Bottleneck decode: 256-d → 4x16x16
        sd_latents = bottleneck.decode(z)
        # SD VAE decode: 4x16x16 → image
        images = sd_vae.decode(sd_latents)
    return images.cpu()


def plot_generated_faces(bottleneck, sd_vae, z_vectors, title, save_path):
    """Decode and plot generated faces."""
    images = decode_bottleneck_to_images(bottleneck, sd_vae, z_vectors, device)
    
    n = min(16, images.shape[0])
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    fig.suptitle(title, fontsize=14)
    
    for i in range(n):
        img = images[i].permute(1, 2, 0).numpy()
        img = ((img + 1) / 2).clip(0, 1)
        ax = axes[i // 8, i % 8]
        ax.imshow(img)
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

plot_generated_faces(
    bottleneck_baseline, sd_vae, samples_z_baseline,
    "Generated Faces (Baseline Bottleneck + Diffusion)",
    "./checkpoints/celeba_diffusion_baseline/generated_samples.png"
)

plot_generated_faces(
    bottleneck_mmreg, sd_vae, samples_z_mmreg,
    "Generated Faces (MM-Reg Bottleneck + Diffusion)",
    "./checkpoints/celeba_diffusion_mmreg/generated_samples.png"
)

In [None]:
# Apply attribute directions to diffusion-generated 256-d vectors
def plot_generated_attr_interpolation(bottleneck, sd_vae, z_samples, directions,
                                       attr_name, title_prefix, save_path):
    """Apply attribute direction to diffusion-generated bottleneck vectors."""
    bottleneck.eval()
    sd_vae.eval()
    
    alphas = [-1.0, 0.0, 1.0, 2.0]
    direction = directions[attr_name].to(device)
    test_z = z_samples[:4].to(device)
    
    fig, axes = plt.subplots(4, len(alphas), figsize=(2*len(alphas), 8))
    fig.suptitle(f"{title_prefix}: {attr_name} on Generated Faces", fontsize=14)
    
    with torch.no_grad():
        for j, alpha in enumerate(alphas):
            z_new = test_z + alpha * direction.unsqueeze(0)
            
            # Decode: 256-d → 4x16x16 → image
            sd_latents = bottleneck.decode(z_new)
            images = sd_vae.decode(sd_latents)
            
            for i in range(4):
                img = images[i].cpu().permute(1, 2, 0).numpy()
                img = ((img + 1) / 2).clip(0, 1)
                axes[i, j].imshow(img)
                axes[i, j].axis('off')
                if i == 0:
                    axes[i, j].set_title(f"a={alpha}")
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()


# Test attribute manipulation on generated samples
for attr_name in CONFIG['test_attributes'][:3]:
    print(f"\nApplying {attr_name} to generated samples...")
    
    plot_generated_attr_interpolation(
        bottleneck_baseline, sd_vae, samples_z_baseline, directions_baseline,
        attr_name, "Baseline",
        f"./checkpoints/interpolations/gen_baseline_{attr_name}.png"
    )
    
    plot_generated_attr_interpolation(
        bottleneck_mmreg, sd_vae, samples_z_mmreg, directions_mmreg,
        attr_name, "MM-Reg",
        f"./checkpoints/interpolations/gen_mmreg_{attr_name}.png"
    )

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bottleneck losses
with open('./checkpoints/celeba_bottleneck_baseline/history.json') as f:
    baseline_bn_hist = json.load(f)
with open('./checkpoints/celeba_bottleneck_mmreg/history.json') as f:
    mmreg_bn_hist = json.load(f)

axes[0].plot([h['loss'] for h in baseline_bn_hist['train']], 'b-', label='Baseline Train')
axes[0].plot([h['loss'] for h in baseline_bn_hist['val']], 'b--', label='Baseline Val')
axes[0].plot([h['loss'] for h in mmreg_bn_hist['train']], 'r-', label='MM-Reg Train')
axes[0].plot([h['loss'] for h in mmreg_bn_hist['val']], 'r--', label='MM-Reg Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Bottleneck VAE Training')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Diffusion losses
with open('./checkpoints/celeba_diffusion_baseline/history.json') as f:
    baseline_diff_hist = json.load(f)
with open('./checkpoints/celeba_diffusion_mmreg/history.json') as f:
    mmreg_diff_hist = json.load(f)

axes[1].plot([h['loss'] for h in baseline_diff_hist['train']], 'b-', label='Baseline Train')
axes[1].plot([h['loss'] for h in baseline_diff_hist['val']], 'b--', label='Baseline Val')
axes[1].plot([h['loss'] for h in mmreg_diff_hist['train']], 'r-', label='MM-Reg Train')
axes[1].plot([h['loss'] for h in mmreg_diff_hist['val']], 'r--', label='MM-Reg Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Diffusion Training (on 256-d)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('./checkpoints/celeba_training_comparison.png', dpi=150)
plt.show()

# Final summary
print("="*60)
print("CELEBA BOTTLENECK EXPERIMENT SUMMARY")
print("="*60)

def to_python(val):
    if hasattr(val, 'item'):
        return val.item()
    return float(val)

summary = {
    'config': CONFIG,
    'architecture': 'Image -> SD VAE -> 4x16x16 -> BottleneckVAE -> 256-d -> Diffusion',
    'bottleneck_results': {
        'baseline': {
            'latent_mse': to_python(results_baseline['latent_mse']),
            'pixel_mse': to_python(results_baseline['pixel_mse']),
            'pearson_corr': to_python(results_baseline['pearson_corr']),
            'spearman_corr': to_python(results_baseline['spearman_corr'])
        },
        'mmreg': {
            'latent_mse': to_python(results_mmreg['latent_mse']),
            'pixel_mse': to_python(results_mmreg['pixel_mse']),
            'pearson_corr': to_python(results_mmreg['pearson_corr']),
            'spearman_corr': to_python(results_mmreg['spearman_corr'])
        }
    },
    'diffusion_final_loss': {
        'baseline_train': to_python(baseline_diff_hist['train'][-1]['loss']),
        'baseline_val': to_python(baseline_diff_hist['val'][-1]['loss']),
        'mmreg_train': to_python(mmreg_diff_hist['train'][-1]['loss']),
        'mmreg_val': to_python(mmreg_diff_hist['val'][-1]['loss'])
    }
}

print("\nBottleneck VAE Comparison (256-d space):")
print(f"  Baseline - Latent MSE: {results_baseline['latent_mse']:.6f}, Pearson: {results_baseline['pearson_corr']:.4f}")
print(f"  MM-Reg   - Latent MSE: {results_mmreg['latent_mse']:.6f}, Pearson: {results_mmreg['pearson_corr']:.4f}")

print("\nDiffusion Final Val Loss (256-d):")
print(f"  Baseline: {summary['diffusion_final_loss']['baseline_val']:.6f}")
print(f"  MM-Reg:   {summary['diffusion_final_loss']['mmreg_val']:.6f}")

improvement = (summary['diffusion_final_loss']['baseline_val'] - summary['diffusion_final_loss']['mmreg_val']) / summary['diffusion_final_loss']['baseline_val'] * 100
print(f"\nMM-Reg diffusion improvement: {improvement:.1f}%")

with open('./checkpoints/celeba_experiment_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\nResults saved to ./checkpoints/celeba_experiment_summary.json")