In [None]:
import jax
import matplotlib.pyplot as plt
from genjax._src.adev.core import Dual, expectation
from genjax._src.adev.primitives import flip_enum

key = jax.random.PRNGKey(314159)

In [None]:
def jax_model(key, theta):
    b = jax.random.bernoulli(key, theta)
    return jax.lax.cond(b, lambda _: 0.0, lambda _: -theta / 2, None)


def expected_val(theta):
    return (theta**2 - theta) / 2


grad = jax.jit(jax.grad(jax_model, argnums=1))

arg = 0.2
epochs = 1000
vals = []
for _ in range(epochs):
    key, subkey = jax.random.split(key)
    grad_val = grad(subkey, arg)
    arg = arg - 0.01 * grad_val
    vals.append(expected_val(arg))


plt.plot(vals)

In [None]:
@expectation
def flip_exact_loss(theta):
    b = flip_enum(theta)
    return jax.lax.cond(
        b,
        lambda _: 0.0,
        lambda _: -theta / 2.0,
        theta,
    )


v, p_tangent = jax.jit(flip_exact_loss.jvp_estimate)(key, Dual(0.7, 1.0))