In [1]:
import numpy as np

In [5]:
import numpyro
import numpyro.distributions as dist

In [6]:
def eight_schools(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        theta = numpyro.sample('theta', dist.Normal(mu, tau))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

In [7]:
from jax import random

In [8]:
from numpyro.infer import MCMC, NUTS

In [9]:
nuts_kernel = NUTS(eight_schools)

In [13]:
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

In [10]:
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)

In [11]:
rng_key = random.PRNGKey(0)

In [14]:
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

sample: 100% 1500/1500 [02:57<00:00,  8.45it/s, 63 steps of size 4.53e-02. acc. prob=0.97] 


In [15]:
mcmc.print_summary() 


                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      4.57      3.53      4.36     -1.72      9.51    138.24      1.01
       tau      3.60      3.06      2.87      0.25      7.62     91.97      1.01
  theta[0]      6.24      5.64      5.67     -3.22     13.91    204.27      1.00
  theta[1]      5.05      4.81      4.76     -2.75     12.11    236.71      1.00
  theta[2]      4.14      5.24      4.10     -4.40     12.23    263.75      1.01
  theta[3]      4.91      5.01      4.76     -3.90     12.01    247.83      1.00
  theta[4]      3.74      5.19      3.78     -4.60     11.93    210.18      1.01
  theta[5]      4.17      4.83      4.16     -3.79     12.05    214.35      1.01
  theta[6]      6.62      5.07      6.33     -1.41     14.35    179.37      1.00
  theta[7]      5.01      5.22      4.64     -3.74     12.53    239.29      1.00

Number of divergences: 2


In [16]:
pe = mcmc.get_extra_fields()['potential_energy']

In [17]:
print('Expected log joint density: {:.2f}'.format(np.mean(-pe)))  

Expected log joint density: -53.39
