# MCMC Support

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/markean/aimz/blob/main/docs/notebooks/mcmc.ipynb)

While aimz is primarily designed around variational inference and predictive sampling, it also provides support for MCMC methods via the [NumPyro backend](https://num.pyro.ai/en/stable/mcmc.html#numpyro.infer.mcmc.MCMC), using the same aimz interface (e.g., `.fit_on_batch()` and `.predict_on_batch()`). This enables users to apply MCMC to more complex models where variational inference may be less effective and dataset sizes are relatively small.

In [1]:
import jax.numpy as jnp
import numpyro.distributions as dist
from jax import random
from jax.typing import ArrayLike
from numpyro import plate, sample
from numpyro.infer import MCMC, NUTS

from aimz import ImpactModel

%load_ext watermark

## Model and Data

We set up a linear regression model and create synthetic data for both features and targets as an example.

In [2]:
def model(X: ArrayLike, y: ArrayLike | None = None) -> None:
    """Linear regression model."""
    w = sample("w", dist.Normal().expand((X.shape[1],)))
    b = sample("b", dist.Normal())
    mu = jnp.dot(X, w) + b
    sigma = sample("sigma", dist.Exponential())
    with plate("data", size=X.shape[0]):
        sample("y", dist.Normal(mu, sigma), obs=y)


rng_key = random.key(42)
rng_key, rng_key_w, rng_key_b, rng_key_x, rng_key_e = random.split(rng_key, 5)
w = random.normal(rng_key_w, (10,))
b = random.normal(rng_key_b)
X = random.normal(rng_key_x, (1000, 10))
e = random.normal(rng_key_e, (1000,))
y = jnp.dot(X, w) + b + e

## MCMC Sampling and Prediction

MCMC sampling can be performed using the `ImpactModel` class by setting the `inference` argument to `MCMC`. Users can configure the sampler, warm-up steps, and other MCMC-specific parameters. Calling `.fit_on_batch()` initiates the sampling process. Internally, aimz executes the sampler via the `.run()` method and stores the posterior samples using `.get_samples()`.

Note that calling `.fit()` with `MCMC` as the inference method will raise a `TypeError`, as this method is intended for mini-batch training or subsampling. Regardless of the number of chains (`num_chains`) used, the posterior samples are combined across chains to ensure compatibility with the rest of the aimz interface. Posterior predictive sampling can be performed using the `.predict()` or `.predict_on_batch()` methods.

In [3]:
rng_key, rng_subkey = random.split(rng_key)
im = ImpactModel(
    model,
    rng_key=rng_subkey,
    inference=MCMC(NUTS(model), num_warmup=500, num_samples=500),
)
im.fit_on_batch(X, y)
im.inference.print_summary()
im.predict_on_batch(X)

Backend: cpu, Devices: 1
Posterior sampling...


sample: 100%|██████████| 1000/1000 [00:00<00:00, 1406.16it/s, 7 steps of size 8.17e-01. acc. prob=0.83]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
         b      0.41      0.03      0.41      0.36      0.46    671.95      1.00
     sigma      1.04      0.02      1.04      1.00      1.07    873.28      1.00
      w[0]      0.59      0.03      0.59      0.53      0.65    926.56      1.00
      w[1]      0.86      0.03      0.86      0.82      0.93   1032.70      1.00
      w[2]     -0.90      0.03     -0.90     -0.95     -0.85    601.99      1.00
      w[3]     -0.60      0.04     -0.60     -0.66     -0.55    786.20      1.00
      w[4]     -1.24      0.03     -1.24     -1.29     -1.18   1111.40      1.00
      w[5]     -0.80      0.03     -0.80     -0.85     -0.75    765.69      1.00
      w[6]     -0.51      0.03     -0.51     -0.56     -0.45    650.49      1.00
      w[7]     -1.22      0.03     -1.22     -1.27     -1.16    663.89      1.00
      w[8]     -0.17      0.03     -0.17     -0.21     -0.11    508.14      1.00
      w[9]     -0.10      0

## Using External MCMC Samples

Users can run MCMC sampling directly using NumPyro and then insert the posterior samples into an `ImpactModel` instance using the `.set_posterior_sample()` method for downstream analysis.
For example:

In [4]:
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
rng_key, rng_subkey = random.split(rng_key)
mcmc.run(rng_key, X, y)

im.set_posterior_sample(mcmc.get_samples())
im.predict_on_batch(X)

sample: 100%|██████████| 2000/2000 [00:00<00:00, 2155.09it/s, 7 steps of size 6.75e-01. acc. prob=0.86]


In [5]:
%watermark -iv

aimz   : 0.7.0
numpyro: 0.19.0
jax    : 0.7.2

