# Conditioning the U-Net

**Module 6.3, Lesson 2** | CourseAI

The U-Net from the previous lesson was the skeleton -- the spatial architecture. This notebook adds the nervous system: the timestep signal that tells each bone how to move. You will implement both mechanisms that make this work:

**What you will do:**
- Implement sinusoidal timestep embedding from the formula (the same one from positional encoding)
- Visualize how the embedding encodes different noise levels as distinct patterns
- Build the 2-layer MLP that refines the raw sinusoidal encoding into a timestep embedding
- Implement adaptive group normalization -- making gamma and beta depend on the timestep
- Compare the capstone's simple linear projection to sinusoidal + MLP conditioning on real training

**For each exercise, PREDICT the output before running the cell.**

Everything in this notebook connects to concepts you already know. The sinusoidal encoding is the same formula from *Embeddings and Position*. The adaptive normalization is standard group norm with one change: gamma and beta come from the timestep instead of a fixed parameter table. No new theory -- just practice.

**Estimated time:** 45--60 minutes (Exercise 4 training takes ~15 minutes on a Colab GPU).

---

## Setup

Run this cell to import everything and configure the environment.

**Important:** If you plan to run Exercise 4 (the training comparison), set the runtime to GPU. In Colab: Runtime > Change runtime type > T4 GPU. Exercises 1--3 work fine on CPU.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math

# Reproducible results
torch.manual_seed(42)

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')

print('\nSetup complete.')

---

## Exercise 1: Sinusoidal Timestep Embedding [Guided]

In *Embeddings and Position*, you implemented sinusoidal positional encoding -- the "clock with many hands" that gives each position a unique, smooth pattern. The formula was:

$$\text{PE}(\text{pos}, 2i) = \sin\!\left(\frac{\text{pos}}{10000^{2i/d}}\right), \quad \text{PE}(\text{pos}, 2i+1) = \cos\!\left(\frac{\text{pos}}{10000^{2i/d}}\right)$$

For timestep embedding, the formula is **identical** -- just replace "pos" with "t":

$$\text{TE}(t, 2i) = \sin\!\left(\frac{t}{10000^{2i/d}}\right), \quad \text{TE}(t, 2i+1) = \cos\!\left(\frac{t}{10000^{2i/d}}\right)$$

**Same formula. Different input. Different question.** In transformers, the clock encodes *where* a token sits. In diffusion, the same clock encodes *how noisy* the input is.

The four requirements still hold: **unique** (each timestep gets a distinct pattern), **smooth** (nearby timesteps produce similar embeddings), **any range** (works for 1000 timesteps or 10,000), and **deterministic** (no learned parameters).

**Before running, predict:**
- Will the heatmap for t=500 and t=501 look similar or different?
- Will the heatmap for t=500 and t=50 look similar or different?
- In the heatmap, which dimensions will change rapidly across timesteps -- the low-index or high-index dimensions?

In [None]:
def sinusoidal_embedding(t, d_emb=256):
    """Compute sinusoidal timestep embedding.

    Same formula as positional encoding from Embeddings and Position:
    TE(t, 2i)   = sin(t / 10000^(2i/d))
    TE(t, 2i+1) = cos(t / 10000^(2i/d))

    Args:
        t: timesteps [batch_size] (integer or float tensor)
        d_emb: embedding dimension (default 256)

    Returns:
        embeddings [batch_size, d_emb]
    """
    # Ensure t is a float tensor with shape [batch_size]
    if not isinstance(t, torch.Tensor):
        t = torch.tensor([t], dtype=torch.float32)
    t = t.float()
    if t.dim() == 0:
        t = t.unsqueeze(0)

    # Half the dimensions get sin, half get cos
    half_d = d_emb // 2

    # Compute the frequencies: 1 / 10000^(2i/d) for i = 0, 1, ..., half_d-1
    # Equivalently: exp(-2i/d * log(10000))
    i = torch.arange(half_d, dtype=torch.float32, device=t.device)
    freq = torch.exp(-i * (math.log(10000.0) / half_d))

    # Outer product: [batch_size, 1] * [1, half_d] -> [batch_size, half_d]
    # Each row is one timestep, each column is one frequency
    angles = t.unsqueeze(1) * freq.unsqueeze(0)

    # Interleave sin and cos: [batch_size, d_emb]
    emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)

    return emb


# Compute embeddings for several timesteps
timesteps = [0, 50, 500, 501, 950, 999]
embeddings = sinusoidal_embedding(torch.tensor(timesteps), d_emb=256)

print(f'Embedding shape: {embeddings.shape}')  # [6, 256]
print(f'Timesteps: {timesteps}')
print()

# Show the first few values for t=500 and t=50
print('First 8 dimensions of each embedding:')
for i, t_val in enumerate(timesteps):
    vals = embeddings[i, :8].numpy()
    print(f'  t={t_val:>4d}: [{"  ".join(f"{v:+.3f}" for v in vals)}  ...]')

In [None]:
# Visualize as a heatmap -- like the PositionalEncodingHeatmap from Embeddings and Position
# Each row is one timestep. Each column is one dimension of the embedding.
# You should see: low-index dimensions oscillate rapidly (the "second hand"),
# high-index dimensions change slowly (the "hour hand").

# Compute embeddings for many timesteps (dense sampling)
all_timesteps = torch.arange(0, 1000)
all_embeddings = sinusoidal_embedding(all_timesteps, d_emb=256)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Full heatmap
im1 = ax1.imshow(all_embeddings.numpy(), aspect='auto', cmap='RdBu_r',
                  vmin=-1, vmax=1, interpolation='nearest')
ax1.set_xlabel('Embedding Dimension', fontsize=11)
ax1.set_ylabel('Timestep t', fontsize=11)
ax1.set_title('Sinusoidal Timestep Embedding (all 1000 timesteps)', fontsize=12)
plt.colorbar(im1, ax=ax1, shrink=0.8)

# Zoomed heatmap: t=495 to t=505 (to check smoothness)
zoom_range = range(495, 506)
zoom_embeddings = sinusoidal_embedding(torch.tensor(list(zoom_range)), d_emb=256)

im2 = ax2.imshow(zoom_embeddings.numpy(), aspect='auto', cmap='RdBu_r',
                  vmin=-1, vmax=1, interpolation='nearest')
ax2.set_xlabel('Embedding Dimension', fontsize=11)
ax2.set_ylabel('Timestep t', fontsize=11)
ax2.set_yticks(range(len(list(zoom_range))))
ax2.set_yticklabels(list(zoom_range))
ax2.set_title('Zoomed: t=495 to t=505 (smoothness check)', fontsize=12)
plt.colorbar(im2, ax=ax2, shrink=0.8)

plt.tight_layout()
plt.show()

print('Left: full heatmap. Low-frequency dimensions (right) change slowly across timesteps.')
print('      High-frequency dimensions (left) oscillate rapidly.')
print('      This is the "clock with many hands" -- the second hand (dim 0) captures')
print('      fine differences; the hour hand (dim 255) captures broad noise-level changes.')
print()
print('Right: zoomed view of t=495 to t=505. Adjacent timesteps produce nearly identical')
print('       patterns. The embedding is SMOOTH -- small change in t, small change in embedding.')

In [None]:
# Compare sinusoidal embedding to a RANDOM embedding
# This is the negative example from the lesson: if embeddings were random,
# t=500 and t=501 would have completely unrelated patterns.

# Random embedding: each timestep gets a random vector
torch.manual_seed(0)
random_embeddings = torch.randn(1000, 256)

# Compute cosine similarity between consecutive timesteps
def cosine_sim(a, b):
    return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()

# Sinusoidal: adjacent timesteps
sin_sim_adjacent = cosine_sim(all_embeddings[500], all_embeddings[501])
sin_sim_distant = cosine_sim(all_embeddings[500], all_embeddings[50])

# Random: adjacent timesteps
rand_sim_adjacent = cosine_sim(random_embeddings[500], random_embeddings[501])
rand_sim_distant = cosine_sim(random_embeddings[500], random_embeddings[50])

print('Cosine similarity comparison:')
print(f'                        t=500 vs t=501    t=500 vs t=50')
print(f'  Sinusoidal:           {sin_sim_adjacent:+.4f}            {sin_sim_distant:+.4f}')
print(f'  Random:               {rand_sim_adjacent:+.4f}            {rand_sim_distant:+.4f}')
print()
print('With sinusoidal encoding:')
print('  - Adjacent timesteps (500 vs 501) are very similar (high cosine similarity).')
print('  - Distant timesteps (500 vs 50) are quite different (low similarity).')
print('  This is the smoothness property in action.')
print()
print('With random encoding:')
print('  - Adjacent and distant timesteps have similarly random similarity.')
print('  - The network would treat t=500 and t=501 as unrelated tasks.')
print('  - No smoothness, no generalization across nearby timesteps.')

### What Just Happened

You implemented the sinusoidal timestep embedding -- the **same formula** from positional encoding in transformers, now encoding noise level instead of sequence position.

The heatmap shows the "clock with many hands":
- **Low-index dimensions** (left side) oscillate rapidly, capturing fine differences between adjacent timesteps. These are the "second hand."
- **High-index dimensions** (right side) change slowly, capturing the broad distinction between "lots of noise" and "almost clean." These are the "hour hand."

The smoothness comparison confirms why sinusoidal encoding is superior to random embeddings: adjacent timesteps (500 vs 501) get nearly identical patterns, so the network can generalize across nearby noise levels. A random embedding destroys this structure.

This encoding has no learned parameters. The learning happens in the MLP that comes next.

---

## Exercise 2: Timestep MLP [Guided]

The sinusoidal encoding provides a structured, multi-frequency input. But it is not the final embedding -- it is the *input* to a 2-layer MLP:

```
t -> sinusoidal_encoding(t) -> Linear(256, 512) -> GELU -> Linear(512, 512) -> timestep embedding
```

The MLP learns to combine and transform the raw frequency components into features that are useful for denoising. This is a standard pattern: provide a structured input (the sinusoidal encoding), let the network refine it (the MLP). The sinusoidal part has no learned parameters -- all the learning happens in the MLP.

**Before running, predict:**
- After the MLP transforms the embeddings, will t=500 and t=501 still be similar? (The MLP is a smooth function -- does it preserve the smoothness of its input?)
- In the cosine similarity matrix, what pattern do you expect along the diagonal?

In [None]:
class TimestepEmbedding(nn.Module):
    """Full timestep embedding: sinusoidal encoding + 2-layer MLP.

    Pipeline:
      t (integer) -> sinusoidal encoding (256-dim) -> MLP -> embedding (512-dim)

    The sinusoidal encoding provides a structured, multi-frequency input.
    The MLP refines it into features useful for denoising.
    """

    def __init__(self, d_sinusoidal=256, d_emb=512):
        super().__init__()
        self.d_sinusoidal = d_sinusoidal
        self.mlp = nn.Sequential(
            nn.Linear(d_sinusoidal, d_emb),
            nn.GELU(),
            nn.Linear(d_emb, d_emb),
        )

    def forward(self, t):
        """Compute timestep embedding.

        Args:
            t: timesteps [batch_size] (integer tensor)

        Returns:
            embeddings [batch_size, d_emb]
        """
        # Step 1: sinusoidal encoding (no learned parameters)
        sin_emb = sinusoidal_embedding(t, self.d_sinusoidal)
        # Step 2: MLP refinement (learned parameters)
        return self.mlp(sin_emb)


# Create the timestep embedding module
torch.manual_seed(42)
t_embed = TimestepEmbedding(d_sinusoidal=256, d_emb=512)

# Count parameters
n_params = sum(p.numel() for p in t_embed.parameters())
print(f'TimestepEmbedding parameters: {n_params:,}')
print(f'  Linear(256, 512): {256 * 512 + 512:,}  (weight + bias)')
print(f'  Linear(512, 512): {512 * 512 + 512:,}  (weight + bias)')
print(f'  Total:            {256*512 + 512 + 512*512 + 512:,}')
print()
print('Note: the sinusoidal encoding has ZERO parameters.')
print('All learning happens in the MLP.')

In [None]:
# Compute embeddings for a range of timesteps and visualize the similarity structure
# We use the UNTRAINED MLP to see what structure the sinusoidal encoding preserves
# even through a random (untrained) network.

sample_timesteps = torch.arange(0, 1000, 10)  # every 10th timestep

with torch.no_grad():
    # Sinusoidal only (no MLP)
    sin_embeds = sinusoidal_embedding(sample_timesteps, d_emb=256)
    # Sinusoidal + MLP
    mlp_embeds = t_embed(sample_timesteps)

# Compute cosine similarity matrices
sin_normed = F.normalize(sin_embeds, dim=1)
mlp_normed = F.normalize(mlp_embeds, dim=1)

sin_sim_matrix = sin_normed @ sin_normed.T
mlp_sim_matrix = mlp_normed @ mlp_normed.T

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

im1 = ax1.imshow(sin_sim_matrix.numpy(), cmap='RdBu_r', vmin=-1, vmax=1,
                  extent=[0, 1000, 1000, 0])
ax1.set_xlabel('Timestep', fontsize=11)
ax1.set_ylabel('Timestep', fontsize=11)
ax1.set_title('Cosine Similarity: Sinusoidal Only', fontsize=12)
plt.colorbar(im1, ax=ax1, shrink=0.8)

im2 = ax2.imshow(mlp_sim_matrix.numpy(), cmap='RdBu_r', vmin=-1, vmax=1,
                  extent=[0, 1000, 1000, 0])
ax2.set_xlabel('Timestep', fontsize=11)
ax2.set_ylabel('Timestep', fontsize=11)
ax2.set_title('Cosine Similarity: Sinusoidal + MLP (untrained)', fontsize=12)
plt.colorbar(im2, ax=ax2, shrink=0.8)

plt.tight_layout()
plt.show()

print('Left: sinusoidal encoding only. Nearby timesteps are similar (bright diagonal band).')
print('      Distant timesteps are dissimilar (dark off-diagonal regions).')
print()
print('Right: after the MLP (even untrained!). The smoothness structure is PRESERVED.')
print('       The MLP is a smooth function -- it cannot destroy the neighborhood')
print('       structure of its input. Adjacent timesteps stay similar.')
print()
print('After training, the MLP will refine WHICH aspects of the timestep are emphasized,')
print('but the smooth neighborhood structure is baked in by the sinusoidal encoding.')

In [None]:
# Verify specific pairs -- quantitative smoothness check
with torch.no_grad():
    pairs = [(500, 501), (500, 510), (500, 50), (500, 0)]

    print('Cosine similarity between timestep pairs:')
    print(f'{"Pair":>20s}    {"Sinusoidal":>12s}    {"Sin + MLP":>12s}')
    print('-' * 52)

    for t_a, t_b in pairs:
        sin_a = sinusoidal_embedding(torch.tensor([t_a]), d_emb=256)
        sin_b = sinusoidal_embedding(torch.tensor([t_b]), d_emb=256)
        mlp_a = t_embed(torch.tensor([t_a]))
        mlp_b = t_embed(torch.tensor([t_b]))

        sin_s = F.cosine_similarity(sin_a, sin_b).item()
        mlp_s = F.cosine_similarity(mlp_a, mlp_b).item()
        print(f'  t={t_a} vs t={t_b:>4d}    {sin_s:+.4f}         {mlp_s:+.4f}')

print()
print('Key observation: the similarity DECREASES as timesteps get further apart.')
print('This holds both before and after the MLP.')
print('The network gets a rich signal about how different two noise levels are.')

### What Just Happened

You built the complete timestep embedding pipeline: sinusoidal encoding followed by a 2-layer MLP.

The cosine similarity matrices confirm two things:
1. **The sinusoidal encoding provides smooth structure.** Nearby timesteps have similar embeddings, distant timesteps have different embeddings. This is the same smoothness property from positional encoding.
2. **The MLP preserves this structure.** Even an untrained MLP cannot destroy the neighborhood structure -- it is a continuous function. After training, it will refine *which aspects* of the timestep are emphasized for denoising, but the smooth foundation is baked in by the sinusoidal encoding.

The sinusoidal encoding has zero learned parameters. The MLP has ~393K parameters. Together they produce a 512-dimensional embedding that the U-Net uses at every residual block.

---

## Exercise 3: Adaptive Group Normalization [Supported]

You know how standard group normalization works: normalize the features, then apply a learned scale $\gamma$ and shift $\beta$. Those are **fixed parameters** -- the same $\gamma$ and $\beta$ for every input, every timestep.

Adaptive group normalization makes one change: $\gamma$ and $\beta$ **depend on the timestep**.

$$\text{AdaGN}(x, t) = \gamma(t) \cdot \text{GroupNorm}(x) + \beta(t)$$

where $[\gamma(t), \beta(t)] = \text{Linear}(\text{emb}_t)$

The normalization step is standard. The scale and shift come from a linear projection of the timestep embedding rather than a fixed parameter table.

**Your task:** Implement the `AdaGN` module. The structure is provided with `# TODO` markers. Each TODO is 1--2 lines.

**Hints:**
- `nn.GroupNorm(num_groups, num_channels, affine=False)` gives you group norm WITHOUT the learned gamma/beta (we replace those with timestep-dependent ones).
- The linear projection maps from `d_emb` to `2 * num_channels` -- half for gamma, half for beta.
- Use `.chunk(2, dim=-1)` to split the projection output into gamma and beta.
- Remember to reshape gamma and beta for broadcasting: `[batch, channels]` -> `[batch, channels, 1, 1]`.

In [None]:
class AdaGN(nn.Module):
    """Adaptive Group Normalization.

    Standard group norm + timestep-dependent scale (gamma) and shift (beta).
    The normalization is standard. The adaptation is in where gamma and beta
    come from: not from a fixed parameter table, but from a linear projection
    of the timestep embedding.

    Args:
        num_channels: number of feature map channels
        d_emb: dimension of the timestep embedding
        num_groups: number of groups for group norm (default 32)
    """

    def __init__(self, num_channels, d_emb, num_groups=32):
        super().__init__()

        # Group norm WITHOUT learned affine parameters (affine=False)
        # We replace gamma/beta with timestep-dependent versions
        # TODO: Create self.norm using nn.GroupNorm with affine=False
        # Hint: nn.GroupNorm(num_groups, num_channels, affine=False)


        # Linear projection: timestep embedding -> [gamma, beta]
        # Maps from d_emb to 2 * num_channels (half for gamma, half for beta)
        # TODO: Create self.proj using nn.Linear
        # Hint: nn.Linear(d_emb, 2 * num_channels)


    def forward(self, x, t_emb):
        """Apply adaptive group normalization.

        Args:
            x: feature maps [batch, channels, height, width]
            t_emb: timestep embedding [batch, d_emb]

        Returns:
            modulated features [batch, channels, height, width]
        """
        # Step 1: standard group norm (normalize only, no affine)
        h = self.norm(x)

        # Step 2: project timestep embedding to gamma and beta
        # TODO: Use self.proj to get gamma_beta from t_emb, then split
        # Hint: gamma, beta = self.proj(t_emb).chunk(2, dim=-1)


        # Step 3: reshape for broadcasting [batch, channels] -> [batch, channels, 1, 1]
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)

        # Step 4: apply scale and shift
        # Standard norm uses: gamma * norm(x) + beta
        # We add 1 to gamma so that the default (untrained) behavior is identity-like:
        # (1 + 0) * norm(x) + 0 = norm(x)
        # TODO: Return (1 + gamma) * h + beta


print('AdaGN class defined. Fill in the TODO markers before running the next cell.')

<details>
<summary>Solution</summary>

The three fills implement the adaptive group normalization mechanism. The key insight: the normalization is standard group norm. The only new part is where gamma and beta come from.

**`__init__` -- Group norm and projection:**
```python
self.norm = nn.GroupNorm(num_groups, num_channels, affine=False)
self.proj = nn.Linear(d_emb, 2 * num_channels)
```
`affine=False` tells PyTorch not to learn gamma/beta parameters -- we are replacing them with timestep-dependent versions. The linear projection maps the 512-dim timestep embedding to `2 * num_channels` values: half become gamma (scale), half become beta (shift).

**`forward` -- Project and split:**
```python
gamma, beta = self.proj(t_emb).chunk(2, dim=-1)
```
One linear layer produces both gamma and beta. `.chunk(2, dim=-1)` splits the output in half along the last dimension. This is the same pattern as producing Q and K from one projection.

**`forward` -- Apply modulation:**
```python
return (1 + gamma) * h + beta
```
Adding 1 to gamma means the default behavior (when gamma is near 0) is close to identity. Without the `+ 1`, the network would start by zeroing out all features, making training harder.

</details>

In [None]:
# Test the AdaGN module: same feature map, different timesteps -> different outputs

torch.manual_seed(42)

# Create the modules
adagn = AdaGN(num_channels=64, d_emb=512, num_groups=32)
t_embed_test = TimestepEmbedding(d_sinusoidal=256, d_emb=512)

# Create a dummy feature map (same for both timesteps)
feature_map = torch.randn(1, 64, 16, 16)  # [batch=1, channels=64, h=16, w=16]

# Compute timestep embeddings for two different timesteps
with torch.no_grad():
    emb_500 = t_embed_test(torch.tensor([500]))  # [1, 512]
    emb_50 = t_embed_test(torch.tensor([50]))     # [1, 512]

    # Apply AdaGN with different timesteps
    out_500 = adagn(feature_map, emb_500)  # same feature map, t=500
    out_50 = adagn(feature_map, emb_50)    # same feature map, t=50

print(f'Input feature map shape:  {feature_map.shape}')
print(f'Output shape (t=500):     {out_500.shape}')
print(f'Output shape (t=50):      {out_50.shape}')
print()

# Are the outputs different?
diff = (out_500 - out_50).abs()
print(f'Are outputs identical?    {torch.allclose(out_500, out_50)}')
print(f'Mean absolute difference: {diff.mean():.4f}')
print(f'Max absolute difference:  {diff.max():.4f}')
print()
print('Same feature map, same normalization, same conv weights.')
print('Different timestep -> different gamma(t) and beta(t) -> different output.')
print('This is the core mechanism: the network\'s behavior changes with the noise level.')

In [None]:
# Visualize what gamma(t) and beta(t) look like for different timesteps
# This shows the "conductor's score" -- different dynamics for different timesteps

with torch.no_grad():
    test_timesteps = [0, 250, 500, 750, 999]
    gammas = []
    betas = []

    for t_val in test_timesteps:
        emb = t_embed_test(torch.tensor([t_val]))
        gamma_beta = adagn.proj(emb)  # [1, 2 * 64]
        gamma, beta = gamma_beta.chunk(2, dim=-1)  # each [1, 64]
        gammas.append(gamma.squeeze().numpy())
        betas.append(beta.squeeze().numpy())

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

for i, t_val in enumerate(test_timesteps):
    ax1.plot(gammas[i][:32], label=f't={t_val}', alpha=0.8, linewidth=1.5)
    ax2.plot(betas[i][:32], label=f't={t_val}', alpha=0.8, linewidth=1.5)

ax1.set_xlabel('Channel index (first 32)', fontsize=11)
ax1.set_ylabel('gamma(t)', fontsize=11)
ax1.set_title('Scale parameters per channel (different timesteps)', fontsize=12)
ax1.legend(fontsize=9)
ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.3)

ax2.set_xlabel('Channel index (first 32)', fontsize=11)
ax2.set_ylabel('beta(t)', fontsize=11)
ax2.set_title('Shift parameters per channel (different timesteps)', fontsize=12)
ax2.legend(fontsize=9)
ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.3)

plt.tight_layout()
plt.show()

print('Each timestep produces DIFFERENT gamma and beta values for each channel.')
print('This is the conductor\'s score: different measures (timesteps) call for')
print('different dynamics (scale) and different keys (shift).')
print()
print('Note: these are from an UNTRAINED network, so the patterns are random.')
print('After training, the network would learn meaningful gamma/beta patterns:')
print('  - At high t: amplify structural channels, suppress detail channels')
print('  - At low t: amplify detail channels for fine-grained refinement')

### What Just Happened

You implemented adaptive group normalization -- the mechanism that lets the U-Net change its behavior based on the noise level. The key points:

1. **The normalization is standard group norm.** Nothing changes about how features are normalized (zero mean, unit variance within each group).
2. **The only change is where gamma and beta come from.** Instead of a fixed parameter table, they come from a linear projection of the timestep embedding. One line of conceptual change.
3. **Same feature map + different timestep = different output.** The conv weights are fixed. The architecture is fixed. What changes are the scale and shift parameters after normalization, which changes the effective behavior of the network.

In the full U-Net, every residual block has its own AdaGN module with its own linear projection. Same timestep embedding, different "learned lens" per block -- the same pattern from Q/K/V in attention.

---

## Exercise 4: Compare Simple vs Sinusoidal Conditioning [Independent]

In *Build a Diffusion Model*, the capstone used a simple approach: normalize t to [0, 1], pass through a 2-layer MLP, add to the bottleneck. That "worked" for MNIST. This exercise tests whether the sinusoidal approach is actually better.

**Your task:**
1. A minimal U-Net with the capstone's **simple linear** timestep embedding is provided below.
2. Create a second version that replaces the simple embedding with **sinusoidal + MLP** conditioning.
3. Train both for 10 epochs on MNIST and compare loss curves.

**What to change:** Replace the `SimpleTimestepEmbed` module with the `TimestepEmbedding` (sinusoidal + MLP) you built in Exercise 2, and adjust the embedding injection accordingly.

**Expected result:** The sinusoidal version should converge faster or to a lower loss. The sinusoidal encoding gives the network a richer, multi-frequency starting point rather than mapping everything onto a single learned direction.

**Note:** This exercise requires GPU for reasonable training times. If you are on CPU, read the solution and discussion instead.

In [None]:
# =====================================================================
# Shared training infrastructure
# =====================================================================

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time

# Diffusion hyperparameters (same as capstone)
T = 1000
beta_min = 0.0001
beta_max = 0.02
betas = torch.linspace(beta_min, beta_max, T)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)


def q_sample(x_0, t, alpha_bars, noise=None):
    """Forward process: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * eps"""
    if noise is None:
        noise = torch.randn_like(x_0)
    alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1)
    return torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1.0 - alpha_bar_t) * noise, noise


def train_epoch(model, dataloader, optimizer, alpha_bars, T, device):
    """Train for one epoch. Returns average loss."""
    model.train()
    total_loss = 0.0
    num_batches = 0
    for batch_images, _ in dataloader:
        x_0 = batch_images.to(device)
        batch_size = x_0.shape[0]
        t = torch.randint(0, T, (batch_size,), device=device)
        noise = torch.randn_like(x_0)
        x_t, noise = q_sample(x_0, t, alpha_bars, noise=noise)
        epsilon_hat = model(x_t, t)
        loss = F.mse_loss(epsilon_hat, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        num_batches += 1
    return total_loss / num_batches


# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])
dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, drop_last=True)

alpha_bars_device = alpha_bars.to(device)

print(f'Dataset: {len(dataset)} images, {len(dataloader)} batches per epoch')
print(f'Device: {device}')
print('Training infrastructure ready.')

In [None]:
# =====================================================================
# VERSION A: Simple linear timestep embedding (the capstone approach)
# =====================================================================

class ConvBlock(nn.Module):
    """Two convolutions with BatchNorm and ReLU."""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.conv(x)


class SimpleTimestepEmbed(nn.Module):
    """The capstone's simple approach: normalize t to [0,1], 2-layer MLP."""
    def __init__(self, d_emb=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(1, d_emb),
            nn.ReLU(),
            nn.Linear(d_emb, d_emb),
        )
    def forward(self, t):
        t_normalized = t.float().unsqueeze(1) / T  # [batch, 1]
        return self.mlp(t_normalized)  # [batch, d_emb]


class UNetSimple(nn.Module):
    """Minimal U-Net with SIMPLE timestep embedding (normalize + MLP).
    Same architecture as the capstone: timestep added to bottleneck only."""

    def __init__(self):
        super().__init__()
        self.time_embed = SimpleTimestepEmbed(d_emb=128)

        self.enc1 = ConvBlock(1, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBlock(64, 128)
        self.bottleneck = ConvBlock(128, 256)

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = ConvBlock(128 + 64, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = ConvBlock(64 + 32, 64)

        self.final = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=1),
        )

    def forward(self, x, t):
        t_emb = self.time_embed(t)  # [B, 128]

        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bottleneck(e3)

        # Simple approach: add timestep to bottleneck only
        t_emb_spatial = t_emb.view(-1, 128, 1, 1)
        b[:, :128, :, :] = b[:, :128, :, :] + t_emb_spatial

        d3 = self.dec3(torch.cat([self.up3(b), e2], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e1], dim=1))
        return self.final(d2)


print('UNetSimple defined (capstone-style simple timestep embedding).')
print(f'Parameters: {sum(p.numel() for p in UNetSimple().parameters()):,}')

In [None]:
# =====================================================================
# VERSION B: Your task -- sinusoidal timestep embedding
# =====================================================================
#
# Create UNetSinusoidal: same architecture as UNetSimple, but replace
# SimpleTimestepEmbed with the sinusoidal + MLP pipeline from Exercise 2.
#
# Changes needed:
# 1. Use TimestepEmbedding (sinusoidal + MLP) instead of SimpleTimestepEmbed
# 2. The embedding dimension changes from 128 to 512
# 3. Update the bottleneck injection to match the new dimension
#
# Write the full class below. Use UNetSimple as your starting point.
# The only changes are in the timestep embedding and how it is injected.

# YOUR CODE HERE: define class UNetSinusoidal


print('Define UNetSinusoidal in the cell above before running this.')

<details>
<summary>Solution</summary>

The key insight: the only changes are (1) swapping the timestep embedding module and (2) adjusting the injection dimensions. The spatial architecture (encoder, decoder, skip connections) is completely unchanged.

```python
class UNetSinusoidal(nn.Module):
    """Minimal U-Net with SINUSOIDAL timestep embedding.
    Same spatial architecture as UNetSimple; only the timestep
    embedding and its injection change."""

    def __init__(self):
        super().__init__()
        # Sinusoidal + MLP instead of simple linear
        self.time_embed = TimestepEmbedding(d_sinusoidal=256, d_emb=512)
        # Project 512-dim embedding to 256 channels to match bottleneck
        self.time_proj = nn.Linear(512, 256)

        self.enc1 = ConvBlock(1, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBlock(64, 128)
        self.bottleneck = ConvBlock(128, 256)

        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = ConvBlock(128 + 64, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = ConvBlock(64 + 32, 64)

        self.final = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=1),
        )

    def forward(self, x, t):
        t_emb = self.time_embed(t)          # [B, 512]  (sinusoidal + MLP)
        t_proj = self.time_proj(t_emb)      # [B, 256]  (project to bottleneck size)

        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bottleneck(e3)

        # Add timestep to ALL bottleneck channels (256-dim projection)
        t_spatial = t_proj.view(-1, 256, 1, 1)
        b = b + t_spatial

        d3 = self.dec3(torch.cat([self.up3(b), e2], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e1], dim=1))
        return self.final(d2)
```

The spatial architecture is identical. The changes:
1. `TimestepEmbedding` replaces `SimpleTimestepEmbed` -- sinusoidal encoding + MLP instead of normalize + MLP
2. A `time_proj` linear layer maps the 512-dim embedding to 256 channels (matching the bottleneck width)
3. The embedding is added to ALL 256 bottleneck channels instead of just the first 128

The sinusoidal encoding gives the network a much richer starting representation for the timestep. Instead of a single normalized scalar that the MLP must stretch into 128 dimensions, the sinusoidal encoding provides 256 frequency components that spread the timestep information across many dimensions from the start.

</details>

In [None]:
# =====================================================================
# Train both versions and compare
# =====================================================================

N_EPOCHS = 10

# --- Version A: Simple ---
torch.manual_seed(42)
model_simple = UNetSimple().to(device)
opt_simple = torch.optim.Adam(model_simple.parameters(), lr=2e-4)

print(f'UNetSimple parameters:     {sum(p.numel() for p in model_simple.parameters()):,}')

# --- Version B: Sinusoidal ---
torch.manual_seed(42)
model_sin = UNetSinusoidal().to(device)
opt_sin = torch.optim.Adam(model_sin.parameters(), lr=2e-4)

print(f'UNetSinusoidal parameters: {sum(p.numel() for p in model_sin.parameters()):,}')
print()

losses_simple = []
losses_sin = []

print(f'Training both models for {N_EPOCHS} epochs...')
print(f'{"Epoch":>6s}  {"Simple":>10s}  {"Sinusoidal":>12s}  {"Time":>8s}')
print('-' * 42)

for epoch in range(1, N_EPOCHS + 1):
    start = time.time()

    loss_a = train_epoch(model_simple, dataloader, opt_simple, alpha_bars_device, T, device)
    loss_b = train_epoch(model_sin, dataloader, opt_sin, alpha_bars_device, T, device)

    elapsed = time.time() - start
    losses_simple.append(loss_a)
    losses_sin.append(loss_b)

    print(f'{epoch:>6d}  {loss_a:>10.4f}  {loss_b:>12.4f}  {elapsed:>7.1f}s')

print('\nTraining complete.')

In [None]:
# Plot the comparison
plt.figure(figsize=(10, 5))
plt.plot(range(1, N_EPOCHS + 1), losses_simple, 'o-', color='#f87171',
         linewidth=2, markersize=6, label='Simple (normalize + MLP)')
plt.plot(range(1, N_EPOCHS + 1), losses_sin, 's-', color='#86efac',
         linewidth=2, markersize=6, label='Sinusoidal + MLP')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Average Loss', fontsize=12)
plt.title('Timestep Embedding Comparison: Simple vs Sinusoidal', fontsize=13)
plt.legend(fontsize=11)
plt.tight_layout()
plt.show()

# Summary statistics
print(f'Final loss (Simple):      {losses_simple[-1]:.4f}')
print(f'Final loss (Sinusoidal):  {losses_sin[-1]:.4f}')
print(f'Difference:               {losses_simple[-1] - losses_sin[-1]:+.4f}')
print()

if losses_sin[-1] < losses_simple[-1]:
    print('The sinusoidal version achieves lower loss.')
    print('Why? The sinusoidal encoding spreads timestep information across 256 frequency')
    print('components from the start. The simple approach maps a single scalar through an MLP,')
    print('which must learn ALL the useful representations from scratch.')
    print()
    print('Analogy from the lesson: describing a GPS position with latitude + longitude + altitude')
    print('(sinusoidal) vs a single distance number (simple). The richer input gives the')
    print('network much more to work with.')
else:
    print('Interesting -- the simple version matched or beat sinusoidal on this run.')
    print('This can happen on MNIST because MNIST is easy. The simple approach is sufficient')
    print('for very simple data. The advantage of sinusoidal encoding becomes clearer at scale')
    print('(higher-resolution images, more complex distributions).')
    print()
    print('Even so, the sinusoidal approach has structural advantages: guaranteed smoothness')
    print('between adjacent timesteps, no dependence on learned representations for the')
    print('basic encoding, and a richer input for the MLP to refine.')

### What Just Happened

You compared the capstone's simple timestep embedding to the sinusoidal + MLP approach on the same architecture and data. The sinusoidal version typically converges faster or to a lower loss because:

1. **Richer input representation.** The sinusoidal encoding provides 256 frequency components from the start. The simple approach maps a single scalar, so the MLP must learn all structure from scratch.
2. **Built-in smoothness.** Adjacent timesteps automatically get similar embeddings. The simple approach must learn this smoothness.
3. **Multi-frequency discrimination.** High-frequency components distinguish t=500 from t=501. Low-frequency components distinguish t=500 from t=50. The simple approach has a single number to encode both fine and coarse distinctions.

On MNIST (a simple dataset), the gap may be small. On real diffusion tasks (256x256 or 512x512 images), the sinusoidal approach is dramatically better -- which is why every production diffusion model uses it.

---

## Key Takeaways

1. **Sinusoidal timestep embedding is the same formula as positional encoding.** Replace "position" with "timestep" and you have the diffusion timestep embedding. The multi-frequency encoding provides smooth, unique patterns for each noise level. No learned parameters in the encoding itself.

2. **The MLP refines but preserves structure.** The 2-layer MLP transforms the raw frequencies into features useful for denoising. It cannot destroy the smoothness of its input -- adjacent timesteps stay similar after the MLP.

3. **Adaptive group normalization makes gamma and beta depend on the timestep.** The normalization is standard. The scale and shift are timestep-dependent, computed by a per-block linear projection. Same architecture, same weights -- different behavior at different noise levels.

4. **Sinusoidal encoding outperforms simple linear projection.** The richer, multi-frequency input gives the network more to work with. The built-in smoothness means adjacent timesteps are similar by construction, not by learning.

5. **The mental model: same orchestra, different conductor's score.** The U-Net weights are the instruments. The timestep embedding is the conductor's score. Adaptive normalization is how the conductor communicates dynamics (scale) and key (shift) to each section. Different measures (timesteps) produce different performances from the same orchestra.