# Neural ODE

System identification on a possibly sparse dataset of autonomous ODEs

In [None]:
import time
from tqdm import tqdm
import diffrax
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
#from interpolation import ZOHInterpolation as Interpolation
from diffrax import LinearInterpolation as Interpolation


In [None]:
%matplotlib widget

In [None]:
seed = 1234
key = jr.PRNGKey(seed)

In [None]:
@jax.jit
def generate_batch(key, seq_len=1_000, batch_size=32, K=10):

    nx = 2 # number of states
    nu = 1
    dt0 = 1e-1
    x0key, ukey, pkey = jr.split(key, 3) # initial state, input, system
    ts = jnp.arange(seq_len) * dt0
    dt0 = ts[1] - ts[0]

    batch_x0 = jr.uniform(x0key, (batch_size, K, nx), minval=-1, maxval=1)
    batch_u = jr.uniform(ukey, (batch_size, K, seq_len, nu), minval=-1, maxval=1)
    params_nominal = jnp.array([1, 0.1]) 
    params = params_nominal * jr.uniform(pkey, (batch_size, 2), minval=0.9, maxval=1.1) # each parameter +/- 10%

    def f_xu(x, u, args):
        """ Toy system: point mass with friction and force"""
        p, v = x # position, velocity
        F, = u # or F = u[0]    
        M, b = args
        dp = v
        dv = -b/M * v + 1/M * F
        dx = jnp.array([dp, dv])
        return dx

    def solve(ts, x0, u, params):
        u_fun = Interpolation(ts=ts, ys=u)
        def vector_field(t, y, args):
            x = y # state rename...
            ut = u_fun.evaluate(t)
            dx = f_xu(x, ut, args)
            return dx
        
        sol = diffrax.diffeqsolve(
            terms=diffrax.ODETerm(vector_field),
            solver=diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=dt0,
            y0=x0,
            saveat=diffrax.SaveAt(ts=ts),
            args=params,
            max_steps=seq_len
        )
        return sol.ys

    solve_reps = jax.vmap(solve, in_axes=(None, 0, 0, None)) # solve K repetitions for one system
    solve_meta = jax.vmap(solve_reps, in_axes=(None, 0, 0, 0)) # solve meta_batch_size systems, K repetitions each
    batch_x = solve_meta(ts, batch_x0, batch_u, params)
       
    return batch_x, batch_u, params 

def generate_batches(key):
    while(True):
        key, subkey = jr.split(key, 2)
        yield generate_batch(subkey)

In [None]:
train_dl = generate_batches(key)
for idx, (batch_x, batch_u, batch_params) in tqdm(enumerate(train_dl)):
    if idx == 50:
        break

batch_u.shape

In [None]:
batch_idx = 10
rep_idx = 3
x = batch_x[batch_idx][rep_idx]
u = batch_u[batch_idx][rep_idx]
params = batch_params[batch_idx]
x0 = x[0]

import scipy

dt0 = 1e-1
seq_len=1_000; batch_size=32; K=10
ts = jnp.arange(seq_len) * dt0


def f_xu(x, u, args):
    """ Toy system: point mass with friction and force"""
    p, v = x # position, velocity
    F, = u # or F = u[0]    
    M, b = args
    dp = v
    dv = -b/M * v + 1/M * F
    dx = jnp.array([dp, dv])
    return dx

u_fun = Interpolation(ts=ts, ys=u)

def fun(t, y, args):
    x = y # state rename...
    ut = u_fun.evaluate(t)
    dx = f_xu(x, ut, args)
    return dx
        
out = scipy.integrate.solve_ivp(fun, t_span=(ts[0], ts[-1]), y0=x0, method='RK45', t_eval=ts, dense_output=False, args=(params,), max_step=dt0)
x_scipy = out.y.T

fix, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(x[:, 0])
ax[0].plot(x_scipy[:, 0])
ax[0].plot(x[:, 0] - x_scipy[:, 0])
ax[1].plot(x[:, 1])
ax[1].plot(x_scipy[:, 1])
ax[1].plot(x[:, 1] - x_scipy[:, 1])