In [20]:
import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
from clu import parameter_overview

In [13]:
class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *args, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            key=key,
        )

    def __call__(self, t, y, *args):
        return self.mlp(y)

In [37]:
data_size = 1
width_size = 4
depth = 1
key = jax.random.PRNGKey(20)

control = Func(data_size, width_size, depth, key=key)
print("Behold the control term: \n----------------------- \n", control)

T = 10
N = 1000
ts = jnp.linspace(0, 10, N+1)
y0 = jnp.array([1])

print("\nTesting the control term: \n----------------------- \n", control(t, y0))

Behold the control term: 
----------------------- 
 Func(
  mlp=MLP(
    layers=[
      Linear(
        weight=f32[4,1],
        bias=f32[4],
        in_features=1,
        out_features=4,
        use_bias=True
      ),
      Linear(
        weight=f32[1,4],
        bias=f32[1],
        in_features=4,
        out_features=1,
        use_bias=True
      )
    ],
    activation=<wrapped function softplus>,
    final_activation=<function _identity>,
    in_size=1,
    out_size=1,
    width_size=4,
    depth=1
  )
)

Testing the control term: 
----------------------- 
 [0.0610788]


In [38]:
class NeuralODE(eqx.Module):
    func: Func

    def __init__(self, data_size, width_size, depth, *args, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        return solution.ys

In [39]:
neural_ode = NeuralODE(data_size, width_size, depth, key=key)
neural_ode(ts, y0)

DeviceArray([[1.       ],
             [1.0006104],
             [1.0012197],
             ...,
             [1.3250225],
             [1.3251774],
             [1.325332 ]], dtype=float32, weak_type=True)