In [5]:
import jax
import jax.numpy as jnp
from jax import jit, random
from jax.random import PRNGKey
from jax.experimental.ode import odeint
from flax import linen as nn
from flax.training import train_state
import optax

# -------------- helper libraries -------------- #
import sys
import os
import time
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

collocation2_path = os.path.abspath(os.path.join('..', 'utils'))

# Add the directory to sys.path
if collocation2_path not in sys.path:
    sys.path.append(collocation2_path)

from collocation import compute_weights, lagrange_derivative
from interpolation import BarycentricInterpolation
from data_generation import generate_ode_data
from non_parametric_collocation import collocate_data
from optimization_pyomo import ODEOptimizationModel as ODEOptimizationModel
from pickle_func import pickle_data, unpickle_data

# -------------- regular neural nets --------------
from neural_net import NeuralODE

In [2]:
import importlib
import neural_net

importlib.reload(neural_net)
NeuralODE = neural_net.NeuralODE

import data_generation
importlib.reload(data_generation)
generate_ode_data = data_generation.generate_ode_data

import non_parametric_collocation
importlib.reload(non_parametric_collocation)
collocate_data = non_parametric_collocation.collocate_data

In [3]:
#---------------------------------------------DATA PARAMS---------------------------------------------#
N = 200
noise_level = 0.1
ode_type, params = "van_der_pol", {"mu": 1, "omega": 1}
start_time, end_time = 0, 15
spacing_type = "equally_spaced" # "equally_spaced" or "chebyshev"
initial_state = jnp.array([0.0, 1.0])

#--------------------------------------------GENERATE DATA--------------------------------------------#
t_vdp, y_vdp, y_noisy_vdp, true_derivatives_vdp = generate_ode_data(N, noise_level, ode_type, params, start_time, end_time, spacing_type, initial_state)

# numpy array is required for pyomo
y_noisy_vdp = np.array(jnp.squeeze(y_noisy_vdp))
t_vdp = np.array(jnp.squeeze(t_vdp))

#-----------------------------------------COLLOCATION MATRIX PREPARATION-------------------------------------------#
weights = compute_weights(t_vdp)
D_vdp = np.array(lagrange_derivative(t_vdp, weights))

#---------------------------------------------------TEST DATA--------------------------------------------#
t_test_vdp, y_test_vdp, _, _ = generate_ode_data(N*2, noise_level, ode_type, params, start_time, end_time*2, "uniform", initial_state)

#--------------------------------------------NON-PARAMETRIC COLLOCATION--------------------------------------------#
estimated_derivative_vdp, estimated_solution_vdp = collocate_data(y_noisy_vdp, t_vdp, 'EpanechnikovKernel', bandwidth=0.5)



In [4]:
y_pred_results = {}

layer_widths = [3, 64, 64, 2]
learning_rate = 1e-3
rng = PRNGKey(0)

neural_ode_model = NeuralODE(layer_widths=layer_widths, time_invariant=False)
state = neural_ode_model.create_train_state(rng, learning_rate)

start_timer = time.time()
# train(self, state, t, observed_data, y0, num_epochs=1000):
trained_state = neural_ode_model.train(state, t_vdp, y_noisy_vdp, initial_state, num_epochs=1000)
end_timer = time.time()
timer = end_timer - start_timer
print(f"Time elapsed: {timer}")

y_train_pred = neural_ode_model.neural_ode(trained_state.params, initial_state, t_vdp, trained_state)

#---------------------------------------------SAVE RESULT---------------------------------------------#
y1 = y_train_pred[-1, :]
# def neural_ode(self, params, y0, t, state):
y_test_pred = neural_ode_model.neural_ode(trained_state.params, y1, t_test_vdp[N-1:], trained_state)
y_pred_results[1] = {"y_train_pred": y_train_pred, "y_test_pred": y_test_pred,
                            "time_elapsed": timer, "y_noisy":y_noisy_vdp}


Epoch 0, Loss: 16.379037209220407
Epoch 100, Loss: 2.881042494026358
Epoch 200, Loss: 2.854533803414634


KeyboardInterrupt: 

In [None]:
# Data generation function
def generate_data(ts, key):
    y0 = jax.random.uniform(key, (2,), minval=-0.6, maxval=1)

    def f(t, y, args):
        return jnp.array([y[1], -y[0]])

    solver = diffrax.Tsit5()
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(diffrax.ODETerm(f), solver, t0=ts[0], t1=ts[-1], dt0=ts[1] - ts[0], y0=y0, saveat=saveat)
    return sol.ys

# Main function to train the model
def main():
    key = random.PRNGKey(0)
    ts = jnp.linspace(0, 10, 100)
    key, subkey = random.split(key)
    observed_data = generate_data(ts, subkey)
    y0 = observed_data[0]

    layer_widths = [2, 64, 64, 2]