In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
from jax import vmap
import jax.numpy as jnp
from pathlib import Path
import numpy as np
from einops import rearrange

In [30]:
from colora.data import load_all_hdf5, split_data_by_mu, prepare_coordinate_data
data_dir = Path('./data')
data_path = data_dir / 'vlasov'
mus, sols, spacing = load_all_hdf5(data_path)

train_mus = np.asarray([0.2, 0.3, 0.4])
test_mus = np.asarray([0.25, 0.35])
train_sols, test_sols = split_data_by_mu(mus, sols, train_mus, test_mus) # mus X variables X time X space_x X space_y


In [31]:
# downsample data
n_space, n_time = 4, 6 # downsample in space and time by these factors
train_sols, test_sols = train_sols[: , :, ::n_time, ::n_space, ::n_space], test_sols[: , :, ::n_time, ::n_space, ::n_space]
spacing[1] = spacing[1][::n_time]
spacing[2] = spacing[2][::n_space]
spacing[3] = spacing[3][::n_space]
n_mu_train, n_q, n_t, n_x1, n_x2 = train_sols.shape
n_mu_test, n_q, n_t, n_x1, n_x2 = test_sols.shape

In [32]:
y_train, mu_t_train, X_grid =  prepare_coordinate_data(spacing, train_mus, train_sols)
y_test, mu_t_test, X_grid =  prepare_coordinate_data(spacing, test_mus, test_sols)

In [33]:
from colora.build import build_colora

x_dim = 2
mu_t_dim = 2
u_dim = 1

u_layers = ['P', 'C', 'C', 'C', 'C', 'C']
h_layers = ['D', 'D', 'D']

u_hat_config = {'width': 10, 'layers': u_layers}
h_config = {'width': 5, 'layers': h_layers}

u_hat_fn, h_fn, theta_init, psi_init = build_colora(
    u_hat_config, h_config, x_dim, mu_t_dim, u_dim, lora_filter='alpha', period=[2.0, 2.0])

In [34]:

h_v_mu_t = vmap(h_fn, in_axes=(None, 0)) # vmap over mu_t array to generate array of phis
u_hat_v_x =  vmap(u_hat_fn, in_axes=(None, None, 0)) # vmaped over x to generate solution field over space points
u_hat_v_x_phi =  vmap(u_hat_v_x, in_axes=(None, 0, None)) # vmaped over x to generate solution field over space points

In [35]:
def predict(psi_theta, mu_t, X_grid):
    psi, theta = psi_theta
    phis = h_v_mu_t(psi, mu_t)
    pred = u_hat_v_x_phi(theta, phis, X_grid)
    return pred

def relative_loss_fn(psi_theta, mu_t, X_gird, sols):
    pred = predict(psi_theta, mu_t, X_gird)
    pred = pred.reshape(*sols.shape)
    loss = jnp.linalg.norm(
        sols - pred, axis=1) / jnp.linalg.norm(sols, axis=1)
    return loss.mean()



In [36]:
from colora.adam import adam_opt


psi_theta = (psi_init, theta_init)
args = (mu_t_train, X_grid, y_train)
opt_psi_theta, loss_history = adam_opt(psi_theta, relative_loss_fn, args=args, steps=5_000, learning_rate=5e-3, verbose=True)


  0%|          | 0/5000 [00:00<?, ?it/s]

In [37]:

pred = predict(opt_psi_theta, mu_t_test, X_grid)
pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t, N1=n_x1, N2=n_x2)

In [38]:
from colora.plot import imshow_movie

imshow_movie(pred[2, 0])