In [None]:
from genjax.delayed import delay, assume, observe
from genjax import beta, flip
from jax import make_jaxpr, jit, vmap
import jax.random as jrand
import jax.numpy as jnp

In [None]:
# "JAX Python"
def fn(x):
    v = jnp.sum(x)
    z = v + 3.0
    return z

make_jaxpr(fn)(jnp.ones(5))

{ lambda ; a:f32[5]. let
    b:f32[] = reduce_sum[axes=(0,)] a
    c:f32[] = add b 3.0
  in (c,) }

# Jaxpr has a set of primitives provided by JAX:
# add_p, sub_p, ... all these primitive array operations.
# you can add your primitives.

# When you add your own primitive -- you need to tell JAX a few things:
# * (abstract evaluation) If I give your primitive arrays of this shape and dtype, 
#   what will you give me back? You have to answer this, if you want your
#   primitive to work with `jax.make_jaxpr`
# * (batching) How does your primitive work with vmap? If I give you arrays,
#   and I tell you that the batching dimension is this, what do you give me back?
#   If you answer this, your primitive will work with vmap.

# My primitives are never vmapped over -- I always run an interpreter and
# eliminate my primitives into pure JAX primitives -- and then vmap works
# without a problem.
vmap(run_my_interpreter(fn_with_my_primitive))
# run_my_interpreter(...) is in "pure JAX" (no extension)

@gen
def model(args):
    x = some_other_gen_fn(...) @ "x"
    y = normal(some_likelihood_jax_computation(x), 3.0) @ "y"

# introduce a primitive called `trace_p` -- 
# and then the interpreters in `static.py` eliminate `trace_p` for pure JAX...
# model.simulate -> runs an interpreter, which turns primitives into pure JAX
vmap(model.simulate)(...)

**Marginalizing out the beta in the beta-bernoulli model:**

$\text{flip}(c; p) = p^c (1 - p)^{1 - c}$

$\text{beta}(p; \alpha, \beta) = \frac{1}{B(\alpha, \beta)} p^{\alpha - 1} (1 - p)^{\beta - 1}$

$\text{marg}(c; \alpha, \beta) = \frac{1}{B(\alpha, \beta)} \int_0^1 p^{\alpha - 1 + c}(1-p)^{\beta -c} dp$

$\text{marg}(0; \alpha, \beta) = \frac{1}{B(\alpha, \beta)} \int_0^1 p^{\alpha - 1}(1-p)^{\beta} dp = \frac{B(\alpha, \beta + 1)}{B(\alpha, \beta)}$

$\text{marg}(1; \alpha, \beta) = \frac{1}{B(\alpha, \beta)} \int_0^1 p^{\alpha - 1}(1-p)^{\beta} dp = \frac{B(\alpha + 1, \beta)}{B(\alpha, \beta)}$

In [None]:
# Low-level target language that supports static delayed sampling.
def fn(obs, a, b):
    # exact logic
    p = assume(beta, a, b)    # p ~ Beta(a, b)
    v = observe(obs, flip, p) # observe(obs, Flip(p))
    # non-linear with JAX primitives
    # ...
    return (p, v)

# Take a GenFn -> this language
# Take a sampler from this language -> GenFn
# model.generate(DelayedSampling("x", "y"), choice_map({"z":3.0}))

In [None]:
@gen
def model():
    @gen
    def submodel():
        p = beta(1.0, 1.0) @ "p" # beta.lower(1.0, 1.0) -> assume(beta, 1.0, 1.0) # ("s1", "p")
        return p

    @gen
    def submodel_(p):
        x = flip(p) @ "f" # flip.lower(p) -> observe(chm["s2", "f"], flip, p) # ("s2", "f")
        return x

    # Step 1 in lowering -- submodel.lower() -> Repr
    p = submodel() @ "s1"
    # Step 2 in lowering -- submodel_.lower(p) -> Repr
    f = submodel_(p) @ "s2"

# model.lower(choice_map) -> Repr

In [None]:
# model.generate(DelayedSampling("x", "y"), choice_map)
# -- Repr: lambda c, a, b:
#             %p = assume(beta, a, b)
#             observe(c, flip, %p)
# "Use delayed sampling" -> sample for some subset of the random variables that you care about
# Map that sample back into the choice map space.

In [None]:
# VmapCombinator
# s = ScanCombinator(model)
# Their logic is kind of complicated -- their "P" distribution involves some programmatic control flow dependency stuff

In [None]:
# s = ScanCombinator(model :: (C, S1) -> G (C, S2)) :: (C, Vec S1) -> G (C, Vec S2)
# s.generate("do delayed sampling _within_ model but not between interations")
# s.generate("do delayed sampling _across iterations_ of the scan")
# instead of thinking about 1 : N
# 1 : 2 -- maybe this will work depending on dependency structure in the scan? 
# Unroll parts of the scan, and then do delayed sampling on those parts.
# Vmap -- each slice is independent of each other -- so you can
# Switch -- 

In [None]:
def the_actual_beta_posterior_mean(v, N, a, b):
    new_a = a + N * v
    new_b = b + N - N * v
    return new_a / (new_a + new_b)

In [None]:
coin = jnp.array(True)
a = 2.0
b = 2.0

# Run a delayed sampler in parallel.
jitted = jit(vmap(delay(fn), in_axes=(0, None, None, None)))
(p, *_), w = jitted(
    jrand.split(jrand.key(3), 10000), 
    coin, a, b,
)
jnp.mean(p), the_actual_beta_posterior_mean(coin, 1, a, b)

In [None]:
%%timeit
(p, p2, *_), w = jitted(
    jrand.split(jrand.key(1), 1000), 
    coin, a, b,
)
jnp.mean(p), jnp.mean(p2), the_actual_beta_posterior_mean(coin, 1, a, b)

In [None]:
mjaxpr = make_jaxpr(delay(fn))
mjaxpr(
    jrand.key(1), 
    coin, a, b,
)