In [1]:
import genjax
from genjax import dippl
from genjax import gensp
import jax
import jax.numpy as jnp
import adevjax

console = genjax.pretty(show_locals=True)
key = jax.random.PRNGKey(314159)

I0000 00:00:1696011130.487402  147892 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
@genjax.gen
def abnormal_flipper():
    flip = dippl.flip_enum(0.5) @ "flip"
    v = jax.lax.cond(flip, lambda: 10.0, lambda: 0.0)
    gensp.accum_score(v)


lifted_model = gensp.choice_map_distribution(
    abnormal_flipper, genjax.select("flip"), None
)

In [3]:
key, sub_key = jax.random.split(key)
lifted_model.random_weighted(sub_key)


[1m([0m
    [1;35mArray[0m[1m([0m[1;36m9.525923[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m,
    [1;35mValueChoiceMap[0m[1m([0m
        [33mvalue[0m=[1;35mHierarchicalChoiceMap[0m[1m([0m
            [33mtrie[0m=[1;35mTrie[0m[1m([0m[33minner[0m=[1m{[0m[32m'flip'[0m: [1;35mValueChoiceMap[0m[1m([0m[33mvalue[0m=[1;35mArray[0m[1m([0m[3;92mTrue[0m, [33mdtype[0m=[35mbool[0m[1m)[0m[1m)[0m[1m}[0m[1m)[0m
        [1m)[0m
    [1m)[0m
[1m)[0m

In [15]:
@genjax.gen
def variational_family(p):
    flip = dippl.flip_enum(p) @ "flip"


lifted_family = gensp.choice_map_distribution(
    variational_family, genjax.select("flip"), None
)

In [16]:
def variational_grad(key, p):
    key, sub_key = jax.random.split(key)
    key, logpdf_key = jax.random.split(key)

    @adevjax.adev
    def flip_loss(p):
        (family_w, v) = lifted_family.random_weighted(sub_key, p)
        model_w = lifted_model.estimate_logpdf(logpdf_key, v)
        return model_w - family_w

    return adevjax.E(flip_loss).grad_estimate(key, (p,))

In [19]:
key, sub_key = jax.random.split(key)
jax.jit(variational_grad)(sub_key, 0.9)

[1m([0m[1;35mArray[0m[1m([0m[1;36m9.41095[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m,[1m)[0m