In [None]:
import jax
import matplotlib.pyplot as plt
from genjax._src.adev.core import Dual, expectation
from genjax._src.adev.primitives import flip_enum

key = jax.random.PRNGKey(314159)
EPOCHS = 300

We are often interested in the average returned value of a probabilistic program. For instance, it could be that 
a run of the program represents a run of a simulation of some form, and we would like to maximize the average reward across many simulations (or equivalently minimize a loss).

A popular technique from optimization is to use iterative methods such as (stochastic) gradient descent. 
One may think we could simply use JAX's AD system to compute gradients of probabilistic programs. 
Indeed, instead of getting a stochastic estimate of the average behaviour by running the program, we would obtain a new program and running it would get us an estimate of the gradient.

If we try it on a simple example in JAX however, we do not get what we would like at all!

In [None]:
def jax_model(key, theta):
    b = jax.random.bernoulli(key, theta)
    return jax.lax.cond(b, lambda _: 0.0, lambda _: -theta / 2, None)


def expected_val(theta):
    return (theta**2 - theta) / 2


grad = jax.jit(jax.grad(jax_model, argnums=1))

arg = 0.2
vals = []
for _ in range(EPOCHS):
    key, subkey = jax.random.split(key)
    grad_val = grad(subkey, arg)
    arg = arg - 0.01 * grad_val
    vals.append(expected_val(arg))

JAX seems happy to compute gradients and do some form of gradient descent, but let's see if we managed to minimize the function.

In [None]:
plt.plot(vals)

We seemed to start ok but then for some reason the curve goes back up and we end up maximizing the loss instead of minimizing it!

The reason is that we failed to account from the change of contribution of the coin flip from  `bernoulli` in the differentiation process, and we will come back to this in more details in follow up notebooks.

ADEV is a new algorithm that computes correct gradient estimates of expectations of probabilistic programs. It  accounts for the change to the expectation coming from a change to the underlying measure present in the expectation (from which all the randomness is drawn).

GenJAX implements ADEV. Slightly rewriting the example from above using GenJAX, we can see how different the behaviour of the optimization process with the corrected gradient estimates is.

In [None]:
@expectation
def flip_exact_loss(theta):
    b = flip_enum(theta)
    return jax.lax.cond(
        b,
        lambda _: 0.0,
        lambda _: -theta / 2.0,
        theta,
    )


adev_grad = jax.jit(flip_exact_loss.jvp_estimate)

arg = 0.2
adev_vals = []
for _ in range(EPOCHS):
    key, subkey = jax.random.split(key)
    grad_val = adev_grad(subkey, Dual(arg, 1.0)).tangent
    arg = arg - 0.01 * grad_val
    adev_vals.append(expected_val(arg))

In [None]:
plt.plot(vals)
plt.plot(adev_vals)
plt.legend(["JAX", "ADEV"])

Above, we just used a forward-mode version of ADEV in the above example. GenJAX also supports a reverse-mode version which is also fully compatible with JAX.

In [None]:
rev_adev_grad = jax.jit(flip_exact_loss.grad_estimate)

arg = 0.2
rev_adev_vals = []
for _ in range(EPOCHS):
    key, subkey = jax.random.split(key)
    (grad_val,) = rev_adev_grad(subkey, (arg,))
    arg = arg - 0.01 * grad_val
    rev_adev_vals.append(expected_val(arg))

plt.plot(rev_adev_vals)
plt.legend(["reverse mode ADEV"])