In [1]:
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer.reparam import LocScaleReparam
from numpyro.infer import MCMC, ESS, NUTS,AIES

Let's use the classical example for Bayesian inference: the eight schools problem.
https://num.pyro.ai/en/stable/getting_started.html

Which can be described as follows:

\begin{align*}
y_j &\sim \text{Normal}(\theta_j, \sigma_j) \\
\theta_j &\sim \text{Normal}(\mu, \tau) \\
\mu &\sim \text{Normal}(0, 5) \\
\tau &\sim \text{HalfCauchy}(0, 5)
\end{align*}

In [2]:
J = 8
y = jnp.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = jnp.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])


def eight_schools_noreparm(J, sigma, y=None):
    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    with numpyro.plate("J", J):
        theta = numpyro.sample("theta", dist.Normal(mu, tau))

        numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)


def eight_schools(J, sigma, y=None):
    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    with numpyro.plate("J", J):
        with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
            theta = numpyro.sample("theta", dist.Normal(mu, tau))

        numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)

#### Inference with the No-U-Turn Sampler (NUTS)

In [5]:
nuts_kernel = NUTS(eight_schools)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=2000,num_chains=12)

rng_key = jax.random.PRNGKey(0)

mcmc.run(rng_key, J, sigma, y=y)

  mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=2000,num_chains=12)
sample: 100%|██████████| 3000/3000 [00:02<00:00, 1350.43it/s, 7 steps of size 4.13e-01. acc. prob=0.89]
sample: 100%|██████████| 3000/3000 [00:00<00:00, 7188.07it/s, 15 steps of size 3.89e-01. acc. prob=0.93]
sample: 100%|██████████| 3000/3000 [00:00<00:00, 7140.69it/s, 15 steps of size 4.12e-01. acc. prob=0.88]
sample: 100%|██████████| 3000/3000 [00:00<00:00, 7340.02it/s, 7 steps of size 5.06e-01. acc. prob=0.85]
sample: 100%|██████████| 3000/3000 [00:00<00:00, 7622.39it/s, 31 steps of size 4.27e-01. acc. prob=0.89]
sample: 100%|██████████| 3000/3000 [00:00<00:00, 6145.35it/s, 7 steps of size 4.88e-01. acc. prob=0.85]
sample: 100%|██████████| 3000/3000 [00:00<00:00, 6124.71it/s, 15 steps of size 3.68e-01. acc. prob=0.93]
sample: 100%|██████████| 3000/3000 [00:00<00:00, 7297.40it/s, 7 steps of size 4.61e-01. acc. prob=0.86]
sample: 100%|██████████| 3000/3000 [00:00<00:00, 6770.03it/s, 15 steps of size 3.75e-01.

In [6]:
mcmc.print_summary(exclude_deterministic=False)


                         mean       std    median      5.0%     95.0%     n_eff     r_hat
                 mu      4.39      3.30      4.42     -1.02      9.79  21658.59      1.00
                tau      3.65      3.28      2.77      0.00      7.88  15706.64      1.00
           theta[0]      6.27      5.62      5.68     -2.53     14.74  22392.60      1.00
           theta[1]      4.95      4.68      4.89     -2.49     12.56  24253.31      1.00
           theta[2]      3.92      5.29      4.20     -4.27     12.20  22059.33      1.00
           theta[3]      4.74      4.81      4.70     -3.13     12.23  26126.45      1.00
           theta[4]      3.64      4.68      3.91     -3.67     11.26  25698.43      1.00
           theta[5]      4.01      4.90      4.20     -3.60     12.04  25425.70      1.00
           theta[6]      6.36      5.08      5.85     -1.48     14.61  23419.89      1.00
           theta[7]      4.90      5.34      4.81     -3.60     13.03  24055.56      1.00
theta_dec

#### Inference with the Affine Invariant Ensemble Sampler (AIES)

In [7]:
aies_kernel = AIES(eight_schools)
mcmc = MCMC(aies_kernel, num_warmup=0, num_samples=20_000, num_chains=50, chain_method='vectorized')

In [8]:
mcmc.run(rng_key, J, sigma, y=y)

sample: 100%|██████████| 20000/20000 [00:03<00:00, 5764.94it/s, acc. prob=0.37] 


In [9]:
mcmc.print_summary(exclude_deterministic=False)


                         mean       std    median      5.0%     95.0%     n_eff     r_hat
                 mu      4.32      3.02      4.32     -0.69      9.29  25219.55      1.00
                tau      4.11      3.33      3.31      0.00      8.44   2560.29      1.01
           theta[0]      6.47      5.43      5.85     -1.88     14.71  29791.11      1.00
           theta[1]      5.04      4.44      4.92     -2.27     11.98  30584.50      1.00
           theta[2]      3.72      5.00      3.96     -4.06     11.66  30484.66      1.00
           theta[3]      4.74      4.50      4.70     -2.46     11.98  30584.00      1.00
           theta[4]      3.41      4.43      3.65     -3.70     10.58  28979.95      1.00
           theta[5]      3.89      4.57      4.05     -3.35     11.26  29296.11      1.00
           theta[6]      6.60      4.93      6.09     -1.24     14.30  29936.42      1.00
           theta[7]      4.83      5.04      4.72     -3.08     12.68  30188.04      1.00
theta_dec

#### Inference with the Ensemble Slice Sampler (ESS)

In [3]:
rng_key = jax.random.PRNGKey(0)

ess_kernel = ESS(eight_schools)
mcmc = MCMC(ess_kernel, num_warmup=0, num_samples=20_000, num_chains=50, chain_method='vectorized')

In [4]:
mcmc.run(rng_key, J, sigma, y=y)

sample: 100%|██████████| 20000/20000 [00:05<00:00, 3904.33it/s] 


In [5]:
mcmc.print_summary(exclude_deterministic=False)


                         mean       std    median      5.0%     95.0%     n_eff     r_hat
                 mu      4.36      3.03      4.38     -0.59      9.37 640510.05      1.00
                tau      4.04      3.32      3.21      0.00      8.38 633294.09      1.00
           theta[0]      6.47      5.45      5.83     -1.95     14.77 711467.24      1.00
           theta[1]      4.99      4.46      4.88     -2.28     12.10 725126.23      1.00
           theta[2]      3.82      5.02      4.08     -4.05     11.76 715660.34      1.00
           theta[3]      4.78      4.54      4.72     -2.51     12.06 665468.42      1.00
           theta[4]      3.44      4.43      3.72     -3.60     10.67 716694.15      1.00
           theta[5]      3.96      4.63      4.14     -3.39     11.47 713719.99      1.00
           theta[6]      6.53      4.91      6.01     -1.26     14.25 668140.16      1.00
           theta[7]      4.88      5.06      4.77     -3.11     12.81 689937.55      1.00
theta_dec