## Setup

---

Additional Notes and Setup (framework info, python imports etc)


In [None]:
!pip install torchinfo

In [None]:
import torch
import torch.nn.functional as F

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
import matplotlib

In [None]:
from torchinfo import summary
from typing import Tuple

### Dataset


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
import matplotlib

# Parameters
n_samples = 50000
noise = 0.005
random_state = 42

# Generate data
X, y = make_moons(
    n_samples=n_samples,
    noise=noise,
    random_state=random_state
)

# Center and scale X
mu = X.mean(axis=0)
std = X.std(axis=0)
X = (X - mu) / std


def plot_dataset():
    fig, ax = plt.subplots()

    # Pick two discrete colors from viridis
    cmap = matplotlib.colormaps["viridis"].resampled(2)  # 2 discrete colors
    colors = [cmap(0), cmap(1)]

    # Plot each class with its viridis color
    ax.scatter(X[y == 0, 0], X[y == 0, 1], c=[colors[0]], s=12, label="0")
    ax.scatter(X[y == 1, 0], X[y == 1, 1], c=[colors[1]], s=12, label="1")

    ax.set_title("Two Moons")
    ax.set_xlabel("x")
    ax.set_ylabel("y")

    # Set the plot limits
    ax.set_xlim(X[:, 0].min() - 0.5, X[:, 0].max() + 0.5)
    ax.set_ylim(X[:, 1].min() - 0.5, X[:, 1].max() + 0.5)

    # Add a simple legend
    ax.legend(title="Label")

In [None]:
import torch
import matplotlib.pyplot as plt


def plot_forward(X, y, alpha, sigma):
    """
    Visualize the VP-SDE forward process.

    Args:
        X (np.ndarray): Initial data of shape (N, 2)
        y (np.ndarray): Labels for coloring the scatter plot
        alpha (callable): Function alpha(t)
        sigma (callable): Function sigma(t)
    """
    # Convert numpy array to torch tensor
    X_tensor = torch.from_numpy(X).float()

    # Define the timesteps we want to visualize
    timesteps = torch.tensor(
        [0.0, 0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.5, 0.75, 1.0])

    # Determine the number of rows and columns for the subplots
    num_timesteps = len(timesteps) + 1  # include final Normal(0,1)
    num_cols = 4
    num_rows = (num_timesteps + num_cols - 1) // num_cols

    # Set up the plot
    fig, axes = plt.subplots(
        num_rows, num_cols, figsize=(5 * num_cols, 5 * num_rows))
    fig.suptitle("Forward Process", fontsize=16)
    axes = axes.flatten()

    # Forward process for each timestep
    for i, t in enumerate(timesteps):
        alpha_t = alpha(t)
        sigma_t = sigma(t)
        epsilon = torch.randn_like(X_tensor)
        X_t = alpha_t * X_tensor + sigma_t * epsilon

        ax = axes[i]
        ax.scatter(X_t[:, 0].numpy(), X_t[:, 1].numpy(),
                   s=10, c=y, cmap='viridis')
        ax.set_title(f"t = {t.item():.2f}")
        ax.set_xlabel("x")
        ax.set_ylabel("y")
        ax.set_xlim(-3, 3)
        ax.set_ylim(-3, 3)
        ax.grid(True)

    # Final Normal(0,1) plot
    epsilon = torch.randn_like(X_tensor)
    final_ax = axes[num_timesteps - 1]
    final_ax.scatter(epsilon[:, 0].numpy(), epsilon[:, 1].numpy(), s=10)
    final_ax.set_title("Normal (0,1)")
    final_ax.set_xlabel("x")
    final_ax.set_ylabel("y")
    final_ax.set_xlim(-3, 3)
    final_ax.set_ylim(-3, 3)
    final_ax.grid(True)

    # Turn off any unused subplots
    for j in range(num_timesteps, len(axes)):
        axes[j].axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


def plot_schedule(alpha_fn, beta_fn, sigma_fn, t=None):
    """
    Plot beta(t), alpha(t), and sigma(t) side by side for t in [0,1].

    Args:
        alpha_fn: function t -> alpha(t)
        beta_fn: function t -> beta(t)
        sigma_fn: function t -> sigma(t)
        t: optional torch.Tensor of time values, default linspace(0,1,500)
    """
    if t is None:
        t = torch.linspace(0, 1, 500)

    b = beta_fn(t)
    a = alpha_fn(t)
    s = sigma_fn(t)

    cmap = plt.get_cmap("viridis")
    colors = [cmap(0.2), cmap(0.5), cmap(0.8)]

    fig, axs = plt.subplots(1, 3, figsize=(15, 4))

    axs[0].plot(t, b, color=colors[0])
    axs[0].set_title("beta(t)")
    axs[0].set_xlabel("t")
    axs[0].set_ylabel("beta")

    axs[1].plot(t, a, color=colors[1])
    axs[1].set_title("alpha(t)")
    axs[1].set_xlabel("t")
    axs[1].set_ylabel("alpha")

    axs[2].plot(t, s, color=colors[2])
    axs[2].set_title("sigma(t)")
    axs[2].set_xlabel("t")
    axs[2].set_ylabel("sigma")

    plt.tight_layout()
    plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from IPython.display import Image, display
import matplotlib.cm as cm


def animate_trajectories(traj, downsample_points=2, downsample_frames=5, gif_path="denoising.gif"):
    """
    Animate a trajectory tensor and optionally save as GIF.

    Args:
        traj (np.ndarray): Trajectories of shape (num_points, num_timesteps, 2)
        downsample_points (int): Factor to downsample points for faster animation
        downsample_frames (int): Factor to downsample frames for faster animation
        gif_path (str): Path to save GIF
        show (bool): Whether to display GIF in notebook
    """
    # Downsample for speed
    traj_ds = traj[::downsample_points, ::downsample_frames, :]
    num_points, num_timesteps, _ = traj_ds.shape

    # Viridis colormap
    colors = cm.viridis(np.linspace(0, 1, num_points))

    # Set up the plot
    fig, ax = plt.subplots(figsize=(6, 6))
    scat = ax.scatter(traj_ds[:, 0, 0], traj_ds[:, 0, 1], s=5, c=colors)
    ax.set_xlim(np.min(traj_ds[:, :, 0]), np.max(traj_ds[:, :, 0]))
    ax.set_ylim(np.min(traj_ds[:, :, 1]), np.max(traj_ds[:, :, 1]))
    ax.set_title(f"Timestep: 0")

    # Animation update function
    def update(frame):
        scat.set_offsets(traj_ds[:, frame, :])
        ax.set_title(f"Timestep: {frame * downsample_frames}")
        return scat,

    # Create animation (does NOT repeat automatically)
    anim = FuncAnimation(
        fig,
        update,
        frames=num_timesteps,
        interval=20,
        blit=True,
        repeat=False
    )

    # Save as GIF
    anim.save(gif_path, writer=PillowWriter(fps=30))

In [None]:
def beta(t, beta_min=0.1, beta_max=20):
    """
    Linear beta schedule for VP-SDE.

    Args:
        t: torch.Tensor of shape (...), values in [0,1]
        beta_min: minimum beta
        beta_max: maximum beta

    Returns:
        beta(t): torch.Tensor of same shape as t
    """
    return beta_min + t * (beta_max - beta_min)


def alpha(t, beta_min=0.1, beta_max=20):
    """
    Compute alpha(t) for VP-SDE with linear beta schedule.

    Args:
        t: torch.Tensor of shape (...), values in [0,1]

    Returns:
        alpha: torch.Tensor same shape as t
    """
    integral = beta_min * t + 0.5 * (beta_max - beta_min) * t**2
    return torch.exp(-0.5 * integral)


def sigma(t, beta_min=0.1, beta_max=20):
    """
    Compute sigma(t) for VP-SDE with linear beta schedule.

    Args:
        t: torch.Tensor of shape (...), values in [0,1]

    Returns:
        sigma: torch.Tensor same shape as t
    """
    return torch.sqrt(1 - torch.square(alpha(t, beta_min, beta_max)))

In [None]:
batch_size = 2048
num_epochs = 1024
lr = 1e-3

In [None]:
device = torch.device(f"cuda:{0}")

In [None]:
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

dataset = TensorDataset(torch.from_numpy(X).float())
loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
)

In [None]:
import torch
import torch.nn as nn


class ScoreModel(nn.Sequential):
    def __init__(self, dim: int, num_blocks: int = 2, hidden_size: int = 64):
        # Build the layers
        layers = [
            nn.Linear(dim + 1, hidden_size),
            nn.SiLU()
        ]
        for _ in range(num_blocks):
            layers.extend([
                nn.Linear(hidden_size, hidden_size),
                nn.SiLU()
            ])
        layers.append(nn.Linear(hidden_size, dim))
        # Initialize the Sequential
        super().__init__(*layers)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # Concatenate input with time and pass through the model
        h = torch.cat([x, t.unsqueeze(-1)], dim=-1)
        return super().forward(h)

In [None]:
def training_step(model: ScoreModel, x: torch.Tensor) -> torch.Tensor:
    # x: (b, 2) [points]
    b, *_ = x.shape
    # Sample time
    t = torch.rand((b, ), device=x.device)
    # Sample noise
    z = torch.randn_like(x)
    x_0 = x
    # Perturb
    a = alpha(t).view((-1, 1))
    s = sigma(t).view((-1, 1))
    x_t = a * x_0 + s * z
    # Score
    # For numerical stability, corresponds to weight lambda(t) = sigma(t)**2
    scaled_score = - 1 * (x_t - a * x_0)
    pred_score = model(x_t, t)
    return F.mse_loss(pred_score, scaled_score)

In [None]:
from typing import List, Tuple
import torch


@torch.inference_mode()
def sample_vp_sde_euler_maruyama(
    model: ScoreModel,
    num_samples: int,
    sample_shape: List[int] = [2],
    T=1.0,
    num_steps=1000,
    eps=1e-5,
    device='cpu',
) -> Tuple[torch.Tensor, torch.Tensor]:
    # Create time steps from T down to eps, for numerical stability
    t_steps = torch.linspace(T, eps, num_steps + 1, device=device)
    # Start with a random noise vector for all samples
    x = torch.randn([num_samples]+sample_shape, device=device)
    # List to store the trajectory of each sample
    trajectory = [x.clone()]
    for i in range(num_steps):
        t = t_steps[i]
        # Differentials
        dt = t_steps[i] - t_steps[i+1]  # positive
        dw = torch.randn_like(x)
        # Create a time tensor for the batch, now with shape (num_samples,)
        ts = torch.full((num_samples,), t, device=device)  # type: ignore
        # Get beta and sigma at the current time step
        beta_t = beta(t)
        sigma_t = sigma(t)
        # Calculate the score estimate from the scaled score
        score = model(x, ts) / (sigma_t**2)
        # Calculate the drift and diffusion terms
        drift = (-0.5 * beta_t * x - beta_t * score)
        diffusion = torch.sqrt(beta_t)
        # Perform the Euler-Maruyama step
        x = x - drift * dt + diffusion * torch.sqrt(dt) * dw
        # Store the current state
        trajectory.append(x.clone())
    # Rearrange trajectory
    trajectory = (
        torch
        .stack(trajectory, dim=0)
        .permute(1, 0, 2)
    )
    return (
        trajectory[:, -1, :],
        trajectory
    )

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal


def plot_denoising_trajectories(ax, traj: np.ndarray, ground_truth_samples: np.ndarray, num_samples_to_plot: int = 4):
    """
    Plots the denoising trajectories on a specified axis using a colormap.

    Args:
        ax (matplotlib.axes.Axes): The axes object to plot on.
        traj (np.ndarray): A 3D numpy array of shape (num_trajectories, num_steps, 2)
                           representing the trajectories.
        ground_truth_samples (np.ndarray): A 2D numpy array of shape (num_samples, 2)
                                           representing the ground truth data points.
        num_samples_to_plot (int, optional): The number of trajectories to plot. Defaults to 4.
    """
    # 1. Ground truth (black background points, p0)
    ax.scatter(ground_truth_samples[:, 0], ground_truth_samples[:,
               1], c="black", s=3, label=r"$p_0$ (Ground truth)")

    # Use the 'viridis' colormap to generate colors
    cmap = matplotlib.colormaps["viridis"]
    colors = [cmap(i / max(1, num_samples_to_plot - 1))
              for i in range(num_samples_to_plot)]

    # 2. Plot trajectories, noises, and samples
    for i in range(num_samples_to_plot):
        color = colors[i]

        # Trajectory (light gray line)
        ax.plot(traj[i, :, 0], traj[i, :, 1],
                color="gray", alpha=0.2, linewidth=1)

        # Arrow (stronger gray)
        x0, y0 = traj[i, 0]
        x1, y1 = traj[i, -1]
        ax.annotate("",
                    xy=(x1, y1), xycoords='data',
                    xytext=(x0, y0), textcoords='data',
                    arrowprops=dict(arrowstyle="->", color="gray", lw=1.2, alpha=0.6))

        # Initial noise (X marker)
        ax.scatter(x0, y0, c=[color], s=60, marker="x", linewidths=2,
                   label="Initial noise" if i == 0 else "")

        # Final sample (O marker)
        ax.scatter(x1, y1, c=[color], s=60, marker="o",  linewidths=1.2,
                   label="Final sample" if i == 0 else "")

    # 3. Gaussian contours (p1)
    xx, yy = np.meshgrid(
        np.linspace(-3, 3, 200),
        np.linspace(-3, 3, 200)
    )
    pos = np.dstack((xx, yy))
    rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
    z = rv.pdf(pos)
    ax.contour(xx, yy, z, colors="gray", linestyles="dotted")

    # --- Legend entries for non-automatic elements ---
    # p1 (Gaussian)
    ax.plot([], [], color="gray", linestyle="dotted",
            label=r"$p_1$ (Standard Gaussian)")
    # Sampling trajectory
    ax.plot([], [], color="gray", alpha=0.2,
            linewidth=1, label="Sampling trajectory")

    ax.axis("equal")
    ax.set_title("Denoising using Reverse Process")
    ax.legend()

In [None]:
import ipywidgets as widgets
from IPython.display import display, Image

# show gif inline


def display_ode_sde_comparison(ode_gif_path, sde_gif_path):
    """
    Displays two GIFs side by side with labels for comparison.

    Args:
        ode_gif_path (str): The file path to the ODE GIF.
        sde_gif_path (str): The file path to the SDE GIF.
    """
    try:
        with open(ode_gif_path, "rb") as f:
            ode_gif = f.read()
        with open(sde_gif_path, "rb") as f:
            sde_gif = f.read()

        widget1 = widgets.Image(value=ode_gif, format='gif', width=500)
        widget2 = widgets.Image(value=sde_gif, format='gif', width=500)

        label1 = widgets.Label("ODE")
        label2 = widgets.Label("SDE")

        vbox1 = widgets.VBox(
            [widget1, label1], layout=widgets.Layout(align_items='center'))
        vbox2 = widgets.VBox(
            [widget2, label2], layout=widgets.Layout(align_items='center'))
        hbox = widgets.HBox([vbox1, vbox2])

        display(hbox)
    except FileNotFoundError as e:
        print(
            f"Error: One or more files not found. Please check the paths. {e}")

## Part III: Optional

---


### Dataset


In [None]:
plot_dataset()

### Re-train the model


In [None]:
model = ScoreModel(dim=2, hidden_size=256).to(device)
summary(model)

In [None]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=lr)

In [None]:
from tqdm import tqdm

log_every_steps = 128
total_epochs = 0

# Training loop with tqdm
for epoch in range(num_epochs):
    step = 0
    # Wrap the DataLoader with tqdm for a progress bar
    progress_loader = tqdm(
        loader,
        desc=f'epoch [{epoch+1}]',
        leave=False
    )
    for x, in progress_loader:
        # Move inputs to the device
        x = x.to(device)
        # Compute the loss
        loss = training_step(model, x)
        loss.backward()
        # Clip gradient norm
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            max_norm=1.0
        )
        optimizer.step()
        # Backprop
        optimizer.zero_grad()
        if step % log_every_steps == 0:
            progress_loader.set_postfix(loss=loss.item())
        step += 1

### Deterministic Sampling with Probability Flow ODE


Perhaps surprisingly, there exists an ODE that samples from the same distribution as the VP-SDE,

$$dx = \left[-\frac{1}{2}\beta(t)x - \frac{1}{2}\beta(t)s_\theta(x,t)\right]dt$$

This ODE, also known as Probability Flow ODE, defines a continuous-time process where the state $x$ evolves backward from noise to data without any stochastic (random) component.
Unlike the SDE, this path is deterministic, meaning that starting from the same initial noisy state will always lead to the same final data sample.

We evolve the ODE using **Euler method**,

$$x_{t-\Delta t} = x_t - \left[-\frac{1}{2}\beta(t)x - \frac{1}{2}\beta(t)s_\theta(x,t)\right]\Delta t$$

where:

- $t$ is the current time step.
- $\Delta t = t_i - t_{i+1}$ is the positive step size.
- $x_{t-\Delta t}$ is the state at the next (earlier) time step.
- $s_\theta(x_t,t)$ is the estimated score function


#### Task: Sampling using Probability Flow ODE

Your goal is to implement a sampler that evolves the **Probability Flow ODE** using the **Euler method**.

**Inputs**:

- **model** – the trained score model. Given inputs `(x, t_batch)`, it outputs `pred_scaled_score`.
- **num_samples** – number of samples to generate.
- **T** – starting time (default 1.0).
- **num_steps** – number of Euler integration steps (default 1000).
- **eps** – final time, close to zero (default 1e-5).
- **device** – computation device, e.g. `"cpu"` or `"cuda"`.

**Outputs**:

The method `sample_pf_ode_euler` should return:

- **final_samples** – tensor of shape `(num_samples, 2)`.
  These are the samples at the final time step after evolving the ODE from noise.

- **trajectory** – tensor of shape `(num_samples, num_steps+1, 2)`.
  This stores the entire trajectory of each sample, starting from Gaussian noise at time $T$ and ending at the data distribution at time $\epsilon$.

**Hint:**

- Use `torch.full` to create a time tensor for the batch and `torch.randn_like` for the Brownian noise.
- Remember to **divide the model output by $\sigma^2(t)$** to obtain the true score.


In [None]:
@torch.inference_mode()
def sample_pf_ode_euler(
    model: ScoreModel,
    num_samples: int,
    T=1.0,
    num_steps=1000,
    eps=1e-5,
    device='cpu'
) -> Tuple[torch.Tensor, torch.Tensor]:
    # Create time steps from T down to eps
    t_steps = torch.linspace(T, eps, num_steps + 1, device=device)
    # Start with a random noise vector for all samples
    x = torch.randn(num_samples, 2, device=device)
    # List to store the trajectory of each sample
    trajectory = [x.clone()]
    for i in range(num_steps):
        t = t_steps[i]
        "YOUR CODE GOES HERE"
        raise NotImplementedError()
        # Store the current state
        trajectory.append(x.clone())

    trajectory = (
        torch
        .stack(trajectory, dim=0)
        .permute(1, 0, 2)
    )

    return (
        trajectory[:, -1, ...],
        trajectory
    )

**Stochastic Sampling**


In [None]:
x_sde, traj_sde = sample_vp_sde_euler_maruyama(
    model,
    num_samples=1000,
    num_steps=1000,
    device=device
)
x_sde = x_sde.detach().cpu().numpy()
traj_sde = traj_sde.detach().cpu().numpy()

x_sde.shape, traj_sde.shape

In [None]:
animate_trajectories(traj_sde, gif_path="sde.gif")

**Deterministic Sampling**


In [None]:
x_ode, traj_ode = sample_pf_ode_euler(
    model,
    num_samples=1000,
    num_steps=1000,
    device=device
)
x_ode = x_ode.detach().cpu().numpy()
traj_ode = traj_ode.detach().cpu().numpy()

x_ode.shape, traj_ode.shape

In [None]:
animate_trajectories(traj_ode, gif_path="ode.gif")

**Sampling Comparison**


In [None]:
fig, (ax_ode, ax_sde) = plt.subplots(1, 2, figsize=(14, 7))

plot_denoising_trajectories(ax_ode, traj_ode, X, num_samples_to_plot=32)
ax_ode.set_title("ODE")

plot_denoising_trajectories(ax_sde, traj_sde, X, num_samples_to_plot=32)
ax_sde.set_title("SDE")

plt.tight_layout()
plt.show()

In [None]:
display_ode_sde_comparison("ode.gif", "sde.gif")

## Flow Matching as a Diffusion Model

Recall from the Theory notebook the diffusion model framework with forward SDE below gives the flow interpolation as its perturbation kernel i.e. $\alpha_t = 1-t, \sigma_t = t$

$$
\mathrm{d}\mathbf{x}_t = -\frac{1}{1-t}x_t\,\mathrm{d}t + \sqrt{2\frac{t}{(1-t)}}\,\mathrm{d}\mathbf{w}_t,
$$

Indeed the probability flow ODE of reverse generative process gives the same velocity field as the learnt flow matching vector field!


## Learning the Probability Flow ODE directly with Flow Matching


In [None]:
FlowModel = ScoreModel

In [None]:
model = FlowModel(dim=2, hidden_size=256).to(device)
summary(model)

#### Task: Flow Matching Schedule

In Flow-Matching, we define interpolation schedules using the simple linear formulas

$$
\alpha(t) = 1 - t, \quad \sigma(t) = t
$$

where $t \in [0,1]$.

Implement these formulas in PyTorch inside the functions `alpha(t)` and `sigma(t)`. The input `t` will be a PyTorch tensor of any shape with values in $[0,1]$. The output should be a tensor of the same shape.


In [None]:
def alpha(t):
    """
    Compute alpha(t) for Flow-Matching.

    Args:
        t: torch.Tensor of shape (...), values in [0,1]

    Returns:
        alpha: torch.Tensor same shape as t
    """
    "YOUR CODE GOES HERE"
    raise NotImplementedError()


def sigma(t):
    """
    Compute sigma(t) for Flow-Matching.

    Args:
        t: torch.Tensor of shape (...), values in [0,1]

    Returns:
        sigma: torch.Tensor same shape as t
    """
    "YOUR CODE GOES HERE"
    raise NotImplementedError()

#### Task: Learning Flow using Flow Matching

In Flow-Matching, the training objective compares the model’s predicted velocity with the true velocity along the interpolation path.

The loss is defined as

$$
\mathbf{L} = \mathbb{E}_{t \sim U(0,1),\, x_1 \sim \mathbb{N}(0,I)} \big[ \|\, v_\theta(x_t, t) - (x_1 - x_0)\,\|^2 \big]
$$

where

$$
x_t = \alpha(t) \, x_0 + \sigma(t) \, x_1
$$

- $x_1$ = random Gaussian noise sample
- $x_0$ = data point from the dataset
- $v_\theta(x_t, t)$ = model’s velocity prediction

Complete the function `training_step(model, x)` by implementing this objective in PyTorch.

1. Sample a batch of times $t \sim U(0,1)$.
2. Sample Gaussian noise $x_1$.
3. Compute perturbed samples $x_t$ using the provided `alpha(t)` and `sigma(t)`.
4. Compute the true velocity $x_1 - x_0$.
5. Pass $(x_t, t)$ to the model to get the predicted velocity.
6. Return the MSE loss between predicted and true velocity.


In [None]:
def training_step(model: ScoreModel, x: torch.Tensor) -> torch.Tensor:
    # x: (b, 2) [points]
    b, *_ = x.shape
    "YOUR CODE GOES HERE"
    raise NotImplementedError()
    return F.mse_loss(
        pred_velocity,
        velocity
    )

Train the model.


In [None]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=lr)

In [None]:
from tqdm import tqdm

log_every_steps = 128
total_epochs = 0

# Training loop with tqdm
for epoch in range(num_epochs):
    step = 0
    # Wrap the DataLoader with tqdm for a progress bar
    progress_loader = tqdm(
        loader,
        desc=f'epoch [{epoch+1}]',
        leave=False
    )
    for x, in progress_loader:
        # Move inputs to the device
        x = x.to(device)
        # Compute the loss
        loss = training_step(model, x)
        loss.backward()
        # Clip gradient norm
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            max_norm=1.0
        )
        optimizer.step()
        # Backprop
        optimizer.zero_grad()
        if step % log_every_steps == 0:
            progress_loader.set_postfix(loss=loss.item())
        step += 1

#### Task: Flow ODE Sampling

We can generate new samples from a Flow-Matching model by integrating the learned ODE forward in time:

$$
\frac{dx}{dt} = -v_\theta(x_t, t)
$$

starting from Gaussian noise $x_1 \sim \mathbb{N}(0, I)$. Using **Euler’s method**, we discretize time into steps and update

$$
x_{t- \Delta t} = x_t - v_\theta(x_t, t)\, \Delta t
$$

Complete the function `sample_flow_ode_euler(model, num_samples, T, num_steps, device)` to implement this sampler.

**Steps to implement:**

1. Create evenly spaced time steps from $T \to 0$.
2. Initialize samples $x_1 \sim \mathbb{N}(0,I)$.
3. For each time step $t$:

   - Prepare a batch tensor of times.
   - Use the model to predict velocity $v_\theta(x_t, t)$.
   - Update samples with the Euler step.
   - Save intermediate samples for the trajectory.

4. Stack the trajectory so the final output has shape `(num_samples, num_steps+1, 2)`.
5. Return both the final samples and the full trajectory.

**Requirements:**

- Ensure shapes are correct:

  - `final_samples`: `(num_samples, 2)`
  - `trajectory`: `(num_samples, num_steps+1, 2)`


In [None]:
@torch.inference_mode()
def sample_flow_ode_euler(
    model: FlowModel,
    num_samples: int,
    T: float = 1.0,
    num_steps: int = 1000,
    device: str = 'cpu'
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Sample from the flow matching model using Euler integration.

    Args:
        model: trained ScoreModel (predicts velocity).
        num_samples: number of samples to generate.
        T: terminal time (default 1.0).
        num_steps: number of Euler steps.
        device: computation device.

    Returns:
        final_samples: Tensor of shape (num_samples, 2).
        trajectory: Tensor of shape (num_samples, num_steps+1, 2).
    """
    # Time discretization from T → 0
    t_steps = torch.linspace(T, 0, num_steps + 1, device=device)
    dt = T / num_steps

    # Start from pure Gaussian noise (p1)
    x = torch.randn(num_samples, 2, device=device)
    trajectory = [x.clone()]

    for i in range(num_steps):
        t = t_steps[i]
        "YOUR CODE GOES HERE"
        raise NotImplementedError()
        trajectory.append(x.clone())

    trajectory = torch.stack(trajectory, dim=0).permute(1, 0, 2)
    return trajectory[:, -1, ...], trajectory

In [None]:
x_ode, traj_ode = sample_flow_ode_euler(
    model,
    num_samples=1000,
    num_steps=1000,
    device=device
)
x_ode = x_ode.detach().cpu().numpy()
traj_ode = traj_ode.detach().cpu().numpy()

x_ode.shape, traj_ode.shape

**Sampling animation**


In [None]:
animate_trajectories(traj_ode, gif_path="flow.gif")

In [None]:
Image(filename="flow.gif")

In [None]:
fig, ax_ode = plt.subplots(1, 1, figsize=(7, 7))

plot_denoising_trajectories(ax_ode, traj_ode, X, num_samples_to_plot=32)
ax_ode.set_title("Flow ODE")

We directly learned the flow ODE and sampled deterministically!
