# Hello World

In [None]:
print('Hello World')

## Energy Matching Toy Example

In [None]:

import math
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torch
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons

from torchcfm.conditional_flow_matching import ExactOptimalTransportConditionalFlowMatcher
from torchcfm.models.models import MLP, GradModel
from torchcfm.utils import *

savedir = 'energy_matching'
os.makedirs(savedir, exist_ok=True)

def plot_trajectories_custom(traj, filename=None):
    n = 2000
    plt.figure(figsize=(6, 6))
    plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c='black')
    plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c='olive')
    plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c='blue')
    plt.legend(['Prior sample z(S)', 'Flow', 'z(0)'])
    plt.xticks([])
    plt.yticks([])
    if filename:
        plt.savefig(filename)
        plt.close()
    else:
        plt.show()

class MLP2(torch.nn.Module):
    def __init__(self, dim, out_dim=None, w=64, time_varying=False):
        super().__init__()
        self.time_varying = time_varying
        if out_dim is None:
            out_dim = dim
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim + (1 if time_varying else 0), w),
            torch.nn.ReLU(),
            torch.nn.Linear(w, w),
            torch.nn.SiLU(),
            torch.nn.Linear(w, w),
            torch.nn.SiLU(),
            torch.nn.Linear(w, out_dim),
        )
    def forward(self, x):
        return self.net(x)

class StaticWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, t, x, *args, **kwargs):
        if not x.requires_grad:
            x = x.clone().requires_grad_()
        return self.model(x)

sigma = 0.1
dim = 2
batch_size = 256
potential_net = MLP2(dim=dim, out_dim=1, time_varying=False, w=256)
model = StaticWrapper(GradModel(potential_net))
optimizer = torch.optim.Adam(potential_net.parameters())
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)

for k in range(20000):
    optimizer.zero_grad()
    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)
    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
    xt.requires_grad_(True)
    V = potential_net(xt)
    V_sum = V.sum()
    dVdx = torch.autograd.grad(V_sum, xt, create_graph=True)[0]
    vt = -dVdx
    loss = torch.mean((vt - ut)**2)
    loss.backward()
    optimizer.step()
    if (k + 1) % 100 == 0:
        node = NeuralODE(torch_wrapper(model), solver='dopri5', sensitivity='adjoint', atol=1e-4, rtol=1e-4)
        with torch.no_grad():
            traj = node.trajectory(sample_8gaussians(1024), t_span=torch.linspace(0, 1, 100)).detach()
            filename = os.path.join(savedir, f'trajectory_{k+1:05d}.png')
            plot_trajectories_custom(traj.cpu().numpy(), filename)

torch.save(model, f"{savedir}/energy_matching_v1.pt")
