# [Train RNN - equinox](https://docs.kidger.site/equinox/examples/train_rnn/)

In [2]:
import math
import time
from typing import Tuple, Sequence, Optional

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.lax as lax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
from jaxtyping import Array, Float, PRNGKeyArray
import matplotlib
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

from tools._dataset.datasets import SpiralDataset
from tools._dataset.dataloader import dataloader
from tools._model.discrete_cde import RNN
from tools._loss.cross_entropy import bce_loss

class RNN(eqx.Module):
    hidden_size: int
    initial: eqx.nn.MLP
    rnn: DiscreteCDELayer
    linear: eqx.nn.Linear
    bias: jax.Array

    def __init__(self, in_size: int, out_size: int, hidden_size: int, width_size: int, depth: int, *, key: PRNGKeyArray):
        ikey, gkey, lkey = jr.split(key, 3)
        self.hidden_size = hidden_size
        self.initial = eqx.nn.MLP(in_size, hidden_size, width_size, depth, key=ikey)
        
        self.rnn = DiscreteCDELayer(in_size, hidden_size, width_size, depth, key=gkey)
        
        self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, xs: Array, key: Optional[jax.random.PRNGKey] = None):
        x0 = xs[0,:]
        y0 = self.initial(x0) # y0
        yT, ys = self.rnn(y0, xs)
        logits = self.linear(yT)
        # sigmoid because we're performing binary classification
        probs = jax.nn.sigmoid(logits + self.bias)
        return probs

# Train and Eval

In [3]:
def main(
    dataset_size=4096,
    length=100,
    out_size=1,
    add_noise=False,
    batch_size=32,
    lr=3e-3,
    steps=100,
    hidden_size=8,
    width_size=128,
    depth=1,
    seed=5678,
):
    key = jr.PRNGKey(seed)
    train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)
    
    dataset = SpiralDataset(dataset_size, length, add_noise, key=train_data_key)
    ts, _, coeffs, labels, in_size = dataset.make_dataset()

    model = RNN(in_size, out_size, hidden_size, width_size, depth, key=model_key)
    
    # Training loop like normal.

    grad_loss = eqx.filter_value_and_grad(bce_loss, has_aux=True)

    @eqx.filter_jit
    def make_step(model: eqx.Module, data: Array, opt_state: Tuple, *args, key:PRNGKeyArray) -> Tuple[Float, Float, eqx.Module, Tuple]:
        ts, labels, *coeffs = data
        (bxe, acc), grads = grad_loss(model, (ts, coeffs), labels, *args, key=key)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return bxe, acc, model, opt_state

    model = eqx.nn.inference_mode(model, value=False)
    optim = optax.adam(lr)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

    loader_key, _loader_key = jr.split(loader_key, 2)
    for step, data in zip(
        range(steps), dataloader((ts, labels) + coeffs, batch_size, key=_loader_key)
    ):
        loader_key, _loader_key = jr.split(loader_key, 2)
        start = time.time()
        bxe, acc, model, opt_state = make_step(model, data, opt_state, key=_loader_key)
        end = time.time()
        print(
            f"Step: {step}, Loss: {bxe}, Accuracy: {acc}, Computation time: "
            f"{end - start}"
        )

    model = eqx.nn.inference_mode(model)
    test_data_key, inference_key = jr.split(test_data_key, 2)
    dataset = SpiralDataset(dataset_size, length, add_noise, key=test_data_key)
    ts, _, coeffs, labels, _ = dataset.make_dataset()
    bxe, acc = bce_loss(model, (ts, coeffs), labels, key=inference_key)
    print(f"Test loss: {bxe}, Test Accuracy: {acc}")

In [4]:
eqx.clear_caches()
jax.clear_caches()
main() 

Step: 0, Loss: 2.630345344543457, Accuracy: 0.46875, Computation time: 0.5296788215637207
Step: 1, Loss: 0.6006730794906616, Accuracy: 0.6875, Computation time: 0.4188117980957031
Step: 2, Loss: 1.2619365453720093, Accuracy: 0.5, Computation time: 0.4128890037536621
Step: 3, Loss: 0.687018632888794, Accuracy: 0.53125, Computation time: 0.0051801204681396484
Step: 4, Loss: 0.5568324327468872, Accuracy: 0.5, Computation time: 0.005151987075805664
Step: 5, Loss: 0.8020071387290955, Accuracy: 0.4375, Computation time: 0.004984140396118164
Step: 6, Loss: 0.6486579179763794, Accuracy: 0.46875, Computation time: 0.005139827728271484
Step: 7, Loss: 0.4808724522590637, Accuracy: 0.9375, Computation time: 0.005090951919555664
Step: 8, Loss: 0.3821321725845337, Accuracy: 1.0, Computation time: 0.005079984664916992
Step: 9, Loss: 0.4018373489379883, Accuracy: 0.875, Computation time: 0.005011081695556641
Step: 10, Loss: 0.4957374930381775, Accuracy: 0.625, Computation time: 0.004930019378662109
St