In [None]:
%env CUDA_VISIBLE_DEVICES = 3

In [None]:
%load_ext autoreload
%autoreload 2
import os

import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import torch
import torchsde

from torchdyn.core import NeuralODE
from tqdm import tqdm
from torchcfm.conditional_flow_matching import *
from torchcfm.models import MLP
from torchcfm.utils import plot_trajectories, torch_wrapper
from einops import rearrange
from read_ks_data import get_batch

savedir = "models/KS-Equation"
os.makedirs(savedir, exist_ok=True)

In [None]:
data = np.load("/home/meet/FlowMatchingTests/conditional-flow-matching/examples/data/ks_data.npy")
data = rearrange(data, "(tr num) t s -> tr (num t) s", num=4)
X = (data - data.mean())/data.std()

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
sigma = 0.1
dim = 32
torch.manual_seed(42)
ot_cfm_model = MLP(dim=dim, time_varying=True, w=64).to(device)
ot_cfm_optimizer = torch.optim.Adam(ot_cfm_model.parameters(), 1e-4)
FM = ConditionalFlowMatcher(sigma=sigma)

In [None]:
for i in tqdm(range(100000)):
    ot_cfm_optimizer.zero_grad()
    t, xt, ut = get_batch(FM, X[:900], 16, 16, X[0].shape[0], device)
    vt = ot_cfm_model(torch.cat([xt, t[:, None]], dim=-1))
    loss = torch.mean((vt - ut) ** 2)
    loss.backward()
    ot_cfm_optimizer.step()
    if i%10000 == 0:
        print(loss.item())

In [None]:
node = NeuralODE(torch_wrapper(ot_cfm_model), solver="dopri5", sensitivity="adjoint")
with torch.no_grad():
    out_lst = [X[983:984,0]]
    init = torch.from_numpy(X[983:984,0]).float().to(device)
    for i in range(32):
        inp = init if i == 0 else torch.from_numpy(out).float().to(device)
        traj = node.trajectory(
            inp,
            t_span=torch.linspace(0, 1, 2),
        ).detach().cpu().numpy()
        out = traj[-1]
        out_lst.append(out)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=3)
i = axes[0].imshow(np.concatenate(out_lst))
fig.colorbar(i, ax=axes[0])
i = axes[1].imshow(X[983,0:33])
fig.colorbar(i, ax=axes[1])
i = axes[2].imshow(np.abs(X[983,0:33] - np.concatenate(out_lst)))
fig.colorbar(i, ax=axes[2])