In [1]:
from genjax.delayed import delay, assume, observe
from genjax import beta, flip
from jax import make_jaxpr, jit, vmap
import jax.random as jrand
import jax.numpy as jnp

**Marginalizing out the beta in the beta-bernoulli model:**

$\text{flip}(c; p) = p^c (1 - p)^{1 - c}$

$\text{beta}(p; \alpha, \beta) = \frac{1}{B(\alpha, \beta)} p^{\alpha - 1} (1 - p)^{\beta - 1}$

$\text{marg}(c; \alpha, \beta) = \frac{1}{B(\alpha, \beta)} \int_0^1 p^{\alpha - 1 + c}(1-p)^{\beta -c} dp$

$\text{marg}(0; \alpha, \beta) = \frac{1}{B(\alpha, \beta)} \int_0^1 p^{\alpha - 1}(1-p)^{\beta} dp = \frac{B(\alpha, \beta + 1)}{B(\alpha, \beta)}$

$\text{marg}(1; \alpha, \beta) = \frac{1}{B(\alpha, \beta)} \int_0^1 p^{\alpha - 1}(1-p)^{\beta} dp = \frac{B(\alpha + 1, \beta)}{B(\alpha, \beta)}$

In [19]:
# Low-level target language that supports static delayed sampling.
def fn(obs, a, b):
    p = assume(beta, a, b)    # p ~ Beta(a, b)
    v = observe(obs, flip, p) # observe(obs, Flip(p))
    p2 = assume(beta, a, b)
    v2 = observe(obs, flip, p2)
    return (p, p2, v, v2)

In [20]:
def the_actual_beta_posterior_mean(v, N, a, b):
    new_a = a + N * v
    new_b = b + N - N * v
    return new_a / (new_a + new_b)

In [27]:
coin = jnp.array(True)
a = 2.0
b = 2.0
# Run a delayed sampler in parallel.
(p, p2, *_), w = vmap(delay(fn), in_axes=(0, None, None, None))(
    jrand.split(jrand.key(1), 1000), 
    coin, a, b,
)
jnp.mean(p), jnp.mean(p2), the_actual_beta_posterior_mean(coin, 1, a, b)

(Array(0.59878796, dtype=float32),
 Array(0.59689975, dtype=float32),
 Array(0.6, dtype=float32, weak_type=True))