## {Gaussian, Uniform} to 8-Gaussian, guidance with Learned, MC, CEG

In [1]:
%load_ext autoreload
%autoreload 2

from functools import partial
from typing import List, Tuple
from guided_flow.backbone.mlp import MLP
from guided_flow.backbone.wrapper import ExpEnergyMLPWrapper, GuidedMLPWrapper, MLPWrapper
from guided_flow.config.sampling import GuideFnConfig
from guided_flow.distributions.base import BaseDistribution, get_distribution
from guided_flow.distributions.gaussian import GaussianDistribution
from guided_flow.flow.optimal_transport import OTPlanSampler
from guided_flow.guidance.gradient_guidance import wrap_grad_fn
from guided_flow.utils.misc import deterministic
from guided_flow.utils.metrics import compute_w2 as w2
import torch
from torchdyn.core import NeuralODE
import numpy as np
from torch.distributions import Normal, Independent
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import os
from tqdm import tqdm
from guided_flow.config.sampling import ODEConfig
from fk_steering import evaluate_fk, create_fk_guide_cfg


# from guided_flow.utils.kl_divergence import compute_kl_divergence
MLP_WIDTH = 256
TRAINING_B = 256 # OT CFM training batch size


def sample_x1_frompeJ(x1_sampler, x1_dist, device, B):
    x1 = None
    while x1 is None or x1.shape[0] < B:
        x1_ = x1_sampler(B).to(device)
        weights = torch.exp(-x1_dist.get_J(x1_))
        acc_prob = weights / weights.max()
        random_numbers = torch.rand(B, device=device)
        x1_ = x1_[random_numbers < acc_prob]
        if x1 is None:
            x1 = x1_
        else:
            x1 = torch.cat([x1, x1_], 0)
    x1 = x1[:B]
    return x1


def compute_w2(trajs, cfgs: List[GuideFnConfig]):
    w2s = []
    for traj, cfg in zip(trajs, cfgs):
        x0_dist = get_distribution(cfg.dist_pair[0])
        x1_dist = get_distribution(cfg.dist_pair[1])
        
        x1 = sample_x1_frompeJ(x1_dist.sample, x1_dist, cfg.ode_cfg.device, cfg.ode_cfg.batch_size)
        w2s.append(w2(traj[-1], x1))
    return w2s

def compute_unweighted_w2(trajs, cfgs: List[GuideFnConfig]):
    w2s = []
    for traj, cfg in zip(trajs, cfgs):
        x0_dist = get_distribution(cfg.dist_pair[0])
        x1_dist = get_distribution(cfg.dist_pair[1])

        x1 = x1_dist.sample(cfg.ode_cfg.batch_size, cfg.ode_cfg.device)
        w2s.append(w2(traj[-1], x1))
    return w2s


def get_mc_guide_fn(x0_dist: BaseDistribution, x1_dist: BaseDistribution, mc_cfg: GuideFnConfig, cfm: str):

    def log_cfm_p_t1(x1, xt, t):
        # xt = t x1 + (1 - t) x0 -> x0 = xt / (1 - t) - t / (1 - t) x1
        x0 = xt / (1 - t + mc_cfg.ep) - (t + mc_cfg.ep) / (1 - t + mc_cfg.ep) * x1 # (B, 2)
        p1t = x0_dist.prob(x0).clamp(1e-8) / (1 - t[0] + mc_cfg.ep) ** 2 # (B,)
        log_p1t = p1t.log()
        # print(log_p1t.mean())
        return log_p1t
        
    def ot_cfm_log_p_tz(x0, x1, xt, t, std=None):
        mean = t * x1 + (1 - t) * x0 # (B, 2)
         # g.t. std: 0. Too small: requires large mc_batch_size; Too large: inaccurate
        base_dist = Normal(loc=mean, scale=std)
        distribution = Independent(base_dist, 1)
        log_p1t = distribution.log_prob(xt) # (B,)
        return log_p1t
        
    def guide_fn(t, x, dx_dt, model, x0=None, x1=None, Jx1=None):
        """
        Args:
            t: Tensor, shape (b, 1)
            x: Tensor, shape (b, dim)
            dx_dt: Tensor, shape (b, dim)
            model: MLP
        """
        # estimate E (e^{-J} / Z - 1) * u
        b = x.shape[0]
        B = mc_cfg.mc_batch_size
        x_ = x.repeat(B, 1) # (MC_B * b, 2)
        t_ = t.repeat(B * b, 1) # (MC_B * b)
        if cfm == 'cfm':
            log_p_t1_x = log_cfm_p_t1(x1, x_, t_) # (MC_B * b) # TODO
            log_p_t_x = log_p_t1_x.reshape(B, b, 1).logsumexp(0) - torch.log(torch.tensor(B, device=x.device)) # (MC_B, B, 1) -> (B, 1)
            log_p_t1_x_times_J_ = (log_p_t1_x + torch.log(Jx1)).reshape(B, b, 1) # (MC_B * b) -> (MC_B, b, 1)            
            logZ = torch.logsumexp(log_p_t1_x_times_J_, 0) - torch.log(torch.tensor(B, device=x.device)) - log_p_t_x # (b, 1)

            Z = torch.exp(logZ)
            u = (x1 - x_) / (1 - t_ + mc_cfg.ep) # (MC_B * b, dim)

            g = (log_p_t1_x.reshape(B, b, 1) - log_p_t_x.unsqueeze(0)).exp() * (Jx1.reshape(B, b, 1) / (Z + 1e-8).unsqueeze(0) - 1) * u.reshape(B, b, 2) # (MC_B, b, dim)

            return g.mean(0)
        
        elif cfm == 'ot_cfm':
            log_p_tz_x = ot_cfm_log_p_tz(x0, x1, x_, t_, std=mc_cfg.ot_std) # (MC_B * b)
            log_p_t_x = log_p_tz_x.reshape(B, b, 1).logsumexp(0) - torch.log(torch.tensor(B, device=x.device)) # (MC_B, b, 1) -> (b, 1)
            log_p_tz_x_times_J_ = (log_p_tz_x + torch.log(Jx1)).reshape(B, b, 1) # (MC_B * b) -> (MC_B, b, 1)
            
            logZ = torch.logsumexp(log_p_tz_x_times_J_, 0) - torch.log(torch.tensor(B, device=x.device)) - log_p_t_x # (b, 1)
            
            Z = torch.exp(logZ)
            u = x1 - x0 # (MC_B * b, dim)
            
            g = (log_p_tz_x.reshape(B, b, 1) - log_p_t_x.unsqueeze(0)).exp() * (Jx1.reshape(B, b, 1) / Z - 1) * u.reshape(B, b, 2) # (MC_B, b, dim)
            
            return g.mean(0)

    
    if cfm == 'cfm':
        x1 = x1_dist.sample(mc_cfg.mc_batch_size).to(mc_cfg.ode_cfg.device).unsqueeze(0).repeat(mc_cfg.ode_cfg.batch_size, 1, 1).permute(1, 0, 2).reshape(-1, 2)
        Jx1 = torch.exp(-mc_cfg.scale * x1_dist.get_J(x1))
        return partial(
            guide_fn, 
            x1=x1, 
            Jx1=Jx1
        )
    elif cfm == 'ot_cfm':
        x0 = x0_dist.sample(mc_cfg.mc_batch_size) # (MC_B, 2)
        x1 = x1_dist.sample(mc_cfg.mc_batch_size) # (MC_B, 2)
        x0_ = x0.to(mc_cfg.ode_cfg.device).unsqueeze(0).repeat(mc_cfg.ode_cfg.batch_size, 1, 1).permute(1, 0, 2).reshape(-1, 2)
        x1_ = x1.to(mc_cfg.ode_cfg.device).unsqueeze(0).repeat(mc_cfg.ode_cfg.batch_size, 1, 1).permute(1, 0, 2).reshape(-1, 2)
        J_ = torch.exp(-mc_cfg.scale * x1_dist.get_J(x1_)) # (MC_B * b)
        
        return partial(
            guide_fn, 
            x0=x0_, 
            x1=x1_, 
            Jx1=J_
        )

def get_guide_fn(dist: BaseDistribution, cfg: GuideFnConfig):
    def guide_fn(t, x, dx_dt, model):

        if cfg.guide_type == 'g_cov_A':
            x1_pred = x + dx_dt * (1 - t)
            J = dist.get_J(x1_pred)
            try:
                with torch.enable_grad():
                    x1_pred = x1_pred.requires_grad_(True)
                    J = dist.get_J(x1_pred)
                    grad = -torch.autograd.grad(J.sum(), x1_pred, create_graph=True)[0]
                    return grad
            except Exception as e:
                return torch.zeros_like(x)
        
        elif cfg.guide_type == 'g_cov_G':
            with torch.enable_grad():
                x = x.requires_grad_(True)
                x1_pred = x + model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1)) * (1 - t)
                J = dist.get_J(x1_pred)
                try:
                    grad = -torch.autograd.grad(J.sum(), x, create_graph=True)[0]
                    return grad
                except Exception as e:
                    return torch.zeros_like(x)
        else:
            raise ValueError(f"Unknown guide function: {cfg.guide_type}")
    # make scale and schedule
    return wrap_grad_fn(cfg.guide_scale, cfg.guide_schedule, guide_fn)

def get_sim_mc_guide_fn(x1_dist: BaseDistribution, cfg: GuideFnConfig):
    def guide_fn(t, x, dx_dt, model):
        """
        Implements guidance following Eq. 12
        Args:
            t: flow time. float
            x: current sample x_t. Tensor, shape (b, dim)
            dx_dt: current predicted VF. Tensor, shape (b, dim)
            model: flow model. MLP
        """
        x1_pred = x + dx_dt * (1 - t) # (B, 2)
        std = cfg.sim_mc_std
        
        x1 = torch.randn_like(x1_pred.unsqueeze(0).repeat(cfg.sim_mc_n, 1, 1)) * std + x1_pred # (cfg.sim_mc_n, B, 2)
        Jx1_ = torch.exp(-cfg.scale * x1_dist.get_J(x1.reshape(-1, 2))).reshape(cfg.sim_mc_n, -1) # (cfg.sim_mc_n, B)
        v = (x1 - x) / (1 - t + cfg.ep)  # Conditional VF v_{t|z} in Eq. 12 (cfg.sim_mc_n, B, 2)
        Z = Jx1_.mean(0) + 1e-8  # Z in Eq. 12 (B,)
        g = (Jx1_ / Z - 1).unsqueeze(2) * v  # g in Eq. 12 (cfg.sim_mc_n, B, 2)
        return g.mean(0)
    return wrap_grad_fn(cfg.guide_scale, cfg.guide_schedule, guide_fn)

def evaluate(x0_sampler, x1_sampler, model, guide_fn, cfg: ODEConfig):
    node = NeuralODE(
        GuidedMLPWrapper(
            model, 
            guide_fn=guide_fn,
            scheduler=lambda t: 1
        ), 
        solver="euler", sensitivity="adjoint", atol=1e-4, rtol=1e-4
    )

    with torch.no_grad():
        traj = node.trajectory(
            x0_sampler(cfg.batch_size).to(cfg.device), 
            t_span=torch.linspace(0, cfg.t_end, cfg.num_steps)
        )
    
    return traj


def sample_and_compute_w2(guide_cfgs: List[GuideFnConfig]):
    print("Monte Carlo batch size:", guide_cfgs[0].mc_batch_size)

    trajs = []

    for cfg in guide_cfgs:

        # Initialize samplers, model and guidance model
        x0_dist = get_distribution(cfg.dist_pair[0])
        x1_dist = get_distribution(cfg.dist_pair[1])

        x0_sampler = x0_dist.sample
        x1_sampler = x1_dist.sample

        model = MLP(dim=2, w=MLP_WIDTH, time_varying=True).to(cfg.ode_cfg.device)
        model.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))

        if cfg.guide_type == 'mc':
            # sample using flow model
            traj = evaluate(x0_sampler, x1_sampler, model, get_mc_guide_fn(x0_dist, x1_dist, cfg, cfg.cfm), cfg.ode_cfg)

        elif cfg.guide_type == 'learned':
            model_G = MLP(dim=2, out_dim=2, w=MLP_WIDTH, time_varying=True).to(cfg.ode_cfg.device)
            model_G.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/guidance_matching_{cfg.gm_type}_scale_{cfg.scale}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))
            traj = evaluate(x0_sampler, x1_sampler, model, MLPWrapper(model_G, scheduler=lambda t: 1., clamp=0), cfg.ode_cfg)
        
        elif cfg.guide_type == 'ceg':
            model_Z = MLP(dim=2, out_dim=1, w=MLP_WIDTH, time_varying=True, exp_final=False).to(cfg.ode_cfg.device)
            model_Z.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/ceg_scale_{cfg.scale}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))
            
            # 2D xy plane. make uniform grid
            XX = torch.linspace(0, 1, 100)
            YY = torch.linspace(0, 1, 100)
            XX, YY = torch.meshgrid(XX, YY, indexing='ij')
            xy = torch.stack([XX.flatten(), YY.flatten()], 1)
            t = torch.zeros(10000, 1) + 0.9
            fig, ax = plt.subplots()
            ax.imshow(model_Z(torch.cat([xy, t], 1).to(cfg.ode_cfg.device)).detach().cpu().numpy().reshape(100, 100))
            
            traj = evaluate(x0_sampler, x1_sampler, model, ExpEnergyMLPWrapper(model_Z, scheduler=lambda t: 1, clamp=1), cfg.ode_cfg)
        
        elif cfg.guide_type in ['g_cov_A', 'g_cov_G']:
            traj = evaluate(x0_sampler, x1_sampler, model, get_guide_fn(x1_dist, cfg), cfg.ode_cfg)

        elif cfg.guide_type == 'g_sim_MC':
            traj = evaluate(x0_sampler, x1_sampler, model, get_sim_mc_guide_fn(x1_dist, cfg), cfg.ode_cfg)
        elif cfg.guide_type == 'fk':
            traj = evaluate_fk(x0_sampler, x1_sampler, model, x1_dist, cfg.fk_params, cfg.ode_cfg)
        elif cfg.guide_type == 'plain' or cfg.guide_type == 'cfm':
            # Plain CFM without any guidance
            node = NeuralODE(
                MLPWrapper(model, scheduler=lambda t: 1, clamp=0),  # No guidance (scheduler returns 0)
                solver="euler", sensitivity="adjoint", atol=1e-4, rtol=1e-4
            )
            with torch.no_grad():
                traj = node.trajectory(
                    x0_sampler(cfg.ode_cfg.batch_size).to(cfg.ode_cfg.device),
                    t_span=torch.linspace(0, cfg.ode_cfg.t_end, cfg.ode_cfg.num_steps)
                )

        trajs.append(traj)
    return trajs, None

deterministic(0)

In [15]:
num_steps = 40
mc_batch_size = 10240
disp_traj_batch = 128
ode_batch_size = 1024
cfm = 'ot_cfm'

In [16]:

num_steps = 40
mc_batch_size = 10240
disp_traj_batch = 128
ode_batch_size = 1024

guide_cfgs_mc_cfm = [
    GuideFnConfig(cfm="cfm", dist_pair=('circle', 's_curve'), mc_batch_size=mc_batch_size, ep=5e-2, scale=1, ode_cfg=ODEConfig(t_end=1.0, num_steps=num_steps, device='cuda:0', batch_size=ode_batch_size), disp_traj_batch=disp_traj_batch), 
    GuideFnConfig(cfm="cfm", dist_pair=('uniform', '8gaussian'), mc_batch_size=mc_batch_size, ep=1e-3, ode_cfg=ODEConfig(t_end=1, num_steps=num_steps, device='cuda:0', batch_size=ode_batch_size), disp_traj_batch=disp_traj_batch), 
    GuideFnConfig(cfm="cfm", dist_pair=('8gaussian', 'moon'), mc_batch_size=mc_batch_size, ep=1e-2, scale=1, ode_cfg=ODEConfig(t_end=1.0, num_steps=num_steps, device='cuda:0', batch_size=ode_batch_size), disp_traj_batch=disp_traj_batch), 
]
deterministic(0)
w_2_mc_cfm = []
for i in range(10):
    trajs_mc_cfm, _ = sample_and_compute_w2(guide_cfgs_mc_cfm)
    w_2_mc_cfm.append(compute_w2(trajs_mc_cfm, guide_cfgs_mc_cfm))

w_2_mc_cfm

Monte Carlo batch size: 10240


  model.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))


Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240


[[0.3460798559184829, 0.20576434674641925, 0.10966435191182855],
 [0.3448578052581991, 0.16637437622492687, 0.11015870516418107],
 [0.3440059241586379, 0.14976293112519604, 0.14200447072864333],
 [0.3244642793909614, 0.16452719079974892, 0.1362307332987635],
 [0.2822554850349835, 0.15478536881468344, 0.11442039089484589],
 [0.33951585375635973, 0.16739490762591525, 0.13027120088743638],
 [0.33048232509887676, 0.2415759199582516, 0.12880329751762848],
 [0.35526426272108613, 0.20571899010862407, 0.11119113909698641],
 [0.3448609345231764, 0.16501166166307896, 0.11536080346503441],
 [0.3565914847923446, 0.1600057165067462, 0.12127416283577756]]

In [18]:
# Generate model wise w2 stats: mean and std
w_2_mc_cfm_mean = np.mean(w_2_mc_cfm, axis=0)
w_2_mc_cfm_std = np.std(w_2_mc_cfm, axis=0)
print(w_2_mc_cfm_mean)
print(w_2_mc_cfm_std)

[0.33683782 0.17809214 0.12193793]
[0.02043313 0.02799803 0.01108433]


In [17]:

num_steps = 40
ode_batch_size = 1024
mc_batch_size = 10240
cfm = 'ot_cfm'

guide_cfgs_mc_ot_cfm = [
    GuideFnConfig(cfm=cfm, dist_pair=('circle', 's_curve'), mc_batch_size=mc_batch_size, ep=0.05, ot_std=0.3, scale=1.5, ode_cfg=ODEConfig(t_end=1.0, num_steps=num_steps, device='cuda:0', batch_size=ode_batch_size), disp_traj_batch=disp_traj_batch), 
    GuideFnConfig(cfm=cfm, dist_pair=('uniform', '8gaussian'), mc_batch_size=mc_batch_size, ep=0.05, ot_std=0.3, scale=2, ode_cfg=ODEConfig(t_end=1, num_steps=num_steps, device='cuda:0', batch_size=ode_batch_size), disp_traj_batch=disp_traj_batch), 
    GuideFnConfig(cfm=cfm, dist_pair=('8gaussian', 'moon'), mc_batch_size=mc_batch_size, ep=0.05, ot_std=0.3, scale=1.5, ode_cfg=ODEConfig(t_end=1.0, num_steps=num_steps, device='cuda:0', batch_size=ode_batch_size), disp_traj_batch=disp_traj_batch), 
]

deterministic(0)
w_2_mc_ot_cfm = []
for i in range(10):
    trajs_mc_ot_cfm, _ = sample_and_compute_w2(guide_cfgs_mc_ot_cfm)
    w_2_mc_ot_cfm.append(compute_w2(trajs_mc_ot_cfm, guide_cfgs_mc_ot_cfm))

w_2_mc_ot_cfm

Monte Carlo batch size: 10240


  model.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))


Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240
Monte Carlo batch size: 10240


[[0.22817266275197415, 0.35572792732769526, 0.23363914126445148],
 [0.2321455058815066, 0.34102956150330666, 0.22832436719583352],
 [0.2300536040104048, 0.34544109775856835, 0.213040883054331],
 [0.262859070630868, 0.3376929216902852, 0.22018501862557827],
 [0.23093576848866196, 0.33428825630124015, 0.24293387694326773],
 [0.2726666177453744, 0.35065835867042466, 0.22368135970361785],
 [0.23124001267907576, 0.34102390660723786, 0.22876696080072315],
 [0.2587846497327946, 0.35123710064540464, 0.22417950437163864],
 [0.2746422271186173, 0.3352578834639605, 0.23988095549245136],
 [0.22147808917793294, 0.3352055758863617, 0.21528866545030392]]

In [19]:
w_2_mc_ot_cfm_mean = np.mean(w_2_mc_ot_cfm, axis=0)
w_2_mc_ot_cfm_std = np.std(w_2_mc_ot_cfm, axis=0)

print(w_2_mc_ot_cfm_mean)
print(w_2_mc_ot_cfm_std)

[0.24429782 0.34275626 0.22699207]
[0.01939162 0.00725538 0.00930691]


In [20]:
num_samples = 128
fk_steps = 40

guide_cfgs_fk_ot_cfm = [
    create_fk_guide_cfg(
        cfm=cfm,
        dist_pair=('circle', 's_curve'),
        scale=1.0,
        num_samples=num_samples,
        fk_steering_temperature=1.0,
        fk_potential_scheduler='harmonic_sum',
        resample_method='residual',
        resample_freq=10,
        ode_cfg=ODEConfig(t_end=1.0, num_steps=fk_steps, batch_size=ode_batch_size)
    ),
    create_fk_guide_cfg(
        cfm=cfm,
        dist_pair=('uniform', '8gaussian'),
        scale=1.0,
        num_samples=num_samples,
        resample_freq=10,
        fk_potential_scheduler='harmonic_sum',
        resample_method='residual',
        ode_cfg=ODEConfig(t_end=1.0, num_steps=fk_steps, batch_size=ode_batch_size)
    ),
    create_fk_guide_cfg(
        cfm=cfm,
        dist_pair=('8gaussian', 'moon'),
        scale=1.0,
        num_samples=num_samples,
        resample_freq=10,
        fk_potential_scheduler='harmonic_sum',
        resample_method='residual',
        ode_cfg=ODEConfig(t_end=1.0, num_steps=fk_steps, batch_size=ode_batch_size)
    ),
]

deterministic(0)
w2s_fk_ot_cfm = []
for i in range(10):
    trajs_fk_ot_cfm, _ = sample_and_compute_w2(guide_cfgs_fk_ot_cfm)
    w2s_fk_ot_cfm.append(compute_w2(trajs_fk_ot_cfm, guide_cfgs_fk_ot_cfm))

w2s_fk_ot_cfm


Monte Carlo batch size: 1024


  model.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))


Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024


[[0.24362046900326023, 0.07230156491300228, 0.07852409385546293],
 [0.20053407007942295, 0.11892456916550699, 0.06866183667646909],
 [0.3007502143289958, 0.12201327052124956, 0.05634350195398302],
 [0.2610194832097039, 0.20546555961962185, 0.0736289949409307],
 [0.2772547361363153, 0.10184224138614084, 0.0625664751664994],
 [0.23237408618076733, 0.0991061503557292, 0.06694504832151671],
 [0.1460313478144786, 0.11818549975503813, 0.07421850631169574],
 [0.23729990429175313, 0.13492513413225898, 0.0853091638251044],
 [0.14647192813321008, 0.2083862172761897, 0.06148638140322131],
 [0.11330024476641773, 0.1419187616566589, 0.05335496162129663]]

In [28]:
w2s_fk_ot_cfm_mean = np.mean(w2s_fk_ot_cfm, axis=0)
w2s_fk_ot_cfm_std = np.std(w2s_fk_ot_cfm, axis=0)
print(w2s_fk_ot_cfm_mean)
print(w2s_fk_ot_cfm_std)

[0.21586565 0.1323069  0.0681039 ]
[0.05912884 0.04165931 0.00951033]


In [26]:
num_samples = 128
fk_steps = 40

guide_cfgs_fk_cfm = [
    create_fk_guide_cfg(
        cfm="cfm",
        dist_pair=('circle', 's_curve'),
        scale=1.0,
        num_samples=num_samples,
        fk_steering_temperature=1.0,
        fk_potential_scheduler='harmonic_sum',
        resample_freq=10,
        resample_method='residual',
        ode_cfg=ODEConfig(t_end=1.0, num_steps=fk_steps, batch_size=ode_batch_size)
    ),
    create_fk_guide_cfg(
        cfm="cfm",
        dist_pair=('uniform', '8gaussian'),
        scale=1.0,
        num_samples=num_samples,
        resample_freq=10,
        fk_potential_scheduler='harmonic_sum',
        resample_method='residual',
        ode_cfg=ODEConfig(t_end=1.0, num_steps=fk_steps, batch_size=ode_batch_size)
    ),
    create_fk_guide_cfg(
        cfm="cfm",
        dist_pair=('8gaussian', 'moon'),
        scale=1.0,
        num_samples=num_samples,
        resample_freq=10,
        fk_potential_scheduler='harmonic_sum',
        resample_method='residual',
        ode_cfg=ODEConfig(t_end=1.0, num_steps=fk_steps, batch_size=ode_batch_size)
    ),
]
deterministic(0)
w_2_fk_cfm = []
for i in range(10):
    trajs_fk_cfm, _ = sample_and_compute_w2(guide_cfgs_fk_cfm)
    w_2_fk_cfm.append(compute_w2(trajs_fk_cfm, guide_cfgs_fk_cfm))

w_2_fk_cfm

Monte Carlo batch size: 1024


  model.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))


Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024


[[0.40087017768182953, 0.09087170984673199, 0.1400530436343675],
 [0.3015677014565128, 0.08415276742517853, 0.08504342233826034],
 [0.2982333932800692, 0.12658303113682112, 0.10226961609310374],
 [0.14082881048252585, 0.13580652382942332, 0.11286135574981801],
 [0.24138424998973648, 0.17839464811410496, 0.08228535626737513],
 [0.1832051156059351, 0.14084591583183628, 0.09271955952978851],
 [0.34441032122005316, 0.12953478200643787, 0.07785058339274747],
 [0.22474099001279724, 0.17283068883798944, 0.07689874194723346],
 [0.2826700369365647, 0.18640703894961524, 0.08338107821706077],
 [0.32143995748576426, 0.12009340895050144, 0.12157328000372469]]

In [27]:
w2s_fk_cfm_mean = np.mean(w_2_fk_cfm, axis=0)
w2s_fk_cfm_std = np.std(w_2_fk_cfm, axis=0)
print(w2s_fk_cfm_mean)
print(w2s_fk_cfm_std)

[0.27393508 0.13655205 0.0974936 ]
[0.07361731 0.03289885 0.02017019]


In [11]:
guide_cfgs_plain_ot_cfm = [
    GuideFnConfig(
        cfm=cfm,
        dist_pair=('circle', 's_curve'),
        guide_type='plain',
        scale=1,
        ode_cfg=ODEConfig(t_end=1.0, num_steps=100, batch_size=ode_batch_size)
    ),
    GuideFnConfig(
        cfm=cfm,
        dist_pair=('uniform', '8gaussian'),
        guide_type='plain',
        scale=1,
        ode_cfg=ODEConfig(t_end=1.0, num_steps=100, batch_size=ode_batch_size)
    ),
    GuideFnConfig(
        cfm=cfm,
        dist_pair=('8gaussian', 'moon'),
        guide_type='plain',
        ode_cfg=ODEConfig(t_end=1.0, num_steps=100, batch_size=ode_batch_size)
    ),
]

deterministic(0)
w2s_plain_ot_cfm = []
for i in range(10):
    trajs_plain_ot_cfm, _ = sample_and_compute_w2(guide_cfgs_plain_ot_cfm)
    w2s_plain_ot_cfm.append(compute_unweighted_w2(trajs_plain_ot_cfm, guide_cfgs_plain_ot_cfm))

w2s_plain_ot_cfm

Monte Carlo batch size: 1024


  model.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))


Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024


[[0.1332946272797528, 0.1037292227951248, 0.08026145424648462],
 [0.13385909465190096, 0.11279275045745335, 0.05883949770256176],
 [0.10095108802602015, 0.1497689558026552, 0.06302064403391057],
 [0.15157546768029265, 0.11918498726893402, 0.08620107305203592],
 [0.0862614457078605, 0.15660184235344418, 0.06128817738296346],
 [0.13309688161928535, 0.10437037930818178, 0.046668999565763855],
 [0.1167532046886388, 0.1205464729909164, 0.055002711368196545],
 [0.12472876987364688, 0.19931695673215657, 0.05247170643191792],
 [0.08016938458809365, 0.16944716700471735, 0.056117664453033685],
 [0.08323520711676544, 0.09374889342703024, 0.07757159770733393]]

In [12]:
w2s_plain_ot_cfm_mean = np.mean(w2s_plain_ot_cfm, axis=0)
w2s_plain_ot_cfm_std = np.std(w2s_plain_ot_cfm, axis=0)
print(w2s_plain_ot_cfm_mean)
print(w2s_plain_ot_cfm_std)

[0.11439252 0.13295076 0.06374435]
[0.02386442 0.03245908 0.01245853]


In [13]:
guide_cfgs_plain_cfm = [
    GuideFnConfig(
        cfm="cfm",
        dist_pair=('circle', 's_curve'),
        guide_type='plain',
        scale=1,
        ode_cfg=ODEConfig(t_end=1.0, num_steps=100, batch_size=1024)
    ),
    GuideFnConfig(
        cfm="cfm",
        dist_pair=('uniform', '8gaussian'),
        guide_type='plain',
        scale=1,
        ode_cfg=ODEConfig(t_end=1.0, num_steps=100, batch_size=1024)
    ),
    GuideFnConfig(
        cfm="cfm",
        dist_pair=('8gaussian', 'moon'),
        guide_type='plain',
        ode_cfg=ODEConfig(t_end=1.0, num_steps=100, batch_size=1024)
    ),
]

deterministic(0)
w2s_plain_cfm = []
for i in range(10):
    trajs_plain_cfm, _ = sample_and_compute_w2(guide_cfgs_plain_cfm)
    w2s_plain_cfm.append(compute_unweighted_w2(trajs_plain_cfm, guide_cfgs_plain_cfm))

w2s_plain_cfm

Monte Carlo batch size: 1024


  model.load_state_dict(torch.load(f'../logs/{cfg.dist_pair[0]}-{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}/{cfg.cfm}_{cfg.dist_pair[0]}_{cfg.dist_pair[1]}.pth'))


Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024
Monte Carlo batch size: 1024


[[0.21076680856771404, 0.09270063982362946, 0.07825950258873694],
 [0.1967778672357857, 0.11839058317260423, 0.08503302750804916],
 [0.18737991564258183, 0.11449741393163038, 0.09660926688619002],
 [0.20068317384703638, 0.1214364816265766, 0.0988927008898253],
 [0.18505342613678893, 0.1553370027766992, 0.09148961575622169],
 [0.17966549368152743, 0.12395935803802197, 0.07219065128990627],
 [0.1425727673880871, 0.12344463734612343, 0.08090414180151269],
 [0.15225909225521192, 0.16950508809332032, 0.08167584471871521],
 [0.15651903070850776, 0.15717479632183362, 0.07439944696694016],
 [0.13151833653286352, 0.07815638764748746, 0.07549243267164539]]

In [14]:
w2s_plain_cfm_mean = np.mean(w2s_plain_cfm, axis=0)
w2s_plain_cfm_std = np.std(w2s_plain_cfm, axis=0)
print(w2s_plain_cfm_mean)
print(w2s_plain_cfm_std)

[0.17431959 0.12546024 0.08349466]
[0.02546667 0.02709594 0.00887324]


In [29]:
import pandas as pd

# Create a csv with the means and stds in the order: mc_cfm, mc_ot_cfm, fk_cfm, fk_ot_cfm (with the three sets each)
df = pd.DataFrame({
    'method': ['mc_cfm', 'mc_ot_cfm', 'fk_cfm', 'fk_ot_cfm'],
    'mean': [w_2_mc_cfm_mean, w_2_mc_ot_cfm_mean, w2s_fk_cfm_mean, w2s_fk_ot_cfm_mean],
    'std': [w_2_mc_cfm_std, w_2_mc_ot_cfm_std, w2s_fk_cfm_std, w2s_fk_ot_cfm_std]
})

In [30]:
# Save the DataFrame to a CSV file
df.to_csv('means_stds.csv', index=False)