Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use bayeux to access a wide range of samplers #775

Merged
merged 34 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
44966d3
use bayeux to access a wide range of samplers
GStechschulte Feb 4, 2024
061a1b0
use bayeux to access a wide range of samplers
GStechschulte Feb 4, 2024
8afe534
add notebook links to family table (#774)
GStechschulte Feb 4, 2024
9f1d9d1
access methods programatically
GStechschulte Feb 5, 2024
9b42fc2
clean bayeux idata to be consistent with pymc model coords
GStechschulte Feb 10, 2024
91ce2a0
rename alternative sampler args in tests
GStechschulte Feb 10, 2024
89a2aee
change docstring to reflect bayeux sampler names
GStechschulte Feb 10, 2024
d6058ad
bayeux dependencies are numpyro/jax/jaxlib/blackjax
GStechschulte Feb 10, 2024
722c8b5
rename idata coords and dims to PyMC model
GStechschulte Feb 19, 2024
ccc2877
add JAX based sampler dependencies
GStechschulte Feb 19, 2024
74b4e8b
Update code of conduct (#783)
tomicapretto Feb 21, 2024
47bb161
[WIP] Fix HSGP predictions (#780)
tomicapretto Feb 29, 2024
9f6fc2a
bayeux 0.1.9 updates
GStechschulte Mar 1, 2024
10bb508
bump bayeux version
GStechschulte Mar 1, 2024
f7bf97f
remove TFP methods, optimizers, and resolve pylint errors
GStechschulte Mar 1, 2024
1147d96
alternative backends docs
GStechschulte Mar 1, 2024
cdcf104
tests for JAX based samplers except TFP
GStechschulte Mar 1, 2024
bf1e478
add TFP backend example
GStechschulte Mar 1, 2024
27a41e6
add TFP MCMC methods
GStechschulte Mar 1, 2024
98f7da8
don't use flowmc, chees, meads for categorical model
GStechschulte Mar 3, 2024
4ae1092
call model.backend.inference_methods to show list of samplers
GStechschulte Mar 3, 2024
81936a2
docstring changes
GStechschulte Mar 3, 2024
f6d8894
inference_methods attribute and change JAX random seed
GStechschulte Mar 3, 2024
02d1df6
Add FutureWarning to inference_method parameter
GStechschulte Mar 4, 2024
dd278d4
black formatting and resolve pylint errors
GStechschulte Mar 4, 2024
b0e94a4
fix package name
GStechschulte Mar 4, 2024
65fd945
drop 3.9 and add 3.12 to testing matrix
GStechschulte Mar 19, 2024
4712f1a
change Python versions in requires-python and target-version
GStechschulte Mar 19, 2024
d508214
remove python 3.11 black target-version
GStechschulte Mar 19, 2024
1d05684
pin requires-python to <3.13
GStechschulte Mar 19, 2024
f06715e
pip upgrade setuptools
GStechschulte Mar 19, 2024
ef575d3
Bump PyMC to 5.12
tomicapretto Mar 28, 2024
9bf90a6
Upgrade black and pylint
tomicapretto Mar 28, 2024
9f9d769
remove upgrading of setup tools
GStechschulte Mar 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@

### Maintenance and fixes

* Fix bug in predictions with models using HSGP (#780)

### Documentation

* Our Code of Conduct now includes how to send a report (#783)

### Deprecation

## 0.13.0
Expand Down
32 changes: 32 additions & 0 deletions CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Bambi Community Code of Conduct

Bambi adopts the NumFOCUS Code of Conduct directly. In other words, we expect our community to treat others with kindness and understanding.

# The short version

Be kind to others. Do not insult or put down others.
Behave professionally. Remember that harassment and sexist, racist,
or exclusionary jokes are not appropriate.
Expand All @@ -15,3 +19,31 @@ or religion. We do not tolerate harassment of community members
in any form.

Thank you for helping make this a welcoming, friendly community for all.

# How to Submit a Report

If you feel that there has been a Code of Conduct violation an anonymous
reporting form is available.

**If you feel your safety is in jeopardy or the situation is an
emergency, we urge you to contact local law enforcement before making
a report. (In the U.S., dial 911.)**

We are committed to promptly addressing any reported issues.
If you have experienced or witnessed behavior that violates this
Code of Conduct, please complete the form below to
make a report.

**REPORTING FORM:** https://numfocus.typeform.com/to/ynjGdT

Reports are sent to the NumFOCUS Code of Conduct Enforcement Team
(see below).

You can view the Privacy Policy and Terms of Service for TypeForm here.
The NumFOCUS Privacy Policy is here:
https://www.numfocus.org/privacy-policy

# Full Code of Conduct

The full text of the NumFOCUS/Bambi Code of Conduct can be found on
NumFOCUS's website https://numfocus.org/code-of-conduct
114 changes: 71 additions & 43 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import functools
import importlib
import logging
import operator
import traceback


from copy import deepcopy
from importlib.metadata import version

Expand All @@ -12,7 +13,6 @@
import pytensor.tensor as pt
from pytensor.tensor.special import softmax


from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2
from bambi.backend.model_components import ConstantComponent, DistributionalComponent
from bambi.utils import get_aliased_name
Expand Down Expand Up @@ -46,6 +46,8 @@ def __init__(self):
self.model = None
self.spec = None
self.components = {}
self.bayeux_methods = _get_bayeux_methods()
self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]}

def build(self, spec):
"""Compile the PyMC model from an abstract model specification.
Expand Down Expand Up @@ -95,7 +97,7 @@ def run(
"""Run PyMC sampler."""
inference_method = inference_method.lower()
# NOTE: Methods return different types of objects (idata, approximation, and dictionary)
if inference_method in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
if inference_method in (self.pymc_methods["mcmc"] + self.bayeux_methods["mcmc"]):
result = self._run_mcmc(
draws,
tune,
Expand All @@ -110,7 +112,7 @@ def run(
inference_method,
**kwargs,
)
elif inference_method == "vi":
elif inference_method in self.pymc_methods["vi"]:
result = self._run_vi(**kwargs)
elif inference_method == "laplace":
result = self._run_laplace(draws, omit_offsets, include_mean)
Expand Down Expand Up @@ -169,8 +171,8 @@ def _run_mcmc(
sampler_backend="mcmc",
**kwargs,
):
with self.model:
if sampler_backend == "mcmc":
if sampler_backend in self.pymc_methods["mcmc"]:
with self.model:
try:
idata = pm.sample(
draws=draws,
Expand Down Expand Up @@ -205,41 +207,35 @@ def _run_mcmc(
)
else:
raise
elif sampler_backend == "nuts_numpyro":
import pymc.sampling_jax # pylint: disable=import-outside-toplevel

if not chains:
# sample_numpyro_nuts does not handle chains = None like pm.sample does
chains = 4
idata = pymc.sampling_jax.sample_numpyro_nuts(
draws=draws,
tune=tune,
chains=chains,
random_seed=random_seed,
**kwargs,
)
elif sampler_backend == "nuts_blackjax":
import pymc.sampling_jax # pylint: disable=import-outside-toplevel

# sample_blackjax_nuts does not handle chains = None like pm.sample does
if not chains:
chains = 4
idata = pymc.sampling_jax.sample_blackjax_nuts(
draws=draws,
tune=tune,
chains=chains,
random_seed=random_seed,
**kwargs,
)
else:
raise ValueError(
f"sampler_backend value {sampler_backend} is not valid. Please choose one of"
f"'mcmc', 'nuts_numpyro' or 'nuts_blackjax'"
)
idata = self._clean_results(idata, omit_offsets, include_mean)
idata_from = "pymc"
elif sampler_backend in self.bayeux_methods["mcmc"]:
import bayeux as bx # pylint: disable=import-outside-toplevel
import jax # pylint: disable=import-outside-toplevel

# Set the seed for reproducibility if provided
if random_seed is not None:
if not isinstance(random_seed, int):
random_seed = random_seed[0]
np.random.seed(random_seed)

jax_seed = jax.random.PRNGKey(np.random.randint(2**32 - 1))

bx_model = bx.Model.from_pymc(self.model)
ColCarroll marked this conversation as resolved.
Show resolved Hide resolved
bx_sampler = operator.attrgetter(sampler_backend)(
bx_model.mcmc # pylint: disable=no-member
)
idata = bx_sampler(seed=jax_seed, **kwargs)
idata_from = "bayeux"
else:
raise ValueError(
f"sampler_backend value {sampler_backend} is not valid. Please choose one of"
f" {self.pymc_methods['mcmc'] + self.bayeux_methods['mcmc']}"
)

idata = self._clean_results(idata, omit_offsets, include_mean, idata_from)
return idata

def _clean_results(self, idata, omit_offsets, include_mean):
def _clean_results(self, idata, omit_offsets, include_mean, idata_from):
for group in idata.groups():

getattr(idata, group).attrs["modeling_interface"] = "bambi"
Expand All @@ -258,6 +254,15 @@ def _clean_results(self, idata, omit_offsets, include_mean):

dims_original = list(self.model.coords)

# Identify bayeux idata and rename dims and coordinates to match PyMC model
if idata_from == "bayeux":
pymc_model_dims = [dim for dim in dims_original if "_obs" not in dim]
bayeux_dims = [
dim for dim in idata.posterior.dims if not dim.startswith(("chain", "draw"))
]
cleaned_dims = dict(zip(bayeux_dims, pymc_model_dims))
idata = idata.rename(cleaned_dims)

# Discard dims that are in the model but unused in the posterior
dims_original = [dim for dim in dims_original if dim in idata.posterior.dims]

Expand All @@ -272,7 +277,6 @@ def _clean_results(self, idata, omit_offsets, include_mean):
idata.posterior = idata.posterior.transpose(*dims_new)

# Compute the actual intercept in all distributional components that have an intercept

for pymc_component in self.distributional_components.values():
bambi_component = pymc_component.component
if (
Expand Down Expand Up @@ -317,8 +321,8 @@ def _run_laplace(self, draws, omit_offsets, include_mean):

Mainly for pedagogical use, provides reasonable results for approximately
Gaussian posteriors. The approximation can be very poor for some models
like hierarchical ones. Use ``mcmc``, ``nuts_numpyro``, ``nuts_blackjax``
or ``vi`` for better approximations.
like hierarchical ones. Use ``mcmc``, ``vi``, or JAX based MCMC methods
for better approximations.

Parameters
----------
Expand Down Expand Up @@ -352,7 +356,7 @@ def _run_laplace(self, draws, omit_offsets, include_mean):
samples = np.random.multivariate_normal(modes, cov, size=draws)

idata = _posterior_samples_to_idata(samples, self.model)
idata = self._clean_results(idata, omit_offsets, include_mean)
idata = self._clean_results(idata, omit_offsets, include_mean, idata_from="pymc")
return idata

@property
Expand All @@ -367,6 +371,10 @@ def constant_components(self):
def distributional_components(self):
return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)}

@property
def inference_methods(self):
return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods}


def _posterior_samples_to_idata(samples, model):
"""Create InferenceData from samples.
Expand Down Expand Up @@ -406,3 +414,23 @@ def _posterior_samples_to_idata(samples, model):

idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model)
return idata


def _get_bayeux_methods():
"""Gets a dictionary of usable bayeux methods if the bayeux package is installed
within the user's environment.

Returns
-------
dict
A dict where the keys are the module names and the values are the methods
available in that module.
"""
bx_methods = {}
if importlib.util.find_spec("bayeux") is None:
return bx_methods

GStechschulte marked this conversation as resolved.
Show resolved Hide resolved
import bayeux as bx # pylint: disable=import-outside-toplevel

# Dummy log density to get access to all methods
return bx.Model(lambda x: -(x**2), 0.0).methods
11 changes: 10 additions & 1 deletion bambi/interpret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,20 @@ def get_model_covariates(model: Model) -> np.ndarray:
for term in terms.values():
if hasattr(term, "components"):
for component in term.components:
# if the component is a function call, use the argument names
# if the component is a function call, look for relevant argument names
if isinstance(component, Call):
# Add variable names passed as unnamed arguments
covariates.append(
[arg.name for arg in component.call.args if isinstance(arg, LazyVariable)]
)
# Add variable names passed as named arguments
covariates.append(
[
kwarg_value.name
for kwarg_value in component.call.kwargs.values()
if isinstance(kwarg_value, LazyVariable)
]
)
else:
covariates.append([component.name])
elif hasattr(term, "factor"):
Expand Down
9 changes: 8 additions & 1 deletion bambi/model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,12 @@ def predict_common(
X = np.delete(X, term_slice, axis=1)

# Add HSGP components contribution to the linear predictor
hsgp_slices = []
GStechschulte marked this conversation as resolved.
Show resolved Hide resolved
for term_name, term in self.hsgp_terms.items():
# Extract data for the HSGP component from the design matrix
term_slice = self.design.common.slices[term_name]
x_slice = X[:, term_slice]
X = np.delete(X, term_slice, axis=1)
hsgp_slices.append(term_slice)
term_aliased_name = get_aliased_name(term)
hsgp_to_stack_dims = (f"{term_aliased_name}_weights_dim",)

Expand Down Expand Up @@ -288,6 +289,12 @@ def predict_common(
# Add contribution to the linear predictor
linear_predictor += hsgp_contribution

# Remove columns of X that are associated with HSGP contributions
# All the slices _must be_ deleted at the same time. Otherwise the slice objects don't
# reflect the right columns of X at the time they're used
if hsgp_slices:
X = np.delete(X, np.r_[tuple(hsgp_slices)], axis=1)

if self.common_terms or self.intercept_term:
# Create DataArray
X_terms = [get_aliased_name(term) for term in self.common_terms.values()]
Expand Down
9 changes: 5 additions & 4 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,9 @@ def fit(
using the ``fit`` function.
Finally, ``"laplace"``, in which case a Laplace approximation is used and is not
recommended other than for pedagogical use.
To use the PyMC numpyro and blackjax samplers, use ``nuts_numpyro`` or ``nuts_blackjax``
respectively. Both methods will only work if you can use NUTS sampling, so your model
must be differentiable.
To get a list of JAX based inference methods, call
``model.backend.inference_methods['bayeux']``. This will return a dictionary of the
available methods such as ``blackjax_nuts``, ``numpyro_nuts``, among others.
init : str
Initialization method. Defaults to ``"auto"``. The available methods are:
* auto: Use ``"jitter+adapt_diag"`` and if this method fails it uses ``"adapt_diag"``.
Expand Down Expand Up @@ -306,7 +306,8 @@ def fit(
Returns
-------
An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default),
"nuts_numpyro", "nuts_blackjax" or "laplace".
"laplace", or one of the MCMC methods in
``model.backend.inference_methods['bayeux']['mcmc]``.
An ``Approximation`` object if ``"vi"``.
"""
method = kwargs.pop("method", None)
Expand Down
3 changes: 3 additions & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ website:
- notebooks/plot_comparisons.ipynb
- notebooks/plot_slopes.ipynb
- notebooks/interpret_advanced_usage.ipynb
- section: Alternative sampling backends
contents:
- notebooks/alternative_samplers.ipynb

quartodoc:
style: pkgdown
Expand Down
Loading
Loading