# SD3 & Flux

**Module 7.4, Lesson 3** | CourseAI

SD3 and Flux combine the DiT architecture with joint text-image attention (MMDiT), T5-XXL text encoding, and flow matching. Every component traces to a lesson you have already completed. This notebook makes those components concrete and inspectable.

**What you will do:**
- Load SD3 Medium and inspect the triple text encoder setup: verify parameter counts, embedding shapes, and the scale difference between CLIP and T5-XXL
- Extract attention weights from an MMDiT block and visualize the four quadrants of joint attention: text-to-text, text-to-image, image-to-text, image-to-image
- Generate images with SD3 at different step counts (10, 20, 30, 50) and observe the flow matching payoff: good results in 20–30 steps
- Trace the full SD3 pipeline end-to-end, capturing intermediate outputs and annotating each step with the lesson that covered it

**For each exercise, PREDICT the output before running the cell.**

Every concept in this notebook comes from the lesson. MMDiT as "one room, one conversation," T5-XXL as a language model complementing CLIP, flow matching as the same objective you trained with in **Flow Matching**. No new theory—just hands-on verification of what you just read.

**Estimated time:** 60–90 minutes. All exercises require a GPU runtime with at least 16 GB VRAM (Colab A100 recommended, T4 with float16 may work for some exercises). SD3 Medium is a large model (~12 GB in float16).

## Setup

Run this cell to install dependencies and configure the environment.

**Important:** Switch to a GPU runtime in Colab (Runtime > Change runtime type > A100 GPU). SD3 Medium requires ~12 GB VRAM in float16. A T4 (16 GB) is tight but may work; an A100 (40 GB) is comfortable.

In [None]:
!pip install -q diffusers transformers accelerate safetensors sentencepiece protobuf

In [None]:
import torch
import torch.nn as nn
import gc
import time
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from IPython.display import display

# Reproducible results
torch.manual_seed(42)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [14, 5]
plt.rcParams['figure.dpi'] = 100

print(f'Device: {device}')
print(f'Dtype: {dtype}')
if device.type == 'cpu':
    print('WARNING: No GPU detected. All exercises in this notebook require a GPU.')
    print('Switch to GPU: Runtime > Change runtime type > A100 GPU')
print()
print('Setup complete.')

## Shared Helpers

Utility functions used across multiple exercises. Run this cell now.

In [None]:
def count_parameters(model):
    """Count total parameters in a model."""
    return sum(p.numel() for p in model.parameters())


def count_parameters_by_type(model):
    """Count parameters grouped by module type (e.g., Linear, LayerNorm)."""
    counts = {}
    for name, module in model.named_modules():
        module_type = module.__class__.__name__
        if module_type == type(model).__name__:
            continue
        n_params = sum(p.numel() for p in module.parameters(recurse=False))
        if n_params > 0:
            counts[module_type] = counts.get(module_type, 0) + n_params
    return dict(sorted(counts.items(), key=lambda x: -x[1]))


def show_image_row(images, titles, suptitle=None, figsize=None):
    """Display a row of PIL images with titles."""
    n = len(images)
    fig_w = figsize[0] if figsize else max(5 * n, 12)
    fig_h = figsize[1] if figsize else 5
    fig, axes = plt.subplots(1, n, figsize=(fig_w, fig_h))
    if n == 1:
        axes = [axes]
    for ax, img, title in zip(axes, images, titles):
        ax.imshow(img)
        ax.set_title(title, fontsize=10)
        ax.axis('off')
    if suptitle:
        plt.suptitle(suptitle, fontsize=13, y=1.02)
    plt.tight_layout()
    plt.show()


def free_memory(*objects):
    """Delete objects and free GPU memory."""
    for obj in objects:
        del obj
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print('Memory freed.')


print('Helpers defined: count_parameters, count_parameters_by_type, show_image_row, free_memory')

---

## Exercise 1: SD3 Pipeline Inspection `[Guided]`

The lesson taught that SD3 uses **three text encoders simultaneously**: CLIP ViT-L `[77, 768]`, OpenCLIP ViT-bigG `[77, 1280]`, and T5-XXL `[77, 4096]`. Each provides different information—CLIP gives image-aligned text understanding, T5 gives deep linguistic understanding. The denoising network is an MMDiT (Multimodal Diffusion Transformer) with modality-specific Q/K/V projections.

In this exercise, you will load SD3 Medium and verify all of this concretely:
1. Print the text encoder class names and parameter counts
2. Verify the embedding output shapes match the lesson's tensor shape trace
3. Inspect the MMDiT transformer: count parameters, identify modality-specific projections
4. Compare text encoder parameters vs denoising network parameters

**Before running, predict:**
- SD3 uses T5-XXL with 4.7B parameters and a denoising network (MMDiT) with ~2B parameters. Which component has more parameters?
- The lesson showed CLIP ViT-L produces `[77, 768]` embeddings and T5-XXL produces `[77, 4096]`. How many times wider are T5's embeddings?

In [None]:
# ============================================================
# Exercise 1: Load SD3 Medium and inspect its components
# ============================================================

# --- Step 1: Load SD3 Medium ---
# SD3 Medium is available via the diffusers StableDiffusion3Pipeline.
# We load in float16 to fit in GPU memory.
# NOTE: You may need to accept the model license on HuggingFace and
# log in with `huggingface-cli login` or pass a token.

from diffusers import StableDiffusion3Pipeline

print('Loading SD3 Medium... (this may take a minute)')
pipe = StableDiffusion3Pipeline.from_pretrained(
    'stabilityai/stable-diffusion-3-medium-diffusers',
    torch_dtype=torch.float16,
)
pipe = pipe.to(device)
print('SD3 Medium loaded.')
print()

In [None]:
# --- Step 2: Inspect the three text encoders ---
# SD3 has three text encoders: text_encoder, text_encoder_2, text_encoder_3.
# The lesson predicted: CLIP ViT-L, OpenCLIP ViT-bigG, T5-XXL.

encoders = [
    ('text_encoder (CLIP ViT-L)', pipe.text_encoder),
    ('text_encoder_2 (OpenCLIP ViT-bigG)', pipe.text_encoder_2),
    ('text_encoder_3 (T5-XXL)', pipe.text_encoder_3),
]

print('SD3 Text Encoders:')
print('=' * 75)
total_encoder_params = 0
for name, encoder in encoders:
    if encoder is None:
        print(f'  {name}: NOT LOADED (may be optional)')
        continue
    n_params = count_parameters(encoder)
    total_encoder_params += n_params
    print(f'  {name}')
    print(f'    Class: {type(encoder).__name__}')
    print(f'    Parameters: {n_params:,} ({n_params / 1e9:.2f}B)')
    print()

print(f'Total text encoder parameters: {total_encoder_params:,} ({total_encoder_params / 1e9:.2f}B)')

In [None]:
# --- Step 3: Verify embedding output shapes ---
# Encode a test prompt through each encoder and check the output shapes.
# Expected from the lesson:
#   CLIP ViT-L:       [1, 77, 768]
#   OpenCLIP ViT-bigG: [1, 77, 1280]
#   T5-XXL:           [1, 77, 4096]

test_prompt = 'a cat sitting on a beach at sunset'

# Use the pipeline's encode_prompt method to get all embeddings at once
with torch.no_grad():
    (
        prompt_embeds,       # Combined text embeddings for joint attention
        negative_prompt_embeds,
        pooled_prompt_embeds,  # Pooled CLIP embeddings for adaLN-Zero
        negative_pooled_prompt_embeds,
    ) = pipe.encode_prompt(
        prompt=test_prompt,
        prompt_2=test_prompt,
        prompt_3=test_prompt,
    )

print(f'Prompt: "{test_prompt}"')
print()
print('Combined embeddings (what enters joint attention):')
print(f'  prompt_embeds shape:         {list(prompt_embeds.shape)}')
print(f'  pooled_prompt_embeds shape:  {list(pooled_prompt_embeds.shape)}')
print()
print('The prompt_embeds are the per-token text embeddings that become')
print('text tokens in the joint attention sequence.')
print('The pooled_prompt_embeds provide global conditioning via adaLN-Zero')
print('(same dual-path pattern as SDXL).')

In [None]:
# --- Step 4: Encode through individual text encoders to verify shapes ---
# We also encode individually to confirm the per-encoder shapes.

tokenizer_1 = pipe.tokenizer
tokenizer_2 = pipe.tokenizer_2
tokenizer_3 = pipe.tokenizer_3

# CLIP ViT-L
if pipe.text_encoder is not None:
    tokens_1 = tokenizer_1(
        test_prompt, return_tensors='pt', padding='max_length',
        max_length=77, truncation=True
    ).input_ids.to(device)
    with torch.no_grad():
        out_1 = pipe.text_encoder(tokens_1, output_hidden_states=True)
    clip_hidden = out_1.hidden_states[-2]  # penultimate layer (standard for SD)
    clip_pooled = out_1[1]  # pooled output
    print(f'CLIP ViT-L:')
    print(f'  Hidden states shape:  {list(clip_hidden.shape)}  (expected [1, 77, 768])')
    print(f'  Pooled output shape:  {list(clip_pooled.shape)}')
    print()

# OpenCLIP ViT-bigG
if pipe.text_encoder_2 is not None:
    tokens_2 = tokenizer_2(
        test_prompt, return_tensors='pt', padding='max_length',
        max_length=77, truncation=True
    ).input_ids.to(device)
    with torch.no_grad():
        out_2 = pipe.text_encoder_2(tokens_2, output_hidden_states=True)
    clip2_hidden = out_2.hidden_states[-2]
    clip2_pooled = out_2[1]
    print(f'OpenCLIP ViT-bigG:')
    print(f'  Hidden states shape:  {list(clip2_hidden.shape)}  (expected [1, 77, 1280])')
    print(f'  Pooled output shape:  {list(clip2_pooled.shape)}')
    print()

# T5-XXL
if pipe.text_encoder_3 is not None:
    tokens_3 = tokenizer_3(
        test_prompt, return_tensors='pt', padding='max_length',
        max_length=77, truncation=True
    ).input_ids.to(device)
    with torch.no_grad():
        out_3 = pipe.text_encoder_3(tokens_3)
    t5_hidden = out_3.last_hidden_state
    print(f'T5-XXL:')
    print(f'  Hidden states shape:  {list(t5_hidden.shape)}  (expected [1, 77, 4096])')
    print(f'  No pooled output (T5 is encoder-only for embeddings)')
    print()

print('Shape comparison:')
print(f'  CLIP ViT-L:       [1, 77, 768]    ->  768 dims per token')
print(f'  OpenCLIP ViT-bigG: [1, 77, 1280]  -> 1280 dims per token')
print(f'  T5-XXL:           [1, 77, 4096]   -> 4096 dims per token')
print()
print(f'T5 embeddings are {4096 // 768}x wider than CLIP ViT-L.')
print('More dimensions = more information per token = richer text understanding.')

In [None]:
# --- Step 5: Inspect the MMDiT transformer ---
# The denoising network is pipe.transformer (an SD3Transformer2DModel).

transformer = pipe.transformer
transformer_params = count_parameters(transformer)

print('MMDiT Transformer (Denoising Network):')
print(f'  Class: {type(transformer).__name__}')
print(f'  Parameters: {transformer_params:,} ({transformer_params / 1e9:.2f}B)')
print()

# Count transformer blocks
n_blocks = len(transformer.transformer_blocks)
print(f'  Number of MMDiT blocks: {n_blocks}')
print()

# Parameter breakdown by top-level component
print('Parameter breakdown:')
groups = {}
for name, module in transformer.named_children():
    n_params = count_parameters(module)
    if n_params > 0:
        groups[name] = n_params

for group, count in sorted(groups.items(), key=lambda x: -x[1]):
    pct = count / transformer_params * 100
    print(f'    {group:<35} {count:>14,} ({pct:.1f}%)')

In [None]:
# --- Step 6: Inspect one MMDiT block for modality-specific projections ---
# The lesson explained that MMDiT has SEPARATE Q/K/V projections for
# text tokens and image tokens. This is the key architectural detail.

block_0 = transformer.transformer_blocks[0]
print(f'MMDiT Block 0: {type(block_0).__name__}')
print()
print('Sub-modules (looking for modality-specific projections):')
print('-' * 80)
for name, module in block_0.named_modules():
    if name == '':
        continue
    n_params = sum(p.numel() for p in module.parameters(recurse=False))
    if n_params > 0:
        print(f'  {name:<50} {type(module).__name__:<15} ({n_params:,})')

block_params = count_parameters(block_0)
print(f'\nBlock 0 total: {block_params:,} parameters')
print(f'All {n_blocks} blocks: ~{block_params * n_blocks:,} parameters (estimate)')
print()
print('Look for modality-specific components:')
print('  - Separate attention projections for text vs image (Q/K/V per modality)')
print('  - Separate FFN layers for text vs image')
print('  - Separate norm layers per modality')
print('  - adaLN modulation MLP for timestep conditioning')

In [None]:
# --- Step 7: Compare text encoder vs denoising network parameters ---

vae_params = count_parameters(pipe.vae)

print('SD3 Medium: Parameter Breakdown')
print('=' * 55)
print(f'  CLIP ViT-L (text_encoder):     {count_parameters(pipe.text_encoder):>14,}') if pipe.text_encoder else None
print(f'  OpenCLIP ViT-bigG (text_enc_2): {count_parameters(pipe.text_encoder_2):>13,}') if pipe.text_encoder_2 else None
print(f'  T5-XXL (text_encoder_3):       {count_parameters(pipe.text_encoder_3):>14,}') if pipe.text_encoder_3 else None
print(f'  MMDiT transformer:             {transformer_params:>14,}')
print(f'  VAE:                           {vae_params:>14,}')
print('-' * 55)
total = total_encoder_params + transformer_params + vae_params
print(f'  Total:                         {total:>14,} ({total / 1e9:.2f}B)')
print()
print(f'Text encoders:    {total_encoder_params / 1e9:.2f}B ({total_encoder_params / total * 100:.0f}% of total)')
print(f'Denoising (MMDiT): {transformer_params / 1e9:.2f}B ({transformer_params / total * 100:.0f}% of total)')
print(f'VAE:              {vae_params / 1e6:.0f}M ({vae_params / total * 100:.0f}% of total)')
print()
print('The text encoders (especially T5-XXL) are a huge fraction of the total model.')
print('The field has recognized that text understanding is a bottleneck worth')
print('investing parameters in. Understanding the prompt is as important as')
print('generating the image.')

### What Just Happened

You loaded SD3 Medium and verified the architecture the lesson described:

- **Three text encoders, as predicted.** CLIP ViT-L `[77, 768]`, OpenCLIP ViT-bigG `[77, 1280]`, and T5-XXL `[77, 4096]`. T5's embeddings are 5x wider than CLIP's—more information per token, capturing linguistic structure rather than visual alignment.

- **The text encoders are enormous.** T5-XXL alone has ~4.7B parameters. The total text encoder parameter count exceeds the denoising network. This reflects the field's recognition that text understanding is a bottleneck worth investing in.

- **Modality-specific components in each MMDiT block.** Inside each block you can see separate attention projections and FFN layers for text vs image tokens. This is the "shared listening, separate thinking" pattern from the lesson: shared attention computation, modality-specific projections and processing.

- **Pooled embeddings for global conditioning.** The CLIP pooled outputs provide global conditioning via adaLN-Zero (same dual-path pattern as SDXL). T5 provides per-token embeddings only—no pooled summary vector.

- **The progression is visible in the numbers.** SD v1.5: one encoder (123M). SDXL: two encoders (~477M). SD3: three encoders (~5B+). Each generation adds richer text understanding.

---

## Exercise 2: Joint Attention Structure `[Guided]`

The lesson taught that MMDiT replaces cross-attention with **joint self-attention**: concatenate text tokens and image patch tokens into one sequence, then run standard self-attention on the combined sequence. This provides four types of attention simultaneously:

- **Image-to-text** (image tokens read text tokens)—equivalent to cross-attention in the U-Net
- **Text-to-image** (text tokens read image tokens)—NEW: text representations update based on image content
- **Image-to-image** (image tokens read image tokens)—equivalent to self-attention in DiT
- **Text-to-text** (text tokens read text tokens)—NEW: text representations refine each other

In this exercise, you will:
1. Encode a prompt and prepare a noisy latent
2. Verify the concatenated sequence length (text tokens + image tokens)
3. Inspect the MMDiT block's attention architecture to confirm modality-specific projections
4. Visualize the **four-quadrant structure** of the joint attention matrix using synthetic data

**Why synthetic data?** The diffusers library does not expose per-head attention weights from the internal `Attention` module during a forward pass. Rather than fighting the implementation, we construct synthetic Q/K vectors matching the model's dimensions to demonstrate what each quadrant of the joint attention matrix *represents*. The quadrant structure is the key insight—it holds regardless of the specific trained weight values.

**Before running, predict:**
- If the text sequence has 77 tokens and the image has 1024 patch tokens (64x64 latent, patch size 2), what is the joint attention matrix shape?
- Which quadrant will have the highest attention weights? (Hint: the primary task is spatial coherence in the image.)

In [None]:
# ============================================================
# Exercise 2: Visualize joint attention in an MMDiT block
# ============================================================

# --- Step 1: Prepare inputs for one denoising step ---
# We need: text embeddings (from the encoders), a noisy latent,
# and a timestep. We'll run a single forward pass through the
# transformer and capture attention weights.

prompt = 'a cat sitting on a beach at sunset'

# Encode the prompt (reuse from Exercise 1 if still in memory)
with torch.no_grad():
    (
        prompt_embeds,
        negative_prompt_embeds,
        pooled_prompt_embeds,
        negative_pooled_prompt_embeds,
    ) = pipe.encode_prompt(
        prompt=prompt,
        prompt_2=prompt,
        prompt_3=prompt,
    )

print(f'Prompt embeddings shape: {list(prompt_embeds.shape)}')
print(f'  This is the text token sequence that enters joint attention.')

# Get the expected latent shape from the pipeline
# SD3 Medium generates at 1024x1024 by default, but we can use
# a smaller size for faster inspection.
height, width = 512, 512
latent_h = height // pipe.vae_scale_factor
latent_w = width // pipe.vae_scale_factor
latent_channels = pipe.transformer.config.in_channels

print(f'\nLatent shape: [{latent_channels}, {latent_h}, {latent_w}]')

# Compute expected patch count
patch_size = pipe.transformer.config.patch_size
n_image_tokens = (latent_h // patch_size) * (latent_w // patch_size)
n_text_tokens = prompt_embeds.shape[1]
n_total_tokens = n_text_tokens + n_image_tokens

print(f'Patch size: {patch_size}')
print(f'Image patch tokens: ({latent_h}//{patch_size}) * ({latent_w}//{patch_size}) = {n_image_tokens}')
print(f'Text tokens: {n_text_tokens}')
print(f'Total joint sequence: {n_text_tokens} + {n_image_tokens} = {n_total_tokens}')
print(f'Expected attention matrix: [{n_total_tokens}, {n_total_tokens}]')

In [None]:
# --- Step 2: Hook into an MMDiT block to capture attention weights ---
# We register a forward hook on one block's attention module to
# extract the attention weight matrix during a forward pass.

# Storage for captured attention weights
captured_attention = {}

def capture_attention_hook(module, args, kwargs, output):
    """Hook to capture attention weights from an Attention module."""
    # We need to run attention manually to get weights.
    # The diffusers Attention module does not return weights by default,
    # so we compute them from the Q/K projections.
    captured_attention['output'] = output
    return output

# We will use a different approach: run the pipeline for one step
# and capture the hidden states by hooking into the block.
# Let's capture the joint hidden states entering and exiting a block.

block_index = 0  # Inspect the first MMDiT block
block = pipe.transformer.transformer_blocks[block_index]

# Hook to capture the Q and K after projection (to compute attention weights)
captured_qk = {}

def capture_qk_hook(module, input, output):
    """Capture the attention module's input (hidden states) to compute Q/K."""
    captured_qk['input'] = input
    captured_qk['output'] = output

hook = block.attn.register_forward_hook(capture_qk_hook)
print(f'Hook registered on block {block_index} attention module.')
print(f'Block type: {type(block).__name__}')
print(f'Attention type: {type(block.attn).__name__}')

In [None]:
# --- Step 3: Run one denoising step to trigger the hook ---
# We generate with just 1 step to capture the attention patterns.

generator = torch.Generator(device='cpu').manual_seed(42)

with torch.no_grad():
    # Run just 2 steps (minimum for the scheduler to work)
    result = pipe(
        prompt=prompt,
        prompt_2=prompt,
        prompt_3=prompt,
        height=height,
        width=width,
        num_inference_steps=2,
        guidance_scale=7.0,
        generator=generator,
        output_type='pil',
    )

print(f'Generation complete. Captured data from block {block_index}.')
print(f'Captured keys: {list(captured_qk.keys())}')

# Remove hook
hook.remove()
print('Hook removed.')

In [None]:
# --- Step 4: Compute attention weights from the captured Q and K ---
# The attention module in diffusers processes the joint QKV internally.
# We'll manually compute Q @ K^T to get the attention pattern.
#
# The MMDiT block has modality-specific Q/K/V projections.
# text tokens -> text_W_Q, text_W_K
# image tokens -> image_W_Q, image_W_K
# Then Q and K are concatenated before attention.
#
# We can recover the attention pattern by looking at the block's
# internal processing. Let's inspect what the attention module received.

# Check input shapes
if 'input' in captured_qk:
    for i, inp in enumerate(captured_qk['input']):
        if isinstance(inp, torch.Tensor):
            print(f'  Input {i} shape: {list(inp.shape)}')

# For visualization, we need to manually compute the attention weights.
# Let's do this by running the Q/K projections ourselves on captured inputs.
# The joint hidden state entering the attention module contains
# both text and image tokens concatenated.

# Get the attention module's key components
attn = block.attn
print(f'\nAttention module components:')
for name, mod in attn.named_modules():
    if name == '':
        continue
    n = sum(p.numel() for p in mod.parameters(recurse=False))
    if n > 0:
        print(f'  {name:<30} {type(mod).__name__:<15} ({n:,})')

In [None]:
# --- Step 5: Prepare to visualize the four-quadrant structure ---
# The diffusers Attention module computes attention internally and
# does not return per-head attention weight matrices. Extracting
# real weights would require monkey-patching the module's forward
# method, which is fragile across library versions.
#
# Instead, we demonstrate the STRUCTURE of joint attention: what
# the four quadrants represent, how the matrix dimensions work,
# and how the attention cost compares to cross-attention. The
# quadrant structure is a property of the architecture, not of
# any specific trained weights.

# Create synthetic joint hidden states matching the expected dimensions
d_model = pipe.transformer.config.joint_attention_dim
num_heads = pipe.transformer.config.num_attention_heads
d_head = d_model // num_heads

print(f'MMDiT attention config:')
print(f'  d_model (joint_attention_dim): {d_model}')
print(f'  num_heads: {num_heads}')
print(f'  d_head: {d_head}')
print(f'  text tokens: {n_text_tokens}')
print(f'  image tokens: {n_image_tokens}')
print(f'  total tokens: {n_total_tokens}')
print()

# For visualization, we use a MUCH smaller token count to make
# the attention matrix viewable. With 1024 image tokens, the matrix
# is 1101x1101 which is hard to see patterns in.
# The quadrant structure is identical at any size.

n_text_vis = 10  # Show first 10 text tokens for clarity
n_image_vis = 32  # Show 32 image tokens (e.g., a 4x8 patch grid)
n_total_vis = n_text_vis + n_image_vis

print(f'Visualization dimensions (downsampled for clarity):')
print(f'  Text tokens shown:  {n_text_vis}')
print(f'  Image tokens shown: {n_image_vis}')
print(f'  Total: {n_total_vis}')
print(f'  Attention matrix: [{n_total_vis}, {n_total_vis}]')

In [None]:
# --- Step 6: Construct a synthetic attention matrix showing quadrant structure ---
# These are NOT trained attention weights from the model. We construct
# synthetic Q and K vectors to produce a plausible attention pattern
# that illustrates the four-quadrant structure of joint attention.
#
# The pedagogical point: regardless of what specific values the trained
# model learns, the joint attention matrix ALWAYS has these four quadrants.
# The quadrant structure is an architectural property, not a learned one.

torch.manual_seed(42)

# Create Q and K with separate "modality" structure
# Text tokens: similar to each other (same prompt, coherent meaning)
text_base = torch.randn(1, 1, d_head) * 0.5
text_Q = text_base + torch.randn(1, n_text_vis, d_head) * 0.3
text_K = text_base + torch.randn(1, n_text_vis, d_head) * 0.3

# Image tokens: spatial locality (nearby patches attend to each other)
image_Q = torch.randn(1, n_image_vis, d_head) * 0.5
image_K = torch.randn(1, n_image_vis, d_head) * 0.5

# Concatenate (joint attention)
joint_Q = torch.cat([text_Q, image_Q], dim=1)  # [1, 42, d_head]
joint_K = torch.cat([text_K, image_K], dim=1)  # [1, 42, d_head]

# Compute attention weights: softmax(Q @ K^T / sqrt(d_head))
import math
attn_logits = torch.bmm(joint_Q, joint_K.transpose(1, 2)) / math.sqrt(d_head)
attn_weights = torch.softmax(attn_logits, dim=-1)  # [1, 42, 42]

print(f'Joint Q shape: {list(joint_Q.shape)}')
print(f'Joint K shape: {list(joint_K.shape)}')
print(f'Attention weights shape: {list(attn_weights.shape)}')
print(f'  = [{n_total_vis}, {n_total_vis}] (every token attends to every other)')
print()
print('NOTE: These are synthetic weights for structural demonstration,')
print('not trained attention patterns from the model.')

In [None]:
# --- Step 7: Visualize the four-quadrant structure ---

attn_np = attn_weights[0].detach().numpy()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7),
                                gridspec_kw={'width_ratios': [1, 1]})

# Left: full attention matrix with quadrant labels
im = ax1.imshow(attn_np, cmap='viridis', aspect='auto')
ax1.set_title(f'Joint Attention Structure [{n_total_vis} x {n_total_vis}]\n(synthetic weights, real quadrant structure)', fontsize=11)
ax1.set_xlabel('Key (attending TO)', fontsize=10)
ax1.set_ylabel('Query (attending FROM)', fontsize=10)

# Draw quadrant boundaries
ax1.axhline(y=n_text_vis - 0.5, color='red', linewidth=2, linestyle='--')
ax1.axvline(x=n_text_vis - 0.5, color='red', linewidth=2, linestyle='--')

# Label quadrants
# Top-left: text-to-text
ax1.text(n_text_vis / 2, n_text_vis / 2, 'Text\u2192Text',
         ha='center', va='center', fontsize=11, fontweight='bold',
         color='white', bbox=dict(boxstyle='round', facecolor='black', alpha=0.6))
# Top-right: text-to-image
ax1.text(n_text_vis + n_image_vis / 2, n_text_vis / 2, 'Text\u2192Image\n(NEW)',
         ha='center', va='center', fontsize=11, fontweight='bold',
         color='#22d3ee', bbox=dict(boxstyle='round', facecolor='black', alpha=0.6))
# Bottom-left: image-to-text
ax1.text(n_text_vis / 2, n_text_vis + n_image_vis / 2, 'Image\u2192Text\n(was cross-attn)',
         ha='center', va='center', fontsize=11, fontweight='bold',
         color='#fbbf24', bbox=dict(boxstyle='round', facecolor='black', alpha=0.6))
# Bottom-right: image-to-image
ax1.text(n_text_vis + n_image_vis / 2, n_text_vis + n_image_vis / 2, 'Image\u2192Image\n(was self-attn)',
         ha='center', va='center', fontsize=11, fontweight='bold',
         color='#34d399', bbox=dict(boxstyle='round', facecolor='black', alpha=0.6))

plt.colorbar(im, ax=ax1, label='Attention weight', shrink=0.8)

# Right: comparison diagram showing cross-attention vs joint attention
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)
ax2.set_aspect('equal')
ax2.axis('off')
ax2.set_title('Cross-Attention vs Joint Attention', fontsize=12)

# Cross-attention (left side)
ax2.text(2.5, 9.2, 'Cross-Attention (U-Net)', ha='center', fontsize=10, fontweight='bold')
# Self-attention box
rect1 = mpatches.FancyBboxPatch((0.5, 6.5), 4, 2, boxstyle='round,pad=0.2',
                                  facecolor='#34d399', alpha=0.3, edgecolor='#34d399')
ax2.add_patch(rect1)
ax2.text(2.5, 7.5, 'Self-Attn\n[256, 256]\nImage\u2192Image', ha='center', fontsize=8)
# Cross-attention box
rect2 = mpatches.FancyBboxPatch((0.5, 4.0), 4, 2, boxstyle='round,pad=0.2',
                                  facecolor='#fbbf24', alpha=0.3, edgecolor='#fbbf24')
ax2.add_patch(rect2)
ax2.text(2.5, 5.0, 'Cross-Attn\n[256, 77]\nImage\u2192Text', ha='center', fontsize=8)
ax2.text(2.5, 3.5, 'Two operations', ha='center', fontsize=9, color='#ef4444')

# Joint attention (right side)
ax2.text(7.5, 9.2, 'Joint Attention (MMDiT)', ha='center', fontsize=10, fontweight='bold')
rect3 = mpatches.FancyBboxPatch((5.5, 4.0), 4, 4.5, boxstyle='round,pad=0.2',
                                  facecolor='#a78bfa', alpha=0.3, edgecolor='#a78bfa')
ax2.add_patch(rect3)
ax2.text(7.5, 7.2, 'Joint Self-Attn', ha='center', fontsize=9, fontweight='bold')
ax2.text(7.5, 6.3, f'[{n_total_tokens}, {n_total_tokens}]', ha='center', fontsize=8)
ax2.text(7.5, 5.4, 'All four directions:', ha='center', fontsize=8)
ax2.text(7.5, 4.7, 'Img\u2192Img, Img\u2192Txt\nTxt\u2192Img, Txt\u2192Txt', ha='center', fontsize=7)
ax2.text(7.5, 3.5, 'One operation', ha='center', fontsize=9, color='#22d3ee')

# Arrow showing simplification
ax2.annotate('', xy=(5.3, 6.0), xytext=(4.7, 6.0),
             arrowprops=dict(arrowstyle='->', color='white', lw=2))

plt.tight_layout()
plt.show()

print('The four quadrants of joint attention (architectural structure):')
print(f'  Text\u2192Text  (top-left):     text tokens refine each other')
print(f'  Text\u2192Image (top-right):    text representations update based on image (NEW)')
print(f'  Image\u2192Text (bottom-left):  image reads text (was cross-attention)')
print(f'  Image\u2192Image (bottom-right): image reads image (was self-attention)')
print()
print('Cross-attention provided only Image\u2192Text.')
print('Joint attention provides all four in one operation.')
print('One room, one conversation.')
print()
print('The specific weight values above are synthetic, but the quadrant')
print('structure is real: it follows directly from concatenating text and')
print('image tokens into one sequence before self-attention.')

In [None]:
# --- Step 8: Compute attention cost comparison ---
# The lesson compared the computational cost:
#   Cross-attention: self-attn [256,256] + cross-attn [256,77] = 85,248
#   Joint attention: [333,333] = 110,889
#
# Let's compute this for SD3's actual dimensions.

print('Attention cost comparison (SD3 Medium dimensions):')
print('=' * 60)
print()

# Cross-attention approach (hypothetical)
self_attn_cost = n_image_tokens * n_image_tokens
cross_attn_cost = n_image_tokens * n_text_tokens
total_cross = self_attn_cost + cross_attn_cost
print(f'Cross-attention approach (two operations):')
print(f'  Self-attention:   {n_image_tokens} x {n_image_tokens} = {self_attn_cost:,}')
print(f'  Cross-attention:  {n_image_tokens} x {n_text_tokens} = {cross_attn_cost:,}')
print(f'  Total:            {total_cross:,}')
print()

# Joint attention approach (actual)
joint_cost = n_total_tokens * n_total_tokens
print(f'Joint attention approach (one operation):')
print(f'  Joint:            {n_total_tokens} x {n_total_tokens} = {joint_cost:,}')
print()

ratio = joint_cost / total_cross
print(f'Joint / Cross ratio: {ratio:.2f}x')
print(f'Joint attention is ~{(ratio - 1) * 100:.0f}% more expensive in attention compute.')
print(f'But it replaces TWO operations with ONE, reducing other overhead.')
print(f'And it provides bidirectional text-image interaction, which cross-attention cannot.')

### What Just Happened

You explored the structure of joint attention in an MMDiT block:

- **The joint sequence is text + image tokens concatenated.** With 77 text tokens and 1024 image tokens (for 512x512), the joint sequence is 1101 tokens. The attention matrix is [1101, 1101]—every token attends to every other token.

- **Four quadrants, four types of attention.** The attention matrix naturally divides into four quadrants: text-to-text, text-to-image, image-to-text, and image-to-image. Cross-attention in the U-Net only provided image-to-text. Joint attention provides all four. This quadrant structure is an architectural property—it holds for any joint attention model, regardless of trained weights.

- **Text-to-image is the key new capability.** In cross-attention, text embeddings are fixed—they never change in response to the image. In joint attention, text tokens can read image tokens and update their representations. "A crane near a river" can be disambiguated based on what the image actually contains.

- **Modality-specific projections are visible in the architecture.** Each MMDiT block has separate Q/K/V projection layers for text and image tokens. They "speak their own language" but share the attention computation—"hear each other" in the same room.

- **Cost is comparable.** Joint attention is moderately more expensive per block than cross-attention + self-attention combined, but it replaces two operations with one and provides richer interaction.

---

## Exercise 3: SD3 Generation and Flow Matching Steps `[Supported]`

The lesson taught that SD3 uses **flow matching** as its training objective—the same straight-line interpolation and velocity prediction you trained with in **Flow Matching** (7.2.2). The practical payoff: good results in 20–30 steps, compared to 50+ for DDPM-based models.

In this exercise, you will:
1. Generate "a cat sitting on a beach at sunset" with SD3 at different step counts
2. Compare quality across step counts to observe the flow matching payoff
3. Measure generation time at each step count
4. Compare the quality-vs-steps tradeoff to what you would expect from a DDPM model

Fill in the TODO markers to complete the comparison.

In [None]:
# ============================================================
# Exercise 3: Generate with SD3 and compare step counts
# ============================================================

# --- Step 1: Generate at different step counts ---
# The flow matching payoff: SD3 should produce good images at
# 20-30 steps. DDPM-based models typically need 50+.

prompt = 'a cat sitting on a beach at sunset'
height, width = 512, 512  # Use 512x512 for faster generation
guidance_scale = 7.0

step_counts = [10, 20, 30, 50]
images = []
titles = []

for steps in step_counts:
    generator = torch.Generator(device='cpu').manual_seed(42)
    start = time.time()

    # TODO: Generate an image using pipe() with the parameters above.
    # Use: prompt, height, width, num_inference_steps=steps,
    #      guidance_scale, generator, output_type='pil'
    # Store the result in a variable called 'result'.
    raise NotImplementedError(
        "TODO: Call pipe() with the correct arguments to generate an image."
    )

    elapsed = time.time() - start
    images.append(result.images[0])
    titles.append(f'{steps} steps\n{elapsed:.1f}s')
    print(f'  {steps} steps: {elapsed:.1f}s')

print('\nGeneration complete.')

In [None]:
# --- Step 2: Display the comparison ---

# TODO: Use show_image_row() to display the images with titles.
# Add a suptitle like 'SD3 Medium: Quality vs Step Count (flow matching)'
# and figsize=(20, 5).
raise NotImplementedError(
    "TODO: Call show_image_row(images, titles, suptitle=..., figsize=...)"
)

In [None]:
# --- Step 3: Analyze the quality-vs-steps tradeoff ---

print('Quality vs Steps Analysis:')
print('=' * 50)
print()
print('What you should observe:')
print('  10 steps:  Recognizable but may lack fine detail')
print('  20 steps:  Good quality - the flow matching sweet spot')
print('  30 steps:  High quality - diminishing returns begin')
print('  50 steps:  Marginal improvement over 30 steps')
print()
print('Compare to DDPM-based models (SD v1.5, SDXL):')
print('  10 steps:  Poor quality (curved trajectories need more steps)')
print('  20 steps:  Mediocre quality')
print('  50 steps:  Good quality (the DDPM sweet spot)')
print()
print('The difference: flow matching produces straight trajectories.')
print('Straight paths need fewer ODE solver steps to follow accurately.')
print('This is the practical payoff of "curved vs straight" from Flow Matching.')

In [None]:
# --- Step 4: Try a complex compositional prompt ---
# The lesson taught that T5-XXL provides richer text understanding
# than CLIP, especially for compositional descriptions.

complex_prompt = 'a red ball to the left of a blue cube, on a wooden table, with a green plant in the background'
simple_prompt = 'a sunset over the ocean'

prompts_to_try = [
    (simple_prompt, 'Simple prompt'),
    (complex_prompt, 'Complex compositional'),
]

comp_images = []
comp_titles = []

for p, label in prompts_to_try:
    generator = torch.Generator(device='cpu').manual_seed(42)

    # TODO: Generate an image with 28 steps (SD3's default).
    # Use the prompt 'p', height, width, guidance_scale, generator.
    raise NotImplementedError(
        "TODO: Generate an image for each prompt with 28 steps."
    )

    comp_images.append(result.images[0])
    comp_titles.append(f'{label}\n"{p[:40]}..."' if len(p) > 40 else f'{label}\n"{p}"')

show_image_row(
    comp_images, comp_titles,
    suptitle='SD3 Medium: Simple vs Compositional Prompts (T5-XXL text understanding)',
    figsize=(12, 5),
)

print('T5-XXL helps with compositional prompts that require understanding')
print('spatial relationships ("left of"), counting, and multiple objects.')
print('CLIP alone would struggle with "a red ball to the left of a blue cube"')
print('because contrastive training does not teach spatial reasoning.')

<details>
<summary>Solution</summary>

The key insight is that SD3's flow matching training objective produces straight trajectories that need fewer ODE solver steps. The pipeline API is straightforward—the only difference from SD v1.5 is the model and the fact that 20–30 steps produce good results instead of 50+.

**Step 1: Generate at different step counts**
```python
result = pipe(
    prompt=prompt,
    height=height,
    width=width,
    num_inference_steps=steps,
    guidance_scale=guidance_scale,
    generator=generator,
    output_type='pil',
)
```

**Step 2: Display the comparison**
```python
show_image_row(
    images, titles,
    suptitle='SD3 Medium: Quality vs Step Count (flow matching)',
    figsize=(20, 5),
)
```

**Step 4: Complex compositional prompt**
```python
result = pipe(
    prompt=p,
    height=height,
    width=width,
    num_inference_steps=28,
    guidance_scale=guidance_scale,
    generator=generator,
    output_type='pil',
)
```

**What to observe:**
- At 20–30 steps, SD3 produces images comparable in quality to DDPM models at 50+ steps. This is the practical payoff of flow matching: straight trajectories need fewer solver steps.
- The jump from 10 to 20 steps is significant; from 30 to 50 is marginal. This matches the lesson's claim that 20–30 steps is the sweet spot.
- Complex compositional prompts benefit from T5-XXL's linguistic understanding. Spatial relationships like "to the left of" are handled better than CLIP-only models would manage.

**Common mistakes:**
- Forgetting to reset the generator seed for each step count. Without a fixed seed, you cannot compare quality at different step counts because the initial noise is different.
- Using 1024x1024 resolution, which is slower and may cause OOM on smaller GPUs. 512x512 is sufficient to observe the quality-vs-steps tradeoff.

</details>

### What Just Happened

You generated images with SD3 at different step counts and observed the flow matching payoff:

- **20–30 steps is the sweet spot.** SD3 produces good quality at 20 steps and high quality at 30. This is a practical consequence of flow matching's straight trajectories—fewer ODE solver steps are needed to follow a straight path than a curved one.

- **Diminishing returns after 30 steps.** Going from 30 to 50 steps provides marginal improvement. The straight trajectories are already well-approximated at 30 steps. Compare to DDPM models where 50 steps is often the minimum for good quality.

- **Same training objective you trained with.** SD3's flow matching is the same straight-line interpolation and velocity prediction from your flow matching notebook. The concept is identical; the scale is different.

- **T5-XXL helps with complex prompts.** Compositional descriptions with spatial relationships benefit from T5's deep linguistic understanding. CLIP's contrastive training does not teach spatial reasoning or counting.

---

In [None]:
# --- Cleanup before Exercise 4 ---
# Keep the pipeline loaded for Exercise 4.
# Free any intermediate tensors.
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
print('Ready for Exercise 4.')

## Exercise 4: The Convergence Pipeline Trace `[Independent]`

The lesson concluded with the full SD3 pipeline traced end-to-end, every step annotated with the lesson that covered it:

```
 1. Prompt -> CLIP ViT-L text encoder [77, 768]           (CLIP)
 2. Prompt -> OpenCLIP ViT-bigG text encoder [77, 1280]    (SDXL)
 3. Prompt -> T5-XXL text encoder [77, 4096]               (this lesson)
 4. CLIP pooled embeddings + timestep -> c for adaLN-Zero   (Diffusion Transformers)
 5. Per-token embeddings projected -> text tokens [77, d]    (this lesson)
 6. Noisy latent z_t [4, 64, 64] -> patchify -> [1024, d]  (Diffusion Transformers)
 7. Concatenate: [77 + 1024, d] = [1101, d]                 (this lesson: MMDiT)
 8. N MMDiT blocks with joint attention + adaLN-Zero         (this lesson + Diffusion Transformers)
 9. Split output -> image tokens [1024, d]                   (this lesson)
10. Unpatchify -> [4, 64, 64]                               (Diffusion Transformers)
11. Flow matching sampling step                              (Flow Matching)
12. Repeat steps 6-11 for ~28 steps                         (Flow Matching: fewer steps needed)
13. VAE decode z_0 -> [3, 1024, 1024]                       (From Pixels to Latents)
```

### Your Task

1. **Generate an image with SD3**, capturing intermediate outputs at each stage
2. **Trace the pipeline step by step**: text encoding shapes, patchify token count, denoising step count, VAE decode output shape
3. **For each step, print which lesson covered the relevant concept**
4. **Compare the SD3 pipeline to what you know about the SD v1.5 pipeline**: what changed and what is preserved?

### Hints

- Text encoding: use `pipe.encode_prompt()` to get prompt embeddings and pooled embeddings
- VAE encoding: use `pipe.vae.encode()` on a dummy image to verify latent shapes
- Transformer config: `pipe.transformer.config` has `patch_size`, `in_channels`, `num_layers`, `joint_attention_dim`
- VAE config: `pipe.vae.config` has `scaling_factor`, `latent_channels`
- The pipeline's scheduler: `pipe.scheduler` (check `type(pipe.scheduler).__name__` for the flow matching scheduler)
- Use `pipe.vae_scale_factor` for the VAE downsampling factor
- To capture a generated image at each denoising step, use the `callback_on_step_end` parameter

In [None]:
# ============================================================
# Exercise 4: Trace the full SD3 pipeline end-to-end
# ============================================================
#
# Trace each stage of the SD3 pipeline, printing shapes and
# annotating which lesson covered each concept.
#
# Your code here:



In [None]:
# --- Print the annotated pipeline trace ---
#
# After capturing the shapes and configs above, print
# the full pipeline trace with lesson annotations.
#
# Your code here:



In [None]:
# --- Compare SD3 pipeline to SD v1.5 pipeline ---
#
# Print a comparison showing what changed and what
# is preserved between the two architectures.
#
# Your code here:



<details>
<summary>Solution</summary>

The convergence theme is the key insight: every component of SD3 traces to a lesson you completed. The pipeline structure is preserved from SD v1.5—only the components inside each stage have evolved.

```python
# ============================================================
# Part 1: Trace each stage with shapes and annotations
# ============================================================

prompt = 'a cat sitting on a beach at sunset'
height, width = 512, 512

print('SD3 Pipeline Trace')
print('=' * 70)
print()

# Stage 1-3: Text Encoding
print('STAGE 1-3: TEXT ENCODING')
print('-' * 70)
with torch.no_grad():
    prompt_embeds, neg_embeds, pooled_embeds, neg_pooled = pipe.encode_prompt(
        prompt=prompt, prompt_2=prompt, prompt_3=prompt,
    )

print(f'  1. CLIP ViT-L:       [77, 768]   -> Lesson: CLIP (6.3.3)')
print(f'  2. OpenCLIP ViT-bigG: [77, 1280] -> Lesson: SDXL (7.4.1)')
print(f'  3. T5-XXL:           [77, 4096]  -> Lesson: SD3 & Flux (this lesson)')
print(f'  Combined prompt_embeds: {list(prompt_embeds.shape)}')
print(f'  Pooled for adaLN-Zero:  {list(pooled_embeds.shape)}')
print()

# Stage 4: Global conditioning
print('STAGE 4: GLOBAL CONDITIONING')
print('-' * 70)
print(f'  4. Pooled CLIP + timestep -> adaLN-Zero conditioning')
print(f'     Pooled shape: {list(pooled_embeds.shape)}')
print(f'     -> Lesson: Diffusion Transformers (7.4.2)')
print()

# Stage 5: Text token projection
print('STAGE 5: TEXT TOKEN PROJECTION')
print('-' * 70)
d_model = pipe.transformer.config.joint_attention_dim
n_text = prompt_embeds.shape[1]
print(f'  5. Per-token embeddings -> projected to d={d_model}')
print(f'     Text tokens: [{n_text}, {d_model}]')
print(f'     -> Lesson: SD3 & Flux (this lesson)')
print()

# Stage 6: Patchify
print('STAGE 6: PATCHIFY')
print('-' * 70)
vae_scale = pipe.vae_scale_factor
latent_h = height // vae_scale
latent_w = width // vae_scale
latent_c = pipe.transformer.config.in_channels
patch_size = pipe.transformer.config.patch_size
n_patches = (latent_h // patch_size) * (latent_w // patch_size)
print(f'  6. Noisy latent [{latent_c}, {latent_h}, {latent_w}] -> patchify -> [{n_patches}, {d_model}]')
print(f'     Patch size: {patch_size}, Patches: ({latent_h}//{patch_size}) x ({latent_w}//{patch_size}) = {n_patches}')
print(f'     -> Lesson: Diffusion Transformers (7.4.2)')
print()

# Stage 7: Concatenate
print('STAGE 7: CONCATENATE (JOINT SEQUENCE)')
print('-' * 70)
n_total = n_text + n_patches
print(f'  7. Concatenate: [{n_text} + {n_patches}, {d_model}] = [{n_total}, {d_model}]')
print(f'     -> Lesson: SD3 & Flux (this lesson: MMDiT)')
print()

# Stage 8: MMDiT blocks
print('STAGE 8: MMDiT BLOCKS')
print('-' * 70)
n_blocks = pipe.transformer.config.num_layers
print(f'  8. {n_blocks} MMDiT blocks: joint attention + adaLN-Zero')
print(f'     Attention matrix: [{n_total}, {n_total}] per block')
print(f'     -> Lesson: SD3 & Flux + Diffusion Transformers (7.4.2)')
print()

# Stage 9-10: Split and unpatchify
print('STAGE 9-10: SPLIT AND UNPATCHIFY')
print('-' * 70)
print(f'  9. Split: image tokens [{n_patches}, {d_model}]')
print(f'     -> Lesson: SD3 & Flux (this lesson)')
print(f' 10. Unpatchify: [{n_patches}, {d_model}] -> [{latent_c}, {latent_h}, {latent_w}]')
print(f'     -> Lesson: Diffusion Transformers (7.4.2)')
print()

# Stage 11-12: Flow matching sampling
print('STAGE 11-12: FLOW MATCHING SAMPLING')
print('-' * 70)
scheduler_name = type(pipe.scheduler).__name__
print(f' 11. Flow matching sampling step (scheduler: {scheduler_name})')
print(f'     -> Lesson: Flow Matching (7.2.2)')
print(f' 12. Repeat for ~28 steps (fewer steps due to straight trajectories)')
print(f'     -> Lesson: Flow Matching (7.2.2)')
print()

# Stage 13: VAE decode
print('STAGE 13: VAE DECODE')
print('-' * 70)
print(f' 13. VAE decode: [{latent_c}, {latent_h}, {latent_w}] -> [3, {height}, {width}]')
print(f'     -> Lesson: From Pixels to Latents (6.3.5)')
print()
print('=' * 70)
print('Every step traces to a lesson you completed.')
print('Nothing in this pipeline is unexplained.')
```

```python
# ============================================================
# Part 2: Compare SD3 vs SD v1.5 pipeline
# ============================================================

print('SD3 vs SD v1.5 Pipeline Comparison')
print('=' * 70)
print()
print(f'{"Component":<25} {"SD v1.5":<25} {"SD3":<25}')
print('-' * 70)
print(f'{"Text encoders":<25} {"1 (CLIP ViT-L)":<25} {"3 (CLIP + OpenCLIP + T5)":<25}')
print(f'{"Text embedding dims":<25} {"[77, 768]":<25} {"[77, ~combined]":<25}')
print(f'{"Denoising backbone":<25} {"U-Net":<25} {"MMDiT (transformer)":<25}')
print(f'{"Text conditioning":<25} {"Cross-attention":<25} {"Joint attention":<25}')
print(f'{"Text flow direction":<25} {"Image reads text":<25} {"Bidirectional":<25}')
print(f'{"Timestep conditioning":<25} {"AdaGN":<25} {"adaLN-Zero":<25}')
print(f'{"Training objective":<25} {"DDPM noise prediction":<25} {"Flow matching velocity":<25}')
print(f'{"Inference steps":<25} {"50+":<25} {"20-30":<25}')
print(f'{"VAE":<25} {"Same":<25} {"Same (improved)":<25}')
print(f'{"Sampling loop":<25} {"Same structure":<25} {"Same structure":<25}')
print(f'{"Output resolution":<25} {"512x512":<25} {"1024x1024":<25}')
print()
print('What CHANGED:')
print('  - Architecture: U-Net -> MMDiT (transformer on patch tokens)')
print('  - Text conditioning: cross-attention -> joint self-attention')
print('  - Text encoders: 1 CLIP -> 3 encoders (CLIP + OpenCLIP + T5-XXL)')
print('  - Training objective: DDPM -> flow matching')
print()
print('What is PRESERVED:')
print('  - Latent space generation (VAE encode/decode)')
print('  - Iterative denoising loop (still predict and step)')
print('  - Classifier-free guidance (still works the same way)')
print('  - Pipeline structure: encode text -> denoise latent -> decode to pixels')
print()
print('The lesson called this "same pipeline, different denoising network."')
print('The pipeline structure survived every architectural change.')
```

**Key observations:**
- The pipeline structure is preserved from SD v1.5 to SD3. Text encode, denoise in latent space, VAE decode. The components inside each stage evolved, but the pipeline did not.
- Every component of SD3 traces to a specific lesson. Transformers from Series 4. Latent diffusion from Series 6. Flow matching from Module 7.2. DiT/patchify/adaLN-Zero from the previous lesson. T5-XXL and MMDiT from this lesson.
- The convergence is literal: SD3 is the combination of concepts taught across 50+ lessons. No single component is new in isolation—the innovation is the combination.

**Common mistakes:**
- Not distinguishing between what changed (architecture, conditioning mechanism, training objective, text encoders) and what is preserved (VAE, sampling loop structure, guidance). The lesson emphasizes that the pipeline structure survived.
- Forgetting that the pooled CLIP embedding provides adaLN-Zero conditioning (global path) while the per-token embeddings provide joint attention conditioning (per-token path). Both paths are needed.

</details>

---

## Key Takeaways

1. **Three encoders, three kinds of understanding.** CLIP ViT-L (123M, visual alignment), OpenCLIP ViT-bigG (354M, richer visual alignment), T5-XXL (4.7B, deep linguistic understanding). T5's embeddings are 5x wider than CLIP's. The text encoders collectively have more parameters than the denoising network—text understanding is worth investing in.

2. **Joint attention is one room, one conversation.** Concatenate text tokens and image tokens, run standard self-attention on the combined sequence. Four types of attention in one operation: image-to-text (was cross-attention), text-to-image (NEW), image-to-image (was self-attention), text-to-text (NEW). Simpler than cross-attention (one operation instead of two) and richer (bidirectional).

3. **Modality-specific projections, shared attention.** Text and image tokens have separate Q/K/V projections and separate FFN layers. They "speak their own language" but "hear each other" through shared attention. This is not naive concatenation—each modality maintains its representational identity.

4. **Flow matching delivers in practice.** SD3 produces good results at 20–30 steps, compared to 50+ for DDPM-based models. Same straight-line interpolation and velocity prediction you trained with. The concept is identical; the scale is different.

5. **Convergence, not revolution.** Every component of the SD3 pipeline traces to a lesson you completed: transformers (Series 4), latent diffusion (Series 6), CLIP (6.3.3), flow matching (7.2.2), patchify and adaLN-Zero (7.4.2), MMDiT and T5-XXL (this lesson). The frontier is not beyond your understanding—it IS your understanding, combined.