In [1]:
%load_ext autoreload
%autoreload 2

In [49]:
import jax
import jax.numpy as jnp
import random
import numpy as np
from einops import rearrange
from jax import jit, vmap
from cone.utils.plot import scatter_movie, imshow_movie
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
from cone.utils.misc import pshape, get_lims


In [4]:
# omega = 6
    
# def get_inital_condition(key):
#     mu_0 = jnp.asarray([1, 1, 0, 0])
#     ic = jax.random.normal(key, (4,))
#     ic = (ic*0.1) - mu_0
#     return ic

# def drift(t, y, *args):

#     x1, x2, x3, x4 = y
#     x1_dot = x3
#     x2_dot = x4
#     x3_dot = -omega**2*x1
#     x4_dot = -omega**2*x2
#     return jnp.asarray([x1_dot, x2_dot, x3_dot, x4_dot])

# def diffusion(t, y, *args):
#     return jnp.asarray([0, 0, 1, 1])*5e-2

# seed = 1
# key = jax.random.key(seed)
# dt = 1e-3
# t_end = 1
# t_eval = np.linspace(0.0, t_end, int(t_end/dt)+1)
   
# n_samples = 10_000

In [10]:
def get_inital_condition(key):
    mu_0 = jnp.asarray([0, 4])
    ic = jax.random.normal(key, (2,))
    ic = (ic*0.5) - mu_0
    return ic

mu =  0.1

def drift(t, y, *args):
    xi, gamma, w = 0.20, mu, 1.0
    x1, x2 = y
    x1_dot = x2
    x2_dot = -2*xi*w*x2 + w**2*x1 - w**2*gamma*x1**3
    return jnp.asarray([x1_dot, x2_dot])

def diffusion(t, y, *args):
    return jnp.asarray([0, 1])



seed = 1
key = jax.random.key(seed)
dt = 1e-3
t_end = 4
t_eval = np.linspace(0.0, t_end, int(t_end/dt)+1)
   
n_samples = 10_000

In [56]:
from cone.integrate.sde import solve_sde

sol = solve_sde(drift, diffusion, t_eval, get_inital_condition, n_samples, dt=1e-3, key=key)
sol = sol[..., :2] # drop velocity dimensions
sol = rearrange(sol, 'N T D -> T N D')
xlim, ylim = get_lims(sol)

In [50]:
scatter_movie(sol, alpha=0.2, t=t_eval)#, ylim=[-1.5,1.5], xlim=[-1.5,1.5])

In [14]:
from cone.utils.misc import get_rand_idx, interplate_in_t
from cone.integrate.quad import get_simpson_quadrature, get_gauss_quadrature

def get_sample_fn(X_data, t_data, mu_data, bs_tau = 16, bs_t=256, bs_n=256, quad='simp', jit_fn=True):
    
    t_data = t_data / t_data[-1] # normalize in [0,1]

    # if we are doing monte carlo, we dont need to interpolate
    if quad != 'mc':
        if quad == 'simp':
            # odd number of points necessary for simpsons
            if (bs_t - 1) % 2 != 0:
                bs_t += 1
            t_batch, quad_weights = get_simpson_quadrature(bs_t)
        elif quad == 'gauss':            
            t_batch, quad_weights = get_gauss_quadrature(bs_t)
        
        # add start and end points for boundary term
        start, end = jnp.asarray([0]), jnp.asarray([1.0])
        t_batch = jnp.concatenate([start, t_batch, end])
        X_batch = interplate_in_t(X_data, t_data, t_batch)
    
        X_batch = jnp.asarray(X_batch)
        t_batch = jnp.asarray(t_batch)
    else:

        X_batch = jnp.asarray(X_data)
        t_batch = jnp.asarray(t_data)
        quad_weights = jnp.ones(bs_t) / bs_t
        
    t_batch = t_batch.reshape(-1, 1)
    
    def sample_fn(in_key):

        nonlocal X_batch
        nonlocal t_batch
    
        T, N, D = X_batch.shape
        T, one = t_batch.shape
        
        in_key, tau_key = jax.random.split(in_key)
     
        tau_batch = jax.random.uniform(tau_key, (bs_tau, 1))

        if quad == 'mc':
            in_key, key_t = jax.random.split(in_key)
            t_idx = jax.random.choice(key_t, T-1, shape=(bs_t,), replace=False)
            start, end = jnp.asarray([0]), jnp.asarray([T-1])
            t_idx = jnp.concatenate([start, t_idx, end])
            t_idx = jnp.sort(t_idx)
            t_batch = t_batch[t_idx]
            X_batch = X_batch[t_idx]
            
        keys = jax.random.split(in_key, num=X_batch.shape[0])
        sample_idx = vmap(get_rand_idx, (0, None, None))(keys, N, bs_n)
        rows = jnp.arange(X_batch.shape[0])[:, jnp.newaxis]
        X_batch = X_batch[rows, sample_idx]

        # in_key, x_key = jax.random.split(in_key)
        # sample_idx = get_rand_idx(x_key, N, bs_n)
        # X_batch = X_batch[:, sample_idx]

        return tau_batch, X_batch, t_batch

    if jit_fn:
        sample_fn = jit(sample_fn)
        
    return sample_fn

mu_data = jnp.asarray([1.0]) # dummy mu for this example

In [30]:
from cone.net.mlp import MLP
from cone.net.adam import adam_opt

dim = 2
D = 6
features = [*[64]*D, dim]

s_net = MLP(features=features, squeeze=True)
pt = jnp.ones(dim+2)
params_init = s_net.init(key, pt)

In [39]:
from cone.utils.misc import meanvmap, tracewrap, sqwrap
from jax import jacrev, jacfwd, grad


def s_fn(tau, x, t, params):
    tau_x_t = jnp.concatenate([tau,x,t])
    return s_net.apply(params, tau_x_t)
@sqwrap
def alpha(tau):
    t_fn = lambda tau: jnp.cos(jnp.pi*tau)**2
    f_fn = lambda tau: tau*0.0
    res = jax.lax.cond(tau <= 0.5, t_fn, f_fn, tau)
    return res
    
@sqwrap
def beta(tau):
    t_fn = lambda tau: jnp.cos(jnp.pi*tau)**2
    f_fn = lambda tau: tau*0.0
    res = jax.lax.cond(tau > 0.5, t_fn, f_fn, tau)
    return res

@sqwrap
def gamma(tau):
    return jnp.sin(jnp.pi*tau)**2 * 0.1
    # return jnp.sqrt(2*tau*(1-tau))

@sqwrap
def interpolant(tau, xt, xt_m1):
    return alpha(tau) * xt + beta(tau) * xt_m1

interpolant_dt = jacrev(interpolant)
gamma_dt = grad(gamma)

def flow_match(xt, xt_m1, t, tau, params, key):

    tau = jnp.squeeze(tau)
    t = jnp.squeeze(t)

    D = xt.shape[0]
    noise_i, noise_l = jax.random.normal(key, shape=(2,D))
    x_tau = interpolant(tau, xt, xt_m1) + gamma(tau)*noise_i

    dt_i = interpolant_dt(tau, xt, xt_m1)
    s = s_fn(tau.reshape(1), x_tau, t.reshape(1), params)
    g_dt = gamma_dt(tau)
    dt_i = jnp.squeeze(dt_i)

    pshape(s, dt_i, g_dt, noise_i)

    l = jnp.dot(s, s) - 2*jnp.dot(s, dt_i + g_dt*noise_l)

    return jnp.squeeze(l)
    
fm_Vx = meanvmap(flow_match, in_axes=(0, 0, None, None, None, None))
fm_Vx_Vtau = meanvmap(fm_Vx, in_axes=(None, None, None, 0, None, None))
fm_Vx_Vtau_Vt = meanvmap(fm_Vx_Vtau, in_axes=(0, 0, 0, None, None, None))

def loss_fn(params, tau_batch, X_batch, t_batch, key):
    
    X_batch = X_batch[1:-1]
    t_batch = t_batch[1:-1]
    
    X_t = X_batch[1:]
    X_tm1 = X_batch[:-1]
    t = t_batch[1:]
    
    loss = fm_Vx_Vtau_Vt(X_t, X_tm1, t, tau_batch, params, key)

    return loss

In [53]:
from cone.net.adam import adam_opt
sample_fn = get_sample_fn(sol, t_eval, mu_data, bs_tau=64, quad='mc')
opt_params, loss_history = adam_opt(params_init, loss_fn, sample_fn, steps=5000, learning_rate=5e-3, verbose=True, key=key, loss_key=True)


  0%|          | 0/5000 [00:00<?, ?it/s]

s: (2,) | dt_i: (2,) | g_dt: () | noise_i: (2,) | 


In [57]:
from cone.integrate.sde import odeint_rk4

def solve_test_cfm(s_fn, params, ics, t_int, T):

    s_Vx = vmap(s_fn, (None, 0, None, None))
   
    @jit
    def integrate(ics, physical_t, params):

        def fn(tau, y):
            return s_Vx(tau, y, physical_t, params)

        sol = odeint_rk4(fn, ics, taus)

        return sol

    test_sol = []
    taus = jnp.linspace(0, 1, T).reshape(-1, 1)
    t_int = t_int.reshape(-1, 1)
    for physical_t in tqdm(t_int, desc='CFM'):
        sol = integrate(ics, physical_t, params)
        ics = sol[-1]
        test_sol = [*test_sol, *sol]
        # test_sol.append(ics)

    test_sol = jnp.asarray(test_sol)
    test_sol = jnp.squeeze(test_sol)

    return test_sol

In [58]:
ic = sol[0]
t_int =  jnp.linspace(0.0, 1.0, 256)
T_tau = 32
n_plot = 1000
idx_sample = np.linspace(0, ic.shape[0]-1, n_plot, dtype=np.uint32)

sol_cfm = solve_test_cfm(s_fn, opt_params, ic[idx_sample], t_int, T_tau)

CFM:   0%|          | 0/256 [00:00<?, ?it/s]

In [60]:

scatter_movie(sol_cfm[::T_tau], c='b', alpha=0.1, xlim=xlim, ylim=ylim)

In [52]:
ii= 0
scatter_movie(sol_cfm[T_tau*ii:T_tau*(ii+1)], c='b', alpha=0.1, xlim=xlim, ylim=ylim)