In [1]:
import genjax
from genjax import dippl
from genjax import gensp
import jax
import jax.numpy as jnp
import adevjax

console = genjax.pretty(show_locals=True)
key = jax.random.PRNGKey(314159)

I0000 00:00:1696092911.388769   65821 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
@genjax.gen
def abnormal_flipper():
    flip = dippl.flip_enum(0.5) @ "flip"
    v = jax.lax.cond(flip, lambda: 10.0, lambda: 0.0)
    gensp.accum_score(v)  # score primitive -- accumulates onto log prob


lifted_model = gensp.choice_map_distribution(
    abnormal_flipper, genjax.select("flip"), None
)

In [3]:
key, sub_key = jax.random.split(key)
(w, v) = lifted_model.random_weighted(sub_key)
v




└── [1m:flip[0m
    └──  bool[]

In [4]:
@genjax.gen
def variational_family(p):
    flip = dippl.flip_enum(p) @ "flip"


lifted_family = gensp.choice_map_distribution(
    variational_family, genjax.select("flip"), None
)

In [5]:
def variational_grad(key, p):
    @dippl.loss
    def flip_variational_loss(p):
        v = dippl.upper(lifted_family)(p)
        dippl.lower(lifted_model)(v)

    (p_grad,) = flip_variational_loss.grad_estimate(key, (p,))
    return p_grad

In [6]:
key, sub_key = jax.random.split(key)
jitted = jax.jit(variational_grad)
for p in [0.1, 0.3, 0.5, 0.7, 0.9, 0.9999, 0.99995]:
    key, sub_key = jax.random.split(key)
    p_grad = jitted(sub_key, p)
    print(p_grad)

12.197225
10.847298
10.0
9.152702
7.8027754
0.7899256
0.09672737
