# DDPM Training on CelebA with Matern32 GP Noise

This notebook trains a diffusion model on celebrity faces using **Gaussian Process structured noise** with a Matern 3/2 kernel.

## GP Noise for Faces

Using GP-structured noise for face generation has several potential advantages:
- **Spatial coherence**: Face features (eyes, nose, mouth) are spatially organized
- **Smooth textures**: Skin, hair have smooth, correlated structure
- **Natural variation**: GP lengthscale controls the scale of noise features
- **Efficient**: Fourier features make it computationally tractable

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

from realistica import (
    NoiseScheduler,
    UNet,
    DDPMTrainer,
    ImageGPNoiseSampler,
    sample_gp_noise_for_images,
    EMA,
    count_parameters
)

## Configuration

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon MPS acceleration")

# Hyperparameters
config = {
    'image_size': 64,
    'batch_size': 64,
    'num_epochs': 100,
    'learning_rate': 2e-4,
    'num_timesteps': 1000,
    'beta_start': 0.0001,
    'beta_end': 0.02,
    'schedule_type': 'linear',
    'save_interval': 10,
    'sample_interval': 10,
    'use_ema': True,
    'ema_decay': 0.9999,
    # GP noise parameters
    'use_gp_noise': True,
    'gp_lengthscale': 0.15,  # Larger for smoother noise (0.1-0.3 recommended)
    'gp_num_features': 1024,  # More features for better approximation
}

# Create output directories
os.makedirs('outputs/celeba_matern32/samples', exist_ok=True)
os.makedirs('outputs/celeba_matern32/checkpoints', exist_ok=True)

## Load CelebA Dataset

In [None]:
# Data preprocessing
transform = transforms.Compose([
    transforms.Resize(config['image_size']),
    transforms.CenterCrop(config['image_size']),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CelebA(
    root='./data',
    split='train',
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=4,
    pin_memory=True if device.type == 'cuda' else False,
    drop_last=True
)

print(f"Training samples: {len(train_dataset)}")
print(f"Batches per epoch: {len(train_loader)}")

## Visualize GP Noise Patterns

In [None]:
# Compare regular noise vs GP noise for RGB images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

# Regular i.i.d. noise
regular_noise = torch.randn(5, 3, 64, 64)
for i in range(5):
    img = regular_noise[i].permute(1, 2, 0).numpy()
    img = (img - img.min()) / (img.max() - img.min())
    axes[0, i].imshow(img)
    axes[0, i].set_title(f'Regular {i+1}')
    axes[0, i].axis('off')

# GP noise with Matern32
gp_noise = sample_gp_noise_for_images(
    batch_size=5,
    channels=3,
    height=64,
    width=64,
    lengthscale=config['gp_lengthscale'],
    num_features=config['gp_num_features'],
    kernel_type='matern32',
    device=device
)

for i in range(5):
    img = gp_noise[i].permute(1, 2, 0).cpu().numpy()
    img = (img - img.min()) / (img.max() - img.min())
    axes[1, i].imshow(img)
    axes[1, i].set_title(f'Matern32 GP {i+1}')
    axes[1, i].axis('off')

axes[0, 0].set_ylabel('Regular Noise', fontsize=12, fontweight='bold')
axes[1, 0].set_ylabel('GP Noise (Matern32)', fontsize=12, fontweight='bold')

plt.suptitle(f'RGB Noise Comparison (lengthscale={config["gp_lengthscale"]})', fontsize=14)
plt.tight_layout()
plt.show()

## Initialize Model and GP Sampler

In [None]:
# Initialize noise scheduler
noise_scheduler = NoiseScheduler(
    num_timesteps=config['num_timesteps'],
    beta_start=config['beta_start'],
    beta_end=config['beta_end'],
    schedule_type=config['schedule_type'],
    device=device
)

# Initialize U-Net model
model = UNet(
    in_channels=3,
    out_channels=3,
    base_channels=128,
    channel_multipliers=(1, 2, 3, 4),
    num_res_blocks=2,
    time_emb_dim=512,
    attention_levels=(False, False, True, True),
    dropout=0.1
).to(device)

# Initialize GP noise sampler
gp_noise_sampler = ImageGPNoiseSampler(
    height=config['image_size'],
    width=config['image_size'],
    lengthscale=config['gp_lengthscale'],
    variance=1.0,
    num_features=config['gp_num_features'],
    kernel_type='matern32',
    device=device
)

num_params = count_parameters(model)
print(f"Model parameters: {num_params:,}")
print(f"Model size: ~{num_params * 4 / 1024**2:.2f} MB (FP32)")
print(f"Using Matern32 GP noise with lengthscale={config['gp_lengthscale']}")

## Visualize Forward Diffusion with GP Noise

In [None]:
sample_batch, _ = next(iter(train_loader))
sample_image = sample_batch[0:1].to(device)
timesteps_to_show = [0, 50, 100, 250, 500, 750, 999]

fig, axes = plt.subplots(1, len(timesteps_to_show), figsize=(16, 3))
for idx, t in enumerate(timesteps_to_show):
    t_tensor = torch.tensor([t], device=device)
    
    # Use GP noise
    noisy_image, _ = noise_scheduler.add_gp_noise(
        sample_image,
        t_tensor,
        gp_noise_sampler=gp_noise_sampler
    )
    
    img = noisy_image[0].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5
    img = np.clip(img, 0, 1)
    axes[idx].imshow(img)
    axes[idx].set_title(f't={t}')
    axes[idx].axis('off')

plt.suptitle('Forward Diffusion with Matern32 GP Noise', fontsize=14)
plt.tight_layout()
plt.show()

## Custom Trainer with GP Noise

In [None]:
class GPDDPMTrainer(DDPMTrainer):
    """
    Extended trainer that uses GP noise instead of regular noise
    """
    
    def __init__(self, model, noise_scheduler, optimizer, device, gp_noise_sampler, criterion=None):
        super().__init__(model, noise_scheduler, optimizer, device, criterion)
        self.gp_noise_sampler = gp_noise_sampler
    
    def train_epoch(self, dataloader):
        """
        Train for one epoch using GP noise
        """
        self.model.train()
        epoch_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {self.epoch+1}")
        
        for batch_idx, batch in enumerate(progress_bar):
            if isinstance(batch, (list, tuple)):
                images = batch[0]
            else:
                images = batch
            
            images = images.to(self.device)
            batch_size = images.shape[0]
            
            # Sample timesteps
            t = self.noise_scheduler.sample_timesteps(batch_size)
            
            # Use GP noise (KEY CHANGE)
            noisy_images, noise = self.noise_scheduler.add_gp_noise(
                images, t, gp_noise_sampler=self.gp_noise_sampler
            )
            
            # Predict noise
            predicted_noise = self.model(noisy_images, t)
            
            # Loss
            loss = self.criterion(predicted_noise, noise)
            
            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            # Track
            epoch_loss += loss.item()
            self.losses.append(loss.item())
            self.global_step += 1
            
            progress_bar.set_postfix({'loss': f'{loss.item():.6f}'})
        
        avg_loss = epoch_loss / len(dataloader)
        self.epoch += 1
        return avg_loss

## Setup Training

In [None]:
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'])

# Learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=config['num_epochs'],
    eta_min=1e-6
)

# Initialize GP trainer
trainer = GPDDPMTrainer(
    model=model,
    noise_scheduler=noise_scheduler,
    optimizer=optimizer,
    device=device,
    gp_noise_sampler=gp_noise_sampler
)

# EMA
ema = None
if config['use_ema']:
    ema = EMA(model, decay=config['ema_decay'])
    print("Using EMA")

## Sampling Function with GP Noise

In [None]:
@torch.no_grad()
def sample_with_gp(model, noise_scheduler, gp_sampler, num_samples=64, device='cpu'):
    """
    Sample using GP noise for both initial state and stochastic steps
    """
    model.eval()
    
    # Start from GP noise
    x = gp_sampler.sample(num_samples, 3)
    
    for t in tqdm(reversed(range(noise_scheduler.num_timesteps)), desc='Sampling'):
        t_batch = torch.tensor([t] * num_samples, device=device)
        
        predicted_noise = model(x, t_batch)
        
        alpha = noise_scheduler.alphas[t]
        alpha_cumprod = noise_scheduler.alphas_cumprod[t]
        beta = noise_scheduler.betas[t]
        
        if t > 0:
            noise = gp_sampler.sample(num_samples, 3)
        else:
            noise = torch.zeros_like(x)
        
        x = (
            1 / torch.sqrt(alpha) * (
                x - (beta / torch.sqrt(1 - alpha_cumprod)) * predicted_noise
            ) + torch.sqrt(beta) * noise
        )
    
    model.train()
    return x

def sample_and_save(trainer):
    """Generate and save samples"""
    print("\nGenerating samples with GP noise...")
    
    if ema is not None:
        ema.apply_shadow()
    
    samples = sample_with_gp(
        model=trainer.model,
        noise_scheduler=trainer.noise_scheduler,
        gp_sampler=gp_noise_sampler,
        num_samples=64,
        device=device
    )
    
    if ema is not None:
        ema.restore()
    
    grid = make_grid(samples, nrow=8, normalize=True, value_range=(-1, 1))
    save_path = f'outputs/celeba_matern32/samples/epoch_{trainer.epoch:04d}.png'
    save_image(grid, save_path)
    
    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.title(f'Generated Faces (GP Noise) - Epoch {trainer.epoch}')
    plt.tight_layout()
    plt.show()
    
    print(f"Samples saved to {save_path}")

## Training Loop

In [None]:
# Train
metrics = trainer.train(
    dataloader=train_loader,
    num_epochs=config['num_epochs'],
    save_dir='outputs/celeba_matern32/checkpoints',
    save_interval=config['save_interval'],
    sample_fn=sample_and_save,
    sample_interval=config['sample_interval'],
    scheduler=lr_scheduler
)

print("\nTraining completed!")

## Plot Training Metrics

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))

ax1.plot(metrics['epoch_losses'], linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss per Epoch (GP Noise)')
ax1.grid(True, alpha=0.3)

step_losses = metrics['step_losses']
ax2.plot(step_losses, alpha=0.3, label='Raw')
smoothed = np.convolve(step_losses, np.ones(100)/100, mode='valid')
ax2.plot(smoothed, linewidth=2, label='Smoothed')
ax2.set_xlabel('Steps')
ax2.set_ylabel('Loss')
ax2.set_title('Training Loss per Step')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('outputs/celeba_matern32/training_curves.png', dpi=150)
plt.show()

## Generate Final Samples

In [None]:
if ema is not None:
    ema.apply_shadow()

final_samples = sample_with_gp(
    model=model,
    noise_scheduler=noise_scheduler,
    gp_sampler=gp_noise_sampler,
    num_samples=64,
    device=device
)

if ema is not None:
    ema.restore()

grid = make_grid(final_samples, nrow=8, normalize=True, value_range=(-1, 1))
plt.figure(figsize=(14, 14))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title('Final Generated Faces with Matern32 GP Noise', fontsize=16)
plt.tight_layout()
plt.savefig('outputs/celeba_matern32/samples/final_samples.png', dpi=200)
plt.show()

## Compare Different GP Lengthscales

In [None]:
# Generate samples with different lengthscales to see the effect
if ema is not None:
    ema.apply_shadow()

lengthscales = [0.05, 0.1, 0.15, 0.2, 0.3]
fig, axes = plt.subplots(len(lengthscales), 8, figsize=(16, 2*len(lengthscales)))

for i, ls in enumerate(lengthscales):
    print(f"Generating with lengthscale {ls}...")
    
    # Create sampler
    ls_sampler = ImageGPNoiseSampler(
        height=config['image_size'],
        width=config['image_size'],
        lengthscale=ls,
        num_features=config['gp_num_features'],
        kernel_type='matern32',
        device=device
    )
    
    # Generate samples
    samples = sample_with_gp(
        model=model,
        noise_scheduler=noise_scheduler,
        gp_sampler=ls_sampler,
        num_samples=8,
        device=device
    )
    
    for j in range(8):
        img = samples[j].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5
        img = np.clip(img, 0, 1)
        axes[i, j].imshow(img)
        axes[i, j].axis('off')
        if j == 0:
            axes[i, j].set_ylabel(f'ℓ={ls}', fontsize=12, fontweight='bold')

plt.suptitle('Generated Faces with Different GP Lengthscales', fontsize=14)
plt.tight_layout()
plt.savefig('outputs/celeba_matern32/samples/lengthscale_comparison.png', dpi=150)
plt.show()

if ema is not None:
    ema.restore()

## Save Final Model

In [None]:
save_dict = {
    'model_state_dict': model.state_dict(),
    'config': config,
}

if ema is not None:
    ema.apply_shadow()
    save_dict['ema_state_dict'] = {k: v.cpu() for k, v in model.state_dict().items()}
    ema.restore()

torch.save(save_dict, 'outputs/celeba_matern32/checkpoints/final_model_gp.pt')
print("Final model with GP noise saved!")

## Key Insights

### GP Noise for Face Generation:

1. **Spatial Structure**: Matern32 GP creates spatially correlated noise patterns that may align better with face structure

2. **Lengthscale Control**:
   - Small (0.05-0.1): Fine-grained noise, more texture detail
   - Medium (0.15-0.2): Balanced smoothness
   - Large (0.25-0.3): Very smooth, large-scale features

3. **Computational Efficiency**: Random Fourier Features make GP sampling O(NM) instead of O(N²)

4. **Training Dynamics**: Model learns to denoise spatially correlated patterns

5. **Potential Benefits**:
   - May produce more realistic skin textures
   - Better preservation of spatial coherence
   - Could lead to smoother generated faces

### Experiment Ideas:
- Compare FID scores between regular and GP noise
- Try different kernels (RBF, Matern52)
- Vary lengthscale during training
- Use different lengthscales for different channels