Skip to content
/ SaxBI Public

JAX implementation of Sequential Neural Likelihood Estimation (SNLE) and Sequential Neural Ratio Estimation (SNRE) simulation-based inference algorithms

Notifications You must be signed in to change notification settings

jtamanas/SaxBI

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

logo

SaxBI is a JAX implementation of likelihood-free simulation-based inference (sbi) methods. Currently, the two algorithms used are Sequential Neural Likelihood Estimation (SNLE) and Sequential Neural Ratio Estimation (SNRE). This package offers a simple, functional API for carrying out the approximate posterior inference.

The fully automated pipeline features:

  • Flax-based autoregressive normalizing flows with affine, piecewise affine, and piecewise rational quadratic splines
  • Flax-based classifiers with/out residual skip connections
  • Hamiltonian Monte Carlo sampling with NUTS kernels implemented in Numpyro
  • And more!
  • Probably some bugs too... Let me know what you find 😅

Installation

saxbi requires python 3.9 or higher. It can be easily installed from the repository's home directory with

python setup.py install

Basic Usage

The main workhorse of this package is the pipeline function which takes 5 required arguments: rng, X_true, get_simulator, log_prior, and sample_prior. We recommend making a simulator.py file from which the latter 4 of these can be imported. The pipeline function then returns the flax model, its trained parameters, and samples from the final iteration of the posterior.

from saxbi import pipeline
from simulator import X_true, get_simulator, log_prior, sample_prior

rng = jax.random.PRNGKey(16)

model, params, Theta_post = pipeline(rng, X_true, get_simulator, log_prior, sample_prior)

The examples/ directory holds a few canonical examples from the literature to show off the syntax in greater detail.

SBI Algorithm References

Sequential Neural Likelihood Estimation (SNLE)

Sequential Neural likelihood-to-evidence Ratio Estimation (SNRE)

Todo

  • Add diagnostics (like MMD, ROC AUC)
  • Add support for Mining Gold (i.e. using simulator derivatives to improve likelihood estimators)

About

JAX implementation of Sequential Neural Likelihood Estimation (SNLE) and Sequential Neural Ratio Estimation (SNRE) simulation-based inference algorithms

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages