In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
import random
import numpy as np
from einops import rearrange
from jax import jit, vmap
from hoam.plot import scatter_movie, imshow_movie
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [3]:
jax.config.update("jax_enable_x64", True)

In [4]:
from hoam.vlasov import run_vlasov


n_samples = 25_000 
dt = 1e-2
t_end = 40

t_eval = np.linspace(0.0, t_end, int(t_end/dt)+1)
train_mus = np.asarray([1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9])
test_mus = np.asarray([1.25, 1.85])
mus = np.concatenate([train_mus, test_mus])
train_sols = []
for mu in tqdm(train_mus):
    res = run_vlasov(n_samples, t_eval, mu, mode='two-stream')
    train_sols.append(res)
    
test_sols = []
for mu in tqdm(test_mus):
    res = run_vlasov(n_samples, t_eval, mu, mode='two-stream')
    test_sols.append(res)
    
train_sols = np.asarray(train_sols)
test_sols = np.asarray(test_sols)

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
from hoam.utils import normalize_data

train_sols, unnorm_train_sol = normalize_data(train_sols, axis=(0, 1, 2), method='01')
test_sols, unnorm_test_sol = normalize_data(test_sols, axis=(0, 1, 2), method='01')
train_mus, unnorm_train_mu = normalize_data(train_mus, axis=(0,), method='std')

In [21]:
from hoam.utils import get_hist_over_time
H = get_hist_over_time(train_sols[0])
imshow_movie(H, t=t_eval, title='True', save_to='./img/vlasov_true.gif')

In [7]:
from hoam.colora import build_colora

key = jax.random.PRNGKey(1)

x_dim = 2
mu_t_dim = 2
u_dim = 1

u_hat_config = {'width': 64, 'layers': ['C']*7}
h_config = {'width': 15, 'layers': ['D']*3}

s_fn, params_init = build_colora(u_hat_config, h_config, x_dim, mu_t_dim, u_dim, rank=3, key=key)

In [8]:
from hoam.utils import get_rand_idx, interplate_in_t
from hoam.quad import get_simpson_quadrature, get_gauss_quadrature
def get_sample_fn(X_data, t_data, mu_data, bs_t=256, bs_n=256, quad='simp', jit_fn=True):
    
    t_data = t_data / t_data[-1] # normalize in [0,1]
    
    # if we are doing monte carlo, we dont need to interpolate
    if quad != 'mc':
        if quad == 'simp':
            # odd number of points necessary for simpsons
            if (bs_t - 1) % 2 != 0:
                bs_t += 1
            t_quad, quad_weights = get_simpson_quadrature(bs_t)
        elif quad == 'gauss':            
            t_quad, quad_weights = get_gauss_quadrature(bs_t)
        
        # add start and end points for boundary term
        start, end = jnp.asarray([0]), jnp.asarray([1.0])
        t_quad = jnp.concatenate([start, t_quad, end])
        
        X_data_interp = []
        for X_mu in X_data:
            res = interplate_in_t(X_mu, t_data, t_quad)
            X_data_interp.append(res)
        X_data = jnp.asarray(X_data_interp)
        t_data = jnp.asarray(t_quad)
        
    else:
        X_data = jnp.asarray(X_data)
        t_data = jnp.asarray(t_data)
        quad_weights = jnp.ones(bs_t) / bs_t
        
    t_data = t_data.reshape(-1, 1)
    mu_data = mu_data.reshape(-1, 1)
    mu_data = jnp.asarray(mu_data)
    
    def sample_fn(in_key):

        nonlocal t_data
        nonlocal X_data
        nonlocal mu_data
        
        M, T, N, D = X_data.shape
        T, one = t_data.shape
        
        # sample in mu
        n_key, mu_key = jax.random.split(in_key)
        mu_idx = jax.random.randint(mu_key, minval=0, maxval=M, shape=())
        mu_batch = mu_data[mu_idx]
        X_batch = X_data[mu_idx]
        
        # sample in time
        if quad == 'mc':
            in_key, key_t = jax.random.split(in_key)
            t_idx = jax.random.choice(key_t, T-1, shape=(bs_t,), replace=False)
            start, end = jnp.asarray([0]), jnp.asarray([T-1])
            t_idx = jnp.concatenate([start, t_idx, end])
            t_idx = jnp.sort(t_idx)
            t_batch = t_data[t_idx]
            X_batch = X_batch[t_idx]
        else:
            t_batch = t_data
          
        # sample in space  
        keys = jax.random.split(in_key, num=X_batch.shape[0])
        sample_idx = vmap(get_rand_idx, (0, None, None))(keys, N, bs_n)
        rows = jnp.arange(X_batch.shape[0])[:, jnp.newaxis]
        X_batch = X_batch[rows, sample_idx]
        
        return mu_batch, X_batch, t_batch, quad_weights
        
    return sample_fn



In [9]:
from hoam.utils import meanvmap, tracewrap
from jax import jacrev, jacfwd

s_Ex = meanvmap(s_fn, in_axes=(None, 0, None, None))
s_Ex_Vt = vmap(s_Ex, in_axes=(None, 0, 0, None))

s_dx = jacrev(s_fn, 1)
s_dx_norm = lambda *args: jnp.sum(s_dx(*args)**2)
s_dx_norm_Ex = meanvmap(s_dx_norm, in_axes=(None, 0, None, None))
s_dx_norm_Ex_Vt = vmap(s_dx_norm_Ex, in_axes=(None, 0, 0, None))

s_dt = jacrev(s_fn, 2)
dt_Ex = meanvmap(s_dt, in_axes=(None, 0, None, None))
dt_Ex_Vt = vmap(dt_Ex, in_axes=(None, 0, 0, None))

laplace = tracewrap(jacfwd(s_dx, 1))
laplace_Ex = meanvmap(laplace, in_axes=(None, 0, None, None))
laplace_Ex_Vt = vmap(laplace_Ex, in_axes=(None, 0, 0, None))


epsilon = 5e-2


def loss_fn(params, mu, X_batch, t_batch, quad_weights):

    boundary_term = s_Ex(mu, X_batch[0], t_batch[0], params) - s_Ex(mu,  X_batch[-1], t_batch[-1], params)

    X_batch = X_batch[1:-1]
    t_batch = t_batch[1:-1]
    
    grad = s_dx_norm_Ex_Vt(mu, X_batch, t_batch, params)
    dt = dt_Ex_Vt(mu, X_batch, t_batch, params)
    dt = jnp.squeeze(dt)
    
    if epsilon > 0.0:
        lap = laplace_Ex_Vt(mu, X_batch, t_batch, params)
    else:
        lap = 0.0
        
    interior = 0.5*grad + dt + epsilon**2*0.5*lap

    interior_loss = (interior * quad_weights).sum()
    
    loss = interior_loss + boundary_term
    
    return loss

In [10]:
from hoam.adam import adam_opt

sample_fn_am = get_sample_fn(train_sols, t_eval, train_mus, quad='mc')
opt_params_am, loss_history_am = adam_opt(params_init, loss_fn, sample_fn_am, steps=25_000, learning_rate=2e-3, verbose=True, key=key)


  0%|          | 0/25000 [00:00<?, ?it/s]

In [16]:
from hoam.sde import solve_test_sde

test_mu_idx = 0
test_mu = train_mus[test_mu_idx]
ic = test_sols[test_mu_idx, 0]
t_int =  jnp.linspace(0.0, 1.0, len(t_eval))
dt_test = 1e-3
test_sol = solve_test_sde(s_fn, opt_params_am, ic, t_int, dt_test, epsilon, test_mu.reshape(1), key)


In [17]:
H = get_hist_over_time(test_sol)
imshow_movie(H, title='AM', t=t_eval, save_to='./img/vlasov_am.gif')

In [13]:

sample_fn_hoam = get_sample_fn(train_sols, t_eval, train_mus, quad='simp')
opt_params_hoam, loss_history_hoam = adam_opt(params_init, loss_fn, sample_fn_hoam, steps=25_000, learning_rate=2e-3, verbose=True, key=key)


  0%|          | 0/25000 [00:00<?, ?it/s]

In [18]:

test_mu_idx = 0
test_mu = train_mus[test_mu_idx]
ic = test_sols[test_mu_idx, 0]
t_int =  jnp.linspace(0.0, 1.0, len(t_eval))
dt_test = 1e-3
test_sol = solve_test_sde(s_fn, opt_params_hoam, ic, t_int, dt_test, epsilon, test_mu.reshape(1), key)


In [20]:
H = get_hist_over_time(test_sol)
imshow_movie(H, title='HOAM-S [Ours]', t=t_eval, save_to='./img/vlasov_hoam.gif')