# Flow Matching on EightMoons

This notebook is a simple example of how to use different Flow Matching loss objective. The following Flow Matching models are implemented:
    
* Conditional Flow Matching from [2302.00482](https://arxiv.org/abs/2302.00482)
    
* Optimal Transport CFM from [2302.00482](https://arxiv.org/abs/2302.00482)

* Self conditioned Flow Matching as in [2310.05764](https://arxiv.org/abs/2310.05764)

As this repository mainly focuses on the generation of point clouds, the examples also consider point cloud structured data. Note, that this data structure complicates the problem and the trainings might take a few minutes to converge.

In [None]:
import math
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torch
from sklearn.datasets import make_moons

# import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons
from tqdm.notebook import tqdm

savedir = "models/8gaussian-moons"
os.makedirs(savedir, exist_ok=True)

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.append("../")
from src.models.components.epic import EPiC_encoder

In [None]:
# Implement some helper functions


def eight_normal_sample(n, dim, scale=1, var=1):
    m = torch.distributions.multivariate_normal.MultivariateNormal(
        torch.zeros(dim), math.sqrt(var) * torch.eye(dim)
    )
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
        (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
        (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
        (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
    ]
    centers = torch.tensor(centers) * scale
    noise = m.sample((n,))
    multi = torch.multinomial(torch.ones(8), n, replacement=True)
    data = []
    for i in range(n):
        data.append(centers[multi[i]] + noise[i])
    data = torch.stack(data)
    return data


def sample_moons(n, num_points=30):
    # x0, _ = generate_moons(n, noise=0.2)
    # return x0 * 3 - 1
    data = []
    for _ in range(n):
        x, _ = make_moons(n_samples=num_points, noise=0.05, shuffle=True)
        data.append(x)
    return torch.tensor(np.array(data), dtype=torch.float32) * 3 - 1


def sample_8gaussians(n, num_points=30):
    data = []
    for _ in range(n):
        x = eight_normal_sample(num_points, 2, scale=5, var=0.1).float()
        data.append(np.array(x))
    return torch.tensor(np.array(data), dtype=torch.float32)


class MLP(torch.nn.Module):
    def __init__(self, dim, out_dim=None, w=64, time_varying=False):
        super().__init__()
        self.time_varying = time_varying
        if out_dim is None:
            out_dim = dim
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim + (1 if time_varying else 0), w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, out_dim),
        )

    def forward(self, x):
        return self.net(x)


class GradModel(torch.nn.Module):
    def __init__(self, action):
        super().__init__()
        self.action = action

    def forward(self, x):
        x = x.requires_grad_(True)
        grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0]
        return grad[:, :-1]


class torch_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x):
        return self.model(torch.cat([x, t.repeat(x.shape[:-1])[..., None]], dim=-1))


class torch_wrapper_epic(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x, *args, **kwargs):
        x_local = torch.cat([x, t.repeat(x.shape[:-1])[..., None]], dim=-1)
        x_global = torch.randn_like(torch.ones(x_local.shape[0], 16, device=x_local.device))
        return self.model(t, x_local)

        # "#B6BFC3",
        # "#3B515B",
        # "#0271BB",
        # "#E2001A",


def plot_trajectories(traj, n=20, save_name="test"):
    plt.figure(figsize=(6, 6))
    plt.xlim(-7, 7)
    plt.ylim(-7, 8)
    plt.scatter(
        traj[0, :n, :, 0],
        traj[0, :n, :, 1],
        s=10,
        alpha=0.8,
        c="#0271BB",
        label="Prior samples $x_1$",
    )
    plt.scatter(
        traj[:, :n, :, 0], traj[:, :n, :, 1], s=0.2, alpha=0.1, c="#3B515B", label="Flow $x_t$"
    )
    plt.scatter(
        traj[-1, :n, :, 0],
        traj[-1, :n, :, 1],
        s=4,
        alpha=1,
        c="#E2001A",
        label="Gen. samples $x_0$",
    )
    legend = plt.legend(frameon=False, loc="upper left")
    # plt.xticks([])
    # plt.yticks([])
    legend.legend_handles[0]._sizes = [30]
    legend.legend_handles[1]._sizes = [30]
    legend.legend_handles[2]._sizes = [30]

    plt.savefig(f"{save_name}.pdf", bbox_inches="tight")
    plt.axis("off")
    plt.show()

In [None]:
moons = sample_moons(100)
print(moons.shape)
gaussians = sample_8gaussians(100)
print(gaussians.shape)

In [None]:
plt.scatter(moons[:10, :, 0], moons[:10, :, 1])
plt.show()

In [None]:
plt.scatter(gaussians[:10, :, 0], gaussians[:10, :, 1])
plt.show()

## Conditional Flow Matching
The Conditional Flow Matching Objective allows the usage of non Gaussian priors.

In [None]:
%%time

steps = 200
sigma = 0.1
dim = 2
batch_size = 256
# model = MLP(dim=dim, time_varying=True)
model = EPiC_encoder(feats=dim, input_dim=dim + 1)
optimizer = torch.optim.Adam(model.parameters())

start = time.time()
for k in tqdm(range(steps)):
    optimizer.zero_grad()
    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)

    t = torch.rand_like(torch.ones(x0.shape[0]))
    t = t.unsqueeze(-1).repeat_interleave(x0.shape[1], dim=1).unsqueeze(-1)
    t = t.type_as(x0)

    mu_t = (1 - t) * x1 + t * x0

    sigma_t = sigma
    x = mu_t + sigma_t * torch.randn_like(x0)

    ut = x0 - x1

    x_local = torch.cat([x, (t)], dim=-1)
    x_global = torch.randn_like(torch.ones(x_local.shape[0], 16, device=x_local.device))

    vt = model(t, x_local)

    loss = torch.mean((vt - ut) ** 2)

    loss.backward()
    optimizer.step()

    if (k + 1) % 50 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        node = NeuralODE(
            torch_wrapper_epic(model),
            solver="midpoint",
            sensitivity="adjoint",
        )
        with torch.no_grad():
            traj = node.trajectory(
                sample_8gaussians(100),
                t_span=torch.linspace(1.0, 0.0, 50),
            )
            plot_trajectories(traj)
# torch.save(model, f"{savedir}/cfm_v1.pt")

In [None]:
# this shows how easy one can implement an euler solver
def esampler(func, noise, steps=50):
    t = torch.ones(1)
    delta_t = 1 / steps
    for _ in range(steps):
        noise += -func(t, noise) * delta_t
        t -= delta_t
    return noise

In [None]:
with torch.no_grad():
    zss = sample_8gaussians(100)
    xss = esampler(torch_wrapper_epic(model), zss)

In [None]:
plt.scatter(xss[:, :, 0], xss[:, :, 1])
plt.show()

In [None]:
plot_trajectories(traj, n=20, save_name="cfm")

### Optimal Transport Conditional Flow Matching

In this flow matching implementation, an the marginal probability paths are ensured to be optimal. The original paper did this in mini-batches, in this implementation for sets, it is done for the whole set. This implementation creates very nice paths but the calculation of the optimal transport map makes it really slow.

In [None]:
%%time

sigma = 0.1
dim = 2
batch_size = 256
# model = MLP(dim=dim, time_varying=True)
model = EPiC_encoder(feats=dim)  # latent_local=dim + 1)
optimizer = torch.optim.Adam(model.parameters())

start = time.time()
for z in tqdm(range(2000)):
    optimizer.zero_grad()
    x0 = sample_8gaussians(batch_size)  # (x,y)
    x1 = sample_moons(batch_size)  # (x,y)

    t = torch.rand_like(torch.ones(x0.shape[0]))
    t = t.unsqueeze(-1).repeat_interleave(x0.shape[1], dim=1).unsqueeze(-1)
    t = t.type_as(x0)

    a, b = pot.unif(x0.size()[1]), pot.unif(x1.size()[1])
    a = np.repeat(np.expand_dims(a, axis=0), x0.size()[0], axis=0)
    b = np.repeat(np.expand_dims(b, axis=0), x1.size()[0], axis=0)
    M = torch.cdist(x0, x1) ** 2
    for k in range(M.shape[0]):
        M[k] = M[k] / M[k].max()
        pi = pot.emd(a[k], b[k], M[k].detach().cpu().numpy())
        p = pi.flatten()
        p = p / p.sum()
        choices = np.random.choice(pi.shape[0] * pi.shape[1], p=p, size=pi.shape[0])
        i, j = np.divmod(choices, pi.shape[1])
        x0[k] = x0[k, i]
        x1[k] = x1[k, j]

    mu_t = x0 * t + x1 * (1 - t)
    sigma_t = sigma
    x = mu_t + sigma_t * torch.randn_like(x0)
    ut = x0 - x1
    x_local = torch.cat([x, t], dim=-1)
    x_global = torch.randn_like(torch.ones(x_local.shape[0], 16, device=x_local.device))
    vt = model(t, x_local)
    loss = torch.mean((vt - ut) ** 2)

    loss.backward()
    optimizer.step()

    if (z + 1) % 500 == 0:
        end = time.time()
        print(f"{z+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        node0 = NeuralODE(
            torch_wrapper_epic(model),
            solver="midpoint",
            sensitivity="adjoint",
        )
        with torch.no_grad():
            traj0 = node0.trajectory(
                sample_8gaussians(100),
                t_span=torch.linspace(1, 0, 50),
            )
            plot_trajectories(traj0)
# torch.save(model, f"{savedir}/cfm_v1.pt")

In [None]:
plot_trajectories(traj, n=20, save_name="cfm_ot")

# Self Conditioned Model
The CFM loss is combined with an EPiC architecture and self-conditioning. This requires also a special solver and not every standard solver can be used.

In [None]:
# modified euler solver
def esampler_selfcond(func, noise, steps=100):
    traj_arr = []
    t = torch.ones(1)
    delta_t = 1 / steps
    for _ in range(steps):
        if t == 1:
            cond = -func(t, noise, None)
            noise += cond * delta_t
        else:
            cond = -func(t, noise, cond)
            noise += cond * delta_t
        t -= delta_t
        traj_arr.append(noise.clone().numpy())
    return noise, np.array(traj_arr)

In [None]:
class torch_wrapper_epic_selfcond(torch.nn.Module):
    """Wraps model to solver compatible format."""

    def __init__(self, model, cond):
        super().__init__()
        self.model = model
        self.cond = cond

    def forward(self, t, x, cond, *args, **kwargs):
        x_local = torch.cat([x, t.repeat(x.shape[:-1])[..., None]], dim=-1)
        x_global = torch.randn_like(torch.ones(x_local.shape[0], 16, device=x_local.device))
        if cond is None:
            x_local = torch.cat([x_local, self.cond], dim=-1)
        else:
            x_local = torch.cat([x_local, cond], dim=-1)
        return self.model(t, x_local)

In [None]:
%%time

steps = 1000
sigma = 0.1
dim = 2
batch_size = 256
# model = MLP(dim=dim, time_varying=True)
model = EPiC_encoder(feats=dim, input_dim=dim + 1 + 2)
optimizer = torch.optim.Adam(model.parameters())

start = time.time()
for k in tqdm(range(steps)):
    optimizer.zero_grad()
    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)

    t = torch.rand_like(torch.ones(x0.shape[0]))
    t = t.unsqueeze(-1).repeat_interleave(x0.shape[1], dim=1).unsqueeze(-1)
    t = t.type_as(x0)

    s = torch.rand(1)

    mu_t = (1 - t) * x1 + t * x0

    sigma_t = sigma
    x = mu_t + sigma_t * torch.randn_like(x0)

    ut = x0 - x1

    x_local = torch.cat([x, (t)], dim=-1)

    x_1_tilde = sample_8gaussians(batch_size)
    if s > 0.5:
        x_local2 = torch.cat([x_local, x_1_tilde], dim=-1)
        x_1_tilde = model(t, x_local2, x_1_tilde)

    x_global = torch.randn_like(torch.ones(x_local.shape[0], 16, device=x_local.device))

    x_local = torch.cat([x_local, x_1_tilde], dim=-1)
    vt = model(t, x_local)

    loss = torch.mean((vt - ut) ** 2)

    loss.backward()
    optimizer.step()

    if (k + 1) % 50 == 0 or k == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        with torch.no_grad():
            torch.manual_seed(0)
            zss = sample_8gaussians(20)
            xss, traj = esampler_selfcond(torch_wrapper_epic_selfcond(model, zss), zss)
        plot_trajectories(traj, n=20, save_name=f"pictures/{k+1}")

In [None]:
with torch.no_grad():
    zss = sample_8gaussians(100)
    xss, traj = esampler_selfcond(torch_wrapper_epic_selfcond(model, zss), zss)

In [None]:
plot_trajectories(traj)

### Optimal Transport Conditional Flow Matching

Next we implement optimal transport conditional flow matching. As in the paper, here we have
$$
\begin{align}
z &= (x_0, x_1) \\
q(z) &= \pi(x_0, x_1) \\
p_t(x | z) &= \mathcal{N}(x | t * x_1 + (1 - t) * x_0, \sigma^2) \\
u_t(x | z) &= x_1 - x_0
\end{align}
$$
where $\pi$ is the joint of an exact optimal transport matrix. We first sample random $x_0, x_1$, then resample according to the optimal transport matrix as computed with the python optimal transport package. We use the 2-Wasserstein distance with an $L^2$ ground distance for equivalence with dynamic optimal transport.

In [None]:
%%time
sigma = 0.1
dim = 2
batch_size = 256
model = MLP(dim=dim, time_varying=True)
optimizer = torch.optim.Adam(model.parameters())

start = time.time()
for k in range(20000):
    optimizer.zero_grad()
    t = torch.rand(batch_size, 1)
    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)
    print(x0.shape, x1.shape)
    # Resample x0, x1 according to transport matrix
    a, b = pot.unif(x0.size()[0]), pot.unif(x1.size()[0])
    # return uniform distribution (256 Werte alle mit dem gleichen Wert 1/256)
    # print(f"x0: {x0.size()[0]}")
    # print(a,b)

    M = torch.cdist(x0, x1) ** 2
    print(f"M shape: {M.shape}")
    print(f"M: {M}")
    M = M / M.max()

    pi = pot.emd(a, b, M.detach().cpu().numpy())
    print(f"pi shape: {pi.shape}")
    print(f"pi: {pi[0]}")

    # Sample random interpolations on pi
    p = pi.flatten()
    p = p / p.sum()
    print(f"P: {p}")
    choices = np.random.choice(pi.shape[0] * pi.shape[1], p=p, size=batch_size)
    print(f"choices.shape: {choices.shape}")
    print(f"choices: {choices}")
    i, j = np.divmod(choices, pi.shape[1])
    print(f"i: {i}")
    print(f"j: {j}")
    x0 = x0[i]
    x1 = x1[j]
    # calculate regression loss
    mu_t = x0 * (1 - t) + x1 * t
    sigma_t = sigma
    x = mu_t + sigma_t * torch.randn(batch_size, dim)
    ut = x1 - x0
    vt = model(torch.cat([x, t], dim=-1))
    loss = torch.mean((vt - ut) ** 2)
    loss.backward()
    optimizer.step()
    if (k + 1) % 5000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        node = NeuralODE(
            torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
        )
        with torch.no_grad():
            traj = node.trajectory(
                sample_8gaussians(1024),
                t_span=torch.linspace(0, 1, 100),
            )
            plot_trajectories(traj)
# torch.save(model, f"{savedir}/otcfm_v1.pt")