###### based on https://github.com/atong01/conditional-flow-matching/blob/main/notebooks/training-8gaussians-to-moons.ipynb

# Conditional Flow Matching

This notebook is a self-contained example of conditional flow matching. We implement a number of different simulation-free methods for learning flow models. They differ based on the interpolant used and the loss function used to train them.

In this notebook we implement 5 models that can map from a source distribution $q_0$ to a target distribution $q_1$:
* Conditional Flow Matching (CFM)
    * This is equivalent to the basic (non-rectified) formulation of "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow" [(Liu et al. 2023)](https://openreview.net/forum?id=XVjTT1nw5z)
    * Is similar to "Stochastic Interpolants" [(Albergo et al. 2023)](https://openreview.net/forum?id=li7qeBbCR1t) with a non-variance preserving interpolant.
    * Is similar to "Flow Matching" [(Lipman et al. 2023)](https://openreview.net/forum?id=PqvMRDCJT9t) but conditions on both source and target.
* Optimal Transport CFM (OT-CFM), which directly optimizes for dynamic optimal transport
* Schrödinger Bridge CFM (SB-CFM), which optimizes for Schrödinger Bridge probability paths
* "Building Normalizing Flows with Stochastic Interpolants" [(Albergo et al. 2023)](https://openreview.net/forum?id=li7qeBbCR1t) this corresponds to "VP-CFM" in our README referring to its variance preserving properties.
* "Action Matching: Learning Stochastic Dynamics From Samples" [(Neklyudov et al. 2022)](https://arxiv.org/abs/2210.06662)

Note that this Flow Matching is different from the Generative Flow Network Flow Matching losses. Here we specifically regress against continuous flows, rather than matching inflows and outflows.

In [None]:
import math
import os
import time

import matplotlib.pyplot as plt
import numpy as np

# import ot as pot
import torch
from sklearn.datasets import make_moons

# import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons

savedir = "models/8gaussian-moons"
os.makedirs(savedir, exist_ok=True)

In [None]:
import sys

sys.path.append("../")
from src.models.components import EPiC_generator

In [None]:
# Implement some helper functions


def eight_normal_sample(n, dim, scale=1, var=1):
    m = torch.distributions.multivariate_normal.MultivariateNormal(
        torch.zeros(dim), math.sqrt(var) * torch.eye(dim)
    )
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
        (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
        (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
        (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
    ]
    centers = torch.tensor(centers) * scale
    noise = m.sample((n,))
    multi = torch.multinomial(torch.ones(8), n, replacement=True)
    data = []
    for i in range(n):
        data.append(centers[multi[i]] + noise[i])
    data = torch.stack(data)
    return data


def sample_moons(n, num_points=30):
    # x0, _ = generate_moons(n, noise=0.2)
    # return x0 * 3 - 1
    data = []
    for _ in range(n):
        x, _ = make_moons(n_samples=num_points, noise=0.05, shuffle=True)
        data.append(x)
    return torch.tensor(np.array(data), dtype=torch.float32) * 3 - 1


def sample_8gaussians(n, num_points=30):
    data = []
    for _ in range(n):
        x = eight_normal_sample(num_points, 2, scale=5, var=0.1).float()
        data.append(np.array(x))
    return torch.tensor(np.array(data), dtype=torch.float32)


class MLP(torch.nn.Module):
    def __init__(self, dim, out_dim=None, w=64, time_varying=False):
        super().__init__()
        self.time_varying = time_varying
        if out_dim is None:
            out_dim = dim
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim + (1 if time_varying else 0), w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, out_dim),
        )

    def forward(self, x):
        return self.net(x)


class GradModel(torch.nn.Module):
    def __init__(self, action):
        super().__init__()
        self.action = action

    def forward(self, x):
        x = x.requires_grad_(True)
        grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0]
        return grad[:, :-1]


class torch_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x):
        return self.model(torch.cat([x, t.repeat(x.shape[:-1])[..., None]], dim=-1))


class torch_wrapper_epic(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x):
        x_local = torch.cat([x, t.repeat(x.shape[:-1])[..., None]], dim=-1)
        x_global = torch.randn_like(torch.ones(x_local.shape[0], 16, device=x_local.device))
        return self.model(t, x_global, x_local)


def plot_trajectories(traj):
    n = 2000
    plt.figure(figsize=(6, 6))
    plt.scatter(traj[0, :n, :, 0], traj[0, :n, :, 1], s=10, alpha=0.8, c="black")
    plt.scatter(traj[:, :n, :, 0], traj[:, :n, :, 1], s=0.2, alpha=0.2, c="olive")
    plt.scatter(traj[-1, :n, :, 0], traj[-1, :n, :, 1], s=4, alpha=1, c="blue")
    plt.legend(["Prior sample z(S)", "Flow", "z(0)"])
    plt.xticks([])
    plt.yticks([])
    plt.show()

In [None]:
moons = sample_moons(100)
print(moons.shape)
gaussians = sample_8gaussians(100)
print(gaussians.shape)

In [None]:
t = torch.tensor(0.0)
rep = t.repeat(moons.shape[:-1])[..., None]
cat = torch.cat([moons, rep], dim=-1)
print(cat.shape)

In [None]:
plt.scatter(moons[:10, :, 0], moons[:10, :, 1])
plt.show()

In [None]:
plt.scatter(gaussians[:10, :, 0], gaussians[:10, :, 1])
plt.show()

### Conditional Flow Matching

First we implement the basic conditional flow matching. As in the paper, we have
$$
\begin{align}
z &= (x_0, x_1) \\
q(z) &= q(x_0)q(x_1) \\
p_t(x | z) &= \mathcal{N}(x | t * x_1 + (1 - t) * x_0, \sigma^2) \\
u_t(x | z) &= x_1 - x_0
\end{align}
$$
When $\sigma = 0$ this is equivalent to zero-steps of rectified flow. We find that small $\sigma$ helps to regularize the problem ymmv.

In [None]:
from tqdm.notebook import tqdm

In [None]:
def calc_set_feats(tensor: torch.Tensor) -> torch.Tensor:
    if len(tensor.shape) != 3:
        raise ValueError("Input tensor must be of shape (batch, set_size, feats)")
    mean_set = torch.mean(tensor, dim=-2)
    sum_set = torch.sum(tensor, dim=-2)
    mean_feat = torch.mean(tensor, dim=-1)
    mean_set_feat = torch.mean(mean_feat, dim=-1).unsqueeze(-1)
    sum_set_feat = torch.sum(mean_feat, dim=-1).unsqueeze(-1)
    cat = torch.cat([mean_set, sum_set, mean_set_feat, sum_set_feat], dim=-1)
    return cat

In [None]:
%%time

sigma = 0.1
dim = 2
batch_size = 256
# model = MLP(dim=dim, time_varying=True)
model = EPiC_generator(feats=dim, latent_local=dim + 1)
optimizer = torch.optim.Adam(model.parameters())

start = time.time()
for k in tqdm(range(500)):
    optimizer.zero_grad()
    x0 = sample_8gaussians(batch_size)  # (x,y)
    # print(f"x_0: {x0.shape}")
    x1 = sample_moons(batch_size)  # (x,y)
    # print(f"x_1: {x1.shape}")
    t = torch.rand_like(x0[..., 0]).unsqueeze(-1)
    # print(f"t shape: {t.shape}")
    mu_t = t * x1 + (1 - t) * x0  #
    # print(f"mu_t: {mu_t.shape}")
    sigma_t = sigma
    x = mu_t + sigma_t * torch.randn_like(x0)
    # print(f"x: {x.shape}")
    ut = x1 - x0
    # print(f"ut: {ut.shape}")
    # vt = model(torch.cat([x, t], dim=-1)) #MLP
    x_local = torch.cat([x, t], dim=-1)
    x_global = torch.randn_like(torch.ones(x_local.shape[0], 16, device=x_local.device))
    vt = model(t, x_global, x_local)
    # print(f"vt: {vt.shape}")
    # print(f"ut: {ut.shape}")
    loss = torch.mean((calc_set_feats(vt) - calc_set_feats(ut)) ** 2)

    # print(f"loss: {loss}")
    loss.backward()
    optimizer.step()
    if (k + 1) % 30 == 0:
        print(f"k: {k+1}")
    if (k + 1) % 499 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        node = NeuralODE(
            torch_wrapper_epic(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
        )
        with torch.no_grad():
            traj = node.trajectory(
                sample_8gaussians(100),
                t_span=torch.linspace(0, 1, 50),
            )
            plot_trajectories(traj)
torch.save(model, f"{savedir}/cfm_v1.pt")

In [None]:
%%time

sigma = 0.1
dim = 2
batch_size = 256
# model = MLP(dim=dim, time_varying=True)
model = EPiC_generator(feats=dim, latent_local=dim + 1)
optimizer = torch.optim.Adam(model.parameters())

start = time.time()
for k in tqdm(range(500)):
    optimizer.zero_grad()
    x0 = sample_8gaussians(batch_size)  # (x,y)
    # print(f"x_0: {x0.shape}")
    x1 = sample_moons(batch_size)  # (x,y)
    # print(f"x_1: {x1.shape}")
    t = torch.rand_like(x0[..., 0]).unsqueeze(-1)
    # print(f"t shape: {t.shape}")
    mu_t = t * x1 + (1 - t) * x0  #
    # print(f"mu_t: {mu_t.shape}")
    sigma_t = sigma
    x = mu_t + sigma_t * torch.randn_like(x0)
    # print(f"x: {x.shape}")
    ut = x1 - x0
    # print(f"ut: {ut.shape}")
    # vt = model(torch.cat([x, t], dim=-1)) #MLP
    x_local = torch.cat([x, t], dim=-1)
    x_global = torch.randn_like(torch.ones(x_local.shape[0], 16, device=x_local.device))
    vt = model(t, x_global, x_local)
    # print(f"vt: {vt.shape}")
    # print(f"ut: {ut.shape}")
    loss = torch.mean((vt - ut) ** 2) + torch.mean((calc_set_feats(vt) - calc_set_feats(ut)) ** 2)

    # print(f"loss: {loss}")
    loss.backward()
    optimizer.step()
    if (k + 1) % 30 == 0:
        print(f"k: {k+1}")
    if (k + 1) % 499 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        node = NeuralODE(
            torch_wrapper_epic(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
        )
        with torch.no_grad():
            traj = node.trajectory(
                sample_8gaussians(100),
                t_span=torch.linspace(0, 1, 50),
            )
            plot_trajectories(traj)
torch.save(model, f"{savedir}/cfm_v1.pt")

In [None]:
print(traj.shape)

In [None]:
def plot_trajectories1(traj):
    n = 1
    m = 2
    plt.figure(figsize=(6, 6))
    plt.scatter(traj[0, :n, :, 0], traj[0, :n, :, 1], s=10, alpha=0.8, c="black")
    plt.scatter(traj[:, :n, :, 0], traj[:, :n, :, 1], s=0.2, alpha=0.2, c="olive")
    plt.scatter(traj[-1, :n, :, 0], traj[-1, :n, :, 1], s=4, alpha=1, c="blue")
    plt.scatter(traj[0, m, :, 0], traj[0, m, :, 1], s=10, alpha=0.8, c="green")
    plt.scatter(traj[:, m, :, 0], traj[:, m, :, 1], s=0.2, alpha=0.2, c="orange")
    plt.scatter(traj[-1, m, :, 0], traj[-1, m, :, 1], s=4, alpha=1, c="red")
    plt.legend(["Prior sample z(S)", "Flow", "z(0)"])
    plt.xticks([])
    plt.yticks([])
    plt.show()

In [None]:
plot_trajectories1(traj)

In [None]:
tensor1 = torch.rand_like(torch.ones(10, 5, 3))
tensor2 = torch.rand_like(torch.ones(10, 5, 3))
print(tensor1.shape)

In [None]:
def calc_set_feats(tensor: torch.Tensor) -> torch.Tensor:
    if len(tensor.shape) != 3:
        raise ValueError("Input tensor must be of shape (batch, set_size, feats)")
    mean_set = torch.mean(tensor, dim=-2)
    sum_set = torch.sum(tensor, dim=-2)
    mean_feat = torch.mean(tensor, dim=-1)
    mean_set_feat = torch.mean(mean_feat, dim=-1).unsqueeze(-1)
    sum_set_feat = torch.sum(mean_feat, dim=-1).unsqueeze(-1)
    cat = torch.cat([mean_set, sum_set, mean_set_feat, sum_set_feat], dim=-1)
    return cat

In [None]:
print(f"function: {calc_set_feats(tensor1).shape}")

In [None]:
mean_set = torch.mean(tensor1, dim=-2)
sum_set = torch.sum(tensor1, dim=-2)
print(mean_set.shape)
print(sum_set.shape)
print(f"tensor1: {tensor1.shape}")
mean_feat = torch.mean(tensor1, dim=-1)
print(mean_feat.shape)
mean_set_feat = torch.mean(mean_feat, dim=-1).unsqueeze(-1)
sum_set_feat = torch.sum(mean_feat, dim=-1).unsqueeze(-1)
print(mean_set_feat.shape)
print(sum_set_feat.shape)
cat = torch.cat([mean_set, sum_set, mean_set_feat, sum_set_feat], dim=-1)
print(cat.shape)

In [None]:
diff = tensor1 - tensor2
print(diff.shape)
sqr = diff**2
print(sqr.shape)
mn = torch.mean(sqr)
print(mn.shape)
# loss = torch.mean((vt - ut) ** 2)