In [1]:
pip install diffrax equinox optax

Collecting diffrax
  Downloading diffrax-0.5.0-py3-none-any.whl (141 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/141.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m133.1/141.7 kB[0m [31m4.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m141.7/141.7 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting equinox
  Downloading equinox-0.11.4-py3-none-any.whl (175 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.2/175.2 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
Collecting jaxtyping>=0.2.24 (from diffrax)
  Downloading jaxtyping-0.2.28-py3-none-any.whl (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.7/40.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting lineax>=0.0.4 (from diffrax)
  Downloading lineax-0.0.5-py3-none-any.whl (66 kB)
[2K     [90m━━━━━━━━━━━━━━━

In [25]:
import time
import diffrax
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import optax
import math

In [13]:
class NN(eqx.Module):
    hidden_size: int
    cell: eqx.Module
    linear: eqx.nn.Linear
    bias: jax.Array

    def __init__(self, in_size, out_size, hidden_size, *, key):
        ckey, lkey = jr.split(key)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
        self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, t, input):
        hidden = jnp.zeros((self.hidden_size,))

        def f(carry, inp):
            return self.cell(inp, carry), None

        out, _ = jax.lax.scan(f, hidden, input)

        return self.linear(out) + self.bias

$$
    \left\{\begin{array}{l}
    \frac{d T}{d t}=r T\left(1-\frac{T}{K}\right)-n E T,  \\
    \frac{d E}{d t}=\sigma+\mu T E-\eta E,
    \end{array}\right.
$$

In [29]:
def _get_data(ts, *, key):
    y0 = jr.uniform(key, (2,), minval = jnp.array([0,0.5]), maxval = 2)
    #y0 = jnp.array([1.62,0.6])
    r=2.5
    K=2.
    n=0.8
    omega=0.5
    miu=4.
    nu=1.5
    def f(t, y, args):
        # El tamaño de y es (L,2)
        T = y[0]
        E = y[1]
        x0 = r*T*(1-(T/K))-(n*T*E)
        x1 = omega+miu*T*E-nu*E
        return jnp.stack([x0, x1], axis=-1)

    solver = diffrax.Tsit5()
    dt0 = 0.01
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
    )
    ys = sol.ys
    return ys


def get_data(dataset_size, *, key):
    ts = jnp.linspace(0, 10, 100)
    key = jr.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    return ys, ts

In [30]:
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

In [31]:
def main(
    dataset_size=10000,
    batch_size=32,
    lr_strategy=(3e-3, 3e-3),
    steps_strategy=(1000, 1000),
    hidden_size=16,
    depth=1,
    seed=5678,
    print_every=100,
):
    key = jr.PRNGKey(seed)
    loader_key, data_key, model_key = jr.split(key, 3)

    ts, ys = get_data(dataset_size, key=data_key)

    # Data size is 2 in this example.
    _, length_size, data_size = ys.shape
    model = NN(in_size=data_size, out_size=data_size, hidden_size=hidden_size, key=model_key)

    @eqx.filter_value_and_grad
    def grad_loss(model, ti, yi):
        print(f"yi: {yi}")
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0, :])
        return jnp.mean((yi - y_pred) ** 2)

    @eqx.filter_jit
    def make_step(ti, yi, model, opt_state):
        loss, grads = grad_loss(model, ti, yi)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for lr, steps in zip(lr_strategy, steps_strategy):
        optim = optax.adam(lr)
        opt_state = optim.init(model)
        _ts = ts[: int(length_size)]
        _ys = ys[:, : int(length_size)]
        for step, (yi,) in zip(
            range(steps), dataloader((_ys,), batch_size, key=loader_key)
        ):
            start = time.time()
            loss, model, opt_state = make_step(_ts, yi, model, opt_state)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")


    return ts, ys, model

In [32]:
ts, ys, model = main()

ValueError: not enough values to unpack (expected 3, got 1)