I want to do my first inference task, how do I do it?

We will do it with importance sampling, which works as follows.
We choose a distribution `q` called a proposal that you we will sample from, and we need a distribution `p` of interest, typically representing a posterior from a model having received observations.

In [8]:
import genjax
import jax

# A python version of the algorithm to get the idea
def importance_sample(hard, easy):
    def _inner(key, hard_args, easy_args):
        sample = easy.simulate(key, *easy_args)
        easy_logpdf = easy.logpdf(sample, *easy_args)
        hard_logpdf = hard.logpdf(sample, *hard_args)
        importance_weight = hard_logpdf - easy_logpdf
        return (importance_weight, sample)
    return _inner

Which we can test on a simple example.

In [11]:
complex_distribution = genjax.mixture_combinator(genjax.categorical, genjax.vmap_combinator(genjax.normal, in_axes=(0,None)))
simple_distribution = genjax.normal

In [None]:
#jitted = jax.jit(importance_sample(hard, easy))
key = jax.random.PRNGKey(0)
key, sub_key = jax.random.split(key)
mix_args = ([0.5, 0.5], [(-3.0, 0.8), (1.0, 0.3)])
d_args = ((0.0,), (1.0,))
#importance_weight, sample = jitted(sub_key, mix_args, d_args)
importance_weight, sample = importance_sample(hard,easy)(sub_key, mix_args, d_args)
print((importance_weight, sample))

And we can also run it in parallel

In [None]:
jitted = jax.jit(jax.vmap(importance_sample(hard, easy), in_axes=(0, None, None)))
key, *sub_keys = jax.random.split(key, 100 + 1)
sub_keys = jnp.array(sub_keys)
(importance_weight, sample) = jitted(sub_keys, mix_args, d_args)
print((importance_weight, sample))

We can also do it natively in genjax

In [12]:
# First using importance sampling with a default proposal
# We can also do it natively in genjax
model_trace_2, weight = complex_distribution.importance(key, genjax.ChoiceMap({"p": 3},{}), mix_args)

IndexError: tuple index out of range

In [None]:
#TODO: a version for GenSP here could be good.
