**Experiments**

Alter latent dimension & augmneted dimension

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from jax import vmap
import jax.numpy as jnp
from pathlib import Path
import numpy as np
from einops import rearrange
import jax
import matplotlib.pyplot as plt

In [None]:
from colora.data import load_all_hdf5, split_data_by_mu, prepare_coordinate_data
data_dir = Path('./data')
data_path = data_dir / 'burgers'
mus, sols, spacing = load_all_hdf5(data_path)


train_mus = np.asarray([0.001, 0.00199, 0.00298, 0.00496, 0.00595, 0.00694, 0.00892, 0.01])
test_mus = np.asarray([0.00397, 0.00793])
train_sols, test_sols = split_data_by_mu(mus, sols, train_mus, test_mus) # mus X variables X time X space_x X space_y


train_sols_first,train_sols_second = jnp.array_split(train_sols, 2, axis=2)
test_sols_first,test_sols_second = jnp.array_split(test_sols, 2, axis=2)

n_mu_train, n_q, n_t, n_x1, n_x2 = train_sols.shape
n_mu_test, n_q, n_t, n_x1, n_x2 = test_sols.shape
time = spacing[1]
x_space = spacing[2]
y_space = spacing[3]

print(f'n_mu train: {n_mu_train}, n_mu test: {n_mu_test}, n_variables: {n_q}, n_time samples: {n_t}, n_x samples: {n_x1}, n_x2 samples: {n_x2}')


In [None]:
time=jnp.linspace(0, 1.0,51)

In [None]:
y_train, mu_t_train, X_grid =  prepare_coordinate_data(spacing, train_mus, train_sols)
y_test, mu_t_test, X_grid =  prepare_coordinate_data(spacing, test_mus, test_sols)
# we check the shapes of the data
print(f'y_train: {y_train.shape}, mu_t_train: {mu_t_train.shape}, X_grid: {X_grid.shape}')
print(f'y_test: {y_test.shape}, mu_t_test: {mu_t_test.shape}, X_grid: {X_grid.shape}')

# we cut everything in half with time

y_train_matrix = y_train.reshape(n_mu_train, n_t, n_x1 *n_x2,1)
y_test_matrix = y_test.reshape(n_mu_test, n_t, n_x1 *n_x2,1)

y_train_first,y_train_second = jnp.array_split(y_train_matrix,2,axis=1)
y_test_first,y_test_second = jnp.array_split(y_test_matrix,2,axis=1)

# flatten the first two dimensions
y_train_first = rearrange(y_train_first, 'mu t x q -> (mu t) x q')
y_train_second = rearrange(y_train_second, 'mu t x q -> (mu t) x q')

y_test_first = rearrange(y_test_first, 'mu t x q -> (mu t) x q')
y_test_second = rearrange(y_test_second, 'mu t x q -> (mu t) x q')

t_span_first,t_span_second = jnp.array_split(time,2,axis=0)
n_t_first=len(t_span_first)
n_t_second=len(t_span_second)



In [None]:
def normalize(x, mean, std):
    return (x-mean)/std

mean, std = jnp.mean(mu_t_train, axis=0), jnp.std(mu_t_train, axis=0)
mu_t_train = normalize(mu_t_train, mean, std)
mu_t_test = normalize(mu_t_test, mean, std)

# we reshape mu_t into a matrix of size n_mu_train by n_t
mu_t_train_matrix = mu_t_train.reshape(n_mu_train, n_t,2)
mu_t_test_matrix = mu_t_test.reshape(n_mu_test, n_t,2)

mu_t_train_first,mu_t_train_second = jnp.array_split(mu_t_train_matrix,2,axis=1)
mu_t_train_first = rearrange(mu_t_train_first, 'mu t x -> (mu t) x')
mu_t_train_second = rearrange(mu_t_train_second, 'mu t x -> (mu t) x')

mu_t_test_first,mu_t_test_second = jnp.array_split(mu_t_test_matrix,2,axis=1)
mu_t_test_first = rearrange(mu_t_test_first, 'mu t x -> (mu t) x')
mu_t_test_second = rearrange(mu_t_test_second, 'mu t x -> (mu t) x')


In [None]:
from colora.build import build_colora

key = jax.random.PRNGKey(1)


d_len_x = x_space[-1] - x_space[0]
d_len_y = y_space[-1] - y_space[0]

x_dim = 2
mu_t_dim = 2
u_dim = 1

u_layers = ['P', 'C', 'C', 'C', 'C', 'C', 'C', 'C'] # 7 colora layers with 1 alpha each means we will have laten dim of 2
h_layers = ['D', 'D', 'D']
rank = 3

u_hat_config = {'width': 25, 'layers': u_layers}
h_config = {'width': 15, 'layers': h_layers}

u_hat_fn, h_fn, theta_init, psi_init = build_colora(
    u_hat_config, h_config, x_dim, mu_t_dim, u_dim, lora_filter=['alpha'], period=[d_len_x, d_len_y], rank=rank, key=key)

In [None]:

h_v_mu_t = vmap(h_fn, in_axes=(None, 0)) # vmap over mu_t array to generate array of phis
u_hat_v_x =  vmap(u_hat_fn, in_axes=(None, None, 0)) # vmaped over x to generate solution field over space points
u_hat_v_x_phi =  vmap(u_hat_v_x, in_axes=(None, 0, None)) # vmaped over x to generate solution field over space points

In [None]:
def predict(psi_theta, mu_t, X_grid):
    psi, theta = psi_theta
    phis = h_v_mu_t(psi, mu_t)
    pred = u_hat_v_x_phi(theta, phis, X_grid)
    return pred

def relative_loss_fn(psi_theta, mu_t, sols, X_grid):
    pred = predict(psi_theta, mu_t, X_grid)
    loss = jnp.linalg.norm(
        sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
    return loss.mean()


In [None]:
from colora.data import Dataset

# the dataset is just responsible for batching the data over the mu_t tensor 
# in order to aviod memory overflow on the GPU
dataset = Dataset(mu_t_train, X_grid, y_train, n_batches=15, key=key)
dataset = iter(dataset)
def args_fn():
    return next(dataset)

In [None]:
from colora.adam import adam_opt

psi_theta = (psi_init, theta_init)
opt_psi_theta, loss_history = adam_opt(psi_theta, relative_loss_fn, args_fn, steps=100, learning_rate=5e-3, verbose=True)
opt_psi, opt_theta = opt_psi_theta


In [None]:
# # we first use pickle to save the psi and theta

# # trained 100k steps

import pickle
with open('burgers_opt_psi.pkl', 'wb') as f:
    pickle.dump(opt_psi, f)
with open('burgers_opt_theta.pkl', 'wb') as f:
    pickle.dump(opt_theta, f)


In [None]:
# # # load optimal psi and theta
import pickle
with open('burgers_opt_psi.pkl', 'rb') as f:
    opt_psi = pickle.load(f)
with open('burgers_opt_theta.pkl', 'rb') as f:
    opt_theta = pickle.load(f)

opt_psi_theta=(opt_psi,opt_theta)

In [None]:
#relative loss
n_t=len(time)
pred = predict(opt_psi_theta, mu_t_test, X_grid)
pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t, N1=n_x1, N2=n_x2)

pred_vec =  rearrange(pred, 'M Q T N1 N2 -> M (Q T N1 N2)') 
pred_vec_for_time = rearrange(pred, 'M Q T N1 N2 -> T (M Q N1 N2)')

test_vec =  rearrange(test_sols, 'M Q T N1 N2 -> M (Q T N1 N2)')
test_vec_for_time = rearrange(test_sols, 'M Q T N1 N2 -> T (M Q N1 N2)')

rel_err = np.linalg.norm(test_vec- pred_vec, axis=1)/np.linalg.norm(test_vec, axis=1)
mean_rel_err = rel_err.mean()


rel_err_over_time = np.linalg.norm(test_vec_for_time - pred_vec_for_time, axis=1) / np.linalg.norm(test_vec_for_time, axis=1)


print(f'Test mean relative error : {mean_rel_err:.2E}')

In [None]:
from colora.plot import imshow_movie

imshow_movie(pred[0][0], save_to='./img/burgers_test/burgers.gif', t=time, title='burgers', tight=True, frames=85)

In [None]:
phis = h_v_mu_t(opt_psi, mu_t_test)
phis = rearrange(phis, '(M T) D -> M T D', T=n_t)
phis_all = []
phis_all.append(phis[0])

from colora.plot import trajectory_movie
leg= []
for i in range(phis.shape[-1]):
    lstr =rf'$\phi_{i}$'
    leg.append(lstr)
trajectory_movie(phis[0], x=time, title='burgers', ylabel=r'$\phi(t;\mu)$', legend=leg, save_to='./img/burgers_test/burgers_dynamics', ylim =[-5,3] , frames=85)


In [None]:
# generate more data using the model
mu_gen_train=jnp.asarray([0.001, 0.0065, 0.012, 0.0175, 0.023, 0.0285, 0.0395, 0.045, 0.0505, 0.056, 0.0615, 0.072, 0.078, 0.0835, 0.089, 0.0945, 0.1])
mu_gen_test = jnp.asarray([0.034,0.067])

t= jnp.linspace(0, 1.0,51)
t_first,t_second=jnp.array_split(t,2,axis=0)

mu_t_gen_train = []
for i in range(len(mu_gen_train)):
    for j in range(len(t)):
        mu_t_gen_train.append([mu_gen_train[i],t[j]])
mu_t_gen_train = jnp.asarray(mu_t_gen_train)

mu_t_gen_train_first = []
for i in range(len(mu_gen_train)):
    for j in range(len(t_first)):
        mu_t_gen_train_first.append([mu_gen_train[i],t[j]])
mu_t_gen_train_first = jnp.asarray(mu_t_gen_train)

mu_t_gen_train_second = []
for i in range(len(mu_gen_train)):
    for j in range(len(t_second)):
        mu_t_gen_train_second.append([mu_gen_train[i],t[j]])
mu_t_gen_train_second = jnp.asarray(mu_t_gen_train)

mu_t_gen_test = []
for i in range(len(mu_gen_test)):
    for j in range(len(t)):
        mu_t_gen_test.append([mu_gen_test[i],t[j]])
mu_t_gen_test = jnp.asarray(mu_t_gen_test)

mu_t_gen_test_first = []
for i in range(len(mu_gen_test)):
    for j in range(len(t_first)):
        mu_t_gen_test_first.append([mu_gen_test[i],t[j]])
mu_t_gen_test_first = jnp.asarray(mu_t_gen_test)

mu_t_gen_test_second = []
for i in range(len(mu_gen_test)):
    for j in range(len(t_second)):
        mu_t_gen_test_second.append([mu_gen_test[i],t[j]])
mu_t_gen_test_second = jnp.asarray(mu_t_gen_test)




y_gen_train = predict(opt_psi_theta,mu_t_gen_train,X_grid)
y_gen_train_first= predict(opt_psi_theta,mu_t_gen_train_first,X_grid)

y_gen_test = predict(opt_psi_theta,mu_t_gen_test,X_grid)

y_gen_test_first = predict(opt_psi_theta,mu_t_gen_test_first,X_grid)

y_gen_train += 5e-3*jax.random.normal(jax.random.PRNGKey(123),y_gen_train.shape)
y_gen_train_first += 5e-3*jax.random.normal(jax.random.PRNGKey(123),y_gen_train_first.shape)

# these data are used to train NODE



**we now build a new hypernetwork to train only on the first half, and show that it doesn't extrapolate**

In [None]:
# Build h_new
_, h_fn_new, _, psi_init_new = build_colora(
    u_hat_config, h_config, x_dim, mu_t_dim, u_dim, lora_filter=['alpha'],period=[d_len_x, d_len_y], rank=rank, key=key)
print(h_fn_new)
h_new_v_mu_t = vmap(h_fn_new, in_axes=(None, 0)) 
# then we train h_fn_new on the first half of the time samples using the opt_theta
def relative_loss_fn_new(psi, mu_t, sols, X_grid):
    phis = h_new_v_mu_t(psi, mu_t)
    pred = u_hat_v_x_phi(opt_theta, phis, X_grid)
    loss = jnp.linalg.norm(
        sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
    return loss.mean()


In [None]:
# process only the first half

from colora.data import Dataset
dataset_first = Dataset(mu_t_train_first, X_grid, y_train_first, n_batches=15, key=key)

dataset_first = iter(dataset_first)
def args_fn_new():
    return next(dataset_first)


In [None]:

# start training
from colora.adam import adam_opt
psi_new = psi_init_new
opt_psi_new, loss_history_new = adam_opt(psi_new, relative_loss_fn_new, args_fn_new, steps=100, learning_rate=5e-3, verbose=True)


In [None]:
#store opt_psi_new
# import pickle
# with open('burgers_opt_psi_new.pkl', 'wb') as f:
#     pickle.dump(opt_psi_new, f)
#read it
import pickle
with open('burgers_opt_psi_new.pkl', 'rb') as f:
    opt_psi_new = pickle.load(f)

In [None]:
# plot the prediction
# test relative loss 
n_t=len(time)
opt_psi_new_opt_theta=(opt_psi_new,opt_theta)

pred = predict(opt_psi_new_opt_theta, mu_t_test, X_grid)

pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t, N1=n_x1, N2=n_x2)

pred_vec =  rearrange(pred, 'M Q T N1 N2 -> M (Q T N1 N2)') 
pred_vec_for_time = rearrange(pred, 'M Q T N1 N2 -> T (M Q N1 N2)')

test_vec =  rearrange(test_sols, 'M Q T N1 N2 -> M (Q T N1 N2)')
test_vec_for_time = rearrange(test_sols, 'M Q T N1 N2 -> T (M Q N1 N2)')

rel_err = np.linalg.norm(test_vec- pred_vec, axis=1)/np.linalg.norm(test_vec, axis=1)
mean_rel_err = rel_err.mean()


rel_err_over_time_h_new = np.linalg.norm(test_vec_for_time - pred_vec_for_time, axis=1) / np.linalg.norm(test_vec_for_time, axis=1)


print(f'Test mean relative error : {mean_rel_err:.2E}')


In [None]:


from colora.plot import imshow_movie

imshow_movie(pred[0][0], save_to='./img/burgers_test/burgers_new_h.gif', t=time, title='burgers: new h trained on [0,2.5]', tight=True, live_cbar=True, frames=85)


In [None]:
#also the trajectories
phis = h_v_mu_t(opt_psi_new, mu_t_test)

n_t=len(time)

phis = rearrange(phis, '(M T) D -> M T D', T=n_t)

phis_all.append(phis[0])

from colora.plot import trajectory_movie
leg= []
for i in range(phis.shape[-1]):
    lstr =rf'$\phi_{i}$'
    leg.append(lstr)
#plot latent trajectories for the 1st test mu.
trajectory_movie(phis[0], x=time, title='burgers: new h trained on [0,2.5]', ylabel=r'$\phi(t;\mu)$', legend=leg, save_to='./img/burgers_test/burgers_new_h_dynamics', frames=85) 


**Learn the dynamics using Neural ODE**

In [None]:
# Get the initial condition phi(0, mu) from the trained hypernetwork
def get_all_phi_0(psi, mu):
    mu_0 = jnp.column_stack((mu, jnp.zeros(mu.shape[0])))
    # normalize mu_0
    mu_0 = normalize(mu_0, mean, std)
    return h_v_mu_t(psi, mu_0)


# calculate phi0 for train mus
phi_0_mus_train = get_all_phi_0(opt_psi, mu_gen_train)
phi_0_mus_test= get_all_phi_0(opt_psi, mu_gen_test)


print(phi_0_mus_train.shape)
print(phi_0_mus_test.shape)


In [None]:
'''Proposed approach: \partial_t phi(t,mu ) = g(phi(t,mu), mu ;omega)
CHANGE ACTIVATION OF MLP ?
'''
from colora.NODE import NODE


# Quick example to check the shape
keygen = 123
phi_dim = 7
'''check this'''
mu_dim = 1   
hidden_dim = 10
depth = 2

g = NODE(phi_dim, mu_dim, hidden_dim, depth, keygen,'quadratic') # g takes in phi and mu and outputs the time derivative of phi
#record tree shape
omega,omega_def = jax.tree_util.tree_flatten(g)



In [None]:
# # NOT AUGMENTED definitions: train omega_theta together


# import equinox as eqx
# def predictNODE(omega_theta, omega_def, t_span, mus, X_grid, phi_0):
#     omega, theta = omega_theta
#     g=jax.tree_util.tree_unflatten(omega_def, omega)
#     g_forall_phi_mu = vmap(g, in_axes=(0, 0, None)) # pack (phi0_i,mu_i)

#     # get phis
#     phis=g_forall_phi_mu(phi_0, mus, t_span)
#     # reshape phis to match the shape of sols later
#     phis=phis.reshape(-1,phi_dim)

#     pred = u_hat_v_x_phi(theta, phis, X_grid)
#     return pred

# def lossNODE(omega_theta, omega_def, t_span, mus, sols, X_grid, phi_0):
#     #omega_flat, theta = omega_flat_theta
#     pred = predictNODE(omega_theta, omega_def, t_span, mus, X_grid,phi_0)

#     loss = jnp.linalg.norm(
#         sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
#     return loss.mean()

# t_span = t_span_first

# omega_theta = (omega, opt_theta)

# #g_forall_phi_mu(phi_0_mus_train, train_mus, t_span)
# loss=lossNODE(omega_theta, omega_def, t_span, train_mus, y_train_first, X_grid, phi_0_mus_train)
# print(loss)

# eqxgrad_lossNODE = eqx.filter_value_and_grad(lossNODE)


In [None]:
# train omega only: define predict and loss

import equinox as eqx
def predictNODE_freeze(omega, theta, omega_def, t_span, mus, X_grid, phi_0):
    g=jax.tree_util.tree_unflatten(omega_def, omega)
    g_forall_phi_mu = vmap(g, in_axes=(0, 0, None)) # pack (phi0_i,mu_i)

    # get phis
    phis=g_forall_phi_mu(phi_0, mus, t_span)
    # reshape phis to match the shape of sols later
    phis=phis.reshape(-1,phi_dim)

    pred = u_hat_v_x_phi(theta, phis, X_grid)
    return pred


def regularized_lossNODE_freeze(omega, theta, omega_def, t_span, mus, sols, X_grid, phi_0):
    # Original loss
    pred = predictNODE_freeze(omega, theta, omega_def, t_span, mus, X_grid, phi_0)

    main_loss = jnp.linalg.norm(sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
    
    g = jax.tree_util.tree_unflatten(omega_def, omega)
    g_for_all_mu = vmap(g, in_axes=(0, None, None))

    def odefunc_norm(phi, mu):
        return jnp.linalg.norm(g(phi, mu, t_span))
    
    reg_loss = jax.vmap(odefunc_norm)(phi_0, mus).mean()
    
    # Combine losses
    total_loss = main_loss.mean() + 1e-3 * reg_loss
    
    return total_loss

# eqxgrad_lossNODE_freeze = eqx.filter_value_and_grad(lossNODE_freeze)
eqxgrad_regularized_lossNODE = eqx.filter_value_and_grad(regularized_lossNODE_freeze)


In [None]:
# # test the time for taking regularized loss and normal oss
# import time as timer
# start=timer.time()
# regularized_lossNODE_freeze(omega, opt_theta, omega_def, t_span, train_mus, y_train_first, X_grid, phi_0_mus_train)
# end=timer.time()
# print(f'Time for regularized loss: {end-start}')
# start=timer.time()
# lossNODE_freeze(omega, opt_theta, omega_def, t_span, train_mus, y_train_first, X_grid, phi_0_mus_train)
# end=timer.time()
# print(f'Time for normal loss: {end-start}')


In [None]:
t_span = t_span_first

ys = y_train_first
mus = train_mus

n_steps=2


losses=[]
learning_rates=[]

theta = opt_theta

import optax
#define cosine decay
init_lr = 1e-3
decay_steps = n_steps  # Total number of training steps
lr_schedule = optax.cosine_decay_schedule(init_lr, 2*decay_steps)
optimizer=optax.chain(optax.adam(learning_rate=lr_schedule),
                      optax.clip_by_global_norm(1.0))

opt_state=optimizer.init(eqx.filter(omega, eqx.is_inexact_array))

@eqx.filter_jit
def update(omega,theta, omega_def, t_span, mus, sols, X_grid, phi_0, opt_state):
    
    loss,grad = eqxgrad_regularized_lossNODE(omega,theta, omega_def, t_span, mus, sols, X_grid, phi_0)

    updates, new_opt_state=optimizer.update(grad,opt_state)
    # we set the update of theta to zero
    new_omega=eqx.apply_updates(omega,updates)

    return new_omega, new_opt_state, loss


import time as timer
#check test loss
test_losses=[]
test_vec =  rearrange(test_sols, 'M Q T N1 N2 -> M (Q T N1 N2)')


for i in range(n_steps):
    start=timer.time()
    omega, opt_state, loss = update(omega, opt_theta, omega_def, t_span, mus, ys, X_grid, phi_0_mus_train, opt_state)
    losses.append(loss)
    lr = lr_schedule(i)
    learning_rates.append(lr)
    if i % 50 ==0:
        end=timer.time()
        #check test loss
        pred = predictNODE_freeze(omega,opt_theta, omega_def, time, test_mus, X_grid, phi_0_mus_test) 
        pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=len(time), N1=n_x1, N2=n_x2)
        
        pred_vec =  rearrange(pred, 'M Q T N1 N2 -> M (Q T N1 N2)') 
        rel_err = np.linalg.norm(test_vec- pred_vec, axis=1)/np.linalg.norm(test_vec, axis=1)
        mean_rel_err = rel_err.mean()
        test_losses.append(mean_rel_err)
        print(f"{i}th iter, train loss={loss}, test loss = {mean_rel_err} lr={lr} time={end-start}")

plt.figure(figsize=(12, 5))

# Plot training and testing losses
plt.subplot(1, 2, 1)
plt.plot(losses, label='Training loss')
plt.plot(np.arange(0, len(losses), 50), test_losses, label='Testing loss')
plt.xlabel('iteration')
plt.yscale('log')
plt.ylabel('loss')
plt.title('Vlasov NODE losses')
plt.legend()

# Plot learning rate
plt.subplot(1, 2, 2)
plt.plot(learning_rates)
plt.xlabel('iteration')
plt.ylabel('learning rate')
plt.title('Learning Rate Schedule')

plt.tight_layout()
plt.savefig('vlasov_NODE_losses_and_lr.png')
plt.show()



**Plotting**

In [None]:
#now we can plot the prediction given by the trained model
from colora.plot import imshow_movie
# we check the same relative loss on the full interval
t_span = time
n_t = len(t_span)

# switch original and augmented NODE here
#--------------

# modified for training omega only
pred = predictNODE_freeze(omega,opt_theta, omega_def, t_span, test_mus, X_grid, phi_0_mus_test)
#--------------

# # modified for augmented
# phi_0_mus_test_aug = jnp.concatenate([phi_0_mus_test,jnp.zeros((n_mu_test,aug_dim))],axis=1)
# print(phi_0_mus_test_aug)
# pred=predictNODE_aug(omega_theta, omega_def, t_span, test_mus, X_grid, phi_0_mus_aug_test)
#--------------

print(pred.shape)
pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t, N1=n_x1, N2=n_x2)

test_vec =  rearrange(test_sols, 'M Q T N1 N2 -> M (Q T N1 N2)')
test_vec_for_time = rearrange(test_sols, 'M Q T N1 N2 -> T (M Q N1 N2)')


pred_vec =  rearrange(pred, 'M Q T N1 N2 -> M (Q T N1 N2)') 
pred_vec_for_time = rearrange(pred, 'M Q T N1 N2 -> T (M Q N1 N2)')
rel_err_over_time_NODE = np.linalg.norm(test_vec_for_time - pred_vec_for_time, axis=1) / np.linalg.norm(test_vec_for_time, axis=1)


rel_err = np.linalg.norm(test_vec- pred_vec, axis=1)/np.linalg.norm(test_vec, axis=1)
mean_rel_err = rel_err.mean()


print(f'Test mean relative error : {mean_rel_err:.2E}')

# define check_loss function?


In [None]:

from colora.plot import imshow_movie

imshow_movie(pred[0][0], save_to='./img/burgers_test/burgers_NODE.gif', t=time, title='burgers NODE', tight=True, live_cbar=True, frames=85)


In [None]:
# and the dynamics of phi_i's by integrating g with the trained omega

opt_omega = omega

phi_0_mus_test = get_all_phi_0(opt_psi, test_mus)

g=jax.tree_util.tree_unflatten(omega_def,opt_omega)
g_forall_phi_mu = vmap(g, in_axes=(0, 0, None))

phis = g_forall_phi_mu(phi_0_mus_test, test_mus, t_span)
print(phis.shape)
phis_all.append(phis[0])

from colora.plot import trajectory_movie
leg= []
for i in range(phis.shape[-1]):
    lstr =rf'$\phi_{i}$'
    leg.append(lstr)
trajectory_movie(phis[0], x=t_span, title='burgers NODE', ylabel=r'$\phi(t;\mu)$', legend=leg, save_to='./img/burgers_test/burgers_NODE_dynamics', frames=85)


**Augmented NODE**

In [None]:
'''Augmented NODE: Definitions '''
# we try to add aux dimensions to the phi_0

from colora.NODE import NODE
keygen = 123
phi_dim = 7  #change this if needed 
mu_dim = 1   
hidden_dim = 10

aug_dim=7

depth = 1


phi_0_mus_aug_train=jnp.concatenate([phi_0_mus_train,jnp.zeros((n_mu_train,aug_dim))],axis=1)
phi_0_mus_aug_test=jnp.concatenate([phi_0_mus_test,jnp.zeros((n_mu_test,aug_dim))],axis=1)

g = NODE(phi_dim+aug_dim, mu_dim, hidden_dim, depth, keygen,'quadratic') # g is a neural ODE

omega,omega_def = jax.tree_util.tree_flatten(g)




In [None]:
'''Augmented NODE: Training 2: train omega  only '''
'''We train on the first half of time'''

import equinox as eqx


def predictNODE_aug(omega,theta, omega_def, t_span, mus, X_grid, phi_0_aug):
    g=jax.tree_util.tree_unflatten(omega_def, omega)
    g_forall_phi_mu = vmap(g, in_axes=(0, 0, None)) # pack (phi0_i,mu_i)

    # get phis
    phis=g_forall_phi_mu(phi_0_aug, mus, t_span)
    phis=phis[:,:,:phi_dim]
    # reshape phis to match the shape of sols later
    phis=phis.reshape(-1,phi_dim)

    pred = u_hat_v_x_phi(theta, phis, X_grid)
    return pred

def lossNODE_aug(omega, theta, omega_def, t_span, mus, sols, X_grid, phi_0_aug):
    #omega_flat, theta = omega_flat_theta
    pred = predictNODE_aug(omega, theta, omega_def, t_span, mus, X_grid,phi_0_aug)

    loss = jnp.linalg.norm(
        sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
    return loss.mean()

eqxgrad_lossNODE = eqx.filter_value_and_grad(lossNODE_aug)


#similarly we regularize it
def regularized_lossNODE_aug(omega, theta, omega_def, t_span, mus, sols, X_grid, phi_0_aug):
    # Original loss
    pred = predictNODE_aug(omega, theta, omega_def, t_span, mus, X_grid, phi_0_aug)

    main_loss = jnp.linalg.norm(sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
    
    g = jax.tree_util.tree_unflatten(omega_def, omega)
    
    def odefunc_norm(phi, mu):
        return jnp.linalg.norm(g(phi, mu, t_span))
    
    reg_loss = jax.vmap(odefunc_norm)(phi_0_aug, mus).mean()

    # Combine losses
    total_loss = main_loss.mean() + 1e-3 * reg_loss
    
    return total_loss

# eqxgrad_lossNODE_freeze = eqx.filter_value_and_grad(lossNODE_freeze)
eqxgrad_regularized_lossNODE_aug = eqx.filter_value_and_grad(regularized_lossNODE_aug)


In [None]:
t_span = t_span_first

ys = y_train_first
mus = train_mus

n_steps=2


losses=[]
lrs=[]

theta = opt_theta

import optax
#define cosine decay
init_lr = 1e-3
decay_steps = n_steps  # Total number of training steps
lr_schedule = optax.cosine_decay_schedule(init_lr, 2*decay_steps+1)
optimizer=optax.adam(learning_rate=lr_schedule)

opt_state=optimizer.init(eqx.filter(omega, eqx.is_inexact_array))

@eqx.filter_jit
def update(omega,theta, omega_def, t_span, mus, sols, X_grid, phi_0, opt_state):
    
    loss,grad = eqxgrad_regularized_lossNODE_aug(omega,theta, omega_def, t_span, mus, sols, X_grid, phi_0)

    updates, new_opt_state=optimizer.update(grad,opt_state)
    # we set the update of theta to zero
    new_omega=eqx.apply_updates(omega,updates)

    return new_omega, new_opt_state, loss


import time as timer

test_vec= rearrange(test_sols, 'M Q T N1 N2 -> M (Q T N1 N2)')
test_losses=[]

for i in range(n_steps):
    start=timer.time()
    omega, opt_state, loss = update(omega, opt_theta, omega_def, t_span, mus, ys, X_grid, phi_0_mus_aug_train, opt_state)
    losses.append(loss)
    lr = lr_schedule(i)
    lrs.append(lr)
    if i % 50 ==0:
        end=timer.time()
        pred=predictNODE_aug(omega,theta, omega_def, time, test_mus, X_grid, phi_0_mus_aug_test)
        pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=len(time), N1=n_x1, N2=n_x2)
        # we do the same as in NODE
        pred_vec =  rearrange(pred, 'M Q T N1 N2 -> M (Q T N1 N2)')
        rel_err = np.linalg.norm(test_vec- pred_vec, axis=1)/np.linalg.norm(test_vec, axis=1)
        mean_rel_err = rel_err.mean()
        test_losses.append(mean_rel_err)
        print(f"{i}th iter, train loss={loss}, test loss ={mean_rel_err} lr={lr} time={end-start}")


plt.figure(figsize=(12, 5))

# Plot training and testing losses
plt.subplot(1, 2, 1)
plt.plot(losses, label='Training loss')
plt.plot(np.arange(0, len(losses), 50), test_losses, label='Testing loss')
plt.xlabel('iteration')
plt.yscale('log')
plt.ylabel('loss')
plt.title('burgers NODE augmented losses')
plt.legend()

# Plot learning rate
plt.subplot(1, 2, 2)
plt.plot(learning_rates)
plt.xlabel('iteration')
plt.ylabel('learning rate')
plt.title('Learning Rate Schedule')

plt.tight_layout()
plt.savefig('burgers_NODE_aug_losses_and_lr.png')
plt.show()



In [None]:
# we check the same relative loss on the full interval
t_span = time
n_t = len(t_span)

#--------------
# modified for augmented
phi_0_mus_test_aug = jnp.concatenate([phi_0_mus_test,jnp.zeros((n_mu_test,aug_dim))],axis=1)
print(phi_0_mus_test_aug)

# if train omega and theta together
#pred=predictNODE(omega_theta, omega_def, t_span, test_mus, X_grid, phi_0_mus_test_aug)
# if only train mu:
pred=predictNODE_aug(omega,theta, omega_def, t_span, test_mus, X_grid, phi_0_mus_aug_test)
#--------------

print(pred.shape)
pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t, N1=n_x1, N2=n_x2)

pred_vec =  rearrange(pred, 'M Q T N1 N2 -> M (Q T N1 N2)') 
pred_vec_for_time = rearrange(pred, 'M Q T N1 N2 -> T (M Q N1 N2)')

test_vec =  rearrange(test_sols, 'M Q T N1 N2 -> M (Q T N1 N2)')
test_vec_for_time = rearrange(test_sols, 'M Q T N1 N2 -> T (M Q N1 N2)')

rel_err = np.linalg.norm(test_vec- pred_vec, axis=1)/np.linalg.norm(test_vec, axis=1)
mean_rel_err = rel_err.mean()


rel_err_over_time_NODE_aug = np.linalg.norm(test_vec_for_time - pred_vec_for_time, axis=1) / np.linalg.norm(test_vec_for_time, axis=1)


print(f'Test mean relative error : {mean_rel_err:.2E}')

from colora.plot import imshow_movie

imshow_movie(pred[0][0], save_to='./img/burgers_test/burgers_NODE_aug.gif', t=time, title='burgers NODE augmented', tight=True, live_cbar=True, frames=85)




In [None]:
# and the dynamics of phi_i's by integrating g with the trained omega

# #for trained together 
# opt_omega,reopt_theta = omega_theta

# # for only omega
opt_omega = omega
phi_0_mus_test = get_all_phi_0(opt_psi, test_mus)
phi_0_mus_aug_test = jnp.concatenate([phi_0_mus_test,jnp.zeros((n_mu_test,aug_dim))],axis=1)

g=jax.tree_util.tree_unflatten(omega_def,opt_omega)
g_forall_phi_mu = vmap(g, in_axes=(0, 0, None))

phis = g_forall_phi_mu(phi_0_mus_aug_test, test_mus, t_span)
print(phis.shape)
phis_all.append(phis[0][:,:phi_dim])

from colora.plot import trajectory_movie
leg= []
for i in range(phis.shape[-1]):
    lstr =rf'$\phi_{i}$'
    leg.append(lstr)
trajectory_movie(phis[0][:,:phi_dim], x=t_span, title='burgers NODE augmented', ylabel=r'$\phi(t;\mu)$', legend=leg, save_to='./img/burgers_test/burgers_NODE_aug_dynamics', frames=85)



In [None]:

#plot relative errors over time with NODE and augmented NODE together
plt.figure()
plt.yscale('log')
plt.plot(time, rel_err_over_time,label='Colora')
plt.plot(time, rel_err_over_time_h_new,label='Colora with new h')
plt.plot(time, rel_err_over_time_NODE,label='NODE')
plt.plot(time, rel_err_over_time_NODE_aug,label='NODE augmented')
plt.legend()
plt.xlabel('Time')
plt.ylabel('Relative error')
plt.title('Relative error over time after training')
plt.savefig('burgers_relative_error_over_time.png')


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib.gridspec import GridSpec

def multi_experiment_trajectory_snapshot(all_phis, times=None, title='', ylabel='', xlabel='Time', 
                                         param_names=None, exp_names=None, colors=None, linestyles=None, 
                                         ylim=None, save_to=None):
    
    n_experiments = len(all_phis)
    n_params = all_phis[0].shape[1]
    
    if times is None:
        times = [np.arange(len(phis)) for phis in all_phis]
    elif not isinstance(times, list):
        times = [times] * n_experiments
    
    fig = plt.figure(figsize=(16, 8))
    gs = GridSpec(1, 2, width_ratios=[3, 1])  # Create a grid with two columns
    
    ax = fig.add_subplot(gs[0])  # Main plot
    leg_ax = fig.add_subplot(gs[1])  # Legend
    leg_ax.axis('off')  # Turn off axis for legend
    
    if ylim is None:
        ylim = np.array([min(phis.min() for phis in all_phis), max(phis.max() for phis in all_phis)])
    xlim = [min(t.min() for t in times), max(t.max() for t in times)]
    
    if colors is None:
        colors = plt.cm.rainbow(np.linspace(0, 1, n_params))
    if linestyles is None:
        linestyles = ['-', '--', ':', '-.'] * (n_experiments // 4 + 1)
    if param_names is None:
        param_names = [f'$\\phi_{{{i}}}$' for i in range(n_params)]
    if exp_names is None:
        exp_names = [f'Exp {i+1}' for i in range(n_experiments)]
    
    for i, (phis, t) in enumerate(zip(all_phis, times)):
        for j in range(n_params):
            ax.plot(t, phis[:, j], color=colors[j], linestyle=linestyles[i])
    
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    
    # Create separate legends for parameters (colors) and experiments (line styles)
    param_legend = [plt.Line2D([0], [0], color=colors[i], lw=2) for i in range(n_params)]
    exp_legend = [plt.Line2D([0], [0], color='gray', linestyle=linestyles[i], lw=2) for i in range(n_experiments)]
    
    # Add the legends to the legend axis
    leg_ax.legend(param_legend, param_names, loc='upper left', title='Parameters')
    leg_ax.legend(exp_legend, exp_names, loc='upper left', bbox_to_anchor=(0, 0.6), title='Experiments')
    
    # Combine both legends
    combined_handles = param_legend + exp_legend
    combined_labels = param_names + exp_names
    combined_titles = None#['Params'] * n_params + ['Experiments'] * n_experiments
    
    # Create a new legend with both parameters and experiments
    leg = leg_ax.legend(combined_handles, combined_labels, loc='center', title=None)
    
    # Add subtitle for each group in the legend
    # for t, l in zip(combined_titles, leg.get_texts()):
    #     l.set_multialignment('left')
    #     l.set_text(f'{t}\n{l.get_text()}')
    
    plt.tight_layout()
    
    if save_to is not None:
        p = Path(save_to).with_suffix('.png')
        plt.savefig(p, dpi=300, bbox_inches='tight')
    
    plt.show()

In [None]:
# we plot the trajectories of phis on the same plot


times = [time, time, time,time]
ys=phis_all

param_names = [f'$\\phi_{i}$' for i in range(7)]
exp_names = ['original', 'new hyperNN', 'NODE', 'NODE augmented']
colors = plt.cm.rainbow(np.linspace(0, 1, 7))  # 7 distinct colors for 7 parameters
linestyles = ['-', '--', ':', '-.']

multi_experiment_trajectory_snapshot(ys, times=times, 
                                     title='Burgers latent dynamics comparison', 
                                     ylabel=r'$\phi(t;\mu)$', 
                                     param_names=param_names,
                                     exp_names=exp_names,
                                     colors=colors,
                                     linestyles=linestyles,
                                     save_to='./img/vlasov_test/burgers_comparison_dynamics')