For regular chunked stochastic forward simulation, as with the MPR backend, we would create the buffer, and step through it, rewriting the noise samples with the computed time series, something like this

In [34]:
import jax.numpy as np
import jax
from jax.test_util import check_grads

keys = jax.random.split(jax.random.PRNGKey(0), 10)
nn = 84
weights = np.abs(jax.random.normal (keys[0], (nn, nn)))+0.5
lengths = jax.random.randint(keys[1], (nn, nn), 0, 255)
nh = lengths.max() + 1
nt = 100
buffer = jax.random.normal(keys[2], (nn, nh + nt))
nn = weights.shape[0]
ns = np.tile(np.arange(nn),(nn,1))

In [23]:
@jax.jit
def run(buffer):
    # TODO convert to scan?
    for t in range(nt):
        x_t = np.mean(weights*buffer[ns, nh + t - lengths], axis=1)
        buffer = buffer.at[:, nh + t].set(x_t)
    return buffer

b1 = run(buffer)
%timeit run(buffer)

117 µs ± 1.13 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [24]:
loss = lambda b: np.sum(np.square(b1 - run(b)))
gloss = jax.jit(jax.grad(loss))
gloss(b1).shape

(84, 265)

In [25]:
%timeit gloss(b1+1)

570 µs ± 3.02 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [22]:
# n=1 run 16us & gloss 52us
# n=10 run 118us & gloss 590us
# n=100 run 1150 us & gloss 32700us

The tricky bit with a scan is the carry,

In [26]:
jax.lax.scan?

In [45]:
def make_run(nt):
    def op(buffer, t):
        x_t = np.mean(weights*buffer[ns, nh + t - lengths], axis=1)
        buffer = buffer.at[:, nh + t].set(x_t)
        return buffer, x_t
    @jax.jit
    def run(buffer):
        b, x = jax.lax.scan(op, buffer, np.r_[:nt])
        return b
    return op, run

op, run2 = make_run(nt)
run2(buffer).shape
%timeit run2(buffer)

1.32 ms ± 6.91 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [38]:
loss2 = lambda b: np.sum(np.square(b1 - run2(b)))
gloss2 = jax.jit(jax.grad(loss2))
gloss2(b1).shape
%timeit gloss2(b1+1)

5.03 ms ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


About 4x slower.. could it be any better with custom vjp? well let''s look at what it's actually doing:

In [51]:
f = lambda b: np.sum(np.square(op(b,0)[0]))

jax.make_jaxpr(jax.grad(f))(b1)

{ lambda a:i32[] b:i32[84,84] c:i32[84,84] d:f32[84,84]; e:f32[84,265]. let
    f:i32[] = add a 0
    g:i32[84,84] = sub f b
    h:bool[84,84] = lt c 0
    i:i32[84,84] = add c 84
    j:i32[84,84] = select_n h c i
    k:bool[84,84] = lt g 0
    l:i32[84,84] = add g 265
    m:i32[84,84] = select_n k g l
    n:i32[84,84,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(84, 84, 1)
    ] j
    o:i32[84,84,1] = broadcast_in_dim[
      broadcast_dimensions=(0, 1)
      shape=(84, 84, 1)
    ] m
    p:i32[84,84,2] = concatenate[dimension=2] n o
    q:f32[84,84] = gather[
      dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, 1), start_index_map=(0, 1))
      fill_value=None
      indices_are_sorted=False
      mode=GatherScatterMode.PROMISE_IN_BOUNDS
      slice_sizes=(1, 1)
      unique_indices=False
    ] e p
    r:f32[84,84] = mul d q
    s:f32[84] = reduce_sum[axes=(1,)] r
    t:f32[84] = div s 84.0
    u:i32[] = add a 0
    v:bool[] = 

This is interesting because it tells us that we have gather, scatter, scatter, gather here, to compute the gradient.  Can't some of that work be kept from the forward pass?  In any case, the jaxpr here is useful to understand what those clever transformations liek `jax.grad` are doing.