In [None]:
import torch
import numpy as np
from lieflow.groups import SE2
from lieflow.models import (
    get_model_SCFM,
    LogarithmicDistance
)
import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
from matplotlib.cm import ScalarMappable
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern"],
    "font.size": 10.0,
    "text.latex.preamble": r"\usepackage{lmodern} \usepackage{amssymb} \usepackage{amsmath}"
})
from tqdm.notebook import tqdm

In [None]:
se2 = SE2()

In [None]:
generate_videos = True

# Distributions

In [None]:
def generate_normals(N):
    translations = np.random.randn(N, 2)
    rotations = 2 * np.pi * np.random.rand(N)
    g = np.hstack((translations, rotations[..., None]))
    return torch.Tensor(g)

def generate_uniforms_on_circle(N, centre=np.array((0., 0.)), ε=0.05):
    angles = torch.Tensor(np.random.rand(N) * 2 * torch.pi)[..., None]
    gs = torch.hstack((torch.cos(angles) + centre[0], torch.sin(angles) + centre[1], angles))
    return gs + ε * torch.randn(gs.shape)

def generate_line(N, d=0., w=1., horizontal=True, ε=0.05):
    if horizontal:
        xs = 2 * w * (np.random.rand(N) - 0.5)
        ys = np.ones(N) * d
        angles = np.ones(N) * np.pi / 2.
    else:
        xs = np.ones(N) * d
        ys = 2 * w * (np.random.rand(N) - 0.5)
        angles = np.zeros(N)
    gs = torch.Tensor(np.hstack((xs[..., None], ys[..., None], angles[..., None])))
    return gs + ε * torch.randn(gs.shape)

In [None]:
# "normal_to_circle" "vertical_line_to_vertical_line"
tests = ("horizontal_line_to_vertical_line", "vertical_line_to_circle")

In [None]:
def data_generator(test, ε=0.04):
    match test:
        case "horizontal_line_to_vertical_line":
            generate_g_0 = lambda n: generate_line(n, d=-2.5, w=1.5, horizontal=True, ε=ε)
            generate_g_1 = lambda n: generate_line(n, d=2.5, w=1.5, horizontal=False, ε=ε)
        case "vertical_line_to_vertical_line":
            generate_g_0 = lambda n: generate_line(n, d=-2.5, w=1.5, horizontal=False, ε=ε)
            generate_g_1 = lambda n: generate_line(n, d=2.5, w=1.5, horizontal=False, ε=ε) 
        case "normal_to_circle":
            generate_g_0 = lambda n: generate_normals(n, ε=ε)
            generate_g_1 = lambda n: generate_uniforms_on_circle(n, ε=ε)
        case "vertical_line_to_circle":
            generate_g_0 = lambda n: generate_line(n, d=-2.5, w=2., horizontal=False, ε=ε)
            generate_g_1 = lambda n: generate_uniforms_on_circle(n, centre=np.array((1.25, 0.)), ε=ε)
    return generate_g_0, generate_g_1

# Models

## Training

In [None]:
EPSILON = 0.03
N = 2**14
BATCH_SIZE = 2**10
EPOCHS = len(tests) * [50]
WEIGHT_DECAY = 0.
LEARNING_RATE = 1e-2
H = 64 # Width
L = 3 # Number of layers is L + 2
device = "cpu"

In [None]:
def train_model_SCFM(g_0s, g_1s, epochs, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY):
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(g_0s, g_1s), batch_size=batch_size, shuffle=True
    )

    model_SCFM = get_model_SCFM(se2).to(device)
    print("Number of parameters: ", model_SCFM.parameter_count)
    optimizer_SCFM = torch.optim.Adam(model_SCFM.parameters(), learning_rate, weight_decay=weight_decay)
    loss = LogarithmicDistance(torch.Tensor([1., 1., 1.]))

    losses_SCFM = np.zeros(epochs)
    for i in tqdm(range(epochs)):
        losses_SCFM[i] = model_SCFM.train_network(device, train_loader, optimizer_SCFM, loss)

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.plot(losses_SCFM)
    ax.set_title("Batch Loss Short Cut Flow Matching")
    ax.set_xscale("log")
    ax.set_yscale("log")

    return model_SCFM

In [None]:
models_SCFM = {}
for i, test in enumerate(tests):
    print(test)
    generate_g_0, generate_g_1 = data_generator(test, ε=EPSILON)

    g_0s = generate_g_0(N)
    g_1s = generate_g_1(N)

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].quiver(g_0s[:32, 0], g_0s[:32, 1], torch.cos(g_0s[:32, 2]), torch.sin(g_0s[:32, 2]))
    ax[0].set_xlim(-3, 3)
    ax[0].set_ylim(-3, 3)
    ax[0].set_title(r"$\mathfrak{X}_0$")
    ax[1].quiver(g_1s[:32, 0], g_1s[:32, 1], torch.cos(g_1s[:32, 2]), torch.sin(g_1s[:32, 2]))
    ax[1].set_xlim(-3, 3)
    ax[1].set_ylim(-3, 3)
    ax[1].set_title(r"$\mathfrak{X}_1$")

    models_SCFM[test] = train_model_SCFM(g_0s, g_1s, epochs=EPOCHS[i], batch_size=BATCH_SIZE,
                                         learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

## Testing

In [None]:
def create_animations(tests, models_SCFM, N=2**5):
    N_models = len(tests)

    g_0s_test = {}
    g_ts_test = {}
    for test in tests:
        generate_g_0, _ = data_generator(test, ε=EPSILON)
        g_0s_test[test] = generate_g_0(N).to(device)
        g_ts_test[test] = g_0s_test[test].detach().clone()
        
        models_SCFM[test].eval()

    t = 0
    N_steps = 120
    Δt = 1. / N_steps
    metadata = {'title': 'Shortcut Modelling SE(2)', 'artist': 'Matplotlib'}
    writer = FFMpegWriter(fps=30, metadata=metadata)

    fig, ax = plt.subplots(1, N_models, figsize=(5 * N_models, 5))

    with writer.saving(fig, f"output/shortcut_SE2.mp4", dpi=150):
        for i, test in enumerate(tests):
            g_ts_test_plot = g_ts_test[test].to("cpu")
            ax[i].quiver(
                g_ts_test_plot[:, 0], g_ts_test_plot[:, 1],
                torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2])
            )
            ax[i].set_title(fr"$\mathfrak{{X}}_{{{t:.2f}}}$")
            ax[i].set_xlim(-3, 3)
            ax[i].set_ylim(-3, 3)
        writer.grab_frame()
        for frame in tqdm(range(N_steps)):
            t = frame * Δt
            for i, test in enumerate(tests):
                with torch.no_grad():
                    g_ts_test[test] = models_SCFM[test].step(g_ts_test[test], torch.Tensor([t]), torch.Tensor([Δt]))
                g_ts_test_plot = g_ts_test[test].to("cpu")
                ax[i].clear()
                ax[i].quiver(
                    g_ts_test_plot[:, 0], g_ts_test_plot[:, 1],
                    torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2])
                )
                ax[i].set_title(fr"$\mathfrak{{X}}_{{{t:.2f}}}$")
                ax[i].set_xlim(-3, 3)
                ax[i].set_ylim(-3, 3)
            writer.grab_frame()
        writer.grab_frame()

In [None]:
if generate_videos:
    create_animations(tests, models_SCFM, N=2**5)

In [None]:
t = 0
N_steps = 24
Δt = 1. / N_steps
N_show = 5
N_skip = int(N_steps / (N_show-1))

N_models = len(tests)
N_samples = 32

g_0s_test = {}
g_1s_test = {}
g_ts_test = {}
for test in tests:
    generate_g_0, generate_g_1 = data_generator(test, ε=EPSILON)
    g_0s_test[test] = generate_g_0(N_samples).to(device)
    g_1s_test[test] = generate_g_1(N_samples).to(device)
    g_ts_test[test] = g_0s_test[test].detach().clone()
    
    models_SCFM[test].eval()

fig = plt.figure(figsize=(4.8, 1.6 * N_models * 3/3.1))
gs = gridspec.GridSpec(N_models, 4, width_ratios=[1, 1, 1, 0.1], height_ratios=N_models * [1.], wspace=0.1, hspace=0.1)
cax = fig.add_subplot(gs[:, 3])

ax = []
for i in range(N_models):
    ax.append([])
    for j in range(3):
        a = fig.add_subplot(gs[i, j])
        a.set_xlim(-3, 3)
        a.set_ylim(-3, 3)
        a.set_xticks([])
        a.set_yticks([])
        a.set_aspect("equal")
        ax[i].append(a)
        
ax[0][0].set_title(r"$\mathfrak{X}_0$")
ax[0][1].set_title(r"$\mathfrak{X}_t$")
ax[0][2].set_title(r"$\mathfrak{X}_1$")
Δc = 1 / (N_show - 1)
colors = [(j * Δc, 0.1, 1 - j * Δc) for j in range(N_show)]
cmap = mcolors.ListedColormap(colors)
for i, test in enumerate(tests):
    k = 0

    ax[i][0].quiver(
        g_0s_test[test][:N_samples, 0], g_0s_test[test][:N_samples, 1],
        torch.cos(g_0s_test[test][:N_samples, 2]), torch.sin(g_0s_test[test][:N_samples, 2]),
        width=0.01
    )
    ax[i][2].quiver(
        g_1s_test[test][:N_samples, 0], g_1s_test[test][:N_samples, 1],
        torch.cos(g_1s_test[test][:N_samples, 2]), torch.sin(g_1s_test[test][:N_samples, 2]),
        width=0.01
    )

    alpha = 1
    N_samples_shown = N_samples
    for j in range(N_steps+1):
        t = j * Δt
        if j == N_steps:
            alpha = 1
            N_samples_shown = N_samples
        with torch.no_grad():
            g_ts_test[test] = models_SCFM[test].step(g_ts_test[test], torch.Tensor([t]), torch.Tensor([Δt]))

        if j % N_skip == 0:
            g_ts_test_plot = g_ts_test[test].to("cpu")
            im = ax[i][1].quiver(
                g_ts_test_plot[:N_samples_shown, 0], g_ts_test_plot[:N_samples_shown, 1],
                torch.cos(g_ts_test_plot[:N_samples_shown, 2]), torch.sin(g_ts_test_plot[:N_samples_shown, 2]),
                color=colors[k], width=0.01, alpha=alpha
            )
            k += 1
        alpha = 0.5
        N_samples_shown = N_samples // 4
fig.colorbar(ScalarMappable(cmap=cmap), cax=cax, ticks=np.linspace(0, 1, N_show), label="$t$");
fig.savefig(f"output/interpolation_SCFM_SE2.pdf", bbox_inches="tight")