In [None]:
import torch
import numpy as np
from lieflow.groups import Rn
from lieflow.models import (
    get_model_FM,
    LogarithmicDistance
)
import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern"],
    "font.size": 10.0,
    "text.latex.preamble": r"\usepackage{lmodern} \usepackage{amssymb} \usepackage{amsmath}"
})
from tqdm.notebook import tqdm

In [None]:
r2 = Rn(2)
M = 10

In [None]:
generate_videos = True

# Distributions

In [None]:
def generate_normals(N, m):
    return torch.Tensor(np.random.randn(N, m, 2))

def generate_uniforms_on_circle(N, m, ε=0.04):
    angles_unshifted = np.random.rand(N) * 2 * np.pi
    dθ = 2 * np.pi/m
    angles = torch.Tensor(np.arange(m)[..., None] * dθ + angles_unshifted)[..., None]
    xs = torch.cat((torch.cos(angles), torch.sin(angles)), dim=-1).transpose(0, 1)
    return xs + ε * torch.randn(xs.shape)


In [None]:
t = generate_uniforms_on_circle(3, M)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
for x in t:
    ax.scatter(x[..., 0], x[..., 1])
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3);

In [None]:
tests = ("normals_to_circle",)

In [None]:
def data_generator(test, ε=0.04):
    match test:
        case "normals_to_circle":
            generate_x_0 = lambda n: generate_normals(n, M)
            generate_x_1 = lambda n: generate_uniforms_on_circle(n, M, ε=ε)
    return generate_x_0, generate_x_1

# Models

## Training

In [None]:
EPSILON = 0.03
N = 2**12
BATCH_SIZE = 2**8
EPOCHS = len(tests) * [100]
WEIGHT_DECAY = 0.
LEARNING_RATE = 5e-4
EMBED_DIM = 64
NUM_HEADS = 8
EXPANSION = 4
L = 4
device = "cuda"

In [None]:
def train_model(x_0s, x_1s, epochs, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY):
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(x_0s, x_1s), batch_size=batch_size, shuffle=True
    )

    model_FM = get_model_FM(r2, L=L, power_group=True, embed_dim=EMBED_DIM, num_heads=NUM_HEADS, expansion=EXPANSION).to(device)
    print("Number of parameters: ", model_FM.parameter_count)
    optimizer_FM = torch.optim.Adam(model_FM.parameters(), learning_rate, weight_decay=weight_decay)
    loss = LogarithmicDistance(torch.Tensor([1., 1.]).to(device))

    losses_FM = np.zeros(epochs)
    for i in tqdm(range(epochs)):
        losses_FM[i] = model_FM.train_network(device, train_loader, optimizer_FM, loss)

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.plot(losses_FM)
    ax.set_title("Batch Loss Flow Matching")
    ax.set_xscale("log")
    ax.set_yscale("log")

    return model_FM, losses_FM

In [None]:
models_FM = {}
losses_FM = {}
for i, test in enumerate(tests):
    print(test)
    generate_x_0, generate_x_1 = data_generator(test, ε=EPSILON)

    x_0s = generate_x_0(N)
    x_1s = generate_x_1(N)

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    for j in range(M):
        ax[0].scatter(x_0s[:32, j, 0], x_0s[:32, j, 1])
        ax[1].scatter(x_1s[:32, j, 0], x_1s[:32, j, 1])
    ax[0].set_xlim(-3, 3)
    ax[0].set_ylim(-3, 3)
    ax[0].set_title(r"$\mathfrak{X}_0$")
    ax[1].set_xlim(-3, 3)
    ax[1].set_ylim(-3, 3)
    ax[1].set_title(r"$\mathfrak{X}_1$")

    models_FM[test], losses_FM[test] = train_model(x_0s, x_1s, epochs=EPOCHS[i], batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE,
                                  weight_decay=WEIGHT_DECAY)

# Testing

In [None]:
def create_animations(tests, models_FM, N=2**5):
    N_models = len(tests)

    # Forward flow
    x_0s_test = {}
    x_ts_test = {}
    for test in tests:
        generate_x_0, _ = data_generator(test, ε=EPSILON)
        x_0s_test[test] = generate_x_0(N).to(device)
        x_ts_test[test] = x_0s_test[test].detach().clone()
        
        models_FM[test].eval()

    t = 0
    N_steps = 120
    Δt = 1. / N_steps
    metadata = {'title': 'Flow Matching R^2', 'artist': 'Matplotlib'}
    writer = FFMpegWriter(fps=30, metadata=metadata)

    fig, ax = plt.subplots(1, N_models, figsize=(5 * N_models, 5))

    Δc = 1 / (M - 1)
    colors = [(j * Δc, 0.1, 1 - j * Δc) for j in range(M)]

    with writer.saving(fig, f"output/flow_matching_R2_n.mp4", dpi=150):
        for frame in tqdm(range(N_steps)):
            t = frame * Δt
            for test in tests:
                with torch.no_grad():
                    x_ts_test[test] = models_FM[test].step(x_ts_test[test], torch.Tensor([t]).to(device), Δt)
                x_ts_test_plot = x_ts_test[test].to("cpu")
                ax.clear()
                for i in range(M):
                    ax.scatter(
                        x_ts_test_plot[:, i, 0], x_ts_test_plot[:, i, 1], color=colors[i]
                    )
                ax.set_title(fr"$\mathfrak{{X}}_{{{t:.2f}}}$")
                ax.set_xlim(-3, 3)
                ax.set_ylim(-3, 3)
            writer.grab_frame()
        writer.grab_frame()

    # Backward flow
    x_1s_test = {}
    x_ts_test = {}
    for test in tests:
        _, generate_x_1 = data_generator(test, ε=EPSILON)
        x_1s_test[test] = generate_x_1(N).to(device)
        x_ts_test[test] = x_1s_test[test].detach().clone()
        
        models_FM[test].eval()

    t = 0
    N_steps = 120
    Δt = 1. / N_steps
    metadata = {'title': 'Flow Matching R^2', 'artist': 'Matplotlib'}
    writer = FFMpegWriter(fps=30, metadata=metadata)

    fig, ax = plt.subplots(1, N_models, figsize=(5 * N_models, 5))

    Δc = 1 / (M - 1)
    colors = [(j * Δc, 0.1, 1 - j * Δc) for j in range(M)]

    with writer.saving(fig, f"output/flow_matching_R2_n_backwards.mp4", dpi=150):
        for frame in tqdm(range(N_steps)):
            t = 1. - frame * Δt
            for test in tests:
                with torch.no_grad():
                    x_ts_test[test] = models_FM[test].step_back(x_ts_test[test], torch.Tensor([t]).to(device), Δt)
                x_ts_test_plot = x_ts_test[test].to("cpu")
                ax.clear()
                for i in range(M):
                    ax.scatter(
                        x_ts_test_plot[:, i, 0], x_ts_test_plot[:, i, 1], color=colors[i]
                    )
                ax.set_title(fr"$\mathfrak{{X}}_{{{t:.2f}}}$")
                ax.set_xlim(-3, 3)
                ax.set_ylim(-3, 3)
            writer.grab_frame()
        writer.grab_frame()

In [None]:
if generate_videos:
    create_animations(tests, models_FM, N=1)

In [None]:
generate_x_0, generate_x_1 = data_generator(test, ε=EPSILON)
x_0s_test = generate_x_0(1).to(device)
x_ts_test = x_0s_test.detach().clone()
x_1s_test = generate_x_1(1).to("cpu")
    
model_FM = models_FM[tests[0]]
model_FM.eval()

t = 0
N_steps = 120
Δt = 1. / N_steps
for test in tests:
    fig, ax = plt.subplots(1, 2, figsize=(5 * 2, 5))

    Δc = 1 / (M - 1)
    colors = [(j * Δc, 0.1, 1 - j * Δc) for j in range(M)]

    for frame in tqdm(range(N_steps)):
        t = frame * Δt
        for test in tests:
            with torch.no_grad():
                x_ts_test = model_FM.step(x_ts_test, torch.Tensor([t]).to(device), Δt)
    x_ts_test_plot = x_ts_test.to("cpu")
    for i in range(M):
        ax[0].scatter(
            x_ts_test_plot[:, i, 0], x_ts_test_plot[:, i, 1], color=colors[i]
        )
        ax[1].scatter(
            x_1s_test[:, i, 0], x_1s_test[:, i, 1], color=colors[i]
        )
    ax[0].set_title(fr"$\mathfrak{{X}}_{{{t:.2f}}}$")
    ax[1].set_title(fr"$\mathfrak{{X}}_{1}$")
    ax[0].set_xlim(-3, 3)
    ax[0].set_ylim(-3, 3)
    ax[1].set_xlim(-3, 3)
    ax[1].set_ylim(-3, 3)