In [None]:
import jax
import matplotlib.pyplot as plt
from genjax._src.adev.core import Dual, expectation
from genjax._src.adev.primitives import flip_enum

key = jax.random.PRNGKey(314159)
EPOCHS = 300

In [None]:
def jax_model(key, theta):
    b = jax.random.bernoulli(key, theta)
    return jax.lax.cond(b, lambda _: 0.0, lambda _: -theta / 2, None)


def expected_val(theta):
    return (theta**2 - theta) / 2


grad = jax.jit(jax.grad(jax_model, argnums=1))

arg = 0.2
vals = []
for _ in range(EPOCHS):
    key, subkey = jax.random.split(key)
    grad_val = grad(subkey, arg)
    arg = arg - 0.01 * grad_val
    vals.append(expected_val(arg))

In [None]:
@expectation
def flip_exact_loss(theta):
    b = flip_enum(theta)
    return jax.lax.cond(
        b,
        lambda _: 0.0,
        lambda _: -theta / 2.0,
        theta,
    )


adev_grad = jax.jit(flip_exact_loss.jvp_estimate)

arg = 0.2
adev_vals = []
for _ in range(EPOCHS):
    key, subkey = jax.random.split(key)
    grad_val = adev_grad(subkey, Dual(arg, 1.0)).tangent
    arg = arg - 0.01 * grad_val
    adev_vals.append(expected_val(arg))

In [None]:
plt.plot(vals)
plt.plot(adev_vals)
plt.legend(["JAX", "ADEV"])