I want to do my first inference task, how do I do it?

We will do it with importance sampling, which works as follows. We choose a distribution `q` called a proposal that you we will sample from, and we need a distribution `p` of interest, typically representing a posterior from a model having received observations.

In [None]:
import genjax
import jax
from jax import jit
import jax.numpy as jnp


# A simple python version of the algorithm to get the idea
def importance_sample(hard, easy):
    def _inner(key, hard_args, easy_args):
        trace = easy.simulate(
            key, *easy_args
        )  # we sample from the easy distribution, the proposal `q`
        chm = trace.get_sample()
        easy_logpdf = (
            trace.get_score()
        )  # we evaluate the score of the easy distribution q(x)
        hard_logpdf, _ = hard.assess(
            chm, *hard_args
        )  # we evaluate the score of the hard distribution p(x)
        importance_weight = hard_logpdf - easy_logpdf
        return (trace, importance_weight)
        # we return the trace and the importance weight p(x)/q(x).
        # the importance weight corrects the bias of the easy distribution
        # compared to the hard distribution

    return _inner

Which we can test on a very simple example.

In [None]:
complex_distribution = genjax.normal
simple_distribution = genjax.normal

complex_args = (0.0, 1.0)
simple_args = (3.0, 4.0)
key = jax.random.PRNGKey(0)
sample, importance_weight = jit(
    importance_sample(complex_distribution, simple_distribution)
)(key, (complex_args,), (simple_args,))
print(importance_weight, sample.get_sample())

In Genjax, every generative function comes equipped with a default proposal which we can use for importance sampling.

In [None]:
from genjax import beta, bernoulli, gen
from genjax import ChoiceMapBuilder as C


@gen
def beta_bernoulli_process(u):
    p = beta(0.0, u) @ "p"
    v = bernoulli(p) @ "v"
    return v


obs = C["v"].set(1)
args = (0.5,)
# The generative function with observation `obs` specifying the value of certain values of the choicemap represent a potentially complex posterior distribution
# The method .importance defines a default proposal based on the generative function which targets the posterior distribution
trace, weight = beta_bernoulli_process.importance(
    key, obs, args
)  # Runs importance sampling once

# This returns a pair containing the new trace and the log probability of produced trace under the model
print(trace.get_sample())
print(weight)

And we can also run it in parallel!

In [None]:
import jax.numpy as jnp

jitted = jax.jit(
    jax.vmap(
        importance_sample(complex_distribution, simple_distribution),
        in_axes=(0, None, None),
    )
)
key, *sub_keys = jax.random.split(key, 100 + 1)
sub_keys = jnp.array(sub_keys)
(sample, importance_weight) = jitted(sub_keys, (complex_args,), (simple_args,))
print(sample.get_choices(), importance_weight)

We can convert `N` weighted samples from importance sampling to `K` non-weighted samples that approximate the posterior.
This is K-sample importance resample or K-SIR.

In [None]:
N = 1000
K = 100


def sir(N, K, dist, chm):
    def _inner(key, args):
        key, subkey = jax.random.split(key, 2)
        samples, weights = jax.vmap(dist.importance, in_axes=(0, None, None))(
            jax.random.split(key, N), chm, args
        )

        idx = jax.vmap(jax.jit(genjax.categorical.simulate), in_axes=(0, None))(
            jax.random.split(subkey, K), (weights,)
        ).get_retval()

        choicemap = samples.get_choices()
        final_samples = jax.tree.map(lambda x: choicemap(x), idx)
        return final_samples

    return _inner


# Testing
key = jax.random.PRNGKey(0)
chm = C["v"].set(1)
args = (0.5,)
samples = jit(sir(N, K, beta_bernoulli_process, chm))(key, args)
print(samples)

In [None]:
# Another way to do the basically the same thing using library functions
from genjax import Target, smc
from jax import random, vmap

N = 1000
K = 100
key = jax.random.PRNGKey(0)
chm = C["v"].set(1)
arg = (0.5,)
target_posterior = Target(
    beta_bernoulli_process, (arg,), chm
)  # We define the target distribution, a posterior distribution in this case
alg = smc.ImportanceK(
    target_posterior, k_particles=N
)  # We specify what inference strategy we want to use, in this case SIR with N particles
sub_keys = random.split(
    key, K
)  # To get K independent samples from the posterior distribution, i.e. running N-particles based SIR K times.
# It's a bit different from the previous example, because each of the final
# K samples is obtained by running a different set of N-particles.
posterior_samples = jit(vmap(alg.simulate, in_axes=(0, None)))(
    sub_keys, (target_posterior,)
).get_retval()

# TODO: finish below
print(posterior_samples)

_, p_chm = jax.vmap(alg.random_weighted, in_axes=(0, None))(sub_keys, target_posterior)

# An estimate of `p` over 50 independent trials of SIR (with K = 50 particles).
print(jnp.mean(p_chm["p"]))

alg.get_num_particles