In [None]:
from functools import partial

import genstudio.plot as Plot
import jax
import jax.numpy as jnp
import jax.random as jrand

from genjax import pjax, pretty

pretty()

In [None]:
import jax
import jax.numpy as jnp

from genjax import ChoiceMap as Chm
from genjax import Selection as Sel
from genjax import beta, flip, gen
from genjax.edits import HMC


# Create a generative model.
@gen
def beta_bernoulli(α, β):
    p = beta(α, β) @ "p"
    v = flip(p) @ "v"
    return v


def exact_posterior_mean(obs, α, β):
    return (α + obs) / (α + β + 1)


# Implements HMC-within-SIR:
# create a trace, edit it with HMC, resample.
def inference_via_editing_traces(obs, α, β, k=50):
    (tr, lws) = beta_bernoulli.importance_k(k)(
        Chm.d({"v": obs}),  # constraint: "v" -> obs
        (α, β),  # (α, β)
    )
    (tr, lws_), _ = tr.repeat(n=20).edit_k(
        # run a single step of HMC for "p" with eps=1e-3.
        HMC(Sel.at["p"], jnp.array(1e-4))
    )
    (tr, _) = tr.resample_k(lws + lws_)
    return tr, jnp.mean(tr["p"])


α, β = 2.0, 2.0
obs = True
(
    exact_posterior_mean(obs, α, β),
    jax.jit(pjax.seed(inference_via_editing_traces))(jrand.key(4), obs, α, β)[1],
)

In [None]:
tr, _ = jax.jit(pjax.seed(partial(inference_via_editing_traces, k=500)))(
    jrand.key(4), obs, α, β
)
tr.visualize(
    "p",
    lambda v: Plot.ruleX([jnp.mean(v)], fill="red")
    + Plot.histogram(v)
    + {"width": 400, "height": 400, "inset": 0},
)

In [None]:
jitted = jax.jit(pjax.seed(partial(inference_via_editing_traces, k=500)))
jitted(jrand.key(4), obs, α, β)

In [None]:
%%timeit
jitted(jrand.key(4), obs, α, β)