# Neural ODE

System identification on a possibly sparse dataset of autonomous ODEs

In [None]:
import time
from tqdm import tqdm
import diffrax
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from interpolation import ZOHInterpolation


In [None]:
%matplotlib widget

In [None]:
seed = 1234
key = jr.PRNGKey(seed)
x0key, ukey, pkey = jr.split(key, 3) # initial state, input, system

In [None]:
# data
nx = 2 # number of states
nu = 1
meta_batch_size = 32 # number of sequences in the dataset
K = 10 # rea
seq_len = 1024 # length of each sequence (number of time steps, sampling time is irregular)

In [None]:
t0 = 0
t1 = 140
ts = jnp.linspace(t0, t1, seq_len)
dt0 = 0.1

x0 = jr.uniform(x0key, (meta_batch_size, K, nx))
u = jr.uniform(ukey, (meta_batch_size, K, seq_len, nu), minval=-1, maxval=1)
params_nominal = jnp.array([1, 0.1]) 
params = params_nominal * jr.uniform(pkey, (meta_batch_size, 2), minval=0.9, maxval=1.1)

def f_xu(x, u, args):
    """ Toy system: point mass with friction and force"""
    p, v = x # position, velocity
    F = u[0]    
    M, b = args
    dp = v
    dv = -b/M * v + 1/M * F
    dx = jnp.array([dp, dv])
    return dx

#f_xu(jnp.zeros(nx), jnp.zeros(nu), params)


In [None]:
def solve(ts, x0, u, params):
    u_fun = ZOHInterpolation(ts=ts, ys=u)
    def vector_field(t, y, args):
        x = y # state rename...
        ut = u_fun.evaluate(t)
        dx = f_xu(x, ut, args)
        return dx
    
    sol = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(vector_field),
        solver=diffrax.Tsit5(),
        t0=ts[0],
        t1=ts[-1],
        dt0=dt0,
        y0=x0,
        saveat=diffrax.SaveAt(ts=ts),
        args=params
    )
    return sol.ys

solve(ts, x0[0][0], u[0][0], params_nominal).shape

In [None]:
# simulate with randomized initial states and inputs, but nominal parameters
solve_reps = jax.vmap(solve, in_axes=(None, 0, 0, None)) # solve K repetitions for one system
solve_meta = jax.vmap(solve_reps, in_axes=(None, 0, 0, 0)) # solve meta_batch_size systems, K repetitions each
ys = solve_meta(ts, x0, u, params)
ys.shape

In [None]:
# plot some data
plt.figure()
plt.plot(ts.T, ys[:4, :, 0].T, "r")
plt.plot(ts.T, ys[:4, :, 1].T, "b");

In [None]:
# simulate with randomized initial states, inputs, and systems
ys = jax.vmap(solve, in_axes=(None, 0, 0, 0))(ts, x0, u, params)
ys.shape

In [None]:
# plot some data
plt.figure()
plt.plot(ts.T, ys[:4, :, 0].T, "r")
plt.plot(ts.T, ys[:4, :, 1].T, "b");