In [12]:
import sys

sys.path.append("../")
from pathlib import Path

import torch
from matplotlib import pyplot as plt
from script.datasets.synthetic import CircleMixturedFilledGaussianDataset
from script.process_based_generative_model import ProcessBasedGenerativeModel
import numpy as np

In [None]:
# Write your trained .ckpt file paths here
A2AFM_PATH = 
DIFFUSION_CFG_PATH = 

# Load model

In [None]:
path = Path(A2AFM_PATH)
path = list(path.glob("*.ckpt"))[0]
print(path)
model = ProcessBasedGenerativeModel.load_from_checkpoint(path)
model.eval()

# Prepare Evaluation Dataset

In [None]:
dataset = CircleMixturedFilledGaussianDataset(
    10000, data_config=model.config.data_config
)
plt.scatter(dataset.xdata[:, 0], dataset.xdata[:, 1])

In [None]:
import ot
target_theta = torch.pi/2
initial_theta = torch.pi/4
initc = torch.ones(100, 1, device="cuda") * initial_theta
xdata = dataset.sample(initial_theta, 100)
targc = torch.ones(100, 1, device="cuda") * target_theta
ydata = dataset.sample(target_theta, 100)

x0 = xdata
x1 = ydata
M = ot.dist(x0, x1)
pi = ot.emd(torch.ones(len(x0)) / len(x0), torch.ones(len(x1)) / len(x1), M)
idx = torch.argmax(pi, dim=1)
traj = torch.stack(
    ([(1 - t) * x0 + t * x1[idx] for t in torch.linspace(0, 1, 100)]), dim=0
)
plt.grid()
plt.scatter(dataset.xdata[:, 0], dataset.xdata[:, 1],alpha=0.1,s=1,c="gray")
plt.scatter(traj[:, :, 0], traj[:, :, 1], alpha=1, s=0.1, c="tan")
plt.scatter(ydata[:, 0], ydata[:, 1], label="target points")
plt.scatter(xdata[:, 0], xdata[:, 1], label="initial points")
plt.gca().set_aspect("equal")
plt.xticks([0.0,0.5,1.0,1.5,2.0])
plt.yticks([0.0,0.5,1.0,1.5,2.0])
plt.legend()

# Evaluate A2A-FM

In [122]:
traj = model.evaluator.integrate(xdata.cuda().clone(), targc.clone(), initc=initc.clone()).cpu()

In [None]:
plt.scatter(dataset.xdata[:, 0], dataset.xdata[:, 1],alpha=0.1,s=1,c="gray")
plt.scatter(ydata[:, 0], ydata[:, 1], label="target points")
plt.scatter(traj[:, :, 0], traj[:, :, 1], alpha=1, s=0.1, c="tan")
plt.scatter(xdata[:, 0], xdata[:, 1], label="initial points")
plt.scatter(traj[-1, :, 0], traj[-1, :, 1], alpha=1, label="generated points")
plt.legend()
plt.gca().set_aspect("equal")
plt.grid()

# Evaluate partial diffusion

In [None]:
path = Path(DIFFUSION_CFG_PATH)
path = list(path.glob("*.ckpt"))[0]
model = ProcessBasedGenerativeModel.load_from_checkpoint(path)
model.eval()
traj = model.evaluator.partial_diffusion(xdata.cuda(), targc,timesteps=0.3).cpu()
traj = torch.stack(
    [traj[-1] * t + xdata * (1 - t) for t in torch.linspace(0, 1, 100)], dim=0
)

In [None]:
plt.scatter(dataset.xdata[:, 0], dataset.xdata[:, 1],alpha=0.1,s=1,c="gray")
plt.scatter(ydata[:, 0], ydata[:, 1], label="target points")
plt.scatter(traj[:, :, 0], traj[:, :, 1], alpha=1, s=0.1, c="tan")
plt.scatter(xdata[:, 0], xdata[:, 1], label="initial points")
plt.scatter(traj[-1, :, 0], traj[-1, :, 1], alpha=1, label="generated points")
plt.legend()
plt.gca().set_aspect("equal")
plt.grid()

# Evaluate Multimariginal SI 

In [None]:
from itertools import combinations
from pathlib import Path

import hydra
from hydra.utils import instantiate
from lightning.pytorch.callbacks import ModelCheckpoint
from stochastic_interpolant.path_optimizer import PathOptimizer
from stochastic_interpolant.utils import Velocity
from stochastic_interpolant.vector_field import VelocityField


def get_filename(logdir):
    versions = sorted(list(logdir.glob("version_*")))
    if len(versions) == 0:
        versions.append(logdir / "version_0")
    latest_number = int(versions[-1].name[8:])
    return logdir / f"version_{latest_number+1}"


with hydra.initialize(
    version_base=None, config_path="../stochastic_interpolants_config"
):
    config = hydra.compose(config_name="synthetic")
val_datasets = instantiate(config.val_datasets, _recursive_=False)
K = len(val_datasets)
root_dir = Path(config.savedir)
ckpt_dirs_vf = [root_dir / f"g{k}" for k in range(K)]
c_list = config.c_list
assert len(c_list) == K
gs = []
gt_versions = [-1] * len(ckpt_dirs_vf)
if hasattr(config, "ckpt_versions_gt"):
    gt_versions = config.ckpt_versions_gt
    print(gt_versions)

for ckpt_dir, v in zip(ckpt_dirs_vf, gt_versions):
    ckpt_path = ckpt_dir / "lightning_logs" / f"version_{v}" / "checkpoints"
    ckpt_path = list(ckpt_path.glob("*.ckpt"))[0]
    print(f"using check point {ckpt_path}")
    gs.append(
        VelocityField.load_from_checkpoint(ckpt_path, myconfig=config.vector_field)
    )
logdir = root_dir / "total_logs"
logdir = get_filename(logdir)


def prepare_vel(i, j):
    reverse = False
    if j < i:
        reverse = True
        tmp = i
        i = j
        j = tmp
    ckpt_dir = root_dir / f"alpha_{i}_{j}"
    ckpt_path = sorted(list(ckpt_dir.rglob("*.ckpt")))[-1]
    print(f"using check point {ckpt_path}")
    alpha = PathOptimizer.load_from_checkpoint(
        ckpt_path, myconfig=config.alpha, gs=gs, i=i, j=j, datasets=val_datasets
    )
    vel = Velocity(gs, alpha, reverse=reverse)
    return vel


def prepare_vels():
    K = config.K
    alphas = [[None] * K for _ in range(K)]
    for i in range(K):
        for j in range(K):
            if i == j:
                continue
            alphas[i][j] = prepare_vel(i, j)
    return alphas


def get_index(c: torch.Tensor):
    classes = (
        torch.bucketize(c, torch.linspace(0, torch.pi / 2, K + 1, device=c.device)) - 1
    )
    classes = torch.clip(classes, min=0)
    return classes


vels = prepare_vels()

In [33]:
from torchdiffeq import odeint


@torch.no_grad()
def integrate(initx, initc, targc):
    b, d = initx.shape
    ts = torch.linspace(0, 1, 100, device=initx.device)
    traj = initx.expand(len(ts), b, d).clone()
    i = get_index(initc)[0]
    js = get_index(targc)
    for j in range(10):
        x0 = initx[(js == j).squeeze()]
        if len(x0) == 0:
            continue
        if i == j:
            continue
        print(i, j)
        vel = vels[i][j]
        traj[:, (js == j).squeeze(), :] = odeint(vel, x0, ts).squeeze(1).clone()
    return traj

In [None]:
traj = integrate(xdata.cuda(), initc, targc).cpu()

In [None]:
plt.scatter(dataset.xdata[:, 0], dataset.xdata[:, 1],alpha=0.1,s=1,c="gray")
plt.scatter(ydata[:, 0], ydata[:, 1], label="target points")
plt.scatter(traj[:, :, 0], traj[:, :, 1], alpha=1, s=0.1, c="tan")
plt.scatter(xdata[:, 0], xdata[:, 1], label="initial points")
plt.scatter(traj[-1, :, 0], traj[-1, :, 1], alpha=1, label="generated points")
plt.legend()
plt.gca().set_aspect("equal")
plt.grid()

# Qunatative Evaluations

In [None]:
from tqdm.notebook import tqdm
def eval_model(model,targthetas,initthetas):
    error = []
    for initial_theta,target_theta in  tqdm(zip(initthetas,targthetas),total=len(targthetas)):
        initc = torch.ones(100, 1, device="cuda") * initial_theta
        xdata = dataset.sample(initial_theta, 100)
        targc = torch.ones(100, 1, device="cuda") * target_theta
        ydata = dataset.sample(target_theta, 100)
        traj = model.evaluator.integrate(xdata.cuda().clone(), targc.clone(), initc=initc.clone()).cpu()
        x0 = xdata
        x1 = ydata
        M = ot.dist(x0, x1)
        pi = ot.emd(torch.ones(len(x0)) / len(x0), torch.ones(len(x1)) / len(x1), M)
        idx = torch.argmax(pi, dim=1)
        error.append(torch.mean((x1[idx] - traj[-1]) ** 2))
    return np.mean(error)
def eval_model_partial_diffusion(model,targthetas,initthetas,t=0.5,w=0.3):
    error = []
    for initial_theta,target_theta in tqdm(zip(initthetas,targthetas),total=len(targthetas)):
        initc = torch.ones(100, 1, device="cuda") * initial_theta
        xdata = dataset.sample(initial_theta, 100)
        targc = torch.ones(100, 1, device="cuda") * target_theta
        ydata = dataset.sample(target_theta, 100)
        traj = model.evaluator.partial_diffusion(xdata.cuda(), targc,timesteps=t,w=w).cpu()
        x0 = xdata
        x1 = ydata
        M = ot.dist(x0, x1)
        pi = ot.emd(torch.ones(len(x0)) / len(x0), torch.ones(len(x1)) / len(x1), M)
        idx = torch.argmax(pi, dim=1)
        error.append(torch.mean((x1[idx] - traj[-1]) ** 2))
    return np.mean(error)
        
def eval_model_sto_interp(targthetas,initthetas):
    error = []
    for initial_theta,target_theta in tqdm(zip(initthetas,targthetas),total=len(targthetas)):
        initc = torch.ones(100, 1, device="cuda") * initial_theta
        xdata = dataset.sample(initial_theta, 100)
        targc = torch.ones(100, 1, device="cuda") * target_theta
        ydata = dataset.sample(target_theta, 100)
        traj  = integrate(xdata.cuda(), initc, targc).cpu()
        x0 = xdata
        x1 = ydata
        M = ot.dist(x0, x1)
        pi = ot.emd(torch.ones(len(x0)) / len(x0), torch.ones(len(x1)) / len(x1), M)
        idx = torch.argmax(pi, dim=1)
        error.append(torch.mean((x1[idx] - traj[-1]) ** 2))
        
    return np.mean(error)
    

In [None]:
targthetas=np.random.rand(100)*np.pi/2
initthetas=np.random.rand(100)*np.pi/2
path = Path(
    A2AFM_PATH
)
path = list(path.glob("*.ckpt"))[0]
print(path)
model = ProcessBasedGenerativeModel.load_from_checkpoint(path)
model.eval()
ours_loss = eval_model(model,targthetas,initthetas)

In [None]:
path = Path(
    DIFFUSION_CFG_PATH
)
path = list(path.glob("*.ckpt"))[0]
model = ProcessBasedGenerativeModel.load_from_checkpoint(path)
model.eval()
partial_diffuson_loss = eval_model_partial_diffusion(model,targthetas,initthetas,w=0.3,t=0.3)
path = Path(
   DIFFUSION_CFG_PATH
)

In [None]:
sto_interp_loss = eval_model_sto_interp(targthetas,initthetas)

In [None]:
print("MSE from pairwise OT (A2A-FM): ", ours_loss)
print("MSE from pairwise OT (Partial Diffusion): ", partial_diffuson_loss)
print("MSE from pairwise OT (Stochastic Interpolant): ", sto_interp_loss)