In [6]:
import genjax
from genjax import dippl
from genjax import gensp
import jax
import jax.numpy as jnp
import adevjax

console = genjax.pretty(show_locals=True)
key = jax.random.PRNGKey(314159)


@genjax.gen
def abnormal_flipper():
    flip = dippl.flip_enum(0.5) @ "flip"
    v = jax.lax.cond(flip, lambda: 10.0, lambda: 0.0)

    # score primitive -- accumulates onto log prob
    # only active inside genjax.gensp terms
    gensp.accum_score(v)


# Lift to GenSP.
lifted_model = gensp.choice_map_distribution(
    abnormal_flipper, genjax.select("flip"), None
)

# Our variational family, parametrized by p.
@genjax.gen
def variational_family(p):
    flip = dippl.flip_enum(p) @ "flip"


# Lift to GenSP.
lifted_family = gensp.choice_map_distribution(
    variational_family, genjax.select("flip"), None
)

# Now, we define a grad estimator using the loss language
# (this is lightweight syntax over GenSP's ADEV compatible interfaces
# and ADEV primitives, including `add_cost`)
def elbo_grad(key, p):

    # Loss is defined imperatively using two interfaces:
    # * `upper` invokes `random_weighted` and returns the sample `v`,
    #      and accumulates (-w) via ADEV's `add_cost` primitive.
    #
    # * `lower` invokes `estimate_logpdf` and accumulates (w)
    #      via ADEV's `add_cost` primitive.
    @dippl.loss
    def flip_variational_loss(p):
        v = dippl.upper(lifted_family)(p)
        dippl.lower(lifted_model)(v)

    # We can automatically derive `jvp` and `grad` estimators using ADEV.
    (p_grad,) = flip_variational_loss.grad_estimate(key, (p,))
    return p_grad


# Trials: look at grads for increasing `p` (overwhelmingly beneficial for the objective).
# Expected behavior: the gradient should decrease with increasing `p`
# (until numerical instability)
key, sub_key = jax.random.split(key)
jitted = jax.jit(elbo_grad)
for p in [0.0001, 0.001, 0.1, 0.3, 0.5, 0.7, 0.9, 0.9999, 0.99995]:
    key, sub_key = jax.random.split(key)
    p_grad = jitted(sub_key, p)
    print(p_grad)

19.21024
16.906754
12.197225
10.847298
10.0
9.152702
7.8027754
0.7899256
0.09672737
