In [1]:
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS

In [2]:
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])


In [3]:
def eight_schools(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        theta = numpyro.sample('theta', dist.Normal(mu, tau))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

In [4]:

nuts_kernel = NUTS(eight_schools)

mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

sample: 100%|██████████| 1500/1500 [00:09<00:00, 165.74it/s, 7 steps of size 2.59e-01. acc. prob=0.76] 


In [5]:
mcmc.print_summary()  


                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      3.92      3.34      3.55     -1.37      9.40    170.62      1.01
       tau      4.24      3.27      3.27      0.93      8.24     84.25      1.02
  theta[0]      6.06      5.85      5.15     -2.82     14.47    147.79      1.02
  theta[1]      4.59      4.92      4.28     -3.76     11.57    323.62      1.01
  theta[2]      3.04      5.72      2.92     -6.03     11.64    353.25      1.00
  theta[3]      4.22      5.31      3.67     -3.69     12.65    265.08      1.02
  theta[4]      2.94      5.07      2.98     -3.93     11.70    363.29      1.00
  theta[5]      3.65      5.23      3.50     -4.30     11.78    324.80      1.00
  theta[6]      6.08      5.54      5.38     -1.62     15.65    192.22      1.02
  theta[7]      4.35      5.42      3.76     -4.02     12.75    369.64      1.01

Number of divergences: 29


In [6]:
pe = mcmc.get_extra_fields()['potential_energy']
print('Expected log joint density: {:.2f}'.format(np.mean(-pe)))  

Expected log joint density: -55.55


In [7]:
from numpyro.infer.reparam import TransformReparam

In [9]:
# Eight Schools example - Non-centered Reparametrization
def eight_schools_noncentered(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        with numpyro.handlers.reparam(config={'theta': TransformReparam()}):
            theta = numpyro.sample('theta',dist.TransformedDistribution(dist.Normal(0., 1.),dist.transforms.AffineTransform(mu, tau)))
    numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)


nuts_kernel = NUTS(eight_schools_noncentered)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
mcmc.print_summary(exclude_deterministic=False)  


sample: 100%|██████████| 1500/1500 [00:05<00:00, 265.53it/s, 7 steps of size 4.63e-01. acc. prob=0.89]



                   mean       std    median      5.0%     95.0%     n_eff     r_hat
           mu      4.27      3.36      4.27     -1.22      9.38    793.90      1.00
          tau      3.72      3.48      2.81      0.00      7.96    893.12      1.00
     theta[0]      6.30      5.79      5.60     -1.95     14.84    944.66      1.00
     theta[1]      4.96      4.67      4.81     -3.65     11.23   1212.58      1.00
     theta[2]      3.69      5.37      4.08     -5.19     11.64    926.15      1.00
     theta[3]      4.78      4.68      4.71     -2.54     12.14   1131.32      1.00
     theta[4]      3.48      4.78      3.97     -4.21     10.49   1042.15      1.00
     theta[5]      3.92      4.77      4.23     -3.10     12.04   1300.24      1.00
     theta[6]      6.34      5.09      5.88     -2.33     13.22    993.05      1.00
     theta[7]      4.69      5.35      4.60     -2.59     12.64    815.38      1.00
theta_base[0]      0.30      0.96      0.33     -1.22      2.02    952.88  

In [15]:
pe = mcmc.get_extra_fields()['potential_energy']
# Compare with the earlier value
print('Expected log joint density: {:.2f}'.format(np.mean(-pe)))  

Expected log joint density: -46.11


In [16]:
from numpyro.infer import Predictive
# New School
def new_school():
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    return numpyro.sample('obs', dist.Normal(mu, tau))

predictive = Predictive(new_school, mcmc.get_samples())
samples_predictive = predictive(random.PRNGKey(1))
print(np.mean(samples_predictive['obs']))  

4.4116945
