# [Train RNN - equinox](https://docs.kidger.site/equinox/examples/train_rnn/)

In [117]:
import math
import time
from typing import Sequence, Optional

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.lax as lax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
from jaxtyping import Array, PRNGKeyArray
import matplotlib
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

from tools._dataset.datasets import make_2dspiral_dataset
from tools._dataset.dataloader import dataloader

In [123]:
def default_floating_dtype():
    if jax.config.jax_enable_x64:  # pyright: ignore
        return jnp.float64
    else:
        return jnp.float32

class RNNCell(eqx.Module, strict=True):
    """A single step of a Recurrent Neural Network (RNN).

    !!! example

        This is often used by wrapping it into a `jax.lax.scan`. For example:

        ```python
        class Model(Module):
            cell: RNNCell

            def __init__(self, **kwargs):
                self.cell = RNNCell(**kwargs)

            def __call__(self, xs):
                scan_fn = lambda state, input: (self.cell(input, state), None)
                init_state = jnp.zeros(self.cell.hidden_size)
                final_state, _ = jax.lax.scan(scan_fn, init_state, xs)
                return final_state
        ```
    """

    weight_ih: Array
    weight_hh: Array
    bias: Optional[Array]
    input_size: int = eqx.field(static=True)
    hidden_size: int = eqx.field(static=True)
    use_bias: bool = eqx.field(static=True)

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        use_bias: bool = True,
        dtype=None,
        *,
        key: PRNGKeyArray,
    ):
        """**Arguments:**

        - `input_size`: The dimensionality of the input vector at each time step.
        - `hidden_size`: The dimensionality of the hidden state passed along between
            time steps.
        - `use_bias`: Whether to add on a bias after each update.
        - `dtype`: The dtype to use for all weights and biases in this GRU cell.
            Defaults to either `jax.numpy.float32` or `jax.numpy.float64` depending on
            whether JAX is in 64-bit mode.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        """
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.use_bias = use_bias
        
        dtype = default_floating_dtype() if dtype is None else dtype
        ihkey, hhkey, bkey = jr.split(key, 3)
        lim = math.sqrt(1 / hidden_size)

        self.weight_ih = jr.uniform(
            ihkey, (hidden_size, input_size), minval=-lim, maxval=lim, dtype=dtype
        )
        self.weight_hh = jr.uniform(
            hhkey, (hidden_size, hidden_size), minval=-lim, maxval=lim, dtype=dtype
        )
        if use_bias:
            self.bias = jr.uniform(
                bkey, (hidden_size,), minval=-lim, maxval=lim, dtype=dtype
            )
        else:
            self.bias = None

    @jax.jit
    def __call__(
        self, input: Array, hidden: Array, *, key: Optional[PRNGKeyArray] = None
    ):
        """**Arguments:**

        - `input`: The input, which should be a JAX array of shape `(input_size,)`.
        - `hidden`: The hidden state, which should be a JAX array of shape
            `(hidden_size,)`.
        - `key`: Ignored; provided for compatibility with the rest of the Equinox API.
            (Keyword only argument.)

        **Returns:**

        The updated hidden state, which is a JAX array of shape `(hidden_size,)`.
        """
        if self.use_bias:
            bias = self.bias
        else:
            bias = 0
        h_i = self.weight_ih @ input
        h_h = self.weight_hh @ hidden
        new = jnn.tanh(h_i + h_h + bias)
        print(new.shape)
        return new

In [124]:
class GRUCell(eqx.Module, strict=True):
    """A single step of a Gated Recurrent Unit (GRU).

    !!! example

        This is often used by wrapping it into a `jax.lax.scan`. For example:

        ```python
        class Model(Module):
            cell: GRUCell

            def __init__(self, **kwargs):
                self.cell = GRUCell(**kwargs)

            def __call__(self, xs):
                scan_fn = lambda state, input: (self.cell(input, state), None)
                init_state = jnp.zeros(self.cell.hidden_size)
                final_state, _ = jax.lax.scan(scan_fn, init_state, xs)
                return final_state
        ```
    """

    weight_ih: Array
    weight_hh: Array
    bias: Optional[Array]
    bias_n: Optional[Array]
    input_size: int = eqx.field(static=True)
    hidden_size: int = eqx.field(static=True)
    use_bias: bool = eqx.field(static=True)

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        use_bias: bool = True,
        dtype=None,
        *,
        key: PRNGKeyArray,
    ):
        """**Arguments:**

        - `input_size`: The dimensionality of the input vector at each time step.
        - `hidden_size`: The dimensionality of the hidden state passed along between
            time steps.
        - `use_bias`: Whether to add on a bias after each update.
        - `dtype`: The dtype to use for all weights and biases in this GRU cell.
            Defaults to either `jax.numpy.float32` or `jax.numpy.float64` depending on
            whether JAX is in 64-bit mode.
        - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
            initialisation. (Keyword only argument.)
        """
        dtype = default_floating_dtype() if dtype is None else dtype
        ihkey, hhkey, bkey, bkey2 = jr.split(key, 4)
        lim = math.sqrt(1 / hidden_size)

        self.weight_ih = jr.uniform(
            ihkey, (3 * hidden_size, input_size), minval=-lim, maxval=lim, dtype=dtype
        )
        self.weight_hh = jr.uniform(
            hhkey, (3 * hidden_size, hidden_size), minval=-lim, maxval=lim, dtype=dtype
        )
        if use_bias:
            self.bias = jr.uniform(
                bkey, (3 * hidden_size,), minval=-lim, maxval=lim, dtype=dtype
            )
            self.bias_n = jr.uniform(
                bkey2, (hidden_size,), minval=-lim, maxval=lim, dtype=dtype
            )
        else:
            self.bias = None
            self.bias_n = None

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.use_bias = use_bias

    @jax.named_scope("eqx.nn.GRUCell")
    def __call__(
        self, input: Array, hidden: Array, *, key: Optional[PRNGKeyArray] = None
    ):
        """**Arguments:**

        - `input`: The input, which should be a JAX array of shape `(input_size,)`.
        - `hidden`: The hidden state, which should be a JAX array of shape
            `(hidden_size,)`.
        - `key`: Ignored; provided for compatibility with the rest of the Equinox API.
            (Keyword only argument.)

        **Returns:**

        The updated hidden state, which is a JAX array of shape `(hidden_size,)`.
        """
        if self.use_bias:
            bias = self.bias
            bias_n = self.bias_n
        else:
            bias = 0
            bias_n = 0
        igates = jnp.split(self.weight_ih @ input + bias, 3) # 単純に効率化のために一括で行列積を取って後からsplitしている.
        hgates = jnp.split(self.weight_hh @ hidden, 3)
        reset = jnn.sigmoid(igates[0] + hgates[0])
        inp = jnn.sigmoid(igates[1] + hgates[1])
        new = jnn.tanh(igates[2] + reset * (hgates[2] + bias_n))
        new = new + inp * (hidden - new)
        print(new.shape)
        return new

In [131]:
class RNNLayer(eqx.Module):
    cell: eqx.Module # eqx.nn.GRUCell --> eqx.Module
    dropout: eqx.Module
    
    def __init__(self, in_size: int, out_size: int, dropout: float = 0.0, *, key: PRNGKeyArray):
        ckey, dkey = jr.split(key, 2)
        self.cell = GRUCell(in_size, out_size, key=ckey) # eqx.nn.GRUCell --> RNNCell
        self.dropout = eqx.nn.Dropout(dropout)
        
    def __call__(self, hidden: Array, input: Array, key: Optional[jax.random.PRNGKey] = None):
        def _f(carry, xs):
            carry = self.cell(xs, carry)
            return carry, carry
        _, outputs = lax.scan(_f, hidden, input)
        outputs = self.dropout(outputs, key=key)
        return outputs[-1], outputs

class RNN(eqx.Module):
    hidden_size: int
    initial: eqx.nn.MLP
    dropout: eqx.Module
    rnns: Sequence[eqx.Module]
    linear: eqx.nn.Linear
    bias: jax.Array

    def __init__(self, in_size: int, out_size: int, hidden_size: int, width_size: int, depth: int, dropout: float = 0.0, *, key: PRNGKeyArray):
        ikey, gkey, lkey = jr.split(key, 3)
        self.hidden_size = in_size * hidden_size
        self.initial = eqx.nn.MLP(in_size, self.hidden_size, width_size, depth, key=ikey)
        self.dropout = eqx.nn.Dropout(dropout)
        
        self.rnns = []
        for _ in jnp.arange(depth):
            gkey, _gkey = jr.split(gkey, 2)
            self.rnns.append(RNNLayer(self.hidden_size, self.hidden_size, key=_gkey))
        
        self.linear = eqx.nn.Linear(self.hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, input: Array, key: Optional[jax.random.PRNGKey] = None):
        hidden = jnp.zeros((self.hidden_size,))
        
        outputs = jax.vmap(self.initial)(input) # 系列長(時間)次元に vmap
        outputs = self.dropout(outputs, key=key)

        for _rnn in self.rnns:
            hidden, outputs = _rnn(hidden, outputs)
            
        hidden = self.linear(hidden)
        # sigmoid because we're performing binary classification
        probs = jax.nn.sigmoid(hidden + self.bias)
        return probs

# Train and Eval

In [132]:
def main(
    dataset_size=4096,
    length=100,
    out_size=1,
    add_noise=False,
    batch_size=32,
    lr=3e-3,
    steps=200,
    hidden_size=8,
    width_size=128,
    depth=1,
    seed=5678,
):
    key = jr.PRNGKey(seed)
    train_data_key, test_data_key, model_key, loader_key, train_model_key, test_model_key = jr.split(key, 6)
    
    ts, ys, _, labels, in_size = make_2dspiral_dataset(
        dataset_size, length, add_noise, key=train_data_key
    )

    model = RNN(in_size, out_size, hidden_size, width_size, depth, dropout=0.2, key=model_key)
    #jax.debug.print("model: {}", model)
    # Training loop like normal.
    
    @eqx.filter_jit
    def loss(model: eqx.Module, x: Array, y: Array, key: PRNGKeyArray):
        batched_keys = jax.random.split(key, num=x.shape[0])
        y = lax.expand_dims(y, dimensions=[1])
        pred = jax.vmap(model)(x, batched_keys) # バッチ次元にvmap
        # Binary cross-entropy
        #jax.debug.print("pred: {}, y: {}", pred.shape, y.shape)
        #jax.debug.print("pred: {}, y: {}", pred, y)
        bxe = y * jnp.log(pred) + (1 - y) * jnp.log(1 - pred)
        bxe = -jnp.mean(bxe)
        acc = jnp.mean((pred > 0.5) == (y == 1))
        return bxe, acc

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)

    @eqx.filter_jit
    def make_step(model, data_i, opt_state, key):
        key, new_key = jax.random.split(key)
        x, y  = data_i
        (bxe, acc), grads = grad_loss(model, x, y, key)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return bxe, acc, model, opt_state, new_key

    model = eqx.nn.inference_mode(model)
    optim = optax.adam(lr)
    #opt_state = optim.init(model)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array)) # こっちにすると必要な勾配計算が途切れて学習されない.
    
    for step, data_i in zip(
        range(steps), dataloader((ys, labels), batch_size, key=loader_key)
    ):
        start = time.time()
        bxe, acc, model, opt_state, train_model_key = make_step(model, data_i, opt_state, key=train_model_key)
        end = time.time()
        print(
            f"Step: {step}, Loss: {bxe.item()}, Accuracy: {acc}, Computation time: "
            f"{end - start}"
        )
        
    model = eqx.nn.inference_mode(model, value=False)
    ts, ys, _, labels, _ = make_2dspiral_dataset(dataset_size, length, add_noise, key=test_data_key)
    bxe, acc = loss(model, ys, labels, key=test_model_key)
    print(f"Test loss: {bxe}, Test Accuracy: {acc}")

In [133]:
eqx.clear_caches()
jax.clear_caches()
main() 

(24,)
Step: 0, Loss: 0.7072348594665527, Accuracy: 0.46875, Computation time: 0.3663449287414551
Step: 1, Loss: 0.75129234790802, Accuracy: 0.375, Computation time: 0.004785776138305664
Step: 2, Loss: 0.6932094097137451, Accuracy: 0.4375, Computation time: 0.004559040069580078
Step: 3, Loss: 0.6912630796432495, Accuracy: 0.53125, Computation time: 0.0046541690826416016
Step: 4, Loss: 0.6814514398574829, Accuracy: 0.59375, Computation time: 0.0046939849853515625
Step: 5, Loss: 0.6781452894210815, Accuracy: 0.59375, Computation time: 0.004791975021362305
Step: 6, Loss: 0.7136348485946655, Accuracy: 0.46875, Computation time: 0.004964113235473633
Step: 7, Loss: 0.7135650515556335, Accuracy: 0.46875, Computation time: 0.004450798034667969
Step: 8, Loss: 0.6365460157394409, Accuracy: 0.75, Computation time: 0.004570960998535156
Step: 9, Loss: 0.6766389608383179, Accuracy: 0.59375, Computation time: 0.00452876091003418
Step: 10, Loss: 0.6384806632995605, Accuracy: 0.71875, Computation time: 

- `class RNNLayer` の `self.cell` を `RNNCell` 　にすると失敗。 `GRUCell` なら動く。