# Schrödinger Bridge Flow Matching Tutorial (JAX Version)

This notebook demonstrates Schrödinger Bridge Flow Matching (SF2M) using JAX/Flax.

SF2M combines flow matching with score matching to learn both the drift and diffusion components of a Schrödinger bridge.


In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys

# Add project root to Python path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import random
import optax
from diffrax import diffeqsolve, ODETerm, Euler, SaveAt
from tqdm import tqdm

from jaxcfm.conditional_flow_matching import SchrodingerBridgeConditionalFlowMatcher
from jaxcfm.models.models import MLP
from jaxcfm.utils import sample_8gaussians, sample_moons

savedir = "models/2d"
os.makedirs(savedir, exist_ok=True)


In [None]:
def plot_trajectories_sb(traj, legend=True):
    n = min(2000, traj.shape[1])
    plt.figure(figsize=(10, 10))
    plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black")
    plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.4, alpha=0.1, c="olive")
    plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue")

    for i in range(5, min(15, traj.shape[1])):
        plt.plot(traj[:, i, 0], traj[:, i, 1], alpha=0.9, c="red")
    if legend:
        plt.legend([r"$p_0$", r"$p_t$", r"$p_1$", r"$X_t \mid X_0$"])
    plt.xticks([])
    plt.yticks([])
    plt.axis("off")
    plt.show()


In [None]:
batch_size = 256
sigma = 1.0
dim = 2
learning_rate = 0.01

# Initialize models
key = random.PRNGKey(42)
model = MLP(dim=dim + 1, out_dim=dim, w=64, time_varying=False)  # +1 for time
score_model = MLP(dim=dim + 1, out_dim=dim, w=64, time_varying=False)

# Initialize model parameters
key, subkey1, subkey2 = random.split(key, 3)
dummy_input = jnp.ones((batch_size, dim + 1))
model_params = model.init(subkey1, dummy_input)
score_params = score_model.init(subkey2, dummy_input)

# Initialize optimizer
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init({"model": model_params, "score": score_params})

# Initialize flow matcher
FM = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)


In [None]:
# Training step function
@jax.jit
def train_step(params_dict, opt_state, x0, x1, key):
    key, subkey1, subkey2 = random.split(key, 3)
    
    # Sample location and conditional flow
    t, xt, ut, eps = FM.sample_location_and_conditional_flow(subkey1, x0, x1, return_noise=True)
    lambda_t = FM.compute_lambda(t)
    
    # Prepare model inputs
    model_input = jnp.concatenate([xt, t[:, None]], axis=-1)
    
    def loss_fn(params_dict):
        vt = model.apply(params_dict["model"], model_input)
        st = score_model.apply(params_dict["score"], model_input)
        flow_loss = jnp.mean((vt - ut) ** 2)
        score_loss = jnp.mean((lambda_t[:, None] * st + eps) ** 2)
        return flow_loss + score_loss
    
    loss, grads = jax.value_and_grad(loss_fn)(params_dict)
    updates, opt_state = optimizer.update(grads, opt_state, params_dict)
    params_dict = optax.apply_updates(params_dict, updates)
    
    return params_dict, opt_state, loss, key

# Training loop
print("Training SF2M model...")
print("Note: First iteration will be slow due to JIT compilation...")
params_dict = {"model": model_params, "score": score_params}

for i in tqdm(range(10000)):
    key, subkey1, subkey2 = random.split(key, 3)
    x0 = sample_8gaussians(subkey1, batch_size)
    x1 = sample_moons(batch_size)
    params_dict, opt_state, loss, key = train_step(params_dict, opt_state, x0, x1, subkey2)
    
    if (i + 1) % 1000 == 0:
        print(f"Step {i+1}, loss: {loss:.4f}")


In [None]:
# Save models (using pickle for JAX parameters)
import pickle

with open(f"{savedir}/sf2m_v1_model.pkl", "wb") as f:
    pickle.dump(params_dict["model"], f)
    
with open(f"{savedir}/sf2m_v1_score.pkl", "wb") as f:
    pickle.dump(params_dict["score"], f)
    
print("Models saved!")


In [None]:
# Generate trajectories using ODE solver
key, subkey = random.split(key)
x0 = sample_8gaussians(subkey, 1024)

# Define ODE vector field (drift only)
def vector_field_ode(t, y, args):
    # y has shape (batch_size, dim)
    t_batch = jnp.full((y.shape[0],), t)
    model_input = jnp.concatenate([y, t_batch[:, None]], axis=-1)
    return model.apply(params_dict["model"], model_input)

term = ODETerm(vector_field_ode)
solver = Euler()
saveat = SaveAt(ts=jnp.linspace(0, 1, 100))

# Solve ODE in batches
traj_list = []
batch_size_ode = 256
for i in range(0, x0.shape[0], batch_size_ode):
    batch_x0 = x0[i:i+batch_size_ode]
    solution = diffeqsolve(term, solver, t0=0.0, t1=1.0, dt0=0.01, y0=batch_x0, saveat=saveat)
    traj_list.append(solution.ys)

traj = jnp.concatenate(traj_list, axis=1)  # Shape: (100, 1024, 2)
print(f"ODE trajectory shape: {traj.shape}")


In [None]:
# Generate trajectories using SDE solver (Euler-Maruyama)
# Note: diffrax doesn't have built-in SDE support, so we implement Euler-Maruyama manually
def euler_maruyama_sde(key, x0, num_steps=100, dt=0.01):
    """Euler-Maruyama method for solving SDE: dx = f(t,x)dt + g(t,x)dW"""
    ts = jnp.linspace(0, 1, num_steps + 1)
    traj = [x0]
    x = x0
    
    for i in range(num_steps):
        t = ts[i]
        key, subkey = random.split(key)
        
        # Drift: f(t, x) = model(x, t) + score(x, t)
        t_batch = jnp.full((x.shape[0],), t)
        model_input = jnp.concatenate([x, t_batch[:, None]], axis=-1)
        drift = model.apply(params_dict["model"], model_input) + score_model.apply(params_dict["score"], model_input)
        
        # Diffusion: g(t, x) = sigma
        dW = random.normal(subkey, x.shape) * jnp.sqrt(dt)
        diffusion = sigma * dW
        
        # Euler-Maruyama step
        x = x + drift * dt + diffusion
        traj.append(x)
    
    return jnp.stack(traj, axis=0)  # Shape: (num_steps+1, batch_size, dim)

key, subkey = random.split(key)
sde_traj = euler_maruyama_sde(subkey, x0, num_steps=100, dt=0.01)
print(f"SDE trajectory shape: {sde_traj.shape}")


In [None]:
# Plot trajectories
import numpy as np

plot_trajectories_sb(np.array(traj), legend=False)
plot_trajectories_sb(np.array(sde_traj), legend=False)
