In [1]:
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Dopri5


def f(t, y, args):
    return -y


term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

In [3]:
solution.t0

Array(0., dtype=float32)

In [11]:
solution.ys.shape

(1, 2)

In [1]:
import flax.linen as nn
from model.actor_critic_rnn import NeuralODE
import jax
import jax.numpy as jnp

rng = jax.random.PRNGKey(0)
coords = jnp.ones((1, 4))

model = NeuralODE(
    encoder=nn.Dense(10),
    derivative_net=nn.Dense(10),
    decoder=nn.Dense(4))
params = jax.jit(model.init)(rng, coords)

In [2]:
@jax.jit
def compute_loss(params, coords, true_coords):
    preds = model.apply(params, coords)
    return jnp.abs(preds - true_coords).sum()


grads = jax.grad(compute_loss)(params, coords, jnp.zeros_like(coords))