In [1]:
import math
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torch
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons
from torchcfm.models.models import *
from torchcfm.utils import *

from src.model.cfm import ConditionalFlowMatcher, VariancePreservingConditionalFlowMatcher
from src.utils.dataset import generate_checkerboard_dataset

savedir = "models/checkerboard_ot"
os.makedirs(savedir, exist_ok=True)

In [None]:
%%time
sigma = 0.002
dim = 2
batch_size = 256
model = MLP(dim=dim, time_varying=True)
optimizer = torch.optim.Adam(model.parameters())

cfm = ConditionalFlowMatcher(sigma=sigma) # or VariancePreservingConditionalFlowMatcher 

start = time.time()
for k in range(300_000):
    optimizer.zero_grad()

    x0 = torch.randn(batch_size, dim)
    x1 = torch.tensor(
        generate_checkerboard_dataset(batch_size, 1, 4),
        dtype=torch.float32
    )
    t, xt, ut = cfm.sample_location_and_conditional_flow(x0, x1)

    vt = model(torch.cat([xt, t[:, None]], dim=-1))
    loss = torch.mean((vt - ut) ** 2)

    loss.backward()
    optimizer.step()

    if (k + 1) % 10_000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        node = NeuralODE(
            torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
        )
        with torch.no_grad():
            traj = node.trajectory(
                torch.randn(1024, dim),
                t_span=torch.linspace(0, 1, 100),
            )
        gens = traj[-1]
        plt.scatter(gens[:, 0], gens[:, 1])
        plt.show()

        torch.save(model.state_dict(), f"{savedir}/checkpoint.pt")