In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import vmap
import jax.numpy as jnp
from pathlib import Path
import numpy as np
from einops import rearrange
import jax
import matplotlib.pyplot as plt

In [3]:
from colora.data import load_all_hdf5, split_data_by_mu, prepare_coordinate_data
problem = 'rde'
data_dir = Path('./data')
data_path = data_dir / problem
mus, sols, spacing = load_all_hdf5(data_path)

train_mus = np.asarray([2.0, 2.1, 2.2, 2.4, 2.5, 2.6, 2.8, 2.9, 3.0,3.1])
test_mus = np.asarray([2.3, 2.7])
train_sols, test_sols = split_data_by_mu(mus, sols, train_mus, test_mus) # mus X variables X time X space_x X space_y
n_mu_train, n_q, n_t, n_x1 = train_sols.shape
n_mu_test, n_q, n_t, n_x1 = test_sols.shape
time = spacing[1]
x_space = spacing[2]

print(f'n_mu train: {n_mu_train}, n_mu test: {n_mu_test}, n_variables: {n_q}, n_time samples: {n_t}, n_x samples: {n_x1}')


n_mu train: 10, n_mu test: 2, n_variables: 2, n_time samples: 101, n_x samples: 513


In [4]:
y_train, mu_t_train, X_grid =  prepare_coordinate_data(spacing, train_mus, train_sols)
y_test, mu_t_test, X_grid =  prepare_coordinate_data(spacing, test_mus, test_sols)

In [5]:
def normalize(x, mean, std):
    return (x-mean)/std

mean, std = jnp.mean(mu_t_train, axis=0), jnp.std(mu_t_train, axis=0)
mu_t_train = normalize(mu_t_train, mean, std)
mu_t_test = normalize(mu_t_test, mean, std)


In [6]:
from colora.build import build_colora

key = jax.random.PRNGKey(1)

domain_len = x_space[-1] - x_space[0]

x_dim = 1 # 1 spatial dim
mu_t_dim = 2
u_dim = 2 # we have two variables here

u_layers = ['P', 'C', 'C', 'C', 'C', 'C', 'C', 'C'] # seven colora layers with 1 alpha each means we will have laten dim of 7
h_layers = ['D', 'D', 'D']
rank = 3
full = True # here we allow 3 alphas per colora layer, resulting in a larger latent dimension 

u_hat_config = {'width': 25, 'layers': u_layers}
h_config = {'width': 15, 'layers': h_layers}

u_hat_fn, h_fn, theta_init, psi_init = build_colora(
    u_hat_config, h_config, x_dim, mu_t_dim, u_dim, lora_filter=['alpha'], period=[domain_len], rank=rank, full=full, key=key)

In [7]:

h_v_mu_t = vmap(h_fn, in_axes=(None, 0)) # vmap over mu_t array to generate array of phis
u_hat_v_x =  vmap(u_hat_fn, in_axes=(None, None, 0)) # vmaped over x to generate solution field over space points
u_hat_v_x_phi =  vmap(u_hat_v_x, in_axes=(None, 0, None)) # vmaped over x to generate solution field over space points

In [8]:
def predict(psi_theta, mu_t, X_grid):
    psi, theta = psi_theta
    phis = h_v_mu_t(psi, mu_t)
    pred = u_hat_v_x_phi(theta, phis, X_grid)
    return pred

def relative_loss_fn(psi_theta, mu_t, sols, X_grid):
    pred = predict(psi_theta, mu_t, X_grid)
    loss = jnp.linalg.norm(
        sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
    return loss.mean()



In [9]:
from colora.data import Dataset

# the dataset is just responsible for batching the data over the mu_t tensor 
# in order to aviod memory overflow on the GPU
dataset = Dataset(mu_t_train, X_grid, y_train, n_batches=15, key=key)
dataset = iter(dataset)
def args_fn():
    return next(dataset)

In [10]:
from colora.adam import adam_opt

psi_theta = (psi_init, theta_init)
opt_psi_theta, loss_history = adam_opt(psi_theta, relative_loss_fn, args_fn, steps=100, learning_rate=5e-3, verbose=True)
opt_psi, opt_theta = opt_psi_theta


  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 100/100 [00:05<00:00, 17.90it/s, loss=8.305E-02]


In [13]:

pred = predict(opt_psi_theta, mu_t_test, X_grid)
print(pred.shape)
pred = rearrange(pred, '(M T) (N1) Q -> M Q T N1', Q=n_q, T=n_t, N1=n_x1)
print(pred.shape)


(202, 513, 2)
(2, 2, 101, 513)


In [14]:
test_vec =  rearrange(test_sols, 'M Q T N1 -> M (Q T N1)')
pred_vec =  rearrange(pred, 'M Q T N1 -> M (Q T N1)')
print(test_vec.shape, pred_vec.shape)

rel_err = np.linalg.norm(test_vec- pred_vec, axis=1)/np.linalg.norm(test_vec, axis=1)
mean_rel_err = rel_err.mean()
print(f'Test mean relative error: {mean_rel_err:.2E}')

(2, 103626) (2, 103626)
Test mean relative error: 8.88E-02


In [None]:
from colora.plot import line_movie

line_movie([pred[0][0], test_sols[0][0]], save_to='./img/rde.gif', t=spacing[1], title='RDE', frames=100, x=x_space, legend=['CoLoRA', 'True'])

In [None]:
phis = h_v_mu_t(opt_psi, mu_t_test)
phis = rearrange(phis, '(M T) D -> M T D', T=n_t)

In [None]:
from colora.plot import trajectory_movie

trajectory_movie(phis[0], x=time, title='RDE', ylabel=r'$\phi(t;\mu)$', save_to='./img/rde_dynamics', frames=100)

**NeuralODE part**

In [None]:
# Get the initial condition phi(0, mu) from the trained hypernetwork
def get_all_phi_0(psi, mu):
    mu_0 = jnp.column_stack((mu, jnp.zeros(mu.shape[0])))
    # normalize mu_0
    mu_0 = normalize(mu_0, mean, std)
    return h_v_mu_t(psi, mu_0)


# calculate phi0 for train mus
phi_0_mus_train = get_all_phi_0(opt_psi, train_mus)

print(phi_0_mus_train.shape)

'''Proposed approach: \partial_t phi(t,mu ) = g(phi(t,mu), mu ;omega)
CHANGE ACTIVATION OF MLP ?
'''

import equinox as eqx
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
import diffrax

# Define the ODE function approximator using an MLP
class ODEFunc(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, input_dim, hidden_dim, output_dim, depth, key):
        super().__init__()
        self.mlp = eqx.nn.MLP(input_dim, output_dim, hidden_dim, depth,activation=jax.nn.relu, key=key)

    def __call__(self, t, phi_mu):
        # Make sure mu is correctly broadcasted
        #mu_expanded = jnp.broadcast_to(mu, (phi.shape[0], 1))  # Ensure mu is repeated for each batch of phi
        inputs = phi_mu
        outputs=self.mlp(inputs)
        value_outputs=jax.device_get(outputs)
        print("MLP output at time", t, ":", value_outputs)

        return outputs

# Define the Neural ODE class

class NeuralODE(eqx.Module):
    ode_func: ODEFunc

    def __init__(self, phi_dim, mu_dim, hidden_dim, depth, key):
        super().__init__()
        self.ode_func = ODEFunc(phi_dim+mu_dim, hidden_dim, phi_dim, depth, key)  # output_dim matches phi_dim

    def __call__(self, phi0, mu, t_span):
        def func(t, phi, args):
            #concatenate phi and mu
            phi_mu=jnp.concatenate((jnp.array(phi), jnp.array([mu])), axis=0)
            return self.ode_func(t, phi_mu)

        solver = diffrax.Tsit5()
        saveat = diffrax.SaveAt(ts=t_span)
        sol = diffrax.diffeqsolve(
            diffrax.ODETerm(func),
            solver,
            t0=t_span[0],
            t1=t_span[-1],
            dt0=t_span[1] - t_span[0],
            y0=phi0,
            saveat=saveat
        )
        return sol.ys

# Quick example to check the shape
key = PRNGKey(13131)
phi_dim = 7
mu_dim = 1   
hidden_dim = 10
depth = 3

# build g
g = NeuralODE(phi_dim, mu_dim, hidden_dim, depth, key) # g takes in phi and mu and outputs the time derivative of phi

#record tree shape
omega,omega_def = jax.tree_util.tree_flatten(g)

print(omega)
# Prediction and Loss

#@eqx.filter_grad
# import functools as ft
# @ft.partial(jax.jit,static_argnums=1)

def predictNODE(omega_flat_theta, g_minus_omega, omega_def, t_span, mus, X_grid, phi_0):
    omega_flat, theta = omega_flat_theta
    omega_tree = jax.tree_util.tree_unflatten(omega_def, omega_flat)
    g=eqx.combine(g_minus_omega,omega_tree)
    g_forall_phi_mu = vmap(g, in_axes=(0, 0, None)) # pack (phi0_i,mu_i)

    # get phis
    phis=g_forall_phi_mu(phi_0, mus, t_span)
    # reshape phis to match the shape of sols later
    phis=phis.reshape(-1,phi_dim)


    pred = u_hat_v_x_phi(theta, phis, X_grid)
    return pred

def lossNODE(omega_flat_theta, g_minus_omega, omega_def, t_span, mus, sols, X_grid, phi_0):
    #omega_flat, theta = omega_flat_theta
    pred = predictNODE(omega_flat_theta, g_minus_omega, omega_def, t_span, mus, X_grid,phi_0)

    loss = jnp.linalg.norm(
        sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
    return loss.mean()

t_span = jnp.linspace(0.0, 1.0, 51)

omega_flat_theta = (omega_flat, opt_theta)

ys = y_train
mus = train_mus
#g_forall_phi_mu(phi_0_mus_train, train_mus, t_span)
loss=lossNODE(omega_flat_theta, g_minus_omega, omega_def, t_span, mus[0:1], ys[0:1], X_grid, phi_0_mus_train[0:1])
print(loss)
eqxgrad_lossNODE = eqx.filter_value_and_grad(lossNODE)

#not flattened pred and loss (see equinox faq)

#@eqx.filter_grad
# import functools as ft
# @ft.partial(jax.jit,static_argnums=1)

def predictNODE(omega_theta, omega_def, t_span, mus, X_grid, phi_0):
    omega, theta = omega_theta
    g=jax.tree_util.tree_unflatten(omega_def, omega)
    g_forall_phi_mu = vmap(g, in_axes=(0, 0, None)) # pack (phi0_i,mu_i)

    # get phis
    phis=g_forall_phi_mu(phi_0, mus, t_span)
    # reshape phis to match the shape of sols later
    phis=phis.reshape(-1,phi_dim)

    pred = u_hat_v_x_phi(theta, phis, X_grid)
    return pred

def lossNODE(omega_theta, omega_def, t_span, mus, sols, X_grid, phi_0):
    #omega_flat, theta = omega_flat_theta
    pred = predictNODE(omega_theta, omega_def, t_span, mus, X_grid,phi_0)

    loss = jnp.linalg.norm(
        sols - pred, axis=(1,2)) / jnp.linalg.norm(sols, axis=(1,2))
    return loss.mean()

t_span = jnp.linspace(0.0, 1.0, 51)

omega_theta = (omega, opt_theta)

#g_forall_phi_mu(phi_0_mus_train, train_mus, t_span)
loss=lossNODE(omega_theta, omega_def, t_span, train_mus, y_train, X_grid, phi_0_mus_train)
print(loss)
eqxgrad_lossNODE = eqx.filter_value_and_grad(lossNODE)

#test: compare eqxgrad with jax grad
eqxgrad_lossNODE = eqx.filter_value_and_grad(lossNODE)
#grad_lossNODE=jax.value_and_grad(lossNODE)

import time as timer

#compare the calculation speed
t_span=jnp.linspace(0.0,1.0,51)

start=timer.time()

#compare them on the whole trainign set
eqx_loss,eqx_grad=eqxgrad_lossNODE(omega_flat_theta, omega_def, t_span, mus, ys, X_grid, phi_0_mus_train)
end=timer.time()
'''problematic here'''
#full_loss,full_grad=grad_lossNODE(omega_flat_theta, omega_def, t_span, mus, ys, X_grid, phi_0_mus_train)
#end1=timer.time()


print(full_loss)
print(full_eqx_loss)
print(end-start)
print(end1-end)
#training 
ys = y_train
mus = train_mus
import optax
optimizer=optax.adam(1e-3)
opt_state=optimizer.init(omega_flat_theta)
#@eqx.filter_jit
def update(omega_flat_theta, g_minus_omega, omega_def, t_span, mus, sols, X_grid, phi_0, opt_state):
    loss,grad = eqxgrad_lossNODE(omega_flat_theta, g_minus_omega, omega_def, t_span, mus, sols, X_grid, phi_0)

    updates, new_opt_state=optimizer.update(grad,opt_state)
    
    new_omega_flat_theta=optax.apply_updates(omega_flat_theta,updates)

    return new_omega_flat_theta, new_opt_state, loss

n_steps=2
losses=[]
import time as timer
for i in range(n_steps):
    start=timer.time()
    omega_flat_theta, opt_state, loss=update(omega_flat_theta, g_minus_omega, omega_def, t_span, mus, ys, X_grid, phi_0_mus_train, opt_state)
    losses.append(loss)
    end=timer.time()
    print(f"{i}th iter, loss={loss}, time={end-start}")
print(losses)


# not flattened training
#training 
ys = y_train
mus = train_mus
import optax
optimizer=optax.adam(1e-3)
opt_state=optimizer.init(eqx.filter(omega_theta, eqx.is_inexact_array))

@eqx.filter_jit
def update(omega_theta, omega_def, t_span, mus, sols, X_grid, phi_0, opt_state):
    loss,grad = eqxgrad_lossNODE(omega_theta, omega_def, t_span, mus, sols, X_grid, phi_0)

    updates, new_opt_state=optimizer.update(grad,opt_state)
    
    new_omega_theta=eqx.apply_updates(omega_theta,updates)

    return new_omega_theta, new_opt_state, loss

n_steps=1
losses=[]
import time as timer

for i in range(n_steps):
    start=timer.time()
    omega_theta, opt_state, loss = update(omega_theta, omega_def, t_span, mus, ys, X_grid, phi_0_mus_train, opt_state)
    losses.append(loss)
    end=timer.time()
    print(f"{i}th iter, loss={loss}, time={end-start}")
print(losses)


#now we have the updated omega_flat_theta, we first test its performance on the test set
ys = y_test
mus = test_mus
# note omega_flat_theta is trained
phi_0_mus_test = get_all_phi_0(opt_psi, test_mus)
loss_on_test=lossNODE(
    omega_theta, omega_def, t_span, mus, ys, X_grid, phi_0_mus_test)

print(loss_on_test)

# we check the same relative loss
pred = predictNODE(omega_theta, omega_def, t_span, test_mus, X_grid, phi_0_mus_test)#removed an argument for nonflattend 
pred = rearrange(pred, '(M T) (N1 N2) Q -> M Q T N1 N2', Q=n_q, T=n_t, N1=n_x1, N2=n_x2)

test_vec =  rearrange(test_sols, 'M Q T N1 N2 -> M (Q T N1 N2)')
pred_vec =  rearrange(pred, 'M Q T N1 N2 -> M (Q T N1 N2)')
rel_err = np.linalg.norm(test_vec- pred_vec, axis=1)/np.linalg.norm(test_vec, axis=1)
mean_rel_err = rel_err.mean()
print(f'Test mean relative error: {mean_rel_err:.2E}')
**Plotting**
#now we can plot the prediction given by the trained model
from colora.plot import imshow_movie

imshow_movie(pred[0][0], save_to='./img/burgers.gif', t=jnp.linspace(0,1,51), title='Burgers', tight=True, live_cbar=True, frames=85)
print(phi_0_mus_test.shape)
# and the dynamics of phi_i's by integrating g with the trained omega
opt_omega,reopt_theta = omega_theta
g=jax.tree_util.tree_unflatten(omega_def,opt_omega)
g_forall_phi_mu = vmap(g, in_axes=(0, 0, None))

phis = g(phi_0_mus_test[0], test_mus[0], t_span)
print(phis.shape)


from colora.plot import trajectory_movie
leg= []
for i in range(phis.shape[-1]):
    lstr =rf'$\phi_{i}$'
    leg.append(lstr)
trajectory_movie(phis, x=t_span, title='Burgers', ylabel=r'$\phi(t;\mu)$', legend=leg, save_to='./img/burgers_dynamics', frames=85)
