In [4]:
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer.reparam import TransformReparam

In [3]:
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

In [5]:
def eight_schools_noncentered(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        with numpyro.handlers.reparam(config={'theta': TransformReparam()}):
            theta = numpyro.sample(
                'theta',
                 dist.TransformedDistribution(dist.Normal(0., 1.),
                                              dist.transforms.AffineTransform(mu, tau)))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

In [7]:
from jax import random
from numpyro.infer import MCMC, NUTS

In [8]:
nuts_kernel = NUTS(eight_schools_noncentered)

In [9]:
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)

In [10]:
rng_key = random.PRNGKey(0)

In [None]:
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

sample:  53% 797/1500 [00:20<00:12, 54.33it/s, 7 steps of size 4.67e-01. acc. prob=0.91] 

In [3]:
data = pd.DataFrame({   
    'yield':[7, 13, 13, 11, 5, 6, 8, 11, 10, 11,
             11, 11, 11, 14, 8, 15, 10, 9, 13, 12,
             8, 15, 7, 11, 5, 11, 15, 10, 13, 9,
             8, 12, 13, 6, 8, 5, 13, 8, 5, 10,
             18, 9, 7, 12, 11, 5, 9, 10, 13, 13,
             7, 12, 8, 16, 10, 6, 12, 13, 10, 12, 
             9, 7, 12, 11, 8, 15, 13, 11, 9, 17,
             11, 10, 15, 19 ,11, 13, 12, 9, 10,10],

    'group':[0,0,0,0,0,0,0,0,0,0,
             0,0,0,0,0,0,0,0,0,0,
             0,0,0,0,0,0,0,0,0,0,
             0,0,0,0,0,0,0,0,0,0,
             0,0,0,0,0,0,0,0,0,0,
             1,1,1,1,1,1,1,1,1,1,
             1,1,1,1,1,1,1,1,1,1,
             1,1,1,1,1,1,1,1,1,1,
             ]    
})