In [None]:
import torch
import torch.nn as nn
import numpy as np
import Rn
import flowfield
import matplotlib.pyplot as plt
from matplotlib.animation import FFMpegWriter
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from tqdm import tqdm

In [None]:
r2 = Rn.Rn(2)

In [None]:
generate_videos = False

# Distributions

In [None]:
N = 2**12

## $X_0 \sim \mathcal{N}(0, I)$

In [None]:
def generate_normals(N, dim=2):
    return torch.Tensor(np.random.randn(N, dim))

In [None]:
x_0s = generate_normals(N)

## $X_1 \sim \operatorname{Uniform}(S^1)$

In [None]:
def generate_uniforms_on_circle(N, dim=2, ε=0.05):
    normals = generate_normals(N, dim)
    lengths = (normals**2).sum(dim=-1).sqrt()
    non_zero = lengths > 0.
    circle_samples = normals[non_zero] / lengths[non_zero, None]
    return circle_samples + ε * torch.randn(circle_samples.shape)

In [None]:
x_1s = generate_uniforms_on_circle(N)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].scatter(x_0s[:, 0], x_0s[:, 1], marker=".")
ax[0].set_xlim(-3, 3)
ax[0].set_ylim(-3, 3)
ax[0].set_title("$X_0$")
ax[1].scatter(x_1s[:, 0], x_1s[:, 1], marker=".")
ax[1].set_xlim(-3, 3)
ax[1].set_ylim(-3, 3)
ax[1].set_title("$X_1$");

# Models

## Training

In [None]:
BATCH_SIZE = 2**10
EPOCHS = 100
device = "cpu"

In [None]:
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(x_0s, x_1s), batch_size=BATCH_SIZE, shuffle=True
)

In [None]:
model_FM = flowfield.FlowField(r2).to(device)
optimizer_FM = torch.optim.Adam(model_FM.parameters(), 1e-2)
# loss = nn.MSELoss()
loss = flowfield.LogarithmicDistance(torch.Tensor([1., 1.]))

In [None]:
losses_FM = np.zeros(EPOCHS)
for i in tqdm(range(EPOCHS)):
    losses_FM[i] = model_FM.train_network(device, train_loader, optimizer_FM, loss)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(losses_FM)
ax.set_title("Batch Loss Flow Matching");

In [None]:
model_SCFM = flowfield.ShortCutField(r2).to(device)
optimizer_SCFM = torch.optim.Adam(model_SCFM.parameters(), 1e-2, weight_decay=0)

In [None]:
losses_SCFM = np.zeros(EPOCHS)
for i in tqdm(range(EPOCHS)):
    losses_SCFM[i] = model_SCFM.train_network(device, train_loader, optimizer_SCFM, loss)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(losses_SCFM)
ax.set_title("Batch Loss Short Cut Flow Matching");

# Testing

In [None]:
model_FM.eval()
model_SCFM.eval()
x_0s_test = generate_normals(2**6).to(device)

In [None]:
N_steps = 3
N_rows = 2
N_cols = 2
N_plots = N_rows * N_cols
N_skip = int(N_steps / (N_plots - 1))
x_ts_test = x_0s_test.detach().clone()
t = 0
Δt = 1 / N_steps
fig, ax = plt.subplots(N_rows, N_cols, figsize=(N_cols * 5, N_rows * 5))
x_0s_test_plot = x_0s_test.to("cpu")
index = np.unravel_index(0, (N_rows, N_cols))
ax[index].scatter(x_0s_test_plot[:, 0], x_0s_test_plot[:, 1], marker=".")
ax[index].set_xlim(-3, 3)
ax[index].set_ylim(-3, 3)
ax[index].set_title("$X_0$")
with torch.no_grad():
    for i in range(N_plots-1):
        for _ in range(N_skip):
            x_ts_test = model_FM.step(x_ts_test, torch.Tensor([t]), Δt)
            t += Δt
        x_ts_test_plot = x_ts_test.to("cpu")
        index = np.unravel_index(i+1, (N_rows, N_cols))
        ax[index].scatter(x_ts_test_plot[:, 0], x_ts_test_plot[:, 1], marker=".")
        ax[index].set_xlim(-3, 3)
        ax[index].set_ylim(-3, 3)
        ax[index].set_title(f"$X_{{{t:.2f}}}$")
    x_1s_test = x_ts_test

In [None]:
N_steps = 3
N_rows = 2
N_cols = 2
N_plots = N_rows * N_cols
N_skip = int(N_steps / (N_plots - 1))
x_ts_test = x_0s_test.detach().clone()
t = 0
Δt = 1 / N_steps
fig, ax = plt.subplots(N_rows, N_cols, figsize=(N_cols * 5, N_rows * 5))
x_0s_test_plot = x_0s_test.to("cpu")
index = np.unravel_index(0, (N_rows, N_cols))
ax[index].scatter(x_0s_test_plot[:, 0], x_0s_test_plot[:, 1], marker=".")
ax[index].set_xlim(-3, 3)
ax[index].set_ylim(-3, 3)
ax[index].set_title("$X_0$")
with torch.no_grad():
    for i in range(N_plots-1):
        for _ in range(N_skip):
            x_ts_test = model_SCFM.step(x_ts_test, torch.Tensor([t]), torch.Tensor([Δt]))
            t += Δt
        x_ts_test_plot = x_ts_test.to("cpu")
        index = np.unravel_index(i+1, (N_rows, N_cols))
        ax[index].scatter(x_ts_test_plot[:, 0], x_ts_test_plot[:, 1], marker=".")
        ax[index].set_xlim(-3, 3)
        ax[index].set_ylim(-3, 3)
        ax[index].set_title(f"$X_{{{t:.2f}}}$")
    x_1s_test = x_ts_test

In [None]:
if generate_videos:
    metadata = {'title': 'Flow Matching', 'artist': 'Matplotlib'}
    writer = FFMpegWriter(fps=60, metadata=metadata)

    x_ts_test = x_0s_test.detach().clone()

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))

    with writer.saving(fig, "flow_matching.mp4", dpi=300):
        x_ts_test_plot = x_ts_test.to("cpu")
        ax.scatter(x_ts_test_plot[:, 0], x_ts_test_plot[:, 1], marker=".")
        ax.set_title(f"$X_{{{t:.2f}}}$")
        ax.set_xlim(-3, 3)
        ax.set_ylim(-3, 3)
        writer.grab_frame()
        for frame in range(N_steps):
            t = frame * Δt
            with torch.no_grad():
                x_ts_test = model_FM.step(x_ts_test, torch.Tensor([t]), Δt)
            ax.clear()
            ax.scatter(x_ts_test[:, 0], x_ts_test[:, 1])
            ax.set_title(f"$X_{{{t:.2f}}}$")
            ax.set_xlim(-3, 3)
            ax.set_ylim(-3, 3)
            writer.grab_frame()

In [None]:
N_grid = 9
xs, ys = torch.meshgrid(torch.linspace(-3, 3, N_grid + 2)[1:-1], torch.linspace(-3, 3, N_grid + 2)[1:-1], indexing="xy")
xs = xs.flatten()
ys = ys.flatten()
grid = torch.vstack((xs, ys)).T.contiguous().to(device)

In [None]:
N_rows = 3
N_cols = 3
N_plots = N_rows * N_cols
N_steps = N_plots - 1

t = 0
Δt = 1 / N_steps

norm = Normalize()
colormap = cm.inferno

fig, ax = plt.subplots(N_rows, N_cols, figsize=(N_cols * 5, N_rows * 5))
for i in range(N_plots):
    with torch.no_grad():
        vectors = model_FM(grid, t * torch.ones(len(grid))[..., None]).to("cpu")

    colors = torch.arctan2(vectors[:, 0], vectors[:, 1])
    norm.autoscale(colors)

    index = np.unravel_index(i, (N_rows, N_cols))
    ax[index].quiver(xs, ys, vectors[:, 0], vectors[:, 1], color=colormap(norm(colors)))
    ax[index].set_xlim(-3, 3)
    ax[index].set_ylim(-3, 3)
    ax[index].set_title(f"$u_{{{t:.2f}}}$")
    t += Δt

In [None]:
if generate_videos:
    metadata = {'title': 'Flow Field', 'artist': 'Matplotlib'}
    writer = FFMpegWriter(fps=60, metadata=metadata)

    N_steps = 500
    t = 0
    Δt = 1 / N_steps

    norm = Normalize()
    colormap = cm.inferno

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))

    with writer.saving(fig, "flow_field.mp4", dpi=300):
        for frame in range(N_steps):
            t = frame * Δt
            with torch.no_grad():
                vectors = model_FM(grid, t * torch.ones(len(grid))[..., None]).to("cpu")

            colors = torch.arctan2(vectors[:, 0], vectors[:, 1])
            norm.autoscale(colors)
            
            ax.clear()
            ax.quiver(xs, ys, vectors[:, 0], vectors[:, 1], color=colormap(norm(colors)))
            ax.set_title(f"$X_{{{t:.2f}}}$")
            ax.set_xlim(-3, 3)
            ax.set_ylim(-3, 3)
            writer.grab_frame()
        writer.grab_frame()