# Neural ODE

System identification on a possibly sparse dataset of autonomous ODEs

In [None]:
import time
from tqdm import tqdm
import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
from flax import linen as nn
from typing import Sequence, Dict, Any
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

In [None]:
%matplotlib widget

In [None]:
# data
nx = 2
dataset_size = 2_000  # number of sequences in the dataset
seq_len = (
    40  # length of each sequence (number of time steps, sampling time is irregular)
)

# model
width_size = 64
depth = 2

# optimization
batch_size = 32
lr = 1e-3
steps = 10_000

# logging
print_every = 100

In [None]:
seed = 1234
key = jr.PRNGKey(seed)
ykey, tkey1, tkey2, model_key, loader_key = jr.split(key, 5)

In [None]:
# generate 40 irregularly-spaced time series from the Lotka-Volterra model

y0 = 8 * jr.uniform(ykey, (dataset_size, 2)) + 6

t0 = 0
t1 = 140 + jr.uniform(tkey1, (dataset_size,))
ts = jr.uniform(tkey2, (dataset_size, seq_len)) * (t1[:, None] - t0) + t0
ts = jnp.sort(ts)
dt0 = 0.1
args = jnp.array([0.1, 0.02, 0.4, 0.02])


def vector_field(t, y, args):
    prey, predator = y
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    d_y = jnp.array([d_prey, d_predator])
    return d_y


def solve(ts, y0):
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(vector_field),
        diffrax.Tsit5(),
        ts[0],
        ts[-1],
        dt0,
        y0,
        saveat=diffrax.SaveAt(ts=ts),
        args=args,
    )
    return sol.ys


ys = jax.vmap(solve)(ts, y0)

ts.shape, ys.shape

In [None]:
# normalize data

# mu = jnp.mean(jnp.mean(ys, axis=1), axis=0)#,
# std = jnp.mean(jnp.std(ys, axis=1), axis=0)
# ys = (ys - mu)/std

In [None]:
# plot some data
plt.figure()
plt.plot(ts[:4, :].T, ys[:4, :, 0].T, "r--*")
plt.plot(ts[:4, :].T, ys[:4, :, 1].T, "b--*")

In [None]:
class MLP(nn.Module):
    features: Sequence[int]
    layer_kwargs: Dict[str, Any] = None

    def setup(self):
        layer_kwargs = self.layer_kwargs if self.layer_kwargs is not None else {}
        self.layers = [nn.Dense(feat, **layer_kwargs) for feat in self.features]

    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = nn.tanh(x)
        return x

In [None]:
model = MLP(features=[width_size] * depth + [nx],
             layer_kwargs={"kernel_init": jax.nn.initializers.normal(stddev=1e-2)}
)
y, params = model.init_with_output(jax.random.key(0), jnp.ones(nx))
y.shape

In [None]:
def vector_field_nn(t, y, params):
    return model.apply(params, y)


vector_field_nn(0, jnp.ones(nx), params).shape

In [None]:
def simulate(params, ts, y0):
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(vector_field_nn),
        diffrax.Tsit5(),
        ts[0],
        ts[-1],
        dt0,
        y0,
        saveat=diffrax.SaveAt(ts=ts),
        args=params,
    )
    return sol.ys


ys_sim = simulate(params, ts[0], y0[0])

plt.figure()
plt.plot(ts[0], ys_sim[:, 0], "r*")

In [None]:
batched_sim = jax.vmap(simulate, in_axes=(None, 0, 0))
ys_sim = batched_sim(params, ts, ys[:, 0])
# plt.figure()
# plt.plot(ts[0], ys_sim[0, :, 0], "r*")

In [None]:
def loss_fn(params, ts, ys):
    batched_sim = jax.vmap(simulate, in_axes=(None, 0, 0))
    ys_sim = batched_sim(params, ts, ys[:, 0, :])
    return jnp.mean((ys - ys_sim) ** 2)


loss_fn(params, ts, ys)

In [None]:
loss_grad_fn = jax.jit(jax.value_and_grad(loss_fn))

loss, grad = loss_grad_fn(params, ts, ys)

In [None]:
loss, grad