# [Train RNN - equinox](https://docs.kidger.site/equinox/examples/train_rnn/)

In [1]:
import math
import time
from typing import Sequence, Optional

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.lax as lax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
from jaxtyping import Array, PRNGKeyArray
import matplotlib
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

from tools._dataset.datasets import make_2dspiral_dataset
from tools._dataset.dataloader import dataloader
from tools._modules.discrete_cde import DiscreteCDELayer

In [2]:
class RNN(eqx.Module):
    hidden_size: int
    initial: eqx.nn.MLP
    rnn: DiscreteCDELayer
    linear: eqx.nn.Linear
    bias: jax.Array

    def __init__(self, in_size: int, out_size: int, hidden_size: int, width_size: int, depth: int, *, key: PRNGKeyArray):
        ikey, gkey, lkey = jr.split(key, 3)
        self.hidden_size = hidden_size
        self.initial = eqx.nn.MLP(in_size, hidden_size, width_size, depth, key=ikey)
        
        self.rnn = DiscreteCDELayer(in_size, hidden_size, width_size, depth, key=gkey)
        
        self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, xs: Array, key: Optional[jax.random.PRNGKey] = None):
        x0 = xs[0,:]
        y0 = self.initial(x0) # y0
        yT, ys = self.rnn(y0, xs)
        logits = self.linear(yT)
        # sigmoid because we're performing binary classification
        probs = jax.nn.sigmoid(logits + self.bias)
        return probs

# Train and Eval

In [3]:
def main(
    dataset_size=4096,
    length=100,
    out_size=1,
    add_noise=False,
    batch_size=32,
    lr=3e-3,
    steps=100,
    hidden_size=8,
    width_size=128,
    depth=1,
    seed=5678,
):
    key = jr.PRNGKey(seed)
    train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)
    
    ts, ys, _, labels, in_size = make_2dspiral_dataset(
        dataset_size, length, add_noise, key=train_data_key
    )

    model = RNN(in_size, out_size, hidden_size, width_size, depth, key=model_key)
    #jax.debug.print("model: {}", model)
    # Training loop like normal.
    
    @eqx.filter_jit
    def loss(model: eqx.Module, x: Array, y: Array, key: Optional[PRNGKeyArray] = None):
        #batched_keys = jax.random.split(key, num=x.shape[0])
        y = lax.expand_dims(y, dimensions=[1])
        pred = jax.vmap(model)(x) # dropoutを使うときは, (x, batched_keys)
        # Binary cross-entropy
        #jax.debug.print("pred: {}, y: {}", pred, y)
        bxe = y * jnp.log(pred) + (1 - y) * jnp.log(1 - pred)
        bxe = -jnp.mean(bxe)
        acc = jnp.mean((pred > 0.5) == (y == 1))
        return bxe, acc

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)

    @eqx.filter_jit
    def make_step(model: eqx.Module, data_i, opt_state, key: Optional[PRNGKeyArray] = None):
        # key, new_key = jax.random.split(key)
        x, y  = data_i
        (bxe, acc), grads = grad_loss(model, x, y) # (model, x, y, key)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return bxe, acc, model, opt_state #, new_key

    model = eqx.nn.inference_mode(model)
    optim = optax.adam(lr)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    
    for step, data_i in zip(
        range(steps), dataloader((ys, labels), batch_size, key=loader_key)
    ):
        start = time.time()
        bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
        end = time.time()
        print(
            f"Step: {step}, Loss: {bxe.item()}, Accuracy: {acc}, Computation time: "
            f"{end - start}"
        )
        
    model = eqx.nn.inference_mode(model, value=False)
    ts, ys, _, labels, _ = make_2dspiral_dataset(dataset_size, length, add_noise, key=test_data_key)
    bxe, acc = loss(model, ys, labels)
    print(f"Test loss: {bxe}, Test Accuracy: {acc}")

In [4]:
eqx.clear_caches()
jax.clear_caches()
main() 

Step: 0, Loss: 2.6379456520080566, Accuracy: 0.46875, Computation time: 0.3438549041748047
Step: 1, Loss: 0.9420516490936279, Accuracy: 0.4375, Computation time: 0.004872798919677734
Step: 2, Loss: 0.9228380918502808, Accuracy: 0.28125, Computation time: 0.0046160221099853516
Step: 3, Loss: 0.6698753237724304, Accuracy: 0.40625, Computation time: 0.004331111907958984
Step: 4, Loss: 0.6057009696960449, Accuracy: 0.59375, Computation time: 0.0044176578521728516
Step: 5, Loss: 0.5476295948028564, Accuracy: 0.59375, Computation time: 0.004488945007324219
Step: 6, Loss: 0.44034063816070557, Accuracy: 0.625, Computation time: 0.004310131072998047
Step: 7, Loss: 0.39380085468292236, Accuracy: 1.0, Computation time: 0.004458904266357422
Step: 8, Loss: 0.37781476974487305, Accuracy: 1.0, Computation time: 0.004428863525390625
Step: 9, Loss: 0.3344845771789551, Accuracy: 1.0, Computation time: 0.0045070648193359375
Step: 10, Loss: 0.28384533524513245, Accuracy: 1.0, Computation time: 0.004412889