# Model Plotting (JAX Version)

Here we compare different interpolants together on the same dataset from saved models.

**Note**: This is a JAX/Flax conversion. Models are loaded from pickle files containing JAX parameters.

In [None]:
import math
import os
import sys
import time
import pickle

# Add project root to Python path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import imageio
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from diffrax import diffeqsolve, ODETerm, Euler, SaveAt

from jaxcfm.models.models import MLP
from jaxcfm.utils import sample_8gaussians


In [None]:
# Implement some helper functions

def sample_normal(key, n, dim=2):
    """Sample from standard normal distribution."""
    return random.multivariate_normal(key, jnp.zeros(dim), jnp.eye(dim), shape=(n,))


def log_normal_density(x):
    """Compute log density of standard normal."""
    dim = x.shape[-1]
    log_norm = -0.5 * dim * jnp.log(2 * jnp.pi)
    log_det = -0.5 * jnp.log(jnp.linalg.det(jnp.eye(dim)))
    quad_form = -0.5 * jnp.sum(x ** 2, axis=-1)
    return log_norm + log_det + quad_form


def log_8gaussian_density(x, scale=5, var=0.1):
    """Compute log density of 8-Gaussian mixture."""
    sqrt2_inv = 1.0 / math.sqrt(2)
    centers = jnp.array([
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (sqrt2_inv, sqrt2_inv),
        (sqrt2_inv, -sqrt2_inv),
        (-sqrt2_inv, sqrt2_inv),
        (-sqrt2_inv, -sqrt2_inv),
    ]) * scale
    centers = centers.T.reshape(1, 2, 8)
    # calculate shifted xs [batch, centers, dims]
    x_expanded = x[:, :, None]  # [batch, dim, 1]
    x_shifted = x_expanded - centers  # [batch, dim, 8]
    x_shifted = jnp.transpose(x_shifted, (0, 2, 1))  # [batch, 8, dim]
    
    # Compute log probabilities for each center
    log_probs = []
    for i in range(8):
        center_log_prob = log_normal_density(x_shifted[:, i, :] / math.sqrt(var))
        log_probs.append(center_log_prob)
    log_probs = jnp.stack(log_probs, axis=-1)  # [batch, 8]
    return jax.scipy.special.logsumexp(log_probs, axis=-1)


def jacobian_trace(f, x):
    """Compute trace of Jacobian using JAX."""
    def f_sum(x):
        return jnp.sum(f(x))
    jac = jax.jacfwd(f)(x)
    return jnp.trace(jac)


def cnf_vector_field(params, model, t, y):
    """CNF vector field with trace computation."""
    # y has shape [batch, dim+1] where first dimension is for divergence
    x_in = y[:, 1:]  # [batch, dim]
    t_batch = jnp.full((x_in.shape[0],), t)
    model_input = jnp.concatenate([x_in, t_batch[:, None]], axis=-1)
    x_out = model.apply(params, model_input)  # [batch, dim]
    
    # Compute trace of Jacobian
    def f(x):
        t_b = jnp.full((x.shape[0],), t)
        inp = jnp.concatenate([x, t_b[:, None]], axis=-1)
        return model.apply(params, inp)
    
    # Compute trace for each sample (simplified - can be optimized)
    trJ = jnp.zeros(x_in.shape[0])
    for i in range(x_in.shape[0]):
        jac = jax.jacfwd(lambda x: f(x[None, :])[0])(x_in[i])
        trJ = trJ.at[i].set(jnp.trace(jac))
    
    # Return augmented vector field
    return jnp.concatenate([-trJ[:, None], x_out], axis=1)

In [None]:
savedir = "models/8gaussian-moons"
# Load JAX models (parameters stored as pickle files)
# Note: Model architecture needs to match the saved parameters
dim = 2
model_arch = MLP(dim=dim + 1, out_dim=dim, w=64, time_varying=False)

models = {}
model_params = {}
key = random.PRNGKey(42)

# Try to load models - if files don't exist, initialize new ones
model_files = {
    "CFM": f"{savedir}/cfm_v1.pkl",
    "OT-CFM (ours)": f"{savedir}/otcfm_v1.pkl",
    "SB-CFM (ours)": f"{savedir}/sbcfm_v1.pkl",
    "VP-CFM": f"{savedir}/stochastic_interpolant_v1.pkl",
    "Action-Matching": f"{savedir}/action_matching_v1.pkl",
    "Action-Matching (Swish)": f"{savedir}/action_matching_swish_v1.pkl",
}

for name, filepath in model_files.items():
    if os.path.exists(filepath):
        with open(filepath, "rb") as f:
            model_params[name] = pickle.load(f)
        models[name] = model_arch
        print(f"Loaded {name} from {filepath}")
    else:
        print(f"Warning: {filepath} not found. Skipping {name}.")

In [None]:
w = 7
points = 100j
points_real = 100
Y, X = np.mgrid[-w:w:points, -w:w:points]
gridpoints = jnp.array(np.stack([X.flatten(), Y.flatten()], axis=1), dtype=jnp.float32)
points_small = 20j
points_real_small = 20
Y_small, X_small = np.mgrid[-w:w:points_small, -w:w:points_small]
gridpoints_small = jnp.array(np.stack([X_small.flatten(), Y_small.flatten()], axis=1), dtype=jnp.float32)

# Generate trajectories
key = random.PRNGKey(42)
key, subkey = random.split(key)
sample = sample_8gaussians(subkey, 1024)
ts = jnp.linspace(0, 1, 101)
trajs = {}

# Define vector field for ODE
def vector_field(t, y, args):
    params, model = args
    t_batch = jnp.full((y.shape[0],), t)
    model_input = jnp.concatenate([y, t_batch[:, None]], axis=-1)
    return model.apply(params, model_input)

for name in models.keys():
    if name not in model_params:
        continue
    params = model_params[name]
    model = models[name]
    
    term = ODETerm(vector_field)
    solver = Euler()
    saveat = SaveAt(ts=ts)
    
    # Solve ODE in batches
    traj_list = []
    batch_size = 256
    for i in range(0, sample.shape[0], batch_size):
        batch_sample = sample[i:i+batch_size]
        solution = diffeqsolve(term, solver, t0=0.0, t1=1.0, dt0=0.01, 
                              y0=batch_sample, saveat=saveat, args=(params, model))
        traj_list.append(solution.ys)
    trajs[name] = jnp.concatenate(traj_list, axis=1)  # [time, batch, dim]

# Visualization loop
names = [n for n in ["CFM", "Action-Matching", "Action-Matching (Swish)", 
                     "VP-CFM", "SB-CFM (ours)", "OT-CFM (ours)"] if n in models]

for i, t in enumerate(ts):
    fig, axes = plt.subplots(3, len(names), figsize=(6 * len(names), 6 * 3))
    for axis, name in zip(axes.T, names):
        if name not in model_params:
            continue
        params = model_params[name]
        model = models[name]
        
        # Density plot (simplified - using direct evaluation)
        if t > 0:
            # For CNF, we'd need to solve backward ODE - simplified here
            log_probs = log_8gaussian_density(gridpoints)
        else:
            log_probs = log_8gaussian_density(gridpoints)
        log_probs_np = np.array(log_probs).reshape(Y.shape)
        
        ax = axis[0]
        ax.pcolormesh(X, Y, np.exp(log_probs_np), vmax=1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(-w, w)
        ax.set_ylim(-w, w)
        ax.set_title(f"{name}", fontsize=30)
        
        # Quiver plot
        t_batch = jnp.full((gridpoints_small.shape[0],), t)
        model_input = jnp.concatenate([gridpoints_small, t_batch[:, None]], axis=-1)
        out = model.apply(params, model_input)
        out_np = np.array(out).reshape([points_real_small, points_real_small, 2])
        
        ax = axis[1]
        ax.quiver(
            X_small,
            Y_small,
            out_np[:, :, 0],
            out_np[:, :, 1],
            np.sqrt(np.sum(out_np**2, axis=-1)),
            cmap="coolwarm",
            scale=50.0,
            width=0.015,
            pivot="mid",
        )
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(-w, w)
        
        # Trajectory plot
        ax = axis[2]
        sample_traj = np.array(trajs[name])
        ax.scatter(sample_traj[0, :, 0], sample_traj[0, :, 1], s=10, alpha=0.8, c="black")
        ax.scatter(sample_traj[:i, :, 0], sample_traj[:i, :, 1], s=0.2, alpha=0.2, c="olive")
        ax.scatter(sample_traj[i, :, 0], sample_traj[i, :, 1], s=4, alpha=1, c="blue")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(-w, w)
        ax.set_ylim(-w, w)
    plt.suptitle(f"8gaussians to Moons T={float(t):0.2f}", fontsize=40)
    os.makedirs("figures/trajectory/v3/", exist_ok=True)
    plt.savefig(f"figures/trajectory/v3/{float(t):0.2f}.png", dpi=40)
    plt.close()

In [None]:
gif_name = "8gaussians-to-moons"
import imageio.v2 as imageio
with imageio.get_writer(f"{gif_name}.gif", mode="I") as writer:
    for filename in [f"figures/trajectory/v3/{float(t):0.2f}.png" for t in ts] + [
        f"figures/trajectory/v3/{float(ts[-1]):0.2f}.png"
    ] * 10:
        image = imageio.imread(filename)
        writer.append_data(image)

  image = imageio.imread(filename)


In [None]:
# Reload models (same as cell 3)
# This cell is kept for compatibility with original notebook structure
pass

In [None]:
# Gaussian to Moons visualization (similar to above but starting from normal distribution)
w = 7
points = 100j
points_real = 100
Y, X = np.mgrid[-w:w:points, -w:w:points]
gridpoints = jnp.array(np.stack([X.flatten(), Y.flatten()], axis=1), dtype=jnp.float32)
points_small = 20j
points_real_small = 20
Y_small, X_small = np.mgrid[-w:w:points_small, -w:w:points_small]
gridpoints_small = jnp.array(np.stack([X_small.flatten(), Y_small.flatten()], axis=1), dtype=jnp.float32)

# Generate trajectories from normal distribution
key = random.PRNGKey(42)
key, subkey = random.split(key)
sample = sample_normal(subkey, 1024)
ts = jnp.linspace(0, 1, 101)
trajs = {}

def vector_field(t, y, args):
    params, model = args
    t_batch = jnp.full((y.shape[0],), t)
    model_input = jnp.concatenate([y, t_batch[:, None]], axis=-1)
    return model.apply(params, model_input)

for name in models.keys():
    if name not in model_params:
        continue
    params = model_params[name]
    model = models[name]
    
    term = ODETerm(vector_field)
    solver = Euler()
    saveat = SaveAt(ts=ts)
    
    traj_list = []
    batch_size = 256
    for i in range(0, sample.shape[0], batch_size):
        batch_sample = sample[i:i+batch_size]
        solution = diffeqsolve(term, solver, t0=0.0, t1=1.0, dt0=0.01,
                              y0=batch_sample, saveat=saveat, args=(params, model))
        traj_list.append(solution.ys)
    trajs[name] = jnp.concatenate(traj_list, axis=1)

names = [n for n in ["CFM", "Action-Matching", "Action-Matching (Swish)",
                     "VP-CFM", "SB-CFM (ours)", "OT-CFM (ours)"] if n in models]

for i, t in enumerate(ts):
    fig, axes = plt.subplots(3, len(names), figsize=(len(names) * 6, len(names) * 3))
    for axis, name in zip(axes.T, names):
        if name not in model_params:
            continue
        params = model_params[name]
        model = models[name]
        
        # Density plot
        if t > 0:
            log_probs = log_normal_density(gridpoints)
        else:
            log_probs = log_normal_density(gridpoints)
        log_probs_np = np.array(log_probs).reshape(Y.shape)
        
        ax = axis[0]
        ax.pcolormesh(X, Y, np.exp(log_probs_np))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(-w, w)
        ax.set_ylim(-w, w)
        ax.set_title(f"{name}", fontsize=30)
        
        # Quiver plot
        t_batch = jnp.full((gridpoints_small.shape[0],), t)
        model_input = jnp.concatenate([gridpoints_small, t_batch[:, None]], axis=-1)
        out = model.apply(params, model_input)
        out_np = np.array(out).reshape([points_real_small, points_real_small, 2])
        
        ax = axis[1]
        ax.quiver(
            X_small, Y_small, out_np[:, :, 0], out_np[:, :, 1],
            np.sqrt(np.sum(out_np**2, axis=-1)), cmap="coolwarm",
            scale=50.0, width=0.015, pivot="mid",
        )
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(-w, w)
        
        # Trajectory plot
        ax = axis[2]
        sample_traj = np.array(trajs[name])
        ax.scatter(sample_traj[0, :, 0], sample_traj[0, :, 1], s=10, alpha=0.8, c="black")
        ax.scatter(sample_traj[:i, :, 0], sample_traj[:i, :, 1], s=0.2, alpha=0.2, c="olive")
        ax.scatter(sample_traj[i, :, 0], sample_traj[i, :, 1], s=4, alpha=1, c="blue")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(-w, w)
        ax.set_ylim(-w, w)
    plt.suptitle(f"Gaussian to Moons T={float(t):0.2f}", fontsize=40)
    os.makedirs("figures/trajectory2/v3/", exist_ok=True)
    plt.savefig(f"figures/trajectory2/v3/{float(t):0.2f}.png", dpi=40)
    plt.close()

In [None]:
gif_name = "gaussians-to-moons"
import imageio.v2 as imageio
ts = jnp.linspace(0, 1, 101)
with imageio.get_writer(f"{gif_name}.gif", mode="I") as writer:
    for filename in [f"figures/trajectory2/v3/{float(t):0.2f}.png" for t in ts] + [
        f"figures/trajectory2/v3/{float(ts[-1]):0.2f}.png"
    ] * 10:
        image = imageio.imread(filename)
        writer.append_data(image)

  image = imageio.imread(filename)
