In [None]:
import jax
import jax.numpy as jnp

import genjax
from genjax import ChoiceMapBuilder as C
from genjax import gen

key = jax.random.PRNGKey(0)

When doing inference with iterative algorithms like MCMC, we often need to make small adjustments to the choice map.

For instance, consider the following generative model.


In [None]:
@gen
def generate_datum(x, prob_outlier, noise, slope, intercept):
    b = genjax.flip(prob_outlier) @ "is_outlier"
    mu, std = jax.lax.cond(
        b, lambda _: (0.0, 10.0), lambda _: (slope * x + intercept, noise), None
    )
    y = genjax.normal(mu, std) @ "y"
    return y


generate_data = generate_datum.vmap(in_axes=(0, None, None, None, None))


@gen
def model(xs):
    slope = genjax.normal(0.0, 2.0) @ "slope"
    intercept = genjax.normal(0.0, 2.0) @ "intercept"
    noise = genjax.gamma(0.0, 1.0) @ "noise"
    prob_outlier = genjax.beta(1.0, 1.0) @ "prob_outlier"
    ys = generate_data(xs, prob_outlier, noise, slope, intercept) @ "data"
    return ys


def make_observations(ys):
    range = jnp.arange(ys.shape[0])
    obs = jax.vmap(lambda idx: C["data", idx, "y"].set(ys[idx]))(range)
    return obs


@gen
def is_outlier_proposal(trace):
    # TODO: problem of accessing this as I get a vmapped trace as input which I don't assume here, so I need the right index etc.
    is_outlier = trace.get_choices()["is_outlier"]
    prob_outlier = jax.lax.cond(is_outlier, lambda _: 0.0, lambda _: 1.0, None)
    outlier_proposal = genjax.flip(prob_outlier) @ "is_outlier"
    return outlier_proposal

Let's write an inference using MH.

In [None]:
N_iters = 1000


def metropolis_hastings_move(key, model, proposal, trace, observations):
    model_args = trace.get_args()
    argdiffs = genjax.Diff.tree_diff_no_change(model_args)
    proposal_args_forward = (trace,)
    key, subkey = jax.random.split(key)
    fwd_choices, fwd_weight, _ = proposal.propose(key, proposal_args_forward)
    new_trace, weight, _, discard = model.update(subkey, trace, fwd_choices, argdiffs)
    proposal_args_backward = (new_trace,)
    bwd_weight, _ = proposal.assess(discard, proposal_args_backward)
    α = weight - fwd_weight + bwd_weight
    key, subkey = jax.random.split(key)
    ret_trace = jax.lax.cond(
        jnp.log(jax.random.uniform(subkey)) < α, lambda: new_trace, lambda: trace
    )
    return ret_trace


def inference_program_1(key, xs, ys):
    constraints = make_observations(ys)
    # TODO: should I use generate here?
    trace, _ = model.importance(key, constraints, (xs,))

    len = xs.shape[0]
    for i in range(N_iters):
        trace = metropolis_hastings_move(
            key, model, is_outlier_proposal.repeat(n=len), trace, constraints
        )

    return trace


# testing
key, subkey = jax.random.split(key)
xs = jnp.linspace(0, 10, 100)
ys = (
    generate_datum.vmap(in_axes=(0, None, None, None, None))
    .simulate(key, (xs, 0.1, 1.0, 1.0, 1.0))
    .get_retval()
)
key, subkey = jax.random.split(key)
inference_program_1(subkey, xs, ys)

Let's now try to do it using the built-in update and see if there's a speedup.