# Neural ODE

System identification on a possibly sparse dataset of autonomous ODEs

In [None]:
import time
from tqdm import tqdm
import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
from flax import nnx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

In [None]:
%matplotlib widget

In [None]:
# data
nx = 2
dataset_size = 2_000 # number of sequences in the dataset
seq_len = 40 # length of each sequence (number of time steps, sampling time is irregular)

# model
width=64
depth=2

# optimization
batch_size = 32
lr = 1e-3
steps = 10_000

# logging
print_every = 100

In [None]:
seed = 1234
key = jr.PRNGKey(seed)
ykey, tkey1, tkey2, model_key, loader_key = jr.split(key, 5)

In [None]:
# generate 40 irregularly-spaced time series from the Lotka-Volterra model

y0 = 8*jr.uniform(ykey, (dataset_size, 2)) + 6

t0 = 0
t1 = 140 + jr.uniform(tkey1, (dataset_size,))
ts = jr.uniform(tkey2, (dataset_size, seq_len)) * (t1[:, None] - t0) + t0
ts = jnp.sort(ts)
dt0 = 0.1
params = jnp.array([0.1, 0.02, 0.4, 0.02]) 

def vector_field(t, y, args):
    prey, predator = y    
    α, β, γ, δ = args
    d_prey = α * prey - β * prey * predator
    d_predator = -γ * predator + δ * prey * predator
    d_y = jnp.array([d_prey, d_predator])
    return d_y

def solve(ts, y0):
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(vector_field),
        diffrax.Tsit5(),
        ts[0],
        ts[-1],
        dt0,
        y0,
        saveat=diffrax.SaveAt(ts=ts),
        args=params
    )
    return sol.ys

ys = jax.vmap(solve)(ts, y0)

ts.shape, ys.shape

In [None]:
# normalize data

mu = jnp.mean(jnp.mean(ys, axis=1), axis=0)#,
std = jnp.mean(jnp.std(ys, axis=1), axis=0)
ys = (ys - mu)/std

In [None]:
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

In [None]:
train_dl = dataloader((ts, ys), batch_size, key=loader_key)

In [None]:
# plot some data
plt.figure()
plt.plot(ts[:4, :].T, ys[:4, :, 0].T, "r--*")
plt.plot(ts[:4, :].T, ys[:4, :, 1].T, "b--*");

In [None]:
class Func(nnx.Module):

    def __init__(self, nx, rngs:nnx.Rngs):
        self.linear1 = nnx.Linear(nx, width, rngs=rngs)
        self.linear2 = nnx.Linear(width, width, rngs=rngs)
        self.linear3 = nnx.Linear(width, nx, rngs=rngs)

    def __call__(self, t, x, args):
        h = nnx.gelu(self.linear1(x))
        h = nnx.gelu(self.linear2(h))
        dx = self.linear3(h)
        return dx * 1e-3
    
class NeuralOde(nnx.Module):
    def __init__(self, func: Func,  **kwargs):
        super().__init__(**kwargs)
        self.func = func
    
    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        return solution.ys

In [None]:
vector_field_nn = Func(nx, rngs=nnx.Rngs(model_key)) # how to mix jax and nnx rngs?
nnx.display(vector_field_nn)

In [None]:
# x=jnp.ones((3, nx)) # (B, din)
# y = vector_field_nn(_, jnp.ones((3, 2)), _)
# y

In [None]:
simulator = NeuralOde(vector_field_nn) # how to mix jax and nnx rngs?
nnx.display(simulator)

In [None]:
ys_sim = simulator(ts[0], ys[0, 0])

plt.figure()
plt.plot(ts[0], ys_sim[:, 0], "r*")

In [None]:
def loss_fn(simulator, t, y):
  batched_simulator = nnx.vmap(simulator)
  y_sim = batched_simulator(t, y[:, 0, :])
  return jnp.mean((y - y_sim)**2)

loss_fn(simulator, ts, ys)

In [None]:
loss_grad_fn = nnx.jit(nnx.value_and_grad(loss_fn))
loss, grad = loss_grad_fn(simulator, ts, ys)

In [None]:
@nnx.jit
def train_step(model, optimizer: nnx.Optimizer, t, y):
  """Train for a single step."""
  loss, grads = loss_grad_fn(model, t, y)
  optimizer.update(grads)
  return loss

In [None]:
optimizer = nnx.Optimizer(simulator, optax.adamw(lr))

In [None]:
LOSS = []
pbar = tqdm(enumerate(train_dl), total=steps)
for step, (ts_batch, ys_batch) in pbar:
    loss = train_step(simulator, optimizer, ts_batch, ys_batch)
    if step % print_every == 0:
        pbar.set_postfix_str(f"Step: {step}, Loss: {loss}")
    LOSS.append(loss)
    if step == steps - 1:
        break

In [None]:
plt.figure()
plt.plot(LOSS)

In [None]:
val_ts, val_ys = next(train_dl)
# jax.vmap(model)(val_ts, val_ys[:, 0])


def dense_grid(ts):
    return jnp.linspace(ts[0], ts[-1], 1000)


val_ts_dense = jax.vmap(dense_grid)(val_ts)
val_yhat_dense = nnx.vmap(simulator)(val_ts_dense, val_ys[:, 0])

idx = 10
plt.figure()
plt.plot(val_ts[idx], val_ys[idx, :, 0], "r*")
plt.plot(val_ts_dense[idx], val_yhat_dense[idx, :, 0], "r")

plt.plot(val_ts[idx], val_ys[idx, :, 1], "b*")
plt.plot(val_ts_dense[idx], val_yhat_dense[idx, :, 1], "b")