In [1]:
import genjax
from genjax import dippl
from genjax import gensp
import jax
import jax.numpy as jnp
import adevjax

console = genjax.pretty(show_locals=True)
key = jax.random.PRNGKey(314159)


# Define our model, and lift to a GenSP choice map distribution.
@genjax.gen
def model():
    x = dippl.normal_reparam(0.0, 10.0) @ "x"
    y = dippl.normal_reparam(0.0, 10.0) @ "y"
    z = dippl.normal_reparam(x**2 + y**2, 0.1) @ "z"


lifted_model = gensp.choice_map_distribution(
    model,
)

# Now, we define our variational proposal.
@genjax.gen
def parametrized_proposal(μ1, σ1, μ2, σ2):
    x = dippl.normal_reparam(μ1, jnp.exp(σ1)) @ "x"
    y = dippl.normal_reparam(μ2, jnp.exp(σ2)) @ "y"


# We embed our proposal into _differentiable_ sampling importance resampling.
variational_sir = dippl.importance_enum(
    100,
    gensp.choice_map_distribution(parametrized_proposal, genjax.select("x", "y")),
)

# Now, we define a grad estimator using the loss language
# (this is lightweight syntax over GenSP's ADEV compatible interfaces
# and ADEV primitives, including `add_cost`)
def elbo_grad(key, fixed_data, params):

    # Loss is defined imperatively using two interfaces:
    # * `upper` invokes `random_weighted` and returns the sample `v`,
    #      and accumulates (-w) via ADEV's `add_cost` primitive.
    #
    # * `lower` invokes `estimate_logpdf` and accumulates (w)
    #      via ADEV's `add_cost` primitive.
    @dippl.loss
    def variational_loss(μ1, σ1, μ2, σ2):
        tgt = gensp.target(model, (), fixed_data)
        v = dippl.do_upper(variational_sir)(tgt, (μ1, σ1, μ2, σ2))
        merged = gensp.merge(v, genjax.ValueChoiceMap(fixed_data))
        dippl.do_lower(lifted_model)(merged)

    # We can automatically derive `jvp` and `grad` estimators using ADEV.
    # (p_grad,) = flip_variational_loss.grad_estimate(key, (p,))
    (μ1, σ1, μ2, σ2) = params
    params_grad = variational_loss.grad_estimate(key, (μ1, σ1, μ2, σ2))
    return params_grad


# Trials:
jitted = jax.jit(elbo_grad)
fixed_data = genjax.choice_map({"z": 5.0})
μ1, σ1, μ2, σ2 = (2.0, 0.01, 2.0, 0.01)


def train(key, fixed_data, params, num_loops=100):
    def _inner(carry, xs):
        (μ1, σ1, μ2, σ2) = carry
        key = xs
        key, sub_key = jax.random.split(key)
        (μ1_grad, σ1_grad, μ2_grad, σ2_grad) = elbo_grad(
            key, fixed_data, (μ1, σ1, μ2, σ2)
        )

        # Nothing too fancy when it comes to optimizer -- just fixed learning rates.
        μ1 = μ1 + 1e-6 * μ1_grad
        σ1 = σ1 + 1e-6 * σ1_grad
        μ2 = μ2 + 1e-6 * μ2_grad
        σ2 = σ2 + 1e-6 * σ2_grad
        return (μ1, σ1, μ2, σ2), (μ1, σ1, μ2, σ2)

    keys = jax.random.split(key, num_loops)
    _, (μ1, σ1, μ2, σ2) = jax.lax.scan(_inner, params, keys)
    return (μ1, σ1, μ2, σ2)


key, sub_key = jax.random.split(key)
jitted = jax.jit(train)
params = jitted(key, fixed_data, (μ1, σ1, μ2, σ2))
params


[1m([0m
    [1;35mArray[0m[1m([0m[1m[[0m[1;36m1.7186848[0m, [1;36m1.6114966[0m, [1;36m1.5304842[0m, [1;36m1.4996581[0m, [1;36m1.4640118[0m, [1;36m1.4495015[0m,
       [1;36m1.4263766[0m, [1;36m1.4083099[0m, [1;36m1.4019818[0m, [1;36m1.3983003[0m, [1;36m1.3956025[0m, [1;36m1.4003799[0m,
       [1;36m1.4070473[0m, [1;36m1.4110829[0m, [1;36m1.4112598[0m, [1;36m1.4093934[0m, [1;36m1.4073944[0m, [1;36m1.4187672[0m,
       [1;36m1.4126804[0m, [1;36m1.4144053[0m, [1;36m1.4214003[0m, [1;36m1.4298216[0m, [1;36m1.4375436[0m, [1;36m1.4436388[0m,
       [1;36m1.4502875[0m, [1;36m1.4568324[0m, [1;36m1.467121[0m , [1;36m1.4691129[0m, [1;36m1.4722154[0m, [1;36m1.4694514[0m,
       [1;36m1.4781789[0m, [1;36m1.4855672[0m, [1;36m1.4879848[0m, [1;36m1.4917557[0m, [1;36m1.4945[0m   , [1;36m1.4978713[0m,
       [1;36m1.4926543[0m, [1;36m1.5013129[0m, [1;36m1.4995786[0m, [1;36m1.5050484[0m, [1;36m1.5100108[0m, [1;36m

In [2]:
%%timeit
params = jitted(key, fixed_data, (μ1, σ1, μ2, σ2))

9.49 ms ± 1.06 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
