Skip to content

google/wide_bnn_sampling

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Wide BNN sampling

Code for the paper "Simplicity of the wide Bayesian neural network weight posterior: theory and accelerated sampling" by Jiri Hron, Roman Novak, Jeffrey Pennington, and Jascha Sohl-Dickstein.

The main contribution is a reparametrisation of Bayesian neural network (BNN) posteriors which enables 10-200x faster mixing compared to standard parametrisation when combined with Hamiltonian Monte Carlo. Intriguingly, the sampling speed becomes higher the larger the BNN. The reparametrisation is derived using the large width BNN theory (e.g., Matthews et al., Lee et al., Hron et al.), and can be shown to transform the exact BNN weight space posterior into a distribution whose KL divergence from the multivariate standard normal distribution vanishes in the large width limit. This is the source of the speed-up at large width, but we have sometimes observed 10x faster mixing even when very far from the wide regime (i.e., when width is much smaller than the dataset size).

The code in this repository provides an efficient way of computing both the reparametrised density and the parameters at the same time. As detailed in the paper, the implementation is based on Cholesky decomposition, and a forward and backward solve akin to the usual implementation of the Cholesky solver.

We rely on JAX, a high-performance machine learning library based on XLA with simple NumPy/Autograd like API, and Neural Tangents, a high-level neural network API enabling computation with finite as well as infinite neural networks. See setup.py for other dependencies.

Using the code

The code has several dependencies described in setup.py. To install them automatically, use

git clone https://github.com/google/wide_bnn_sampling
cd wide_bnn_sampling
pip install -e .

A dependency not included is jaxlib whose installation differs based on the available hardware; please follow the relevant instructions from JAX's repository. If you want to just quickly try the code with CPU backend, you can run pip install jax jaxlib --upgrade.

To set off an experiment, you can modify the provided config.py as needed, and invoke

python3 wide_bnn_sampling/main.py --config wide_bnn_sampling/config.py --store_dir <results-directory>

The high-level structure of main.py dependencies is descibed below:

  • config.py: Configuration flags for the dataset, neural network architecture, the sampler, and auxiliary experiment run settings.
  • datasets.py: Loading and preprocessing of data.
  • measurements.py: Logging utilities.
  • models.py: Constructs neural network architectures with Neural Tangents.
  • reparametrisation.py: Effective implementation of the reparametrisation under the assumption of Gaussian likelihood and prior (details in the paper).
  • samplers.py: Custom implementation of HMC/LMC, and Metropolis-Hastings with a simple Gaussian proposal.
  • utils.py: Auxiliary methods primarily used within main.py.

CAVEAT: Despite using several tricks for improved stability, we observed significant deterioration of acceptance probabilities when computational precisions is low. We recommend using at least float32, but preferring float64 where feasible. The relevant flags in JAX are jax_enable_x64 (and jax_default_matmul_precision if on TPU).

Contributing

See CONTRIBUTING.md for details.

License

Apache 2.0; see LICENSE for details.

Disclaimer

This project is not an official Google project. It is not supported by Google and Google specifically disclaims all warranties as to its quality, merchantability, or fitness for a particular purpose.

About

No description, website, or topics provided.

Resources

License

Code of conduct

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages