# Alternative sampling backends

In Bambi, the sampler used is automatically selected given the type of variables used in the model. For inference, Bambi supports both MCMC and variational inference. By default, Bambi uses PyMC's implementation of the adaptive Hamiltonian Monte Carlo (HMC) algorithm for sampling. Also known as the No-U-Turn Sampler (NUTS). This sampler is a good choice for many models. However, it is not the only sampling method, nor is PyMC the only library implementing NUTS. 

To this extent, Bambi supports multiple backends for MCMC sampling such as NumPyro and Blackjax. This notebook will cover how to use such alternatives in Bambi.

_Note_: Bambi utilizes [bayeux](https://github.com/jax-ml/bayeux) to access a variety of sampling backends. Thus, you will need to install the optional dependencies in the Bambi [pyproject.toml](https://github.com/bambinos/bambi/blob/main/pyproject.toml) file to use these backends.

In [1]:
import arviz as az
import bambi as bmb
import bayeux as bx
import numpy as np
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


## Bayeux

Bambi leverages `bayeux` to access different sampling backends. In short, `bayeux` lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. 

Since the underlying Bambi model is a PyMC model, this PyMC model can be "given" to `bayeux`. Then, we can choose from a variety of MCMC methods to perform inference. 

To demonstrate the available backends, we will fist simulate data and build a model.

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
num_samples = 100
num_features = 1
noise_std = 1.0
random_seed = 42

np.random.seed(random_seed)

coefficients = np.random.randn(num_features)
X = np.random.randn(num_samples, num_features)
error = np.random.normal(scale=noise_std, size=num_samples)
y = X @ coefficients + error

data = pd.DataFrame({"y": y, "x": X.flatten()})

In [4]:
model = bmb.Model("y ~ x", data)
model.build()

We can call `bmb.inference_methods.names` that returns a nested dictionary of the backends and list of inference methods.

In [5]:
methods = bmb.inference_methods.names
methods

{'pymc': {'mcmc': ['mcmc'], 'vi': ['vi']},
 'bayeux': {'mcmc': ['tfp_hmc',
   'tfp_nuts',
   'tfp_snaper_hmc',
   'blackjax_hmc',
   'blackjax_chees_hmc',
   'blackjax_meads_hmc',
   'blackjax_nuts',
   'blackjax_hmc_pathfinder',
   'blackjax_nuts_pathfinder',
   'flowmc_rqspline_hmc',
   'flowmc_rqspline_mala',
   'flowmc_realnvp_hmc',
   'flowmc_realnvp_mala',
   'numpyro_hmc',
   'numpyro_nuts',
   'nutpie']}}

With the PyMC backend, we have access to their implementation of the NUTS sampler and mean-field variational inference.

In [6]:
methods["pymc"]

{'mcmc': ['mcmc'], 'vi': ['vi']}

`bayeux` lets us have access to Tensorflow probability, Blackjax, FlowMC, and NumPyro backends.

In [7]:
methods["bayeux"]

{'mcmc': ['tfp_hmc',
  'tfp_nuts',
  'tfp_snaper_hmc',
  'blackjax_hmc',
  'blackjax_chees_hmc',
  'blackjax_meads_hmc',
  'blackjax_nuts',
  'blackjax_hmc_pathfinder',
  'blackjax_nuts_pathfinder',
  'flowmc_rqspline_hmc',
  'flowmc_rqspline_mala',
  'flowmc_realnvp_hmc',
  'flowmc_realnvp_mala',
  'numpyro_hmc',
  'numpyro_nuts',
  'nutpie']}

The values of the MCMC and VI keys in the dictionary are the names of the argument you would pass to `inference_method` in `model.fit`. This is shown in the section below.

## Specifying an `inference_method`

By default, Bambi uses the PyMC NUTS implementation. To use a different backend, pass the name of the `bayeux` MCMC method to the `inference_method` parameter of the `fit` method.

### Blackjax

In [8]:
blackjax_nuts_idata = model.fit(inference_method="blackjax_nuts")
blackjax_nuts_idata

Different backends have different naming conventions for the parameters specific to that MCMC method. Thus, to specify backend-specific parameters, pass your own `kwargs` to the `fit` method.

The following can be performend to identify the kwargs specific to each method.

In [9]:
bmb.inference_methods.get_kwargs("blackjax_nuts")

{<function blackjax.adaptation.window_adaptation.window_adaptation(algorithm, logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.8, progress_bar: bool = False, adaptation_info_fn: Callable = <function return_all_adapt_info at 0x14f0cf6a0>, integrator=<function generate_euclidean_integrator.<locals>.euclidean_integrator at 0x14f096e80>, **extra_parameters) -> blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
  'is_mass_matrix_diagonal': True,
  'initial_step_size': 1.0,
  'target_acceptance_rate': 0.8,
  'progress_bar': False,
  'adaptation_info_fn': <function blackjax.adaptation.base.return_all_adapt_info(state, info, adaptation_state)>,
  'algorithm': GenerateSamplingAPI(differentiable=<function as_top_level_api at 0x14f0ccea0>, init=<function init at 0x14f095bc0>, build_kernel=<function build_kernel at 0x14f01ae80>

Now, we can identify the kwargs we would like to change and pass to the `fit` method.

In [10]:
kwargs = {
    "adapt.run": {"num_steps": 500},
    "num_chains": 4,
    "num_draws": 250,
    "num_adapt_draws": 250,
}

blackjax_nuts_idata = model.fit(inference_method="blackjax_nuts", **kwargs)
blackjax_nuts_idata

### Tensorflow probability

In [11]:
tfp_nuts_idata = model.fit(inference_method="tfp_nuts")
tfp_nuts_idata

### NumPyro

In [12]:
numpyro_nuts_idata = model.fit(inference_method="numpyro_nuts")
numpyro_nuts_idata

sample: 100%|██████████| 1500/1500 [00:02<00:00, 667.23it/s] 


### flowMC

In [13]:
flowmc_idata = model.fit(inference_method="flowmc_realnvp_hmc")
flowmc_idata

['n_dim', 'n_chains', 'n_local_steps', 'n_global_steps', 'n_loop', 'output_thinning', 'verbose']


Global Tuning: 100%|██████████| 5/5 [00:10<00:00,  2.15s/it]
Global Sampling: 100%|██████████| 5/5 [00:00<00:00, 62.57it/s]


## Sampler comparisons

With ArviZ, we can compare the inference result summaries of the samplers. _Note:_ We can't use `az.compare` as not each inference data object returns the pointwise log-probabilities. Thus, an error would be raised.

In [14]:
az.summary(blackjax_nuts_idata)

[autoreload of cutils_ext failed: Traceback (most recent call last):
  File "/Users/alex_andorra/mambaforge/envs/bambi-env/lib/python3.12/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/Users/alex_andorra/mambaforge/envs/bambi-env/lib/python3.12/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/Users/alex_andorra/mambaforge/envs/bambi-env/lib/python3.12/importlib/__init__.py", line 130, in reload
    raise ModuleNotFoundError(f"spec not found for the module {name!r}", name=name)
ModuleNotFoundError: spec not found for the module 'cutils_ext'
]


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
Intercept,0.02,0.096,-0.171,0.189,0.003,0.003,774.0,667.0,1.01
sigma,0.946,0.068,0.826,1.073,0.002,0.002,984.0,766.0,1.01
x,0.357,0.108,0.151,0.548,0.003,0.002,1072.0,781.0,1.01


In [15]:
az.summary(tfp_nuts_idata)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
Intercept,0.023,0.096,-0.155,0.205,0.001,0.001,6927.0,5296.0,1.0
sigma,0.948,0.069,0.826,1.082,0.001,0.001,7051.0,6070.0,1.0
x,0.36,0.106,0.16,0.556,0.001,0.001,6749.0,5613.0,1.0


In [16]:
az.summary(numpyro_nuts_idata)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
Intercept,0.023,0.096,-0.155,0.204,0.001,0.001,7083.0,5374.0,1.0
sigma,0.948,0.068,0.82,1.075,0.001,0.001,7256.0,6005.0,1.0
x,0.361,0.106,0.158,0.562,0.001,0.001,6932.0,5546.0,1.0


In [17]:
az.summary(flowmc_idata)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
Intercept,0.024,0.096,-0.149,0.207,0.003,0.002,876.0,615.0,1.02
sigma,0.947,0.067,0.822,1.066,0.001,0.001,5554.0,5920.0,1.0
x,0.361,0.104,0.161,0.55,0.001,0.001,5081.0,4653.0,1.0


## Summary

Thanks to `bayeux`, we can use three different sampling backends and 10+ alternative MCMC methods in Bambi. Using these methods is as simple as passing the inference name to the `inference_method` of the `fit` method.

In [18]:
%load_ext watermark
%watermark -n -u -v -iv -w

Last updated: Tue Dec 10 2024

Python implementation: CPython
Python version       : 3.12.8
IPython version      : 8.30.0

arviz : 0.20.0
numpy : 1.26.4
bayeux: 0.1.14
bambi : 0.14.1.dev10+g46d5572b.d20241109
pandas: 2.2.3

Watermark: 2.5.0

