# Diffusion Process Implementation

This notebook implements the **forward diffusion process** for DDPM (Denoising Diffusion Probabilistic Models).

## Key Equations

**Forward process:** $q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) \mathbf{I})$

This allows us to sample $x_t$ directly from $x_0$ without iterating through all steps.

## Notation

| Symbol | Name | Description |
|--------|------|-------------|
| $\beta_t$ | Beta | Variance schedule, controls noise added at each step |
| $\alpha_t$ | Alpha | $1 - \beta_t$ |
| $\bar{\alpha}_t$ | Alpha bar | $\prod_{s=1}^{t} \alpha_s$ (cumulative product) |

**Reference:** "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)


In [4]:
%load_ext autoreload
%autoreload 2
import torch
import matplotlib.pyplot as plt
from config import T  # Total number of diffusion timesteps (e.g., 1000)
from dataset import MNIST

print(f"Total diffusion timesteps T = {T}")

## 1. Diffusion Schedule Parameters

These parameters define the noise schedule and are precomputed for efficiency.

In [5]:
# Beta schedule: linear increase from 0.0001 to 0.02 over T steps
# β_t controls how much noise is added at each step
# Small values at start (preserve structure) -> larger values at end (more noise)
betas = torch.linspace(0.0001, 0.02, T)  # Shape: (T,)

print(f"betas shape: {betas.shape}")
print(f"betas[0] = {betas[0]:.6f} (start)")
print(f"betas[-1] = {betas[-1]:.6f} (end)")

In [6]:
# Alpha: the "keep" ratio at each step (how much of the signal to retain)
# α_t = 1 - β_t
alphas = 1 - betas  # Shape: (T,)

# Alpha cumulative product: product of all alphas from step 0 to t
# α̅_t = α_1 * α_2 * ... * α_t
# This tells us how much of the original signal remains at step t
alphas_cumprod = torch.cumprod(alphas, dim=-1)  # Shape: (T,)

# Alpha cumulative product for previous timestep (t-1)
# Prepend 1.0 for t=0 case (no noise added yet, full signal)
alphas_cumprod_prev = torch.cat(
    (torch.tensor([1.0]), alphas_cumprod[:-1]), dim=-1
)  # Shape: (T,)

print(f"alphas_cumprod[0] = {alphas_cumprod[0]:.4f} (almost all signal)")
print(f"alphas_cumprod[-1] = {alphas_cumprod[-1]:.6f} (almost no signal)")

In [7]:
# Posterior variance: used in the reverse (denoising) process
# This is the variance of q(x_{t-1} | x_t, x_0)
# Formula: σ²_t = β_t * (1 - α̅_{t-1}) / (1 - α̅_t)
variance = (1 - alphas) * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)  # Shape: (T,)

print(f"variance shape: {variance.shape}")

In [9]:
# Visualize the diffusion schedule
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(betas.numpy())
axes[0].set_title("Beta Schedule (β_t)")
axes[0].set_xlabel("Timestep t")
axes[0].set_ylabel("β_t")

axes[1].plot(alphas_cumprod.numpy())
axes[1].set_title("Alpha Cumulative Product (α̅_t)")
axes[1].set_xlabel("Timestep t")
axes[1].set_ylabel("α̅_t")

axes[2].plot(variance.numpy())
axes[2].set_title("Posterior Variance (σ²_t)")
axes[2].set_xlabel("Timestep t")
axes[2].set_ylabel("σ²_t")

plt.tight_layout()
plt.show()

## 2. Forward Diffusion Process

The key insight of DDPM is that we can sample $x_t$ directly from $x_0$ without iterating through all intermediate steps:

$$x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon$$

where $\epsilon \sim \mathcal{N}(0, \mathbf{I})$ is standard Gaussian noise.

In [10]:
def forward_add_noise(x, t):
    """
    Add noise to images according to the forward diffusion process.
    
    Args:
        x: Clean images, shape (batch, channel, height, width), range [-1, 1]
        t: Timesteps for each image, shape (batch,)
    
    Returns:
        x_noisy: Noisy images, same shape as x
        noise: The Gaussian noise that was added (training target)
    """
    # Sample random Gaussian noise with same shape as input
    noise = torch.randn_like(x)
    
    # Get α̅_t for each image's timestep and reshape for broadcasting
    batch_alphas_cumprod = alphas_cumprod[t].view(x.size(0), 1, 1, 1)
    
    # Apply the forward diffusion formula:
    # x_t = sqrt(α̅_t) * x_0 + sqrt(1 - α̅_t) * ε
    x_noisy = (
        torch.sqrt(batch_alphas_cumprod) * x
        + torch.sqrt(1 - batch_alphas_cumprod) * noise
    )
    
    return x_noisy, noise

## 3. Visualization

Let's load some MNIST images and see how the forward diffusion process adds noise.

In [16]:
# Load MNIST dataset
dataset = MNIST()
print(f"Dataset size: {len(dataset)}")

# Stack 2 images into a batch: (2, 1, 28, 28)
x = torch.stack((dataset[0][0], dataset[1][0]), dim=0)
print(f"Batch shape: {x.shape}")
print(f"Image range is {x.min().item()} to {x.max().item()}")

In [17]:
# Display original images
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title("Original Image 1")
plt.imshow(x[0].permute(1, 2, 0), cmap='gray')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title("Original Image 2")
plt.imshow(x[1].permute(1, 2, 0), cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()

In [18]:
# Normalize pixel values from [0, 1] to [-1, 1]
# This matches the range of Gaussian noise (mean=0)
x_normalized = x * 2 - 1

# Sample random timesteps for each image
t = torch.randint(0, T, size=(x.size(0),))
print(f"Timesteps: {t.tolist()}")

# Apply forward diffusion (add noise)
x_noisy, noise = forward_add_noise(x_normalized, t)
print(f"Noisy image shape: {x_noisy.shape}")
print(f"Noise shape: {noise.shape}")

In [28]:
# Display noisy images
# Convert back from [-1, 1] to [0, 1] for visualization
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f"Noisy Image 1 (t={t[0].item()})")
plt.imshow(((x_noisy[0] + 1) / 2).permute(1, 2, 0), cmap='gray')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title(f"Noisy Image 2 (t={t[1].item()})")
plt.imshow(((x_noisy[1] + 1) / 2).permute(1, 2, 0), cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()

## 4. Noise Progression

Let's visualize how a single image gets progressively noisier at different timesteps.

In [29]:
# Show noise progression for a single image at different timesteps
single_img = x_normalized[0:1]  # Shape: (1, 1, 28, 28)

# Select timesteps to visualize
timesteps_to_show = [0, T//4, T//2, 3*T//4, T-1]

plt.figure(figsize=(15, 3))
for i, t_val in enumerate(timesteps_to_show):
    t_tensor = torch.tensor([t_val])
    noisy_img, _ = forward_add_noise(single_img, t_tensor)
    
    plt.subplot(1, len(timesteps_to_show), i + 1)
    plt.title(f"t = {t_val}")
    plt.imshow(((noisy_img[0] + 1) / 2).permute(1, 2, 0), cmap='gray')
    plt.axis('off')

plt.suptitle("Forward Diffusion: Adding Noise Over Time", fontsize=14)
plt.tight_layout()
plt.show()