In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
import random
import numpy as np
from einops import rearrange
from jax import jit, vmap
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from hflow.misc.plot import scatter_movie
from jax import jacrev


In [3]:
# jax.config.update("jax_enable_x64", True)

In [162]:
# from hflow.data.particles import get_ic_vfp, get_2d_vfp
from hflow.data.sde import solve_sde
from jax import grad

def get_ic_vfp(key):
    var = 5e-2
    bias = jnp.asarray([0.4,0])
    ic = jax.random.normal(key, (2,))
    ic = (ic*var) - bias
    return ic


def get_2d_vfp():
    def phi(x):
        return -(0.2 + 0.2*jnp.cos(jnp.pi*x**4)+0.1*jnp.sin(jnp.pi*x))

    grad_phi = grad(phi)
    
    def drift(t, y, *args):
        x1, x2 = y
        x1_dot = x2
        x2_dot = -grad_phi(x1)
        return jnp.asarray([x1_dot, x2_dot])

    def diffusion(t, y, *args):
        return jnp.asarray([1, 1])*1e-2

    return drift, diffusion



drift, diffusion = get_2d_vfp()
n_samples = 2500
dt = 1e-3
t_end = 7.0
t_eval = np.linspace(0.0, t_end, int(t_end / dt) + 1)
key = jax.random.key(1)
sol = solve_sde(
    drift, diffusion, t_eval, get_ic_vfp, n_samples, dt=1e-2, key=key
)
sol = rearrange(sol, "N T D -> T N D")



In [163]:

scatter_movie(sol)

  sct = ax.scatter(x=pts[0, 0], y=pts[0, 1],


In [164]:
train_sols = sol
train_sols.shape

(7001, 2500, 2)

In [165]:
import jax.numpy as jnp
import numpy as np

from hflow.config import Config, Network
from hflow.net.build import build_colora, build_mlp


key = jax.random.PRNGKey(1)
x_dim = 2

u_config = {
    "width": 64,
    "layers": ['D']*5,
}

u_fn, params_init = build_mlp(
    u_config, in_dim=x_dim + 2, out_dim=1, key=key
)

def s_fn(t, x, params):
    t_x = jnp.concatenate([t, x])
    return jnp.squeeze(u_fn(params, t_x))


In [166]:
# key = jax.random.PRNGKey(2)
# x_dim = 2

# v_config = {
#     "width": 64,
#     "layers": ['D']*6,
# }

# dnn_fn, params_init = build_mlp(
#     v_config, in_dim=x_dim, out_dim=1, key=key
# )

# def v_fn(x, params):
#     return jnp.squeeze(dnn_fn(params, x))


In [170]:
from hflow.misc.jax import batchmap, get_rand_idx, meanvmap, tracewrap
from hflow.train.quad import get_simpsons
from hflow.train.sample import interplate_in_t

def get_sample_fn(X_data, t_data, bs_t=315, bs_n=256):
    
        # odd number of points necessary for simpsons
    if (bs_t - 1) % 2 != 0:
        bs_t += 1
    t_quad, quad_weights = get_simpsons(bs_t)
    # add start and end points for boundary term
    start, end = jnp.asarray([0]), jnp.asarray([1.0])
    t_quad = jnp.concatenate([start, t_quad, end])

    X_data = interplate_in_t(X_data, t_data, t_quad)
    t_data = jnp.asarray(t_quad)

    t_data = t_data.reshape(-1, 1)

    def sample_fn(in_key):

        nonlocal t_data
        nonlocal X_data
        
        T, N, D = X_data.shape
        T, one = t_data.shape

        X_batch = X_data

        t_batch = t_data        
        # # sample in time
        # in_key, key_t = jax.random.split(in_key)
        # t_idx = jax.random.choice(key_t, T-1, shape=(bs_t,), replace=False)
        # start, end = jnp.asarray([0]), jnp.asarray([T-1])
        # t_idx = jnp.concatenate([start, t_idx, end])
        # t_idx = jnp.sort(t_idx)
        # t_batch = t_data[t_idx]
        # X_batch = X_batch[t_idx]

        # sample in space  
        keys = jax.random.split(in_key, num=X_batch.shape[0])
        sample_idx = vmap(get_rand_idx, (0, None, None))(keys, N, bs_n)
        rows = jnp.arange(X_batch.shape[0])[:, jnp.newaxis]
        X_batch = X_batch[rows, sample_idx]

        mu = jnp.asarray([0.0])
        
        return X_batch, mu, t_batch, quad_weights
        
    return sample_fn



In [171]:
# from jax import jacrev, jacfwd

# s_Ex = meanvmap(s_fn, in_axes=(None, 0, None, None))
# s_Ex_Vt = vmap(s_Ex, in_axes=(None, 0, 0, None))

# s_dx = jacrev(s_fn, 1)
# s_dx_norm = lambda *args: jnp.sum(s_dx(*args)**2)
# s_dx_norm_Ex = meanvmap(s_dx_norm, in_axes=(None, 0, None, None))
# s_dx_norm_Ex_Vt = vmap(s_dx_norm_Ex, in_axes=(None, 0, 0, None))

# s_dt = jacrev(s_fn, 2)
# dt_Ex = meanvmap(s_dt, in_axes=(None, 0, None, None))
# dt_Ex_Vt = vmap(dt_Ex, in_axes=(None, 0, 0, None))

# laplace = tracewrap(jacfwd(s_dx, 1))
# laplace_Ex = meanvmap(laplace, in_axes=(None, 0, None, None))
# laplace_Ex_Vt = vmap(laplace_Ex, in_axes=(None, 0, 0, None))

# epsilon = 0 # 1e-2

# def loss_fn(params, X_batch, t_batch, quad_weights, key):

#     mu = jnp.asarray([0.0])
#     boundary_term = s_Ex(mu, X_batch[0], t_batch[0], params) - s_Ex(mu,  X_batch[-1], t_batch[-1], params)

#     X_batch = X_batch[1:-1]
#     t_batch = t_batch[1:-1]
    
#     grad = s_dx_norm_Ex_Vt(mu, X_batch, t_batch, params)
#     dt = dt_Ex_Vt(mu, X_batch, t_batch, params)
#     dt = jnp.squeeze(dt)
    
#     if epsilon > 0.0:
#         lap = laplace_Ex_Vt(mu, X_batch, t_batch, params)
#     else:
#         lap = 0.0
        
#     interior = 0.5*grad + dt + epsilon**2*0.5*lap

#     interior_loss = (interior * quad_weights).sum()
    
#     loss = interior_loss + boundary_term

#     return loss

In [192]:
from hflow.train.loss import DICE_Loss, OV_Loss
sigma = 1e-2


loss_fn = DICE_Loss(s_fn, sigma=sigma)

loss_fn = OV_Loss(s_fn, sigma=sigma)
        

In [193]:
from hflow.train.adam import adam_opt

sample_fn_am = get_sample_fn(train_sols, t_eval)
params, opt_params, loss_history = adam_opt(params_init, loss_fn, sample_fn_am, steps=15000, learning_rate=1e-4, verbose=True, key=key)


adam:   0%|          | 0/15000 [00:00<?, ?it/s]

2025-03-19 15:04:52.458244: W external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1058] Compiling 21 configs for gemm_fusion_dot.29 on a single thread.


In [194]:
from hflow.test.test import solve_test_sde

mu = jnp.asarray([0.0])
ics = train_sols[0]
t_int =  jnp.linspace(0.0, 1.0, len(t_eval))
dt_test = 1e-3


In [195]:
from hflow.data.sde import solve_sde_ic

s_dx = jacrev(s_fn, 1)

def drift(t, y, *args):
    t = jnp.asarray([t]).reshape(1)
    mu_t = jnp.concatenate([mu, t])
    f = jnp.squeeze(s_dx(mu_t, y, params))
    return f

def diffusion(t, y, *args):
    return sigma * jnp.ones_like(y)

keys = jax.random.split(key, num=len(ics))
test_sol = vmap(solve_sde_ic, (0, 0, None, None, None, None))(
    ics, keys, t_int, dt, drift, diffusion
)
test_sol = rearrange(test_sol, "N T D -> T N D")

In [196]:
test_sol.shape, sol.shape

((7001, 2500, 2), (7001, 2500, 2))

In [197]:

scatter_movie([sol, test_sol])

  sct = ax.scatter(x=pts[0, 0], y=pts[0, 1],
