# MCMC using Pyro

This is a direct example of MCMC given in [here](http://pyro.ai/examples/mcmc.html).

In [1]:
import argparse
import logging

import torch

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import MCMC, NUTS

logging.basicConfig(format='%(message)s', level=logging.INFO)
pyro.enable_validation(__debug__)
pyro.set_rng_seed(0)

In [2]:
J = 8
y = torch.tensor([28,  8, -3,  7, -1,  1, 18, 12]).type(torch.Tensor)
sigma = torch.tensor([15, 10, 16, 11,  9, 11, 10, 18]).type(torch.Tensor)

In [3]:
def model(sigma):
    eta = pyro.sample('eta', dist.Normal(torch.zeros(8), torch.ones(8)))
    mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1)))
    tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(1)))

    theta = mu + tau * eta

    return pyro.sample("obs", dist.Normal(theta, sigma))


def conditioned_model(model, sigma, y):
    return poutine.condition(model, data={"obs": y})(sigma)

In [4]:
nuts_kernel = NUTS(conditioned_model, jit_compile='store_true')

In [5]:
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=1000, num_chains=1)

In [6]:
mcmc.run(model, sigma, y)

Sample: 100%|███████████████████████████████████████| 2000/2000 [00:53, 37.48it/s, step size=4.51e-01, acc. prob=0.833]


In [7]:
mcmc.summary(prob=0.5)


                mean       std    median     25.0%     75.0%     n_eff     r_hat
    eta[0]      0.39      0.93      0.41     -0.43      0.84    786.59      1.00
    eta[1]      0.05      0.86      0.05     -0.62      0.53    901.37      1.00
    eta[2]     -0.17      0.96     -0.24     -0.90      0.43   1047.54      1.00
    eta[3]      0.01      0.84     -0.01     -0.55      0.52    718.53      1.00
    eta[4]     -0.31      0.89     -0.32     -0.88      0.27   1058.78      1.00
    eta[5]     -0.16      0.88     -0.18     -0.72      0.43    969.34      1.00
    eta[6]      0.36      0.95      0.42     -0.15      1.10    656.40      1.00
    eta[7]      0.07      0.96      0.09     -0.52      0.79   1141.12      1.00
     mu[0]      6.39      4.11      6.33      3.36      8.78    730.24      1.00
    tau[0]      5.53      4.60      4.54      0.00      4.55    397.46      1.00

Number of divergences: 0
