In [1]:
from cde import NeuralCDE
import jax.random as jr
import numpy as np
import diffrax
import equinox as eqx
import optax
import jax

In [10]:
cde = NeuralCDE(128, 128, 128, 2, key=jr.key(4321))
d = np.load("data/Tsit.npz")
ts = d["tt"][:100]
us = d["sol"][:100]
coeffs = diffrax.backward_hermite_coefficients(ts, us)

@eqx.filter_jit
def compute_loss(net: NeuralCDE):
    out = net(us[0], ts, coeffs, unroll_out=True)
    return ((out - us)**2).mean()

grad_loss = eqx.filter_value_and_grad(compute_loss)
opt = optax.lion(1e-4)
opt_state = opt.init(eqx.filter(cde, eqx.is_inexact_array))

@eqx.filter_jit
def make_step(net, opt_state):
    loss, grads = grad_loss(net)
    updates, opt_state = opt.update(grads, opt_state, net)
    net = eqx.apply_updates(net, updates)
    return loss, net, opt_state

In [11]:
make_step(cde, opt_state)

(Array(11.502912, dtype=float32),
 NeuralCDE(
   func=Func(
     mlp=MLP(
       layers=(
         Linear(
           weight=f32[128,128],
           bias=f32[128],
           in_features=128,
           out_features=128,
           use_bias=True
         ),
         Linear(
           weight=f32[128,128],
           bias=f32[128],
           in_features=128,
           out_features=128,
           use_bias=True
         ),
         Linear(
           weight=f32[16384,128],
           bias=f32[16384],
           in_features=128,
           out_features=16384,
           use_bias=True
         )
       ),
       activation=<wrapped function softplus>,
       final_activation=<wrapped function tanh>,
       use_bias=True,
       use_final_bias=True,
       in_size=128,
       out_size=16384,
       width_size=128,
       depth=2
     ),
     data_size=128,
     hidden_size=128
   )
 ),
 (ScaleByLionState(count=Array(1, dtype=int32), mu=NeuralCDE(
    func=Func(
      mlp=MLP(
        lay