What is marginalization? How and when to do it?

In [None]:
# TODO: for now basic explanation and example using GenSP, but will need to be revised to be more extensive

# outline:
# 1. inference is on full traces, how to reduce the trace and still do correct inference.
# 2. show example
# 3. explain the different new components
# 3.1 random_weighted
# 3.2 marginal
# 3.3 target
# 3.4 estimate_logpdf
# 3.5 estimate_normalizing_constant
# 3.6 estimate_reciprocal_normalizing_constant
# 3.7 Alg
# 4. explain how this recovers the base case

# TODO: may need to be cut into smaller pieces.


import jax
import jax.numpy as jnp
from genjax import beta, flip, gen, Target, ChoiceMap
from genjax.inference.smc import ImportanceK
from jax import jit


# Create a generative model.
@gen
def beta_bernoulli(α, β):
    p = beta(α, β) @ "p"
    v = flip(p) @ "v"
    return v


@jit
def run_inference(obs: bool):
    # Create an inference query - a posterior target - by specifying
    # the model, arguments to the model, and constraints.
    posterior_target = Target(
        beta_bernoulli,  # the model
        (2.0, 2.0),  # arguments to the model
        ChoiceMap.d({"v": obs}),  # constraints
    )

    # Use a library algorithm, or design your own - more on that in the docs!
    alg = ImportanceK(posterior_target, k_particles=50)

    # Everything is JAX compatible by default.
    # JIT, vmap, to your heart's content.
    key = jax.random.PRNGKey(314159)
    sub_keys = jax.random.split(key, 50)
    _, p_chm = jit(jax.vmap(alg.random_weighted, in_axes=(0, None)))(
        sub_keys, posterior_target
    )

    # An estimate of `p` over 50 independent trials of SIR (with K = 50 particles).
    return jnp.mean(p_chm["p"])


(run_inference(True), run_inference(False))