### I have a generative function with a single variable but 2000 observations, or I just want to use/apply it repeatedly. What do I do? [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ChiSym/genjax/blob/main/docs/cookbook/inactive/expressivity/iterating_computation.ipynb)

In [None]:
import sys

if "google.colab" in sys.modules:
    %pip install --quiet "genjax[genstudio]"

In [None]:
import jax
import jax.numpy as jnp

import genjax
from genjax import bernoulli, gen, pretty

key = jax.random.key(0)
pretty()

First start by creating a simple generative function

In [None]:
@gen
def double_flip(p, q):
    v1 = bernoulli(p) @ "v1"
    v2 = bernoulli(q) @ "v2"
    return v1 + v2


Now we can create a vectorized version that takes a batch of p values and calls the function for each value in the batch. The `in_axes` tell the `vmap` combinator which arguments are mapped over, and which are not. The value `0` means we will map over this argument and `None` means we will not.

In [None]:
batched_double_flip = double_flip.vmap(in_axes=(0, None))

Now we can use the batched version to generate a batch of samples

In [None]:
size_of_batch = 20


To do so, we have to create batched keys and p values

In [None]:
key, subkey = jax.random.split(key)
p = jax.random.uniform(subkey, (size_of_batch,))
q = 0.5

We will run the generative function once for (p1, q), once for (p2, q), ...

In [None]:
key, subkey = jax.random.split(key)
traces = batched_double_flip.simulate(subkey, (p, q))
traces.get_retval()

We can also use call it on `(p1, q1)`, `(p2, q2)`, ...

In [None]:
key, subkey = jax.random.split(key)
p = jax.random.uniform(subkey, (size_of_batch,))
key, subkey = jax.random.split(key)
q = jax.random.uniform(subkey, (size_of_batch,))
batched_double_flip_v2 = double_flip.vmap(in_axes=(0, 0))
key, subkey = jax.random.split(key)
traces = batched_double_flip_v2.simulate(subkey, (p, q))
traces.get_retval()

Note: We cannot batch different variables with different shapes

In [None]:
try:
    key, subkey = jax.random.split(key)
    p = jax.random.uniform(subkey, (size_of_batch,))
    key, subkey = jax.random.split(key)
    q = jax.random.uniform(subkey, (size_of_batch + 1,))
    key, subkey = jax.random.split(key)
    traces = batched_double_flip_v2.simulate(subkey, (p, q))
    print(traces.get_retval())
except ValueError as e:
    print(e)

What about iterating `vmap`, e.g. if we want to apply a generative function acting on a pixel over a 2D space?

In [None]:
image = jnp.zeros([300, 500], dtype=jnp.float32)

We first create a function on one "pixel" value.

In [None]:
@gen
def sample_pixel(pixel):
    new_pixel = genjax.normal(pixel, 1.0) @ "new_pixel"
    return new_pixel


key, subkey = jax.random.split(key)
tr = sample_pixel.simulate(subkey, (0.0,))
tr.get_choices()["new_pixel"]

Now what if we want to apply a generative function over a 2D space?

We can use a nested `vmap` combinator:

In [None]:
sample_image = sample_pixel.vmap(in_axes=(0,)).vmap(in_axes=(0,))
key, subkey = jax.random.split(key)
tr = sample_image.simulate(subkey, (image,))

We can access the new_pixel value for each pixel in the image

In [None]:
(
    tr.get_choices(),
    tr.get_choices()[0, 0, "new_pixel"],
    tr.get_choices()[299, 499, "new_pixel"],
)

We can wrap this model in a bigger model.

In [None]:
image = jnp.zeros([2, 3], dtype=jnp.float32)


@gen
def model(p):
    sampled_image = sample_image(image) @ "sampled_image"
    return sampled_image[0] + p


key, subkey = jax.random.split(key)
tr = model.simulate(subkey, (0.0,))
tr

We can use ellipsis to access the new_pixel value for each pixel in the image

In [None]:
tr.get_choices()["sampled_image", :, :, "new_pixel"]

Alternatively, we can flatten the 2 dimensions into one and use a single `vmap` combinator.
This can be more efficient in some cases and usually has a faster compile time.

In [None]:
sample_image_flat = sample_pixel.vmap(in_axes=(0,))
key, subkey = jax.random.split(key)
tr = sample_image_flat.simulate(subkey, (image.flatten(),))
# resize the sample to the original shape
out_image = tr.get_choices()[:, "new_pixel"].reshape(image.shape)
out_image

But wait, now I have a `jax.vmap` and a `genjax.vmap`, when do I use one vs another? 

The rule of thumb is that `jax.vmap` should only be applied to deterministic code. In particular, `model.simulate` is deterministic per given random key which we control explicitly, so we can use `jax.vmap` along the desired axes on this one. However, functions that use `~` in a `@genjax.gen` function should not be vmapped using `jax.vmap` and one should one `model.vmap` (or equivalently `genjax.vmap`) instead.

Oh but my iteration is actually over time, not space, i.e. I may want to reuse the same model by composing it with itself, e.g. for a Hidden Markov Model (HMM). 

For this, we can use the `scan` combinator.

In [None]:
@gen
def hmm_kernel(x):
    z = genjax.normal(x, 1.0) @ "z"
    y = genjax.normal(z, 1.0) @ "y"
    return y


@genjax.scan(n=10)
@gen
def hmm(x, _):
    x1 = hmm_kernel(x) @ "x1"
    return x1, None

Testing

In [None]:
key, subkey = jax.random.split(key)
initial_x = 0.0
tr_1 = hmm.simulate(subkey, (initial_x, None))
print("Value of z at the beginning:")
tr_1.get_choices()[0, "x1", "z"]

In [None]:
print("Value of y at the end:")
tr_1.get_choices()[9, "x1", "y"]

In [None]:
tr_1.get_choices()[:, "x1", "z"]

Alternatively, we can directly create the same HMM model

In [None]:
@genjax.scan(n=10)
@gen
def hmm_v2(x, _):
    z = genjax.normal(x, 1.0) @ "z"
    y = genjax.normal(z, 1.0) @ "y"
    return y, None

Testing the second version.

In [None]:
key, subkey = jax.random.split(key)
tr_2 = hmm_v2.simulate(subkey, (initial_x, None))
tr_2.get_choices()[0, "z"], tr_2.get_choices()[9, "y"], tr_2.get_choices()[:, "z"]

Yet another alternative, we can call the generative function with a `repeat` combinator.
This will run the generative function multiple times on a single argument and return the results

In [None]:
@genjax.gen
def model(y):
    x = genjax.normal(y, 0.01) @ "x"
    y = genjax.normal(x, 0.01) @ "y"
    return y


arg = 3.0
key, subkey = jax.random.split(key)
tr = model.repeat(n=10).simulate(subkey, (arg,))

tr.get_choices()[:, "x"], tr.get_retval()

It can for instance be combined with JAX's `vmap`.

In [None]:
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, 3)
args = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
n = 3
tr = jax.jit(jax.vmap(model.repeat(n=n).simulate, in_axes=(0, None)))(keys, (args,))
tr.get_choices()

Note that it's running a computation |keys| * |args| * |n| times, i.e. 45 times in this case

In [None]:
tr.get_retval()