# Modeling a biased coin with BamoJAX

This notebook shows how to build a simple Beta–Bernoulli model with `bamojax`, run Hamiltonian Monte Carlo via BlackJAX, and interpret both posterior and posterior predictive results.


## Data

We'll treat the array below as 20 coin flips collected in the lab. A value of 1 represents heads, 0 represents tails.


In [None]:
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import matplotlib.pyplot as plt
import numpyro.distributions as dist

from bamojax.base import Model
from bamojax.samplers import mcmc_sampler
from bamojax.inference import MCMCInference

from blackjax import nuts

%config InlineBackend.figure_format = 'retina'
plt.style.use('seaborn-v0_8-darkgrid')


In [None]:
coin_flips = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0], dtype=jnp.int32)
num_trials = int(coin_flips.size)
successes = int(coin_flips.sum())
failures = num_trials - successes

print(f"Observed {successes} heads and {failures} tails out of {num_trials} flips.")


## Build the model

BamoJAX represents Bayesian models as directed acyclic graphs of `Node` objects. We create a `Model`, place a uniform Beta prior on the coin bias `theta`, and attach a Bernoulli likelihood for each observed flip.


In [None]:
model = Model(name="Beta-Bernoulli coin")

theta = model.add_node("theta", distribution=dist.Beta(1.0, 1.0))
likelihood = model.add_node(
    "y",
    distribution=dist.Bernoulli,
    observations=coin_flips,
    parents={"probs": theta},
    shape=coin_flips.shape,
)

print("Latent nodes:", list(model.get_latent_nodes().keys()))
print("Leaf node:", [node.name for node in model.get_leaf_nodes()])


## Configure and run NUTS

BamoJAX wraps BlackJAX samplers. The first run triggers JAX compilation, so expect a small pause the first time you execute the cell.


In [None]:
initial_kernel = mcmc_sampler(
    model,
    mcmc_kernel=nuts,
    mcmc_parameters={"step_size": 0.1, "inverse_mass_matrix": jnp.array([1.0])},
)

inference = MCMCInference(
    model=model,
    num_chains=1,
    mcmc_kernel=initial_kernel,
    num_samples=2000,
    num_burn=500,
    num_warmup=500,
    return_diagnostics=True,
)

rng_key = jr.PRNGKey(2)

print("Running NUTS sampling (first call may compile JAX kernels)...")
results = inference.run(rng_key)

theta_samples = results["states"]["theta"]
diagnostics = results["info"]

print(f"Collected {{theta_samples.shape[0]}} posterior samples.")
print(f"Average acceptance rate: {{float(diagnostics.acceptance_rate.mean()):.3f}}")


## Posterior summary

Use the samples to compute credible intervals, tail probabilities, and a quick conjugate-Beta check for intuition.


In [None]:
theta_np = np.asarray(theta_samples)
posterior_mean = theta_np.mean()
posterior_ci = np.quantile(theta_np, [0.025, 0.5, 0.975])
prob_theta_gt_half = (theta_np > 0.5).mean()

posterior_alpha = 1.0 + successes
posterior_beta = 1.0 + failures

print(f"Posterior mean: {posterior_mean:.3f}")
print(f"Central 95% interval: [{posterior_ci[0]:.3f}, {posterior_ci[2]:.3f}]")
print(f"P(theta > 0.5 | data) = {prob_theta_gt_half:.3f}")
print(f"Conjugate Beta parameters (reference): alpha={posterior_alpha:.1f}, beta={posterior_beta:.1f}")


In [None]:
fig, ax = plt.subplots(figsize=(7, 4))
ax.hist(theta_np, bins=40, color="#4f6db8", alpha=0.85, density=True, label="Posterior samples")
ax.axvline(posterior_mean, color="#d05c3b", linestyle="--", linewidth=2, label=f"Mean = {posterior_mean:.3f}")
ax.set(xlabel=r"$\theta$", ylabel="Density", title="Posterior for coin bias")
ax.legend(frameon=False)
plt.show()


## Posterior predictive checks

Draw new datasets by feeding posterior samples back through `model.sample_predictive` and compare the implied number of heads to what we observed.


In [None]:
num_ppc_draws = 200
ppc_keys = jr.split(jr.PRNGKey(4), num_ppc_draws)
theta_subset = np.asarray(theta_samples[-num_ppc_draws:])

posterior_predictive_counts = []
for subkey, theta_value in zip(ppc_keys, theta_subset):
    state = {"theta": jnp.array(theta_value)}
    simulated = model.sample_predictive(subkey, state)["y"]
    posterior_predictive_counts.append(np.asarray(simulated).sum())

posterior_predictive_counts = np.asarray(posterior_predictive_counts)
ppc_interval = np.quantile(posterior_predictive_counts, [0.025, 0.975])

print(f"Posterior predictive mean heads: {posterior_predictive_counts.mean():.2f}")
print(f"Posterior predictive 95% interval for heads: [{ppc_interval[0]:.1f}, {ppc_interval[1]:.1f}]")


In [None]:
fig, ax = plt.subplots(figsize=(7, 4))
bins = np.arange(-0.5, num_trials + 1.5, 1)
ax.hist(posterior_predictive_counts, bins=bins, color="#5aa469", alpha=0.85, rwidth=0.9)
ax.axvline(successes, color="#d05c3b", linestyle="--", linewidth=2, label=f"Observed heads = {successes}")
ax.set(xlabel="Number of heads out of 20", ylabel="Frequency", title="Posterior predictive distribution")
ax.set_xticks(range(0, num_trials + 1, 2))
ax.legend(frameon=False)
plt.show()
