# Build a Diffusion Model

**Module 6.2, Lesson 5 (Capstone)** | CourseAI

You have spent four lessons building up every piece of the diffusion pipeline: the intuition for why it works, the math of noise destruction, the training objective, and the sampling algorithm. You have never written a single line of diffusion code. This notebook closes that gap.

**What you will do:**
- Implement the noise schedule and compute alpha-bar from scratch
- Implement the forward process as a `q_sample()` function
- Read and understand a minimal U-Net architecture (provided with annotations)
- Fill in the DDPM training loop and train on MNIST
- Implement the sampling algorithm and generate images from pure noise
- Measure the computational cost of 1,000-step pixel-space sampling
- Compare diffusion generation to VAE generation in quality and speed

**For each exercise, PREDICT the output before running the cell.**

Every line of code in this notebook comes from the last four lessons. The forward process formula is from *The Forward Process*. The training algorithm is from *Learning to Denoise*. The sampling loop is from *Sampling and Generation*. No new theory — just translation from math to PyTorch.

**Estimated time:** 60–90 minutes (training takes 20–30 minutes on a Colab GPU).

---

## Setup

Run this cell to install dependencies, import everything, and configure the environment.

**Important:** Set the runtime to GPU before running. In Colab: Runtime → Change runtime type → T4 GPU.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import time

# Reproducible results
torch.manual_seed(42)

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
else:
    print('WARNING: No GPU detected. Training will be very slow.')
    print('In Colab: Runtime → Change runtime type → T4 GPU')

print('\nSetup complete.')

### Load MNIST

We normalize images to [-1, 1]. This connects to the **variance-preserving formulation** from *The Forward Process*: images are centered around 0 with roughly unit variance, so when we add unit-variance Gaussian noise, the total variance stays controlled.

In [None]:
# Load MNIST, normalized to [-1, 1]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # maps [0,1] → [-1, 1]
])

dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)

dataloader = DataLoader(dataset, batch_size=128, shuffle=True, drop_last=True)

print(f'Dataset size: {len(dataset)}')
print(f'Image shape: {dataset[0][0].shape}')
print(f'Pixel range: [{dataset[0][0].min():.1f}, {dataset[0][0].max():.1f}]')
print(f'Batches per epoch: {len(dataloader)}')

# Show a few examples
fig, axes = plt.subplots(1, 8, figsize=(12, 2))
for i in range(8):
    img, label = dataset[i]
    axes[i].imshow(img.squeeze(), cmap='gray', vmin=-1, vmax=1)
    axes[i].set_title(str(label), fontsize=10)
    axes[i].axis('off')
plt.suptitle('MNIST samples (normalized to [-1, 1])', y=1.02)
plt.tight_layout()
plt.show()

---

## Part 1: Noise Schedule (Guided)

The noise schedule defines how quickly noise is added during the forward process. From *The Forward Process*, you know the chain:

$$\beta_t \rightarrow \alpha_t = 1 - \beta_t \rightarrow \bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s$$

$\bar{\alpha}_t$ is the **signal-to-noise dial** — the fraction of original signal remaining at timestep $t$. We use a **linear schedule** with $T = 1000$, $\beta_{\min} = 0.0001$, $\beta_{\max} = 0.02$, exactly as defined in the original DDPM paper.

**Before running, predict:**
- What will `alpha_bar[0]` be? (Hint: the product of one alpha near 1.)
- What will `alpha_bar[999]` be? (Hint: the cumulative product of 1000 values, each slightly less than 1.)
- What shape will the alpha_bar curve have?

In [None]:
# Diffusion hyperparameters
T = 1000           # Total timesteps
beta_min = 0.0001  # Starting noise level
beta_max = 0.02    # Ending noise level

# Linear noise schedule: beta increases linearly from beta_min to beta_max
betas = torch.linspace(beta_min, beta_max, T)

# alpha_t = 1 - beta_t (fraction of signal preserved at each step)
alphas = 1.0 - betas

# alpha_bar_t = cumulative product of alphas (total signal remaining at step t)
# This is the key quantity from The Forward Process
alpha_bars = torch.cumprod(alphas, dim=0)

print(f'beta range:       [{betas[0]:.4f}, {betas[-1]:.4f}]')
print(f'alpha range:      [{alphas[0]:.4f}, {alphas[-1]:.4f}]')
print(f'alpha_bar[0]:     {alpha_bars[0]:.4f}  (nearly all signal)')
print(f'alpha_bar[499]:   {alpha_bars[499]:.4f}  (midpoint)')
print(f'alpha_bar[999]:   {alpha_bars[999]:.6f}  (nearly no signal)')

In [None]:
# Plot the alpha_bar curve — the signal-to-noise dial
# Compare this to the widget from The Forward Process
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Alpha-bar curve
ax1.plot(alpha_bars.numpy(), color='#c084fc', linewidth=2)
ax1.set_xlabel('Timestep t', fontsize=12)
ax1.set_ylabel('$\\bar{\\alpha}_t$', fontsize=14)
ax1.set_title('Alpha-bar: signal remaining at each timestep', fontsize=12)
ax1.axhline(y=0.5, color='gray', linestyle='--', alpha=0.4, label='50% signal')
ax1.legend(fontsize=10)

# Beta schedule
ax2.plot(betas.numpy(), color='#86efac', linewidth=2)
ax2.set_xlabel('Timestep t', fontsize=12)
ax2.set_ylabel('$\\beta_t$', fontsize=14)
ax2.set_title('Beta schedule (noise added per step)', fontsize=12)

plt.tight_layout()
plt.show()

print('Does your alpha_bar curve look like the one from the widget in The Forward Process?')
print('It should start near 1.0 and drop to near 0.0.')
print('With a linear schedule, the curve is not linear — it is concave (drops faster at first).')

### What Just Happened

You computed the complete noise schedule from scratch:
- **betas**: linearly increasing noise level at each step
- **alphas**: fraction of signal preserved at each step ($1 - \beta_t$)
- **alpha_bars**: cumulative product — total signal remaining at timestep $t$

This is the same chain of definitions from *The Forward Process*. Alpha-bar starts near 1.0 (all signal) and drops to near 0.0 (all noise). The linear schedule is not a linear curve in alpha-bar — the cumulative product creates a concave shape that drops faster at first.

We will use `alpha_bars` everywhere in this notebook: in the forward process, in the training loop, and in the sampling algorithm.

---

## Part 2: Forward Process (Guided)

Now implement the closed-form formula as a function. From *The Forward Process*:

$$x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon$$

This is the formula you derived step by step, verified with the 1D pixel walkthrough, and used in the *Learning to Denoise* notebook. Now it becomes a reusable function.

**Before running, predict:**
- At $t = 0$, what will the image look like? (What are the coefficients?)
- At $t = 999$, what will the image look like? (Recall: $\bar{\alpha}_{999} \approx 0$.)

In [None]:
def q_sample(x_0, t, alpha_bars, noise=None):
    """Forward process: create a noisy image at timestep t.

    This is the closed-form formula from The Forward Process:
    x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon

    Args:
        x_0: clean images [batch_size, 1, 28, 28]
        t: timesteps [batch_size] (integer indices into alpha_bars)
        alpha_bars: precomputed alpha_bar schedule [T]
        noise: optional pre-sampled noise (if None, sample fresh noise)

    Returns:
        x_t: noisy images [batch_size, 1, 28, 28]
        noise: the noise that was added (needed for training loss)
    """
    if noise is None:
        noise = torch.randn_like(x_0)

    # Gather alpha_bar for each image's timestep
    alpha_bar_t = alpha_bars[t]  # [batch_size]

    # Reshape for broadcasting: [batch_size] → [batch_size, 1, 1, 1]
    alpha_bar_t = alpha_bar_t.view(-1, 1, 1, 1)

    # The closed-form formula
    signal_coeff = torch.sqrt(alpha_bar_t)
    noise_coeff = torch.sqrt(1.0 - alpha_bar_t)
    x_t = signal_coeff * x_0 + noise_coeff * noise

    return x_t, noise

print('q_sample() defined. This is the formula from The Forward Process.')

In [None]:
# Visualize the noise progression on a real MNIST digit
# Compare this to the DiffusionNoiseWidget from The Diffusion Idea

x_0, label = dataset[3]  # grab a digit
x_0_batch = x_0.unsqueeze(0)  # [1, 1, 28, 28]

# Use the same noise across all timesteps so we can see the progression clearly
torch.manual_seed(42)
fixed_noise = torch.randn_like(x_0_batch)

timesteps_to_show = [0, 100, 250, 500, 750, 999]

fig, axes = plt.subplots(1, len(timesteps_to_show), figsize=(15, 3))

for i, t_val in enumerate(timesteps_to_show):
    t_tensor = torch.tensor([t_val])
    x_t, _ = q_sample(x_0_batch, t_tensor, alpha_bars, noise=fixed_noise)

    axes[i].imshow(x_t.squeeze(), cmap='gray', vmin=-2, vmax=2)
    axes[i].set_title(f't={t_val}\n$\\bar{{\\alpha}}$={alpha_bars[t_val]:.3f}', fontsize=10)
    axes[i].axis('off')

plt.suptitle(f'Forward process: digit {label} at different noise levels', y=1.05, fontsize=13)
plt.tight_layout()
plt.show()

print('At t=0:   nearly all signal — the digit is clear.')
print('At t=250: light noise — still recognizable.')
print('At t=500: signal and noise are mixing.')
print('At t=750: mostly noise — barely recognizable.')
print('At t=999: almost pure noise — the digit is gone.')
print()
print('This should match what you saw in the DiffusionNoiseWidget from The Diffusion Idea.')

### What Just Happened

You implemented `q_sample()` — the closed-form formula from *The Forward Process* as a reusable Python function. The formula lets you jump directly to any noise level without iterating through intermediate steps.

The noise progression should match what you saw in the interactive widgets:
- At $t=250$, the digit is still recognizable (high $\bar{\alpha}$, mostly signal)
- At $t=750$, it is mostly noise (low $\bar{\alpha}$, mostly noise)
- At $t=999$, it is indistinguishable from pure Gaussian noise

This function will be used in the training loop (Part 4) to create noisy training examples on the fly.

---

## Part 3: Simple U-Net (Guided)

Now we need a denoising network — the model that takes a noisy image $x_t$ and a timestep $t$ and predicts the noise $\epsilon$ that was added.

The architecture is a **minimal U-Net**: an encoder-decoder with skip connections and timestep conditioning. Think of it as the autoencoder from *Autoencoders* (Module 6.1) with two additions:

1. **Skip connections** — the autoencoder's bottleneck forces compression, but for denoising we want the decoder to have access to the encoder's high-resolution features. Skip connections pass them through directly, like giving the decoder a cheat sheet.

2. **Timestep embedding** — the network needs to know which noise level it is working at. At $t = 50$, remove a tiny amount of noise. At $t = 950$, hallucinate structure from near-pure static. We embed $t$ into a vector and add it to the features.

**This section is the most scaffolded.** The architecture is provided and annotated — your job is to read and understand it, not write it from scratch. Architecture sophistication is not the focus of this lesson. The full U-Net with attention and sinusoidal embeddings comes in Module 6.3.

**Before running, predict:**
- How many parameters will this network have? (Hint: think small. This is MNIST at 28x28.)
- What will the output shape be? (Hint: same as the input — the model predicts noise at every pixel.)

In [None]:
class ConvBlock(nn.Module):
    """Two convolutions with BatchNorm and ReLU.
    The same Conv-BN-ReLU pattern from your CNNs in Series 3."""

    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.conv(x)


class SimpleUNet(nn.Module):
    """A minimal U-Net for denoising MNIST.

    Architecture:
    - 3 encoder levels: 1->32->64->128 channels, spatial 28->14->7
    - Bottleneck at 7x7 with 256 channels
    - 2 decoder levels with skip connections: enc1->dec2, enc2->dec3
      (enc3 feeds into the bottleneck without an explicit skip connection)
    - Timestep embedding: learned linear projection added to bottleneck features
    - Output: 1 channel (predicted noise, same shape as input)

    This is NOT the full U-Net from real diffusion systems.
    No attention, no group norm, no sinusoidal embeddings.
    Just enough to prove diffusion works. Module 6.3 builds the real thing.
    """

    def __init__(self):
        super().__init__()

        # ---- Timestep embedding ----
        # The network needs to know which noise level it is working at.
        # Simplest approach: embed the integer timestep into a vector
        # using a learned linear layer, then add it to the features.
        # (The full version uses sinusoidal positional encoding — Module 6.3.)
        self.time_embed = nn.Sequential(
            nn.Linear(1, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
        )

        # ---- Encoder (downsampling path) ----
        # Same structure as the autoencoder encoder from Module 6.1.
        # Conv blocks extract features, MaxPool reduces spatial resolution.
        self.enc1 = ConvBlock(1, 32)        # 28x28 → 28x28, 32 channels
        self.pool1 = nn.MaxPool2d(2)        # 28x28 → 14x14
        self.enc2 = ConvBlock(32, 64)       # 14x14 → 14x14, 64 channels
        self.pool2 = nn.MaxPool2d(2)        # 14x14 → 7x7
        self.enc3 = ConvBlock(64, 128)      # 7x7 → 7x7, 128 channels

        # ---- Bottleneck ----
        self.bottleneck = ConvBlock(128, 256)  # 7x7, 256 channels

        # ---- Decoder (upsampling path) ----
        # ConvTranspose2d upsamples (like the autoencoder decoder).
        # But unlike the autoencoder, the decoder receives SKIP CONNECTIONS
        # from the encoder — features are concatenated along the channel dimension.
        # That is why the input channels include both upsampled and skip features.
        # Note: we have skip connections at 2 of 3 resolution levels.
        # enc3's features flow through the bottleneck rather than being skipped directly.
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)  # 7x7 → 14x14
        self.dec3 = ConvBlock(128 + 64, 128)   # 128 from upsample + 64 from enc2 skip
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)   # 14x14 → 28x28
        self.dec2 = ConvBlock(64 + 32, 64)     # 64 from upsample + 32 from enc1 skip

        # Final conv: map to 1 output channel (predicted noise)
        self.final = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=1),  # 1x1 conv to get 1 output channel
        )

    def forward(self, x, t):
        """
        Args:
            x: noisy image [batch_size, 1, 28, 28]
            t: timestep [batch_size] (integer)

        Returns:
            predicted noise [batch_size, 1, 28, 28]
        """
        # Embed the timestep
        # Normalize t to [0, 1] so the network sees a consistent range
        t_normalized = t.float().unsqueeze(1) / T  # [batch_size, 1]
        t_emb = self.time_embed(t_normalized)      # [batch_size, 128]

        # ---- Encoder ----
        e1 = self.enc1(x)               # [B, 32, 28, 28]
        e2 = self.enc2(self.pool1(e1))   # [B, 64, 14, 14]
        e3 = self.enc3(self.pool2(e2))   # [B, 128, 7, 7]

        # ---- Bottleneck ----
        # enc3's features flow into the bottleneck directly (no skip connection here)
        b = self.bottleneck(e3)          # [B, 256, 7, 7]

        # Add timestep embedding to bottleneck features.
        # Reshape: [B, 128] → [B, 128, 1, 1] and broadcast across spatial dims.
        # We only add to the first 128 channels (matching the embedding size).
        t_emb_spatial = t_emb.view(-1, 128, 1, 1)  # [B, 128, 1, 1]
        b[:, :128, :, :] = b[:, :128, :, :] + t_emb_spatial

        # ---- Decoder with skip connections ----
        # Skip connections at 2 of 3 levels: enc2→dec3 and enc1→dec2.
        # This is the key difference from the autoencoder: the decoder gets a
        # "cheat sheet" of high-resolution features via concatenation.
        d3 = self.up3(b)                             # [B, 128, 14, 14]
        d3 = self.dec3(torch.cat([d3, e2], dim=1))   # concat enc2 skip → [B, 128, 14, 14]

        d2 = self.up2(d3)                            # [B, 64, 28, 28]
        d2 = self.dec2(torch.cat([d2, e1], dim=1))   # concat enc1 skip → [B, 64, 28, 28]

        # Final conv: map from 64 channels to 1 (predicted noise)
        out = self.final(d2)                         # [B, 1, 28, 28]

        return out

print('SimpleUNet defined.')

In [None]:
# Verify the architecture: check shapes and count parameters
model = SimpleUNet().to(device)

# Test with a dummy input
dummy_x = torch.randn(4, 1, 28, 28).to(device)
dummy_t = torch.randint(0, T, (4,)).to(device)
dummy_out = model(dummy_x, dummy_t)

print(f'Input shape:  {dummy_x.shape}')    # [4, 1, 28, 28]
print(f'Output shape: {dummy_out.shape}')   # [4, 1, 28, 28] — same as input!

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'\nTotal parameters:     {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
print(f'\nThis is a small model — well under 1M parameters.')
print(f'Real diffusion models (Stable Diffusion) have ~860M parameters.')
print(f'But this is enough to generate recognizable MNIST digits.')

### What Just Happened

You read through a minimal U-Net architecture with clear annotations. The key pieces:

1. **Encoder-decoder structure** — the same hourglass shape from your autoencoder in Module 6.1. Convolutions extract features, pooling reduces resolution, transposed convolutions upsample.

2. **Skip connections** — encoder features at each resolution level (enc1 at 28x28, enc2 at 14x14) are concatenated with the corresponding decoder level. This is the upgrade over the autoencoder: the decoder does not have to reconstruct fine details from a compressed bottleneck alone.

3. **Timestep embedding** — the integer timestep $t$ is normalized and passed through a small MLP, producing a 128-dimensional vector. This vector is added to the bottleneck features, telling the network which noise level it is working at.

4. **Output shape = input shape** — the network takes a noisy image and outputs a noise prediction of the same size. Every pixel gets a noise estimate.

The architecture is deliberately minimal. No attention layers, no group normalization, no sinusoidal positional encoding. These improvements come in Module 6.3 when you build the full U-Net.

---

## Part 4: Training Loop (Supported)

Now you assemble the DDPM training loop. From *Learning to Denoise*, the algorithm has 7 steps:

1. **Sample** a training image $x_0$ from the dataset
2. **Sample** a random timestep $t \sim \text{Uniform}(0, T-1)$
3. **Sample** noise $\epsilon \sim \mathcal{N}(0, I)$
4. **Create** the noisy image using `q_sample()`
5. **Predict**: $\hat{\epsilon} = \text{model}(x_t, t)$
6. **Compute loss**: $L = \text{MSE}(\hat{\epsilon}, \epsilon)$
7. **Backpropagate** and update weights

The loop skeleton is provided. You fill in the **diffusion-specific parts** (steps 2-4 and step 6). The surrounding code handles the standard training mechanics you already know.

**Task:** Fill in the four `# YOUR CODE HERE` markers. Each is 1-2 lines of code.

In [None]:
def train_epoch(model, dataloader, optimizer, alpha_bars, T, device):
    """Train for one epoch. Returns the average loss."""
    model.train()
    total_loss = 0.0
    num_batches = 0

    for batch_images, _ in dataloader:
        # Step 1: x_0 is the batch of clean images
        x_0 = batch_images.to(device)
        batch_size = x_0.shape[0]

        # Step 2: Sample a random timestep for each image in the batch
        # Each image gets its own random t from Uniform(0, T-1)
        # Hint: use torch.randint. Range should be [0, T).
        # YOUR CODE HERE


        # Step 3: Sample noise — the "answer key" for this training step
        # Fresh Gaussian noise, same shape as x_0
        # Hint: use torch.randn_like
        # YOUR CODE HERE


        # Step 4: Create the noisy image using the forward process
        # Hint: use the q_sample() function you implemented in Part 2
        # YOUR CODE HERE


        # Step 5: Predict the noise (forward pass through the network)
        epsilon_hat = model(x_t, t)

        # Step 6: Compute MSE loss between predicted and actual noise
        # This is the same nn.MSELoss from Series 1 and the Learning to Denoise notebook
        # YOUR CODE HERE


        # Step 7: Backpropagate and update weights (standard training loop)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    return total_loss / num_batches

print('train_epoch() defined. Fill in the YOUR CODE HERE markers before running.')

<details>
<summary>Solution</summary>

The four fills are the diffusion-specific data preparation steps. Everything else is the standard training loop you have used since Series 2.

**Step 2 — Sample random timesteps:**
```python
t = torch.randint(0, T, (batch_size,), device=device)
```
Each image in the batch gets its own random timestep. This is why the closed-form formula matters — it lets you create a noisy image at any timestep without iterating.

**Step 3 — Sample noise:**
```python
noise = torch.randn_like(x_0)
```
Fresh Gaussian noise, same shape as the images. This is the "answer key" — what the model will try to predict.

**Step 4 — Create noisy image:**
```python
x_t, noise = q_sample(x_0, t, alpha_bars, noise=noise)
```
Uses the `q_sample()` function you built in Part 2. The `alpha_bars` parameter is already on the GPU (we moved it to the device before training), so no `.to(device)` call is needed here.

**Step 6 — Compute loss:**
```python
loss = F.mse_loss(epsilon_hat, noise)
```
MSE between predicted noise and actual noise. The same loss formula from Series 1. Different question, same math.

</details>

### Phase 1: Train for 5 Epochs

Start with a short training run to see if the loss decreases. This is a sanity check, not full training.

**Before running, predict:**
- Will 5 epochs be enough to generate recognizable digits? (Recall the misconception from the planning: diffusion training needs more patience than autoencoders.)
- What will the initial loss be? (Hint: an untrained model predicts random noise.)

In [None]:
# Fresh model and optimizer
model = SimpleUNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# Move alpha_bars to device once (avoids repeated CPU→GPU transfers per batch)
alpha_bars_device = alpha_bars.to(device)

# Train for 5 epochs
losses_phase1 = []
print('Phase 1: Training for 5 epochs...')
print('-' * 40)

for epoch in range(1, 6):
    start_time = time.time()
    avg_loss = train_epoch(model, dataloader, optimizer, alpha_bars_device, T, device)
    elapsed = time.time() - start_time
    losses_phase1.append(avg_loss)
    print(f'Epoch {epoch:>2d} | Loss: {avg_loss:.4f} | Time: {elapsed:.1f}s')

print('\nPhase 1 complete.')
print(f'Loss dropped from {losses_phase1[0]:.4f} to {losses_phase1[-1]:.4f}')

In [None]:
# Plot the Phase 1 loss curve
plt.figure(figsize=(8, 4))
plt.plot(range(1, len(losses_phase1) + 1), losses_phase1, 'o-', color='#c084fc', linewidth=2, markersize=6)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Average Loss', fontsize=12)
plt.title('Phase 1: First 5 epochs', fontsize=13)
plt.axhline(y=2.0, color='gray', linestyle='--', alpha=0.4, label='Random guess (MSE~2)')
plt.legend(fontsize=10)
plt.tight_layout()
plt.show()

print('The loss should be decreasing — the model is learning to predict noise.')
print('But 5 epochs is not enough for good generation. The model must learn to')
print('denoise at ALL 1000 noise levels simultaneously, and random timestep sampling')
print('means it sees each noise level sparsely.')

### Is the Model Learning? A Denoising Diagnostic

Before training further, let's check whether 5 epochs was enough to learn *anything*. We haven't built the sampling loop yet (that's Part 5), so we can't generate from scratch. But we can test the model's denoising ability directly.

**The idea:** Take a real MNIST digit, add noise at several levels, then ask the model: "what do you think the clean image looks like?" The model predicts $\hat{\epsilon}$, and we recover the clean image estimate:

$$\hat{x}_0 = \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \hat{\epsilon}}{\sqrt{\bar{\alpha}_t}}$$

This is just the closed-form formula from *The Forward Process*, rearranged to solve for $x_0$.

**Before running, predict:** At which noise levels will the model's prediction be best — low noise (small $t$) or high noise (large $t$)?

In [None]:
# Single-step denoising diagnostic: what does the model think the clean image looks like?
# This does NOT use the sampling loop (that's Part 5). Instead, we ask the model
# to predict epsilon at a single timestep, then recover the clean image estimate.

@torch.no_grad()
def denoise_diagnostic(model, x_0, timesteps, alpha_bars, device):
    """Show the model's one-step denoising prediction at several noise levels.

    For each timestep t:
      1. Create noisy image: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
      2. Predict noise: epsilon_hat = model(x_t, t)
      3. Recover clean estimate: x_0_hat = (x_t - sqrt(1 - alpha_bar_t) * epsilon_hat) / sqrt(alpha_bar_t)

    This is the closed-form formula rearranged to solve for x_0.
    """
    model.eval()
    x_0 = x_0.to(device)
    results = []

    for t_val in timesteps:
        t_tensor = torch.tensor([t_val], device=device)
        noise = torch.randn_like(x_0)

        # Forward process: add noise
        alpha_bar_t = alpha_bars[t_val]
        x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1.0 - alpha_bar_t) * noise

        # Model predicts the noise
        epsilon_hat = model(x_t, t_tensor)

        # Recover clean image estimate (rearranged closed-form formula)
        x_0_hat = (x_t - torch.sqrt(1.0 - alpha_bar_t) * epsilon_hat) / torch.sqrt(alpha_bar_t)

        results.append({
            't': t_val,
            'x_t': x_t.cpu().squeeze(),
            'x_0_hat': x_0_hat.cpu().squeeze().clamp(-1, 1),
        })

    return results


# Pick a real digit and run the diagnostic
x_0_sample, label = dataset[3]
x_0_batch = x_0_sample.unsqueeze(0)  # [1, 1, 28, 28]

diagnostic_timesteps = [100, 300, 600, 900]
results = denoise_diagnostic(model, x_0_batch, diagnostic_timesteps, alpha_bars_device, device)

# Display: original | noisy | model's prediction at each noise level
n_cols = len(diagnostic_timesteps)
fig, axes = plt.subplots(3, n_cols, figsize=(3 * n_cols, 8))

for i, r in enumerate(results):
    # Row 1: noisy input
    axes[0, i].imshow(r['x_t'], cmap='gray', vmin=-2, vmax=2)
    axes[0, i].set_title(f't={r["t"]}\n$\\bar{{\\alpha}}$={alpha_bars[r["t"]]:.3f}', fontsize=10)
    axes[0, i].axis('off')

    # Row 2: model's prediction of clean image
    axes[1, i].imshow(r['x_0_hat'], cmap='gray', vmin=-1, vmax=1)
    axes[1, i].set_title('Model prediction', fontsize=10)
    axes[1, i].axis('off')

    # Row 3: original clean image (same in every column)
    axes[2, i].imshow(x_0_sample.squeeze(), cmap='gray', vmin=-1, vmax=1)
    axes[2, i].set_title('Original', fontsize=10)
    axes[2, i].axis('off')

# Row labels
axes[0, 0].set_ylabel('Noisy input', fontsize=11, rotation=0, labelpad=70, va='center')
axes[1, 0].set_ylabel('Model predicts', fontsize=11, rotation=0, labelpad=70, va='center')
axes[2, 0].set_ylabel('Ground truth', fontsize=11, rotation=0, labelpad=70, va='center')

plt.suptitle(f'Denoising diagnostic after 5 epochs (digit {label})', y=1.02, fontsize=13)
plt.tight_layout()
plt.show()

print('Look at the model\'s predictions:')
print('- At low noise (t=100): the model should do reasonably well — most signal is intact.')
print('- At high noise (t=900): the prediction will be rough/blobby — the model must')
print('  hallucinate structure from almost pure noise. After only 5 epochs, it is still learning.')
print()
print('The model IS learning, but 5 epochs is not enough for clean predictions at high noise.')
print('Let\'s keep training and see how much more helps.')

### Phase 2: Train for 15 More Epochs (Total 20)

Now train longer. This will take longer — allow ~15-25 minutes on a Colab GPU.

The model is learning a much harder task than your autoencoder: denoising at 1,000 different noise levels with one network. Training loss decreases gradually rather than dropping sharply.

In [None]:
# Continue training for 15 more epochs (total 20)
losses_phase2 = list(losses_phase1)  # keep phase 1 losses

print('Phase 2: Training for 15 more epochs (total 20)...')
print('-' * 40)

for epoch in range(6, 21):
    start_time = time.time()
    avg_loss = train_epoch(model, dataloader, optimizer, alpha_bars_device, T, device)
    elapsed = time.time() - start_time
    losses_phase2.append(avg_loss)
    print(f'Epoch {epoch:>2d} | Loss: {avg_loss:.4f} | Time: {elapsed:.1f}s')

print('\nPhase 2 complete.')
print(f'Total loss trajectory: {losses_phase2[0]:.4f} → {losses_phase2[-1]:.4f}')

In [None]:
# Plot the full training loss curve
plt.figure(figsize=(10, 4))
plt.plot(range(1, len(losses_phase2) + 1), losses_phase2, 'o-', color='#c084fc', linewidth=2, markersize=4)
plt.axvline(x=5, color='#86efac', linestyle='--', alpha=0.5, label='End of Phase 1')
plt.axhline(y=2.0, color='gray', linestyle='--', alpha=0.3, label='Random guess (MSE~2)')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Average Loss', fontsize=12)
plt.title('Full training curve: 20 epochs', fontsize=13)
plt.legend(fontsize=10)
plt.tight_layout()
plt.show()

print('Notice: the loss decreases gradually, not the sharp drop you might expect.')
print('The model is learning to denoise at 1000 different noise levels simultaneously.')
print('This is a harder task than autoencoder reconstruction.')

### What Just Happened

You filled in the diffusion-specific parts of the training loop and trained a real denoising model. The four fills were:

1. **Sample timesteps** — `torch.randint` to get a random $t$ per image
2. **Sample noise** — `torch.randn_like` to get the "answer key"
3. **Create noisy images** — `q_sample()` to apply the forward process
4. **Compute loss** — `F.mse_loss` between predicted and actual noise

Everything else was the standard training loop from Series 2. Same heartbeat: forward → loss → backward → update.

The loss curve shows gradual improvement. This is expected — the model must learn to denoise at 1,000 different noise levels with one set of weights. Each batch only covers a sparse sampling of those levels.

---

### Does Low Loss Mean Good Generation?

Your training loss has decreased steadily over 20 epochs. In classification, low loss meant high accuracy. In autoencoders, low loss meant good reconstructions. Does low diffusion loss guarantee good generations?

Not necessarily. The training loss is an **average** across all 1,000 timesteps. The model can have decent average MSE while still making correlated errors at certain noise levels. And in sampling, errors **accumulate** across 1,000 sequential steps — a small per-step error compounds into a large deviation by the end.

You already saw a hint of this in the denoising diagnostic: your 5-epoch model had a decreasing loss, but its predictions at high noise levels were rough and blobby. The loss was better than random, but the model had not yet learned enough to produce clean results.

**The takeaway:** Loss is a necessary but not sufficient signal. The real test is always the generated samples. Let's see what 20 epochs of training actually produces.

## Part 5: Sampling (Supported)

This is the moment of truth. From *Sampling and Generation*, the DDPM sampling algorithm:

1. **Sample** $x_T \sim \mathcal{N}(0, I)$ — start from pure noise
2. **Loop** from $t = T-1$ down to $t = 0$:
   - Predict the noise: $\hat{\epsilon} = \text{model}(x_t, t)$
   - Compute the denoised estimate (the reverse step formula):
     $$x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \hat{\epsilon} \right) + \sigma_t z$$
   - where $z \sim \mathcal{N}(0, I)$ for $t > 1$, and $z = 0$ for $t = 1$ (the last step commits without noise)
3. **Return** $x_0$

The loop skeleton is provided. You fill in the **reverse step computation** inside the loop.

**Task:** Fill in the two `# YOUR CODE HERE` markers.

In [None]:
@torch.no_grad()
def sample(model, n_samples, T, betas, alphas, alpha_bars, device):
    """Generate images using the DDPM sampling algorithm.

    This is the algorithm from Sampling and Generation:
    start from pure noise, iteratively denoise, return the generated image.

    Args:
        model: trained denoising network
        n_samples: number of images to generate
        T: total timesteps
        betas: noise schedule [T]
        alphas: 1 - betas [T]
        alpha_bars: cumulative product of alphas [T]
        device: cuda or cpu

    Returns:
        generated images [n_samples, 1, 28, 28]
    """
    model.eval()

    # Step 1: Start from pure noise
    x = torch.randn(n_samples, 1, 28, 28, device=device)

    # Precompute sigma_t = sqrt(beta_t) for the noise injection
    # (Schedule tensors stay on CPU — scalar indexing below extracts Python floats
    #  that broadcast with CUDA tensors automatically.)
    sigmas = torch.sqrt(betas)

    # Step 2: Loop from t = T-1 down to t = 0
    for t_val in reversed(range(T)):
        t_batch = torch.full((n_samples,), t_val, device=device, dtype=torch.long)

        # Predict the noise at this timestep
        epsilon_hat = model(x, t_batch)

        # Get schedule values for this timestep
        # (Scalar indexing returns Python floats — no device mismatch.)
        alpha_t = alphas[t_val]
        alpha_bar_t = alpha_bars[t_val]
        beta_t = betas[t_val]
        sigma_t = sigmas[t_val]

        # The reverse step formula from Sampling and Generation:
        # x_{t-1} = (1/sqrt(alpha_t)) * (x_t - (beta_t / sqrt(1 - alpha_bar_t)) * epsilon_hat) + sigma_t * z
        #
        # YOUR CODE HERE: compute the denoised estimate (the mean of the reverse step)
        # Hint: this is one line implementing the formula above, minus the sigma_t * z part


        # Add noise for all steps except the last (t=0)
        # At t=0, z = 0 (the last step commits without noise)
        # This is the t=1 special case from Sampling and Generation
        # (Note: t_val=0 corresponds to the final step in our 0-indexed loop)
        #
        # YOUR CODE HERE: sample z and add sigma_t * z when t_val > 0
        # Hint: sample z = torch.randn_like(x), then only add it if t_val > 0


    # Step 3: Return the generated images
    # Clamp to [-1, 1] for display
    return x.clamp(-1, 1)

print('sample() defined. Fill in the YOUR CODE HERE markers before running.')

<details>
<summary>Solution</summary>

The two fills implement the reverse step formula from *Sampling and Generation*.

**Reverse step mean:**
```python
x = (1.0 / torch.sqrt(alpha_t)) * (x - (beta_t / torch.sqrt(1.0 - alpha_bar_t)) * epsilon_hat)
```
This is the formula you traced with numbers at t=500 in the lesson. The first term scales up the signal. The second term subtracts the model's estimated noise contribution.

**Noise injection (with final-step special case):**
```python
if t_val > 0:
    z = torch.randn_like(x)
    x = x + sigma_t * z
```
For all steps except the last, we add a small amount of fresh noise. This is the stochastic part — like the temperature in language model sampling. At the final step, $z = 0$: the last step commits to a specific image without adding noise.

**Index convention note:** Our loop uses 0-indexed `t_val` (Python convention), so `t_val=0` is the final step. In the DDPM paper and in *Sampling and Generation*, this corresponds to $t=1$ (1-indexed). The condition `t_val > 0` is equivalent to the paper's "$z = 0$ when $t = 1$."

Remember the analogy from *Sampling and Generation*: "The last step commits. Every step before it explores."

</details>

### The Timing Experiment

Before generating, make a prediction.

**Think about this:** One training step takes one forward pass through the model (for one random timestep). How long did each training epoch take? Roughly how long is one forward pass?

Sampling requires **T = 1,000 sequential forward passes per image.** If you want to generate 64 images, that is 64 × 1,000 = 64,000 forward passes.

**Before running, predict:** How long will it take to generate 64 images?

In [None]:
# YOUR PREDICTION
# Before running the next cell, write your prediction here:
#
# How long will it take to generate 64 images?
# Your prediction: _____ seconds
#
# How does that compare to one training epoch?
# Your reasoning: ...

In [None]:
# First: time a single training step for comparison
model.train()
sample_batch, _ = next(iter(dataloader))
sample_batch = sample_batch.to(device)

# Warm up the GPU
for _ in range(5):
    t_dummy = torch.randint(0, T, (sample_batch.shape[0],), device=device)
    x_t_dummy, noise_dummy = q_sample(sample_batch, t_dummy, alpha_bars_device)
    _ = model(x_t_dummy, t_dummy)

# Time one training step
torch.cuda.synchronize() if device.type == 'cuda' else None
start = time.time()

t_dummy = torch.randint(0, T, (sample_batch.shape[0],), device=device)
x_t_dummy, noise_dummy = q_sample(sample_batch, t_dummy, alpha_bars_device)
pred = model(x_t_dummy, t_dummy)
loss = F.mse_loss(pred, noise_dummy)
optimizer.zero_grad()
loss.backward()
optimizer.step()

torch.cuda.synchronize() if device.type == 'cuda' else None
train_step_time = time.time() - start

print(f'One training step: {train_step_time*1000:.1f} ms')
print(f'  (1 forward pass + 1 backward pass for 1 random timestep)')
print()
print(f'Now generating 64 images...')
print(f'  (64 images × 1000 timesteps = 64,000 sequential forward passes)')
print()

In [None]:
# Generate 64 images and TIME it
torch.cuda.synchronize() if device.type == 'cuda' else None
start = time.time()

generated = sample(model, 64, T, betas, alphas, alpha_bars, device)

torch.cuda.synchronize() if device.type == 'cuda' else None
sample_time = time.time() - start

print(f'Generated 64 images in {sample_time:.1f} seconds.')
print(f'That is {sample_time/64:.2f} seconds per image.')
print(f'\nCompare:')
print(f'  One training step:       {train_step_time*1000:.1f} ms  (1 forward pass)')
print(f'  Generating one image:    {sample_time/64*1000:.0f} ms  (1000 forward passes)')
print(f'  Ratio:                   {(sample_time/64)/train_step_time:.0f}x slower')

In [None]:
# Display the generated images as an 8x8 grid
def show_grid(images, nrow=8, title='Generated Images'):
    """Display a batch of images as a grid."""
    # images: [n, 1, 28, 28], range [-1, 1]
    images = images.cpu()
    # Rescale to [0, 1] for display
    images = (images + 1) / 2
    images = images.clamp(0, 1)

    grid = torchvision.utils.make_grid(images, nrow=nrow, padding=2)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap='gray')
    plt.title(title, fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

show_grid(generated, title=f'Generated MNIST digits (20 epochs, {sample_time:.0f}s to generate)')

print('Look at your generated digits:')
print('- Are they recognizable? (They should be — if imperfect.)')
print('- Are they varied? (Different digits in different positions.)')
print('- Are they perfect? (They should NOT be — this is a tiny model on 28x28.)')
print()
print('These images never existed in the training set.')
print('They were generated from pure Gaussian noise by a model you built from scratch.')

### The Cost of Pixel-Space Diffusion

Let that sampling time sink in. This is the **deliberate pain** of this lesson.

In [None]:
# Calculate: how long would it take at scale?
time_per_image = sample_time / 64

print('=== The Cost of Pixel-Space DDPM ===')
print()
print(f'Your model: {sum(p.numel() for p in model.parameters()):,} parameters')
print(f'Image size: 28×28 (784 pixels)')
print(f'Timesteps:  {T}')
print()
print(f'Time to generate 1 image:      {time_per_image:.2f}s')
print(f'Time to generate 64 images:    {sample_time:.1f}s')
print(f'Time to generate 1000 images:  {time_per_image * 1000 / 60:.1f} minutes')
print()
print('Now imagine scaling up:')
print(f'  Stable Diffusion: 512×512 image = {512*512//784:.0f}x more pixels')
print(f'  Stable Diffusion: ~860M parameters vs your {sum(p.numel() for p in model.parameters())//1000}K')
print(f'  Each forward pass would be MUCH slower.')
print(f'  1000 steps × slower forward pass = minutes per image.')
print()
print('⚠ This slowness is NOT a bug in your code.')
print('  It is a fundamental property of pixel-space DDPM.')
print('  Every image requires T=1000 sequential forward passes.')
print('  This is why latent diffusion was invented.')

### What Just Happened

You implemented the DDPM sampling algorithm — the reverse process from *Sampling and Generation* — and generated real images from pure noise. The two fills were:

1. **The reverse step formula** — compute the denoised estimate by subtracting the model's predicted noise contribution, scaled by the schedule parameters.
2. **The noise injection** — add fresh noise at every step except the last. The last step ($t = 0$) commits without noise.

You also experienced the computational cost firsthand. One training step is fast (one forward pass). One sampling run is ~1000x slower (1000 sequential forward passes). This is the fundamental bottleneck of pixel-space DDPM.

The timing is not a failure of your implementation. It is a property of the algorithm. And it is exactly what motivates latent diffusion in Module 6.3.

---

## Part 6: Reflection and Bridge (Independent)

You have now built two generative models in this series:
1. A **VAE** (Module 6.1) that generates images with one decoder forward pass
2. A **DDPM** (this lesson) that generates images with 1,000 iterative denoising steps

### Reflection Questions

Think about these questions before reading the discussion below.

**Question 1:** Compare your VAE samples from *Exploring Latent Spaces* to your diffusion samples. Which produced sharper, more detailed images? Which was faster?

**Question 2:** The VAE generates in one forward pass (instant). Diffusion takes 1,000 steps (slow). Why does diffusion produce better quality despite using the same basic building blocks (conv layers, MSE loss, backprop)?

**Question 3:** Is there a way to get the quality of diffusion with the speed of a VAE?

In [None]:
# YOUR REFLECTION
# Before reading the discussion below, write your answers here:
#
# Question 1 — Quality vs Speed:
#   VAE quality: ...
#   VAE speed: ...
#   Diffusion quality: ...
#   Diffusion speed: ...
#
# Question 2 — Why is diffusion better quality?
#   Your reasoning: ...
#
# Question 3 — Can we get both?
#   Your idea: ...

### Discussion

**VAE vs Diffusion — a tale of two tradeoffs:**

| | VAE | Diffusion (DDPM) |
|---|---|---|
| **Speed** | One decoder forward pass — instant | 1,000 sequential forward passes — slow |
| **Quality** | Blurry, averaged features | Sharper, more detailed |
| **Why** | Must compress everything through a bottleneck | Iterative refinement — each step only needs to make a small correction |
| **Loss** | Reconstruction + KL divergence | Simple MSE on noise prediction |

**Why does diffusion produce better quality?** The VAE must compress an image into a small latent vector and reconstruct everything in one shot — details that do not survive the bottleneck are lost. Diffusion breaks the problem into 1,000 tiny steps. Each step only needs to make a small correction. The cumulative effect of 1,000 small corrections produces more detail than one big reconstruction.

**The cost:** Those 1,000 steps are sequential. You cannot parallelize them — step $t-1$ depends on step $t$. This is the fundamental bottleneck.

**Can we get both quality AND speed?** Yes — and the insight is elegant:

> *What if you ran diffusion in the VAE's latent space instead of pixel space?*

The VAE's encoder compresses a 28×28 image into a small latent representation. The latent space is much smaller than pixel space. If you run the diffusion process in that compressed space:
- Each forward pass through the denoising network is faster (smaller input)
- The 1,000 steps happen in a compressed space
- The VAE's decoder upsamples the result back to pixel space at the end

This is **latent diffusion** — the core idea behind Stable Diffusion. That is Module 6.3.

---

## Key Takeaways

1. **Every formula from the module became a line of PyTorch code.** The closed-form formula became `q_sample()`. The 7-step training algorithm became `train_epoch()`. The reverse step formula became the core of `sample()`. No new theory — just translation.

2. **A working diffusion model generates recognizable images from pure noise.** Even a minimal architecture (<1M parameters) trained on MNIST for 20 epochs produces varied, recognizable digits. Architecture sophistication improves quality; it does not gate whether diffusion works at all.

3. **The training loop is the standard loop with diffusion-specific data preparation.** Sample a random timestep, sample noise, create the noisy image with `q_sample()`, predict the noise, compute MSE loss, backpropagate. The heartbeat has not changed since Series 2.

4. **Sampling is 1,000x slower than training.** One training step: one forward pass. One sample: 1,000 sequential forward passes. This is a fundamental property of pixel-space DDPM, not a bug. You measured it yourself.

5. **The sampling cost motivates latent diffusion.** Running diffusion in a compressed latent space (instead of pixel space) makes each step faster. This is the core idea behind Stable Diffusion — Module 6.3.

**What you built:** A complete DDPM pipeline — forward process, denoising network, training loop, sampling algorithm — from scratch. Digits emerged from pure noise, generated by a model you built and trained yourself. That is the payoff of four lessons of theory.

**What comes next:** Module 6.3 answers the question that the sampling wait raises: *what if you ran diffusion in a compressed latent space?*