# Flow Matching on $\mathbb{M}_3$
In this notebook, we perform flow matching on Position-Orientation space $\mathbb{M}_3$.

In [None]:
import torch
import numpy as np
from lieflow.groups import M3
from lieflow.models import (
    get_model_FM,
    LogarithmicDistance
)
import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
from matplotlib.cm import ScalarMappable
from tqdm.notebook import tqdm
from functools import partial
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern"],
    "font.size": 10.0,
    "text.latex.preamble": r"\usepackage{lmodern} \usepackage{amssymb} \usepackage{amsmath}"
})
%matplotlib widget

In [None]:
m3 = M3()

In [None]:
generate_videos = True

# Distributions

In [None]:
def generate_line(N, g, w=1., ε=0.05):
    x = torch.zeros(N, 3)
    x[:, 0] = 2. * w * (torch.rand(N) - 0.5)
    x += ε * torch.randn(N, 3)
    n = torch.zeros(N, 3)
    n[:, 1] = 1.
    n += ε * torch.randn(N, 3)
    n = n / (n**2).sum(-1, keepdim=True)
    p = m3.pack_position_orientation(x, n)
    return m3.act(g, p)

def generate_normals(N, g):
    xs = torch.randn(N, 3)
    normals = torch.randn(N, 3)
    lengths = (normals**2).sum(dim=-1).sqrt()
    non_zero = lengths > 0.
    ns = normals[non_zero] / lengths[non_zero, None]
    return m3.act(g, m3.pack_position_orientation(xs[non_zero], ns))

def generate_uniforms_on_circle(N, g, ε=0.05):
    normals = torch.randn(N, 3)
    lengths = (normals**2).sum(dim=-1).sqrt()
    non_zero = lengths > 0.
    xs = normals[non_zero] / lengths[non_zero, None]
    ns = xs + ε * torch.randn(non_zero.sum(), 3)
    ns = ns / (ns**2).sum(dim=-1, keepdim=True).sqrt()
    return m3.act(g, m3.pack_position_orientation(xs, ns))

x_ref = torch.zeros(3)
n_ref_x = torch.tensor([1., 0., 0.])
n_ref_y = torch.tensor([0., 1., 0.])
n_ref_z = torch.tensor([0., 0., 1.])
p_ref_x = m3.pack_position_orientation(x_ref, n_ref_x)
p_ref_y = m3.pack_position_orientation(x_ref, n_ref_y)
p_ref = m3.pack_position_orientation(x_ref, n_ref_z)

def g_to_frame(g):
    return m3.act(g, p_ref_x), m3.act(g, p_ref_y), m3.act(g, p_ref)

def lift_to_SE3(p):
    A = m3.get_generator(p_ref, p)
    random_rotation = m3.se3.exp(2 * torch.pi * torch.rand(*p.shape[:-2], 1, 1) * m3.se3.lie_algebra_basis[-1])
    return m3.se3.L(m3.se3.exp(A), random_rotation)
    
def project_to_M3(g):
    return m3.act(g, p_ref)

def plot_p(ax, p, **kwargs):
    x, n = m3.get_position_orientation(p)
    return ax.quiver(x[..., 0], x[..., 1], x[..., 2], n[..., 0], n[..., 1], n[..., 2], **kwargs)

In [None]:
tests = ("line_to_line", "normals_to_circle")

In [None]:
def data_generator(test, ε=0.04):
    match test:
        case "line_to_line":
            g_0 = m3.se3.exp((torch.tensor([0., -2., 0., 0., 0., 0.])[..., None, None] * m3.se3.lie_algebra_basis).sum(0))
            g_1 = m3.se3.exp((torch.tensor([0., 2., 0., 0., 0.25 * torch.pi, 0.25 * torch.pi])[..., None, None] * m3.se3.lie_algebra_basis).sum(0))
            generate_p_0 = partial(generate_line, g=g_0, ε=ε)
            generate_p_1 = partial(generate_line, g=g_1, ε=ε)
        case "normals_to_circle":
            g_0 = m3.se3.exp((torch.tensor([0., -1., 0., 0., 0., 0.])[..., None, None] * m3.se3.lie_algebra_basis).sum(0))
            g_1 = m3.se3.exp((torch.tensor([0., 1., 0., 0., 0., 0.])[..., None, None] * m3.se3.lie_algebra_basis).sum(0))
            generate_p_0 = partial(generate_normals, g=g_0)
            generate_p_1 = partial(generate_uniforms_on_circle, g=g_1, ε=ε)
    return generate_p_0, generate_p_1

In [None]:
generate_p_0, generate_p_1 = data_generator(tests[0])

def generate_g_0(N): return lift_to_SE3(generate_p_0(N))
def generate_g_1(N): return lift_to_SE3(generate_p_1(N))

In [None]:
l_big = 0.4
l_small = 0.2
alpha_big = 0.8
alpha_small = 0.3

p_0s = generate_p_0(32)
g_0s = lift_to_SE3(p_0s)
p_0_x, p_0_y, p_0_z = g_to_frame(g_0s)
p_1s = generate_p_1(32)
g_1s = lift_to_SE3(p_1s)

# g_0s = generate_g_0(32)
# p_0s = project_to_M3(g_0s)
# p_0_x, p_0_y, p_0_z = g_to_frame(g_0s)
# g_1s = generate_g_1(32)
# p_1s = project_to_M3(g_1s)

p_1_x, p_1_y, p_1_z = g_to_frame(g_1s)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
plot_p(ax, p_0s, alpha=alpha_big, color="red", length=l_big)
plot_p(ax, p_1s, alpha=alpha_big, color="green", length=l_big)
plot_p(ax, p_0_x, alpha=alpha_small, color="red", length=l_small)
plot_p(ax, p_0_y, alpha=alpha_small, color="green", length=l_small)
plot_p(ax, p_0_z, alpha=alpha_small, color="blue", length=l_small)
plot_p(ax, p_1_x, alpha=alpha_small, color="red", length=l_small)
plot_p(ax, p_1_y, alpha=alpha_small, color="green", length=l_small)
plot_p(ax, p_1_z, alpha=alpha_small, color="blue", length=l_small)
ax.set_xlim(-3, 3)
ax.set_xlabel("$x$")
ax.set_ylim(-3, 3)
ax.set_ylabel("$y$")
ax.set_zlim(-3, 3)
ax.set_zlabel("$z$");

# Models

## Training

In [None]:
EPSILON = 0.01
N = 2**14
BATCH_SIZE = 2**10
EPOCHS = len(tests) * [50]
WEIGHT_DECAY = 0.
LEARNING_RATE = 1e-3
H = 64 # Width
L = 3 # Number of layers is L + 2
device = "cpu"

In [None]:
def train_models(p_0s, p_1s, epochs, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY):
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(p_0s, p_1s), batch_size=batch_size, shuffle=True
    )

    model_M3 = get_model_FM(m3, H=H, L=L).to(device)
    print("Number of parameters: ", model_M3.parameter_count)
    optimizer_M3 = torch.optim.Adam(model_M3.parameters(), learning_rate, weight_decay=weight_decay)
    loss = LogarithmicDistance(torch.tensor([1., 1., 1., 1., 1., 1.]))

    losses_M3 = np.zeros(epochs)
    for i in tqdm(range(epochs)):
        losses_M3[i] = model_M3.train_network(device, train_loader, optimizer_M3, loss)
    
    g_0s = lift_to_SE3(p_0s)
    g_1s = lift_to_SE3(p_1s)
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(g_0s, g_1s), batch_size=batch_size, shuffle=True
    )

    model_SE3 = get_model_FM(m3.se3, H=H, L=L).to(device)
    print("Number of parameters: ", model_SE3.parameter_count)
    optimizer_SE3 = torch.optim.Adam(model_SE3.parameters(), learning_rate, weight_decay=weight_decay)
    loss = LogarithmicDistance(torch.tensor([1., 1., 1., 1., 1., 1.]))

    losses_SE3 = np.zeros(epochs)
    for i in tqdm(range(epochs)):
        losses_SE3[i] = model_SE3.train_network(device, train_loader, optimizer_SE3, loss)

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.plot(losses_M3, label=r"$\mathbb{M}_3$")
    ax.plot(losses_SE3, label=r"$\operatorname{SE}(3)$")
    ax.set_title("Batch Loss Flow Matching")
    ax.legend()
    ax.set_xscale("log")
    ax.set_yscale("log")

    return model_M3, model_SE3

In [None]:
models_M3 = {}
models_SE3 = {}
for i, test in enumerate(tests):
    print(test)
    generate_p_0, generate_p_1 = data_generator(test, ε=EPSILON)
    
    p_0s = generate_p_0(N)
    g_0s = lift_to_SE3(p_0s)
    p_1s = generate_p_1(N)
    g_1s = lift_to_SE3(p_1s)
    p_0_x, p_0_y, p_0_z = g_to_frame(g_0s)
    p_1_x, p_1_y, p_1_z = g_to_frame(g_1s)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    plot_p(ax, p_0s[:32], alpha=alpha_big, color="red", length=l_big)
    plot_p(ax, p_1s[:32], alpha=alpha_big, color="green", length=l_big)
    plot_p(ax, p_0_x[:32], alpha=alpha_small, color="red", length=l_small)
    plot_p(ax, p_0_y[:32], alpha=alpha_small, color="green", length=l_small)
    plot_p(ax, p_0_z[:32], alpha=alpha_small, color="blue", length=l_small)
    plot_p(ax, p_1_x[:32], alpha=alpha_small, color="red", length=l_small)
    plot_p(ax, p_1_y[:32], alpha=alpha_small, color="green", length=l_small)
    plot_p(ax, p_1_z[:32], alpha=alpha_small, color="blue", length=l_small)
    ax.set_xlim(-3, 3)
    ax.set_xlabel("$x$")
    ax.set_ylim(-3, 3)
    ax.set_ylabel("$y$")
    ax.set_zlim(-3, 3)
    ax.set_zlabel("$z$")
    models_M3[test], models_SE3[test] = train_models(p_0s, p_1s, epochs=EPOCHS[i], batch_size=BATCH_SIZE,
                                                     learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

## Testing

In [None]:
def create_animations(tests, models_M3, models_SE3, N=2**5):
    # Forward flow
    p_0s_test = {}
    p_ts_test = {}
    g_0s_test = {}
    g_ts_test = {}
    for test in tests:
        generate_p_0, _ = data_generator(test, ε=EPSILON)
        p_0s_test[test] = generate_p_0(N).to(device)
        p_ts_test[test] = p_0s_test[test].detach().clone()
        g_0s_test[test] = lift_to_SE3(p_0s_test[test])
        g_ts_test[test] = g_0s_test[test].detach().clone()
        
        models_M3[test].eval()
        models_SE3[test].eval()

    t = 0
    N_steps = 120
    Δt = 1. / N_steps

    for test in tests:
        metadata = {'title': f'Flow Matching M3 {test}', 'artist': 'Matplotlib'}
        writer = FFMpegWriter(fps=30, metadata=metadata)

        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection='3d')

        ax.set_xlim(-3, 3)
        ax.set_xlabel("$x$")
        ax.set_ylim(-3, 3)
        ax.set_ylabel("$y$")
        ax.set_zlim(-3, 3)
        ax.set_zlabel("$z$")
        with writer.saving(fig, f"output/flow_matching_M3_{test}.mp4", dpi=150):
            q_0 = ax.quiver([], [], [], [], [], [])
            for frame in tqdm(range(N_steps+1)):
                q_0.remove()
                q_0 = plot_p(ax, p_ts_test[test].to("cpu"), length=l_big)
                writer.grab_frame()
                t = frame * Δt
                with torch.no_grad():
                    p_ts_test[test] = models_M3[test].step(p_ts_test[test], torch.Tensor([t])[..., None, None], Δt)
            
        metadata = {'title': f'Flow Matching SE3 {test}', 'artist': 'Matplotlib'}
        writer = FFMpegWriter(fps=30, metadata=metadata)

        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection='3d')

        ax.set_xlim(-3, 3)
        ax.set_xlabel("$x$")
        ax.set_ylim(-3, 3)
        ax.set_ylabel("$y$")
        ax.set_zlim(-3, 3)
        ax.set_zlabel("$z$")
        with writer.saving(fig, f"output/flow_matching_SE3_{test}.mp4", dpi=150):
            p_0_x, p_0_y, p_0_z = g_to_frame(g_ts_test[test].to("cpu"))
            artists = []
            for frame in tqdm(range(N_steps+1)):
                for a in artists:
                    a.remove()
                p_0_x, p_0_y, p_0_z = g_to_frame(g_ts_test[test].to("cpu"))
                artists = []
                artists.append(plot_p(ax, p_0_x, length=l_small, alpha=alpha_small, color="red"))
                artists.append(plot_p(ax, p_0_y, length=l_small, alpha=alpha_small, color="green"))
                artists.append(plot_p(ax, p_0_z, length=l_small, alpha=alpha_small, color="blue"))
                writer.grab_frame()
                t = frame * Δt
                with torch.no_grad():
                    g_ts_test[test] = models_SE3[test].step(g_ts_test[test], torch.Tensor([t])[..., None, None], Δt)


In [None]:
if generate_videos:
    create_animations(tests, models_M3, models_SE3, N=2**5)

In [None]:
t = 0
N_steps = 240
Δt = 1. / N_steps
N_show = 5
N_skip = int(N_steps / (N_show-1))

N_models = len(tests)
N_samples = 32

p_0s_test = {}
p_1s_test = {}
p_ts_test = {}
for test in tests:
    generate_p_0, generate_p_1 = data_generator(test, ε=EPSILON)
    p_0s_test[test] = generate_p_0(N_samples).to(device)
    p_1s_test[test] = generate_p_1(N_samples).to(device)
    p_ts_test[test] = p_0s_test[test].detach().clone()
    
    models_FM[test].eval()

fig = plt.figure(figsize=(4.8, 1.6 * N_models * 3/3.1))
gs = gridspec.GridSpec(N_models, 4, width_ratios=[1, 1, 1, 0.1], height_ratios=N_models * [1.], wspace=0.1, hspace=0.1)
cax = fig.add_subplot(gs[:, 3])

ax = []
for i in range(N_models):
    ax.append([])
    for j in range(3):
        a = fig.add_subplot(gs[i, j], projection="3d")

        a.set_xlim(-3, 3)
        a.set_xticks([])
        a.set_ylim(-3, 3)
        a.set_yticks([])
        a.set_zlim(-3, 3)
        a.set_zticks([])
        ax[i].append(a)
        
ax[0][0].set_title(r"$\mathfrak{X}_0$")
ax[0][1].set_title(r"$\mathfrak{X}_t$")
ax[0][2].set_title(r"$\mathfrak{X}_1$")
Δc = 1 / (N_show - 1)
colors = [(j * Δc, 0.1, 1 - j * Δc) for j in range(N_show)]
cmap = mcolors.ListedColormap(colors)
for i, test in enumerate(tests):
    k = 0

    x_0s, n_0s = m3.get_position_orientation(p_0s_test[test])
    x_1s, n_1s = m3.get_position_orientation(p_1s_test[test])
    ax[i][0].quiver(
        x_0s[:, 0], x_0s[:, 1], x_0s[:, 2],
        n_0s[:, 0], n_0s[:, 1], n_0s[:, 2],
        # length=0.1, linewidths=0.05
    )
    ax[i][2].quiver(
        x_1s[:, 0], x_1s[:, 1], x_1s[:, 2],
        n_1s[:, 0], n_1s[:, 1], n_1s[:, 2],
        # length=0.1, linewidths=0.05
    )

    alpha = 1
    N_samples_shown = N_samples
    for j in range(N_steps+1):
        t = j * Δt
        if j == N_steps:
            alpha = 1
            N_samples_shown = N_samples
        with torch.no_grad():
            p_ts_test[test] = models_FM[test].step(p_ts_test[test], torch.Tensor([t]), torch.Tensor([Δt]))

        if j % N_skip == 0:
            x_ts, n_ts = m3.get_position_orientation(p_ts_test[test].to("cpu"))

            im = ax[i][1].quiver(
                x_ts[:, 0], x_ts[:, 1], x_ts[:, 2],
                n_ts[:, 0], n_ts[:, 1], n_ts[:, 2],
                color=colors[k], alpha=alpha
            )
            k += 1
        alpha = 0.5
        N_samples_shown = N_samples // 4
fig.colorbar(ScalarMappable(cmap=cmap), cax=cax, ticks=np.linspace(0, 1, N_show), label="$t$")
fig.savefig("output/interpolation_M3.pdf", bbox_inches="tight")