In [1]:
import torch
import torch.nn as nn
import numpy as np
import SE2
import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
from tqdm import tqdm

In [2]:
generate_videos = False

# Distributions

In [3]:
N = 2**12

## $G_0 \sim \mathcal{N}(0, I)$

In [4]:
def generate_normals(N):
    translations = np.random.randn(N, 2)
    rotations = 2 * np.pi * np.random.rand(N)
    g = np.hstack((translations, rotations[..., None]))
    return torch.Tensor(g)

In [5]:
g_0s = generate_normals(N)

## $G_1 \sim \operatorname{Uniform}(S^1)$

In [6]:
def generate_uniforms_on_circle(N, ε=0.05):
    normals = torch.Tensor(np.random.randn(N, 2))
    lengths = (normals**2).sum(dim=-1).sqrt()
    non_zero = lengths > 0.
    circle_samples = normals[non_zero] / lengths[non_zero, None]
    angles = SE2.mod_offset(torch.arctan2(circle_samples[..., 1], circle_samples[..., 0]), 2 * np.pi, 0.)
    gs = torch.hstack((circle_samples, angles[..., None]))
    return gs + ε * torch.randn(gs.shape)

In [7]:
g_1s = generate_uniforms_on_circle(N)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].quiver(g_0s[:32, 0], g_0s[:32, 1], torch.cos(g_0s[:32, 2]), torch.sin(g_0s[:32, 2]))
ax[0].set_xlim(-3, 3)
ax[0].set_ylim(-3, 3)
ax[0].set_title("$G_0$")
ax[1].quiver(g_1s[:32, 0], g_1s[:32, 1], torch.cos(g_1s[:32, 2]), torch.sin(g_1s[:32, 2]))
ax[1].set_xlim(-3, 3)
ax[1].set_ylim(-3, 3)
ax[1].set_title("$G_1$");

# Models

In [9]:
class FlowField(nn.Module):
    def __init__(self, H=64):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(4, H), nn.ReLU(),
            nn.Linear(H, H), nn.ReLU(),
            nn.Linear(H, H), nn.ReLU(),
            nn.Linear(H, 3)
        )

    def forward(self, g_t, t):
        return self.network(torch.cat((g_t, t), dim=-1))
    
    def step(self, g_t, t, Δt):
        t = t.view(1, 1).expand(g_t.shape[0], 1)
        return SE2.L(g_t, SE2.exp(Δt * self(g_t, t)))

## Training

In [10]:
BATCH_SIZE = 2**8
EPOCHS = 100
device = "cpu"

In [11]:
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(g_0s, g_1s), batch_size=BATCH_SIZE, shuffle=True
)

In [12]:
def train_FM(model, device, train_loader, optimizer, loss):
    model.train()
    N_batches = len(train_loader)
    losses = np.zeros(N_batches)
    for i, (g_0, g_1) in tqdm(
        enumerate(train_loader),
        total=N_batches,
        desc="Training",
        dynamic_ncols=True,
        unit="batch"
    ):
        t = torch.rand(len(g_1), 1)
        g_0, g_1 = g_0.to(device), g_1.to(device)
        A_t = SE2.log(SE2.L_inv(g_0, g_1))
        g_t = SE2.L(g_0, SE2.exp(t * A_t))
        optimizer.zero_grad()
        batch_loss = loss(model(g_t, t), A_t)
        losses[i] = float(batch_loss.cpu().item())
        batch_loss.backward()
        optimizer.step()
    return losses.mean()

In [13]:
model_FM = FlowField().to(device)
optimizer_FM = torch.optim.Adam(model_FM.parameters(), 1e-2)
loss = nn.MSELoss()

In [None]:
losses_FM = np.zeros(EPOCHS)
for i in tqdm(range(EPOCHS)):
    losses_FM[i] = train_FM(model_FM, device, train_loader, optimizer_FM, loss)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(losses_FM)
ax.set_title("Batch Loss Flow Matching");

## Testing

In [16]:
model_FM.eval()
g_0s_test = generate_normals(BATCH_SIZE).to(device)

In [None]:
N_steps = 21
N_rows = 2
N_cols = 2
N_plots = N_rows * N_cols
N_skip = int(N_steps / (N_plots - 1))
g_ts_test = g_0s_test.detach().clone()
t = 0
Δt = 1 / N_steps
fig, ax = plt.subplots(N_rows, N_cols, figsize=(N_cols * 5, N_rows * 5))
g_0s_test_plot = g_0s_test.to("cpu")
index = np.unravel_index(0, (N_rows, N_cols))
ax[index].quiver(g_0s_test_plot[:, 0], g_0s_test_plot[:, 1], torch.cos(g_0s_test_plot[:, 2]), torch.sin(g_0s_test_plot[:, 2]))
ax[index].set_xlim(-3, 3)
ax[index].set_ylim(-3, 3)
ax[index].set_title("$G_0$")
with torch.no_grad():
    for i in range(N_plots-1):
        for _ in range(N_skip):
            g_ts_test = model_FM.step(g_ts_test, torch.Tensor([t]), Δt)
            t += Δt
        g_ts_test_plot = g_ts_test.to("cpu")
        index = np.unravel_index(i+1, (N_rows, N_cols))
        ax[index].quiver(g_ts_test_plot[:, 0], g_ts_test_plot[:, 1], torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2]))
        ax[index].set_xlim(-3, 3)
        ax[index].set_ylim(-3, 3)
        ax[index].set_title(f"$G_{{{t:.2f}}}$")
    g_1s_test = g_ts_test

In [18]:
generate_videos = True

In [None]:
if generate_videos:
    t = 0
    N_steps = 240
    Δt = 1. / N_steps
    metadata = {'title': 'Flow Matching SE(2)', 'artist': 'Matplotlib'}
    writer = FFMpegWriter(fps=60, metadata=metadata)

    g_ts_test = g_0s_test.detach().clone()

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))

    with writer.saving(fig, "flow_matching_SE2.mp4", dpi=300):
        g_ts_test_plot = g_ts_test.to("cpu")
        ax.quiver(g_ts_test_plot[:, 0], g_ts_test_plot[:, 1], torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2]))
        ax.set_title(f"$G_{{{t:.2f}}}$")
        ax.set_xlim(-3, 3)
        ax.set_ylim(-3, 3)
        writer.grab_frame()
        for frame in range(N_steps):
            t = frame * Δt
            with torch.no_grad():
                g_ts_test = model_FM.step(g_ts_test, torch.Tensor([t]), Δt)
            ax.clear()
            g_ts_test_plot = g_ts_test.to("cpu")
            ax.quiver(g_ts_test_plot[:, 0], g_ts_test_plot[:, 1], torch.cos(g_ts_test_plot[:, 2]), torch.sin(g_ts_test_plot[:, 2]))
            ax.set_title(f"$G_{{{t:.2f}}}$")
            ax.set_xlim(-3, 3)
            ax.set_ylim(-3, 3)
            writer.grab_frame()