In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

from genjax._src.adev.core import Dual, expectation
from genjax._src.adev.primitives import flip_enum

key = jax.random.PRNGKey(314159)
EPOCHS = 600

We are often interested in the average returned value of a probabilistic program. For instance, it could be that 
a run of the program represents a run of a simulation of some form, and we would like to maximize the average reward across many simulations (or equivalently minimize a loss).

In [None]:
def jax_model(key, theta):
    b = jax.random.bernoulli(key, theta)
    return jax.lax.cond(b, lambda _: 0.0, lambda _: theta / 2, None)


thetas = jnp.arange(0.0, 1.0, 0.002)
keys = jax.random.split(key, len(thetas))

samples = jax.vmap(jax_model, in_axes=(0, 0))(keys, thetas)

We can see that the simulation can have two "modes" that split further appart over time.

In [None]:
plt.scatter(thetas, samples, s=1, label="samples")
plt.xlabel(r"$\theta$")
plt.ylabel("y")
plt.legend()

We can also easily imagine a more noisy version of the same idea.

In [None]:
sigma = 0.25


def more_noisy_jax_model(key, theta):
    b = jax.random.bernoulli(key, theta)
    return jax.lax.cond(
        b,
        lambda _: jax.random.normal(key) * sigma * theta**2 / 3,
        lambda _: (jax.random.normal(key) * sigma + theta) / 2,
        None,
    )


more_thetas = jnp.arange(0.0, 1.0, 0.0005)
keys = jax.random.split(key, len(more_thetas))

noisy_samples = jax.vmap(more_noisy_jax_model, in_axes=(0, 0))(keys, more_thetas)

plt.scatter(more_thetas, noisy_samples, s=1, label="samples")
plt.xlabel(r"$\theta$")
plt.ylabel("y")
plt.legend()

As we can see better on the noisy version, the samples divide into two groups. One tends to go up as theta increases while the other stays relatively stable around 0 with a higher variance. For simplicity of the analysis, in the rest of this notebook we will stick to the simpler first example.


In that simple case, we can compute the exact average value of the random process as a function of $\theta$. We have probability $\theta$ to return $0$ and probablity $1-\theta$ to return $\frac{\theta}{2}$. So overall the expected value is
$$\theta*0 + (1-\theta)*\frac{\theta}{2} = \frac{\theta-\theta^2}{2}$$

We can code this and plot the result for comparison.

In [None]:
def expected_val(theta):
    return (theta - theta**2) / 2


exact_vals = jax.vmap(expected_val)(thetas)
plt.plot(thetas, exact_vals, color="red", label="Expected value")
plt.scatter(thetas, samples, s=1, label="Samples")
plt.xlabel(r"$\theta$")
plt.ylabel("y")
plt.legend()

We can see that the curve in red is a perfectly reasonable differentiable function. We can use JAX to compute its derivative (more generally its gradient) at various points.

In [None]:
grad_exact = jax.jit(jax.grad(expected_val))
theta_tangent_points = [0.0, 0.3, 0.5, 1.0]

plot_thetas = jnp.linspace(0, 1, 400)
y = expected_val(plot_thetas)

plt.plot(plot_thetas, y, label="Expected value")

for theta_tan in theta_tangent_points:
    slope = grad_exact(theta_tan)
    y_intercept = expected_val(theta_tan) - slope * theta_tan
    tangent_line = slope * plot_thetas + y_intercept
    plt.plot(
        plot_thetas, tangent_line, "--", label=r"Tangent at $\theta$=" + f"{theta_tan}"
    )

plt.xlabel(r"$\theta$")
plt.ylabel("y")
plt.legend()
plt.title("Expectation curve and its Tangent Lines")
plt.grid(True)
plt.xlim([0, 1])
plt.ylim([0, 0.4])
plt.show()

A popular technique from optimization is to use iterative methods such as (stochastic) gradient ascent. 
Starting from any location, say 0.2, we can use JAX to find the maximum of the function. 

In [None]:
arg = 0.2
vals = []
arg_list = []
for _ in range(EPOCHS):
    grad_val = grad_exact(arg)
    arg_list.append(arg)
    vals.append(expected_val(arg))
    arg = arg + 0.01 * grad_val
    if arg < 0:
        arg = 0
        break
    elif arg > 1:
        arg = 1

We can plot the evolution of the value of the function over the iterations of the algorithms.

In [None]:
plt.plot(vals)
plt.xlabel(r"$\theta$")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.show()

We can also directly visualize the points on the curve.

In [None]:
color1 = "#D4CC47"
color2 = "#FB575D"


def hex_to_RGB(hex_str):
    """#FFFFFF -> [255,255,255]"""
    # Pass 16 to the integer function for change of base
    return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]


def get_color_gradient(c1, c2, n):
    """
    Given two hex colors, returns a color gradient
    with n colors.
    """
    assert n > 1
    c1_rgb = jnp.array(hex_to_RGB(c1)) / 255
    c2_rgb = jnp.array(hex_to_RGB(c2)) / 255
    mix_pcts = [x / (n - 1) for x in range(n)]
    rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts]
    return [
        "#" + "".join([format(int(round(val * 255)), "02x") for val in item])
        for item in rgb_colors
    ]


plt.scatter(
    arg_list,
    vals,
    color=get_color_gradient(color1, color2, len(arg_list)),
    s=1,
    label="Gradient descent: yellow at start and red at the end",
)
plt.plot(thetas, exact_vals, alpha=0.3)
plt.xlabel(r"$\theta$")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.show()

We have this in this example that we can compute the average value exactly. But will not be the case in general. One popular technique to approximate an average value is to use Monte Carlo Integration: we sample a bunch from the program and take the average value. 

As we use more and more samples we will converge to the correct result by the Central limit theorem. 

In [None]:
number_of_samples = sorted([1, 3, 5, 10, 20, 50, 100, 200, 500, 1000] * 7)
means = []
for n in number_of_samples:
    key, subkey = jax.random.split(key)
    keys = jax.random.split(key, n)
    samples = jax.vmap(jax_model, in_axes=(0, None))(keys, 0.3)
    mean = jnp.mean(samples)
    means.append(mean)

plt.scatter(
    number_of_samples,
    means,
    s=10,
    color=get_color_gradient(color1, color2, len(number_of_samples)),
    label="Mean estimate",
)

plt.xscale("log")
plt.axhline(expected_val(0.3), color="green", alpha=0.2, label="True value")
plt.xlabel("Number of samples")
plt.ylabel("y")
plt.legend()
plt.show()

As we just discussed, most of the time we will not be able to compute the average value and then compute the gradient using JAX. One thing we may want to try is to use JAX on the probabilistic program to get a gradient estimate, and hope that by using more and more samples this will converge to the correct gradient that we could use in optimization. Let's try it in JAX.

In [None]:
plot_thetas = jnp.linspace(0, 1, 400)
y = expected_val(plot_thetas)

plt.plot(plot_thetas, y, label="Expected value")

theta_tan = 0.3
slope = grad_exact(theta_tan)

slope_estimates = [slope + i / 20 for i in range(-4, 4)]
y_intercept = expected_val(theta_tan) - slope * theta_tan
tangent_line = slope * plot_thetas + y_intercept
plt.plot(
    plot_thetas,
    tangent_line,
    "--",
    label=r"Exact tangent at $\theta$=" + f"{theta_tan}",
)
for slope_est in slope_estimates:
    y_intercept = expected_val(theta_tan) - slope_est * theta_tan
    tangent_line = slope_est * plot_thetas + y_intercept
    plt.plot(plot_thetas, tangent_line, "--", alpha=0.3, color="orange")
y_intercept = expected_val(theta_tan) - 5 / 20 * theta_tan
tangent_line = 5 / 20 * plot_thetas + y_intercept
plt.plot(
    plot_thetas, tangent_line, "--", alpha=0.3, color="orange", label="Tangent Estimate"
)

plt.xlabel(r"$\theta$")
plt.ylabel("y")
plt.legend()
plt.title(r"Expectation curve and Tangent Estimates at $\theta=$0.3")
plt.grid(True)
plt.xlim([0, 1])
plt.ylim([0, 0.4])
plt.show()

In [None]:
def jax_model(key, theta):
    b = jax.random.bernoulli(key, theta)
    return jax.lax.cond(b, lambda _: 0.0, lambda _: theta / 2, None)


grad = jax.jit(jax.grad(jax_model, argnums=1))

arg = 0.2
vals = []
for _ in range(EPOCHS):
    key, subkey = jax.random.split(key)
    grad_val = grad(subkey, arg)
    arg = arg + 0.01 * grad_val
    vals.append(expected_val(arg))

JAX seems happy to compute something and we can use the iterative technique from before, but let's see if we managed to minimize the function.

In [None]:
plt.plot(vals, label="Attempting gradient ascent with JAX")
plt.xlabel("Iteration number")
plt.ylabel("y")
plt.legend()
plt.title("Maximization of the expected value of a probabilistic function")
plt.show()

Woops! We seemed to start ok but then for some reason the curve goes back down and we end up minimizing the function instead of maximizing it!

The reason is that we failed to account from the change of contribution of the coin flip from  `bernoulli` in the differentiation process, and we will come back to this in more details in follow up notebooks.

For now, let's just get a sense of what the gradient estimates computed by JAX look like.

In [None]:
theta_tangent_points = [0.15, 0.3, 0.65, 0.8]

plot_thetas = jnp.linspace(0, 1, 400)
y = expected_val(plot_thetas)

plt.plot(plot_thetas, y, label="Expected value")

for theta_tan in theta_tangent_points:
    key, subkey = jax.random.split(key)
    slope = grad(key, theta_tan)
    y_intercept = expected_val(theta_tan) - slope * theta_tan
    tangent_line = slope * plot_thetas + y_intercept
    plt.plot(
        plot_thetas,
        tangent_line,
        "--",
        label=r"Tangent estimate at $\theta$=" + f"{theta_tan}",
    )

plt.xlabel(r"$\theta$")
plt.ylabel("y")
plt.legend()
plt.title("Expectation curve and JAX-computed tangent estimates")
plt.grid(True)
plt.xlim([0, 1])
plt.ylim([0, 0.4])
plt.show()

Ouch. They do not look even remotely close to valid gradient estimates.

ADEV is a new algorithm that computes correct gradient estimates of expectations of probabilistic programs. It  accounts for the change to the expectation coming from a change to the randomness present in the expectation.

GenJAX implements ADEV. Slightly rewriting the example from above using GenJAX, we can see how different the behaviour of the optimization process with the corrected gradient estimates is.

In [None]:
@expectation
def flip_exact_loss(theta):
    b = flip_enum(theta)
    return jax.lax.cond(
        b,
        lambda _: 0.0,
        lambda _: -theta / 2.0,
        theta,
    )


adev_grad = jax.jit(flip_exact_loss.jvp_estimate)

arg = 0.2
adev_vals = []
for _ in range(EPOCHS):
    key, subkey = jax.random.split(key)
    grad_val = adev_grad(subkey, Dual(arg, 1.0)).tangent
    arg = arg - 0.01 * grad_val
    adev_vals.append(expected_val(arg))

In [None]:
plt.plot(vals)
plt.plot(adev_vals)
plt.legend(["Gradient ascent with JAX", "Gradient ascent with ADEV"])
plt.title("Maximization of the expected value of a probabilistic function")
plt.xlabel("Iteration number")
plt.ylabel("y")

In the above example, by using `jvp_estimate` we used a forward-mode version of ADEV. GenJAX also supports a reverse-mode version which is also fully compatible with JAX and can be jitted.

In [None]:
rev_adev_grad = jax.jit(flip_exact_loss.grad_estimate)

arg = 0.2
rev_adev_vals = []
for _ in range(EPOCHS):
    key, subkey = jax.random.split(key)
    (grad_val,) = rev_adev_grad(subkey, (arg,))
    arg = arg - 0.01 * grad_val
    rev_adev_vals.append(expected_val(arg))

plt.plot(rev_adev_vals, color="orange")
plt.legend(["Reverse mode ADEV"])
plt.title("Maximization of the expected value of a probabilistic function")
plt.xlabel("Iteration number")
plt.ylabel("y")