# RNN LM with Flax and Redex combinators

[Redex](https://github.com/manifest/redex) library makes possible designing [Flax](https://github.com/google/flax) layers and models using combinators.

In [1]:
import logging
import jax
from jax import numpy as jnp
from flax import linen as nn
from flax.linen.recurrent import Array
from redex import combinator as cb
from flax_extra import operator as op
from flax_extra import random

Combinators provide a simple and concise way of composing functional code. They may compose, be used by, be mixed with Flax linen modules, other combinators, or standard python functions.

In [2]:
LSTMState = tuple[Array, Array]

class LSTMCell(nn.LSTMCell):
    def __call__(self, carry: LSTMState, inputs: Array) -> tuple[LSTMState, Array]:
        return super().__call__(carry, inputs)

class LSTM(nn.Module):
    d_hidden: int

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        return cb.serial(
            cb.branch(self.initial_state, op.identity),
            nn.scan(
                LSTMCell,
                variable_broadcast="params",
                split_rngs={"params": False},
                in_axes=1,
                out_axes=1,
            )(),
            cb.drop(n_in=2),
        )(inputs)

    def initial_state(self, inputs: Array) -> LSTMState:
        batch_size = inputs.shape[0]
        return nn.LSTMCell.initialize_carry(
            self.make_rng("carry"),
            (batch_size,),
            self.d_hidden,
        )

class RNNLM(nn.Module):
    vocab_size: int
    d_model: int = 512
    n_layers: int = 2
    dropout_rate: float = 0.1
    deterministic: bool = True

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        return cb.serial(
            op.ShiftRight(axis=1),
            nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.d_model,
                embedding_init=nn.initializers.normal(stddev=1.0),
            ),
            [LSTM(d_hidden=self.d_model) for _ in range(self.n_layers)],
            nn.Dropout(
                rate=self.dropout_rate,
                deterministic=self.deterministic,
            ),
            nn.Dense(features=self.vocab_size),
        )(inputs)

In [3]:
MAX_LENGTH = 4
VOCAB_SIZE = 3
BATCH_SIZE = 1
D_MODEL = 2

model = RNNLM(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
)

model_init = jax.jit(model.init)
model_apply = jax.jit(model.apply)

rnkeyg = random.sequence(seed=1356)
sample = jnp.ones((BATCH_SIZE, MAX_LENGTH), dtype=int)



With debug logging enabled, we can inspect how data flow within combinators.

In [4]:
logging.getLogger().setLevel("DEBUG")
initial_variables = model_init(
    rngs=random.into_collection(
        key=next(rnkeyg),
        labels=["params", "carry", "dropout"],
    ),
    inputs=sample,
)
logging.getLogger().setLevel("INFO")

DEBUG:absl:Compiling _split (6376033984) for args (ShapedArray(uint32[2]),).
DEBUG:absl:Compiling _split (6375976128) for args (ShapedArray(uint32[2]),).
DEBUG:root:constrained_call :: ShiftRight           stack_size=1  signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Embed                stack_size=1  signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: LSTM                 stack_size=1  signature=Signature(n_in=1, n_out=1, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Serial               stack_size=1  signature=Signature(n_in=1, n_out=3, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Select               stack_size=1  signature=Signature(n_in=1, n_out=2, start_index=0, in_shape=((),))
DEBUG:root:constrained_call :: Parallel             stack_size=2  signature=Signature(n_in=2, n_out=3, start_index=0, in_shape=((), ()))
DEBUG:root:constrained_call :: initial_

In [5]:
model_apply(
    initial_variables,
    inputs=sample,
    rngs=random.into_collection(
        key=next(rnkeyg),
        labels=["params", "carry", "dropout"],
    ),
)

DeviceArray([[[-0.00954835, -0.00654768, -0.00827007],
              [-0.01224634, -0.01415828, -0.01595778],
              [-0.01219137, -0.01792459, -0.01944371],
              [-0.01135414, -0.01874325, -0.02001233]]], dtype=float32)