# Text Conditioning & Guidance

**Module 6.3, Lesson 4** | CourseAI

In the lesson, you learned how cross-attention injects CLIP text embeddings into the U-Net, and how classifier-free guidance amplifies the text signal at inference time. Now you will build and explore both mechanisms hands-on.

**What you will do:**
- Modify a working self-attention implementation into cross-attention by changing where K and V come from
- Visualize cross-attention weights as a heatmap showing different spatial locations attending to different text tokens
- Implement the CFG formula and observe how guidance scale affects the magnitude of noise predictions
- Generate images from a real text-conditioned diffusion model at different guidance scales and identify the quality/fidelity tradeoff

**For each exercise, PREDICT the output before running the cell.**

Everything builds on what you already know. Cross-attention is the same QKV formula from Module 4.2 — the only change is where K and V come from. CFG is one line of arithmetic. No new theory — just practice.

**Estimated time:** 30–45 minutes.

---

## Setup

Run this cell to install dependencies and import everything.

In [None]:
!pip install -q diffusers transformers accelerate

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Reproducible results
torch.manual_seed(42)
np.random.seed(42)

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

print('\nSetup complete.')

---

## Exercise 1: Cross-Attention from Self-Attention [Guided]

You built the self-attention formula across three lessons in Module 4.2:

$$\text{output} = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$

In self-attention, Q, K, and V are all projected from the **same input**. Cross-attention changes one thing: **K and V come from a different input** (the text embeddings instead of the spatial features).

Below is a working self-attention implementation. We will modify it to perform cross-attention by changing where K and V come from. The code demonstrates both side by side with dummy tensors: a 4×4 spatial feature map (16 locations, dimension 32) and 6 text token embeddings (dimension 32).

**Before running, predict:**
- In self-attention over 16 spatial locations, what is the shape of the attention weight matrix? (`16 × 16` — each location attends to every other location.)
- After changing K and V to come from 6 text tokens instead, what shape will the attention weight matrix be? (No longer square — `16 × 6`, because 16 queries attend over 6 keys.)
- Will the output shape change? (No — the output is always `(num_queries, d_v)`. The number of queries stays 16 and d_v stays 32.)

In [None]:
# ---- Self-attention implementation ----

d_model = 32   # embedding dimension
H, W = 4, 4    # spatial feature map size
n_spatial = H * W  # 16 spatial locations

# Dummy spatial features: (16, 32) -- 16 locations, each with a 32-dim feature vector
# In the real U-Net, these come from a residual block at 16x16 or 32x32 resolution.
x_spatial = torch.randn(n_spatial, d_model)

# Learned projection matrices
W_Q = nn.Linear(d_model, d_model, bias=False)
W_K = nn.Linear(d_model, d_model, bias=False)
W_V = nn.Linear(d_model, d_model, bias=False)

def self_attention(x, W_Q, W_K, W_V):
    """Standard self-attention: Q, K, V all come from the same input x."""
    Q = W_Q(x)    # (n, d_model)
    K = W_K(x)    # (n, d_model)  <-- same input x
    V = W_V(x)    # (n, d_model)  <-- same input x
    
    d_k = Q.shape[-1]
    attn_weights = torch.softmax(Q @ K.T / (d_k ** 0.5), dim=-1)
    output = attn_weights @ V
    return output, attn_weights

# Run self-attention
with torch.no_grad():
    sa_output, sa_weights = self_attention(x_spatial, W_Q, W_K, W_V)

print('=== Self-Attention ===')
print(f'Input shape:             {x_spatial.shape}')      # (16, 32)
print(f'Q shape:                 ({n_spatial}, {d_model})')
print(f'K shape:                 ({n_spatial}, {d_model})  <-- from same input')
print(f'V shape:                 ({n_spatial}, {d_model})  <-- from same input')
print(f'Attention weights shape: {sa_weights.shape}')     # (16, 16) -- SQUARE
print(f'Output shape:            {sa_output.shape}')      # (16, 32)
print()
print(f'The attention matrix is {sa_weights.shape[0]}×{sa_weights.shape[1]} -- SQUARE.')
print(f'Each of {n_spatial} spatial locations attends to every other spatial location.')

In [None]:
# ---- Cross-attention: the one-line change ----
#
# The ONLY difference: K and V are projected from the text embeddings,
# not from the spatial features. Q still comes from spatial features.
# The formula is IDENTICAL: output = softmax(QK^T / sqrt(d_k)) V

T = 6  # number of text tokens (e.g., "a cat sitting in a sunset")

# Dummy text embeddings: (6, 32) -- 6 tokens, each 32-dim
# In the real U-Net, these come from CLIP's text encoder.
x_text = torch.randn(T, d_model)

def cross_attention(x_spatial, x_text, W_Q, W_K, W_V):
    """Cross-attention: Q from spatial features, K and V from text embeddings."""
    Q = W_Q(x_spatial)  # (n_spatial, d_model)
    K = W_K(x_text)     # (T, d_model)  <-- DIFFERENT input!
    V = W_V(x_text)     # (T, d_model)  <-- DIFFERENT input!
    
    d_k = Q.shape[-1]
    attn_weights = torch.softmax(Q @ K.T / (d_k ** 0.5), dim=-1)
    output = attn_weights @ V
    return output, attn_weights

# Run cross-attention
with torch.no_grad():
    ca_output, ca_weights = cross_attention(x_spatial, x_text, W_Q, W_K, W_V)

print('=== Cross-Attention ===')
print(f'Spatial input shape:     {x_spatial.shape}')      # (16, 32)
print(f'Text input shape:        {x_text.shape}')         # (6, 32)
print(f'Q shape:                 ({n_spatial}, {d_model})  <-- from spatial features')
print(f'K shape:                 ({T}, {d_model})           <-- from TEXT')
print(f'V shape:                 ({T}, {d_model})           <-- from TEXT')
print(f'Attention weights shape: {ca_weights.shape}')     # (16, 6) -- NOT SQUARE
print(f'Output shape:            {ca_output.shape}')      # (16, 32)
print()
print(f'The attention matrix is {ca_weights.shape[0]}×{ca_weights.shape[1]} -- RECTANGULAR.')
print(f'Each of {n_spatial} spatial locations attends over {T} text tokens.')
print(f'Different dimensions because Q and K come from different inputs.')
print()
print(f'Output shape is still ({n_spatial}, {d_model}) -- same as self-attention.')
print(f'The number of queries determines the output size, not the number of keys.')
print()
print('--- Comparison ---')
print(f'Self-attention weights: {sa_weights.shape}  (square: spatial × spatial)')
print(f'Cross-attention weights: {ca_weights.shape}  (rectangular: spatial × text)')

In [None]:
# Verify: each row of attention weights sums to 1 (softmax over text tokens)
row_sums = ca_weights.sum(dim=-1)
print('Row sums of cross-attention weights (should all be 1.0):')
print(f'  Min: {row_sums.min():.6f}')
print(f'  Max: {row_sums.max():.6f}')
print()
print('Each spatial location has its own probability distribution over text tokens.')
print('This is what makes text conditioning SPATIALLY VARYING:')
print('different locations can attend to different words.')

### What Just Happened

You modified self-attention into cross-attention with a **one-line change**: K and V now come from the text embeddings instead of the spatial features. Everything else — the dot-product scoring, the softmax, the weighted average — is identical.

The key shape change:
- **Self-attention:** `(16 × 16)` — square, because Q and K come from the same 16 spatial locations.
- **Cross-attention:** `(16 × 6)` — rectangular, because 16 spatial queries attend over 6 text keys.

The output shape stays `(16, 32)` in both cases — one enriched vector per spatial location. In self-attention, each location is enriched by other spatial locations. In cross-attention, each location is enriched by the text.

---

## Exercise 2: Visualize Cross-Attention Weights [Guided]

The lesson showed a table where different spatial regions (cat's face, sky, cat's body) attend to different text tokens with different weights. Now you will compute and visualize this yourself.

We will use the cross-attention function from Exercise 1 with the same 4×4 spatial features and 6 text tokens. The resulting attention weight matrix has shape `(16, 6)` — 16 spatial locations, each with a distribution over 6 text tokens. We will visualize this as a heatmap.

**Before running, predict:**
- Since these are random embeddings (not a trained model), will different spatial locations show different attention patterns? (Yes — different random query vectors will dot-product differently against the key vectors. The patterns will not be meaningful, but they will be different.)
- Which text token will get the most attention on average across all spatial locations? (Cannot predict with random weights — it depends on the random initialization. But the mean attention per token will NOT be uniform, because the random queries will happen to align more with some keys than others.)

In [None]:
# Use the cross-attention weights computed in Exercise 1
# ca_weights has shape (16, 6): 16 spatial locations x 6 text tokens

# We will give the text tokens meaningful labels for the visualization
text_token_labels = ['a', 'cat', 'sitting', 'in', 'a', 'sunset']

# Create spatial position labels (row, col in 4x4 grid)
spatial_labels = [f'({r},{c})' for r in range(H) for c in range(W)]

# ---- Heatmap visualization ----
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6), 
                                gridspec_kw={'width_ratios': [3, 1]})

# Main heatmap
weights_np = ca_weights.numpy()
im = ax1.imshow(weights_np, cmap='magma', aspect='auto')

# Annotate each cell
for i in range(n_spatial):
    for j in range(T):
        val = weights_np[i, j]
        color = 'black' if val > 0.25 else 'white'
        ax1.text(j, i, f'{val:.2f}', ha='center', va='center',
                fontsize=7, color=color)

ax1.set_xticks(range(T))
ax1.set_xticklabels(text_token_labels, fontsize=10)
ax1.set_yticks(range(n_spatial))
ax1.set_yticklabels(spatial_labels, fontsize=8)
ax1.set_xlabel('Text Tokens', fontsize=11)
ax1.set_ylabel('Spatial Location (row, col)', fontsize=11)
ax1.set_title('Cross-Attention Weights: 16 spatial locations \u00d7 6 text tokens', fontsize=12)
plt.colorbar(im, ax=ax1, shrink=0.8)

# Mean attention per text token (bar chart)
mean_attn = weights_np.mean(axis=0)  # average across spatial locations
bars = ax2.barh(range(T), mean_attn, color='#c084fc')
ax2.set_yticks(range(T))
ax2.set_yticklabels(text_token_labels, fontsize=10)
ax2.set_xlabel('Mean Attention', fontsize=10)
ax2.set_title('Avg Attention\nper Token', fontsize=11)
ax2.invert_yaxis()

# Annotate bars
for i, v in enumerate(mean_attn):
    ax2.text(v + 0.005, i, f'{v:.3f}', va='center', fontsize=9)

plt.tight_layout()
plt.show()

print('Observations:')
print('  1. Each row sums to 1.0 (softmax) -- a probability distribution over text tokens.')
print('  2. Different spatial locations have DIFFERENT attention patterns.')
print('     This is what "spatially varying conditioning" means.')
print('  3. With random weights, the patterns are not meaningful.')
print('     In a trained model, spatial locations near the cat would attend')
print('     strongly to "cat" and locations in the sky would attend to "sunset."')
print()

# Which token gets the most and least attention on average?
most_attended = text_token_labels[mean_attn.argmax()]
least_attended = text_token_labels[mean_attn.argmin()]
print(f'  Most attended token (on average):  "{most_attended}" ({mean_attn.max():.3f})')
print(f'  Least attended token (on average): "{least_attended}" ({mean_attn.min():.3f})')
print()
print('  With random embeddings, this is not meaningful. In a real model, content')
print('  words ("cat", "sunset") would receive more attention than function words ("a", "in").')

In [None]:
# Reshape and visualize as a 4x4 spatial grid for each text token
# This shows WHERE in the spatial feature map each text token gets attention.

fig, axes = plt.subplots(1, T, figsize=(15, 3))
for j in range(T):
    # Attention from all spatial locations to text token j
    attn_for_token = weights_np[:, j].reshape(H, W)  # (4, 4)
    im = axes[j].imshow(attn_for_token, cmap='magma', vmin=0, vmax=weights_np.max())
    axes[j].set_title(f'"{text_token_labels[j]}"', fontsize=11)
    axes[j].set_xticks([])
    axes[j].set_yticks([])

plt.suptitle('Cross-Attention: Where each text token is attended to in the 4\u00d74 spatial grid',
             fontsize=12)
plt.tight_layout()
plt.show()

print('Each subplot shows one text token\'s attention across the 4\u00d74 spatial grid.')
print('Brighter = more attention from that spatial location to that text token.')
print()
print('In a trained model, you would see:')
print('  - "cat" is bright in the region where the cat is')
print('  - "sunset" is bright in the sky region')
print('  - Function words ("a", "in") are dim everywhere')
print()
print('This is the spatially-varying nature of text conditioning:')
print('each spatial location extracts different information from the text.')

### What Just Happened

You visualized the cross-attention weight matrix two ways:

1. **Heatmap (16 rows × 6 columns):** Each row is one spatial location's distribution over text tokens. Different rows have different patterns — this is spatially-varying conditioning.
2. **Spatial grid per token:** For each text token, you saw where in the 4×4 feature map it receives the most attention. In a trained model, content words like "cat" and "sunset" would light up in the corresponding spatial regions.

The key insight: **cross-attention gives each spatial location its own text signal.** The cat region gets "cat" information. The sky gets "sunset" information. This is fundamentally different from timestep conditioning, which injects the same signal everywhere.

---

## Exercise 3: Implement Classifier-Free Guidance [Supported]

CFG is a simple formula applied at inference time:

$$\epsilon_{\text{cfg}} = \epsilon_{\text{uncond}} + w \cdot (\epsilon_{\text{cond}} - \epsilon_{\text{uncond}})$$

Where:
- $\epsilon_{\text{uncond}}$ = model's prediction without text (null embedding)
- $\epsilon_{\text{cond}}$ = model's prediction with text
- $w$ = guidance scale (typically 7.5 for Stable Diffusion)

The difference $(\epsilon_{\text{cond}} - \epsilon_{\text{uncond}})$ is the **text direction** — the effect the text has on the model's prediction. CFG amplifies this direction by $w$.

Below is a dummy model that produces different noise predictions depending on whether it receives a text embedding or a null embedding. **Your task:** implement the `apply_cfg` function and plot how the L2 norm of the CFG prediction changes with guidance scale.

**Hints:**
- The CFG formula is one line: `noise_uncond + w * (noise_cond - noise_uncond)`
- `torch.norm(tensor)` computes the L2 norm
- Higher guidance scale should produce larger-magnitude predictions

In [None]:
# A dummy "model" that simulates conditional and unconditional noise predictions.
# In a real diffusion model, this would be two forward passes through the U-Net:
#   noise_uncond = unet(x_t, t, null_embedding)
#   noise_cond   = unet(x_t, t, text_embedding)

# Simulate noise predictions for a 4x4 spatial feature map (flattened to 16 dims)
torch.manual_seed(42)

# Unconditional prediction: what the model predicts without any text
noise_uncond = torch.randn(16) * 0.5  # moderate magnitude

# Conditional prediction: slightly different -- the text nudges the prediction
# In reality, the difference is small but meaningful.
text_direction = torch.tensor([
    0.08, -0.05, 0.12, -0.03, 0.07, -0.09, 0.04, 0.11,
    -0.06, 0.10, -0.04, 0.08, -0.07, 0.05, 0.09, -0.06
])
noise_cond = noise_uncond + text_direction  # conditional = unconditional + text effect

print('Noise predictions (first 6 dimensions):')
print(f'  Unconditional: {noise_uncond[:6].tolist()}')
print(f'  Conditional:   {noise_cond[:6].tolist()}')
print(f'  Difference:    {text_direction[:6].tolist()}')
print()
print(f'The difference (text direction) is small: L2 norm = {text_direction.norm():.4f}')
print(f'Unconditional prediction L2 norm: {noise_uncond.norm():.4f}')
print(f'Conditional prediction L2 norm:   {noise_cond.norm():.4f}')
print()
print('The text only nudges the prediction slightly. CFG will amplify this nudge.')

In [None]:
def apply_cfg(noise_uncond, noise_cond, guidance_scale):
    """Apply classifier-free guidance.
    
    Args:
        noise_uncond: model's noise prediction without text conditioning
        noise_cond: model's noise prediction with text conditioning
        guidance_scale: the weight w that amplifies the text direction
    
    Returns:
        The guided noise prediction.
    """
    # TODO: Implement the CFG formula.
    # noise_cfg = noise_uncond + w * (noise_cond - noise_uncond)
    # Hint: it is literally one line.
    pass


# Test at specific guidance scales
test_scales = [0, 1, 3, 7.5, 15]

print('CFG predictions at different guidance scales:')
print('=' * 65)
print(f'{"w":>6s}  {"L2 Norm":>10s}  First 4 dimensions')
print('-' * 65)

for w in test_scales:
    result = apply_cfg(noise_uncond, noise_cond, w)
    norm = result.norm().item()
    dims = result[:4].tolist()
    dims_str = ', '.join(f'{d:+.3f}' for d in dims)
    print(f'{w:>6.1f}  {norm:>10.4f}  [{dims_str}]')

print()
print('Observations:')
print('  w=0:   Unconditional prediction (text ignored entirely)')
print('  w=1:   Conditional prediction (no amplification)')
print('  w=7.5: Typical Stable Diffusion -- text effect amplified 7.5x')
print('  w=15:  Aggressive -- text dominates, large magnitude prediction')

In [None]:
# TODO: Plot L2 norm of CFG prediction vs guidance scale.
#
# Create a range of guidance scales from 0 to 20 (use torch.linspace or np.linspace).
# For each scale, apply CFG and compute the L2 norm of the result.
# Plot guidance_scale (x-axis) vs L2_norm (y-axis).
#
# Structure is provided -- fill in the TODOs.

guidance_scales = np.linspace(0, 20, 100)
norms = []

for w in guidance_scales:
    # TODO: Apply CFG at this guidance scale and compute the L2 norm.
    # Hint: result = apply_cfg(noise_uncond, noise_cond, w)
    #        norms.append(result.norm().item())
    pass

# Plot (only if norms were computed -- fill in the loop above first)
fig, ax = plt.subplots(figsize=(10, 5))

if norms:
    ax.plot(guidance_scales, norms, color='#c084fc', linewidth=2)

    # Mark the specific guidance scales from the test above
    for w in test_scales:
        result = apply_cfg(noise_uncond, noise_cond, w)
        if result is not None:
            norm_val = result.norm().item()
            ax.plot(w, norm_val, 'o', color='#f59e0b', markersize=8, zorder=5)
            ax.annotate(f'w={w}', (w, norm_val), textcoords='offset points',
                        xytext=(8, 8), fontsize=9, color='#f59e0b')

    ax.set_xlabel('Guidance Scale (w)', fontsize=12)
    ax.set_ylabel('L2 Norm of CFG Prediction', fontsize=12)
    ax.set_title('CFG Prediction Magnitude vs Guidance Scale', fontsize=13)
    ax.axvline(x=7.5, color='#3b82f6', linestyle='--', alpha=0.5, label='Typical SD default (w=7.5)')
    ax.axvline(x=1.0, color='#22c55e', linestyle='--', alpha=0.5, label='No guidance (w=1)')
    ax.legend(fontsize=10)
    ax.grid(alpha=0.2)
    plt.tight_layout()
    plt.show()

    print('The L2 norm increases roughly linearly with guidance scale.')
    print('Higher w = more aggressive denoising in the text direction.')
    print('At extreme scales, the model over-commits to the text signal,')
    print('producing oversaturated, artifact-heavy images in practice.')
else:
    plt.close(fig)
    print('⚠ norms list is empty. Fill in the TODO loop above to compute')
    print('  the L2 norm at each guidance scale, then re-run this cell.')

<details>
<summary>Solution</summary>

The CFG formula is one line of vector arithmetic. The key insight: `(noise_cond - noise_uncond)` is the **text direction** — the vector that captures how the text changes the model's prediction. Multiplying by `w` amplifies this direction.

**`apply_cfg` function:**
```python
def apply_cfg(noise_uncond, noise_cond, guidance_scale):
    return noise_uncond + guidance_scale * (noise_cond - noise_uncond)
```

**Norm computation loop:**
```python
for w in guidance_scales:
    result = apply_cfg(noise_uncond, noise_cond, w)
    norms.append(result.norm().item())
```

At `w=0`, you get the unconditional prediction (text ignored). At `w=1`, you get the conditional prediction (no amplification). At `w > 1`, you **extrapolate** beyond the conditional prediction in the text direction. The L2 norm grows because the amplified text direction adds magnitude to the prediction vector.

The equivalent form is: `noise_cfg = (1 - w) * noise_uncond + w * noise_cond`. At `w > 1`, the coefficient on `noise_uncond` is **negative** — you are actively subtracting the unconditional component and adding more than 100% of the conditional component. This is extrapolation, not interpolation.

Common mistake: writing `noise_cond + w * (noise_cond - noise_uncond)` (starting from conditional instead of unconditional). This doubles the text effect at `w=1` instead of recovering the conditional prediction.

</details>

### What Just Happened

You implemented the CFG formula and observed:

1. **The formula is one line** of vector arithmetic: `noise_uncond + w * (noise_cond - noise_uncond)`.
2. **`w=0` recovers unconditional, `w=1` recovers conditional.** Values in between interpolate; values above 1 extrapolate.
3. **Higher guidance scale = larger prediction magnitude.** The model commits more aggressively to the text direction. In practice, this means more vivid, text-faithful images — up to a point, after which it causes artifacts.

The text direction `(noise_cond - noise_uncond)` is a small vector — the text only nudges the prediction. CFG amplifies this nudge by `w`. At `w=7.5`, a nudge of 0.08 becomes a push of 0.60. At `w=15`, it becomes 1.20. The model goes from "gently considering the text" to "aggressively optimizing for the text."

---

## Exercise 4: CFG with a Real Diffusion Model [Independent]

Now for the payoff: use a real text-conditioned diffusion model to generate images at different guidance scales and see the quality/fidelity tradeoff with your own eyes.

**Your task:**
1. Load a small pretrained text-to-image diffusion pipeline
2. Choose a text prompt (something with clear visual content, e.g., "a lighthouse on a cliff at sunset")
3. Generate images at guidance scales: `w = 1, 3, 7.5, 12, 20`
4. Display them side by side
5. Observe: which guidance scale looks best? At what point do artifacts appear?

**No skeleton code is provided.** Use the `diffusers` library, which was installed in the setup cell.

**Key tips:**
- `StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-small", torch_dtype=torch.float16)` loads a small distilled version of SD v1.5 (fast, low VRAM)
- `.to(device)` moves it to GPU
- `pipe(prompt, guidance_scale=w, num_inference_steps=30).images[0]` generates one image
- Use the **same random seed** for each guidance scale (via `generator=torch.Generator(device).manual_seed(42)`) so the only variable is the guidance scale
- **This requires a GPU.** If running on CPU, reduce `num_inference_steps` to 15 and expect slow generation.
- For the full-size model, use `"stable-diffusion-v1-5/stable-diffusion-v1-5"` instead (requires more VRAM)

**Reflection questions** (answer after generating):
- At `w=1`, does the image match the prompt at all?
- What is the sweet spot? Where do you see the best balance of quality and text fidelity?
- At `w=20`, what specific artifacts do you notice? (oversaturation, distortion, repetitive patterns?)
- Why does extreme guidance produce artifacts? (Think about what the formula does: it extrapolates far beyond the conditional prediction in the text direction, pushing the noise prediction into a region the model was never trained to operate in.)

In [None]:
# YOUR CODE HERE
#
# Steps:
# 1. Load the pipeline
# 2. Choose a prompt
# 3. Generate images at w = 1, 3, 7.5, 12, 20 (same seed each time)
# 4. Display side by side with matplotlib
# 5. Answer the reflection questions


<details>
<summary>Solution</summary>

The experiment reveals the guidance scale tradeoff visually. The key insight: CFG is not "free quality" — it is a dial that trades diversity for text fidelity, with artifacts appearing at the extreme end.

```python
from diffusers import StableDiffusionPipeline

# Load a small distilled model for speed (same architecture as SD v1.5, fewer parameters)
# If you have enough VRAM, you can use "stable-diffusion-v1-5/stable-diffusion-v1-5" instead.
pipe = StableDiffusionPipeline.from_pretrained(
    "nota-ai/bk-sdm-small",
    torch_dtype=torch.float16
)
pipe = pipe.to(device)

prompt = "a lighthouse on a cliff at sunset, oil painting"
guidance_scales = [1.0, 3.0, 7.5, 12.0, 20.0]
images = []

for w in guidance_scales:
    # Same seed for every generation so the only variable is guidance scale
    generator = torch.Generator(device=device).manual_seed(42)
    image = pipe(
        prompt,
        guidance_scale=w,
        num_inference_steps=30,
        generator=generator
    ).images[0]
    images.append(image)
    print(f'  Generated at w={w}')

# Display side by side
fig, axes = plt.subplots(1, len(guidance_scales), figsize=(20, 4))
for ax, img, w in zip(axes, images, guidance_scales):
    ax.imshow(img)
    ax.set_title(f'w = {w}', fontsize=12)
    ax.axis('off')

plt.suptitle(f'"{prompt}" at different guidance scales', fontsize=13)
plt.tight_layout()
plt.show()

print('Observations:')
print('  w=1:   Diverse but may not follow the prompt well. Muted colors.')
print('  w=3:   Starting to follow the prompt. Better composition.')
print('  w=7.5: Strong text adherence + good image quality. The sweet spot.')
print('  w=12:  Very saturated. Starting to see artifacts.')
print('  w=20:  Oversaturated, distorted, artifact-heavy.')
print()
print('Why artifacts at high w?')
print('  CFG extrapolates far beyond the conditional prediction.')
print('  The noise prediction lands in a region the model was never')
print('  trained to operate in. The denoising process breaks down.')
```

**If you do not have a GPU:** You can still run this with `num_inference_steps=10` on CPU, but it will be slow (~2-5 minutes per image). Alternatively, use the plot from Exercise 3 and the lesson's GradientCards to understand the tradeoff conceptually.

**What to notice:** The sweet spot is typically around `w=7.5`. Below that, images are diverse but may not follow the prompt. Above that, the model over-optimizes for the text at the expense of image coherence. This is the tradeoff the lesson describes: CFG is a contrast slider, and too much contrast destroys the image.

</details>

---

## Key Takeaways

1. **Cross-attention is the same QKV formula as self-attention — the only change is where K and V come from.** Q comes from the U-Net's spatial features, K and V come from CLIP text embeddings. Same softmax, same weighted average, same "learned lens" pattern. The attention weight matrix changes from square (spatial × spatial) to rectangular (spatial × text).

2. **Cross-attention creates spatially-varying text conditioning.** Each spatial location generates its own query and gets its own distribution over text tokens. The cat region attends to "cat," the sky attends to "sunset." This is fundamentally different from timestep conditioning, which injects the same signal everywhere.

3. **CFG is one line of arithmetic that amplifies the text signal.** `noise_cfg = noise_uncond + w * (noise_cond - noise_uncond)`. The difference `(noise_cond - noise_uncond)` is the text direction — the effect the text has on the prediction. The guidance scale `w` controls how aggressively to follow the text.

4. **The guidance scale is a tradeoff, not a free improvement.** Low `w` = diverse but unfaithful to text. Medium `w` (7.5) = the sweet spot. High `w` = oversaturated, distorted, artifact-heavy. The model is extrapolating beyond its training distribution.

5. **The full conditioning pipeline is now clear.** Timestep via adaptive norm (global, every resolution) tells the network WHEN. Text via cross-attention (spatially varying, middle resolutions) tells it WHAT. CFG turns up the volume on the WHAT.