One thing one can do is write a custom proposal for importance sampling.
The idea is to sample from this one instead of the default one used by genjax when using `model.importance`.
The default one is only informed by the structure of the model, and not by the posterior defined by both the model and the observations.

In [None]:
import jax.numpy as jnp
from genjax import ChoiceMapBuilder as C
from genjax import gen, normal
from jax import jit, random, vmap
from jax.scipy.special import logsumexp

Let's first define a simple model with a broad normal prior and some observations

In [None]:
@gen
def model():
    # Initially, the prior is a pretty broad normal distribution centred at 0
    x = normal(0.0, 100.0) @ "x"
    # We add some observations, which will shift the posterior towards these values
    _ = normal(x, 1.0) @ "obs1"
    _ = normal(x, 1.0) @ "obs2"
    _ = normal(x, 1.0) @ "obs3"
    return x


# We create some data, 3 observed values at 234
obs = C["obs1"].set(234.0) ^ C["obs2"].set(234.0) ^ C["obs3"].set(234.0)

We then run importance sampling with a default proposal,
snd print the average weight of the samples, to give us a sense of how well the proposal is doing.

In [None]:
key = random.PRNGKey(0)
key, *sub_keys = random.split(key, 1000 + 1)
sub_keys = jnp.array(sub_keys)
args = ()
jitted = jit(vmap(model.importance, in_axes=(0, None, None)))
trace, weight = jitted(sub_keys, obs, args)
print("The average weight is", logsumexp(weight) - jnp.log(len(weight)))
print("The maximum weight is", weight.max())

We can see that both the average and even maximum weight are quite low, which means that the proposal is not doing a great job.
Ideally, the weight should center around 1 and be quite concentrated around that value.
A weight much higher than 1 means that the proposal is too narrow and is missing modes. Indeed, for that to happen, one has to sample a very unlikely value under the proposal which is very likely under the target.
A weight much lower than 1 means that the proposal is too broad and is wasting samples. This happens in this case as the default proposal uses the broad prior `normal(0.0, 100.0)` as a proposal, which is far from the observed values centred around 234.0.

We now define a custom proposal, which will be a normal distribution centred around the observed values


In [None]:
@gen
def proposal(obs):
    avg_val = jnp.array(obs).mean()
    std = jnp.array(obs).std()
    x = (
        normal(avg_val, 0.1 + std) @ "x"
    )  # To avoid a degenerate proposal, we add a small value to the standard deviation
    # Note that this is not very elegant as we'd like to only propose the latent variable `x`, but we need to add the observations to get a full trace
    _ = normal(x, 1.0) @ "obs1"
    _ = normal(x, 1.0) @ "obs2"
    _ = normal(x, 1.0) @ "obs3"
    return x

To do things by hand first, let's reimplement the importance function.
It samples from the proposal and then computes the importance weight

In [None]:
def importance_sample(target, proposal):
    def _inner(key, target_args, proposal_args):
        trace = proposal.simulate(key, *proposal_args)
        chm = trace.get_sample()
        proposal_logpdf = trace.get_score()
        target_logpdf, _ = target.assess(chm, *target_args)
        importance_weight = target_logpdf - proposal_logpdf
        return (trace, importance_weight)

    return _inner

We then run importance sampling with the custom proposal

In [None]:
key = random.PRNGKey(0)
key, *sub_keys = random.split(key, 1000 + 1)
sub_keys = jnp.array(sub_keys)
args_for_model = ()
args_for_proposal = (jnp.array([obs["obs1"], obs["obs2"], obs["obs3"]]),)
jitted = jit(vmap(importance_sample(model, proposal), in_axes=(0, None, None)))
trace, new_weight = jitted(sub_keys, (args_for_model,), (args_for_proposal,))

We see that the new values, both average and maximum, are much higher than before, which means that the custom proposal is doing a much better job

In [None]:
print("The new average weight is", logsumexp(new_weight) - jnp.log(len(new_weight)))
print("The new maximum weight is", new_weight.max())