# Flow Matching on $\mathbb{R}^2$
In this notebook, we perform flow matching on the two dimensional translation group $\mathbb{R}^2$, which is isometric to two dimensional Euclidean space.

In [None]:
import torch
import numpy as np
from lieflow.groups import Rn
from lieflow.models import (
    get_model_FM,
    LogarithmicDistance
)
import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
from matplotlib.cm import ScalarMappable
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern"],
    "font.size": 10.0,
    "text.latex.preamble": r"\usepackage{lmodern} \usepackage{amssymb} \usepackage{amsmath}"
})
from tqdm.notebook import tqdm

In [None]:
r2 = Rn(2)

In [None]:
generate_videos = True

# Distributions

In [None]:
def generate_normals(N):
    return torch.Tensor(np.random.randn(N, 2))

def generate_uniforms_on_circle(N, ε=0.04):
    normals = generate_normals(N)
    lengths = (normals**2).sum(dim=-1).sqrt()
    non_zero = lengths > 0.
    circle_samples = normals[non_zero] / lengths[non_zero, None]
    return circle_samples + ε * torch.randn(circle_samples.shape)

In [None]:
tests = ("normals_to_circle",)

In [None]:
def data_generator(test, ε=0.04):
    match test:
        case "normals_to_circle":
            generate_x_0 = lambda n: generate_normals(n)
            generate_x_1 = lambda n: generate_uniforms_on_circle(n, ε=ε)
    return generate_x_0, generate_x_1

# Models

## Training

In [None]:
EPSILON = 0.03
N = 2**14
BATCH_SIZE = 2**10
EPOCHS = len(tests) * [20]
WEIGHT_DECAY = 0.
LEARNING_RATE = 1e-2
H = 64 # Width
L = 3 # Number of layers is L + 2
device = "cpu"

In [None]:
def train_model(x_0s, x_1s, epochs, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY):
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(x_0s, x_1s), batch_size=batch_size, shuffle=True
    )

    model_FM = get_model_FM(r2, H=H, L=L).to(device)
    print("Number of parameters: ", model_FM.parameter_count)
    optimizer_FM = torch.optim.Adam(model_FM.parameters(), learning_rate, weight_decay=weight_decay)
    loss = LogarithmicDistance(torch.Tensor([1., 1.]))

    losses_FM = np.zeros(epochs)
    for i in tqdm(range(epochs)):
        losses_FM[i] = model_FM.train_network(device, train_loader, optimizer_FM, loss)

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.plot(losses_FM)
    ax.set_title("Batch Loss Flow Matching")
    ax.set_xscale("log")
    ax.set_yscale("log")

    return model_FM

In [None]:
models_FM = {}
for i, test in enumerate(tests):
    print(test)
    generate_x_0, generate_x_1 = data_generator(test, ε=EPSILON)

    x_0s = generate_x_0(N)
    x_1s = generate_x_1(N)

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].scatter(x_0s[:32, 0], x_0s[:32, 1])
    ax[0].set_xlim(-3, 3)
    ax[0].set_ylim(-3, 3)
    ax[0].set_title(r"$\mathfrak{X}_0$")
    ax[1].scatter(x_1s[:32, 0], x_1s[:32, 1])
    ax[1].set_xlim(-3, 3)
    ax[1].set_ylim(-3, 3)
    ax[1].set_title(r"$\mathfrak{X}_1$")

    models_FM[test] = train_model(x_0s, x_1s, epochs=EPOCHS[i], batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE,
                                  weight_decay=WEIGHT_DECAY)

# Testing

In [None]:
def create_animations(tests, models_FM, N=2**5):
    N_models = len(tests)

    # Forward flow
    x_0s_test = {}
    x_ts_test = {}
    for test in tests:
        generate_x_0, _ = data_generator(test, ε=EPSILON)
        x_0s_test[test] = generate_x_0(N).to(device)
        x_ts_test[test] = x_0s_test[test].detach().clone()
        
        models_FM[test].eval()

    t = 0
    N_steps = 120
    Δt = 1. / N_steps
    metadata = {'title': 'Flow Matching R^2', 'artist': 'Matplotlib'}
    writer = FFMpegWriter(fps=30, metadata=metadata)

    fig, ax = plt.subplots(1, N_models, figsize=(5 * N_models, 5))

    with writer.saving(fig, f"output/flow_matching_R2.mp4", dpi=150):
        for i, test in enumerate(tests):
            x_ts_test_plot = x_ts_test[test].to("cpu")
            ax.scatter(
                x_ts_test_plot[:, 0], x_ts_test_plot[:, 1]
            )
            ax.set_title(fr"$\mathfrak{{X}}_{{{t:.2f}}}$")
            ax.set_xlim(-3, 3)
            ax.set_ylim(-3, 3)
        writer.grab_frame()
        for frame in tqdm(range(N_steps)):
            t = frame * Δt
            for i, test in enumerate(tests):
                with torch.no_grad():
                    x_ts_test[test] = models_FM[test].step(x_ts_test[test], torch.Tensor([t]), Δt)
                x_ts_test_plot = x_ts_test[test].to("cpu")
                ax.clear()
                ax.scatter(
                    x_ts_test_plot[:, 0], x_ts_test_plot[:, 1]
                )
                ax.set_title(fr"$\mathfrak{{X}}_{{{t:.2f}}}$")
                ax.set_xlim(-3, 3)
                ax.set_ylim(-3, 3)
            writer.grab_frame()
        writer.grab_frame()

    # Backward flow
    x_1s_test = {}
    x_ts_test = {}
    for test in tests:
        _, generate_x_1 = data_generator(test, ε=EPSILON)
        x_1s_test[test] = generate_x_1(N).to(device)
        x_ts_test[test] = x_1s_test[test].detach().clone()
        
        models_FM[test].eval()

    t = 0
    N_steps = 120
    Δt = 1. / N_steps
    metadata = {'title': 'Flow Matching R^2', 'artist': 'Matplotlib'}
    writer = FFMpegWriter(fps=30, metadata=metadata)

    fig, ax = plt.subplots(1, N_models, figsize=(5 * N_models, 5))

    with writer.saving(fig, f"output/flow_matching_R2_backwards.mp4", dpi=150):
        for i, test in enumerate(tests):
            x_ts_test_plot = x_ts_test[test].to("cpu")
            ax.scatter(
                x_ts_test_plot[:, 0], x_ts_test_plot[:, 1]
            )
            ax.set_title(fr"$\mathfrak{{X}}_{{{t:.2f}}}$")
            ax.set_xlim(-3, 3)
            ax.set_ylim(-3, 3)
        writer.grab_frame()
        for frame in tqdm(range(N_steps)):
            t = 1. - frame * Δt
            for i, test in enumerate(tests):
                with torch.no_grad():
                    x_ts_test[test] = models_FM[test].step_back(x_ts_test[test], torch.Tensor([t]), Δt)
                x_ts_test_plot = x_ts_test[test].to("cpu")
                ax.clear()
                ax.scatter(
                    x_ts_test_plot[:, 0], x_ts_test_plot[:, 1]
                )
                ax.set_title(fr"$\mathfrak{{X}}_{{{t:.2f}}}$")
                ax.set_xlim(-3, 3)
                ax.set_ylim(-3, 3)
            writer.grab_frame()
        writer.grab_frame()

In [None]:
if generate_videos:
    create_animations(tests, models_FM, N=2**5)

In [None]:
t = 0
N_steps = 240
Δt = 1. / N_steps
N_show = 5
N_skip = int(N_steps / (N_show-1))

N_models = len(tests)
N_samples = 32

x_0s_test = {}
x_1s_test = {}
x_ts_test = {}
for test in tests:
    generate_x_0, generate_x_1 = data_generator(test, ε=EPSILON)
    x_0s_test[test] = generate_x_0(N_samples).to(device)
    x_1s_test[test] = generate_x_1(N_samples).to(device)
    x_ts_test[test] = x_0s_test[test].detach().clone()
    
    models_FM[test].eval()

fig = plt.figure(figsize=(4.8, 1.6 * N_models * 3/3.1))
gs = gridspec.GridSpec(N_models, 4, width_ratios=[1, 1, 1, 0.1], height_ratios=N_models * [1.], wspace=0.1, hspace=0.1)
cax = fig.add_subplot(gs[:, 3])

ax = []
for i in range(N_models):
    ax.append([])
    for j in range(3):
        a = fig.add_subplot(gs[i, j])
        a.set_xlim(-3, 3)
        a.set_ylim(-3, 3)
        a.set_xticks([])
        a.set_yticks([])
        a.set_aspect("equal")
        ax[i].append(a)
        
ax[0][0].set_title(r"$\mathfrak{X}_0$")
ax[0][1].set_title(r"$\mathfrak{X}_t$")
ax[0][2].set_title(r"$\mathfrak{X}_1$")
Δc = 1 / (N_show - 1)
colors = [(j * Δc, 0.1, 1 - j * Δc) for j in range(N_show)]
cmap = mcolors.ListedColormap(colors)
for i, test in enumerate(tests):
    k = 0

    ax[i][0].scatter(
        x_0s_test[test][:N_samples, 0], x_0s_test[test][:N_samples, 1], marker="."
    )
    ax[i][2].scatter(
        x_1s_test[test][:N_samples, 0], x_1s_test[test][:N_samples, 1], marker="."
    )

    alpha = 1
    N_samples_shown = N_samples
    for j in range(N_steps+1):
        t = j * Δt
        if j == N_steps:
            alpha = 1
            N_samples_shown = N_samples
        with torch.no_grad():
            x_ts_test[test] = models_FM[test].step(x_ts_test[test], torch.Tensor([t]), torch.Tensor([Δt]))

        if j % N_skip == 0:
            x_ts_test_plot = x_ts_test[test].to("cpu")
            im = ax[i][1].scatter(
                x_ts_test_plot[:N_samples_shown, 0], x_ts_test_plot[:N_samples_shown, 1],
                color=colors[k], marker=".", alpha=alpha
            )
            k += 1
        alpha = 0.5
        N_samples_shown = N_samples // 4
fig.colorbar(ScalarMappable(cmap=cmap), cax=cax, ticks=np.linspace(0, 1, N_show), label="$t$");
fig.savefig(f"output/interpolation_R2.pdf", bbox_inches="tight")