In [None]:
import torch
import numpy as np
import SE2
import flowfield
import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
import matplotlib.colors as mcolors
from matplotlib.cm import ScalarMappable
from tqdm import tqdm

In [None]:
se2 = SE2.SE2()

In [None]:
generate_videos = True

# Distributions

In [None]:
N = 2**12

In [None]:
def generate_normals(N):
    translations = np.random.randn(N, 2)
    rotations = 2 * np.pi * np.random.rand(N)
    g = np.hstack((translations, rotations[..., None]))
    return torch.Tensor(g)

def generate_uniforms_on_circle(N, ε=0.05):
    normals = torch.Tensor(np.random.randn(N, 2))
    lengths = (normals**2).sum(dim=-1).sqrt()
    non_zero = lengths > 0.
    circle_samples = normals[non_zero] / lengths[non_zero, None]
    angles = SE2._mod_offset(torch.arctan2(circle_samples[..., 1], circle_samples[..., 0]), 2 * np.pi, 0.)
    gs = torch.hstack((circle_samples, angles[..., None]))
    return gs + ε * torch.randn(gs.shape)

def generate_line(N, d=0., w=1., horizontal=True, ε=0.05):
    if horizontal:
        xs = 2 * w * (np.random.rand(N) - 0.5)
        ys = np.ones(N) * d
        angles = np.ones(N) * np.pi / 2.
    else:
        xs = np.ones(N) * d
        ys = 2 * w * (np.random.rand(N) - 0.5)
        angles = np.zeros(N)
    gs = torch.Tensor(np.hstack((xs[..., None], ys[..., None], angles[..., None])))
    return gs + ε * torch.randn(gs.shape)

## $G_0 \sim \mathcal{N}(0, I)$, $G_1 \sim \operatorname{Uniform}(S^1)$

In [None]:
# "line_to_line" "normal_to_circle"
test = "line_to_line"

match test:
    case "line_to_line":
        generate_g_0 = lambda n: generate_line(n, d=-2., w=1., horizontal=True)
        generate_g_1 = lambda n: generate_line(n, d=2, w=1., horizontal=False) 
    case "normal_to_circle":
        generate_g_0 = generate_normals
        generate_g_1 = generate_uniforms_on_circle

g_0s = generate_g_0(N)
g_1s = generate_g_1(N)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].quiver(g_0s[:32, 0], g_0s[:32, 1], torch.cos(g_0s[:32, 2]), torch.sin(g_0s[:32, 2]))
ax[0].set_xlim(-3, 3)
ax[0].set_ylim(-3, 3)
ax[0].set_title("$G_0$")
ax[1].quiver(g_1s[:32, 0], g_1s[:32, 1], torch.cos(g_1s[:32, 2]), torch.sin(g_1s[:32, 2]))
ax[1].set_xlim(-3, 3)
ax[1].set_ylim(-3, 3)
ax[1].set_title("$G_1$");

# Models

## Training

In [None]:
BATCH_SIZE = 2**10
EPOCHS = 100
WEIGHT_DECAY = 1e-2
device = "cpu"

In [None]:
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(g_0s, g_1s), batch_size=BATCH_SIZE, shuffle=True
)

In [None]:
model_FM = flowfield.FlowField(se2).to(device)
optimizer_FM = torch.optim.Adam(model_FM.parameters(), 1e-2, weight_decay=WEIGHT_DECAY)
# import torch.nn as nn
# loss = nn.MSELoss()
loss = flowfield.LogarithmicDistance(torch.Tensor([1., 3., 1.]))

In [None]:
losses_FM = np.zeros(EPOCHS)
for i in tqdm(range(EPOCHS)):
    losses_FM[i] = model_FM.train_network(device, train_loader, optimizer_FM, loss)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(losses_FM)
ax.set_title("Batch Loss Flow Matching");

In [None]:
model_SCFM = flowfield.ShortCutField(se2).to(device)
optimizer_SCFM = torch.optim.Adam(model_SCFM.parameters(), 1e-2, weight_decay=WEIGHT_DECAY)

In [None]:
losses_SCFM = np.zeros(EPOCHS)
for i in tqdm(range(EPOCHS)):
    losses_SCFM[i] = model_SCFM.train_network(device, train_loader, optimizer_SCFM, loss)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(losses_SCFM)
ax.set_title("Batch Loss Short Cut Flow Matching");

## Testing

In [None]:
model_FM.eval()
model_SCFM.eval()
g_0s_test = generate_g_0(2**5).to(device)

In [None]:
xs_uniform = np.linspace(-2.5, 2.5, 10)
ys_uniform = np.linspace(-2.5, 2.5, 10)
θs_uniform = 2* np.pi * np.arange(10) / 10.
xs_uniform, ys_uniform, θs_uniform = np.meshgrid(xs_uniform, ys_uniform, θs_uniform)
g_0s_uniform = torch.Tensor(np.concatenate((xs_uniform[..., None], ys_uniform[..., None], θs_uniform[..., None]), axis=-1))
print(g_0s_uniform.shape)
ts = 0.5 * torch.ones_like(g_0s_uniform)

In [None]:
N_steps = 21
N_rows = 2
N_cols = 2
N_plots = N_rows * N_cols
N_skip = int(N_steps / (N_plots - 1))
g_ts_test = g_0s_test.detach().clone()
t = 0
Δt = 1 / N_steps
fig, ax = plt.subplots(N_rows, N_cols, figsize=(N_cols * 5, N_rows * 5))
g_0s_test_plot = g_0s_test.to("cpu")
index = np.unravel_index(0, (N_rows, N_cols))
ax[index].quiver(g_0s_test_plot[:, 0], g_0s_test_plot[:, 1], torch.cos(g_0s_test_plot[:, 2]), torch.sin(g_0s_test_plot[:, 2]))
ax[index].set_xlim(-3, 3)
ax[index].set_ylim(-3, 3)
ax[index].set_title("$G_0$")
with torch.no_grad():
    for i in range(N_plots-1):
        for _ in range(N_skip):
            g_ts_test = model_FM.step(g_ts_test, torch.Tensor([t]), Δt)
            t += Δt
        g_ts_test_plot = g_ts_test.to("cpu")
        index = np.unravel_index(i+1, (N_rows, N_cols))
        ax[index].quiver(g_ts_test_plot[:, 0], g_ts_test_plot[:, 1], torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2]))
        ax[index].set_xlim(-3, 3)
        ax[index].set_ylim(-3, 3)
        ax[index].set_title(f"$G_{{{t:.2f}}}$")
    g_1s_test = g_ts_test

In [None]:
N_steps = 21
N_rows = 2
N_cols = 2
N_plots = N_rows * N_cols
N_skip = int(N_steps / (N_plots - 1))
g_ts_test = g_0s_test.detach().clone()
t = 0
Δt = 1 / N_steps
fig, ax = plt.subplots(N_rows, N_cols, figsize=(N_cols * 5, N_rows * 5))
g_0s_test_plot = g_0s_test.to("cpu")
index = np.unravel_index(0, (N_rows, N_cols))
ax[index].quiver(g_0s_test_plot[:, 0], g_0s_test_plot[:, 1], torch.cos(g_0s_test_plot[:, 2]), torch.sin(g_0s_test_plot[:, 2]))
ax[index].set_xlim(-3, 3)
ax[index].set_ylim(-3, 3)
ax[index].set_title("$G_0$")
with torch.no_grad():
    for i in range(N_plots-1):
        for _ in range(N_skip):
            g_ts_test = model_FM.step(g_ts_test, torch.Tensor([t]), Δt)
            t += Δt
        g_ts_test_plot = g_ts_test.to("cpu")
        index = np.unravel_index(i+1, (N_rows, N_cols))
        ax[index].quiver(g_ts_test_plot[:, 0], g_ts_test_plot[:, 1], torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2]))
        ax[index].set_xlim(-3, 3)
        ax[index].set_ylim(-3, 3)
        ax[index].set_title(f"$G_{{{t:.2f}}}$")
    g_1s_test = g_ts_test

In [None]:
N_steps = 21
N_rows = 2
N_cols = 2
N_plots = N_rows * N_cols
N_skip = int(N_steps / (N_plots - 1))
g_ts_test = g_0s_test.detach().clone()
t = 0
Δt = 1 / N_steps
fig, ax = plt.subplots(N_rows, N_cols, figsize=(N_cols * 5, N_rows * 5))
g_0s_test_plot = g_0s_test.to("cpu")
index = np.unravel_index(0, (N_rows, N_cols))
ax[index].quiver(g_0s_test_plot[:, 0], g_0s_test_plot[:, 1], torch.cos(g_0s_test_plot[:, 2]), torch.sin(g_0s_test_plot[:, 2]))
ax[index].set_xlim(-3, 3)
ax[index].set_ylim(-3, 3)
ax[index].set_title("$G_0$")
with torch.no_grad():
    for i in range(N_plots-1):
        for _ in range(N_skip):
            g_ts_test = model_FM.step(g_ts_test, torch.Tensor([t]), Δt)
            t += Δt
        g_ts_test_plot = g_ts_test.to("cpu")
        index = np.unravel_index(i+1, (N_rows, N_cols))
        ax[index].quiver(g_ts_test_plot[:, 0], g_ts_test_plot[:, 1], torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2]))
        ax[index].set_xlim(-3, 3)
        ax[index].set_ylim(-3, 3)
        ax[index].set_title(f"$G_{{{t:.2f}}}$")
    g_1s_test = g_ts_test

In [None]:
N_steps = 21
N_rows = 2
N_cols = 2
N_plots = N_rows * N_cols
N_skip = int(N_steps / (N_plots - 1))
g_ts_test = g_0s_test.detach().clone()
t = 0
Δt = 1 / N_steps
fig, ax = plt.subplots(N_rows, N_cols, figsize=(N_cols * 5, N_rows * 5))
g_0s_test_plot = g_0s_test.to("cpu")
index = np.unravel_index(0, (N_rows, N_cols))
ax[index].quiver(
    g_0s_test_plot[:, 0], g_0s_test_plot[:, 1],
    torch.cos(g_0s_test_plot[:, 2]), torch.sin(g_0s_test_plot[:, 2])
)
ax[index].set_xlim(-3, 3)
ax[index].set_ylim(-3, 3)
ax[index].set_title("$G_0$")
with torch.no_grad():
    for i in range(N_plots-1):
        for _ in range(N_skip):
            g_ts_test = model_SCFM.step(g_ts_test, torch.Tensor([t]), torch.Tensor([Δt]))
            t += Δt
        g_ts_test_plot = g_ts_test.to("cpu")
        index = np.unravel_index(i+1, (N_rows, N_cols))
        ax[index].quiver(
            g_ts_test_plot[:, 0], g_ts_test_plot[:, 1],
            torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2])
        )
        ax[index].set_xlim(-3, 3)
        ax[index].set_ylim(-3, 3)
        ax[index].set_title(f"$G_{{{t:.2f}}}$")
    g_1s_test = g_ts_test

In [None]:
if generate_videos:
    t = 0
    N_steps = 240
    Δt = 1. / N_steps
    metadata = {'title': 'Flow Matching SE(2)', 'artist': 'Matplotlib'}
    writer = FFMpegWriter(fps=60, metadata=metadata)

    g_ts_test = g_0s_test.detach().clone()

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))

    with writer.saving(fig, f"flow_matching_SE2_{loss}.mp4", dpi=300):
        g_ts_test_plot = g_ts_test.to("cpu")
        ax.quiver(
            g_ts_test_plot[:, 0], g_ts_test_plot[:, 1],
            torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2])
        )
        ax.set_title(f"$G_{{{t:.2f}}}$")
        ax.set_xlim(-3, 3)
        ax.set_ylim(-3, 3)
        writer.grab_frame()
        for frame in range(N_steps):
            t = frame * Δt
            with torch.no_grad():
                g_ts_test = model_FM.step(g_ts_test, torch.Tensor([t]), Δt)
            ax.clear()
            g_ts_test_plot = g_ts_test.to("cpu")
            ax.quiver(
                g_ts_test_plot[:, 0], g_ts_test_plot[:, 1],
                torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2])
            )
            ax.set_title(f"$G_{{{t:.2f}}}$")
            ax.set_xlim(-3, 3)
            ax.set_ylim(-3, 3)
            writer.grab_frame()
        writer.grab_frame()

In [None]:
t = 0
N_steps = 240
Δt = 1. / N_steps
N_show = 5
N_skip = int(N_steps / (N_show-1))

g_ts_test = g_0s_test.detach().clone()
g_ts_test_plot = g_ts_test.to("cpu")

Δc = 1 / (N_show - 1)
colors = [(j * Δc, 0.1, 1 - j * Δc) for j in range(N_show)]
cmap = mcolors.ListedColormap(colors)
j = 0

fig, ax = plt.subplots(1, 1, figsize=(6, 5))
fig.colorbar(ScalarMappable(cmap=cmap), ax=ax, ticks=np.linspace(0, 1, N_show), label="t");
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
for i in range(N_steps+1):
    t = i * Δt
    with torch.no_grad():
        g_ts_test = model_FM.step(g_ts_test, torch.Tensor([t]), Δt)
    
    if i % N_skip == 0:
        print(i, j)
        g_ts_test_plot = g_ts_test.to("cpu")
        im = ax.quiver(
            g_ts_test_plot[:, 0], g_ts_test_plot[:, 1],
            torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2]),
            color=colors[j]
        )
        j += 1