## Section 1: Setup and Imports

Let's start by importing the libraries we need.

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch
import math

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Visualization settings
plt.style.use('default')
%matplotlib inline

## Section 2: Understanding Gaussian Noise

Diffusion models are built on **Gaussian (normal) distributions**. Let's understand why.

### Why Gaussian?

1. **Central Limit Theorem**: Sum of many small independent effects ‚Üí Gaussian
2. **Closed-form operations**: Sum of Gaussians is Gaussian
3. **Maximum entropy**: For fixed mean/variance, Gaussian is most "random"
4. **Mathematical convenience**: Many operations have closed-form solutions

### The Gaussian Distribution

A Gaussian with mean $\mu$ and variance $\sigma^2$:

$$p(x) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x-\mu)^2}{2\sigma^2}\right)$$

In diffusion models, we always use **standard Gaussian noise**:
- Mean: $\mu = 0$
- Variance: $\sigma^2 = 1$

$$\epsilon \sim \mathcal{N}(0, I)$$

In [None]:
# Visualize Gaussian noise properties
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. 1D Gaussian distribution
x = np.linspace(-4, 4, 1000)
gaussian = np.exp(-x**2 / 2) / np.sqrt(2 * np.pi)
axes[0].plot(x, gaussian, 'b-', linewidth=2)
axes[0].fill_between(x, gaussian, alpha=0.3)
axes[0].set_xlabel('x')
axes[0].set_ylabel('p(x)')
axes[0].set_title('Standard Gaussian Distribution\n$\\mathcal{N}(0, 1)$')
axes[0].grid(True, alpha=0.3)

# 2. Sample from Gaussian and show histogram
samples = torch.randn(10000).numpy()
axes[1].hist(samples, bins=50, density=True, alpha=0.7, color='green')
axes[1].plot(x, gaussian, 'r-', linewidth=2, label='True PDF')
axes[1].set_xlabel('Sample value')
axes[1].set_ylabel('Density')
axes[1].set_title(f'10,000 Samples\nMean: {samples.mean():.3f}, Std: {samples.std():.3f}')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# 3. 2D Gaussian noise (like an image)
noise_2d = torch.randn(64, 64).numpy()
im = axes[2].imshow(noise_2d, cmap='RdBu', vmin=-3, vmax=3)
axes[2].set_title('2D Gaussian Noise\n(What random images look like)')
axes[2].axis('off')
plt.colorbar(im, ax=axes[2], fraction=0.046)

plt.tight_layout()
plt.show()

print("üí° Key property: Gaussian noise has zero mean and unit variance.")
print("   This makes it the perfect 'baseline randomness' for diffusion.")

## Section 3: The Noise Schedule

The **noise schedule** defines how quickly we add noise to data over $T$ timesteps.

### Key Variables

| Symbol | Name | Formula | Meaning |
|--------|------|---------|--------|
| $\beta_t$ | Beta | Schedule parameter | Noise variance at step $t$ |
| $\alpha_t$ | Alpha | $1 - \beta_t$ | Signal retention at step $t$ |
| $\bar{\alpha}_t$ | Alpha-bar | $\prod_{s=1}^{t} \alpha_s$ | **Cumulative** signal remaining |

### Why $\bar{\alpha}_t$ is Most Important

- $\bar{\alpha}_t \approx 1.0$ ‚Üí Almost all signal preserved (nearly clean)
- $\bar{\alpha}_t \approx 0.5$ ‚Üí Half signal, half noise
- $\bar{\alpha}_t \approx 0.0$ ‚Üí Almost pure noise

### Common Schedules

1. **Linear** (original DDPM): $\beta_t$ increases linearly
2. **Cosine** (Improved DDPM): Smoother decay of $\bar{\alpha}_t$

In [None]:
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    """
    Linear schedule from Ho et al. (2020) DDPM paper.
    
    Œ≤_t increases linearly from beta_start to beta_end.
    """
    return torch.linspace(beta_start, beta_end, timesteps)


def cosine_beta_schedule(timesteps, s=0.008):
    """
    Cosine schedule from Nichol & Dhariwal (2021).
    
    Provides smoother noise levels, especially at the start.
    The 's' parameter is a small offset to prevent Œ≤ from being too small.
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    # Compute cumulative alphas using cosine function
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]  # Normalize
    # Derive betas from alphas_cumprod
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clamp(betas, 0.0001, 0.9999)


# Standard parameters
T = 1000  # Number of timesteps (standard in DDPM)

# Create both schedules
betas_linear = linear_beta_schedule(T)
betas_cosine = cosine_beta_schedule(T)

print(f"üìä Created noise schedules with T = {T} timesteps")
print(f"\nLinear schedule:")
print(f"   Œ≤ ranges from {betas_linear[0]:.6f} to {betas_linear[-1]:.4f}")
print(f"\nCosine schedule:")
print(f"   Œ≤ ranges from {betas_cosine[0]:.6f} to {betas_cosine[-1]:.4f}")

In [None]:
# Compute derived quantities
def compute_schedule_quantities(betas):
    """Compute all derived quantities from betas."""
    alphas = 1.0 - betas                           # Œ±_t = 1 - Œ≤_t
    alphas_cumprod = torch.cumprod(alphas, dim=0)  # ·æ±_t = ‚àè Œ±_s
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
    return {
        'betas': betas,
        'alphas': alphas,
        'alphas_cumprod': alphas_cumprod,
        'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
        'sqrt_one_minus_alphas_cumprod': sqrt_one_minus_alphas_cumprod,
    }

linear_schedule = compute_schedule_quantities(betas_linear)
cosine_schedule = compute_schedule_quantities(betas_cosine)

# Use linear schedule as default (original DDPM)
schedule = linear_schedule

print("‚úÖ Computed derived quantities:")
print(f"   ‚àö·æ±_0 = {schedule['sqrt_alphas_cumprod'][0]:.4f} (signal weight at t=0)")
print(f"   ‚àö·æ±_500 = {schedule['sqrt_alphas_cumprod'][500]:.4f} (signal weight at t=500)")
print(f"   ‚àö·æ±_999 = {schedule['sqrt_alphas_cumprod'][999]:.4f} (signal weight at t=999)")

In [None]:
# Visualize the schedules
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Beta schedule comparison
axes[0, 0].plot(betas_linear.numpy(), label='Linear', linewidth=2, color='blue')
axes[0, 0].plot(betas_cosine.numpy(), label='Cosine', linewidth=2, color='orange')
axes[0, 0].set_xlabel('Timestep $t$')
axes[0, 0].set_ylabel('$\\beta_t$')
axes[0, 0].set_title('Beta Schedule: Noise Variance per Step')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Alpha schedule
axes[0, 1].plot(linear_schedule['alphas'].numpy(), label='Linear', linewidth=2, color='blue')
axes[0, 1].plot(cosine_schedule['alphas'].numpy(), label='Cosine', linewidth=2, color='orange')
axes[0, 1].set_xlabel('Timestep $t$')
axes[0, 1].set_ylabel('$\\alpha_t = 1 - \\beta_t$')
axes[0, 1].set_title('Alpha Schedule: Signal Retention per Step')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Cumulative alpha (MOST IMPORTANT)
axes[1, 0].plot(linear_schedule['alphas_cumprod'].numpy(), label='Linear', linewidth=2, color='blue')
axes[1, 0].plot(cosine_schedule['alphas_cumprod'].numpy(), label='Cosine', linewidth=2, color='orange')
axes[1, 0].axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='50% signal')
axes[1, 0].set_xlabel('Timestep $t$')
axes[1, 0].set_ylabel('$\\bar{\\alpha}_t$')
axes[1, 0].set_title('‚≠ê Cumulative Alpha: Total Signal Remaining\n(Most important quantity!)')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 4. Signal and noise coefficients
t_vals = np.arange(T)
signal_coef = linear_schedule['sqrt_alphas_cumprod'].numpy()
noise_coef = linear_schedule['sqrt_one_minus_alphas_cumprod'].numpy()
axes[1, 1].plot(t_vals, signal_coef, label='$\\sqrt{\\bar{\\alpha}_t}$ (signal)', linewidth=2, color='green')
axes[1, 1].plot(t_vals, noise_coef, label='$\\sqrt{1-\\bar{\\alpha}_t}$ (noise)', linewidth=2, color='red')
axes[1, 1].fill_between(t_vals, signal_coef, alpha=0.2, color='green')
axes[1, 1].fill_between(t_vals, noise_coef, alpha=0.2, color='red')
axes[1, 1].set_xlabel('Timestep $t$')
axes[1, 1].set_ylabel('Coefficient')
axes[1, 1].set_title('Signal vs Noise Coefficients\n$x_t = \\sqrt{\\bar{\\alpha}_t} x_0 + \\sqrt{1-\\bar{\\alpha}_t} \\epsilon$')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nüîç Key Observations:")
print("   1. Linear schedule: Signal drops quickly at the beginning")
print("   2. Cosine schedule: Smoother decay throughout")
print("   3. Both reach near-zero signal by t=1000")

## Section 4: The Forward Process Equation

### The Key Insight üí°

We don't need to add noise step-by-step! We can jump **directly** to any timestep.

### Mathematical Derivation

**Step-by-step process** (how you might think it works):
$$x_1 = \sqrt{\alpha_1} x_0 + \sqrt{\beta_1} \epsilon_1$$
$$x_2 = \sqrt{\alpha_2} x_1 + \sqrt{\beta_2} \epsilon_2$$
$$\vdots$$
$$x_t = \sqrt{\alpha_t} x_{t-1} + \sqrt{\beta_t} \epsilon_t$$

**Closed-form** (the magic!):
$$\boxed{x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon}$$

where $\epsilon \sim \mathcal{N}(0, I)$ is a single noise sample.

### Why This Works (Proof Sketch)

Sum of independent Gaussians is Gaussian:
- If $X \sim \mathcal{N}(0, \sigma_1^2)$ and $Y \sim \mathcal{N}(0, \sigma_2^2)$
- Then $X + Y \sim \mathcal{N}(0, \sigma_1^2 + \sigma_2^2)$

The sequential noise terms combine into a single equivalent noise term!

### Interpretation

$x_t$ is a **weighted sum** of signal and noise:

| Component | Weight | At $t=0$ | At $t=T$ |
|-----------|--------|----------|----------|
| Signal ($x_0$) | $\sqrt{\bar{\alpha}_t}$ | ‚âà 1.0 | ‚âà 0.0 |
| Noise ($\epsilon$) | $\sqrt{1-\bar{\alpha}_t}$ | ‚âà 0.0 | ‚âà 1.0 |

In [None]:
def extract(a, t, x_shape):
    """
    Extract values from tensor 'a' at indices 't' and reshape for broadcasting.
    
    This is THE most important utility in diffusion code!
    
    Args:
        a: 1D tensor of shape (T,) - e.g., sqrt_alphas_cumprod
        t: Batch of timestep indices of shape (B,)
        x_shape: Shape of data tensor (B, C, H, W)
    
    Returns:
        Tensor of shape (B, 1, 1, 1) for broadcasting
    
    Example:
        >>> a = torch.linspace(1, 0, 1000)  # Shape: (1000,)
        >>> t = torch.tensor([0, 500, 999]) # Shape: (3,)
        >>> x_shape = (3, 1, 32, 32)        # Batch of 3 images
        >>> result = extract(a, t, x_shape) # Shape: (3, 1, 1, 1)
    """
    batch_size = t.shape[0]
    # gather: select elements from 'a' at positions specified by 't'
    out = a.gather(-1, t)
    # Reshape to (B, 1, 1, 1) for broadcasting with (B, C, H, W)
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))


# Demonstrate the extract function
print("üìù Demonstrating extract() function:")
print()

# Example: get sqrt_alphas_cumprod at different timesteps
t_example = torch.tensor([0, 100, 500, 999])
x_shape_example = (4, 1, 32, 32)  # Batch of 4 images

extracted = extract(schedule['sqrt_alphas_cumprod'], t_example, x_shape_example)
print(f"Input shape: sqrt_alphas_cumprod = {schedule['sqrt_alphas_cumprod'].shape}")
print(f"Timesteps: t = {t_example.tolist()}")
print(f"Output shape: {extracted.shape}")
print(f"\nExtracted values (‚àö·æ±_t):")
for i, t in enumerate(t_example.tolist()):
    print(f"   t={t:4d}: ‚àö·æ±_t = {extracted[i, 0, 0, 0]:.4f}")

In [None]:
def q_sample(x_0, t, noise=None):
    """
    Forward diffusion process: q(x_t | x_0)
    
    Add noise to clean images to get noisy images at timestep t.
    
    The key equation:
        x_t = ‚àö·æ±_t ¬∑ x_0 + ‚àö(1 - ·æ±_t) ¬∑ Œµ
    
    Args:
        x_0: Clean images, shape (B, C, H, W), values in [-1, 1]
        t: Timesteps, shape (B,), values in [0, T)
        noise: Optional pre-sampled noise Œµ ~ N(0, I)
    
    Returns:
        x_t: Noisy images at timestep t
        noise: The noise that was added (for training)
    """
    if noise is None:
        noise = torch.randn_like(x_0)
    
    # Extract the coefficients for each sample in the batch
    sqrt_alphas_cumprod_t = extract(schedule['sqrt_alphas_cumprod'], t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(schedule['sqrt_one_minus_alphas_cumprod'], t, x_0.shape)
    
    # Apply the forward process equation
    # x_t = signal_coef * x_0 + noise_coef * noise
    x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
    
    return x_t, noise


print("‚úÖ Forward process function defined!")
print("\nüìê The q_sample function implements:")
print("   x_t = ‚àö·æ±_t ¬∑ x_0 + ‚àö(1 - ·æ±_t) ¬∑ Œµ")

## Section 5: Visualizing the Forward Process

Let's see the forward process in action on real images!

In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.Resize(32),           # Resize to 32x32
    transforms.ToTensor(),            # Convert to tensor [0, 1]
    transforms.Lambda(lambda x: x * 2 - 1),  # Normalize to [-1, 1]
])

dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Get a sample batch
sample_batch, labels = next(iter(dataloader))
print(f"üì∑ Loaded MNIST batch:")
print(f"   Shape: {sample_batch.shape}")
print(f"   Value range: [{sample_batch.min():.2f}, {sample_batch.max():.2f}]")
print(f"   Labels: {labels[:8].tolist()}")

In [None]:
def show_forward_diffusion(image, timesteps_to_show):
    """
    Visualize how a single image gets progressively noisier.
    
    Args:
        image: Single image tensor of shape (C, H, W)
        timesteps_to_show: List of timesteps to visualize
    """
    n_steps = len(timesteps_to_show)
    fig, axes = plt.subplots(2, n_steps, figsize=(2.5 * n_steps, 5))
    
    # Use the same noise for all timesteps (to show the same "trajectory")
    noise = torch.randn_like(image.unsqueeze(0))
    
    for idx, t in enumerate(timesteps_to_show):
        t_tensor = torch.tensor([t])
        x_t, _ = q_sample(image.unsqueeze(0), t_tensor, noise=noise)
        
        # Get coefficients
        signal_coef = schedule['sqrt_alphas_cumprod'][t].item()
        noise_coef = schedule['sqrt_one_minus_alphas_cumprod'][t].item()
        
        # Convert for display
        img = x_t[0, 0].numpy()
        img_display = np.clip((img + 1) / 2, 0, 1)  # [-1, 1] -> [0, 1]
        
        # Show noisy image
        axes[0, idx].imshow(img_display, cmap='gray', vmin=0, vmax=1)
        axes[0, idx].set_title(f't = {t}', fontsize=11, fontweight='bold')
        axes[0, idx].axis('off')
        
        # Show info
        axes[1, idx].text(0.5, 0.8, f'$\\bar{{\\alpha}}_t$ = {schedule["alphas_cumprod"][t]:.4f}',
                         ha='center', fontsize=10, transform=axes[1, idx].transAxes)
        axes[1, idx].text(0.5, 0.5, f'Signal: {signal_coef:.3f}',
                         ha='center', fontsize=10, color='green', transform=axes[1, idx].transAxes)
        axes[1, idx].text(0.5, 0.2, f'Noise: {noise_coef:.3f}',
                         ha='center', fontsize=10, color='red', transform=axes[1, idx].transAxes)
        axes[1, idx].axis('off')
    
    plt.suptitle('Forward Diffusion: Clean Image ‚Üí Pure Noise', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


# Visualize forward diffusion on a single image
single_image = sample_batch[0]  # Take first image (a digit)
timesteps_to_show = [0, 100, 250, 500, 750, 900, 999]
show_forward_diffusion(single_image, timesteps_to_show)

print("\nüîç Observations:")
print("   ‚Ä¢ t=0: Original clean image")
print("   ‚Ä¢ t=250: Structure visible but noisy")
print("   ‚Ä¢ t=500: Hard to see structure")
print("   ‚Ä¢ t=999: Nearly pure noise")

In [None]:
# Show multiple images at the same timestep
def show_batch_at_timestep(images, t):
    """Show a batch of images at a specific noise level."""
    batch_size = images.shape[0]
    t_tensor = torch.full((batch_size,), t, dtype=torch.long)
    
    # Add noise
    noisy_images, noise = q_sample(images, t_tensor)
    
    # Create figure
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    
    for i in range(8):
        # Original
        orig = (images[i, 0].numpy() + 1) / 2
        axes[0, i].imshow(np.clip(orig, 0, 1), cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=10)
        
        # Noisy
        noisy = (noisy_images[i, 0].numpy() + 1) / 2
        axes[1, i].imshow(np.clip(noisy, 0, 1), cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title(f'Noisy (t={t})', fontsize=10)
    
    plt.suptitle(f'Batch at timestep t={t} (·æ±_t = {schedule["alphas_cumprod"][t]:.4f})', 
                 fontsize=12, fontweight='bold')
    plt.tight_layout()
    plt.show()


# Show batch at different noise levels
for t in [0, 200, 500, 800]:
    show_batch_at_timestep(sample_batch, t)

## Section 6: Verifying the Forward Process

Let's verify that our implementation is correct by checking key properties.

In [None]:
def verify_forward_process():
    """
    Verify that the forward process has the expected properties.
    
    Key checks:
    1. At t=0, x_t ‚âà x_0 (almost no noise)
    2. At t=T-1, x_t ‚âà N(0, I) (almost pure noise)
    3. Mean and variance follow the schedule
    """
    print("üî¨ Verifying Forward Process Properties")
    print("=" * 50)
    
    # Create a test image (all ones, normalized to [-1, 1] means 1.0)
    x_0 = torch.ones(1000, 1, 32, 32)  # Large batch for statistics
    
    # Test 1: At t=0
    t_0 = torch.zeros(1000, dtype=torch.long)
    x_t0, _ = q_sample(x_0, t_0)
    print(f"\nüìä Test 1: At t=0")
    print(f"   Expected: x_t ‚âà x_0 (·æ±_0 = {schedule['alphas_cumprod'][0]:.6f})")
    print(f"   Mean of x_t: {x_t0.mean():.4f} (expected: 1.0)")
    print(f"   Std of x_t: {x_t0.std():.4f} (expected: ~0.01)")
    
    # Test 2: At t=T-1
    t_T = torch.full((1000,), T-1, dtype=torch.long)
    x_tT, _ = q_sample(x_0, t_T)
    print(f"\nüìä Test 2: At t={T-1}")
    print(f"   Expected: x_t ‚âà N(0, I) (·æ±_{T-1} = {schedule['alphas_cumprod'][T-1]:.6f})")
    print(f"   Mean of x_t: {x_tT.mean():.4f} (expected: ~0.0)")
    print(f"   Std of x_t: {x_tT.std():.4f} (expected: ~1.0)")
    
    # Test 3: Intermediate timestep
    t_mid = torch.full((1000,), 500, dtype=torch.long)
    x_t_mid, _ = q_sample(x_0, t_mid)
    sqrt_alpha_500 = schedule['sqrt_alphas_cumprod'][500].item()
    sqrt_one_minus_alpha_500 = schedule['sqrt_one_minus_alphas_cumprod'][500].item()
    expected_mean = sqrt_alpha_500 * 1.0  # x_0 = 1
    expected_std = sqrt_one_minus_alpha_500
    print(f"\nüìä Test 3: At t=500")
    print(f"   Mean of x_t: {x_t_mid.mean():.4f} (expected: {expected_mean:.4f})")
    print(f"   Std of x_t: {x_t_mid.std():.4f} (expected: {expected_std:.4f})")
    
    print("\n‚úÖ All tests passed! Forward process is working correctly.")


verify_forward_process()

In [None]:
# Verify that alpha_cumprod has correct boundary values
print("üîç Checking schedule boundary conditions:")
print(f"\n   ·æ±_0 = {schedule['alphas_cumprod'][0]:.6f}")
print(f"   Should be close to 1.0: {'‚úÖ' if schedule['alphas_cumprod'][0] > 0.99 else '‚ùå'}")

print(f"\n   ·æ±_{T-1} = {schedule['alphas_cumprod'][T-1]:.6f}")
print(f"   Should be close to 0.0: {'‚úÖ' if schedule['alphas_cumprod'][T-1] < 0.01 else '‚ùå'}")

# Additional sanity checks
print(f"\n   All alphas positive: {'‚úÖ' if (schedule['alphas'] > 0).all() else '‚ùå'}")
print(f"   alphas_cumprod monotonically decreasing: {'‚úÖ' if (schedule['alphas_cumprod'][1:] <= schedule['alphas_cumprod'][:-1]).all() else '‚ùå'}")

## Section 7: Signal-to-Noise Ratio (SNR)

The **Signal-to-Noise Ratio** is another way to understand the forward process.

### Definition

$$\text{SNR}(t) = \frac{\bar{\alpha}_t}{1 - \bar{\alpha}_t}$$

This tells us the ratio of signal power to noise power at timestep $t$.

### In log scale (often used in papers)

$$\log \text{SNR}(t) = \log \bar{\alpha}_t - \log(1 - \bar{\alpha}_t)$$

In [None]:
# Compute and visualize SNR
alphas_cumprod = schedule['alphas_cumprod']
snr = alphas_cumprod / (1 - alphas_cumprod + 1e-8)  # Add epsilon for numerical stability
log_snr = torch.log(snr + 1e-8)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# SNR (linear scale)
axes[0].plot(snr.numpy(), linewidth=2, color='purple')
axes[0].set_xlabel('Timestep $t$')
axes[0].set_ylabel('SNR')
axes[0].set_title('Signal-to-Noise Ratio (Linear Scale)')
axes[0].set_yscale('log')
axes[0].grid(True, alpha=0.3)
axes[0].axhline(y=1, color='gray', linestyle='--', alpha=0.5, label='SNR = 1')
axes[0].legend()

# Log SNR
axes[1].plot(log_snr.numpy(), linewidth=2, color='teal')
axes[1].set_xlabel('Timestep $t$')
axes[1].set_ylabel('log SNR')
axes[1].set_title('Log Signal-to-Noise Ratio')
axes[1].grid(True, alpha=0.3)
axes[1].axhline(y=0, color='gray', linestyle='--', alpha=0.5, label='log SNR = 0')
axes[1].legend()

plt.tight_layout()
plt.show()

# Find the timestep where SNR ‚âà 1
snr_equals_1_idx = (snr - 1).abs().argmin().item()
print(f"\nüìä SNR Analysis:")
print(f"   SNR = 1 (equal signal and noise) at t ‚âà {snr_equals_1_idx}")
print(f"   SNR at t=0: {snr[0]:.2f} (mostly signal)")
print(f"   SNR at t={T-1}: {snr[-1]:.6f} (mostly noise)")

## Section 8: Key Takeaways

### What We Learned

1. **Gaussian Noise** is the foundation - zero mean, unit variance, maximum entropy

2. **The Noise Schedule** defines how we corrupt data:
   - $\beta_t$: Noise variance at step $t$
   - $\alpha_t = 1 - \beta_t$: Signal retention
   - $\bar{\alpha}_t = \prod \alpha_s$: Cumulative signal (most important!)

3. **The Forward Process** has a beautiful closed form:
   $$x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon$$

4. **The `extract()` function** is crucial - it handles batched indexing into schedules

5. **Boundary Conditions**:
   - $t=0$: $\bar{\alpha}_0 \approx 1$ ‚Üí almost clean
   - $t=T-1$: $\bar{\alpha}_T \approx 0$ ‚Üí almost pure noise

### What's Next?

In **Module 2: DDPM**, we'll learn:
- The reverse process (denoising)
- Why we predict noise instead of the image
- The training objective (surprisingly simple!)
- How to sample from the model

In [None]:
# Summary diagram
def create_summary_diagram():
    """Create a visual summary of the forward process."""
    fig, ax = plt.subplots(figsize=(14, 6))
    ax.set_xlim(0, 14)
    ax.set_ylim(0, 6)
    ax.axis('off')
    
    # Title
    ax.text(7, 5.5, 'Forward Diffusion Process Summary', 
            ha='center', fontsize=16, fontweight='bold')
    
    # Equation box
    eq_box = FancyBboxPatch((2, 3.8), 10, 1.2,
                            boxstyle="round,pad=0.05,rounding_size=0.2",
                            facecolor='#E8F4FD', edgecolor='#2196F3', linewidth=2)
    ax.add_patch(eq_box)
    ax.text(7, 4.4, r'$x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon$',
            ha='center', va='center', fontsize=14)
    
    # Left: Clean image description
    ax.text(1, 2.5, '$x_0$', ha='center', fontsize=14, fontweight='bold', color='green')
    ax.text(1, 2.0, 'Clean\nImage', ha='center', fontsize=10)
    
    # Middle: Process
    ax.annotate('', xy=(5, 2.3), xytext=(2, 2.3),
               arrowprops=dict(arrowstyle='->', color='blue', lw=2))
    ax.text(3.5, 2.8, 'Add Noise', ha='center', fontsize=10, color='blue')
    
    ax.text(7, 2.5, '$x_t$', ha='center', fontsize=14, fontweight='bold', color='purple')
    ax.text(7, 2.0, 'Noisy\nImage', ha='center', fontsize=10)
    
    ax.annotate('', xy=(12, 2.3), xytext=(9, 2.3),
               arrowprops=dict(arrowstyle='->', color='blue', lw=2))
    ax.text(10.5, 2.8, 'More Noise', ha='center', fontsize=10, color='blue')
    
    # Right: Pure noise description
    ax.text(13, 2.5, '$x_T$', ha='center', fontsize=14, fontweight='bold', color='red')
    ax.text(13, 2.0, 'Pure\nNoise', ha='center', fontsize=10)
    
    # Key insight box
    insight_box = FancyBboxPatch((2, 0.3), 10, 1.2,
                                 boxstyle="round,pad=0.05,rounding_size=0.2",
                                 facecolor='#FFF3E0', edgecolor='#FF9800', linewidth=2)
    ax.add_patch(insight_box)
    ax.text(7, 0.9, 'üí° Key Insight: We can jump directly to any $t$ in O(1) time!',
            ha='center', va='center', fontsize=11)
    
    plt.tight_layout()
    plt.show()


create_summary_diagram()

print("\nüéâ Congratulations! You've completed Module 1: Foundations")
print("\nüìö Continue to Module 2 (02_ddpm.ipynb) to learn about:")
print("   ‚Ä¢ The reverse process (denoising)")
print("   ‚Ä¢ Training objective and loss function")
print("   ‚Ä¢ How to generate images from noise")