In [None]:
import jax

import genjax
from genjax import ChoiceMapBuilder as C
from genjax import gen

key = jax.random.PRNGKey(0)

When doing inference with iterative algorithms like MCMC, we often need to make small adjustments to the choice map.

For instance, consider the following generative model.


In [None]:
@gen
def generate_datum(x, prob_outlier, noise, slope, intercept):
    b = genjax.flip(prob_outlier) @ "is_outlier"
    mu, std = jax.lax.cond(
        b, lambda _: (0.0, 10.0), lambda _: (slope * x + intercept, noise), None
    )
    return genjax.normal(mu, std) @ "y"


@gen
def model(xs):
    slope = genjax.normal(0.0, 2.0) @ "slope"
    intercept = genjax.normal(0.0, 2.0) @ "intercept"
    noise = genjax.gamma(0.0, 1.0) @ "noise"
    prob_outlier = genjax.beta(1.0, 1.0) @ "prob_outlier"
    ys = (
        generate_datum.vmap(in_axes=(0, None, None, None, None))(
            xs, prob_outlier, noise, slope, intercept
        )
        @ "data"
    )
    return ys


def make_observations(ys):
    obs = jax.vmap(lambda idx: C["data", idx, "y"].set(ys[idx]))(ys)
    return obs


# TODO: why does it feel a bit like pain to do this in GenJAX? I really want this to be as straightforward as in Gen.jl
@gen
def is_outlier_proposal(target, i):
    prob_outlier = genjax.beta(1.0, 1.0) @ "prob_outlier"
    return genjax.flip(prob_outlier) @ "is_outlier"

Let's write an inference using MH.

In [None]:
N_iters = 1000


def inference_program(key, xs, ys):
    constraints = make_observations(ys)
    trace, w = model.importance(key, (xs,), constraints)

    # TODO:
    for i in range(N_iters):
        trace, w = model.importance(
            key,
            (xs,),
            constraints,
            proposal=is_outlier_proposal,
            trace=trace,
            weights=w,
        )

    return trace

Let's now try to do it using the built-in update and see if there's a speedup.