I have a generative function with a single variable but 2000 observations or I just want to use/apply it repeatedly, what do I do?

In [6]:
import genjax
import jax
from genjax import bernoulli
from genjax import gen

# First start by creating a simple generative function
@gen
def double_flip(p,q):
    v1 = bernoulli(p) @ "v1" 
    v2 = bernoulli(q) @ "v2" 
    return v1+v2

# Now we can create a vectorized version that takes a batch of p values
# and calls the function for each value in the batch.
# The `in_axes` tell the `vmap_combinator` which arguments are 
# mapped over, and which are not. 
# The value `0` means we will map over this argument and `None` means we will not.
batched_double_flip = genjax.vmap_combinator(double_flip, in_axes=(0,None))

# Now we can use the batched version to generate a batch of samples
key = jax.random.PRNGKey(0)
size_of_batch = 20
# To do so, we have to create batched keys and p values
p = jax.random.uniform(key, (size_of_batch,))
q = 0.5
# We will run the generative function once for (p1,q), once for (p2,q), ...
traces = batched_double_flip.simulate(key, (p,q))
print(traces.get_retval())

# We can also use call it on (p1,q1), (p2,q2), ...
p = jax.random.uniform(key, (size_of_batch,))
q = jax.random.uniform(key, (size_of_batch,))
batched_double_flip_v2 = genjax.vmap_combinator(double_flip, in_axes=(0,0))
traces = batched_double_flip_v2.simulate(key, (p,q))
print(traces.get_retval())

[2 0 2 1 1 2 1 2 2 1 1 2 2 0 2 1 0 1 2 1]
[2 0 2 1 1 2 1 2 2 1 1 2 2 0 2 1 0 1 2 1]


In [7]:
# We cannot batch different variables with different shapes
try:
    p = jax.random.uniform(key, (size_of_batch,))
    q = jax.random.uniform(key, (size_of_batch+1,))
    traces = batched_double_flip_v2.simulate(key, (p,q))
    print(traces.get_retval())
except:
    print("Error: The batched version of the generative function is not working correctly")

Error: The batched version of the generative function is not working correctly


In [13]:
import jax.numpy as jnp

#TODO: adapt example below from Arijit for iterated vmap.
image = jnp.zeros([300,500], dtype=jnp.float32)

key = jax.random.PRNGKey(42)
@gen
def sample_pixel(pixel):
    new_pixel = genjax.normal(pixel,1.0) @ "new_pixel"
    return new_pixel

# on one "pixel" value
tr = sample_pixel.simulate(key,(0.0,))
#TODO: print(tr['new_pixel'])
# prints Array(1.3694694, dtype=float32)

# Now what if we want to apply a generative function over a 2D space? We can do a nested MAP combinator
sample_image = genjax.vmap_combinator(in_axes=(0,))(genjax.vmap_combinator(in_axes=(0,))(sample_pixel))

# sample an image
tr = sample_image.simulate(key,(image,))
#TODO: print(tr.inner.inner['new_pixel'].shape)

In [None]:
# Alternatively, we can call the generative function with a repeat combinator

#TODO: don't jit inside a for loop. is there an equivalent mistake in genjax?
#TODO: is my thing compiling or is it blocked at traced time? check make_jaxpr.
#TODO: I'm running OOM, what do I do?
#TODO: why is nested vmap so slow?