# Energy Matching: 2D Tutorial

This notebook demonstrates training a scalar energy potential on a 2D dataset, mapping from an "8 Gaussians" distribution to a "Two Moons" target. The trained model is then used to perform:

- (a) **Unconditional sample generation**
- (b) **Conditional posterior sampling**
- (c) **Conditional posterior sampling with additional interaction energies**

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn

class PotentialModel(nn.Module):
    def __init__(self, dim=2, w=128, time_varying=True):
        super().__init__()
        self.time_varying = time_varying
        self.net = nn.Sequential(
            nn.Linear(dim + (1 if time_varying else 0), w),
            nn.ReLU(),
            nn.Linear(w, w),
            nn.SiLU(),
            nn.Linear(w, w),
            nn.SiLU(),
            nn.Linear(w, w),
            nn.SiLU(),
            nn.Linear(w, 1)
        )

    def forward(self, x, t=None):
        if not self.time_varying:
            return self.net(x)
        if t is None:
            raise ValueError('time_varying=True but t is None.')
        if t.dim() == 0:
            t = t.expand(x.size(0)).unsqueeze(-1)
        elif t.dim() == 1:
            if t.size(0) != x.size(0):
                t = t.expand(x.size(0)).unsqueeze(-1)
            else:
                t = t.unsqueeze(-1)
        t_clamped = torch.clamp(t, max=0.0)
        inp = torch.cat([x, t_clamped], dim=-1)
        return self.net(inp)


In [None]:
from utils_2D import train

# Reproducibility
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
epochs_phase1 = 200
epochs_phase2 = 0
batch_size = 256
lr = 1e-4
flow_loss_weight = 1.0
ebm_loss_weight = 1.0
sigma = 0.1
save_dir = "2D_toy"

model = train(PotentialModel,
              device=device,
              batch_size=batch_size,
              lr=lr,
              epochs_phase1=epochs_phase1,
              epochs_phase2=epochs_phase2,
              flow_weight=flow_loss_weight,
              ebm_weight=ebm_loss_weight,
              sigma=sigma,
              save_dir=save_dir)



The following Langevin sampling algorithm is used in this notebook. In case (a) we consider only the potential energy $V_{\theta}(x)$ without additional terms.

\begin{algorithm}[H]
\small
\caption{Sampling for inverse problems (with optional interaction energy)}
\label{alg:sampling_inverse_interaction}
\begin{algorithmic}[1]
\For{\(m = 1\) to \(M\)}
  \State Initialize $x_m^{(0)} \sim \mathcal{N}(0, I)$ \Comment{Start each chain from Gaussian prior}
\EndFor
\State $N \gets \lfloor \samplingTime / \Delta t \rfloor$ \Comment{Number of Langevin steps for sampling time $\samplingTime$}
\For{\(n = 0\) to \(N - 1\)}
  \For{\(m = 1\) to \(M\)} \Comment{Data fidelity + prior + interaction energy}
    \State $U_\theta(x_m^{(n)}) \gets \frac{\varepsilon^{(n)}}{\zeta^2}\lVert y - A(x_m^{(n)})\rVert^2 + V_\theta(x_m^{(n)}) - \frac{\varepsilon^{(n)}}{\sigma^2} \sum_{k \neq m} W(x_m^{(n)}, x_k^{(n)})$
    \State $\eta \sim \mathcal{N}(0, I)$ \Comment{Gaussian noise for Langevin step}
    \State $x_m^{(n+1)} \gets x_m^{(n)} - \Delta t\,\nabla_x U_\theta(x_m^{(n)}) + \sqrt{2\varepsilon^{(n)}\Delta t}\;\eta$ \Comment{Langevin step}
  \EndFor
\EndFor
\end{algorithmic}
\end{algorithm}


In [None]:
from utils_2D import simulate_piecewise_length, plot_trajectories_custom
from torchcfm.utils import sample_8gaussians

x_init = sample_8gaussians(1024).to(next(model.parameters()).device)
traj_np, times_np = simulate_piecewise_length(model, x_init, dt=0.01, max_length=400)
plot_trajectories_custom(traj_np)
