<a href="https://colab.research.google.com/github/gnoejh/ict1022/blob/main/Architectures/diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Diffusion Models

Diffusion models are a class of generative models that have recently achieved state-of-the-art results in generating high-quality images, audio, and other data types. These models work by gradually adding noise to data and then learning to reverse this process.

## Conceptual Overview

Diffusion models operate based on two processes:

1. **Forward Process (Diffusion)**: Gradually adds noise to the data until it becomes pure noise
2. **Reverse Process (Denoising)**: Learns to gradually remove noise to recover the data

![Diffusion Process](https://i.imgur.com/uHDRgUc.png)

The key insight is that while destroying information (adding noise) is easy, the model learns the challenging task of restoring information from noise. This approach creates a more stable training process compared to adversarial methods like GANs.

## Mathematical Foundation

### Forward Process

The forward diffusion process is defined as a Markov chain that gradually adds Gaussian noise to the data:

$$q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t I)$$

where $\beta_t$ is a variance schedule that controls the noise level at each step.

An important property is that we can sample $x_t$ directly from $x_0$ without going through all intermediate steps:

$$q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) I)$$

where $\bar{\alpha}_t = \prod_{i=1}^{t} (1 - \beta_i)$.

### Reverse Process

The goal is to learn the reverse process, which gradually denoises the data:

$$p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))$$

The neural network is trained to predict the parameters of this distribution.

## Training Objective

The training objective is derived from variational inference and simplifies to predicting either:

1. The added noise ε (noise prediction)
2. The clean data $x_0$ (data prediction)
3. The mean of the reverse process distribution (μ prediction)

The simplified objective is:

$$L_{simple} = \mathbb{E}_{t, x_0, \epsilon} \left[ ||\epsilon - \epsilon_\theta(x_t, t)||^2 \right]$$

where $\epsilon$ is the noise added during the forward process, and $\epsilon_\theta$ is the model's prediction of that noise.

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# Define a simple U-Net style architecture for the denoising network
class SimpleUNet(nn.Module):
    def __init__(self, channels=3, time_emb_dim=100):
        super().__init__()
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )
        
        # Encoder
        self.conv1 = nn.Conv2d(channels, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1)
        
        # Middle with time embedding
        self.time_proj = nn.Linear(time_emb_dim, 256)
        self.mid_conv1 = nn.Conv2d(256, 256, 3, padding=1)
        self.mid_conv2 = nn.Conv2d(256, 256, 3, padding=1)
        
        # Decoder
        self.up1 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.up2 = nn.ConvTranspose2d(128 + 128, 64, 4, stride=2, padding=1)  # Skip connection
        self.final = nn.Conv2d(64 + 64, channels, 3, padding=1)  # Skip connection
        
    def forward(self, x, t):
        # Time embedding
        t_emb = self.time_mlp(t.unsqueeze(-1).float())
        
        # Encoder
        x1 = F.silu(self.conv1(x))  # Save for skip connection
        x = F.silu(self.conv2(x1))
        x2 = x  # Save for skip connection
        x = F.silu(self.conv3(x))
        
        # Middle
        time_proj = self.time_proj(t_emb).unsqueeze(-1).unsqueeze(-1)
        x = x + time_proj  # Add time information
        x = F.silu(self.mid_conv1(x))
        x = F.silu(self.mid_conv2(x))
        
        # Decoder with skip connections
        x = F.silu(self.up1(x))
        x = torch.cat([x, x2], dim=1)  # Skip connection
        x = F.silu(self.up2(x))
        x = torch.cat([x, x1], dim=1)  # Skip connection
        
        return self.final(x)

## Diffusion Process Implementation

In [6]:
class DiffusionModel(nn.Module):
    def __init__(self, model, beta_start=1e-4, beta_end=0.02, timesteps=1000):
        super().__init__()
        self.model = model
        self.timesteps = timesteps
        
        # Define beta schedule
        self.betas = torch.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # Pre-compute values for sampling
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
    
    def q_sample(self, x_0, t, noise=None):
        """Forward diffusion process: add noise to the data"""
        if noise is None:
            noise = torch.randn_like(x_0)
            
        # Extract the corresponding alpha values
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t]
        
        # Reshape for proper broadcasting
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.reshape(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.reshape(-1, 1, 1, 1)
        
        # Apply the diffusion formula
        return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
    
    def p_losses(self, x_0, t, noise=None):
        """Calculate the loss for training"""
        if noise is None:
            noise = torch.randn_like(x_0)
            
        # Forward diffuse the data
        x_noisy = self.q_sample(x_0, t, noise)
        
        # Predict the noise
        predicted_noise = self.model(x_noisy, t)
        
        # Calculate the simple MSE loss
        loss = F.mse_loss(noise, predicted_noise)
        
        return loss
    
    @torch.no_grad()
    def p_sample(self, x_t, t):
        """Sample from p(x_{t-1} | x_t) - one step of the reverse process"""
        # Get the model prediction
        predicted_noise = self.model(x_t, t)
        
        # Extract parameters
        beta_t = self.betas[t]
        alpha_t = self.alphas[t]
        sqrt_recip_alpha_t = self.sqrt_recip_alphas[t]
        
        # Reshape for broadcasting
        beta_t = beta_t.reshape(-1, 1, 1, 1)
        sqrt_recip_alpha_t = sqrt_recip_alpha_t.reshape(-1, 1, 1, 1)
        
        # Calculate the mean for the reverse process
        mean = sqrt_recip_alpha_t * (x_t - (beta_t / torch.sqrt(1 - self.alphas_cumprod[t])) * predicted_noise)
        
        # Add some noise for t > 0
        if t > 0:
            variance = torch.sqrt(self.posterior_variance[t])
            variance = variance.reshape(-1, 1, 1, 1)
            noise = torch.randn_like(x_t)
            return mean + variance * noise
        else:
            return mean
    
    @torch.no_grad()
    def sample(self, batch_size, img_shape, device):
        """Generate samples by running the complete reverse process"""
        # Start from pure noise
        img = torch.randn(batch_size, *img_shape, device=device)
        
        # Iteratively denoise
        for t in reversed(range(self.timesteps)):
            t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
            img = self.p_sample(img, t_batch)
            
        return img

## Training Loop Example

In [7]:
def train_diffusion(diffusion_model, dataloader, optimizer, device, epochs):
    """Train the diffusion model"""
    diffusion_model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        
        for i, batch in enumerate(dataloader):
            # Get batch of images
            x = batch[0].to(device)
            batch_size = x.shape[0]
            
            # Reset gradients
            optimizer.zero_grad()
            
            # Sample random timesteps
            t = torch.randint(0, diffusion_model.timesteps, (batch_size,), device=device).long()
            
            # Calculate loss
            loss = diffusion_model.p_losses(x, t)
            
            # Backpropagation
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            # Print progress
            if i % 100 == 99:
                print(f"Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss / 100:.4f}")
                running_loss = 0.0
        
        # Generate samples after each epoch
        if (epoch + 1) % 5 == 0:
            with torch.no_grad():
                samples = diffusion_model.sample(4, x.shape[1:], device)
                # Here you would save or display the samples
                
    print("Training completed.")

## Setup Example

In [8]:
# Example setup (not executed)
def setup_example():
    import torch
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Setup data
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
    
    # Create model
    unet = SimpleUNet(channels=3).to(device)
    diffusion = DiffusionModel(unet).to(device)
    
    # Setup optimizer
    optimizer = torch.optim.Adam(diffusion.parameters(), lr=1e-4)
    
    # Train model
    train_diffusion(diffusion, dataloader, optimizer, device, epochs=30)
    
    # Generate samples
    samples = diffusion.sample(16, (3, 32, 32), device)
    
    return samples

# Note: This function is not executed in this notebook

## Key Variants of Diffusion Models

### DDPM (Denoising Diffusion Probabilistic Models)
- The original formulation with a fixed variance schedule
- Excellent quality but requires many sampling steps (typically 1000)

### DDIM (Denoising Diffusion Implicit Models)
- Extends DDPM with a non-Markovian process
- Enables faster sampling with fewer steps (10-50 steps often sufficient)
- Supports deterministic sampling

### Latent Diffusion Models (LDMs)
- Apply diffusion in a compressed latent space instead of pixel space
- Significantly more efficient for high-resolution images
- Used in Stable Diffusion for text-to-image generation

### Score-Based Generative Models
- Alternative mathematical formulation using score matching
- Shown to be equivalent to diffusion models in certain conditions
- Often uses stochastic differential equations (SDEs) for sampling

## Conditioning Techniques

Diffusion models can be conditioned on various inputs to control generation:

### Class Conditioning
```python
class ConditionalUNet(nn.Module):
    def __init__(self, num_classes=10, **kwargs):
        super().__init__(**kwargs)
        self.class_embedding = nn.Embedding(num_classes, embedding_dim)
        # Rest of the model
        
    def forward(self, x, t, class_labels):
        # Get class embedding and combine with time embedding
        class_emb = self.class_embedding(class_labels)
        # Use in the model
```

### Text Conditioning (as in Stable Diffusion)
- Use a pre-trained text encoder (like CLIP)
- Extract text embeddings and inject them into the UNet
- Often uses cross-attention mechanisms

### Image Conditioning
- For image-to-image translation tasks
- Can use concatenation or cross-attention mechanisms

## Applications of Diffusion Models

1. **Image Generation**
   - Text-to-image (Stable Diffusion, DALL-E 2, Midjourney)
   - High-resolution image synthesis
   - Image-to-image translation

2. **Audio Generation**
   - Speech synthesis (AudioLM)
   - Music generation
   - Sound effects creation

3. **Video Generation**
   - Text-to-video models (Imagen Video, Make-A-Video)
   - Video prediction and interpolation

4. **3D Content Creation**
   - Text-to-3D (DreamFusion)
   - 3D asset generation

5. **Scientific Applications**
   - Protein structure generation
   - Molecule design
   - Material discovery

## Advantages and Challenges

### Advantages
- High-quality outputs, often superior to GANs
- Stable training without adversarial dynamics
- Greater diversity in generated samples
- Flexible conditioning mechanisms
- Well-understood probabilistic foundation

### Challenges
- Slow sampling process (though recent optimizations help)
- High computational requirements
- Mode coverage sometimes at the expense of sample fidelity
- Hyperparameter sensitivity

## References and Further Reading

- Ho, J., et al. (2020). [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239). NeurIPS.
- Song, J., et al. (2020). [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502). ICLR.
- Rombach, R., et al. (2022). [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752). CVPR.
- Dhariwal, P., & Nichol, A. (2021). [Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/abs/2105.05233). NeurIPS.
- Song, Y., & Ermon, S. (2019). [Generative Modeling by Estimating Gradients of the Data Distribution](https://arxiv.org/abs/1907.05600). NeurIPS.

### Implementations
- [HuggingFace Diffusers Library](https://github.com/huggingface/diffusers)
- [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
- [Annotated Diffusion Model](https://huggingface.co/blog/annotated-diffusion)