Skip to content

Commit

Permalink
Use bayeux to access a wide range of samplers (#775)
Browse files Browse the repository at this point in the history
* use bayeux to access a wide range of samplers

* use bayeux to access a wide range of samplers

* add notebook links to family table (#774)

* access methods programatically

* clean bayeux idata to be consistent with pymc model coords

* rename alternative sampler args in tests

* change docstring to reflect bayeux sampler names

* bayeux dependencies are numpyro/jax/jaxlib/blackjax

* rename idata coords and dims to PyMC model

* add JAX based sampler dependencies

* Update code of conduct (#783)

* Update code of conduct

* update changelog

* [WIP] Fix HSGP predictions (#780)

* Delete all HSGP slices at the same time

* Make interpret consider kwargs in function calls

* Update code of conduct (#783)

* Update code of conduct

* update changelog

* Update formulae to >=0.5.3

* start a test for the hsgp and 'by'

* update changelog

* bayeux 0.1.9 updates

* bump bayeux version

* remove TFP methods, optimizers, and resolve pylint errors

* alternative backends docs

* tests for JAX based samplers except TFP

* add TFP backend example

* add TFP MCMC methods

* don't use flowmc, chees, meads for categorical model

* call model.backend.inference_methods to show list of samplers

* docstring changes

* inference_methods attribute and change JAX random seed

* Add FutureWarning to inference_method parameter

* black formatting and resolve pylint errors

* fix package name

* drop 3.9 and add 3.12 to testing matrix

* change Python versions in requires-python and target-version

* remove python 3.11 black target-version

* pin requires-python to <3.13

* pip upgrade setuptools

* Bump PyMC to 5.12

* Upgrade black and pylint

* remove upgrading of setup tools

---------

Co-authored-by: Tomás Capretto <tomicapretto@gmail.com>
  • Loading branch information
GStechschulte and tomicapretto committed Mar 29, 2024
1 parent ff685b7 commit 714ccb7
Show file tree
Hide file tree
Showing 16 changed files with 7,097 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]

name: Set up Python ${{ matrix.python-version }}
steps:
Expand Down
130 changes: 87 additions & 43 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import functools
import importlib
import logging
import operator
import traceback

import warnings

from copy import deepcopy
from importlib.metadata import version
Expand All @@ -12,7 +14,6 @@
import pytensor.tensor as pt
from pytensor.tensor.special import softmax


from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2
from bambi.backend.model_components import ConstantComponent, DistributionalComponent
from bambi.utils import get_aliased_name
Expand Down Expand Up @@ -46,6 +47,8 @@ def __init__(self):
self.model = None
self.spec = None
self.components = {}
self.bayeux_methods = _get_bayeux_methods()
self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]}

def build(self, spec):
"""Compile the PyMC model from an abstract model specification.
Expand Down Expand Up @@ -94,8 +97,24 @@ def run(
):
"""Run PyMC sampler."""
inference_method = inference_method.lower()

if inference_method == "nuts_numpyro":
inference_method = "numpyro_nuts"
warnings.warn(
"'nuts_numpyro' has been replaced by 'numpyro_nuts' and will be "
"removed in a future release",
category=FutureWarning,
)
elif inference_method == "nuts_blackjax":
inference_method = "blackjax_nuts"
warnings.warn(
"'nuts_blackjax' has been replaced by 'blackjax_nuts' and will "
"be removed in a future release",
category=FutureWarning,
)

# NOTE: Methods return different types of objects (idata, approximation, and dictionary)
if inference_method in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
if inference_method in (self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]):
result = self._run_mcmc(
draws,
tune,
Expand All @@ -110,7 +129,7 @@ def run(
inference_method,
**kwargs,
)
elif inference_method == "vi":
elif inference_method in self.pymc_methods["vi"]:
result = self._run_vi(**kwargs)
elif inference_method == "laplace":
result = self._run_laplace(draws, omit_offsets, include_mean)
Expand Down Expand Up @@ -169,8 +188,8 @@ def _run_mcmc(
sampler_backend="mcmc",
**kwargs,
):
with self.model:
if sampler_backend == "mcmc":
if sampler_backend in self.pymc_methods["mcmc"]:
with self.model:
try:
idata = pm.sample(
draws=draws,
Expand Down Expand Up @@ -205,41 +224,35 @@ def _run_mcmc(
)
else:
raise
elif sampler_backend == "nuts_numpyro":
import pymc.sampling_jax # pylint: disable=import-outside-toplevel

if not chains:
# sample_numpyro_nuts does not handle chains = None like pm.sample does
chains = 4
idata = pymc.sampling_jax.sample_numpyro_nuts(
draws=draws,
tune=tune,
chains=chains,
random_seed=random_seed,
**kwargs,
)
elif sampler_backend == "nuts_blackjax":
import pymc.sampling_jax # pylint: disable=import-outside-toplevel

# sample_blackjax_nuts does not handle chains = None like pm.sample does
if not chains:
chains = 4
idata = pymc.sampling_jax.sample_blackjax_nuts(
draws=draws,
tune=tune,
chains=chains,
random_seed=random_seed,
**kwargs,
)
else:
raise ValueError(
f"sampler_backend value {sampler_backend} is not valid. Please choose one of"
f"'mcmc', 'nuts_numpyro' or 'nuts_blackjax'"
)
idata = self._clean_results(idata, omit_offsets, include_mean)
idata_from = "pymc"
elif sampler_backend in self.bayeux_methods["mcmc"]:
import bayeux as bx # pylint: disable=import-outside-toplevel
import jax # pylint: disable=import-outside-toplevel

# Set the seed for reproducibility if provided
if random_seed is not None:
if not isinstance(random_seed, int):
random_seed = random_seed[0]
np.random.seed(random_seed)

jax_seed = jax.random.PRNGKey(np.random.randint(2**32 - 1))

bx_model = bx.Model.from_pymc(self.model)
bx_sampler = operator.attrgetter(sampler_backend)(
bx_model.mcmc # pylint: disable=no-member
)
idata = bx_sampler(seed=jax_seed, **kwargs)
idata_from = "bayeux"
else:
raise ValueError(
f"sampler_backend value {sampler_backend} is not valid. Please choose one of"
f" {self.pymc_methods['mcmc'] + self.bayeux_methods['mcmc']}"
)

idata = self._clean_results(idata, omit_offsets, include_mean, idata_from)
return idata

def _clean_results(self, idata, omit_offsets, include_mean):
def _clean_results(self, idata, omit_offsets, include_mean, idata_from):
for group in idata.groups():

getattr(idata, group).attrs["modeling_interface"] = "bambi"
Expand All @@ -258,6 +271,15 @@ def _clean_results(self, idata, omit_offsets, include_mean):

dims_original = list(self.model.coords)

# Identify bayeux idata and rename dims and coordinates to match PyMC model
if idata_from == "bayeux":
pymc_model_dims = [dim for dim in dims_original if "_obs" not in dim]
bayeux_dims = [
dim for dim in idata.posterior.dims if not dim.startswith(("chain", "draw"))
]
cleaned_dims = dict(zip(bayeux_dims, pymc_model_dims))
idata = idata.rename(cleaned_dims)

# Discard dims that are in the model but unused in the posterior
dims_original = [dim for dim in dims_original if dim in idata.posterior.dims]

Expand All @@ -272,7 +294,6 @@ def _clean_results(self, idata, omit_offsets, include_mean):
idata.posterior = idata.posterior.transpose(*dims_new)

# Compute the actual intercept in all distributional components that have an intercept

for pymc_component in self.distributional_components.values():
bambi_component = pymc_component.component
if (
Expand Down Expand Up @@ -317,8 +338,8 @@ def _run_laplace(self, draws, omit_offsets, include_mean):
Mainly for pedagogical use, provides reasonable results for approximately
Gaussian posteriors. The approximation can be very poor for some models
like hierarchical ones. Use ``mcmc``, ``nuts_numpyro``, ``nuts_blackjax``
or ``vi`` for better approximations.
like hierarchical ones. Use ``mcmc``, ``vi``, or JAX based MCMC methods
for better approximations.
Parameters
----------
Expand Down Expand Up @@ -352,7 +373,7 @@ def _run_laplace(self, draws, omit_offsets, include_mean):
samples = np.random.multivariate_normal(modes, cov, size=draws)

idata = _posterior_samples_to_idata(samples, self.model)
idata = self._clean_results(idata, omit_offsets, include_mean)
idata = self._clean_results(idata, omit_offsets, include_mean, idata_from="pymc")
return idata

@property
Expand All @@ -367,6 +388,10 @@ def constant_components(self):
def distributional_components(self):
return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)}

@property
def inference_methods(self):
return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods}


def _posterior_samples_to_idata(samples, model):
"""Create InferenceData from samples.
Expand Down Expand Up @@ -406,3 +431,22 @@ def _posterior_samples_to_idata(samples, model):

idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model)
return idata


def _get_bayeux_methods():
"""Gets a dictionary of usable bayeux methods if the bayeux package is installed
within the user's environment.
Returns
-------
dict
A dict where the keys are the module names and the values are the methods
available in that module.
"""
if importlib.util.find_spec("bayeux") is None:
return {"mcmc": []}

import bayeux as bx # pylint: disable=import-outside-toplevel

# Dummy log density to get access to all methods
return bx.Model(lambda x: -(x**2), 0.0).methods
1 change: 1 addition & 0 deletions bambi/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Code for loading datasets."""

from .datasets import clear_data_home, load_data

__all__ = ["clear_data_home", "load_data"]
1 change: 1 addition & 0 deletions bambi/data/datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base IO code for datasets. Heavily influenced by Arviz's (and scikit-learn's) implementation."""

import hashlib
import itertools
import os
Expand Down
1 change: 1 addition & 0 deletions bambi/defaults/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Settings for default priors, families, etc. in Bambi."""

from bambi.defaults.utils import get_default_prior
from bambi.defaults.families import get_builtin_family

Expand Down
1 change: 1 addition & 0 deletions bambi/families/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Classes to construct model families."""

from bambi.families.family import Family
from bambi.families.likelihood import Likelihood
from bambi.families.link import Link
Expand Down
1 change: 1 addition & 0 deletions bambi/interpret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def set_default_variable_values(self) -> np.ndarray:
If categoric dtype the returned value is the unique levels of `variable'.
"""
values = None # Otherwise pylint complains
terms = get_model_terms(self.model)
# get default values for each variable in the model
for term in terms.values():
Expand Down
9 changes: 5 additions & 4 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,9 @@ def fit(
using the ``fit`` function.
Finally, ``"laplace"``, in which case a Laplace approximation is used and is not
recommended other than for pedagogical use.
To use the PyMC numpyro and blackjax samplers, use ``nuts_numpyro`` or ``nuts_blackjax``
respectively. Both methods will only work if you can use NUTS sampling, so your model
must be differentiable.
To get a list of JAX based inference methods, call
``model.backend.inference_methods['bayeux']``. This will return a dictionary of the
available methods such as ``blackjax_nuts``, ``numpyro_nuts``, among others.
init : str
Initialization method. Defaults to ``"auto"``. The available methods are:
* auto: Use ``"jitter+adapt_diag"`` and if this method fails it uses ``"adapt_diag"``.
Expand Down Expand Up @@ -306,7 +306,8 @@ def fit(
Returns
-------
An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default),
"nuts_numpyro", "nuts_blackjax" or "laplace".
"laplace", or one of the MCMC methods in
``model.backend.inference_methods['bayeux']['mcmc]``.
An ``Approximation`` object if ``"vi"``.
"""
method = kwargs.pop("method", None)
Expand Down
1 change: 1 addition & 0 deletions bambi/priors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Classes to represent prior distributions and methods to set automatic priors"""

from .prior import Prior
from .scaler import PriorScaler

Expand Down
18 changes: 6 additions & 12 deletions bambi/terms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,27 @@ class BaseTerm(ABC):

@property
@abstractmethod
def term(self):
...
def term(self): ...

@property
@abstractmethod
def data(self):
...
def data(self): ...

@property
@abstractmethod
def name(self):
...
def name(self): ...

@property
@abstractmethod
def shape(self):
...
def shape(self): ...

@property
@abstractmethod
def levels(self):
...
def levels(self): ...

@property
@abstractmethod
def categorical(self):
...
def categorical(self): ...

@property
def alias(self):
Expand Down
1 change: 1 addition & 0 deletions bambi/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def weighted(x, weights):

weighted.__metadata__ = {"kind": "weighted"}


# pylint: disable = invalid-name
@register_stateful_transform
class HSGP: # pylint: disable = too-many-instance-attributes
Expand Down
3 changes: 3 additions & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ website:
- notebooks/plot_comparisons.ipynb
- notebooks/plot_slopes.ipynb
- notebooks/interpret_advanced_usage.ipynb
- section: Alternative sampling backends
contents:
- notebooks/alternative_samplers.ipynb

quartodoc:
style: pkgdown
Expand Down
Loading

0 comments on commit 714ccb7

Please sign in to comment.