# 8 schools data

In [2]:
from numpyro.infer import Predictive
from numpyro.infer.reparam import TransformReparam, LocScaleReparam
from jax import random
from numpyro.infer import MCMC
from numpyro.infer.hug import Hug
import numpyro.distributions as dist
import numpyro
import numpy as np
import matplotlib.pyplot as plt

Let us explore NumPyro using a simple example. We will use the eight schools example from Gelman et al., Bayesian Data Analysis: Sec. 5.5, 2003, which studies the effect of coaching on SAT performance in eight schools.

The data is given by:

In [3]:
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

where `y` are the treatment effects and `sigma` the standard error. 

We build a hierarchical model for the study where we assume that the group-level parameters `theta` for each school are sampled from a Normal distribution with unknown mean `mu` and standard deviation `tau`, while the observed data are in turn generated from a Normal distribution with mean and standard deviation given by `theta` (true effect) and `sigma`, respectively. 
This allows us to estimate the population-level parameters `mu` and `tau` by pooling from all the observations, while still allowing for individual variation amongst the schools using the group-level `theta` parameters.
This is written in `numpyro` using:

In [4]:
def eight_schools(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        theta = numpyro.sample('theta', dist.Normal(mu, tau))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

Let us infer the values of the unknown parameters in our model by running MCMC using the No-U-Turn Sampler (NUTS). Note the usage of the extra_fields argument in MCMC.run. By default, we only collect samples from the target (posterior) distribution when we run inference using MCMC. However, collecting additional fields like potential energy or the acceptance probability of a sample can be easily achieved by using the extra_fields argument. For a list of possible fields that can be collected, see the `HMCState` object. In this example, we will additionally collect the `potential_energy` for each sample.

In [None]:
kernel = Hug(eight_schools, step_size=0.1, trajectory_length=10)
mcmc = MCMC(kernel, num_warmup=0, num_samples=1000, num_chains=1)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy','accept_prob'))
mcmc.print_summary()

  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
pe = mcmc.get_extra_fields()['potential_energy']
print('Expected log joint density: {:.2f}'.format(
    np.mean(-pe)))  # doctest: +SKIP

In [None]:
plt.plot(pe)

In [None]:
ap = mcmc.get_extra_fields()['accept_prob']
plt.hist(ap, range=[0,1])

In [None]:
samples = mcmc.get_samples()
plt.plot(samples['theta'])

In [None]:
plt.plot(samples['tau'])

In [None]:
plt.plot(samples['mu'])

The values above 1 for the split Gelman Rubin diagnostic `r_hat` indicates that the chain has not fully converged. The low value for the effective sample size `n_eff`, particularly for `tau`, and the number of divergent transitions looks problematic. 
Fortunately, this is a common pathology that can be rectified by using a non-centered paramaterization for `tau` in our model. This is straightforward to do in `numpyro` by using a `TransformedDistribution` instance together with a "reparameterization effect handler". Let us rewrite the same model but instead of sampling `theta` from a Normal(`mu`, `tau`), we will instead sample it from a base Normal(0, 1) distribution that is transformed using an `AffineTransform`. 
Note that by doing so, `nunmpyro` runs HMC by generating samples `theta_base` for the base Normal(0, 1) distribution instead. We see that the resulting chain does not suffer from the same pathology — the Gelman Rubin diagnostic is 1 for all the parameters and the effective sample size looks quite good!

In [None]:
def eight_schools_noncentered(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        with numpyro.handlers.reparam(config={'theta': TransformReparam()}):
            theta = numpyro.sample(
                'theta',
                dist.TransformedDistribution(dist.Normal(0., 1.),
                                             dist.transforms.AffineTransform(mu, tau)))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

kernel = Hug(eight_schools_noncentered, step_size=0.1)
mcmc = MCMC(kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
mcmc.print_summary(exclude_deterministic=False)
pe = mcmc.get_extra_fields()['potential_energy']
print('Expected log joint density: {:.2f}'.format(
    np.mean(-pe)))

In [None]:
plt.plot(pe)

In [None]:
samples = mcmc.get_samples()
plt.plot(samples['theta'])

In [None]:
plt.plot(samples['tau'])

In [None]:
plt.plot(samples['mu'])

Now, assume that we have a new school for which we have not observed any test scores, but we would like to generate predictions. `numpyro` provides a `Predictive` class for such a purpose. Note that in the absence of any observed data, we simply use the population-level parameters to generate predictions. The `Predictive` utility conditions the unobserved `mu` and `tau` sites to values drawn from the posterior distribution from our last MCMC run, and runs the model forward to generate predictions.

In [None]:
def new_school():
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    return numpyro.sample('obs', dist.Normal(mu, tau))

predictive = Predictive(new_school, mcmc.get_samples())
samples_predictive = predictive(random.PRNGKey(1))
print(np.mean(samples_predictive['obs']))  