# Seminar 11: Flow Matching

**Deep Learning Course 2025**

**Author:** Nikita Kiselev

In [None]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader

plt.rcParams["axes.linewidth"] = 1.5
plt.rcParams["font.size"] = 16

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

**2D-Spiral distribution**

In this section we will train a Flow Matching model to generate data from a 2D-Spiral distribution.

In [None]:
def create_spiral_data(num_samples=100_000, scale=1):
    noise = 0.1 * scale 
    theta = 6 * torch.pi * torch.rand(num_samples)
    r = theta / (2 * torch.pi) * scale
    x = r * torch.cos(theta) + noise * torch.randn(num_samples)
    y = r * torch.sin(theta) + noise * torch.randn(num_samples)
    return torch.stack([x, y], dim=1)

In [None]:
train_dataset = create_spiral_data()
train_loader = DataLoader(dataset=train_dataset, batch_size=1024, shuffle=True)

In [None]:
plt.scatter(train_dataset[:1000, 0], train_dataset[:1000, 1], s=50, alpha=0.5, color="green")
plt.title("Data Distribution")
plt.grid(alpha=0.1)
plt.gca().set_aspect("equal")
plt.xlim(-3.5, 3.5)
plt.ylim(-3.5, 3.5)
plt.tight_layout()
plt.show() 

**Flow Matching (Linear Interpolation)**

In conditional flow matching, our objective is to learn a vector field $\mathbf{f}_{\boldsymbol{\theta}}(\mathbf{x}, t)$, parameterized by a neural network, that aligns with a known target vector field $\mathbf{f}(\mathbf{x}, \mathbf{z}, t)$ at each point along a path connecting the data distribution and a base distribution. In this task, we consider the **linear interpolation conditional vector field**, defined by:
$$
\mathbf{f}(\mathbf{x}, \mathbf{z}, t) = \frac{d\mathbf{x}}{dt} = \mathbf{x}_1 - \mathbf{x}_0
$$
which means that $\mathbf{x}$ iterpolates linearly by making data from pure noise:
$$
\mathbf{x}_t = t \mathbf{x}_1 + (1 - t) \mathbf{x}_0
$$

So, the training objective is defined as:
$$
\mathbb{E}_{t \sim \mathcal{U}[0, 1]}\, \mathbb{E}_{\mathbf{x}_1 \sim \pi(\mathbf{x})} \mathbb{E}_{\mathbf{x}_0 \sim \mathcal{N}(\mathbf{0}, \mathbf{I})} \left[ \left\| (\mathbf{x}_1 - \mathbf{x}_0) - \mathbf{f}_{\boldsymbol{\theta}}(\mathbf{x}, t) \right\|^2 \right] \to \min_{\boldsymbol{\theta}}
$$

**Building the model class**

Our Flow Matching model must take two arguments: noisy vector $\mathbf{x}_t$ and timestep $t$.

Here we map them into a common hidden space and add.

In [None]:
class FlowMatchingModel(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=512):
        super().__init__()
        self.x_proj = nn.Linear(input_dim, hidden_dim)
        self.t_proj = TimeEmbedding(hidden_dim)  # TODO: need to be implemented
        self.net = nn.Sequential(
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, input_dim),
        )
        
    def forward(self, x, t):
        x = self.x_proj(x)
        t = self.t_proj(t)
        x = x + t
        x = self.net(x)
        return x

To condition on a timestep we need to implement a time embedding module.

In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        # self.freqs = torch.arange(1, dim // 2 + 1) * torch.pi
        self.register_buffer("freqs", torch.arange(1, dim // 2 + 1) * torch.pi)  # NOTE: important for .to(device), etc.
        
    def forward(self, t):
        emb = self.freqs * t
        emb = torch.cat([emb.cos(), emb.sin()], dim=-1)
        return emb

**Training and evaluating the Flow Matching model**

Now that we’ve defined the Flow Matching model, it’s time to train it.

We’ll run the training loop for several epochs, calculate the loss, and visualize how well the model learns to generate data.

Our training loop must follow the next steps:
1. Sample $\mathbf{x}_1 \sim \pi(\mathbf{x})$
2. Sample time $t \sim \mathcal{U}[0, 1]$ and $\mathbf{x}_0 \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
3. Obtain the noisy vector $\mathbf{x}_t = t \mathbf{x}_1 + (1 - t) \mathbf{x}_0$
4. Compute the loss $\mathcal{L} = \left\| (\mathbf{x}_1 - \mathbf{x}_0) - \mathbf{f}_{\boldsymbol{\theta}}(\mathbf{x}_t, t) \right\|_2^2$

In [None]:
epochs = 100
learning_rate = 1e-3

model = FlowMatchingModel().to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

model.train()

for epoch in tqdm(range(epochs), desc="Epoch"):
    for x_1 in train_loader:
        x_1 = x_1.to(device)
        x_0 = torch.randn_like(x_1).to(device)
        t = torch.rand(x_1.shape[0], 1).to(device)
        x_t = t * x_1 + (1 - t) * x_0
        velocity = x_1 - x_0
        pred_velocity = model(x_t, t)
        loss = F.mse_loss(pred_velocity, velocity)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

**Sample from trained model**

To sample from a pretrained Flow Matching model, we need to implement an $\texttt{ODESolver}$, which can be a simple Euler method:

$$
\mathbf{x}(t + h) = \mathbf{x}(t) + h \cdot \mathbf{f}_{\boldsymbol{\theta}}(\mathbf{x}(t), t)
$$

It must take a pretrained model, number of points to sample and a number of approximation steps.

In [None]:
@torch.no_grad()
def sample(model, num_samples, num_steps=100):
    model.eval()
    x = torch.randn(num_samples, 2).to(device)
    ts = torch.linspace(0, 1, num_steps).to(device)
    for t, dt in zip(ts[:-1], torch.diff(ts)):
        t = torch.full((num_samples, 1), t).to(device)
        pred_velocity = model(x, t)
        x = x + dt * pred_velocity
    return x

In [None]:
samples = sample(model, num_samples=1000).cpu()
plt.scatter(samples[:, 0], samples[:, 1], s=50, alpha=0.5, color="orange")
plt.title("Learned Distribution")
plt.grid(alpha=0.1)
plt.gca().set_aspect("equal")
plt.xlim(-3.5, 3.5)
plt.ylim(-3.5, 3.5)
plt.tight_layout()
plt.show() 

**Plot transition of the points**

Below we provide a code, similar to above one, which will plot the transition of the points from the base distribution to the learned one.

In [None]:
def interpolate_color(t, start="blue", end="orange"):
    start_color = plt.cm.colors.to_rgb(start)
    end_color = plt.cm.colors.to_rgb(end)
    return (1 - t) * np.array(start_color) + t * np.array(end_color)

@torch.no_grad()
def plot_transition(model, num_samples, num_steps=100, num_plots=5):
    model.eval()
    x = torch.randn(num_samples, 2).to(device)
    ts = torch.linspace(0, 1, num_steps + 1).to(device)
    fig, axes = plt.subplots(1, num_plots, figsize=(4 * num_plots, 4))
    plot_idx = 0
    ax = axes[plot_idx]
    color = interpolate_color(0)
    ax.scatter(x.cpu()[:, 0], x.cpu()[:, 1], s=50, alpha=0.5, c=[color])
    ax.set_title(f"t = {0:.2f}")
    ax.grid(alpha=0.1)
    ax.set_aspect("equal")
    ax.set_xlim(-3.5, 3.5)
    ax.set_ylim(-3.5, 3.5)
    plot_idx += 1
    plot_every = num_steps // (num_plots - 1)
    for step, (t, dt) in enumerate(zip(ts[:-1], torch.diff(ts))):
        color = interpolate_color(t.cpu())
        t = torch.full((num_samples, 1), t).to(device)
        pred_velocity = model(x, t)
        x = x + dt * pred_velocity
        if (step + 1) % plot_every == 0:
            ax = axes[plot_idx]
            ax.scatter(x.cpu()[:, 0], x.cpu()[:, 1], s=50, alpha=0.5, c=[color])
            ax.set_title(f"t = {t[0][0].cpu() + dt:.2f}")
            ax.grid(alpha=0.1)
            ax.set_aspect("equal")
            ax.set_xlim(-3.5, 3.5)
            ax.set_ylim(-3.5, 3.5)
            plot_idx += 1
    plt.tight_layout()
    plt.show()

In [None]:
plot_transition(model, num_samples=1000)