I have a generative function with a single variable but 2000 observations or I just want to use/apply it repeatedly, what do I do?

In [1]:
import genjax
import jax
from genjax import bernoulli
from genjax import gen
import jax.numpy as jnp

# First start by creating a simple generative function
@gen
def double_flip(p,q):
    v1 = bernoulli(p) @ "v1" 
    v2 = bernoulli(q) @ "v2" 
    return v1+v2

# Now we can create a vectorized version that takes a batch of p values
# and calls the function for each value in the batch.
# The `in_axes` tell the `vmap_combinator` which arguments are 
# mapped over, and which are not. 
# The value `0` means we will map over this argument and `None` means we will not.
batched_double_flip = genjax.vmap_combinator(double_flip, in_axes=(0,None))

# Now we can use the batched version to generate a batch of samples
key = jax.random.PRNGKey(0)
size_of_batch = 20
# To do so, we have to create batched keys and p values
p = jax.random.uniform(key, (size_of_batch,))
q = 0.5
# We will run the generative function once for (p1, q), once for (p2, q), ...
traces = batched_double_flip.simulate(key, (p, q))
print(traces.get_retval())

# We can also use call it on (p1,q1), (p2,q2), ...
p = jax.random.uniform(key, (size_of_batch,))
q = jax.random.uniform(key, (size_of_batch,))
batched_double_flip_v2 = genjax.vmap_combinator(double_flip, in_axes = (0, 0))
traces = batched_double_flip_v2.simulate(key, (p, q))
print(traces.get_retval())

[2 0 2 1 1 2 1 2 2 1 1 2 2 0 2 1 0 1 2 1]
[2 0 2 1 1 2 1 2 2 1 1 2 2 0 2 1 0 1 2 1]


In [2]:
# Note: We cannot batch different variables with different shapes
try:
    p = jax.random.uniform(key, (size_of_batch,))
    q = jax.random.uniform(key, (size_of_batch + 1,))
    traces = batched_double_flip_v2.simulate(key, (p, q))
    print(traces.get_retval())
except:
    print("Error: The batched version of the generative function is not working correctly")

Error: The batched version of the generative function is not working correctly


What about iterating `vmap`, e.g. if we want to apply a generative function acting on a pixel over a 2D space?

In [3]:
key = jax.random.PRNGKey(42)
image = jnp.zeros([300,500], dtype = jnp.float32)

# Generative function on one "pixel" value
@gen
def sample_pixel(pixel):
    new_pixel = genjax.normal(pixel, 1.0) @ "new_pixel"
    return new_pixel

tr = sample_pixel.simulate(key, (0.0,))
print("new_pixel:", tr.get_sample()["new_pixel"])

# Now what if we want to apply a generative function over a 2D space? 
# We can use a nested `vmap_combinator`
sample_image = genjax.vmap_combinator(in_axes = (0,))(genjax.vmap_combinator(in_axes=(0,))(sample_pixel))

tr = sample_image.simulate(key,(image,))
#We can access the new_pixel value for each pixel in the image
print(tr.get_sample())
print("2D space choicemap:", tr.get_sample()[0, 0, "new_pixel"])
print("2D space choicemap:", tr.get_sample()[299, 499, "new_pixel"])

# Model wrapped in a bigger model
image = jnp.zeros([2,3], dtype=jnp.float32)
@gen
def model(p):
    sampled_image = sample_image(image) @ "sampled_image"
    return sampled_image[0]+p

tr = model.simulate(key,(0.0,))
# We can use ellipsis to access the new_pixel value for each pixel in the image
print("sampled_image:", tr.get_sample()["sampled_image",...,...,"new_pixel"])
print()

# Alternatively, we can flatten the 2 dimensions into one and use a single `vmap_combinator`. 
# This can be more efficient in some cases and usually has a faster compile time.
sample_image_flat = genjax.vmap_combinator(in_axes=(0,))(sample_pixel)
tr = sample_image_flat.simulate(key,(image.flatten(),))
# resize the sample to the original shape
out_image = tr.get_sample()[...,"new_pixel"].reshape(image.shape)
print("sampled_image:", out_image)

new_pixel: 1.3694694
IdxChm(addr=<jax.Array int32(300,) [≥0, ≤299] zero:1 nonzero:299>, c=IdxChm(addr=<jax.Array int32(300, 500)>, c=StaticChm(addr='new_pixel', c=ValueChm(v=<jax.Array float32(300, 500)>))))
2D space choicemap: Mask(flag=<jax.Array(True, dtype=bool)>, value=<jax.Array(0.80852467, dtype=float32)>)
2D space choicemap: Mask(flag=<jax.Array(True, dtype=bool)>, value=<jax.Array(0.32702535, dtype=float32)>)
sampled_image: [[ 0.13146028 -1.1258085  -0.3549479 ]
 [ 0.6431166   0.59783465  0.7178338 ]]

sampled_image: [[ 0.15701398  0.8691822  -0.88012   ]
 [-1.0565741  -0.9048738  -0.4763998 ]]


Oh but my iteration is actually over time, not space, i.e. I may want to reuse the same model by composing it with itself, e.g. for an HMM. 
For this, we can use the `scan` combinator.

In [4]:
# Simple kernel for a Hidden Markov Model (HMM) example.
@gen
def hmm_kernel(x):
    z = genjax.normal(x, 1.0) @ "z"
    y = genjax.normal(z, 1.0) @ "y"
    return y

# Now we can create a function that runs the kernel multiple times
@genjax.scan_combinator(max_length=10)
@gen
def hmm(x, c):
    x1 = hmm_kernel(x) @ "x1"
    return x1, None

# Alternatively, we can directly create the same HMM model
@genjax.scan_combinator(max_length=10)
@gen
def hmm_v2(x, c):
    z = genjax.normal(x, 1.0) @ "z"
    y = genjax.normal(z, 1.0) @ "y"
    return y, None

key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
initial_x = 0.0
tr_1 = hmm.simulate(key, (initial_x, None))
print("Value of z at the beginning:", tr_1.get_sample()[0, "x1","z"])
print("Value of y at the end:", tr_1.get_sample()[9, "x1","y"])
print(tr_1.get_sample()[..., "x1","z"])

tr_2 = hmm_v2.simulate(subkey, (initial_x, None))
print(tr_2.get_sample()[0, "z"])
print(tr_2.get_sample()[9,"y"])
print(tr_2.get_sample()[...,"z"])

Value of z at the beginning: Mask(flag=<jax.Array(True, dtype=bool)>, value=<jax.Array(-1.6903946, dtype=float32)>)
Value of y at the end: Mask(flag=<jax.Array(True, dtype=bool)>, value=<jax.Array(-1.7305961, dtype=float32)>)
[-1.6903946  0.6667787 -1.222213  -3.3950129 -2.656077  -1.0463438
 -2.2370918 -2.027141  -2.7775748 -3.7542048]
Mask(flag=<jax.Array(True, dtype=bool)>, value=<jax.Array(-0.06991103, dtype=float32)>)
Mask(flag=<jax.Array(True, dtype=bool)>, value=<jax.Array(4.7136474, dtype=float32)>)
[-0.06991103  0.25033963 -0.5039929   0.12699662 -1.3981494  -0.7294941
 -0.4651397  -0.539045    1.7028391   3.6123393 ]


In [5]:
# Alternatively, we can call the generative function with a repeat combinator
# This will run the generative function multiple times on a single argument and return the results
@genjax.gen
def model(y):
    x = genjax.normal(y, 0.01) @ "x"
    y = genjax.normal(x, 0.01) @ "y"
    return y

key = jax.random.PRNGKey(0)
arg = 3.0
tr = model.repeat(num_repeats=10).simulate(key, (arg,))

print(tr.get_sample()[...,"x"])
print()
print(tr.get_retval())
print()

# It can be combined with vmap
sub_keys = jax.random.split(key, 3)
args = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
num_repeats = 3
tr = jax.jit(jax.vmap(model.repeat(num_repeats=num_repeats).simulate, in_axes=(0, None)))(sub_keys, (args,))
print(tr.get_sample())
print()

# Note that it's running a computation |keys| * |args| * |num_repeats| times,
# i.e. 45 times in this case
print(tr.get_retval())

[3.0018916 3.0128644 3.0139802 3.026225  3.0105915 3.0164824 3.002301
 3.0027137 3.0074248 2.9685202]

[3.0059068 3.0282145 3.00671   3.0299015 3.0067174 3.0182729 3.0149763
 3.0054522 3.0251455 2.9722319]

IdxChm(addr=<jax.Array([0, 1, 2], dtype=int32)>, c=AddrMapChm(addr_map={Ellipsis: '_internal'}, c=StaticChm(addr='_internal', c=XorChm(c1=StaticChm(addr='x', c=ValueChm(v=<jax.Array float32(3, 3, 5) ≈3.0 ±1.4 [≥0.99, ≤5.0] nonzero:45>)), c2=StaticChm(addr='y', c=ValueChm(v=<jax.Array float32(3, 3, 5) ≈3.0 ±1.4 [≥0.99, ≤5.0] nonzero:45>))))))

[[[1.0132105  2.0106866  3.020013   3.9952815  4.9921556 ]
  [1.0091456  2.0109982  3.0036159  4.011836   5.004953  ]
  [0.9880202  2.0039277  2.9946487  4.0137424  4.9885173 ]]

 [[0.9929348  2.0283499  2.995508   3.9934149  5.0182095 ]
  [1.0136458  1.982539   3.0122788  3.9929068  5.0067253 ]
  [1.0344796  2.002393   3.0159934  3.9936728  4.9772143 ]]

 [[0.99285704 2.001703   3.0308223  3.9815516  5.0091867 ]
  [0.99326736 1.9940647  3.0176