In [None]:
# Import and constants
import genstudio.plot as Plot
import jax
import jax.numpy as jnp

from genjax._src.adev.core import Dual, expectation
from genjax._src.adev.primitives import flip_enum, normal_reparam

key = jax.random.key(314159)
EPOCHS = 400
sigma = 0.05

In [None]:
# Model
def noisy_jax_model(key, theta):
    b = jax.random.bernoulli(key, theta)
    return jax.lax.cond(
        b,
        lambda theta: jax.random.normal(key) * sigma * theta,
        lambda theta: jax.random.normal(key) * sigma + theta / 2,
        theta,
    )


def expected_val(theta):
    return (theta - theta**2) / 2

In [None]:
# Samples
thetas = jnp.arange(0.0, 1.0, 0.0005)
keys = jax.random.split(key, len(thetas))

noisy_samples = jax.vmap(noisy_jax_model, in_axes=(0, 0))(keys, thetas)

plot_options = Plot.new(
    Plot.color_legend(),
    {"x": {"label": "θ"}, "y": {"label": "y"}},
    Plot.aspect_ratio(1),
    Plot.grid(),
)

(
    Plot.dot({"x": thetas, "y": noisy_samples}, fill=Plot.constantly("samples"), r=2)
    + plot_options
)

In [None]:
# Adding exact expectation
exact_vals = jax.vmap(expected_val)(thetas)

expected_value_plot = (
    Plot.line(
        {"x": thetas, "y": exact_vals},
        strokeWidth=2,
        stroke=Plot.constantly("Expected value"),
    )
    + plot_options
)

(
    Plot.dot({"x": thetas, "y": noisy_samples}, fill=Plot.constantly("samples"), r=2)
    + expected_value_plot
    + plot_options
)

In [None]:
# JAX computed exact gradients
grad_exact = jax.jit(jax.grad(expected_val))
theta_tangent_points = [0.1, 0.3, 0.45]

plot_thetas = jnp.linspace(0, 1, 400)
y = expected_val(plot_thetas)

# Optimization on ideal curve
arg = 0.2
vals = []
arg_list = []
for _ in range(EPOCHS):
    grad_val = grad_exact(arg)
    arg_list.append(arg)
    vals.append(expected_val(arg))
    arg = arg + 0.01 * grad_val
    if arg < 0:
        arg = 0
        break
    elif arg > 1:
        arg = 1

color1 = "#D4CC47"
color2 = "#FB575D"


def hex_to_RGB(hex_str):
    """#FFFFFF -> [255,255,255]"""
    return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]


def get_color_gradient(c1, c2, n):
    """
    Given two hex colors, returns a color gradient
    with n colors.
    """
    assert n > 1
    c1_rgb = jnp.array(hex_to_RGB(c1)) / 255
    c2_rgb = jnp.array(hex_to_RGB(c2)) / 255
    mix_pcts = [x / (n - 1) for x in range(n)]
    rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts]
    return [
        "#" + "".join([format(int(round(val * 255)), "02x") for val in item])
        for item in rgb_colors
    ]

In [None]:
(
    Plot.line({"x": list(range(EPOCHS)), "y": vals})
    + {"x": {"label": "Iteration"}, "y": {"label": "y"}}
)

In [None]:
def tangent_line_plot(theta_tan):
    slope = grad_exact(theta_tan)
    y_intercept = expected_val(theta_tan) - slope * theta_tan
    tangent_line = slope * plot_thetas + y_intercept
    return Plot.line(
        {"x": plot_thetas, "y": tangent_line},
        opacity=0.75,
        stroke=Plot.constantly(f"Tangent at θ={theta_tan}"),
    )


(
    plot_options
    + [tangent_line_plot(theta_tan) for theta_tan in theta_tangent_points]
    + expected_value_plot
    + Plot.domain([0, 1], [0, 0.4])
    + Plot.title("Expectation curve and its Tangent Lines")
)

In [None]:
theta_tan = 0.3
slope = grad_exact(theta_tan)
slope_estimates = [slope + i / 20 for i in range(-4, 4)]
y_intercept = expected_val(theta_tan) - slope * theta_tan
tangent_line = slope * plot_thetas + y_intercept


def slope_plot(slope_est):
    y_intercept = expected_val(theta_tan) - slope_est * theta_tan
    return Plot.line(
        {"x": plot_thetas, "y": slope_est * plot_thetas + y_intercept},
        strokeWidth=2,
        stroke=Plot.constantly("Tangent estimate"),
    )


exact_tangent_plot = Plot.line(
    {"x": plot_thetas, "y": tangent_line},
    strokeWidth=2,
    stroke=Plot.constantly("Exact tangent at θ=0.3"),
)


Plot.new(
    # + Plot.domain([0, 1], [0, 0.4])
    [slope_plot(slope_est) for slope_est in slope_estimates],
    exact_tangent_plot,
    Plot.dot({"x": thetas, "y": noisy_samples}, fill=Plot.constantly("Samples"), r=2),
    Plot.title("Expectation curve and Tangent Estimates at θ=0.3"),
    Plot.color_map({
        "Expected value": "blue",
        "Tangent estimate": "rgba(255,165,0,0.3)",
        "Exact tangent at θ=0.3": "rgba(255,165,0,1)",
        "Samples": "teal",
    }),
    expected_value_plot,
    plot_options,
)

In [None]:
jax_grad = jax.jit(jax.grad(noisy_jax_model, argnums=1))

arg = 0.2
vals = []
grads = []
for _ in range(EPOCHS):
    key, subkey = jax.random.split(key)
    grad_val = jax_grad(subkey, arg)
    arg = arg + 0.01 * grad_val
    vals.append(expected_val(arg))
    grads.append(grad_val)

In [None]:
(
    Plot.line(
        {"x": list(range(EPOCHS)), "y": vals},
        stroke=Plot.constantly("Attempting gradient ascent with JAX"),
    )
    + {"x": {"label": "Iteration"}, "y": {"label": "y"}}
    + Plot.domainX([0, EPOCHS])
    + Plot.title("Maximization of the expected value of a probabilistic function")
    + Plot.color_legend()
)

In [None]:
theta_tangent_points = [0.15, 0.3, 0.65, 0.8]

plot_thetas = jnp.linspace(0, 1, 400)
y = expected_val(plot_thetas)


def theta_tangent_plot(theta_tan):
    global key
    key, subkey = jax.random.split(key)
    slope = jax_grad(subkey, theta_tan)
    y_intercept = expected_val(theta_tan) - slope * theta_tan
    tangent_line = slope * plot_thetas + y_intercept
    return Plot.line(
        {"x": plot_thetas, "y": tangent_line},
        opacity=0.75,
        stroke=Plot.constantly(f"Tangent estimate at θ={theta_tan}"),
    )


(
    plot_options
    + [theta_tangent_plot(theta_tan) for theta_tan in theta_tangent_points]
    + expected_value_plot
    + Plot.domain([0, 1], [0, 0.4])
    + Plot.title("Expectation curve and JAX-computed tangent estimates")
)

In [None]:
@expectation
def flip_approx_loss(theta):
    b = flip_enum(theta)
    return jax.lax.cond(
        b,
        lambda theta: normal_reparam(0.0, sigma) * theta,
        lambda theta: normal_reparam(theta / 2, sigma),
        theta,
    )


adev_grad = jax.jit(flip_approx_loss.jvp_estimate)

arg = 0.2
adev_vals = []
adev_grads = []
for _ in range(EPOCHS):
    key, subkey = jax.random.split(key)
    grad_val = adev_grad(subkey, Dual(arg, 1.0)).tangent
    adev_grads.append(grad_val)
    arg = arg + 0.01 * grad_val
    adev_vals.append(expected_val(arg))

In [None]:
(
    Plot.line(
        {"x": list(range(EPOCHS)), "y": vals},
        stroke=Plot.constantly("Gradient ascent with JAX"),
    )
    + Plot.line(
        {"x": list(range(EPOCHS)), "y": adev_vals},
        stroke=Plot.constantly("Gradient ascent with ADEV"),
    )
    + {"x": {"label": "Iteration"}, "y": {"label": "y"}}
    + Plot.domainX([0, EPOCHS])
    + Plot.title("Maximization of the expected value of a probabilistic function")
    + Plot.color_legend()
)

In [None]:
(
    Plot.line(
        {"x": list(range(EPOCHS)), "y": grads},
        stroke=Plot.constantly("Gradients from JAX"),
    )
    + Plot.line(
        {"x": list(range(EPOCHS)), "y": adev_grads},
        stroke=Plot.constantly("Gradients from ADEV"),
    )
    + {"x": {"label": "Iteration"}, "y": {"label": "y"}}
    + Plot.domainX([0, EPOCHS])
    + Plot.title("Comparison of computed gradients JAX vs ADEV")
    + Plot.color_legend()
)