Bayesian MCMC sampling package with coupled A/B chains and nested R-hat diagnostics.
- Coupled A/B Sampling: Chains split into two groups where each group's proposal distribution is informed by the other group's current state
- Nested R-hat Diagnostics: Supports superchain/subchain structure for improved convergence diagnostics (Margossian et al., 2022)
- GPU Acceleration: Built on JAX for efficient GPU-based sampling
- Flexible Proposal System: 13 proposal types for different sampling scenarios
- COUPLED_TRANSFORM Sampler: Theta-preserving updates for Non-Centered Parameterization (NCP) in hierarchical models
- Registry Pattern: Easy registration of custom posterior models
- Multi-run Sampling: Automatic checkpoint management with reset/resume schedules
- Cross-session Caching: JAX compilation persists across Python sessions
| Type | Description |
|---|---|
SELF_MEAN |
Random walk centered on current state |
CHAIN_MEAN |
Independent proposal centered on population mean |
MIXTURE |
Probabilistic mix of SELF_MEAN and CHAIN_MEAN |
MULTINOMIAL |
For discrete parameters on integer grid |
MALA |
Metropolis-adjusted Langevin (gradient-based) |
MEAN_MALA |
MALA with gradient at coupled mean |
MEAN_WEIGHTED |
Adaptive interpolation based on Mahalanobis distance |
MODE_WEIGHTED |
Interpolation toward mode (highest log-posterior chain) |
MCOV_WEIGHTED |
Mean-covariance weighted with configurable blend |
MCOV_WEIGHTED_VEC |
Vectorized per-parameter distance and interpolation |
MCOV_SMOOTH |
Smooth three-zone transition: chain_mean at equilibrium, tracking when far |
MCOV_MODE |
Mode-targeting with scalar Mahalanobis distance scaling |
MCOV_MODE_VEC |
Mode-targeting with per-parameter distance and interpolation |
| Type | Description |
|---|---|
METROPOLIS_HASTINGS |
Standard MH with configurable proposal |
DIRECT_CONJUGATE |
Direct/Gibbs sampling for conjugate priors |
COUPLED_TRANSFORM |
MH with deterministic coupled transforms (theta-preserving NCP) |
pip install -e .Requires JAX with CUDA support for GPU acceleration.
from bamcmc import register_posterior, BlockSpec, SamplerType, ProposalType, rmcmc
# Register your posterior
register_posterior('my_model', {
'log_posterior': my_log_posterior_fn,
'batch_type': my_batch_type_fn,
'initial_vector': my_initial_vector_fn,
})
# Run MCMC with run schedule
mcmc_config = {
'posterior_id': 'my_model',
'num_chains_a': 500,
'num_chains_b': 500,
'burn_iter': 1000,
'num_collect': 5000,
'thin_iteration': 10,
'reset_runs': 3, # 3 reset runs, then...
'resume_runs': 5, # 5 resume runs
}
summary = rmcmc(
mcmc_config,
data,
output_dir='./output',
)
# Or use rmcmc_single for single-run control
from bamcmc import rmcmc_single
results, checkpoint = rmcmc_single(mcmc_config, data)
history = results['history']
diagnostics = results['diagnostics']See docs/README.md for detailed package documentation including:
- Core concepts (BlockSpec, proposals, coupled chains)
- Data format requirements
- Adding new proposals and posteriors
- Performance considerations
pytest tests/ -vMIT — see LICENSE.