# 2. Architecture Evolution: From UNet to MMDiT
**How Stable Diffusion's Brain Got an Upgrade**

---

*This notebook traces the architectural evolution of diffusion model denoisers from Stable Diffusion 1.5 through FLUX.1. No GPU required --- this is a theory and visualization notebook.*

**What you'll learn:**
- How the latent diffusion framework works across all SD variants
- Why the UNet was eventually replaced by Transformers
- What makes MMDiT (Multimodal Diffusion Transformer) a breakthrough
- How FLUX.1 combines every major advance into a single model

## The Latent Diffusion Framework

Every Stable Diffusion variant --- from SD 1.5 to FLUX.1 --- shares the same high-level framework: **Latent Diffusion**. The key insight is that we never operate on raw pixels. Instead, we compress images into a compact **latent space** and do all the heavy lifting there.

### The Three Core Components

| Component | Role | Details |
|-----------|------|---------|
| **VAE (Encoder/Decoder)** | Compress & reconstruct images | 512x512x3 image $\rightarrow$ 64x64x4 latent (8x spatial compression) |
| **Denoiser (UNet or Transformer)** | Predict and remove noise from latents | The "brain" --- this is what evolves across SD versions |
| **Text Encoder (CLIP, T5)** | Convert text prompts to conditioning vectors | Provides semantic guidance to the denoiser |

### The Generation Pipeline

```
"A cat wearing a top hat"
         |
         v
  [Text Encoder]  (CLIP / T5)
         |
         v
  conditioning vectors
         |
         v
  [Denoiser]  <--- operates on 64x64x4 latent (NOT 512x512x3 pixels!)
   (UNet or    
   Transformer)   Pure noise --> Structured latent (iterative denoising)
         |
         v
  [VAE Decoder]  
         |
         v
  512x512x3 image
```

### Why Latent Space?

Working in latent space is **~64x cheaper** computationally than working in pixel space:

- **Pixel space**: 512 x 512 x 3 = **786,432** values per image
- **Latent space**: 64 x 64 x 4 = **16,384** values per latent
- **Reduction factor**: ~48x fewer values, and the denoiser complexity scales quadratically with spatial dimensions

This is what made high-resolution diffusion practical on consumer hardware.

> **Key insight**: The denoiser architecture (UNet vs. Transformer) is the component that has evolved most dramatically. The VAE and the overall latent diffusion framework remain largely unchanged.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np

# ============================================================
# Architecture Comparison Table
# ============================================================

fig, ax = plt.subplots(figsize=(14, 5))
ax.axis('off')

columns = ['Model', 'Year', 'Denoiser', 'Parameters', 'Text Encoder(s)',
           'Training Objective', 'Typical Steps']

data = [
    ['SD 1.5',     '2022', 'UNet',  '860M',  'CLIP-L',                  'e-prediction',  '50'],
    ['SD 2.1',     '2022', 'UNet',  '865M',  'OpenCLIP-H',             'v-prediction',  '50'],
    ['SDXL',       '2023', 'UNet',  '2.6B',  'CLIP-L + CLIP-G',       'e-prediction',  '50'],
    ['SD3 Medium', '2024', 'MMDiT', '2B',    'CLIP-L + CLIP-G + T5-XXL', 'Flow Matching', '28'],
    ['FLUX.1',     '2024', 'MMDiT', '12B',   'CLIP-L + T5-XXL',       'Flow Matching', '4 (schnell)'],
]

# Row colors
row_colors = [
    '#E8F5E9',  # light green
    '#E8F5E9',  # light green
    '#FFF3E0',  # light orange
    '#FFEBEE',  # light red
    '#FFCDD2',  # medium red
]

header_color = '#37474F'
header_text_color = 'white'

table = ax.table(
    cellText=data,
    colLabels=columns,
    cellLoc='center',
    loc='center',
)

table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.0, 1.8)

# Style header row
for j in range(len(columns)):
    cell = table[0, j]
    cell.set_facecolor(header_color)
    cell.set_text_props(color=header_text_color, fontweight='bold', fontsize=10)
    cell.set_edgecolor('white')
    cell.set_linewidth(1.5)

# Style data rows
for i in range(len(data)):
    for j in range(len(columns)):
        cell = table[i + 1, j]
        cell.set_facecolor(row_colors[i])
        cell.set_edgecolor('white')
        cell.set_linewidth(1.5)
        # Bold the model name column
        if j == 0:
            cell.set_text_props(fontweight='bold')
        # Highlight denoiser column for MMDiT models
        if j == 2 and data[i][2] == 'MMDiT':
            cell.set_text_props(fontweight='bold', color='#C62828')

ax.set_title('Diffusion Model Architecture Comparison',
             fontsize=16, fontweight='bold', pad=20, color='#212121')

plt.tight_layout()
plt.show()

## SD 1.x --- The Original UNet (2022)

Stable Diffusion 1.5 introduced the **UNet** as the denoiser backbone. This architecture, borrowed from medical image segmentation, turned out to be remarkably effective for iterative denoising.

### UNet Architecture Overview

The UNet follows an **encoder-decoder** structure with **skip connections**:

1. **Encoder (Downsampling)**: Progressively reduces spatial resolution while increasing channel depth
   - 64x64 $\rightarrow$ 32x32 $\rightarrow$ 16x16 $\rightarrow$ 8x8
   - Each level: ResNet blocks + Self-Attention + Cross-Attention

2. **Bottleneck**: Processes the most compressed representation (8x8)

3. **Decoder (Upsampling)**: Mirrors the encoder, progressively restoring resolution
   - 8x8 $\rightarrow$ 16x16 $\rightarrow$ 32x32 $\rightarrow$ 64x64
   - **Skip connections** carry fine-grained spatial details from encoder to decoder

### Cross-Attention: How Text Guides Generation

At each resolution level, **cross-attention** layers inject text conditioning:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V$$

Where:
- $Q$ (Query) comes from the **image latents**
- $K$ (Key) and $V$ (Value) come from the **text embeddings**

This is **one-way conditioning**: text informs the image, but the image cannot inform the text representation. This asymmetry becomes important when we discuss MMDiT.

### Limitations of SD 1.x

- **77 CLIP token limit**: Prompts are truncated beyond ~77 tokens
- **Limited spatial understanding**: Struggles with complex compositions ("a red cube on top of a blue sphere")
- **One-way conditioning**: Text cannot adapt based on what the image looks like
- **860M parameters**: Relatively small model capacity

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.patches as FancyBboxPatch
from matplotlib.patches import FancyArrowPatch

# ============================================================
# UNet Architecture Diagram
# ============================================================

fig, ax = plt.subplots(figsize=(14, 9))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')

def draw_block(ax, x, y, w, h, label, color, fontsize=8, text_color='white'):
    """Draw a rounded rectangle block with a centered label."""
    rect = mpatches.FancyBboxPatch(
        (x, y), w, h,
        boxstyle="round,pad=0.1",
        facecolor=color, edgecolor='white', linewidth=1.5
    )
    ax.add_patch(rect)
    ax.text(x + w / 2, y + h / 2, label, ha='center', va='center',
            fontsize=fontsize, fontweight='bold', color=text_color)

def draw_arrow(ax, start, end, color='#455A64', style='->'):
    """Draw an arrow from start to end."""
    ax.annotate('', xy=end, xytext=start,
                arrowprops=dict(arrowstyle=style, color=color, lw=2))

# --- Title ---
ax.text(7, 9.5, 'UNet Architecture (SD 1.5)', ha='center',
        fontsize=16, fontweight='bold', color='#212121')

# --- Encoder (left side, going down) ---
enc_color = '#1565C0'
enc_x = 1.5
enc_positions = [
    (enc_x, 7.8, 2.0, 0.9, '64x64\nResBlock + Attn'),
    (enc_x, 5.8, 2.0, 0.9, '32x32\nResBlock + Attn'),
    (enc_x, 3.8, 2.0, 0.9, '16x16\nResBlock + Attn'),
]

for (x, y, w, h, label) in enc_positions:
    draw_block(ax, x, y, w, h, label, enc_color, fontsize=8)

# Encoder label
ax.text(enc_x + 1.0, 9.0, 'ENCODER', ha='center', fontsize=11,
        fontweight='bold', color=enc_color)

# Encoder arrows (down)
draw_arrow(ax, (enc_x + 1.0, 7.8), (enc_x + 1.0, 6.7), color=enc_color)
draw_arrow(ax, (enc_x + 1.0, 5.8), (enc_x + 1.0, 4.7), color=enc_color)
draw_arrow(ax, (enc_x + 1.0, 3.8), (enc_x + 1.0, 2.9), color=enc_color)

# --- Bottleneck ---
bneck_color = '#4A148C'
draw_block(ax, 5.5, 1.8, 3.0, 1.0, 'BOTTLENECK\n8x8 (most compressed)', bneck_color, fontsize=9)

# Encoder to bottleneck
draw_arrow(ax, (enc_x + 1.0, 3.8), (5.5, 2.3), color='#455A64')

# --- Decoder (right side, going up) ---
dec_color = '#C62828'
dec_x = 10.5
dec_positions = [
    (dec_x, 3.8, 2.0, 0.9, '16x16\nResBlock + Attn'),
    (dec_x, 5.8, 2.0, 0.9, '32x32\nResBlock + Attn'),
    (dec_x, 7.8, 2.0, 0.9, '64x64\nResBlock + Attn'),
]

for (x, y, w, h, label) in dec_positions:
    draw_block(ax, x, y, w, h, label, dec_color, fontsize=8)

# Decoder label
ax.text(dec_x + 1.0, 9.0, 'DECODER', ha='center', fontsize=11,
        fontweight='bold', color=dec_color)

# Bottleneck to decoder
draw_arrow(ax, (8.5, 2.3), (dec_x + 1.0, 3.8), color='#455A64')

# Decoder arrows (up)
draw_arrow(ax, (dec_x + 1.0, 4.7), (dec_x + 1.0, 5.8), color=dec_color)
draw_arrow(ax, (dec_x + 1.0, 6.7), (dec_x + 1.0, 7.8), color=dec_color)

# --- Skip Connections ---
skip_color = '#FF8F00'
for i, enc_y in enumerate([7.8, 5.8, 3.8]):
    dec_y = [3.8, 5.8, 7.8][2 - i]
    ax.annotate('',
                xy=(dec_x, dec_y + 0.45), xytext=(enc_x + 2.0, enc_y + 0.45),
                arrowprops=dict(arrowstyle='->', color=skip_color, lw=2,
                                linestyle='dashed',
                                connectionstyle='arc3,rad=0.0'))

ax.text(7, 8.6, 'Skip Connections', ha='center', fontsize=9,
        fontweight='bold', color=skip_color, style='italic')

# --- CLIP Text Encoder (side panel) ---
clip_color = '#00695C'
draw_block(ax, 5.8, 5.5, 2.4, 0.7, 'CLIP Text\nEncoder', clip_color, fontsize=9)

# Cross-attention labels
ax.text(7.0, 5.0, 'Cross-Attention\n(text -> image)', ha='center',
        fontsize=8, color=clip_color, style='italic')

# Arrows from CLIP to encoder/decoder cross-attention points
# Left arrows (to encoder blocks)
draw_arrow(ax, (5.8, 5.85), (3.5, 6.25), color=clip_color)
draw_arrow(ax, (5.8, 5.85), (3.5, 4.25), color=clip_color)

# Right arrows (to decoder blocks)
draw_arrow(ax, (8.2, 5.85), (10.5, 6.25), color=clip_color)
draw_arrow(ax, (8.2, 5.85), (10.5, 4.25), color=clip_color)

# --- Input / Output labels ---
ax.text(enc_x + 1.0, 9.0 - 0.3, '(downsampling)', ha='center',
        fontsize=8, color=enc_color, style='italic')
ax.text(dec_x + 1.0, 9.0 - 0.3, '(upsampling)', ha='center',
        fontsize=8, color=dec_color, style='italic')

# --- Legend ---
legend_items = [
    mpatches.Patch(color=enc_color, label='Encoder blocks (downsample)'),
    mpatches.Patch(color=dec_color, label='Decoder blocks (upsample)'),
    mpatches.Patch(color=bneck_color, label='Bottleneck'),
    mpatches.Patch(color=clip_color, label='Text conditioning (cross-attention)'),
    mpatches.Patch(color=skip_color, label='Skip connections'),
]
ax.legend(handles=legend_items, loc='lower center', ncol=3,
          fontsize=8, framealpha=0.9, edgecolor='#BDBDBD')

plt.tight_layout()
plt.show()

## SDXL --- Scaling the UNet (2023)

SDXL represented the peak of UNet-based diffusion models, demonstrating that significant quality gains were still possible within the existing architecture.

### Key Innovations

| Feature | SD 1.5 | SDXL |
|---------|--------|------|
| Parameters | 860M | **2.6B** (3x larger) |
| Text Encoders | CLIP-L only | **CLIP-L + CLIP-G** (dual encoder) |
| Base Resolution | 512x512 | **1024x1024** |
| Conditioning | Text only | Text + **micro-conditioning** |

### Dual CLIP Encoders

SDXL concatenates embeddings from two different CLIP models:
- **CLIP-L** (ViT-L/14): Good at understanding concepts and objects
- **CLIP-G** (ViT-bigG/14): Better at understanding style, composition, and nuance

The concatenated embedding provides a richer, more nuanced text representation.

### Micro-Conditioning

SDXL introduced conditioning on **image metadata** during training:
- **Original resolution** of the training image
- **Crop coordinates** (top, left)
- **Target resolution** for generation

This eliminated common artifacts like awkward cropping and helped the model understand image composition better.

### Optional Refiner Model

SDXL can optionally use a **refiner** model that takes the base model's output and enhances fine details. This two-stage approach improves texture quality and small details at the cost of additional inference time.

> **SDXL proved that UNets could still scale, but also revealed their limits** --- the architecture was becoming increasingly complex and difficult to scale further.

## SD3 --- The MMDiT Revolution (2024)

Stable Diffusion 3 represents the most significant architectural shift in the SD lineage: **replacing the UNet entirely with a Transformer**. This is the KEY innovation that separates the "old" and "new" eras of diffusion models.

---

### From UNet to Transformer: Why?

The DiT paper (Peebles & Xie, 2023) demonstrated that Vision Transformers could replace UNets as diffusion denoisers. The advantages:

1. **Better scaling**: Transformers scale more predictably with parameters (well-studied scaling laws)
2. **Simpler architecture**: No encoder/decoder asymmetry, no skip connections
3. **Unified attention**: All information flows through the same attention mechanism
4. **Hardware efficiency**: Transformers are better optimized on modern GPUs/TPUs

---

### MMDiT: Multimodal Diffusion Transformer

SD3's specific innovation is the **MMDiT (Multimodal Diffusion Transformer)** block. This is not just "a Transformer for images" --- it fundamentally changes how text and image interact.

#### The Core Idea: Two Streams, Joint Attention

MMDiT maintains **two separate processing streams**:

| Stream | Input | Purpose |
|--------|-------|---------|
| **Image stream** | Noisy image latents (patchified) | Processes visual information |
| **Text stream** | Text encoder outputs | Processes linguistic information |

Each stream has its own:
- LayerNorm
- QKV (Query, Key, Value) projections
- Feed-Forward Network (MLP)

But they **share a single attention operation** --- this is the critical difference.

#### Joint Attention: The Breakthrough

In the UNet's cross-attention:
```
Q = image,  K = text,  V = text     (text -> image, ONE-WAY)
```

In MMDiT's joint attention:
```
Q = [image_Q ; text_Q]              (concatenated)
K = [image_K ; text_K]              (concatenated)  
V = [image_V ; text_V]              (concatenated)

Attention = softmax(QK^T / sqrt(d)) * V    (TWO-WAY!)
```

This means:
- **Image tokens attend to text tokens** (as before)
- **Text tokens attend to image tokens** (NEW!)
- **Image tokens attend to other image tokens** (self-attention)
- **Text tokens attend to other text tokens** (self-attention)

All four types of attention happen in a **single attention matrix**.

#### Why Two-Way Attention Matters

With one-way cross-attention (UNet), the text representation is **frozen** --- it cannot adapt based on the current state of the image. With joint attention (MMDiT), the text representation is **dynamically refined** at every layer based on the evolving image.

This enables:
- Better spatial reasoning ("the cat is ON TOP OF the box")
- More accurate multi-object compositions
- Better text rendering in images
- More faithful prompt following overall

---

### Triple Text Encoders

SD3 uses **three** text encoders simultaneously:

| Encoder | Type | Token Limit | Strength |
|---------|------|-------------|----------|
| CLIP-L | Contrastive | 77 | Object/concept understanding |
| CLIP-G | Contrastive | 77 | Style/composition understanding |
| T5-XXL | Generative (encoder-only) | 512 | Long, detailed text understanding |

The T5-XXL encoder is particularly important: it breaks the 77-token barrier, enabling detailed prompts that describe complex scenes.

---

### Flow Matching Replaces DDPM

SD3 also replaces the traditional DDPM noise schedule with **Rectified Flow Matching** (covered in depth in the next notebook). The key benefit: straighter denoising trajectories that require **fewer sampling steps**.

> **SD3 is where everything changed.** MMDiT + Flow Matching + triple encoders = a completely new generation of diffusion models.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# ============================================================
# MMDiT Block Diagram
# ============================================================

fig, ax = plt.subplots(figsize=(14, 10))
ax.set_xlim(0, 14)
ax.set_ylim(0, 11)
ax.axis('off')

def draw_box(ax, x, y, w, h, label, color, fontsize=9, text_color='white',
             alpha=1.0, linestyle='-'):
    """Draw a rounded box with label."""
    rect = mpatches.FancyBboxPatch(
        (x, y), w, h,
        boxstyle="round,pad=0.12",
        facecolor=color, edgecolor='white', linewidth=1.5,
        alpha=alpha
    )
    ax.add_patch(rect)
    ax.text(x + w / 2, y + h / 2, label, ha='center', va='center',
            fontsize=fontsize, fontweight='bold', color=text_color)

def arrow(ax, start, end, color='#455A64'):
    ax.annotate('', xy=end, xytext=start,
                arrowprops=dict(arrowstyle='->', color=color, lw=2))

# --- Title ---
ax.text(7, 10.5, 'MMDiT Block (Multimodal Diffusion Transformer)',
        ha='center', fontsize=16, fontweight='bold', color='#212121')
ax.text(7, 10.0, 'Two-way information flow between image and text',
        ha='center', fontsize=11, color='#616161', style='italic')

# --- Colors ---
img_color = '#1565C0'      # blue for image stream
txt_color = '#C62828'      # red for text stream
joint_color = '#6A1B9A'    # purple for joint attention
ffn_img = '#1976D2'        # lighter blue
ffn_txt = '#E53935'        # lighter red
norm_color = '#546E7A'     # gray for norms

# ===================== LEFT STREAM (Image) =====================
left_x = 1.5
stream_w = 3.0

# Input
draw_box(ax, left_x, 8.5, stream_w, 0.7, 'Image Latents (patchified)',
         img_color, fontsize=9)

# LayerNorm
draw_box(ax, left_x, 7.3, stream_w, 0.6, 'LayerNorm', norm_color, fontsize=9)
arrow(ax, (left_x + stream_w / 2, 8.5), (left_x + stream_w / 2, 7.9))

# QKV Projection
draw_box(ax, left_x, 6.1, stream_w, 0.7, 'Q_img, K_img, V_img\n(Linear Projections)',
         img_color, fontsize=8)
arrow(ax, (left_x + stream_w / 2, 7.3), (left_x + stream_w / 2, 6.8))

# ===================== RIGHT STREAM (Text) =====================
right_x = 9.5

# Input
draw_box(ax, right_x, 8.5, stream_w, 0.7, 'Text Tokens (from CLIP/T5)',
         txt_color, fontsize=9)

# LayerNorm
draw_box(ax, right_x, 7.3, stream_w, 0.6, 'LayerNorm', norm_color, fontsize=9)
arrow(ax, (right_x + stream_w / 2, 8.5), (right_x + stream_w / 2, 7.9))

# QKV Projection
draw_box(ax, right_x, 6.1, stream_w, 0.7, 'Q_txt, K_txt, V_txt\n(Linear Projections)',
         txt_color, fontsize=8)
arrow(ax, (right_x + stream_w / 2, 7.3), (right_x + stream_w / 2, 6.8))

# ===================== JOINT ATTENTION (Center) =====================
joint_x = 4.0
joint_w = 6.0

draw_box(ax, joint_x, 4.3, joint_w, 1.2,
         'JOINT ATTENTION\n'
         'Q = [Q_img ; Q_txt]   K = [K_img ; K_txt]   V = [V_img ; V_txt]',
         joint_color, fontsize=9)

# Arrows from both QKV blocks into joint attention
arrow(ax, (left_x + stream_w / 2, 6.1), (joint_x + 1.5, 5.5), color=img_color)
arrow(ax, (right_x + stream_w / 2, 6.1), (joint_x + joint_w - 1.5, 5.5),
      color=txt_color)

# Label the two-way flow
ax.text(7, 3.8, 'Image attends to Text  &  Text attends to Image',
        ha='center', fontsize=10, fontweight='bold', color=joint_color,
        style='italic')

# ===================== OUTPUTS: Split back =====================
# FFN Image
draw_box(ax, left_x, 2.2, stream_w, 0.7, 'Feed-Forward Net\n(Image)', ffn_img, fontsize=9)
arrow(ax, (joint_x + 1.5, 4.3), (left_x + stream_w / 2, 2.9), color=img_color)

# FFN Text
draw_box(ax, right_x, 2.2, stream_w, 0.7, 'Feed-Forward Net\n(Text)', ffn_txt, fontsize=9)
arrow(ax, (joint_x + joint_w - 1.5, 4.3), (right_x + stream_w / 2, 2.9),
      color=txt_color)

# Output labels
draw_box(ax, left_x, 1.0, stream_w, 0.6, 'Updated Image Latents',
         img_color, fontsize=9, alpha=0.7)
arrow(ax, (left_x + stream_w / 2, 2.2), (left_x + stream_w / 2, 1.6),
      color=img_color)

draw_box(ax, right_x, 1.0, stream_w, 0.6, 'Updated Text Tokens',
         txt_color, fontsize=9, alpha=0.7)
arrow(ax, (right_x + stream_w / 2, 2.2), (right_x + stream_w / 2, 1.6),
      color=txt_color)

# --- Comparison annotation ---
ax.text(7, 0.3,
        'UNet cross-attention: text \u2192 image (one-way)  |  '
        'MMDiT joint attention: text \u2194 image (two-way)',
        ha='center', fontsize=10, fontweight='bold',
        color='#37474F',
        bbox=dict(boxstyle='round,pad=0.4', facecolor='#FFF9C4',
                  edgecolor='#F9A825', linewidth=1.5))

# --- Stream labels ---
ax.text(left_x + stream_w / 2, 9.5, 'IMAGE STREAM',
        ha='center', fontsize=12, fontweight='bold', color=img_color)
ax.text(right_x + stream_w / 2, 9.5, 'TEXT STREAM',
        ha='center', fontsize=12, fontweight='bold', color=txt_color)

plt.tight_layout()
plt.show()

## FLUX.1 --- Pushing the Frontier (August 2024)

FLUX.1, released by **Black Forest Labs** in August 2024, represents the state of the art in open-source image generation. The team behind it includes the original creators of Stable Diffusion who left Stability AI to found their own company.

### What Makes FLUX.1 Special

| Feature | Details |
|---------|--------|
| **Scale** | 12 billion parameters --- 6x larger than SD3 Medium |
| **Architecture** | MMDiT (inherited from SD3) |
| **Text Encoders** | CLIP-L + T5-XXL (dual encoder) |
| **Training** | Flow Matching + Guidance Distillation |
| **License** | Apache 2.0 (schnell) --- fully open source |

### Three Variants

| Variant | Steps | Speed | Use Case | License |
|---------|-------|-------|----------|--------|
| **FLUX.1-schnell** | 4 | Very fast | Real-time / interactive | Apache 2.0 |
| **FLUX.1-dev** | ~50 | Moderate | High-quality generation | Non-commercial |
| **FLUX.1-pro** | ~50 | Moderate | Commercial applications | Commercial API |

### Guidance Distillation: The Speed Secret

Traditional diffusion models use **classifier-free guidance (CFG)** at inference time, which requires **two forward passes** per step:
1. One pass with the text prompt (conditional)
2. One pass without text (unconditional)
3. Final output = unconditional + scale * (conditional - unconditional)

**Guidance distillation** trains the model to internalize this guidance behavior, so it only needs a **single forward pass** per step. Combined with flow matching's straighter trajectories, this enables FLUX.1-schnell to generate quality images in just **4 steps**.

### The Significance of FLUX.1

FLUX.1 is significant not just for its quality, but for what it proves:
- **MMDiT scales**: Going from 2B (SD3) to 12B yields clear quality improvements
- **Open source can compete**: FLUX.1-schnell matches or exceeds many closed models
- **Speed and quality are not trade-offs**: With the right training (guidance distillation + flow matching), you can have both
- **The SD lineage continues**: Despite the architectural revolution, the latent diffusion framework remains the foundation

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# ============================================================
# Parameter Scaling Visualization
# ============================================================

models = ['SD 1.5\n(2022)', 'SD 2.1\n(2022)', 'SDXL\n(2023)', 'SD3 Med\n(2024)', 'FLUX.1\n(2024)']
params = [0.86, 0.865, 2.6, 2.0, 12.0]
colors = ['#4CAF50', '#66BB6A', '#FFA726', '#FF7043', '#E53935']
denoiser = ['UNet', 'UNet', 'UNet', 'MMDiT', 'MMDiT']

fig, ax = plt.subplots(figsize=(12, 6))
bars = ax.bar(models, params, color=colors, edgecolor='white', linewidth=2)

for bar, param, arch in zip(bars, params, denoiser):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.3,
            f'{param}B\n({arch})', ha='center', va='bottom',
            fontsize=11, fontweight='bold')

ax.set_ylabel('Parameters (Billions)', fontsize=12)
ax.set_title('Diffusion Model Parameter Scaling: 2022-2024',
             fontsize=14, fontweight='bold')
ax.set_ylim(0, 15)
ax.grid(axis='y', alpha=0.3)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Add annotation for architecture shift
ax.axvline(x=2.5, color='gray', linestyle='--', alpha=0.5)
ax.text(1.0, 14, 'UNet Era', ha='center', fontsize=12, style='italic', color='gray')
ax.text(3.5, 14, 'Transformer Era', ha='center', fontsize=12, style='italic', color='gray')

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np

# ============================================================
# Key Innovations Timeline
# ============================================================

fig, ax = plt.subplots(figsize=(14, 6))
ax.set_xlim(2019.5, 2025)
ax.set_ylim(-2, 5)
ax.axis('off')

# Title
ax.text(2022.25, 4.7, 'Timeline of Diffusion Model Innovations',
        ha='center', fontsize=16, fontweight='bold', color='#212121')

# Main timeline axis
ax.plot([2019.8, 2024.9], [0, 0], color='#37474F', linewidth=3, zorder=1)

# Year markers
for year in [2020, 2021, 2022, 2023, 2024]:
    ax.plot(year, 0, 'o', color='#37474F', markersize=10, zorder=2)
    ax.text(year, -0.5, str(year), ha='center', fontsize=11,
            fontweight='bold', color='#37474F')

# Events
events = [
    (2020, 1.5, 'DDPM\n(Ho et al.)', '#4CAF50'),
    (2021, 2.5, 'Improved DDPM\n+\nLatent Diffusion\n(Rombach et al.)', '#66BB6A'),
    (2022, 3.5, 'SD 1.x / SD 2.x\n(Public Release)', '#2196F3'),
    (2023, 2.5, 'SDXL (2.6B UNet)\n+\nDiT Paper\n(Peebles & Xie)', '#FFA726'),
    (2024.15, 3.5, 'SD3 (Feb)\nMMDiT + Flow\nMatching', '#FF7043'),
    (2024.6, 1.5, 'FLUX.1 (Aug)\n12B MMDiT\n+ Guidance\nDistillation', '#E53935'),
]

for (x, y, label, color) in events:
    # Vertical connector line
    ax.plot([x, x], [0.15, y - 0.15], color=color, linewidth=2, zorder=1)
    # Event dot on timeline
    ax.plot(x, 0, 'o', color=color, markersize=14, zorder=3)
    # Label box
    ax.text(x, y, label, ha='center', va='bottom', fontsize=8.5,
            fontweight='bold', color='white',
            bbox=dict(boxstyle='round,pad=0.4', facecolor=color,
                      edgecolor='white', linewidth=1.5, alpha=0.95))

# Architecture era annotation
ax.annotate('', xy=(2024.0, -1.2), xytext=(2020.0, -1.2),
            arrowprops=dict(arrowstyle='<->', color='#1565C0', lw=2))
ax.text(2022.0, -1.5, 'UNet-based denoiser', ha='center', fontsize=10,
        color='#1565C0', fontweight='bold')

ax.annotate('', xy=(2024.9, -1.2), xytext=(2024.0, -1.2),
            arrowprops=dict(arrowstyle='<->', color='#C62828', lw=2))
ax.text(2024.45, -1.5, 'Transformer\n(MMDiT)', ha='center', fontsize=10,
        color='#C62828', fontweight='bold')

plt.tight_layout()
plt.show()

## Key Takeaways

---

### 1. UNet to Transformer: A Paradigm Shift
The shift from UNet to Transformer-based denoisers (DiT/MMDiT) enables better scaling, simpler architectures, and more effective use of compute. Transformers follow well-understood scaling laws --- more parameters reliably means better quality.

### 2. MMDiT's Two-Way Attention Is the Critical Innovation
The UNet's cross-attention only allowed text to influence image generation (one-way). MMDiT's joint attention enables **bidirectional information flow** --- text and image representations refine each other at every layer. This is why SD3 and FLUX.1 are dramatically better at prompt following.

### 3. Flow Matching Enables Fewer Steps
By training with rectified flow matching instead of DDPM, the denoising trajectories become straighter and require fewer steps to traverse. SD3 needs ~28 steps vs. SD 1.5's 50 steps for comparable quality.

### 4. Guidance Distillation Enables Real-Time Generation
FLUX.1-schnell combines flow matching with guidance distillation to generate quality images in just 4 steps with a single forward pass per step. This makes real-time diffusion generation practical.

### 5. The Latent Diffusion Framework Endures
Despite all the architectural changes in the denoiser, the overall framework remains the same: encode to latent space, denoise, decode. This stability means many tools and techniques (LoRA, ControlNet, IP-Adapter) can be adapted across model generations.

---

| Era | Models | Key Advance |
|-----|--------|-------------|
| **UNet Era** | SD 1.5, SD 2.1, SDXL | Cross-attention conditioning, scaling parameters |
| **Transformer Era** | SD3, FLUX.1 | MMDiT joint attention, flow matching, guidance distillation |

---

*Next notebook: We dive deep into Flow Matching --- the training paradigm that replaced DDPM and enabled the efficiency gains of SD3 and FLUX.1.*

## References

1. **Rombach et al.** (2022). *High-Resolution Image Synthesis with Latent Diffusion Models.* CVPR 2022. [arXiv:2112.10752](https://arxiv.org/abs/2112.10752)

2. **Peebles & Xie** (2023). *Scalable Diffusion Models with Transformers (DiT).* ICCV 2023. [arXiv:2212.09748](https://arxiv.org/abs/2212.09748)

3. **Esser et al.** (2024). *Scaling Rectified Flow Transformers for High-Resolution Image Synthesis (SD3).* [arXiv:2403.03206](https://arxiv.org/abs/2403.03206)

4. **Black Forest Labs** (2024). *FLUX.1: An open-source text-to-image model.* [https://blackforestlabs.ai](https://blackforestlabs.ai)

5. **Ho et al.** (2020). *Denoising Diffusion Probabilistic Models.* NeurIPS 2020. [arXiv:2006.11239](https://arxiv.org/abs/2006.11239)

6. **Podell et al.** (2023). *SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis.* [arXiv:2307.01952](https://arxiv.org/abs/2307.01952)