In [None]:
import jax.random as jrand
from jax import make_jaxpr

from genjax import Diff, Selection, beta, flip, gen

In [None]:
# Simple example:


@gen
def model(alpha1, alpha2):
    v1 = beta(alpha1, 1.0) @ "v1"
    v2 = beta(alpha2, 2.0) @ "v2"
    f = flip((v1 + v2) / 2) @ "f"
    return f


key = jrand.key(1)

# Start with a trace ...
tr = model.simulate(key, (1.0, 2.0))

# `blanket`: I want to change it at 'Selection' ...
_, blanket_fn = model.blanket(
    tr, Selection.at["v1"], (Diff.unknown_change(2.0), Diff.no_change(2.0))
)

# Returns a function which produces the IR if we `make_jaxpr`
make_jaxpr(blanket_fn)(2.0, 2.0)

In [None]:
# More complex -- what if we change "v2" via an argument change?
@gen
def model(alpha1, alpha2):
    v1 = beta(alpha1, 1.0) @ "v1"
    v2 = beta(alpha2, 2.0) @ "v2"
    f = flip((v1 + v2) / 2) @ "f"
    return f


key = jrand.key(1)

# Start with a trace ...
tr = model.simulate(key, (1.0, 2.0))

# `blanket`: I want to change it at 'Selection' ...
_, blanket_fn = model.blanket(
    tr, Selection.at["v1"], (Diff.no_change(1.0), Diff.unknown_change(3.0))
)

# Returns a function which produces the IR if we `make_jaxpr`
make_jaxpr(blanket_fn)(1.0, 3.0)