In this notebook I benchmark the implementation of the CWRNN vs. the implementation of a standard RNN.
The CWRNN should avoid some computations, but the sparse multiplication logic could  actually slow down the execution, compared to a simple matvec for computinh $h_{t+1}$ in the standard RNN. 
I'll execute two experiments
- unbatched small inputs
- batched big inputs

In [12]:
import sys

# sys.path.append("..")
import jax
import jax.random as jr
import numpy as np
import equinox as eqx  # pytree utilities
from rnn_jax.cells import ElmanRNNCell, ClockWorkRNNCell
from rnn_jax.layers import RNNEncoder
import matplotlib.pyplot as plt
import timeit

In [20]:
rnn = RNNEncoder(ElmanRNNCell(1, 64, key=jr.key(0)))
cwrnn = RNNEncoder(
    ClockWorkRNNCell(
        1, [16, 16, 16, 16], [2, 4, 8, 16], nonlinearity=jax.nn.relu, key=jr.key(0)
    )
)
x = jr.uniform(jr.key(1), (256, 1))
jit_rnn = eqx.filter_jit(rnn)
jit_rnn(x)
jit_cwrnn = eqx.filter_jit(cwrnn)
jit_cwrnn(x)
print("rnn:")
%timeit jit_rnn(x)
print("cw-rnn:")
%timeit jit_cwrnn(x)

rnn:
85.1 μs ± 2.03 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
cw-rnn:
244 μs ± 10.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


CW-RNN is considerably slower ($\approx$ 3x) in the unbatched case

In [18]:
x = jr.uniform(jr.key(1), (256, 100, 10))
rnn = RNNEncoder(ElmanRNNCell(10, 1024, key=jr.key(0)))
cwrnn = RNNEncoder(
    ClockWorkRNNCell(
        10, [256 // 4] * 4, [4, 16, 64, 256], nonlinearity=jax.nn.relu, key=jr.key(0)
    )
)
vmap_rnn = eqx.filter_vmap(rnn)
vmap_rnn(x)
vmap_cwrnn = eqx.filter_vmap(cwrnn)
vmap_cwrnn(x)
print("rnn:")
%timeit vmap_cwrnn(x)
print("cw-rnn:")
%timeit vmap_cwrnn(x)

rnn:
98.7 ms ± 1.54 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
cw-rnn:
97.7 ms ± 1.45 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


When batching, cwrnn catches up in speed. Probably optimizing more the logic should make it go faster than a standard RNN. 