In [1]:
import pystan
import numpy as np
import matplotlib.pyplot as plt

schools_code = """
data {
    int<lower=0> J; // number of schools
    real y[J]; // estimated treatment effects
    real<lower=0> sigma[J]; // s.e. of effect estimates
}
parameters {
    real mu;
    real<lower=0> tau;
    real eta[J];
}
transformed parameters {
    real theta[J];
    for (j in 1:J)
        theta[j] <- mu + tau * eta[j];
}
model {
    eta ~ normal(0, 1);
    y ~ normal(theta, sigma);
}
"""

schools_dat = {'J': 8,
               'y': [28,  8, -3,  7, -1,  1, 18, 12],
               'sigma': [15, 10, 16, 11,  9, 11, 10, 18]}

fit = pystan.stan(model_code=schools_code, data=schools_dat,
                  iter=1000, chains=4)

print(fit)

eta = fit.extract(permuted=True)['eta']
np.mean(eta, axis=0)

# if matplotlib is installed (optional, not required), a visual summary and
# traceplot are available
fit.plot()
plt.show()

INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_7af71fe73cad6b7d51a4ce15383db4c6 NOW.


Inference for Stan model: anon_model_7af71fe73cad6b7d51a4ce15383db4c6.
4 chains, each with iter=1000; warmup=500; thin=1; 
post-warmup draws per chain=500, total post-warmup draws=2000.

           mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu         8.13    0.21   5.38  -1.83   4.73   8.03  11.43   18.3  629.0   1.01
tau        6.94    0.29    6.5   0.25   2.48    5.3   9.65  22.14  515.0    1.0
eta[0]     0.37    0.02   0.91  -1.44  -0.23   0.39   0.97   2.18 1695.0    1.0
eta[1]    -0.03    0.02    0.9  -1.86  -0.62-9.4e-3   0.55   1.75 1804.0    1.0
eta[2]    -0.22    0.02   0.93  -2.03  -0.83  -0.22   0.34   1.71 2000.0    1.0
eta[3]  -7.9e-3    0.02   0.89  -1.74  -0.59  -0.04   0.52   1.82 2000.0    1.0
eta[4]    -0.37    0.02   0.87  -2.05  -0.96  -0.36    0.2    1.4 2000.0    1.0
eta[5]    -0.19    0.02   0.86  -1.86  -0.77   -0.2   0.35   1.49 1396.0    1.0
eta[6]     0.33    0.02   0.88  -1.49  -0.26   0.36   0.94   1.99 1576.0    1.0
eta[7]     0.

In [2]:
schools_dat

{'J': 8,
 'sigma': [15, 10, 16, 11, 9, 11, 10, 18],
 'y': [28, 8, -3, 7, -1, 1, 18, 12]}