# Consistency Models

**Module 7.3, Lesson 1** | CourseAI

You know the theory from the lesson: consistency models use the self-consistency property of deterministic ODE trajectories—any point on the same trajectory maps to the same clean endpoint—as a training objective, bypassing multi-step ODE solving entirely. This notebook makes that concrete.

**What you will do:**
- Visualize ODE trajectories from a pretrained 2D diffusion model and verify the self-consistency property with real numbers
- Compare one-step Euler, one-step DDIM, and the true ODE endpoint—see why single-step ODE methods fail on curved trajectories
- Train a toy consistency model via distillation on 2D two-moons data and generate one-step samples
- Compare multi-step consistency generation (1, 2, 4, 8 steps) to the teacher's ODE solving (1, 5, 10, 20, 50 steps)

**For each exercise, PREDICT the output before running the cell.**

Every concept in this notebook comes from the lesson. The self-consistency property, the consistency function, consistency distillation, multi-step consistency. No new theory—just hands-on practice with the math and models.

**Estimated time:** 35–50 minutes. Exercises 1–2 use a pretrained model (no training). Exercises 3–4 train small MLPs on 2D data (~2–3 minutes on CPU).

## Setup

Run this cell to import everything and configure the environment.

No GPU required for this notebook. Everything runs on CPU. The models are tiny MLPs trained on 2D point distributions.

In [None]:
!pip install -q torch numpy matplotlib scikit-learn

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons

# Reproducible results
torch.manual_seed(42)
np.random.seed(42)

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]
plt.rcParams['figure.dpi'] = 100

print('Setup complete. No GPU needed for this notebook.')

## Shared Helpers

A small MLP for velocity prediction (the pretrained teacher), data generation utilities, and ODE solving helpers. Run this cell now—it defines everything needed for all four exercises.

In [None]:
# ============================================================
# Shared: MLP for 2D generative models
# ============================================================

class ToyModel(nn.Module):
    """MLP that takes (x_t, t) and outputs a 2D vector.
    
    For flow matching: output = predicted velocity v_theta(x_t, t)
    For consistency models: output = predicted clean point f_theta(x_t, t)
    Same architecture, different training target.
    """
    def __init__(self, hidden_dim=128):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.net = nn.Sequential(
            nn.Linear(2 + hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 2),
        )

    def forward(self, x_t, t):
        if t.dim() == 1:
            t = t.unsqueeze(-1)
        t_emb = self.time_mlp(t)
        inp = torch.cat([x_t, t_emb], dim=-1)
        return self.net(inp)


class ConsistencyModel(nn.Module):
    """Consistency model with skip connection enforcing the boundary condition.
    
    f(x_t, t) = c_skip(t) * x_t + c_out(t) * F_theta(x_t, t)
    
    At t = epsilon (near 0): c_skip -> 1, c_out -> 0, so f(x, epsilon) = x.
    At t = T (high noise): c_skip -> 0, c_out -> 1, so the network has full control.
    
    This enforces the boundary condition f(x, epsilon) = x by construction.
    """
    def __init__(self, hidden_dim=128, sigma_min=0.002):
        super().__init__()
        self.sigma_min = sigma_min
        self.backbone = ToyModel(hidden_dim=hidden_dim)
    
    def forward(self, x_t, t):
        """Predict the clean endpoint f(x_t, t).
        
        x_t: (batch, 2) -- noisy point
        t: (batch,) or (batch, 1) -- noise level in [sigma_min, 1]
        """
        if t.dim() == 2:
            t_scalar = t.squeeze(-1)
        else:
            t_scalar = t
        
        # Skip connection coefficients
        # c_skip(t) = sigma_min / t -- goes to 1 as t -> sigma_min
        # c_out(t) = (t - sigma_min) / t -- goes to 0 as t -> sigma_min
        c_skip = (self.sigma_min / t_scalar).unsqueeze(-1)  # (batch, 1)
        c_out = ((t_scalar - self.sigma_min) / t_scalar).unsqueeze(-1)  # (batch, 1)
        
        # Backbone prediction
        F = self.backbone(x_t, t_scalar)
        
        # f(x_t, t) = c_skip * x_t + c_out * F_theta(x_t, t)
        return c_skip * x_t + c_out * F


def sample_two_moons(n, noise=0.06):
    """Sample from the two-moons distribution."""
    data, _ = make_moons(n_samples=n, noise=noise)
    data = (data - data.mean(axis=0)) * 2.0
    return torch.tensor(data, dtype=torch.float32)


@torch.no_grad()
def solve_ode(model, x_start, t_start, t_end, n_steps):
    """Solve the flow matching ODE from t_start to t_end using Euler's method.
    
    Returns the full trajectory as a list of (x, t) pairs.
    The model predicts velocity v(x_t, t), and we step:
        x_{t-dt} = x_t - dt * v(x_t, t)
    """
    dt = (t_start - t_end) / n_steps
    x = x_start.clone()
    trajectory = [(x.clone(), t_start)]
    
    for i in range(n_steps):
        t = t_start - i * dt
        t_tensor = torch.full((x.shape[0], 1), t)
        v = model(x, t_tensor)
        x = x - dt * v
        trajectory.append((x.clone(), t - dt))
    
    return trajectory


@torch.no_grad()
def ode_endpoint(model, x_start, t_start, t_end, n_steps):
    """Solve the ODE and return only the final point."""
    traj = solve_ode(model, x_start, t_start, t_end, n_steps)
    return traj[-1][0]


print('Shared helpers defined.')
print('- ToyModel: MLP that takes (x_t, t) -> 2D vector')
print('- ConsistencyModel: MLP with skip connection enforcing f(x, eps) = x')
print('- sample_two_moons: 2D two-moons distribution')
print('- solve_ode / ode_endpoint: Euler ODE solver for flow matching models')

## Pretrained Teacher Model

We need a pretrained 2D diffusion model (flow matching) that defines ODE trajectories. We train it here—same as Exercise 3 from the flow matching notebook. This takes ~30 seconds.

This teacher model is used in all four exercises. Think of it as the "pretrained diffusion model" from the lesson—it has already learned the ODE trajectory structure.

In [None]:
# ============================================================
# Train the teacher: a flow matching model on two-moons
# ============================================================
# This is identical to what you built in the flow matching notebook.
# The teacher defines the ODE trajectories that consistency models
# will learn to bypass.

torch.manual_seed(42)

teacher = ToyModel(hidden_dim=128)
optimizer = torch.optim.Adam(teacher.parameters(), lr=3e-4)

print('Training teacher (flow matching) on two-moons...')
print('(~30 seconds on CPU)')

for epoch in range(500):
    x_0 = sample_two_moons(512)
    epsilon = torch.randn_like(x_0)
    t = torch.rand(512, 1)
    
    # Flow matching interpolation
    x_t = (1 - t) * x_0 + t * epsilon
    
    # Target velocity
    target_v = epsilon - x_0
    
    # Predict and optimize
    pred_v = teacher(x_t, t)
    loss = nn.functional.mse_loss(pred_v, target_v)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 100 == 0:
        print(f'  Epoch {epoch+1}/500, loss: {loss.item():.4f}')

teacher.eval()
print(f'\nTeacher trained. Final loss: {loss.item():.4f}')
print('This model defines the ODE trajectories used in all exercises.')

In [None]:
# Quick sanity check: generate samples from the teacher
torch.manual_seed(0)
noise = torch.randn(500, 2)
teacher_samples = ode_endpoint(teacher, noise, t_start=1.0, t_end=0.0, n_steps=50)

real_data = sample_two_moons(500)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].scatter(real_data[:, 0], real_data[:, 1], s=3, alpha=0.5, c='#60a5fa')
axes[0].set_title('Real Data (two-moons)', fontsize=11)
axes[0].set_xlim(-4, 4); axes[0].set_ylim(-4, 4)
axes[0].set_aspect('equal'); axes[0].grid(alpha=0.15)

axes[1].scatter(teacher_samples[:, 0], teacher_samples[:, 1], s=3, alpha=0.5, c='#34d399')
axes[1].set_title('Teacher (50-step ODE)', fontsize=11)
axes[1].set_xlim(-4, 4); axes[1].set_ylim(-4, 4)
axes[1].set_aspect('equal'); axes[1].grid(alpha=0.15)

plt.suptitle('Teacher Model Sanity Check', fontsize=13, y=1.02)
plt.tight_layout()
plt.show()
print('The teacher generates clean two-moons samples with 50 ODE steps.')
print('This teacher defines the ODE trajectories for all exercises.')

---

## Exercise 1: Visualize Self-Consistency on ODE Trajectories `[Guided]`

From the lesson: the self-consistency property says that any point on the same ODE trajectory maps to the same clean endpoint. This is a trivial consequence of the ODE being deterministic—but it becomes a powerful training objective.

We will use the pretrained teacher to:
1. Generate several ODE trajectories by starting from different noise points and solving to completion
2. Pick one trajectory and highlight 5 points at different noise levels
3. Run the ODE from each of these 5 points and verify they all reach the same endpoint

**Before running, predict:**
- If you start from the midpoint of a trajectory (t=0.5) and run the ODE to completion, will you reach the same endpoint as starting from the beginning (t=1.0)?
- How close will the endpoints be? Exactly the same, or approximately the same? Why?

In [None]:
# ============================================================
# Exercise 1: Visualize ODE trajectories and self-consistency
# ============================================================

torch.manual_seed(17)

# Step 1: Generate 8 ODE trajectories from random noise to clean data
n_trajectories = 8
n_ode_steps = 200  # Dense steps for smooth trajectory visualization

noise_starts = torch.randn(n_trajectories, 2)
all_trajectories = []

for i in range(n_trajectories):
    x_i = noise_starts[i:i+1]  # (1, 2)
    traj = solve_ode(teacher, x_i, t_start=1.0, t_end=0.0, n_steps=n_ode_steps)
    # Extract just the points
    points = torch.stack([pt[0].squeeze(0) for pt in traj])  # (n_steps+1, 2)
    times = [pt[1] for pt in traj]  # list of t values
    all_trajectories.append((points, times))

# Step 2: Pick one trajectory and highlight 5 points
chosen_idx = 2  # Pick the 3rd trajectory
chosen_points, chosen_times = all_trajectories[chosen_idx]

# Select 5 points at t = 0.9, 0.7, 0.5, 0.3, 0.1
highlight_t_values = [0.9, 0.7, 0.5, 0.3, 0.1]
highlight_indices = []
for t_val in highlight_t_values:
    # Find the index closest to this t value
    idx = min(range(len(chosen_times)), key=lambda i: abs(chosen_times[i] - t_val))
    highlight_indices.append(idx)

highlight_points = chosen_points[highlight_indices]  # (5, 2)
highlight_actual_t = [chosen_times[i] for i in highlight_indices]

# The true endpoint of this trajectory (t=0)
true_endpoint = chosen_points[-1]  # (2,)

print(f'Chosen trajectory: starts at noise ({chosen_points[0][0]:.3f}, {chosen_points[0][1]:.3f})')
print(f'True endpoint: ({true_endpoint[0]:.3f}, {true_endpoint[1]:.3f})')
print()
print('Highlighted points on this trajectory:')
for i, (pt, t_val) in enumerate(zip(highlight_points, highlight_actual_t)):
    print(f'  Point {i+1}: t={t_val:.3f}, position=({pt[0]:.3f}, {pt[1]:.3f})')

In [None]:
# Step 3: Run the ODE from each highlighted point to t=0
# If the self-consistency property holds, all 5 should reach the same endpoint.

endpoints_from_midpoints = []

print('Running ODE from each highlighted point to t=0...')
print()

for i, (pt, t_val) in enumerate(zip(highlight_points, highlight_actual_t)):
    x_start = pt.unsqueeze(0)  # (1, 2)
    # Use enough steps for a good approximation
    n_steps_from_here = max(int(t_val * 200), 10)
    endpoint = ode_endpoint(teacher, x_start, t_start=t_val, t_end=0.0, n_steps=n_steps_from_here)
    endpoint = endpoint.squeeze(0)  # (2,)
    endpoints_from_midpoints.append(endpoint)
    
    distance = torch.norm(endpoint - true_endpoint).item()
    print(f'  From t={t_val:.3f}: endpoint = ({endpoint[0]:.4f}, {endpoint[1]:.4f}), '
          f'distance from true endpoint = {distance:.6f}')

# Compute overall statistics
all_endpoints = torch.stack(endpoints_from_midpoints)
distances = torch.norm(all_endpoints - true_endpoint.unsqueeze(0), dim=1)
print(f'\nMax distance from true endpoint: {distances.max().item():.6f}')
print(f'Mean distance from true endpoint: {distances.mean().item():.6f}')
print()
print('All 5 points on the same trajectory reach (approximately) the same endpoint.')
print('The small differences are due to numerical ODE solver precision, not a failure')
print('of the self-consistency property. With infinite precision, they would be identical.')

In [None]:
# Step 4: Visualize all trajectories and the self-consistency verification

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# --- Left panel: All 8 ODE trajectories ---
ax = axes[0]
colors = plt.cm.cool(np.linspace(0.2, 0.9, n_trajectories))

for i, (pts, times) in enumerate(all_trajectories):
    alpha = 0.8 if i == chosen_idx else 0.25
    width = 2.5 if i == chosen_idx else 1.0
    ax.plot(pts[:, 0], pts[:, 1], color=colors[i], linewidth=width, alpha=alpha)
    # Mark start (noise) and end (data)
    ax.plot(pts[0, 0], pts[0, 1], 'o', color=colors[i], markersize=4, alpha=alpha)
    ax.plot(pts[-1, 0], pts[-1, 1], '*', color=colors[i], markersize=8 if i == chosen_idx else 5, alpha=alpha)

# Highlight the 5 points on the chosen trajectory
for i, (pt, t_val) in enumerate(zip(highlight_points, highlight_actual_t)):
    ax.plot(pt[0], pt[1], 's', color='#f59e0b', markersize=10, zorder=10,
            markeredgecolor='white', markeredgewidth=1.5)
    ax.annotate(f't={t_val:.1f}', (pt[0], pt[1]), textcoords='offset points',
                xytext=(8, 5), fontsize=8, color='#f59e0b')

ax.set_title('8 ODE Trajectories (noise to data)\nHighlighted: chosen trajectory with 5 points', fontsize=11)
ax.set_xlabel('$x_1$'); ax.set_ylabel('$x_2$')
ax.grid(alpha=0.15)
ax.set_aspect('equal')

# --- Right panel: Self-consistency verification ---
ax = axes[1]

# Plot the chosen trajectory
pts = all_trajectories[chosen_idx][0]
ax.plot(pts[:, 0], pts[:, 1], color='#60a5fa', linewidth=1.5, alpha=0.4, label='ODE trajectory')

# Plot the 5 starting points
for i, (pt, t_val) in enumerate(zip(highlight_points, highlight_actual_t)):
    ax.plot(pt[0], pt[1], 's', color='#f59e0b', markersize=10, zorder=10,
            markeredgecolor='white', markeredgewidth=1.5)
    ax.annotate(f't={t_val:.1f}', (pt[0], pt[1]), textcoords='offset points',
                xytext=(8, 5), fontsize=8, color='#f59e0b')

# Plot the endpoints reached from each starting point
for i, ep in enumerate(endpoints_from_midpoints):
    label = 'Endpoints from midpoints' if i == 0 else None
    ax.plot(ep[0], ep[1], 'o', color='#34d399', markersize=8, zorder=10, label=label)
    # Draw arrow from starting point to endpoint
    ax.annotate('', xy=(ep[0], ep[1]), xytext=(highlight_points[i][0], highlight_points[i][1]),
                arrowprops=dict(arrowstyle='->', color='#34d399', alpha=0.4, lw=1.5))

# Mark the true endpoint
ax.plot(true_endpoint[0], true_endpoint[1], '*', color='#ef4444', markersize=18, zorder=11,
        markeredgecolor='white', markeredgewidth=1.5, label='True endpoint')

ax.set_title('Self-Consistency Verification\n5 starting points → same endpoint', fontsize=11)
ax.set_xlabel('$x_1$'); ax.set_ylabel('$x_2$')
ax.legend(fontsize=9, loc='upper left')
ax.grid(alpha=0.15)
ax.set_aspect('equal')

plt.tight_layout()
plt.show()

print('Left: 8 ODE trajectories from random noise to the two-moons distribution.')
print('The highlighted trajectory has 5 marked points at different noise levels.')
print()
print('Right: Running the ODE from each of the 5 points reaches the same endpoint.')
print('The green circles (endpoints from midpoints) cluster tightly around the')
print('red star (true endpoint). This IS the self-consistency property:')
print('f(x_t, t) = f(x_t\', t\') for any t, t\' on the same trajectory.')

### What Just Happened

You verified the self-consistency property with real numbers on real ODE trajectories:

- **The self-consistency property is real.** Starting the ODE from t=0.9, t=0.7, t=0.5, t=0.3, or t=0.1—all on the same trajectory—produces (nearly) the same endpoint. The small differences are numerical precision from the ODE solver, not a violation of the property.

- **This is just what "deterministic" means.** The ODE is deterministic: same starting point, same path, same endpoint. If two points share a trajectory, they share an endpoint. This is the trivially true fact that the lesson calls "not a new mathematical result."

- **The insight is what we can do with it.** If a neural network could learn the function f(x_t, t) = endpoint for ANY x_t and t, we would not need to run the ODE at all. We would just evaluate f once and get the clean data point. That is the consistency model idea.

---

## Exercise 2: One-Step ODE vs Consistency Model Prediction `[Guided]`

From the lesson: a consistency model is NOT the same as an ODE solver with 1 step. An ODE solver with 1 step computes a direction at x_T and takes a single (massive, inaccurate) step. The consistency model maps x_T directly to x_0—no direction, no step, a direct mapping.

We will compare three one-step approaches:
1. **True endpoint**—run the ODE with many steps (ground truth)
2. **One Euler step**—compute velocity at x_T, take one step to t=0
3. **DDIM-style 1-step prediction**—use the velocity to estimate x_0 directly

For the flow matching ODE, the DDIM-style 1-step prediction computes:

$$\hat{x}_0 = x_T - T \cdot v_\theta(x_T, T) = x_T - 1.0 \cdot v_\theta(x_T, 1.0)$$

(Since $x_t = (1-t) \cdot x_0 + t \cdot \epsilon$ implies $x_0 = x_t - t \cdot v$ where $v = \epsilon - x_0$.)

Note: for flow matching with perfectly straight trajectories, the Euler 1-step and the DDIM 1-step are actually the same computation. But the learned velocity field is not perfectly straight (individual paths are straight; the aggregate field has curvature), so there can be subtle differences.

**Before running, predict:**
- How far off will the 1-step Euler estimate be from the true endpoint?
- Will the 1-step predictions land anywhere near the two-moons distribution, or will they be scattered randomly?

In [None]:
# ============================================================
# Exercise 2: Compare true ODE endpoint vs 1-step approximations
# ============================================================

torch.manual_seed(42)
n_samples = 300

# Start from pure noise
x_T = torch.randn(n_samples, 2)

# --- Ground truth: 200-step ODE ---
true_endpoints = ode_endpoint(teacher, x_T, t_start=1.0, t_end=0.0, n_steps=200)

# --- 1 Euler step ---
# From t=1.0 to t=0.0 in one step:
# x_0_euler = x_T - 1.0 * v_theta(x_T, 1.0)
with torch.no_grad():
    t_one = torch.ones(n_samples, 1)
    v_at_T = teacher(x_T, t_one)
    euler_1step = x_T - 1.0 * v_at_T

# --- DDIM-style 1-step prediction ---
# For flow matching: x_0_hat = x_T - T * v_theta(x_T, T)
# This is actually the same formula as Euler 1-step for flow matching
# (confirming that for flow matching, Euler from t=1 to t=0 IS the x_0 prediction)
ddim_1step = x_T - 1.0 * v_at_T

# --- Also compare: 5-step ODE (a reasonable middle ground) ---
ode_5step = ode_endpoint(teacher, x_T, t_start=1.0, t_end=0.0, n_steps=5)

# Compute errors
euler_errors = torch.norm(euler_1step - true_endpoints, dim=1)
ode5_errors = torch.norm(ode_5step - true_endpoints, dim=1)

print('Endpoint errors (distance from true 200-step ODE endpoint):')
print(f'  1 Euler step:  mean={euler_errors.mean():.4f}, max={euler_errors.max():.4f}')
print(f'  5 ODE steps:   mean={ode5_errors.mean():.4f}, max={ode5_errors.max():.4f}')
print()
print('Note: for flow matching, 1 Euler step and DDIM 1-step x_0 prediction')
print('are the same computation: x_0 = x_T - v_theta(x_T, 1.0).')
print('The error comes from curvature in the LEARNED aggregate velocity field.')

In [None]:
# Visualize the comparison

fig, axes = plt.subplots(1, 4, figsize=(18, 4.5))

# Real data
real = sample_two_moons(500)
axes[0].scatter(real[:, 0], real[:, 1], s=3, alpha=0.3, c='#60a5fa')
axes[0].set_title('Real Data\n(two-moons)', fontsize=11)

# True endpoints (200-step ODE)
axes[1].scatter(true_endpoints[:, 0], true_endpoints[:, 1], s=3, alpha=0.5, c='#34d399')
axes[1].set_title('True Endpoint\n(200-step ODE)', fontsize=11)

# 1-step prediction
axes[2].scatter(euler_1step[:, 0], euler_1step[:, 1], s=3, alpha=0.5, c='#f59e0b')
axes[2].set_title(f'1-Step Prediction\n(mean error: {euler_errors.mean():.3f})', fontsize=11)

# 5-step ODE
axes[3].scatter(ode_5step[:, 0], ode_5step[:, 1], s=3, alpha=0.5, c='#a78bfa')
axes[3].set_title(f'5-Step ODE\n(mean error: {ode5_errors.mean():.3f})', fontsize=11)

for ax in axes:
    ax.set_xlim(-5, 5); ax.set_ylim(-5, 5)
    ax.set_aspect('equal'); ax.grid(alpha=0.15)

plt.suptitle(
    'Why We Need Consistency Models: Single-Step ODE Methods Are Inaccurate',
    fontsize=13, y=1.02
)
plt.tight_layout()
plt.show()

print('Observations:')
print('- The true endpoint (200-step ODE) produces clean two-moons samples.')
print('- The 1-step prediction is recognizable but noticeably noisier/messier.')
print('  Points scatter away from the clean crescents because the aggregate')
print('  velocity field has curvature that one step cannot account for.')
print('- The 5-step ODE is much closer to the true endpoint.')
print()
print('THIS is the problem consistency models solve. A trained consistency model')
print('would map x_T directly to x_0 in one step WITHOUT the curvature error.')
print('It does not compute a direction and step—it predicts the endpoint directly.')
print('The self-consistency training objective teaches it to be accurate at this.')

### What Just Happened

You saw concretely why consistency models are needed—and why they are NOT just an ODE solver with 1 step:

- **One-step ODE prediction is inaccurate.** Even with flow matching's nearly-straight trajectories, the aggregate learned velocity field has enough curvature that a single step from pure noise produces noticeably worse samples. The points are roughly in the right area but scatter away from the clean crescent shapes.

- **More steps fix it, but cost more.** 5 steps dramatically improves the result; 200 steps is nearly perfect. Every additional step costs one model evaluation. This is the core tension: accuracy costs compute.

- **A consistency model takes a different approach.** Instead of computing a direction and stepping (which requires multiple steps for accuracy), it learns a direct mapping: f(x_T, T) = x_0. No direction, no steps, no curvature error. But this requires special training—the self-consistency objective—which is what Exercise 3 implements.

---

## Exercise 3: Train a Toy Consistency Model on 2D Data `[Supported]`

From the lesson: consistency distillation uses a pretrained teacher to provide ODE trajectory estimates. The training procedure:

1. Sample a data point $x_0$, add noise to get $x_{t_{n+1}}$ at a higher noise level
2. Use the teacher to take one ODE step from $x_{t_{n+1}}$ to estimate $\hat{x}_{t_n}$ at the next lower noise level
3. Train the consistency model so that $f_\theta(x_{t_{n+1}}, t_{n+1})$ matches $f_{\theta^-}(\hat{x}_{t_n}, t_n)$
4. $\theta^-$ is an EMA of $\theta$ (prevents collapse)

The loss: $\mathcal{L} = \| f_\theta(x_{t_{n+1}}, t_{n+1}) - f_{\theta^-}(\hat{x}_{t_n}, t_n) \|^2$

Your task: fill in the TODO markers to implement the consistency distillation training loop.

**Before running, predict:**
- After training, will the consistency model's 1-step samples be better or worse than the 1-step ODE prediction from Exercise 2?
- The consistency model and the teacher have the same architecture size. How can the consistency model produce better 1-step results than the teacher with 1 ODE step?

In [None]:
# ============================================================
# Exercise 3: Consistency distillation training
# ============================================================
# NOTE: Fill in ALL four TODOs before running this cell.

torch.manual_seed(42)

# Hyperparameters
sigma_min = 0.002    # Minimum noise level (boundary condition point)
sigma_max = 1.0      # Maximum noise level
n_timesteps = 20     # Number of discretization steps for the noise schedule
ema_decay = 0.999    # EMA decay rate for target network
n_epochs = 1000
batch_size = 512
lr = 3e-4

# Create the noise level schedule: evenly spaced from sigma_min to sigma_max
# These are the discrete noise levels t_1 < t_2 < ... < t_N
sigmas = torch.linspace(sigma_min, sigma_max, n_timesteps)

# Create the consistency model (online) and EMA target
cm_model = ConsistencyModel(hidden_dim=128, sigma_min=sigma_min)
cm_target = ConsistencyModel(hidden_dim=128, sigma_min=sigma_min)

# Initialize target as a copy of the model
cm_target.load_state_dict(cm_model.state_dict())

# The target does NOT receive gradients
for param in cm_target.parameters():
    param.requires_grad_(False)

optimizer = torch.optim.Adam(cm_model.parameters(), lr=lr)

losses = []

print('Training consistency model via distillation...')
print(f'Teacher: pretrained flow matching model')
print(f'Noise schedule: {n_timesteps} levels from {sigma_min} to {sigma_max}')
print(f'EMA decay: {ema_decay}')
print('(~1-2 minutes on CPU)')
print()

for epoch in range(n_epochs):
    # Step 1: Sample clean data
    x_0 = sample_two_moons(batch_size)
    
    # Step 2: Sample random adjacent timestep pairs (t_n, t_{n+1})
    # Pick random indices into the sigma schedule
    # n ranges from 0 to n_timesteps-2 (so n+1 is valid)
    n_idx = torch.randint(0, n_timesteps - 1, (batch_size,))
    t_n = sigmas[n_idx]          # Lower noise level
    t_n1 = sigmas[n_idx + 1]     # Higher noise level (one step noisier)
    
    # Step 3: Create the noisy sample at the higher noise level t_{n+1}
    # Using flow matching interpolation: x_t = (1-t) * x_0 + t * epsilon
    epsilon = torch.randn_like(x_0)
    
    # TODO: Compute x at the higher noise level t_{n+1}
    # x_tn1 = (1 - t_{n+1}) * x_0 + t_{n+1} * epsilon
    # Remember: t_n1 has shape (batch_size,), so unsqueeze to (batch_size, 1)
    x_tn1 = None  # <-- Replace this line
    
    # Step 4: Teacher takes one ODE step from t_{n+1} to t_n
    # This estimates where x_{t_{n+1}} would be at noise level t_n
    # Euler step: x_hat_tn = x_tn1 - (t_{n+1} - t_n) * v_teacher(x_tn1, t_{n+1})
    with torch.no_grad():
        v_teacher = teacher(x_tn1, t_n1.unsqueeze(-1))
        dt = (t_n1 - t_n).unsqueeze(-1)  # (batch, 1)
        
        # TODO: Compute the teacher's one-step ODE estimate at noise level t_n
        # x_hat_tn = x_tn1 - dt * v_teacher
        x_hat_tn = None  # <-- Replace this line
    
    # Step 5: Consistency model predictions
    # Online model: f_theta(x_{t_{n+1}}, t_{n+1})
    # Target model: f_theta_minus(x_hat_{t_n}, t_n) [no gradients!]
    
    pred_online = cm_model(x_tn1, t_n1)
    
    with torch.no_grad():
        pred_target = cm_target(x_hat_tn, t_n)
    
    # TODO: Compute the consistency distillation loss
    # L = MSE between the online prediction and the target prediction
    # The online model's prediction at the higher noise level should match
    # the target model's prediction at the teacher-estimated lower noise level.
    loss = None  # <-- Replace this line
    
    # Guard: make sure TODOs are filled in
    if x_tn1 is None or x_hat_tn is None or loss is None:
        raise NotImplementedError(
            'Fill in the TODOs (x_tn1, x_hat_tn, loss) before running this cell.'
        )
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # TODO: Update the EMA target network
    # For each parameter pair (online_param, target_param):
    #   target_param = ema_decay * target_param + (1 - ema_decay) * online_param
    # Use torch.no_grad() and .data to avoid tracking gradients.
    with torch.no_grad():
        for p_online, p_target in zip(cm_model.parameters(), cm_target.parameters()):
            pass  # <-- Replace this line with the EMA update
    
    losses.append(loss.item())
    if (epoch + 1) % 200 == 0:
        print(f'  Epoch {epoch+1}/{n_epochs}, loss: {loss.item():.6f}')

print(f'\nTraining complete. Final loss: {losses[-1]:.6f}')

In [None]:
# Plot training loss
fig, ax = plt.subplots(1, 1, figsize=(10, 3))
ax.plot(losses, color='#a78bfa', linewidth=1, alpha=0.7)
ax.set_xlabel('Epoch')
ax.set_ylabel('Consistency Loss')
ax.set_title('Consistency Distillation Training Loss')
ax.set_yscale('log')
plt.tight_layout()
plt.show()

In [None]:
# ============================================================
# Generate 1-step samples from the consistency model
# ============================================================
# Single-step generation: f(x_T, T) -> x_0
# Sample noise, feed through the consistency model at t=sigma_max, done.

cm_model.eval()

torch.manual_seed(42)
n_gen = 500

# Start from pure noise
x_noise = torch.randn(n_gen, 2)
t_max = torch.full((n_gen,), sigma_max)

# One-step consistency model generation
with torch.no_grad():
    cm_samples = cm_model(x_noise, t_max)

# Compare to: teacher 1-step (from Exercise 2) and teacher 50-step
with torch.no_grad():
    v_teacher_at_T = teacher(x_noise, torch.ones(n_gen, 1))
    teacher_1step = x_noise - 1.0 * v_teacher_at_T

teacher_50step = ode_endpoint(teacher, x_noise, t_start=1.0, t_end=0.0, n_steps=50)

# Plot comparison
fig, axes = plt.subplots(1, 4, figsize=(18, 4.5))

real = sample_two_moons(500)
axes[0].scatter(real[:, 0], real[:, 1], s=3, alpha=0.3, c='#60a5fa')
axes[0].set_title('Real Data', fontsize=11)

axes[1].scatter(teacher_50step[:, 0], teacher_50step[:, 1], s=3, alpha=0.5, c='#34d399')
axes[1].set_title('Teacher (50-step ODE)', fontsize=11)

axes[2].scatter(teacher_1step[:, 0], teacher_1step[:, 1], s=3, alpha=0.5, c='#f59e0b')
axes[2].set_title('Teacher (1-step ODE)', fontsize=11)

axes[3].scatter(cm_samples[:, 0], cm_samples[:, 1], s=3, alpha=0.5, c='#a78bfa')
axes[3].set_title('Consistency Model (1 step)', fontsize=11)

for ax in axes:
    ax.set_xlim(-5, 5); ax.set_ylim(-5, 5)
    ax.set_aspect('equal'); ax.grid(alpha=0.15)

plt.suptitle(
    'Consistency Model vs Teacher: One-Step Generation',
    fontsize=13, y=1.02
)
plt.tight_layout()
plt.show()

print('Observations:')
print('- The consistency model\'s 1-step samples should be closer to the real')
print('  distribution than the teacher\'s 1-step ODE prediction.')
print('- The consistency model was TRAINED to map noise to clean data in one step.')
print('  The teacher was trained to predict velocity, not to be accurate in 1 step.')
print('- This is the payoff of the self-consistency training objective: the model')
print('  learns that f(x_T, T) should be the trajectory endpoint, not just the')
print('  direction at x_T.')

<details>
<summary>Solution</summary>

The key insight is that consistency distillation enforces a LOCAL constraint (adjacent timesteps agree) that implies GLOBAL consistency (any point maps to the endpoint). The teacher provides the trajectory information, and the EMA target prevents the model from collapsing to a constant function.

**Training TODOs:**
```python
# Noisy sample at higher noise level
x_tn1 = (1 - t_n1.unsqueeze(-1)) * x_0 + t_n1.unsqueeze(-1) * epsilon

# Teacher's one-step ODE estimate
x_hat_tn = x_tn1 - dt * v_teacher

# Consistency distillation loss
loss = nn.functional.mse_loss(pred_online, pred_target)

# EMA update
with torch.no_grad():
    for p_online, p_target in zip(cm_model.parameters(), cm_target.parameters()):
        p_target.data.mul_(ema_decay).add_(p_online.data, alpha=1 - ema_decay)
```

**Why the EMA target matters:** Without it, both sides of the loss are computed by the same model. The model could trivially minimize the loss by outputting a constant—same output for any input means the two predictions always match. The EMA target breaks this symmetry: the target network changes slowly, so the online network must actually learn the correct mapping to match the target's predictions.

**Why the teacher takes only one ODE step:** Running the full ODE would require 10-50 model evaluations per training step—prohibitively expensive. One step between adjacent noise levels provides a reasonable estimate (the trajectory changes little between nearby noise levels). The consistency model learns to chain these local constraints into global consistency.

**Common mistakes:**
- Forgetting to unsqueeze `t_n1` from shape `(batch,)` to `(batch, 1)` for broadcasting with `x_0` of shape `(batch, 2)`
- Using `cm_model` instead of `cm_target` for the target prediction (the EMA copy must be used for stability)
- Forgetting `torch.no_grad()` around the EMA update (would corrupt the gradient computation)
- Using `.copy_()` instead of the weighted combination for the EMA update

</details>

### What Just Happened

You trained a consistency model via distillation and generated one-step samples:

- **The consistency model produces better 1-step samples than the teacher's 1-step ODE.** The teacher was trained to predict velocity (direction at each point). The consistency model was trained to predict the endpoint directly. Same architecture, different training objective, different result.

- **The training loop enforces local consistency.** At each training step, two adjacent points on the teacher's trajectory should produce the same output. Over training, these local constraints compose into global consistency: any point maps to the endpoint.

- **The EMA target prevents collapse.** Without the slowly-moving target network, the model could cheat by outputting a constant. The EMA target provides a stable reference that the online model must actually match.

- **One-step quality is decent but not perfect.** Mapping from pure noise to clean data in one step is a hard problem. The consistency model does it better than the teacher's 1-step ODE, but probably not as well as the teacher's 50-step ODE. This motivates multi-step consistency in Exercise 4.

---

## Exercise 4: Multi-Step Consistency and Quality Comparison `[Independent]`

From the lesson: multi-step consistency is a middle ground between 1-step and multi-step diffusion. The procedure:

1. Start at $x_T$ (pure noise)
2. Apply $f(x_T, T)$ to get a clean estimate $\hat{x}_0$
3. Add noise to $\hat{x}_0$ to get $x_{t_2}$ at a lower noise level $t_2$
4. Apply $f(x_{t_2}, t_2)$ to get a better clean estimate
5. Repeat

This is NOT ODE solving—each step restarts by jumping to $x_0$ and re-noising. There is no trajectory being followed between steps.

**Your task:**

1. **Implement multi-step consistency sampling** using the trained consistency model from Exercise 3
2. **Generate samples at 1, 2, 4, and 8 consistency steps**
3. **Generate teacher ODE samples at 1, 5, 10, 20, and 50 steps** for comparison
4. **Plot both progressions side by side** to visualize the quality-speed tradeoff

For multi-step consistency with $K$ steps, use evenly-spaced noise levels from $\sigma_{\text{max}}$ down to $\sigma_{\text{min}}$.

**Before running, predict:**
- At how many consistency steps will the samples look comparable to the teacher's 50-step ODE?
- Will the 1-step consistency sample be better or worse than the teacher's 5-step ODE?
- Will 8 consistency steps look much better than 4?

In [None]:
# Your code here.
#
# Suggested structure:
#
# 1. Implement multi-step consistency sampling:
#    def sample_consistency_multistep(model, n_samples, n_steps, sigma_min, sigma_max):
#        Create evenly-spaced noise levels from sigma_max down to sigma_min
#        Start from pure Gaussian noise
#        For each step:
#            a. Apply the consistency model to get a clean estimate: x_0_hat = f(x_t, t)
#            b. Re-noise the clean estimate to the next lower noise level:
#               x_next = (1 - t_next) * x_0_hat + t_next * new_noise
#            c. (Skip re-noising on the final step -- just return x_0_hat)
#        Return the final clean estimate
#
# 2. Generate samples:
#    - Consistency model at 1, 2, 4, 8 steps (same starting noise for all)
#    - Teacher ODE at 1, 5, 10, 20, 50 steps (same starting noise)
#
# 3. Plot a comparison grid:
#    Row 1: Teacher ODE at 1, 5, 10, 20, 50 steps
#    Row 2: Consistency model at 1, 2, 4, 8 steps + real data
#    Use the same axis limits and styling for fair visual comparison.
#
# Remember:
# - cm_model is the trained consistency model from Exercise 3
# - teacher is the pretrained flow matching model
# - Use torch.no_grad() for all sampling
# - For 1-step consistency, just apply f(x_T, T) directly (same as Exercise 3)
# - For multi-step, each re-noising uses FRESH random noise


<details>
<summary>Solution</summary>

The core insight is that multi-step consistency is NOT ODE solving. Each step independently teleports to $x_0$ and then re-noises to a lower noise level. There is no trajectory being followed. Each consistency function evaluation is an independent jump to the endpoint, starting from progressively less noisy inputs.

```python
@torch.no_grad()
def sample_consistency_multistep(model, n_samples, n_steps, sigma_min=0.002, sigma_max=1.0):
    """Multi-step consistency sampling.
    
    Each step: apply f to get x_0 estimate, re-noise to next lower level.
    Final step: return the x_0 estimate without re-noising.
    """
    # Noise levels from sigma_max down to sigma_min, evenly spaced
    # For 1 step: just [sigma_max]
    # For 2 steps: [sigma_max, sigma_max/2]
    # For K steps: evenly spaced from sigma_max down
    noise_levels = torch.linspace(sigma_max, sigma_min, n_steps + 1)[:-1]  # K levels
    
    # Start from pure Gaussian noise
    x = torch.randn(n_samples, 2)
    
    for i, t_val in enumerate(noise_levels):
        t = torch.full((n_samples,), t_val.item())
        
        # Apply consistency model: get clean estimate
        x_0_hat = model(x, t)
        
        # If this is the last step, return the clean estimate
        if i == len(noise_levels) - 1:
            return x_0_hat
        
        # Otherwise, re-noise to the next lower noise level
        t_next = noise_levels[i + 1].item()
        fresh_noise = torch.randn_like(x_0_hat)
        x = (1 - t_next) * x_0_hat + t_next * fresh_noise
    
    return x_0_hat


# Generate samples
cm_model.eval()
n_gen = 500

consistency_steps = [1, 2, 4, 8]
teacher_steps = [1, 5, 10, 20, 50]

cm_samples_dict = {}
teacher_samples_dict = {}

for k in consistency_steps:
    torch.manual_seed(42)
    cm_samples_dict[k] = sample_consistency_multistep(
        cm_model, n_gen, k, sigma_min=0.002, sigma_max=1.0
    )

for k in teacher_steps:
    torch.manual_seed(42)
    noise = torch.randn(n_gen, 2)
    teacher_samples_dict[k] = ode_endpoint(teacher, noise, t_start=1.0, t_end=0.0, n_steps=k)

real = sample_two_moons(n_gen)

# Plot comparison
fig, axes = plt.subplots(2, 5, figsize=(22, 9))

# Row 1: Teacher ODE at varying step counts
for col, k in enumerate(teacher_steps):
    ax = axes[0, col]
    s = teacher_samples_dict[k]
    ax.scatter(s[:, 0], s[:, 1], s=3, alpha=0.5, c='#34d399')
    ax.set_title(f'Teacher ODE\n{k} step{"s" if k > 1 else ""}', fontsize=10)
    ax.set_xlim(-5, 5); ax.set_ylim(-5, 5)
    ax.set_aspect('equal'); ax.grid(alpha=0.15)

axes[0, 0].set_ylabel('Teacher (ODE)', fontsize=12, fontweight='bold')

# Row 2: Consistency model at varying step counts + real data
for col, k in enumerate(consistency_steps):
    ax = axes[1, col]
    s = cm_samples_dict[k]
    ax.scatter(s[:, 0], s[:, 1], s=3, alpha=0.5, c='#a78bfa')
    ax.set_title(f'Consistency Model\n{k} step{"s" if k > 1 else ""}', fontsize=10)
    ax.set_xlim(-5, 5); ax.set_ylim(-5, 5)
    ax.set_aspect('equal'); ax.grid(alpha=0.15)

# Real data in the last cell
axes[1, 4].scatter(real[:, 0], real[:, 1], s=3, alpha=0.5, c='#60a5fa')
axes[1, 4].set_title('Real Data\n(reference)', fontsize=10)
axes[1, 4].set_xlim(-5, 5); axes[1, 4].set_ylim(-5, 5)
axes[1, 4].set_aspect('equal'); axes[1, 4].grid(alpha=0.15)

axes[1, 0].set_ylabel('Consistency (CM)', fontsize=12, fontweight='bold')

plt.suptitle(
    'Multi-Step Consistency vs Teacher ODE: Quality-Speed Tradeoff\n'
    'Top: Teacher ODE solving (more steps = more model evaluations)\n'
    'Bottom: Consistency model (each step = one teleport + re-noise)',
    fontsize=13, y=1.02
)
plt.tight_layout()
plt.show()

print('Expected observations:')
print('- Consistency model at 2-4 steps should approach the teacher at 20-50 steps.')
print('- Each consistency step is an independent teleportation, not a trajectory step.')
print('  Re-noising gives the model a second (third, fourth...) chance to refine.')
print('- The teacher at 1 step is poor (curvature error). The consistency model at 1 step')
print('  is trained to be accurate at 1 step, so it should be better.')
print('- Diminishing returns beyond 4 consistency steps for this simple 2D distribution.')
```

**Key design decisions:**
- Evenly-spaced noise levels work well for this toy example. In practice, the spacing matters (more levels near high noise where the prediction is hardest).
- Fresh noise at each re-noising step introduces diversity. Using the same noise would limit the model's ability to refine.
- The final step does NOT re-noise—it returns the clean estimate directly.

**Common mistakes:**
- Re-noising on the final step (produces a noisy output instead of clean)
- Using the same noise for re-noising as the initial noise (limits diversity)
- Confusing multi-step consistency with ODE solving—there is no trajectory between steps, each step is independent

</details>

---

## Key Takeaways

1. **The self-consistency property is real and verifiable.** Starting the ODE from any point on a trajectory reaches the same endpoint. You verified this with concrete numbers—not an abstraction, but a measurable property of the pretrained model's ODE.

2. **Single-step ODE methods fail because of curvature.** Even with flow matching's nearly-straight trajectories, one Euler step from pure noise produces inaccurate samples. The aggregate velocity field has enough curvature that a single linear extrapolation misses the target. This is why consistency models exist: they are trained to be accurate in one step, rather than relying on a solver that needs multiple steps.

3. **Consistency distillation uses local constraints for global consistency.** The training loss only compares adjacent timesteps: f(x at higher noise) should match f(x at slightly lower noise). Over training, these local constraints compose into global consistency—any noise level maps to the same endpoint. The teacher provides the trajectory information; the EMA target prevents collapse.

4. **Multi-step consistency is independent teleportation, not trajectory-following.** Each step jumps to the clean estimate and re-noises, rather than continuing along a trajectory. 2-4 consistency steps can approach the quality of 20-50 ODE steps—a massive speedup for the same number of model evaluations.

5. **The training objective, not the architecture, makes the difference.** The consistency model and teacher use the same MLP architecture. The consistency model is better at 1-step generation because it was trained for it—the self-consistency objective specifically optimizes for accurate single-step predictions.