# Neural ODE

System identification on a possibly sparse dataset of autonomous ODEs

In [None]:
import time
from tqdm import tqdm
import diffrax
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
#from interpolation import ZOHInterpolation as Interpolation
from diffrax import LinearInterpolation as Interpolation


In [None]:
%matplotlib widget

In [None]:
seed = 1234
key = jr.PRNGKey(seed)
x0key, ukey, pkey = jr.split(key, 3) # initial state, input, system

In [None]:
# data
nx = 2 # number of states
nu = 1
dataset_size = 64 # number of sequences in the dataset
seq_len = 10_000 # length of each sequence (number of time steps, sampling time is irregular)

In [None]:
def multisine(N, pmin, pmax, P, key):
    uf = jnp.zeros((N//2 + 1,), dtype=complex)
    for p in range(pmin, pmax):
        key, subkey = jr.split(key)
        uf = uf.at[p].set(jnp.exp(1j*jr.uniform(subkey, minval=0, maxval=jnp.pi*2)))

    uk = jnp.fft.irfft(uf/2)
    uk /= jnp.std(uk)
    uk = jnp.concatenate([uk] * P)
    return uk

def multisines(N, pmin, pmax, P, batch_size, key):
    keys = jr.split(key, batch_size)
    uk = jax.vmap(multisine, in_axes=(None, None, None, None, 0))(N, pmin, pmax, P, keys)
    return uk

In [None]:
def f_xu(x, u, args):
    """ Duffing oscillator"""
    p, v = x # position, velocity
    alpha, beta, delta, gamma, omega = args
    F = u[0]   
    #F = gamma * jnp.cos(omega * t)
    dp = v
    dv = -delta * v -alpha * p  -beta * p**3 + F
    dx = jnp.array([dp, dv])
    return dx

In [None]:
dt = 0.005
ts = jnp.arange(seq_len) * dt

t0 = ts[0]
t1 = ts[-1]

x0 = jr.uniform(x0key, (dataset_size, nx), minval=-1, maxval=1)
u = multisines(seq_len, pmin=1, pmax=21, P=1, batch_size=dataset_size, key=ukey)
u = u[..., None]
#u = jr.uniform(ukey, (dataset_size, seq_len, nu), minval=-1, maxval=1)
params_nominal = jnp.array([1.0, 5.0, 0.02, 8., 0.5]) 
params = params_nominal * jr.uniform(pkey, (dataset_size, params_nominal.shape[0]), minval=0.9, maxval=1.1)

In [None]:
plt.figure()
plt.plot(u[0, :, 0])
plt.plot(u[1, :, 0])

In [None]:
def solve_diffrax(ts, x0, u, params):
    u_fun = Interpolation(ts=ts, ys=u)
    def vector_field(t, y, args):
        x = y # state rename...
        ut = u_fun.evaluate(t)
        dx = f_xu(x, ut, args)
        return dx
    
    sol = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(vector_field),
        #solver=diffrax.Tsit5(),
        #solver=diffrax.Dopri5(),
        solver=diffrax.Euler(),
        t0=ts[0],
        t1=ts[-1],
        dt0=dt,
        y0=x0,
        saveat=diffrax.SaveAt(ts=ts),
        #stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-6, jump_ts=ts),

        stepsize_controller = diffrax.ConstantStepSize(),
        #stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-6),
        args=params,
        max_steps=None
    )
    return sol.ys

solve_diffrax(ts, x0[0], u[0], params_nominal).shape

In [None]:
#for _ in range(10)

In [None]:
# simulate with randomized initial states and inputs, but nominal parameters
x = jax.jit(jax.vmap(solve_diffrax, in_axes=(None, 0, 0, 0)))(ts, x0, u, params)
x.shape

In [None]:
# plot some data
plt.figure()
plt.plot(ts.T, x[:1, :, 0].T, "r", label="p")
plt.plot(ts.T, x[:1, :, 1].T, "b", label="v")
plt.legend()

In [None]:
def discretize_euler(fun_ct, dt):
    def fun_rk(x, u, args):
        dt2 = dt/2
        k1 = fun_ct(x, u, args)
        k2 = fun_ct(x + dt2 * k1, u, args)
        k3 = fun_ct(x + dt2 * k2, u, args)
        k4 = fun_ct(x + dt * k3, u, args)
        dx = dt / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4)
        x_new = x + dx
        return x_new, x
    return fun_rk

def discretize_euler(fun_ct, dt):
    def fun_euler(x, u, args):
        k1 = fun_ct(x, u, args)
        dx = dt * k1
        x_new = x + dx
        return x_new, x
    return fun_euler

def solve_dt(fn_ct, ts, x0, u, args):
    #fn_rk = discretize_euler(fn_ct, ts)
    fn_dt = discretize_euler(fn_ct, ts)
    _, x_sim = jax.lax.scan(lambda x, u: fn_dt(x, u, args), x0, u)
    return x_sim

solve_dt_batch = jax.vmap(solve_dt, in_axes=(None, None, 0, 0, 0))

x_rk4 = solve_dt_batch(f_xu, dt, x0, u, params)

In [None]:
idx = 3
fix, ax = plt.subplots(1, 2, figsize=(12, 4))

ax[0].plot(x[idx, :, 0], "k")
ax[0].plot(x_rk4[idx, :, 0], "b")
ax[0].plot(x[idx,  :, 0] - x_rk4[idx, :, 0], "r")

ax[1].plot(x[idx, :, 1])
ax[1].plot(x_rk4[idx, :, 1])
ax[1].plot(x[idx, :, 1] - x_rk4[idx, :, 1], "r")