In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from jax import vmap
import jax.numpy as jnp
from pathlib import Path
import numpy as np
from einops import rearrange
import jax
import matplotlib.pyplot as plt
from jax.experimental.ode import odeint
from colora.data import load_all_hdf5, split_data_by_mu, prepare_coordinate_data
data_dir = Path('./data')
data_path = data_dir / 'burgers'
mus, sols, spacing = load_all_hdf5(data_path)

train_mus = np.asarray([0.001, 0.00199, 0.00298, 0.00496, 0.00595, 0.00694, 0.00892, 0.01])
test_mus = np.asarray([0.00397, 0.00793])
train_sols, test_sols = split_data_by_mu(mus, sols, train_mus, test_mus) # mus X variables X time X space_x X space_y
n_mu_train, n_q, n_t, n_x1, n_x2 = train_sols.shape
n_mu_test, n_q, n_t, n_x1, n_x2 = test_sols.shape
time = spacing[1]
x_space = spacing[2]
y_space = spacing[3]

print(f'n_mu train: {n_mu_train}, n_mu test: {n_mu_test}, n_variables: {n_q}, n_time samples: {n_t}, n_x samples: {n_x1}, n_x2 samples: {n_x2}')
print(time)
print(x_space)
print(y_space)

y_train, mu_t_train, X_grid =  prepare_coordinate_data(spacing, train_mus, train_sols)
y_test, mu_t_test, X_grid =  prepare_coordinate_data(spacing, test_mus, test_sols)

def normalize(x, mean, std):
    return (x-mean)/std

mean, std = jnp.mean(mu_t_train, axis=0), jnp.std(mu_t_train, axis=0)
mu_t_train = normalize(mu_t_train, mean, std)
mu_t_test = normalize(mu_t_test, mean, std)

In [None]:
from colora.build import build_colora

key = jax.random.PRNGKey(1)

x_dim = 2
mu_t_dim = 2
u_dim = 1

d_len_x = x_space[-1] - x_space[0]
d_len_y = y_space[-1] - y_space[0]


u_layers = ['P', 'C', 'C', 'C', 'C', 'C', 'C', 'C'] # seven colora layers with 1 alpha each means we will have laten dim of 7
g_layers = ['D', 'D', 'D', 'END'] # now g takes in phi(t,mu) and psi and outputs the time derivative of phi
rank = 3

u_hat_config = {'width': 25, 'layers': u_layers}
g_config = {'width': 15, 'layers': g_layers}

u_hat_fn, g_fn, theta_init, psi_init = build_colora(
    u_hat_config, g_config, x_dim, mu_t_dim, u_dim, lora_filter=['alpha'],period=[d_len_x, d_len_y], rank=rank, key=key)

#check neural ODE is correctly defined

print(g_fn)


In [None]:
u_hat_v_x =  vmap(u_hat_fn, in_axes=(None, None, 0)) # vmaped over x to generate solution field over space points
u_hat_v_x_phi =  vmap(u_hat_v_x, in_axes=(None, 0, None)) # vmaped over x to generate solution field over space points


In [None]:
n_phi = len(u_layers)-1
phi_init = jnp.zeros(n_phi)

# define time integrator for g(phi(t, mu), psi) and vmap over mu
def integrate_g_forall_mu(g_fn, phi_init, time, mu_t, X_grid):
    batched_integrator = vmap(odeint, in_axes=(None, None, None, 0, None))
    phi_forall_mu = batched_integrator(g_fn, phi_init, time, mu_t, X_grid)
    return phi_forall_mu

# we will now define loss interms of the params of u_hat and g
def predict(psi_theta,mu_t, X_grid):
    psi, theta = psi_theta
    phi = integrate_g_forall_mu(g_fn_vmap, phi_init, time, mu_t, X_grid)
    u_hat = u_hat_v_x_phi(theta, phi, X_grid)
    return u_hat

def loss_NODE(psi_theta,sols, mu_t, X_grid):
    pred=predict(psi_theta,mu_t, X_grid)
    loss = jnp.linalg.norm(
        sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
    return jnp.mean(loss)


In [120]:
#### Given by GPT
n_phi = len(u_layers) - 1
phi_init = jnp.ones(n_phi)

# def integrate_g_vmap_mu(g_fn, phi_init, time, mu_t): # integrate g over all time for all mu
#     batched_integrator = vmap(odeint, in_axes=(None, None, None, 0))
#     phi_forall_mu = batched_integrator(g_fn, phi_init, time, mu_t)
#     return phi_forall_mu

# def predict(psi_theta, phi_init, time, mu_t, X_grid):
#     psi, theta = psi_theta
#     phi = integrate_g_vmap_mu(g_fn, phi_init, time, mu_t)
#     pred = u_hat_v_x_phi(theta, phi, X_grid)
#     return pred

def predict(g_fn, psi_theta, phi_init, time, mu_t, X_grid):
    """
    Parameters:
    - g_fn: The function representing the neural ODE's RHS, which outputs time derivatives of phi.
    - psi_theta: Tuple containing the parameters for psi and theta.
    - phi_init: Initial conditions for phi.
    - time: Array of time points for integration.
    - mu_t: Array of mu values, one for each scenario to integrate over.
    - X_grid: Spatial grid points at which evaluations are required.
    Returns:
    - Predicted values of u_hat across all time points, mu values, and spatial points.
    """
    psi, theta = psi_theta
    
    # Integrate over all mu values using vmap for batch processing
    batched_integrator = vmap(odeint, in_axes=(None, None, None, 0))
    phi_forall_mu = batched_integrator(g_fn, phi_init, time, mu_t)

    # Evaluate u_hat_fn using the results of the integration, vectorized over all mu and spatial points
    # Assuming u_hat_fn is structured to take inputs (theta, phi, X_grid)
    # and that it can handle batched inputs for phi and X_grid
    u_hat_v_x_phi = vmap(vmap(u_hat_fn, in_axes=(None, None,0)), in_axes=(None, 0, None))
    pred = u_hat_v_x_phi(theta, phi_forall_mu, X_grid)
    return pred


def loss_NODE(psi_theta, mu_t, sols, X_grid):
    pred = predict(psi_theta, time, mu_t, X_grid)
    print("Shapes - sols: {}, pred: {}".format(sols.shape, pred.shape))
    loss = jnp.linalg.norm(
        sols - pred, axis=(1, 2)) / jnp.linalg.norm(sols, axis=(1, 2))
    return jnp.mean(loss)


In [123]:
from colora.data import Dataset

# the dataset is just responsible for batching the data over the mu_t tensor 
# in order to aviod memory overflow on the GPU
dataset = Dataset(mu_t_train, X_grid, y_train, n_batches=15, key=key)
dataset = iter(dataset)
def args_fn():
    return next(dataset),


In [124]:
from colora.adam import adam_opt

psi_theta = (psi_init, theta_init)
opt_psi_theta, loss_history = adam_opt(psi_theta, loss_NODE, args_fn, steps=100, learning_rate=5e-3, verbose=True) 
opt_psi, opt_theta = opt_psi_theta

print(opt_psi_theta)
print(opt_psi)
print(opt_theta)

pred = predict(opt_psi_theta, mu_t_test, X_grid)
pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t, N1=n_x1, N2=n_x2)

test_vec =  rearrange(test_sols, 'M Q T N1 N2 -> M (Q T N1 N2)')
pred_vec =  rearrange(pred, 'M Q T N1 N2 -> M (Q T N1 N2)')
rel_err = np.linalg.norm(test_vec - pred_vec, axis=1) / np.linalg.norm(test_vec, axis=1)
mean_rel_err = rel_err.mean()
print(f'Test mean relative error: {mean_rel_err:.2E}')
from colora.plot import imshow_movie

imshow_movie(pred[0][0], save_to='./img/burgers.gif', t=spacing[1], title='Burgers', tight=True, live_cbar=True, frames=85)

  0%|          | 0/100 [00:00<?, ?it/s]


TypeError: loss_NODE() missing 3 required positional arguments: 'sols', 'mu_t', and 'X_grid'

In [None]:
print(opt_psi_theta)
print(opt_psi)
print(opt_theta)

In [None]:
pred = predict(opt_psi_theta, mu_t_test, X_grid)
pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t, N1=n_x1, N2=n_x2)

test_vec =  rearrange(test_sols, 'M Q T N1 N2 -> M (Q T N1 N2)')
pred_vec =  rearrange(pred, 'M Q T N1 N2 -> M (Q T N1 N2)')
rel_err = np.linalg.norm(test_vec- pred_vec, axis=1)/np.linalg.norm(test_vec, axis=1)
mean_rel_err = rel_err.mean()
print(f'Test mean relative error: {mean_rel_err:.2E}')

In [None]:
from colora.plot import imshow_movie

imshow_movie(pred[0][0], save_to='./img/burgers.gif', t=spacing[1], title='Burgers', tight=True, live_cbar=True, frames=85)

phis = h_v_mu_t(opt_psi, mu_t_test)
phis = rearrange(phis, '(M T) D -> M T D', T=n_t)


In [None]:
from colora.plot import trajectory_movie
leg= []
for i in range(phis.shape[-1]):
    lstr =rf'$\phi_{i}$'
    leg.append(lstr)
trajectory_movie(phis[0], x=time, title='Burgers', ylabel=r'$\phi(t;\mu)$', legend=leg, save_to='./img/burgers_dynamics', frames=85)