# [Train RNN - equinox](https://docs.kidger.site/equinox/examples/train_rnn/)

In [1]:
import math
import time
from typing import Tuple, Sequence, Optional

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.lax as lax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
from jaxtyping import Array, Float, PRNGKeyArray
import matplotlib
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

from tools._dataset.datasets import SpiralDataset
from tools._dataset.dataloader import dataloader
from tools._model.discrete_cde import RNN

class RNN(eqx.Module):
    hidden_size: int
    initial: eqx.nn.MLP
    rnn: DiscreteCDELayer
    linear: eqx.nn.Linear
    bias: jax.Array

    def __init__(self, in_size: int, out_size: int, hidden_size: int, width_size: int, depth: int, *, key: PRNGKeyArray):
        ikey, gkey, lkey = jr.split(key, 3)
        self.hidden_size = hidden_size
        self.initial = eqx.nn.MLP(in_size, hidden_size, width_size, depth, key=ikey)
        
        self.rnn = DiscreteCDELayer(in_size, hidden_size, width_size, depth, key=gkey)
        
        self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, xs: Array, key: Optional[jax.random.PRNGKey] = None):
        x0 = xs[0,:]
        y0 = self.initial(x0) # y0
        yT, ys = self.rnn(y0, xs)
        logits = self.linear(yT)
        # sigmoid because we're performing binary classification
        probs = jax.nn.sigmoid(logits + self.bias)
        return probs

# Train and Eval

In [4]:
def main(
    dataset_size=4096,
    length=100,
    out_size=1,
    add_noise=False,
    batch_size=32,
    lr=3e-3,
    steps=100,
    hidden_size=8,
    width_size=128,
    depth=1,
    seed=5678,
):
    key = jr.PRNGKey(seed)
    train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)
    
    dataset = SpiralDataset(dataset_size, length, add_noise, key=train_data_key)
    ts, ys, _, labels, in_size = dataset.make_dataset()

    model = RNN(in_size, out_size, hidden_size, width_size, depth, key=model_key)
    #jax.debug.print("model: {}", model)
    # Training loop like normal.
    
    @eqx.filter_jit
    def loss(model: eqx.Module, x: Array, y: Array, *, key: Optional[PRNGKeyArray] = None) -> Tuple[Float, Float]:
        #batched_keys = jax.random.split(key, num=x.shape[0])
        preds = jax.vmap(model)(x) # dropoutを使うときは, (x, batched_keys)
        
        # Binary cross-entropy
        y = lax.expand_dims(y, dimensions=[1])
        bxe = y * jnp.log(preds) + (1 - y) * jnp.log(1 - preds)
        bxe = -jnp.mean(bxe)
        acc = jnp.mean((preds > 0.5) == (y == 1))
        return bxe, acc

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)

    @eqx.filter_jit
    def make_step(model: eqx.Module, data: Tuple[Array, ...], opt_state: Tuple, key: Optional[PRNGKeyArray] = None) -> Tuple[Float, Float, eqx.Module, Tuple]:
        # key, new_key = jax.random.split(key)
        x, y  = data
        (bxe, acc), grads = grad_loss(model, x, y) # (model, x, y, key)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return bxe, acc, model, opt_state #, new_key

    model = eqx.nn.inference_mode(model)
    optim = optax.adam(lr)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    
    for step, data_i in zip(
        range(steps), dataloader((ys, labels), batch_size, key=loader_key)
    ):
        start = time.time()
        bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
        end = time.time()
        print(
            f"Step: {step}, Loss: {bxe.item()}, Accuracy: {acc}, Computation time: "
            f"{end - start}"
        )
        
    model = eqx.nn.inference_mode(model, value=False)
    dataset = SpiralDataset(dataset_size, length, add_noise, key=train_data_key)
    ts, ys, _, labels, _ = dataset.make_dataset()
    bxe, acc = loss(model, ys, labels)
    print(f"Test loss: {bxe}, Test Accuracy: {acc}")

In [5]:
eqx.clear_caches()
jax.clear_caches()
main() 

Step: 0, Loss: 2.6379456520080566, Accuracy: 0.46875, Computation time: 0.3254389762878418
Step: 1, Loss: 0.9420516490936279, Accuracy: 0.4375, Computation time: 0.004773855209350586
Step: 2, Loss: 0.9228380918502808, Accuracy: 0.28125, Computation time: 0.00456690788269043
Step: 3, Loss: 0.6698753237724304, Accuracy: 0.40625, Computation time: 0.004745960235595703
Step: 4, Loss: 0.6057009696960449, Accuracy: 0.59375, Computation time: 0.004636049270629883
Step: 5, Loss: 0.5476295948028564, Accuracy: 0.59375, Computation time: 0.004507780075073242
Step: 6, Loss: 0.44034063816070557, Accuracy: 0.625, Computation time: 0.0046901702880859375
Step: 7, Loss: 0.39380085468292236, Accuracy: 1.0, Computation time: 0.00453495979309082
Step: 8, Loss: 0.37781476974487305, Accuracy: 1.0, Computation time: 0.004513978958129883
Step: 9, Loss: 0.3344845771789551, Accuracy: 1.0, Computation time: 0.00455927848815918
Step: 10, Loss: 0.28384533524513245, Accuracy: 1.0, Computation time: 0.00453400611877