### Speed Gains Part 2: Optimizing updates for vmap

- introduce current limitation of update 
- how to manually overwrite
- test on a variant of model from cookbook 3

In [None]:
import jax
import jax.numpy as jnp

import genjax
from genjax import ChoiceMapBuilder as C
from genjax import (
    StaticRequest,
    Update,
    gen,
    normal,
    pretty,
)
from genjax._src.core.pytree import Const

pretty()
key = jax.random.PRNGKey(0)

As we discussed in the previous cookbook entries, a main point of `update` is to be used for incremental computation: `update` performs algebraic simplifications of the logpdf-ratios computed in the weight that it returns. This is tracked through the `Diff` system.

A limitation of the current automation is that if an address "x" has a tensor value, and any index of "x" changes, the system will consider that "x" has changed without capturing a finer description of what exactly changed.

However, we can manually specify how something has changed in a more specific way.

In [None]:
@gen
def model(size_model: Const[int]):
    size_model = size_model.unwrap()
    x = normal(0.0, 1.0) @ "x"
    a = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "a"
    b = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "b"
    c = normal.vmap()(jnp.zeros(size_model), jnp.ones(size_model)) @ "c"
    obs = normal(jnp.sum(a) + jnp.sum(b) + jnp.sum(c) + x, 5.0) @ "obs"
    return obs

Let's create a trace from our model.

In [None]:
obs = C["obs"].set(
    1.0,
)
size_model = 10000
args = (Const(size_model),)
key, subkey = jax.random.split(key)
tr, _ = model.importance(subkey, obs, args)

tr.subtraces[3].inner

Let's first see an equivalent way to perform do what `update` does. 
Just like `update` generalizes `importance`, there is yet another more general interface, `edit`, which generalizes `update`.

We will go into the details of `edit` in a follow up cookbook.
For now, let's see the equivalent of `update` using `edit`. For this, we introduce a `Request` to change the trace.
`edit` will then answer the `Request` and change the trace following the logic of the request. 
To mimick `update`, we will perform an `Update` request.

In [None]:
change_in_value_for_a = jnp.ones(size_model)

# usual update
constraints = C["a"].set(change_in_value_for_a)
argdiffs = genjax.Diff.no_change(args)
key, subkey = jax.random.split(key)
new_tr1, _, _, _ = tr.update(subkey, constraints, argdiffs)

# update using `Request`
val = C.v(change_in_value_for_a)
request = StaticRequest({"a": Update(val)})
key, subkey = jax.random.split(key)
new_tr2, _, _, _ = request.edit(subkey, tr, args)

# comparing the values of both choicemaps after the update
jax.tree_util.tree_all(
    jax.tree.map(jnp.allclose, new_tr1.get_choices(), new_tr2.get_choices())
)

Now let's see how we can efficiently change the value of "a" at a specific index. 
For that, we create a more specific `Request` called an `IndexRequest`. This request expects another request for what to do at the given index.

In [None]:
# request = StaticRequest({
#         "a": IndexRequest(jnp.array(3), Update(C.v(1.)))})

# key, subkey = jax.random.split(key)
# new_tr, _, _, _ = request.edit(subkey, tr, args)


# jnp.sum(new_tr.get_choices()["a"] == 1.0)

Finally, let's compare both options.