# MCMC

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/markean/aimz/blob/main/docs/notebooks/mcmc.ipynb)

In [1]:
import jax.numpy as jnp
import numpyro.distributions as dist
from jax import random
from jax.typing import ArrayLike
from numpyro import sample
from numpyro.infer import MCMC, NUTS

from aimz.model import ImpactModel

%load_ext watermark

## Model

In [None]:
def model(X: ArrayLike, y: ArrayLike | None = None) -> None:
    """Linear regression model."""
    w = sample("w", dist.Normal().expand((X.shape[1],)))
    b = sample("b", dist.Normal())
    mu = jnp.dot(X, w) + b
    sigma = sample("sigma", dist.Exponential())
    sample("y", dist.Normal(mu, sigma), obs=y)


rng_key = random.key(42)
rng_key, rng_key_w, rng_key_b, rng_key_x, rng_key_e = random.split(rng_key, 5)
w = random.normal(rng_key_w, (10,))
b = random.normal(rng_key_b)
X = random.normal(rng_key_x, (1000, 10))
e = random.normal(rng_key_e, (1000,))
y = jnp.dot(X, w) + b + e

## MCMC

In [3]:
rng_key, rng_subkey = random.split(rng_key)
im = ImpactModel(
    model,
    rng_key=rng_subkey,
    inference=MCMC(NUTS(model), num_warmup=500, num_samples=500),
)
im.fit_on_batch(X, y)
im.predict_on_batch(X)

Backend: cpu, Devices: 1
Posterior sampling...


sample: 100%|██████████| 1000/1000 [00:00<00:00, 1332.75it/s, 7 steps of size 8.17e-01. acc. prob=0.83]


In [4]:
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
rng_key, rng_subkey = random.split(rng_key)
mcmc.run(rng_key, X, y)
mcmc.print_summary()

im.set_posterior_sample(mcmc.get_samples())
im.predict_on_batch(X)

sample: 100%|██████████| 2000/2000 [00:00<00:00, 2093.19it/s, 7 steps of size 6.75e-01. acc. prob=0.86]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
         b      0.41      0.03      0.41      0.36      0.46   2005.80      1.00
     sigma      1.03      0.02      1.03      1.00      1.07   1167.20      1.00
      w[0]      0.59      0.03      0.59      0.54      0.65   2714.41      1.00
      w[1]      0.87      0.03      0.87      0.82      0.92   1583.38      1.00
      w[2]     -0.91      0.04     -0.91     -0.96     -0.85   1940.84      1.00
      w[3]     -0.60      0.03     -0.60     -0.65     -0.54   1964.39      1.00
      w[4]     -1.24      0.03     -1.24     -1.29     -1.19   1349.14      1.00
      w[5]     -0.80      0.03     -0.80     -0.85     -0.75   1894.85      1.00
      w[6]     -0.51      0.03     -0.51     -0.56     -0.46   1808.16      1.00
      w[7]     -1.22      0.03     -1.22     -1.27     -1.16   1750.85      1.00
      w[8]     -0.16      0.03     -0.16     -0.22     -0.11   2075.41      1.00
      w[9]     -0.10      0

In [5]:
%watermark -iv

numpyro: 0.19.0
aimz   : 0.5.0
jax    : 0.7.0

