In [None]:
import jax.numpy as jnp
import jax.random as jrand
from jax import jit

from genjax import ChoiceMap as Chm
from genjax import Selection as Sel
from genjax import beta, flip, gen
from genjax.edits import HMC


# Create a generative model.
@gen
def beta_bernoulli(α, β):
    p = beta(α, β) @ "p"
    v = flip(p) @ "v"
    return v


def exact_posterior_mean(obs, α, β):
    return (α + obs) / (α + β + 1)


# Implements HMC-within-SIR:
# create a trace, edit it with HMC, resample.
@jit
def inference_via_editing_traces(key, obs, α, β):
    key, (tr, lws) = beta_bernoulli.importance_k(500)(
        key,  # fresh randomness
        Chm.d({"v": obs}),  # constraint: "v" -> True
        (α, β),  # (α, β)
    )
    key, (tr, lws_, *_) = tr.edit_k(
        key,  # fresh randomness
        # run a single step of HMC for "p" with eps=1e-3.
        HMC(Sel.at["p"], jnp.array(1e-3)),
    )
    _, (tr, _) = tr.resample_k(key, lws + lws_)
    return jnp.mean(tr.get_choices()["p"])


α, β = 1.0, 1.0
obs = True
(
    exact_posterior_mean(obs, α, β),
    inference_via_editing_traces(jrand.key(1), obs, α, β),
)