In [1]:
import genjax
from genjax import gen
from genjax.inference import importance_sampling

import jax.numpy as jnp
import jax
from jax import jit, vmap

In [2]:
genjax.__file__

'/home/ubuntu/workspace/bayes3d/assets/genjax/src/genjax/__init__.py'

In [3]:
@gen
def model(x):
    y = genjax.normal(x,1.0) @ "y"
    return y

@gen
def prop():
    y = genjax.normal(0.0,5.0) @ "y"
    return y

key = jax.random.PRNGKey(0)

model_args = (1.,)
prop_args = ()

model.simulate(key, model_args)
prop.simulate(key, prop_args)

(Array([4146024105,  967050713], dtype=uint32),
 BuiltinTrace(gen_fn=BuiltinGenerativeFunction(source=<function prop at 0x7f44c712e1f0>), args=(), retval=Array(-6.2576942, dtype=float32), choices=Trie(inner={'y': DistributionTrace(gen_fn=Normal(), args=(0.0, 5.0), value=Array(-6.2576942, dtype=float32), score=Array(-8.716225, dtype=float32))}), cache=Trie(inner={}), score=Array(-8.716225, dtype=float32)))

In [4]:
def importance_sampling_fix(key, model, model_args, observations, proposal, proposal_args,   N):

    key, *sub_keys = jax.random.split(key, N + 1)
    sub_keys = jnp.array(sub_keys)

    _, p_trs = jax.vmap(proposal.simulate, in_axes=(0, None))(
        sub_keys,
        proposal_args,
    )

    # tree_util.map didnt exist but tree_map does
    observations = jax.tree_util.tree_map(
        lambda v: jnp.repeats(v, 1), observations
    )
    chm = p_trs.get_choices().merge(observations)

    key, *sub_keys = jax.random.split(key, N + 1)
    sub_keys = jnp.array(sub_keys)

    _, (lws, m_trs) = jax.vmap(model.importance, in_axes=(0, 0, None))(
        sub_keys,
        chm,
        model_args,
    )
    lws = lws - p_trs.get_score()

    log_total_weight = jax.scipy.special.logsumexp(lws)
    log_normalized_weights = lws - log_total_weight
    log_ml_estimate = log_total_weight - jnp.log(N)

    return key, (m_trs, log_normalized_weights, log_ml_estimate)

In [5]:
proposal = prop
proposal_args = prop_args
observations = genjax.choice_map({})
N = 10

key, (m_trs, log_normalized_weights, log_ml_estimate) = importance_sampling_fix(key, model, model_args, observations, proposal, proposal_args, N)