### Speed Gains Part 3: Optimizing updates for scan

In the previous entry, we have seen how to use `IndexRequest` to perform a more localized `update` in a `vmap` model. This is always sound to do this in a `vmap` model as, by construction, the variables in the different slices are independent (conditioned on the variables outside the `vmap`).

Another kind of vectorized model is `scan`. The variable `x` at iteration $i$ and $i+1$ are both stored in the same tensor at the traced address `x`. In a general state-space model represented with a `scan`, if we change the value of `x` at iteration $i$, this change can affect all the downstream computations. Therefore, in general we need to recompute all the logpdf for the values traced at iteration $i+1$ and beyond. 

The default GenJAX does the conservative thing of recomputing logpdf for all the traced values, even before $i$.
There is a special case of interest, however, where we can do better. In Hidden-Markov models (HMM), the conditional probabilities simplify as $P(x_{i+1}~|~x_i,x_{i-1}) = P(x_{i+1}~|~x_i)$. That is, no information from past steps to the future "leaks through", and is instead captured by only the previous  step. A simple counter-example would be having a momentum term $\mu = \sum_{J \leq i} c_i.x_i$ that affects the sampling of $x_{i+1}$.

In the HMM situation, if we change the value of $x_i$, we only need recompute the logpdf for $x_{i+1}$. This is because a change to $x_{i}$ will not change the value of $x_{i+1}$ and therefore we know there is no change to the logpdf of values $x_{i+1}$ and beyond.
This is exactly what `IndexRequest` on a scan model leverages.


In [None]:
import jax
import jax.numpy as jnp

import genjax
from genjax import ChoiceMapBuilder as C
from genjax import IndexRequest, StaticRequest, Update, gen, normal, pretty

key = jax.random.key(0)
pretty()

Here's a simple example of an HMM

In [None]:
def f(x, y):
    return x + y


@gen
def kernel(carry):
    x = normal(carry, 1.0) @ "x"
    y = normal(0.0, 1.0) @ "y"
    return f(x, y)  # works in general for any deterministic function f


key, subkey = jax.random.split(key)
args = (jnp.array(0.0),)
hmm = kernel.iterate(n=10)
tr = hmm.simulate(subkey, args)
tr

Let's try doing an `IndexRequest` and checking that it matches with what `update is computing`.

In [None]:
request = IndexRequest(jnp.array(3), StaticRequest({"x": Update(C.v(42.0))}))

key, subkey = jax.random.split(key)
new_tr, w, _, _ = request.edit(subkey, tr, genjax.Diff.no_change(args))

key, subkey = jax.random.split(key)
constraint = C["x"].set(tr.get_choices()["x"].at[3].set(42.0))
argdiff = genjax.Diff.no_change(args)
new_tr1, w1, _, _ = tr.update(subkey, constraint, argdiff)

assert w == w1 and jax.tree_util.tree_all(
    jax.tree_util.tree_map(
        lambda x, y: jnp.all(x == y), new_tr.get_choices(), new_tr1.get_choices()
    )
)

Here's an example of a kernel from a general state-space model that is not an HMM because it has a momentum-like term.

In [None]:
def f(momentum, x, _):
    return 0.9 * momentum + x


@gen
def kernel(momentum):
    x = normal(momentum, 1.0) @ "x"
    y = normal(0.0, 1.0) @ "y"
    return f(momentum, x, y)


key, subkey = jax.random.split(key)
ssm = kernel.iterate(n=10)
tr = ssm.simulate(subkey, (0.0,))
tr

Here, let's check that `IndexRequest` does not work on this model. In fact the system will detect it and throw an error.

In [None]:
try:
    request = IndexRequest(jnp.array(3), StaticRequest({"x": Update(C.v(42.0))}))

    key, subkey = jax.random.split(key)
    new_tr, w, _, _ = request.edit(subkey, tr, genjax.Diff.no_change(args))
except AssertionError:
    print(
        "IndexRequest failed as expected - this model does not satisfy the required conditions"
    )

The restriction on `IndexRequest` being valid for a `scan` model is weaker than the HMM assumption. The general condition is that 
the distributions 2 indices after the request should not be impacted by the change. This is verified by the HMM condition which imposes that all interactions that are 2 steps away are mediated through the interaction one step away, but we don't necessarily need this global condition, and here's an example.

In [None]:
@gen
def kernel(carry):
    _ = normal(carry, 1.0) @ "x"
    y = normal(0.0, 1.0) @ "y"
    return carry + y


key, subkey = jax.random.split(key)
non_hmm = kernel.iterate(n=10)
tr = non_hmm.simulate(subkey, (0.0,))
tr

Let's check for good measure.

In [None]:
request = IndexRequest(jnp.array(3), StaticRequest({"x": Update(C.v(42.0))}))

key, subkey = jax.random.split(key)
new_tr, w, _, _ = request.edit(subkey, tr, genjax.Diff.no_change(args))

key, subkey = jax.random.split(key)
constraint = C["x"].set(tr.get_choices()["x"].at[3].set(42.0))
argdiff = genjax.Diff.no_change(args)
new_tr1, w1, _, _ = tr.update(subkey, constraint, argdiff)

assert w == w1 and jax.tree_util.tree_all(
    jax.tree_util.tree_map(
        lambda x, y: jnp.all(x == y), new_tr.get_choices(), new_tr1.get_choices()
    )
)

In that last model, if we change `y` at some index ten all the following carry which will effect all the following `x`, so the model is not an HMM.
However, if we only decide to regenerate `x`, then this particular variable satisfies the condition and we can use `IndexRequest` here.

A more intricate example where `x` does affect the future but in a restricted way is the following:

In [None]:
@gen
def kernel(carry):
    a, b = carry
    x = normal(a, 1.0) @ "x"
    y = normal(b, 1.0) @ "y"
    return (x + y, b + y)


key, subkey = jax.random.split(key)
args = ((0.0, 0.0),)
fancy_non_hmm = kernel.iterate(n=10)
tr = fancy_non_hmm.simulate(subkey, args)
tr

Here, `x` does affect the future carry (its left component) but it then gets absorbed into the next `x` and doesn’t leak further, so it behaves like an HMM on the left part of the carry. Thus we can use `IndexRequest(jnp.array(idx), Regenerate(S.at["x"]))` . Still, we can’t use `IndexRequest(jnp.array(idx), Regenerate(S.at["y"]))`  for the same argument as in the previous example.

Let's check this for the final example.

In [None]:
request = IndexRequest(jnp.array(3), StaticRequest({"x": Update(C.v(42.0))}))

key, subkey = jax.random.split(key)
new_tr, w, _, _ = request.edit(subkey, tr, genjax.Diff.no_change(args))

key, subkey = jax.random.split(key)
constraint = C["x"].set(tr.get_choices()["x"].at[3].set(42.0))
argdiff = genjax.Diff.no_change(args)
new_tr1, w1, _, _ = tr.update(subkey, constraint, argdiff)

assert w == w1 and jax.tree_util.tree_all(
    jax.tree_util.tree_map(
        lambda x, y: jnp.all(x == y), new_tr.get_choices(), new_tr1.get_choices()
    )
)
try:
    key, subkey = jax.random.split(key)
    constraint = C["y"].set(tr.get_choices()["y"].at[3].set(42.0))
    argdiff = genjax.Diff.no_change(args)
    new_tr2, w1, _, _ = tr.update(subkey, constraint, argdiff)
except AssertionError:
    print("Update failed as expected - y affects future carries")