In [None]:
import torch
from lieflow.groups import Rn
from lieflow.models import (
    get_model_FM,
    LogarithmicDistance
)
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from functools import partial
from math import floor
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern"],
    "font.size": 10.0,
    "text.latex.preamble": r"\usepackage{lmodern} \usepackage{amssymb} \usepackage{amsmath}"
})

In [None]:
r2 = Rn(2)

In [None]:
def generate_normals(N, μ):
    return torch.randn(N, 2) + torch.tensor([μ, 0.]) + (2. * (torch.randn(N, 1) > 0) - 1.) * torch.tensor([0., 2.*μ])

In [None]:
μ = 8.
generate_x_0 = partial(generate_normals, μ=-μ)
generate_x_1 = partial(generate_normals, μ=μ)

In [None]:
EPSILON = 0.03
N = 2**14
BATCH_SIZE = 2**10
EPOCHS = 20
WEIGHT_DECAY = 0.001
LEARNING_RATE = 1e-2
H = 64 # Width
L = 3 # Number of layers is L + 2
device = "cpu"

In [None]:
def train_model(x_0s, x_1s, epochs, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE,
                weight_decay=WEIGHT_DECAY, visualise=False):
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(x_0s, x_1s), batch_size=batch_size, shuffle=True
    )

    model_FM = get_model_FM(r2, H=H, L=L).to(device)
    print("Number of parameters: ", model_FM.parameter_count)
    optimizer_FM = torch.optim.Adam(model_FM.parameters(), learning_rate, weight_decay=weight_decay)
    loss = LogarithmicDistance(torch.Tensor([1., 1.]))

    losses_FM = torch.zeros(epochs)
    for i in tqdm(range(epochs)):
        losses_FM[i] = model_FM.train_network(device, train_loader, optimizer_FM, loss)

    if visualise:
        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
        ax.plot(losses_FM)
        ax.set_title("Batch Loss Flow Matching")
        ax.set_xscale("log")
        ax.set_yscale("log")

    return model_FM

In [None]:
x_0s = generate_x_0(N)
x_1s = generate_x_1(N)

model_FM = train_model(x_0s, x_1s, epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE,
                       weight_decay=WEIGHT_DECAY, visualise=True)
model_FM.eval()

In [None]:
def visualise_model(model_FM):
    with torch.no_grad():
        N_samples = 2**7
        t = 0
        N_steps = 120
        Δt = 1. / N_steps
        trajectories = torch.zeros(N_steps+1, N_samples, 2)
        x_0s = generate_x_0(N_samples)
        trajectories[0] = x_0s
        xs = x_0s
        for i in range(N_steps):
            xs = model_FM.step(xs, torch.tensor([t]), Δt)
            trajectories[i+1] = xs
            t += Δt
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.scatter(trajectories[0, :, 0], trajectories[0, :, 1], color="tab:green", label="X_0")
    for i in range(N_samples):
        ax.plot(trajectories[:, i, 0], trajectories[:, i, 1], color="tab:blue", label="X_t")
    ax.scatter(trajectories[-1, :, 0], trajectories[-1, :, 1], color="tab:red", label="X_1")
    ax.set_aspect("equal")
    # ax.set_xlim(-μ - 5, μ + 5)
    # ax.set_ylim(-μ - 5, μ + 5)

In [None]:
visualise_model(model_FM)

In [None]:
def reflow(model_FM, prod_frac=0.):
    N_prod = floor(N * prod_frac / 2) * 2
    x_0s_prod = generate_x_0(N_prod)
    x_1s_prod = generate_x_1(N_prod)
    
    x_0s_init = generate_x_0((N - N_prod)//2)
    x_1s_init = generate_x_1((N - N_prod)//2)
    with torch.no_grad():
        t = 0
        N_steps = 120
        Δt = 1. / N_steps
        xs = x_0s_init
        for _ in range(N_steps):
            xs = model_FM.step(xs, torch.tensor([t]), Δt)
            t += Δt
        x_1s = torch.concatenate((x_1s_prod, xs, x_1s_init), dim=0)

    with torch.no_grad():
        t = 1
        N_steps = 120
        Δt = 1. / N_steps
        xs = x_1s_init
        for _ in range(N_steps):
            xs = model_FM.step_back(xs, torch.tensor([t]), Δt)
            t -= Δt
        x_0s = torch.concatenate((x_0s_prod, x_0s_init, xs), dim=0)

    model_FM = train_model(x_0s, x_1s, epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LEARNING_RATE,
                           weight_decay=WEIGHT_DECAY, visualise=True)
    model_FM.eval()

    return model_FM

In [None]:
prod_fracs = [0.5, 0.]
for prod_frac in prod_fracs:
    model_FM = reflow(model_FM, prod_frac=prod_frac)
    visualise_model(model_FM)

In [None]:
visualise_model(model_FM)