# GenJAX Basics: Beta-Bernoulli

First exploration of generative functions and inference in GenJAX.
Based on the quickstart example from genjax.gen.dev

In [5]:
import jax
import jax.numpy as jnp
import genjax
from genjax import beta, flip, gen, Target, ChoiceMap
from genjax.inference.smc import ImportanceK

## 1. Define a Generative Model

A generative function is a computational object representing a probability distribution.
The `@gen` decorator marks this as a generative function.

This model:
1. Draws `p` from a Beta(α, β) distribution
2. Flips a coin with probability `p`
3. Returns the coin flip result

The `@ "p"` and `@ "v"` are **addresses** - named random choices we can constrain/observe.

In [6]:
@gen
def beta_bernoulli(α, β):
    p = beta(α, β) @ "p"  # Draw p from Beta distribution
    v = flip(p) @ "v"     # Flip coin with probability p
    return v

## 2. Understanding Traces

When we run a generative function, we get a **trace** - a record of all random choices made.

In [11]:
key = jax.random.key(42)
trace = beta_bernoulli.simulate(key, (2.0, 2.0))

print(f"Return value: {trace.get_retval()}")
print(f"Choices: {trace.get_choices()}")
print(f"  p = {trace.get_choices()['p']}")
print(f"  v = {trace.get_choices()['v']}")
print(f"Log probability: {trace.get_score()}")

Return value: True
Choices: Static({'v': Choice(v=<jax.Array(True, dtype=bool)>), 'p': Choice(v=<jax.Array(0.76402074, dtype=float32)>)})
  p = 0.7640207409858704
  v = True
Log probability: -0.19057175517082214


## 3. Inference: What is p given we observed v=True?

This is the key question in probabilistic programming:
Given observations, what are the latent variables?

We create a **Target** (posterior) by specifying:
- The model
- Arguments to the model  
- Constraints (observations)

In [15]:
@jax.jit
def run_inference(obs: bool):
    # Create a posterior target
    posterior_target = Target(
        beta_bernoulli,          # the model
        (2.0, 2.0),              # arguments (α, β)
        ChoiceMap.d({"v": obs}), # constraint: we observed v
    )
    
    # Use Sampling Importance Resampling (SIR) with K=50 particles
    alg = ImportanceK(posterior_target, k_particles=50)
    
    # Run 50 independent trials
    key = jax.random.key(314)
    sub_keys = jax.random.split(key, 50)
    
    # vmap runs inference in parallel across all keys
    _, p_chm = jax.vmap(alg.random_weighted, in_axes=(0, None))(
        sub_keys, posterior_target
    )
    
    # Return mean estimate of p
    return jnp.mean(p_chm["p"])

In [16]:
p_given_true = run_inference(True)
p_given_false = run_inference(False)

print(f"E[p | v=True]  = {p_given_true:.4f}")
print(f"E[p | v=False] = {p_given_false:.4f}")

E[p | v=True]  = 0.5646
E[p | v=False] = 0.4338


## Key Concepts Learned

1. **Generative functions** (`@gen`) define probability distributions
2. **Addresses** (`@ "name"`) label random choices for later reference
3. **Traces** record all choices and their probabilities
4. **Targets** specify inference problems (model + constraints)
5. **Inference algorithms** (like ImportanceK) estimate posterior distributions
6. **JAX integration**: Everything is JIT-compilable and vmap-able

## Exercises

1. Try different α, β values - how does the prior affect the posterior?
2. What happens with more/fewer particles?
3. What if we observe multiple coin flips?