In [11]:
# base model in numpyro

import jax
import numpyro
from jax import numpy as jnp
from numpyro.infer import MCMC, NUTS
dist = numpyro.distributions

scale = jnp.ones(5)

def model():
    z = numpyro.sample("z",dist.Normal(0,1))
    x = numpyro.sample("x",dist.Normal(z,scale))

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
key = jax.random.PRNGKey(0)
mcmc.run(key)
xs = mcmc.get_samples()['x']
print(jnp.mean(xs,axis=0))

sample: 100%|██████████| 2000/2000 [00:01<00:00, 1015.90it/s, 7 steps of size 3.67e-01. acc. prob=0.92]

[0.1608343  0.1762394  0.15983294 0.16351682 0.17033295]





In [21]:
# numpyro seems to have no built-in support for missing data
# instead you just need to kind of deal with it yourself

x_obs = jnp.array([-1,3,jnp.nan,jnp.nan,jnp.nan])

def model():
    z = numpyro.sample("z",dist.Normal(0,1))
    x1 = numpyro.sample("x1",dist.Normal(z,scale[:2]),obs=x_obs[:2])
    x2 = numpyro.sample("x2",dist.Normal(z,scale[2:]))

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=10000)
key = jax.random.PRNGKey(1)
mcmc.run(key)
xs = mcmc.get_samples()['x2']
print(jnp.mean(xs,axis=0))

expected_mean = (0 -1+3)/3

print(f"{expected_mean=}")

sample: 100%|██████████| 11000/11000 [00:04<00:00, 2524.93it/s, 7 steps of size 5.66e-01. acc. prob=0.92] 

[0.6402163  0.63642496 0.6509023 ]
expected_mean=0.6666666666666666



