# Diffusion Transformers (DiT)

**Module 7.4, Lesson 2** | CourseAI

DiT replaces the U-Net with a standard vision transformer operating on latent patches. Two deep knowledge threads converge: transformers from Series 4 and latent diffusion from Series 6. Every component of the DiT block is something you already know, except the conditioning mechanism (adaLN-Zero).

**What you will do:**
- Implement patchify and unpatchify from scratch, verify tensor shapes at every step, and confirm the round-trip recovers the original latent
- Build one adaLN-Zero conditioning step, verify the identity property at alpha=0, and watch the block's contribution grow as alpha increases
- Load a pretrained DiT model and a U-Net, compare their architectures side-by-side: parameter counts, layer types, and the adaLN-Zero components
- Generate class-conditional ImageNet images with DiT-XL/2 at different classifier-free guidance scales

**For each exercise, PREDICT the output before running the cell.**

Every concept in this notebook comes from the lesson. Patchify as "tokenize the image," adaLN-Zero as adaptive norm with a zero-initialized gate, the two-knob scaling recipe. No new theory—just hands-on verification of what you just read.

**Estimated time:** 45–60 minutes. Exercises 1–2 are pure PyTorch (no GPU needed). Exercises 3–4 use pretrained models and benefit from a GPU runtime (~3 GB VRAM for DiT-XL/2 in float16).

## Setup

Run this cell to install dependencies and configure the environment.

**Important:** For Exercises 3–4, switch to a GPU runtime in Colab (Runtime > Change runtime type > T4 GPU). Exercises 1–2 work on CPU.

In [None]:
!pip install -q diffusers transformers accelerate safetensors timm

In [None]:
import torch
import torch.nn as nn
import math
import gc
import time
import matplotlib.pyplot as plt
from IPython.display import display

# Reproducible results
torch.manual_seed(42)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [14, 5]
plt.rcParams['figure.dpi'] = 100

print(f'Device: {device}')
print(f'Dtype: {dtype}')
if device.type == 'cpu':
    print('Note: No GPU detected. Exercises 1-2 work fine on CPU.')
    print('For Exercises 3-4, switch to GPU: Runtime > Change runtime type > T4 GPU')
print()
print('Setup complete.')

## Shared Helpers

Utility functions used across multiple exercises. Run this cell now.

In [None]:
def count_parameters(model):
    """Count total parameters in a model."""
    return sum(p.numel() for p in model.parameters())


def count_parameters_by_type(model):
    """Count parameters grouped by module type (e.g., Linear, LayerNorm)."""
    counts = {}
    for name, module in model.named_modules():
        module_type = module.__class__.__name__
        if module_type == type(model).__name__:
            continue  # skip the top-level model itself
        n_params = sum(p.numel() for p in module.parameters(recurse=False))
        if n_params > 0:
            counts[module_type] = counts.get(module_type, 0) + n_params
    return dict(sorted(counts.items(), key=lambda x: -x[1]))


def show_image_row(images, titles, suptitle=None, figsize=None):
    """Display a row of PIL images with titles."""
    n = len(images)
    fig_w = figsize[0] if figsize else max(5 * n, 12)
    fig_h = figsize[1] if figsize else 5
    fig, axes = plt.subplots(1, n, figsize=(fig_w, fig_h))
    if n == 1:
        axes = [axes]
    for ax, img, title in zip(axes, images, titles):
        ax.imshow(img)
        ax.set_title(title, fontsize=10)
        ax.axis('off')
    if suptitle:
        plt.suptitle(suptitle, fontsize=13, y=1.02)
    plt.tight_layout()
    plt.show()


def free_memory(*objects):
    """Delete objects and free GPU memory."""
    for obj in objects:
        del obj
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print('Memory freed.')


print('Helpers defined: count_parameters, count_parameters_by_type, show_image_row, free_memory')

---

## Exercise 1: Patchify and Unpatchify `[Guided]`

The lesson taught that DiT "tokenizes the image": take the noisy latent `[C, H, W]`, split it into non-overlapping patches of size `p x p`, flatten each patch, and project to `d_model` dimensions. This is the image equivalent of text tokenization.

The reverse operation—unpatchify—takes the transformer output sequence back to a spatial tensor. Together, patchify and unpatchify are the bridge between the latent diffusion world (spatial tensors) and the transformer world (token sequences).

In this exercise, you will implement both operations from scratch and verify:
1. The tensor shapes match the lesson's trace at every step
2. The round-trip patchify → unpatchify recovers the original tensor (before projection)

**Before running, predict:**
- A latent of shape `[4, 32, 32]` with patch size `p=2`: how many patches? What is each patch's raw dimension before projection?
- A latent of shape `[4, 64, 64]` with patch size `p=4`: how many patches?

In [None]:
# ============================================================
# Exercise 1: Implement patchify and unpatchify from scratch
# ============================================================

# --- Step 1: Define dimensions ---
C = 4        # latent channels (VAE output)
H = 32       # latent height (e.g., 256x256 image with 8x VAE downsampling)
W = 32       # latent width
p = 2        # patch size
d_model = 1152  # DiT-XL hidden dimension

# Create a random "noisy latent" tensor
z = torch.randn(1, C, H, W)  # batch=1 for clarity
print(f'Input noisy latent: {list(z.shape)}')
print()

In [None]:
# --- Step 2: Patchify ---
# Split the spatial dimensions into patches of size p x p.
# Each patch is a [C, p, p] volume, flattened to C*p*p dimensions.

num_patches_h = H // p   # patches along height
num_patches_w = W // p   # patches along width
num_patches = num_patches_h * num_patches_w  # total patch tokens
patch_dim = C * p * p    # raw dimension per patch before projection

print(f'Patches per row:    {num_patches_h}')
print(f'Patches per column: {num_patches_w}')
print(f'Total patches (L):  {num_patches}')
print(f'Raw patch dim:      C * p * p = {C} * {p} * {p} = {patch_dim}')
print()

# Reshape: [1, C, H, W] -> [1, C, num_patches_h, p, num_patches_w, p]
patches = z.reshape(1, C, num_patches_h, p, num_patches_w, p)

# Permute to group patch spatial dims together:
# [1, C, num_patches_h, p, num_patches_w, p]
# -> [1, num_patches_h, num_patches_w, C, p, p]
patches = patches.permute(0, 2, 4, 1, 3, 5)

# Flatten the patch grid into a sequence, and flatten each patch:
# [1, num_patches_h, num_patches_w, C, p, p] -> [1, L, C*p*p]
patches = patches.reshape(1, num_patches, patch_dim)

print(f'After patchify: {list(patches.shape)}')
print(f'  = [batch, {num_patches} tokens, {patch_dim} dims]')
print()
print('This matches the lesson\'s trace:')
print(f'  [4, 32, 32] -> [{num_patches} tokens, {patch_dim} dims]')

In [None]:
# --- Step 3: Linear projection to d_model ---
# This is the equivalent of the token embedding lookup in text transformers.
# A learned nn.Linear maps from patch_dim to d_model.

proj = nn.Linear(patch_dim, d_model)

with torch.no_grad():
    tokens = proj(patches)  # [1, L, d_model]

print(f'After linear projection: {list(tokens.shape)}')
print(f'  = [batch, {num_patches} tokens, {d_model} dims]')
print()
print('The transformer now has a sequence of 256 tokens, each with 1152 dimensions.')
print('It does not know or care that these tokens came from image patches.')

In [None]:
# --- Step 4: Add positional embeddings ---
# Same concept as positional encoding from Series 4.
# DiT uses learned positional embeddings: one vector per patch position.

pos_embed = nn.Parameter(torch.randn(1, num_patches, d_model) * 0.02)

tokens_with_pos = tokens + pos_embed

print(f'Positional embeddings shape: {list(pos_embed.shape)}')
print(f'Tokens + position:           {list(tokens_with_pos.shape)}')
print()
print(f'Final input to transformer:  [{num_patches}, {d_model}]')
print('(Same as the lesson\'s trace: [256, 1152])')

In [None]:
# --- Step 5: Unpatchify (reverse the operation) ---
# After N transformer blocks, the model has [L, d_model] tokens.
# Project back to patch dimensions, then reshape to spatial grid.

# Simulate transformer output (just use the patches before projection
# to verify the round-trip without the projection lossy step)
transformer_output = patches  # [1, L, patch_dim]

# In the real DiT: nn.Linear(d_model, patch_dim) projects back to patch dims.
# We skip projection here to verify exact round-trip.

# Reshape: [1, L, C*p*p] -> [1, num_patches_h, num_patches_w, C, p, p]
spatial = transformer_output.reshape(1, num_patches_h, num_patches_w, C, p, p)

# Permute back: [1, num_patches_h, num_patches_w, C, p, p]
# -> [1, C, num_patches_h, p, num_patches_w, p]
spatial = spatial.permute(0, 3, 1, 4, 2, 5)

# Reshape to original spatial dimensions: -> [1, C, H, W]
reconstructed = spatial.reshape(1, C, H, W)

print(f'Unpatchified output: {list(reconstructed.shape)}')
print()

# Verify round-trip: patchify -> unpatchify should recover the original
match = torch.allclose(z, reconstructed, atol=1e-6)
max_diff = (z - reconstructed).abs().max().item()
print(f'Round-trip exact match: {match}')
print(f'Max absolute difference: {max_diff}')
print()
print('Patchify -> unpatchify is lossless (ignoring the learned projections).')
print('The spatial information is preserved through the reshape operations.')

In [None]:
# --- Step 6: Try different latent sizes and patch sizes ---
# The lesson showed: L = (H/p) * (W/p)
# Smaller patch size = more tokens = finer spatial detail = quadratic attention cost.

print('Sequence length L = (H/p) * (W/p):')
print()
print(f'{"Latent":<16} {"Patch":<8} {"Tokens (L)":<12} {"Raw patch dim":<16} {"Attention cost"}')
print('-' * 70)

configs = [
    (4, 32, 32, 2),   # DiT-XL/2 on 256x256
    (4, 32, 32, 4),   # DiT-XL/4 on 256x256
    (4, 32, 32, 8),   # DiT-XL/8 on 256x256
    (4, 64, 64, 2),   # DiT on 512x512
    (4, 64, 64, 4),   # DiT on 512x512
]

for c, h, w, ps in configs:
    L = (h // ps) * (w // ps)
    raw_dim = c * ps * ps
    attn_cost = L * L  # O(L^2)
    print(f'[{c},{h},{w}]      p={ps:<4} L={L:<8} {c}*{ps}*{ps}={raw_dim:<10} O({attn_cost:,})')

print()
print('Key insight: halving patch size quadruples tokens AND increases')
print('attention cost by 16x (because attention is O(L^2)).')
print('This is the same tradeoff from SDXL, now controlled by patch size.')

### What Just Happened

You implemented the patchify and unpatchify operations from scratch and verified the full shape trace from the lesson:

```
Noisy latent:      [4, 32, 32]
After patchify:    [256, 16]     (256 tokens, each 4*2*2=16 dims)
After projection:  [256, 1152]   (projected to d_model)
After pos embed:   [256, 1152]   (same shape, position info added)
... N transformer blocks ...
After unpatchify:  [4, 32, 32]   (back to spatial tensor)
```

Key observations:

- **Patchify is reshaping, not convolution.** No learned parameters in the reshape itself. The learning happens in the linear projection (`nn.Linear(16, 1152)`).

- **The round-trip is exact.** Patchify followed by unpatchify (without projection) perfectly recovers the original tensor. No spatial information is lost—it is just rearranged.

- **Patch size is the resolution knob.** `p=2` on a `[4, 32, 32]` latent gives 256 tokens. `p=4` gives 64 tokens (4x fewer). `p=8` gives 16 tokens (16x fewer). Halving the patch size quadruples the tokens and increases attention cost by 16x. Same quadratic tradeoff from SDXL.

- **"Tokenize the image" is literal.** The patchify + linear projection is structurally identical to tokenization + embedding lookup in text transformers. The transformer processes both kinds of sequences the same way.

---

## Exercise 2: adaLN-Zero Forward Pass `[Guided]`

The lesson taught that DiT conditions each transformer block via adaLN-Zero: the conditioning vector `c` (timestep + class embedding) is projected through an MLP to produce six parameters per block:

```
c -> MLP -> (γ₁, β₁, α₁, γ₂, β₂, α₂)

MHA sub-layer:  x' = x + α₁ * MHA(γ₁ * LayerNorm(x) + β₁)
FFN sub-layer:  output = x' + α₂ * FFN(γ₂ * LayerNorm(x') + β₂)
```

The critical design choice: all alpha values are **initialized to zero**. This means every DiT block starts as an identity function—input in, same input out. The model gradually learns what each block should contribute.

In this exercise, you will:
1. Build the adaLN-Zero MLP that produces the six parameters
2. Apply adaptive layer norm (scale + shift) to a LayerNorm output
3. Apply the gated residual connection
4. Verify the identity property at alpha=0
5. Watch the block's contribution grow as alpha increases

**Before running, predict:**
- When alpha=0, what will the output of the block be? (Think about the residual: `x + alpha * f(x)`.)
- When alpha=1, will the output be exactly `x + f(x)` (the standard residual connection)?

In [None]:
# ============================================================
# Exercise 2: Build adaLN-Zero from scratch
# ============================================================

# --- Step 1: Define dimensions ---
d_model = 384    # Use DiT-S dimension for speed
seq_len = 16     # Small sequence for clarity (16 patch tokens)
cond_dim = 384   # Conditioning vector dimension (same as d_model in DiT)

print(f'd_model:  {d_model}')
print(f'seq_len:  {seq_len} (patch tokens)')
print(f'cond_dim: {cond_dim}')
print()

In [None]:
# --- Step 2: Build the adaLN-Zero MLP ---
# This MLP takes the conditioning vector c and produces 6 * d_model parameters:
# (gamma_1, beta_1, alpha_1, gamma_2, beta_2, alpha_2)
# Each is a vector of d_model dimensions.

adaln_mlp = nn.Sequential(
    nn.SiLU(),
    nn.Linear(cond_dim, 6 * d_model),
)

# Critical: initialize the final linear layer to output zeros.
# This ensures all six parameter vectors start at zero.
nn.init.zeros_(adaln_mlp[1].weight)
nn.init.zeros_(adaln_mlp[1].bias)

print(f'adaLN MLP output dimension: {6 * d_model} = 6 * {d_model}')
print(f'  = gamma_1[{d_model}] + beta_1[{d_model}] + alpha_1[{d_model}]')
print(f'  + gamma_2[{d_model}] + beta_2[{d_model}] + alpha_2[{d_model}]')
print()
print(f'MLP parameters: {count_parameters(adaln_mlp):,}')

In [None]:
# --- Step 3: Produce the six conditioning parameters ---
# Feed a random conditioning vector through the MLP.

# Simulate conditioning: c = timestep_embedding + class_embedding
c = torch.randn(1, cond_dim)

with torch.no_grad():
    params = adaln_mlp(c)  # [1, 6 * d_model]

# Split into the six parameter vectors
gamma_1, beta_1, alpha_1, gamma_2, beta_2, alpha_2 = params.chunk(6, dim=-1)

print('Six adaLN-Zero parameters (all initialized to ~zero):')
print(f'  gamma_1: mean={gamma_1.mean().item():.6f}, max_abs={gamma_1.abs().max().item():.6f}')
print(f'  beta_1:  mean={beta_1.mean().item():.6f}, max_abs={beta_1.abs().max().item():.6f}')
print(f'  alpha_1: mean={alpha_1.mean().item():.6f}, max_abs={alpha_1.abs().max().item():.6f}')
print(f'  gamma_2: mean={gamma_2.mean().item():.6f}, max_abs={gamma_2.abs().max().item():.6f}')
print(f'  beta_2:  mean={beta_2.mean().item():.6f}, max_abs={beta_2.abs().max().item():.6f}')
print(f'  alpha_2: mean={alpha_2.mean().item():.6f}, max_abs={alpha_2.abs().max().item():.6f}')
print()
print('All values are near zero because we initialized the MLP output layer to zeros.')
print('This is the "Zero" in adaLN-Zero.')

In [None]:
# --- Step 4: Apply adaptive layer norm + gated residual ---
# Trace one MHA sub-layer step by step.

# Create a random input (simulating patch token sequence)
x = torch.randn(1, seq_len, d_model)

# Standard LayerNorm (no learnable parameters needed;
# adaLN replaces the learned gamma/beta with conditioning-dependent ones)
layer_norm = nn.LayerNorm(d_model, elementwise_affine=False)

# Step 4a: Adaptive layer norm
# Standard LN: normalize, then gamma * x + beta
# adaLN: normalize, then gamma(c) * x + beta(c)
# gamma_1 and beta_1 are [1, d_model] -> broadcast over [1, seq_len, d_model]
normed = layer_norm(x)
adaln_output = gamma_1.unsqueeze(1) * normed + beta_1.unsqueeze(1)

print('Step 4a: Adaptive Layer Norm')
print(f'  x shape:          {list(x.shape)}')
print(f'  LayerNorm(x):     {list(normed.shape)}')
print(f'  gamma_1 * LN + beta_1: {list(adaln_output.shape)}')
print()

# Step 4b: Simulate MHA output (just use random for demonstration)
# In a real DiT block, this would be MultiHeadAttention(adaln_output)
mha_output = torch.randn_like(x) * 0.1  # small random "attention output"

# Step 4c: Gated residual connection
# x' = x + alpha_1 * MHA_output
x_prime = x + alpha_1.unsqueeze(1) * mha_output

print('Step 4c: Gated Residual')
print(f'  x\' = x + alpha_1 * MHA(...)  shape: {list(x_prime.shape)}')
print(f'  alpha_1 max_abs: {alpha_1.abs().max().item():.6f}')
print()

In [None]:
# --- Step 5: Verify the identity property at alpha=0 ---
# With the zero-initialized MLP, alpha is ~0.
# So: x' = x + 0 * MHA(...) = x

# Check: is x' approximately equal to x?
diff = (x_prime - x).abs().max().item()
print(f'Max |x\' - x|: {diff:.8f}')
print(f'Identity property holds: {diff < 1e-5}')
print()
print('At initialization (alpha=0), the gated residual connection')
print('means the MHA output contributes NOTHING. The input passes')
print('through unchanged. The block is an identity function.')
print()
print('This is the same principle as:')
print('  - ControlNet zero convolution: start contributing nothing')
print('  - LoRA B matrix at zero: bypass starts at zero')
print('All three ensure new components start undamaged.')

In [None]:
# --- Step 6: Watch the block's contribution grow ---
# Manually vary alpha from 0 to 1 and measure how much the block
# output deviates from the input.

x_test = torch.randn(1, seq_len, d_model)
mha_test = torch.randn_like(x_test) * 0.1  # simulated MHA output

alphas = [0.0, 0.01, 0.05, 0.1, 0.25, 0.5, 1.0]
deviations = []

print(f'{"alpha":<10} {"mean |output - input|":<25} {"max |output - input|"}')
print('-' * 60)

for a in alphas:
    output = x_test + a * mha_test
    mean_dev = (output - x_test).abs().mean().item()
    max_dev = (output - x_test).abs().max().item()
    deviations.append(mean_dev)
    print(f'{a:<10.2f} {mean_dev:<25.6f} {max_dev:.6f}')

print()
print('The block\'s contribution grows linearly with alpha.')
print('At alpha=0: identity. At alpha=1: full standard residual connection.')
print('During training, the model learns the optimal alpha for each block.')

In [None]:
# --- Step 7: Visualize alpha's effect ---

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(alphas, deviations, 'o-', color='#22d3ee', linewidth=2, markersize=8)
ax.set_xlabel('alpha (gate value)', fontsize=12)
ax.set_ylabel('Mean |output - input|', fontsize=12)
ax.set_title('adaLN-Zero: Block Contribution vs Gate Value', fontsize=13)
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax.annotate('Identity\n(alpha=0)', xy=(0, 0), xytext=(0.15, deviations[-1]*0.15),
            fontsize=10, color='#fbbf24',
            arrowprops=dict(arrowstyle='->', color='#fbbf24', lw=1.5))
ax.annotate('Full residual\n(alpha=1)', xy=(1, deviations[-1]),
            xytext=(0.6, deviations[-1]*0.85),
            fontsize=10, color='#34d399',
            arrowprops=dict(arrowstyle='->', color='#34d399', lw=1.5))
plt.tight_layout()
plt.show()

print('At initialization, all DiT blocks sit at the left (alpha=0, identity).')
print('Through training, each block learns its own alpha value,')
print('gradually moving rightward as it learns what to contribute.')

### What Just Happened

You built the adaLN-Zero conditioning mechanism from scratch and verified its key properties:

- **Six parameters per block.** The conditioning MLP produces `(γ₁, β₁, α₁, γ₂, β₂, α₂)` from the conditioning vector. Three per sub-layer: scale (γ), shift (β), gate (α).

- **Zero initialization means identity.** With the output layer initialized to zeros, all six parameters start at zero. The alpha gates are zero, so `x' = x + 0 * f(x) = x`. The entire block is an identity function at initialization.

- **The gate controls contribution magnitude.** As alpha increases from 0 to 1, the block's output deviates more from its input. At alpha=1, you get the standard residual connection `x + f(x)`. The model learns the optimal alpha for each block during training.

- **This extends adaptive group norm.** The U-Net's AdaGN has gamma and beta (2 parameters). adaLN-Zero adds alpha (3 parameters per sub-layer). The alpha + zero initialization is the key difference—it is the same safety principle as ControlNet's zero convolution and LoRA's zero-initialized B matrix.

- **Compare to AdaGN:** adaptive group norm starts with gamma=1, beta=0 (standard normalization behavior). adaLN-Zero starts with alpha=0 (identity function). The U-Net block immediately contributes from the first training step. The DiT block starts contributing nothing and gradually learns to contribute.

---

## Exercise 3: Architecture Inspection `[Supported]`

The lesson compared the U-Net and DiT architectures side by side:

| | U-Net | DiT |
|---|---|---|
| Basic unit | Conv residual block + optional attention | Standard transformer block (MHA + FFN) |
| Convolutions | Every layer | None (only patchify/unpatchify) |
| Skip connections | Encoder-to-decoder | None (only within-block residual) |
| Scaling recipe | Ad hoc | d_model and N |

Now you will verify this by loading real models and comparing their architectures. You will:
1. Load a pretrained DiT model
2. Load a U-Net from a diffusion pipeline
3. Compare parameter counts and layer types
4. Find the adaLN-Zero components inside the DiT
5. Observe how DiT parameter counts scale across model sizes

Fill in the TODO markers to complete the comparison.

In [None]:
# ============================================================
# Exercise 3: Load and inspect DiT vs U-Net
# ============================================================

# --- Step 1: Load a pretrained DiT-XL/2 model ---
# The DiT authors released pretrained models on HuggingFace.
# We load via DiTPipeline from diffusers, then access pipe.transformer
# as a proper model object—this lets us use the same inspection tools
# (named_modules, count_parameters_by_type) on both DiT and U-Net.

from diffusers import DiTPipeline

print('Loading DiT-XL/2 (256x256 ImageNet) via DiTPipeline...')
dit_pipe = DiTPipeline.from_pretrained(
    'facebook/DiT-XL-2-256',
    torch_dtype=torch.float32,  # float32 for inspection (no GPU needed)
)
dit_model = dit_pipe.transformer  # DiTTransformer2DModel—a proper nn.Module

print(f'DiT model type: {type(dit_model).__name__}')
print(f'DiT-XL/2 loaded as a model object.')

In [None]:
# --- Step 2: Analyze DiT-XL/2 parameter structure ---
# Now that we have a proper model object, we can use standard PyTorch inspection.

total_params = count_parameters(dit_model)
print(f'DiT-XL/2 total parameters: {total_params:,} ({total_params / 1e6:.1f}M)')
print()

# Group parameters by top-level module
groups = {}
for name, module in dit_model.named_children():
    n_params = count_parameters(module)
    if n_params > 0:
        groups[name] = n_params

print('Parameter breakdown by top-level component:')
for group, count in sorted(groups.items(), key=lambda x: -x[1]):
    pct = count / total_params * 100
    print(f'  {group:<35} {count:>12,} ({pct:.1f}%)')

In [None]:
# --- Step 3: Inspect one DiT block's parameters ---
# Look at block 0 to find all the components the lesson described:
# MHA (Q/K/V/O projections), FFN, LayerNorm, and the adaLN-Zero MLP.

block_0 = dit_model.transformer_blocks[0]
print(f'DiT block 0 type: {type(block_0).__name__}')
print()
print('Sub-modules in DiT block 0:')
for name, module in block_0.named_modules():
    if name == '':
        continue
    n_params = sum(p.numel() for p in module.parameters(recurse=False))
    if n_params > 0:
        print(f'  {name:<40} {type(module).__name__:<20} ({n_params:,} params)')

block_0_params = count_parameters(block_0)
print(f'\n  Block 0 total: {block_0_params:,} parameters')
print()
print('Look for:')
print('  - attn1: the self-attention module (Q/K/V/O projections)')
print('  - ff: the FFN (feed-forward network)')
print('  - norm1/norm2: LayerNorm layers')
print('  - The adaLN modulation MLP that produces the 6 conditioning parameters')

In [None]:
# --- Step 4: Load SD v1.5 U-Net for comparison ---
from diffusers import UNet2DConditionModel

print('Loading SD v1.5 U-Net...')
unet = UNet2DConditionModel.from_pretrained(
    'stable-diffusion-v1-5/stable-diffusion-v1-5',
    subfolder='unet',
    torch_dtype=torch.float32,
)

unet_params = count_parameters(unet)
print(f'SD v1.5 U-Net parameters: {unet_params:,} ({unet_params / 1e6:.1f}M)')
print(f'DiT-XL/2 parameters:      {total_params:,} ({total_params / 1e6:.1f}M)')
print()

In [None]:
# --- Step 5: Compare layer types ---
# The lesson's key claim: DiT has NO convolutions in the transformer blocks.
# The U-Net has convolutions everywhere.

# TODO: Use count_parameters_by_type() on BOTH models.
# Print the U-Net breakdown first, then the DiT breakdown.
# Compare: which module types dominate each architecture?
# Hint: The function returns a dict of {module_type: param_count}.
raise NotImplementedError(
    "TODO: Call count_parameters_by_type(unet) and count_parameters_by_type(dit_model),\n"
    "then print both results side by side."
)

In [None]:
# --- Step 6: Verify no convolutions in DiT transformer blocks ---
# Scan the DiT model's named_modules for any Conv2d layers.

# TODO: Iterate through dit_model.named_modules() and collect any module
# that is an instance of nn.Conv2d. Print the results.
# Expected: zero Conv2d layers in the transformer blocks themselves.
# (There may be a conv-like projection in the patchify/unpatchify layer,
#  but the transformer blocks should be convolution-free.)
raise NotImplementedError(
    "TODO: Search for Conv2d modules in dit_model.named_modules(). See hint above."
)

In [None]:
# --- Step 7: DiT scaling across model sizes ---
# The lesson showed DiT's systematic scaling recipe:
#   DiT-S:  d_model=384,  N=12, ~33M
#   DiT-B:  d_model=768,  N=12, ~130M
#   DiT-L:  d_model=1024, N=24, ~458M
#   DiT-XL: d_model=1152, N=28, ~675M
#
# We can estimate parameter counts from the architecture formula.
# A transformer block has roughly:
#   MHA: 4 * d_model^2 (Q, K, V, O projections)
#   FFN: 2 * d_model * 4*d_model = 8 * d_model^2 (two linear layers)
#   adaLN: ~6 * d_model * d_model (MLP producing 6*d_model outputs)
#   Total per block: ~18 * d_model^2 (approximate)

# TODO: For each DiT variant (S, B, L, XL), compute the estimated
# total parameter count using: N_blocks * 18 * d_model^2.
# Print a table comparing the estimate to the published values.
# Also print the actual count for DiT-XL from the model we loaded.

dit_variants = [
    ('DiT-S',  384,  12,  33),
    ('DiT-B',  768,  12, 130),
    ('DiT-L',  1024, 24, 458),
    ('DiT-XL', 1152, 28, 675),
]

raise NotImplementedError(
    "TODO: Estimate parameters for each variant and print a comparison table.\n"
    f"Hint: DiT-XL actual count from loaded model = {total_params:,}"
)

<details>
<summary>Solution</summary>

The key insight is that the U-Net is dominated by `Conv2d` layers (convolutions at every resolution), while DiT has no convolutions in its transformer blocks—only `Linear` layers (for Q/K/V projections, FFN, and adaLN-Zero MLP). Because we loaded both as proper model objects, `count_parameters_by_type()` works symmetrically on both.

**Step 5: Compare layer types (both models)**
```python
unet_by_type = count_parameters_by_type(unet)
dit_by_type = count_parameters_by_type(dit_model)

print('SD v1.5 U-Net parameters by module type:')
for module_type, count in unet_by_type.items():
    pct = count / unet_params * 100
    print(f'  {module_type:<30} {count:>12,} ({pct:.1f}%)')

print()
print('DiT-XL/2 parameters by module type:')
for module_type, count in dit_by_type.items():
    pct = count / total_params * 100
    print(f'  {module_type:<30} {count:>12,} ({pct:.1f}%)')
```

You should see Conv2d as the dominant module type in the U-Net (often 60%+ of parameters), with GroupNorm, Linear, and other types making up the rest. In the DiT, Linear should dominate with LayerNorm and no Conv2d in the transformer blocks.

**Step 6: No convolutions in DiT transformer blocks**
```python
conv_modules = [
    (name, module)
    for name, module in dit_model.named_modules()
    if isinstance(module, nn.Conv2d)
]
print(f'Conv2d modules in DiT: {len(conv_modules)}')
for name, module in conv_modules:
    n_params = sum(p.numel() for p in module.parameters())
    print(f'  {name:<50} {n_params:,} params')
if len(conv_modules) == 0:
    print('  None found. DiT uses no convolutions at all.')
else:
    # Check if any are inside transformer_blocks
    block_convs = [n for n, _ in conv_modules if 'transformer_blocks' in n]
    print(f'\n  Conv2d inside transformer_blocks: {len(block_convs)}')
    if len(block_convs) == 0:
        print('  No convolutions in the transformer blocks themselves.')
        print('  Any Conv2d found is in the patchify/unpatchify projection.')
```

**Step 7: DiT scaling estimates**
```python
print(f'{"Model":<10} {"d_model":<10} {"N":<6} {"Est (M)":<12} {"Published (M)":<15} {"Ratio"}')
print('-' * 65)
for name, d, n, published in dit_variants:
    est = n * 18 * d * d / 1e6  # rough estimate
    ratio = est / published
    print(f'{name:<10} {d:<10} {n:<6} {est:<12.1f} {published:<15} {ratio:.2f}x')
print()
print(f'DiT-XL actual (from loaded model): {total_params:,} ({total_params / 1e6:.1f}M)')
print()
print('The estimates are rough (they ignore embeddings, final layer, biases),')
print('but they show the scaling pattern: doubling d_model quadruples parameters')
print('per block (because of d_model^2 terms). Doubling N doubles total parameters.')
print('Two knobs, predictable scaling.')
```

The estimates will not match exactly because they ignore the patch embedder, positional embeddings, final layer norm, and output projection. But the ratios between model sizes should be approximately correct, demonstrating the systematic scaling recipe.

**Common mistakes:**
- Expecting the parameter estimates to be exact. The `18 * d_model^2` formula is a rough approximation. The actual count includes biases, embedding layers, and the final output head.
- Confusing the patchify projection (which may use Conv2d with kernel_size=patch_size and stride=patch_size) with convolutions inside the transformer blocks. The patchify Conv2d is equivalent to the flattened-patch linear projection—it is just the input embedding, not a spatial operation in the denoising backbone.

</details>

In [None]:
# --- Cleanup before Exercise 4 ---
free_memory(unet, dit_model, dit_pipe)
gc.collect()
print('Ready for Exercise 4.')

### What Just Happened

You loaded a real DiT model and a real U-Net as proper model objects and compared them with the same inspection tools. Key observations:

- **DiT has ~675M parameters, U-Net has ~860M.** Different architectures, similar scale. DiT achieves better results with fewer parameters.

- **The U-Net is dominated by Conv2d layers.** `count_parameters_by_type()` shows convolutions at every resolution level as the primary spatial processing operation. GroupNorm normalizes within each conv block.

- **DiT is dominated by Linear layers.** The same `count_parameters_by_type()` function on the DiT model shows Linear layers for Q/K/V projections, FFN, and adaLN-Zero MLP. No Conv2d inside the transformer blocks.

- **The adaLN modulation layer is clearly visible** in each DiT block. It is the MLP that takes the conditioning vector and produces the six modulation parameters (γ₁, β₁, α₁, γ₂, β₂, α₂).

- **DiT scaling is systematic.** From DiT-S to DiT-XL, parameter count grows predictably with `d_model^2 * N`. Two knobs, not twenty.

- **Symmetric comparison matters.** Loading both models as `nn.Module` objects means the same inspection functions (`count_parameters_by_type`, `named_modules`) work on both. This makes the architectural differences concrete and verifiable.

---

## Exercise 4: Generate with DiT `[Independent]`

From the lesson: DiT-XL/2 (patch size 2, 675M params) achieved FID 2.27 on ImageNet 256×256 class-conditional generation—state-of-the-art at the time. The scaling argument is not just theory—you can see it in generated images.

DiT is class-conditional on ImageNet. It uses class labels (integers 0–1000), not text prompts. Classifier-free guidance works the same way: amplify the difference between the class-conditioned and unconditional predictions.

### Your Task

1. **Load DiT-XL/2** using the HuggingFace `DiTPipeline` from `diffusers`
2. **Generate class-conditional images** for at least two different ImageNet classes (e.g., golden retriever=207, macaw=88, volcano=980, castle=483)
3. **Vary the classifier-free guidance scale** (e.g., 1.0, 1.5, 4.0, 7.5) for the same class and seed
4. **Compare the results**: observe how guidance scale affects quality, detail, and diversity
5. **Bonus**: vary the number of sampling steps and observe the quality-speed tradeoff

### Hints

- The pipeline class is `DiTPipeline` from `diffusers`
- The model ID is `"facebook/DiT-XL-2-256"`
- Generation uses `pipe(class_labels=[207], num_inference_steps=50, guidance_scale=4.0)`
- `class_labels` takes a list of ImageNet class indices
- Use `generator=torch.Generator(device='cpu').manual_seed(42)` for reproducibility
- DiT-XL/2 needs ~3 GB VRAM in float16. Use `torch_dtype=torch.float16` when loading
- Some useful ImageNet classes: 207 (golden retriever), 88 (macaw), 980 (volcano), 483 (castle), 388 (panda), 279 (arctic fox), 33 (loggerhead turtle), 250 (husky)

In [None]:
# ============================================================
# Exercise 4: Generate class-conditional images with DiT-XL/2
# ============================================================
#
# Load the DiT pipeline and generate images.
# Compare different class labels and guidance scales.
#
# Your code here:



In [None]:
# --- Your guidance scale comparison ---
#
# Pick one class, keep the seed fixed, vary guidance_scale.
# Use show_image_row() to display the results.
#
# Your code here:



In [None]:
# --- (Bonus) Your step count comparison ---
#
# Pick one class and guidance scale, vary num_inference_steps.
# Try steps: 10, 25, 50, 100
#
# Your code here:



<details>
<summary>Solution</summary>

The key insight is that DiT generation looks just like any other diffusion pipeline—the only difference is the denoising backbone. The pipeline handles patchify/unpatchify, adaLN-Zero conditioning, and sampling internally. You provide a class label instead of a text prompt.

```python
from diffusers import DiTPipeline

# --- Load DiT-XL/2 ---
print('Loading DiT-XL/2...')
pipe = DiTPipeline.from_pretrained(
    'facebook/DiT-XL-2-256',
    torch_dtype=torch.float16,
)
pipe = pipe.to(device)
print('DiT-XL/2 loaded.')

# --- Generate multiple classes ---
classes = {
    207: 'Golden Retriever',
    88: 'Macaw',
    980: 'Volcano',
    279: 'Arctic Fox',
}

class_images = []
class_titles = []
for class_id, class_name in classes.items():
    generator = torch.Generator(device='cpu').manual_seed(42)
    result = pipe(
        class_labels=[class_id],
        num_inference_steps=50,
        guidance_scale=4.0,
        generator=generator,
    )
    class_images.append(result.images[0])
    class_titles.append(f'Class {class_id}\n{class_name}')
    print(f'  Generated class {class_id} ({class_name})')

show_image_row(
    class_images, class_titles,
    suptitle='DiT-XL/2: Class-Conditional ImageNet Generation (guidance=4.0, 50 steps)',
    figsize=(20, 5),
)

# --- Guidance scale comparison ---
guidance_scales = [1.0, 1.5, 4.0, 7.5]
cfg_images = []
cfg_titles = []

for gs in guidance_scales:
    generator = torch.Generator(device='cpu').manual_seed(42)
    result = pipe(
        class_labels=[207],  # golden retriever
        num_inference_steps=50,
        guidance_scale=gs,
        generator=generator,
    )
    cfg_images.append(result.images[0])
    cfg_titles.append(f'cfg={gs}')
    print(f'  Generated guidance_scale={gs}')

show_image_row(
    cfg_images, cfg_titles,
    suptitle='DiT-XL/2: Golden Retriever (class 207) at Different Guidance Scales',
    figsize=(20, 5),
)

print('Observations:')
print('  cfg=1.0: No guidance. Diverse but potentially lower quality.')
print('  cfg=1.5: Mild guidance. Good balance of diversity and quality.')
print('  cfg=4.0: Standard DiT guidance. Sharp, detailed, class-typical.')
print('  cfg=7.5: Strong guidance. Very class-typical but may oversaturate.')

# --- (Bonus) Step count comparison ---
step_counts = [10, 25, 50, 100]
step_images = []
step_titles = []

for steps in step_counts:
    generator = torch.Generator(device='cpu').manual_seed(42)
    start = time.time()
    result = pipe(
        class_labels=[207],
        num_inference_steps=steps,
        guidance_scale=4.0,
        generator=generator,
    )
    elapsed = time.time() - start
    step_images.append(result.images[0])
    step_titles.append(f'{steps} steps\n{elapsed:.1f}s')
    print(f'  {steps} steps: {elapsed:.1f}s')

show_image_row(
    step_images, step_titles,
    suptitle='DiT-XL/2: Quality vs Steps (Golden Retriever, cfg=4.0)',
    figsize=(20, 5),
)
```

**Key observations:**
- DiT generates high-quality class-conditional images. The golden retriever, macaw, and volcano should be clearly recognizable and well-composed.
- Guidance scale has a significant effect. At cfg=1.0, images are diverse but may lack coherence. At cfg=4.0 (the paper's recommended value), images are sharp and class-typical. At cfg=7.5, oversaturation may appear—same phenomenon as with U-Net models.
- More steps generally means better quality, but with diminishing returns. The jump from 10 to 25 steps is dramatic; from 50 to 100 is subtle.
- The pipeline API looks nearly identical to Stable Diffusion—`class_labels` instead of `prompt` is the main difference. This confirms the lesson: DiT replaces only the denoising network. The sampling loop, scheduler, and output processing are the same.

**Common mistakes:**
- Forgetting `torch_dtype=torch.float16` when loading. DiT-XL/2 in float32 needs ~6 GB VRAM, which may exceed Colab's free T4.
- Using text prompts instead of class labels. DiT is class-conditional, not text-conditional. Text conditioning comes in the next lesson (SD3/Flux).
- Not setting a manual seed for reproducibility. Without a fixed seed, each generation is different, making guidance scale comparisons meaningless.

</details>

---

## Key Takeaways

1. **Patchify is reshaping + linear projection, not magic.** Split the latent `[C, H, W]` into `(H/p) * (W/p)` patches of size `[C, p, p]`, flatten to `C*p*p` dimensions, project to `d_model`. The round-trip is lossless. This is the image equivalent of text tokenization—the transformer processes both kinds of sequences identically.

2. **adaLN-Zero = adaptive norm + zero-initialized gate.** Six parameters per block from one conditioning vector: scale (γ), shift (β), and gate (α) for each sub-layer. At initialization, alpha=0 makes every block an identity function. The model learns what each block should contribute. Same zero-initialization safety pattern as ControlNet and LoRA.

3. **DiT has no convolutions in its transformer blocks.** Only Linear layers for Q/K/V projections, FFN, and adaLN-Zero MLP. The U-Net is dominated by Conv2d. DiT removes convolutions, encoder-decoder hierarchy, and skip connections—and gets better results at sufficient scale.

4. **Two knobs, predictable scaling.** From DiT-S (33M) to DiT-XL (675M): increase d_model and N. Parameter count scales as `d_model^2 * N`. No ad hoc decisions about channels, resolutions, or where to add attention. The same recipe that scaled GPT-2 to GPT-3.

5. **Same pipeline, different denoising network.** DiT-XL/2 generates through the same sampling loop as every diffusion model you have seen. Class labels instead of text prompts. Guidance scale works the same way. The VAE, scheduler, and output processing are unchanged. Only the middle box—the denoising network—is different.