# Energy Matching 2D Tutorial
This notebook demonstrates how to train an energy-based model with conditional flow matching on a simple 2D dataset.

In [None]:
# Setup
import os
import time
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchcfm.conditional_flow_matching import ExactOptimalTransportConditionalFlowMatcher
from torchcfm.utils import sample_8gaussians, sample_moons

# Reproducibility
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Hyperparameters
FLOW_EPOCHS_PHASE1 = 20000
FLOW_EPOCHS_PHASE2 = 10000
BATCH_SIZE = 256
LEARNING_RATE = 1e-4
FLOW_LOSS_WEIGHT = 1.0
EBM_LOSS_WEIGHT = 1.0
SIGMA = 0.1
SAVE_DIR = 'EM_good'


In [None]:
# Model and utilities
class PotentialModel(nn.Module):
    def __init__(self, dim=2, w=128, time_varying=True):
        super().__init__()
        self.time_varying = time_varying
        self.net = nn.Sequential(
            nn.Linear(dim + (1 if time_varying else 0), w),
            nn.ReLU(),
            nn.Linear(w, w),
            nn.SiLU(),
            nn.Linear(w, w),
            nn.SiLU(),
            nn.Linear(w, w),
            nn.SiLU(),
            nn.Linear(w, 1)
        )
    def forward(self, x, t=None):
        if not self.time_varying:
            return self.net(x)
        if t is None:
            raise ValueError('time_varying=True but t is None.')
        if t.dim() == 0:
            t = t.expand(x.size(0)).unsqueeze(-1)
        elif t.dim() == 1:
            if t.size(0) != x.size(0):
                t = t.expand(x.size(0)).unsqueeze(-1)
            else:
                t = t.unsqueeze(-1)
        t_clamped = torch.clamp(t, max=0.0)
        inp = torch.cat([x, t_clamped], dim=-1)
        return self.net(inp)

def temperature(t):
    if t.dim() == 2 and t.size(1) == 1:
        t = t.squeeze(-1)
    eps = torch.zeros_like(t)
    mask_mid = (t >= 0.8) & (t < 1.0)
    eps[mask_mid] = 0.2 * (t[mask_mid] - 0.8) / 0.2
    eps[t >= 1.0] = 0.15
    return eps

def velocity_training(model, x, t):
    x = x.detach().requires_grad_(True)
    V = model(x, t)
    gradV = torch.autograd.grad(V.sum(), x, create_graph=True)[0]
    return -gradV

def velocity_inference(model, x, t):
    with torch.enable_grad():
        if not x.requires_grad:
            x = x.detach().requires_grad_(True)
        V = model(x, t)
        gradV = torch.autograd.grad(V.sum(), x, create_graph=False)[0]
    return -gradV

def gibbs_sampler(model, x_init, t_start, steps=10, dt=0.01):
    x = x_init
    for step in range(steps):
        t_current = t_start + (1 - t_start) * ((step + 1) / steps)
        x.requires_grad_(True)
        V = model(x, torch.tensor(1.0, device=x.device))
        g = torch.autograd.grad(V.sum(), x, create_graph=False)[0]
        eps = temperature(t_current)
        noise_scale = torch.sqrt(2.0 * eps * dt).unsqueeze(-1)
        noise = noise_scale * torch.randn_like(x)
        x = (x - g * dt + noise).detach()
    return x

def simulate_piecewise_length(model, x0, dt=0.01, max_length=4.0):
    x = x0
    traj = [x0.cpu().numpy()]
    times = [0.0]
    t_now = 0.0
    cum_length = 0.0
    device = x0.device
    while cum_length < max_length:
        t_tensor = torch.tensor([t_now], dtype=x0.dtype, device=device)
        g = velocity_inference(model, x, t_tensor)
        eps_now = temperature(t_tensor).item()
        if t_now < 0.8:
            dx = g * dt
        else:
            noise = torch.sqrt(torch.tensor(2.0 * eps_now * dt, device=device)) * torch.randn_like(x)
            dx = g * dt + noise
        x = (x + dx).detach()
        step_length = torch.norm(dx).item()
        cum_length += step_length
        t_now += dt
        traj.append(x.cpu().numpy())
        times.append(t_now)
    return np.array(traj), np.array(times)

def plot_trajectories_custom(traj):
    n = 2000
    plt.figure(figsize=(6, 6))
    plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=3, alpha=0.8, c='black', marker='s')
    plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.1, c='olive')
    plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1.0, c='blue', marker='*')
    for i in range(10):
        plt.plot(traj[:, i, 0], traj[:, i, 1], c='red', linewidth=1.2, alpha=1.0)
    plt.xticks([]); plt.yticks([])
    plt.show()


In [None]:
# Training function
def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.makedirs(SAVE_DIR, exist_ok=True)
    model = PotentialModel(dim=2, w=128, time_varying=True).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    FM = ExactOptimalTransportConditionalFlowMatcher(sigma=SIGMA)
    def x0_dist(n):
        return sample_8gaussians(n).to(device)
    def x1_dist(n):
        return sample_moons(n).to(device)
    for epoch in range(FLOW_EPOCHS_PHASE1):
        optimizer.zero_grad()
        x0 = x0_dist(BATCH_SIZE)
        x1 = x1_dist(BATCH_SIZE)
        t_samp, x_t, u_t = FM.sample_location_and_conditional_flow(x0, x1)
        v_pred = velocity_training(model, x_t, t_samp.unsqueeze(-1))
        loss_flow = (v_pred - u_t).pow(2).mean()
        loss_flow.backward()
        optimizer.step()
    for epoch in range(FLOW_EPOCHS_PHASE2):
        optimizer.zero_grad()
        x0 = x0_dist(BATCH_SIZE)
        x1 = x1_dist(BATCH_SIZE)
        t_flow, x_t_flow, u_t_flow = FM.sample_location_and_conditional_flow(x0, x1)
        v_pred_flow = velocity_training(model, x_t_flow, t_flow.unsqueeze(-1))
        loss_flow = (v_pred_flow - u_t_flow).pow(2).mean()
        x_data = x1_dist(BATCH_SIZE)
        Epos = model(x_data, torch.tensor(1.0, device=device)).mean()
        half_bs = BATCH_SIZE // 2
        x_data_init = x1_dist(half_bs)
        x_prior_init = x0_dist(half_bs)
        x_init_neg = torch.cat([x_data_init, x_prior_init], dim=0)
        t_start = torch.cat([torch.ones(half_bs, device=device), torch.zeros(half_bs, device=device)], dim=0)
        x_neg = gibbs_sampler(model, x_init_neg, t_start, steps=200, dt=0.01)
        Eneg = model(x_neg, torch.tensor(1.0, device=device)).mean()
        loss_ebm = Epos - Eneg
        loss = FLOW_LOSS_WEIGHT * loss_flow + EBM_LOSS_WEIGHT * loss_ebm
        loss.backward()
        optimizer.step()
    torch.save(model.state_dict(), os.path.join(SAVE_DIR, 'final_V_model.pth'))
    return model


In [None]:
# Run training and display results
model = train()
x_init = sample_8gaussians(1024).to(next(model.parameters()).device)
traj_np, times_np = simulate_piecewise_length(model, x_init, dt=0.01, max_length=400)
plot_trajectories_custom(traj_np)
