I'm doing importance sampling as advised but it's bad, what can I do? 

One thing one can do is write a custom proposal for importance sampling.
The idea is to sample from this one instead of the default one used by genjax when using `model.importance`.
The default one is only informed by the structure of the model, and not by the posterior defined by both the model and the observations.

In [4]:
# Let's first define a simple model with a broad normal prior and some observations
from genjax import gen, normal
from genjax import ChoiceMapBuilder as C
from jax import random, jit, vmap
import jax.numpy as jnp
import jax

@gen
def model():
    # Initially, the prior is a pretty broad normal distribution centred at 0
    x = normal(0.0, 100.0) @ "x"
    # We add some observations, which will shift the posterior towards these values
    _ = normal(x, 1.0) @ "obs1"
    _ = normal(x, 1.0) @ "obs2"
    _ = normal(x, 1.0) @ "obs3"
    return x

# We create some data, 3 observed values at 234
obs = C["obs1"].set(234.0) ^ C["obs2"].set(234.0) ^ C["obs3"].set(234.0)

# We then run importance sampling with a default proposal
# And print the average weight of the samples, to give us a sense of how well the proposal is doing
key = random.PRNGKey(0)
key, *sub_keys = random.split(key, 1000 + 1)
sub_keys = jnp.array(sub_keys)
args = ()
jitted = jit(vmap(model.importance, in_axes=(0,None,None)))
trace, weight = jitted(sub_keys, obs, args)

# We can see that both the average and even maximum weight are quite low, which means that the proposal is not doing a great job.
# Ideally the weight should center around 1 and be quite concentrated around that value.
print("The average weight is", jax.scipy.special.logsumexp(weight))
print("The maximum weight is", weight.max())

# We now define a custom proposal, which will be a normal distribution centred around the observed values

@gen
def proposal(obs):
    avg_val = jnp.array(obs).mean()
    std = jnp.array(obs).std()
    x = normal(avg_val, 0.1 + std) @ "x" # To avoid a degenerate proposal, we add a small value to the standard deviation
    # Note that this is not very elegant as we'd like to only propose the latent variable `x`, but we need to add the observations to get a full trace
    _ = normal(x, 1.0) @ "obs1"
    _ = normal(x, 1.0) @ "obs2"
    _ = normal(x, 1.0) @ "obs3"
    return x

# To do things by hand first, let's reimplement the importance function
# It samples from the proposal and then computes the importance weight
def importance_sample(hard, easy):
    def _inner(key, hard_args, easy_args):
        trace = easy.simulate(key, *easy_args)   
        chm = trace.get_sample()
        easy_logpdf = trace.get_score()  
        hard_logpdf, _ = hard.assess(chm, *hard_args) 
        importance_weight = hard_logpdf - easy_logpdf
        return (trace, importance_weight) 
    return _inner

# We then run importance sampling with the custom proposal
key = random.PRNGKey(0)
key, *sub_keys = random.split(key, 1000 + 1)
sub_keys = jnp.array(sub_keys)
args_for_model = ()
args_for_proposal = (jnp.array([obs["obs1"], obs["obs2"], obs["obs3"]]),)
jitted = jit(vmap(importance_sample(model, proposal), in_axes=(0, None, None)))
trace, new_weight = jitted(sub_keys, (args_for_model,), (args_for_proposal,))
# We see that the new values, both average and maximum, are much higher than before, which means that the custom proposal is doing a much better job
print("The new average weight is", jax.scipy.special.logsumexp(new_weight))
print("The new maximum weight is", new_weight.max())


The average weight is -9.114648
The maximum weight is -9.121351
The new average weight is -0.808337
The new maximum weight is -1.6441832


In [16]:

# An equivalent way to do this is to use the importance function from the genjax library

from genjax import Target, smc
from jax import random, vmap

k_particles = 1000
args_for_model = (0.0, 1.0)
args_for_proposal = (jnp.array([obs["obs1"], obs["obs2"], obs["obs3"]]),)
key = random.PRNGKey(0)
target_posterior = Target(model, (args_for_model,), obs)
# TODO: find a way to see the proposal as a SampleDistribution.
# TODO: find a more elegant way to define the proposal only on the latent variable x
proposal = Target(proposal, args_for_proposal, ())
alg = smc.ImportanceK(target_posterior, proposal, k_particles=k_particles)
