In [33]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [34]:
from jax import vmap
import jax.numpy as jnp
from pathlib import Path
import numpy as np
from einops import rearrange
import jax
import matplotlib.pyplot as plt
print(jax.devices())

[cuda(id=0)]


In [35]:
from colora.data import load_all_hdf5, split_data_by_mu, prepare_coordinate_data
data_dir = Path('./data')
data_path = data_dir / 'burgers'
mus, sols, spacing = load_all_hdf5(data_path)

train_mus = np.asarray([0.001, 0.00199, 0.00298, 0.00496, 0.00595, 0.00694, 0.00892, 0.01])
test_mus = np.asarray([0.00397, 0.00793])
train_sols, test_sols = split_data_by_mu(mus, sols, train_mus, test_mus) # mus X variables X time X space_x X space_y
n_mu_train, n_q, n_t, n_x1, n_x2 = train_sols.shape
n_mu_test, n_q, n_t, n_x1, n_x2 = test_sols.shape
time = spacing[1]
x_space = spacing[2]
y_space = spacing[3]

print(f'n_mu train: {n_mu_train}, n_mu test: {n_mu_test}, n_variables: {n_q}, n_time samples: {n_t}, n_x samples: {n_x1}, n_x2 samples: {n_x2}')


n_mu train: 8, n_mu test: 2, n_variables: 1, n_time samples: 51, n_x samples: 129, n_x2 samples: 129


In [36]:
y_train, mu_t_train, X_grid =  prepare_coordinate_data(spacing, train_mus, train_sols)
y_test, mu_t_test, X_grid =  prepare_coordinate_data(spacing, test_mus, test_sols)

In [37]:
def normalize(x, mean, std):
    return (x-mean)/std

mean, std = jnp.mean(mu_t_train, axis=0), jnp.std(mu_t_train, axis=0)
mu_t_train = normalize(mu_t_train, mean, std)
mu_t_test = normalize(mu_t_test, mean, std)


In [38]:
from colora.build import build_colora

key = jax.random.PRNGKey(1)

x_dim = 2
mu_t_dim = 2
u_dim = 1

d_len_x = x_space[-1] - x_space[0]
d_len_y = y_space[-1] - y_space[0]

"""modify for experiments
"""
u_layers = ['P', 'C', 'C', 'C', 'C', 'C', 'C', 'C'] # seven colora layers with 1 alpha each means we will have latent dim of 7
h_layers = ['D', 'D', 'D']
rank = 3
# made the nn smaller for taining purposes
u_hat_config = {'width': 25, 'layers': u_layers}
h_config = {'width': 15, 'layers': h_layers}

u_hat_fn, h_fn, theta_init, psi_init = build_colora(
    u_hat_config, h_config, x_dim, mu_t_dim, u_dim, lora_filter=['alpha'],period=[d_len_x, d_len_y], rank=rank, key=key)

In [39]:

h_v_mu_t = vmap(h_fn, in_axes=(None, 0)) # vmap over mu_t array to generate array of phis
u_hat_v_x =  vmap(u_hat_fn, in_axes=(None, None, 0)) # vmaped over x to generate solution field over space points
u_hat_v_x_phi =  vmap(u_hat_v_x, in_axes=(None, 0, None)) # vmaped over x to generate solution field over space points

In [40]:
def predict(psi_theta, mu_t, X_grid):
    psi, theta = psi_theta
    phis = h_v_mu_t(psi, mu_t)
    pred = u_hat_v_x_phi(theta, phis, X_grid)
    return pred

def relative_loss_fn(psi_theta, mu_t, sols, X_grid):
    pred = predict(psi_theta, mu_t, X_grid)
    loss = jnp.linalg.norm(
        sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
    return loss.mean()


In [41]:
from colora.data import Dataset

# the dataset is just responsible for batching the data over the mu_t tensor 
# in order to aviod memory overflow on the GPU
dataset = Dataset(mu_t_train, X_grid, y_train, n_batches=25, key=key)
dataset = iter(dataset)
def args_fn():
    return next(dataset)

from colora.adam import adam_opt

psi_theta = (psi_init, theta_init)
opt_psi_theta, loss_history = adam_opt(psi_theta, relative_loss_fn, args_fn, steps=50_000, learning_rate=5e-3, verbose=True)
opt_psi, opt_theta = opt_psi_theta


  0%|          | 0/50000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# # we first use pickle to save the psi and theta

# # trained 50k steps

# import pickle
# with open('opt_psi.pkl', 'wb') as f:
#     pickle.dump(opt_psi, f)
# with open('opt_theta.pkl', 'wb') as f:
#     pickle.dump(opt_theta, f)


In [None]:
# # load optimal psi and theta
# import pickle
# with open('opt_psi.pkl', 'rb') as f:
#     opt_psi = pickle.load(f)
# with open('opt_theta.pkl', 'rb') as f:
#     opt_theta = pickle.load(f)

# opt_psi_theta=(opt_psi,opt_theta)

In [None]:
# test relative loss
pred = predict(opt_psi_theta, mu_t_test, X_grid)
pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t, N1=n_x1, N2=n_x2)
test_vec =  rearrange(test_sols, 'M Q T N1 N2 -> M (Q T N1 N2)')
pred_vec =  rearrange(pred, 'M Q T N1 N2 -> M (Q T N1 N2)')
rel_err = np.linalg.norm(test_vec- pred_vec, axis=1)/np.linalg.norm(test_vec, axis=1)
mean_rel_err = rel_err.mean()
print(f'Test mean relative error: {mean_rel_err:.2E}')

In [None]:
from colora.plot import imshow_movie

imshow_movie(pred[0][0], save_to='./img/burgers_test/burgers.gif', t=jnp.linspace(0,5,251), title='Burgers', tight=True, live_cbar=True, frames=85)

In [None]:
phis = h_v_mu_t(opt_psi, mu_t_test)

phis = rearrange(phis, '(M T) D -> M T D', T=n_t)

from colora.plot import trajectory_movie
leg= []
for i in range(phis.shape[-1]):
    lstr =rf'$\phi_{i}$'
    leg.append(lstr)
#plot latent trajectories for the 1st test mu.
trajectory_movie(phis[0], x=time, title='Burgers', ylabel=r'$\phi(t;\mu)$', legend=leg, save_to='./img/burgers_test/burgers_dynamics', frames=85) 

**make prediction of the dynamics beyond the trained time frame**

In [None]:
''' somehow wrong: the movie starts at t=1
'''

# Extend the time span for prediction
extended_time = jnp.linspace(0, 4 * time[-1], 4 * len(time)-3)  # Example: 4 times time span

extended_mu_t_test = []
for mu in test_mus:
    for t in extended_time:
        extended_mu_t_test.append([mu, t])
extended_mu_t_test = jnp.array(extended_mu_t_test)


# Predict extended time

def predict_extended(psi_theta, mu_t, X_grid):
    psi, theta = psi_theta
    phis = h_v_mu_t(psi, mu_t)
    pred = u_hat_v_x_phi(theta, phis, X_grid)
    return pred

pred_extended = predict_extended(opt_psi_theta, extended_mu_t_test, X_grid)
n_t_extended = len(extended_time)
pred_extended = rearrange(pred_extended, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t_extended, N1=n_x1, N2=n_x2)

imshow_movie(pred_extended[0][0], frames=201, t=extended_time, save_to='./img/burgers_test/burgers_extended.gif', title='Burgers Extended', tight=True, live_cbar=True)



**Now relearn the dynamics with NODE with IC** $\phi_0(\mu)=h(0,\mu,opt\_psi)$

In [None]:
# Get the initial condition phi(0, mu) from the trained hypernetwork
def get_all_phi_0(psi, mu):
    mu_0 = jnp.column_stack((mu, jnp.zeros(mu.shape[0])))
    # normalize mu_0
    mu_0 = normalize(mu_0, mean, std)
    return h_v_mu_t(psi, mu_0)


# calculate phi0 for train mus
phi_0_mus_train = get_all_phi_0(opt_psi, train_mus)

print(phi_0_mus_train.shape)


In [None]:
'''Proposed approach: \partial_t phi(t,mu ) = g(phi(t,mu), mu ;omega)
CHANGE ACTIVATION OF MLP ?
'''
from colora.NODE import NODE
keygen = 123
phi_dim = 7  #change this if needed 
mu_dim = 1   
hidden_dim = 20
depth = 4

# build g
g = NODE(phi_dim, mu_dim, hidden_dim, depth, keygen) # g takes in phi and mu and outputs the time derivative of phi

#record tree shape
omega,omega_def = jax.tree_util.tree_flatten(g)


In [None]:
#not flattened pred and loss (see equinox faq)

#@eqx.filter_grad
# import functools as ft
# @ft.partial(jax.jit,static_argnums=1)
import equinox as eqx
def predictNODE(omega_theta, omega_def, t_span, mus, X_grid, phi_0):
    omega, theta = omega_theta
    g=jax.tree_util.tree_unflatten(omega_def, omega)
    g_forall_phi_mu = vmap(g, in_axes=(0, 0, None)) # pack (phi0_i,mu_i)

    # get phis
    phis=g_forall_phi_mu(phi_0, mus, t_span)
    # reshape phis to match the shape of sols later
    phis=phis.reshape(-1,phi_dim)

    pred = u_hat_v_x_phi(theta, phis, X_grid)
    return pred

def lossNODE(omega_theta, omega_def, t_span, mus, sols, X_grid, phi_0):
    #omega_flat, theta = omega_flat_theta
    pred = predictNODE(omega_theta, omega_def, t_span, mus, X_grid,phi_0)

    loss = jnp.linalg.norm(
        sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
    return loss.mean()

t_span = jnp.linspace(0.0, 1.0, 51)

omega_theta = (omega, opt_theta)

#g_forall_phi_mu(phi_0_mus_train, train_mus, t_span)
loss=lossNODE(omega_theta, omega_def, t_span, train_mus, y_train, X_grid, phi_0_mus_train)
print(loss)
eqxgrad_lossNODE = eqx.filter_value_and_grad(lossNODE)


In [None]:
# not flattened training

ys = y_train
mus = train_mus

import optax
optimizer=optax.adam(1e-3)
opt_state=optimizer.init(eqx.filter(omega_theta, eqx.is_inexact_array))

@eqx.filter_jit
def update(omega_theta, omega_def, t_span, mus, sols, X_grid, phi_0, opt_state):
    loss,grad = eqxgrad_lossNODE(omega_theta, omega_def, t_span, mus, sols, X_grid, phi_0)

    updates, new_opt_state=optimizer.update(grad,opt_state)
    
    new_omega_theta=eqx.apply_updates(omega_theta,updates)

    return new_omega_theta, new_opt_state, loss

n_steps=5000
losses=[]
import time as timer

for i in range(n_steps):
    start=timer.time()
    omega_theta, opt_state, loss = update(omega_theta, omega_def, t_span, mus, ys, X_grid, phi_0_mus_train, opt_state)
    losses.append(loss)
    
    if i % 100 ==0:
        end=timer.time()
        print(f"{i}th iter, loss={loss}, time={end-start}")
plt.plot(losses)
plt.xlabel('iteration')
plt.yscale('log')
plt.ylabel('loss')
plt.title('NODE loss')
plt.savefig('NODE_loss.png')
plt.show()



In [None]:
#now we have the updated omega_flat_theta, we first test its performance on the test set
ys = y_test
mus = test_mus
# note omega_flat_theta is trained
phi_0_mus_test = get_all_phi_0(opt_psi, test_mus)
loss_on_test=lossNODE(
    omega_theta, omega_def, t_span, mus, ys, X_grid, phi_0_mus_test)

print(loss_on_test)


In [None]:
# we check the same relative loss
pred = predictNODE(omega_theta, omega_def, t_span, test_mus, X_grid, phi_0_mus_test)#removed an argument for nonflattend 
pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t, N1=n_x1, N2=n_x2)

test_vec =  rearrange(test_sols, 'M Q T N1 N2 -> M (Q T N1 N2)')
pred_vec =  rearrange(pred, 'M Q T N1 N2 -> M (Q T N1 N2)')
rel_err = np.linalg.norm(test_vec- pred_vec, axis=1)/np.linalg.norm(test_vec, axis=1)
mean_rel_err = rel_err.mean()
print(f'Test mean relative error: {mean_rel_err:.2E}')

**Plotting**

In [None]:
#now we can plot the prediction given by the trained model
from colora.plot import imshow_movie

imshow_movie(pred[0][0], save_to='./img/burgers_test/burgers_NODE.gif', t=jnp.linspace(0,1,51), title='Burgers NODE', tight=True, live_cbar=True, frames=171)


In [None]:
# and the dynamics of phi_i's by integrating g with the trained omega
opt_omega,reopt_theta = omega_theta

phi_0_mus_test = get_all_phi_0(opt_psi, test_mus)

g=jax.tree_util.tree_unflatten(omega_def,opt_omega)
g_forall_phi_mu = vmap(g, in_axes=(0, 0, None))

phis = g(phi_0_mus_test[0], test_mus[0], t_span)
print(phis.shape)


from colora.plot import trajectory_movie
leg= []
for i in range(phis.shape[-1]):
    lstr =rf'$\phi_{i}$'
    leg.append(lstr)
trajectory_movie(phis, x=t_span, title='Burgers NODE', ylabel=r'$\phi(t;\mu)$', legend=leg, save_to='./img/burgers_test/burgers_NODE_dynamics', frames=85)


**Prediction beyond window with NODE**

In [None]:
# we predict it for longer time frame by integrating g over a longer time span
extended_time = jnp.linspace(0, 5 * time[-1], 5 * len(time))  # Example: double the time span
t_span = extended_time
phi_0_mus_test = get_all_phi_0(opt_psi, test_mus)
pred_extended = predictNODE(omega_theta, omega_def, t_span, test_mus, X_grid, phi_0_mus_test)
n_t_extended = len(extended_time)
pred_extended = rearrange(pred_extended, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t_extended, N1=n_x1, N2=n_x2)


imshow_movie(pred_extended[0][0], frames=100, t=extended_time, save_to='./img/burgers_test/burgers_NODE_extended.gif', title='Burgers NODE Extended', tight=True, live_cbar=True)



In [None]:
# and the dynamics of phi_i's by integrating g with the trained omega
opt_omega,reopt_theta = omega_theta
phi_0_mus_test = get_all_phi_0(opt_psi, test_mus)

g=jax.tree_util.tree_unflatten(omega_def,opt_omega)
g_forall_phi_mu = vmap(g, in_axes=(0, 0, None))

phis_extended = g(phi_0_mus_test[0], test_mus[0], extended_time)
print(phis.shape)


from colora.plot import trajectory_movie
leg= []
for i in range(phis.shape[-1]):
    lstr =rf'$\phi_{i}$'
    leg.append(lstr)
trajectory_movie(phis_extended, x=extended_time, title='Burgers NODE', ylabel=r'$\phi(t;\mu)$', legend=leg, save_to='./img/burgers_test/burgers_NODE_dynamics_extended', frames=85)


**Learn latent grid directly (not done)**

In [None]:
'''
This learns the time derivative of phi by constructi h_prime(mu,t,omega)=dphi/dt
'''
import equinox as eqx
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
import diffrax
import optax
#import numpy as np

class ODEFunc_NODE(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, input_dim, hidden_dim, output_dim, depth, key):
        super().__init__()
        self.mlp = eqx.nn.MLP(input_dim, output_dim, hidden_dim, depth, key=key)

    def __call__(self, t, y, mu):
        # Combine time, state, and mu into a single input vector
        input = jnp.concatenate([jnp.array([t]), y, mu])
        return self.mlp(input)

class NeuralODE_h_prime(eqx.Module):
    ode_func: ODEFunc_NODE

    def __init__(self, mu_t_dim, hidden_dim, phi_dim, depth, key):
        super().__init__()
        self.ode_func = ODEFunc(mu_t_dim, hidden_dim, phi_dim, depth, key)

    def __call__(self, y0, mu, t_span):
        def func(t, y, args):
            return self.ode_func(t, y, mu)

        solver = diffrax.Tsit5()
        saveat = diffrax.SaveAt(ts=t_span)
        sol = diffrax.diffeqsolve(
            diffrax.ODETerm(func),
            solver,
            t0=t_span[0],
            t1=t_span[-1],
            dt0=t_span[1] - t_span[0],
            y0=y0,
            saveat=saveat
        )
        return sol.ys

n_phi = 7
mu_t_dim=2 # 1 mu and 1 t
hidden_dim=10
output_dim=n_phi
depth=3

hp=NeuralODE_h_prime(mu_t_dim, hidden_dim, output_dim, depth, key)

params_hp, hp_def = jax.tree_util.tree_flatten(hp)


In [None]:
#calculate training materials
phis_train=h_v_mu_t(opt_psi, mu_t_train)
phis_test=h_v_mu_t(opt_psi, mu_t_test)


In [None]:
def predict_hp(params_hp, hp_def, t_span,mus, phi_0):
    hp=jax.tree_util.tree_unflatten(hp_def, params_hp)
    
    hp_forall_phi_mu = vmap(hp, in_axes=(0, 0, None)) # pack (phi0_i,mu_i)

    phis=hp_forall_phi_mu(phi_0, mus, t_span)
    # reshape phis to match the shape of sols later
    phis=phis.reshape(-1,n_phi)

    return phis
def loss_hp(params_hp,hp_def, t_span, mus, phi_target, phi_0):
    
    pred = predict_hp(params_hp, hp_def, t_span, mus, phi_0)
    loss = jnp.linalg.norm(
        phi_target - pred, axis=(1,2)) / jnp.linalg.norm(phis, axis=(1,2))
    return loss.mean()


t_span = jnp.linspace(0.0, 1.0, 51)

phi_target=phis_train

mus = train_mus

loss=loss_hp(params_hp, hp_def, t_span, mus, phi_target, phi_0_mus_train)
