Skip to content
This repository has been archived by the owner on May 1, 2024. It is now read-only.

google-deepmind/annealed_flow_transport

Repository files navigation

Continual Repeated Annealed Flow Transport Monte Carlo (CRAFT)

and Annealed Flow Transport Monte Carlo (AFT)

The release contains implementations of

  • Continual Repeated Annealed Flow Transport Monte Carlo (CRAFT), Matthews et al (2022).
  • Annealed Flow Transport Monte Carlo (AFT), Arbel et al (2021).
  • Stochastic Normalizing Flows (SNF), Wu et al (2020).
  • Sequential Monte Carlo samplers (SMC), Del Moral et al (2006).
  • Variational inference with Normalizing Flows (VI), Rezende and Mohamed (2015).
  • Particle Markov Chain Monte Carlo (PIMH), Andrieu et al (2010).

The implementation of AFT is based on Algorithm 2 of that paper. See https://arxiv.org/abs/2102.07501 for more details.

The implementation of SNFs differs from the original one in that it exploits the connection with Annealed Importance Sampling with added normalizing flows. The training dynamics are still the same.

The implementation of Particle Markov Chain Monte Carlo is specialized to the case of a final target of interest rather than a time series and assumes an independent proposal, which can be based on SMC, VI or CRAFT.

This is research code. If you are interested in comparing to our work and have questions about how to use it please do reach out.

Installation

The code uses Python 3. We recommend using pip install -e . which makes an editable install. A reliable way to do this is within a virtual environment.

virtualenv -p python3.10 ~/venv/annealed_flow_transport
source ~/venv/annealed_flow_transport/bin/activate
pip install -e .

A GPU is highly recommended. To use one you will need to install JAX with CUDA support. For example:

pip install --upgrade jax jaxlib==0.1.68+cuda111 -f
https://storage.googleapis.com/jax-releases/jax_releases.html

The CUDA version will need to match your GPU drivers. See the JAX documentation for more discussion.

To run the unit tests use the following command:

python -m pytest

Usage

The entry point to the code is main.py taking a config file as an argument. As an example from the base directory the following command runs a simple one dimensional toy example:

python main.py --config=configs/single_normal.py

This example anneals between two one dimensional normal distributions with the same scale and two different locations using CRAFT. The script should print a sequence of steps and return a log normalizing constant estimate.

The config files use the ConfigDict from ml_collections to specify all details of the desired experiment. For example: the algorithm, the MCMC kernel, and the base distribution and target distribution. More examples can be found in the configs directory.

If you specify a model snapshot destination in the config file by setting config.save_checkpoint = True and specifying config.params_filename then you can store a file to evaluate using PIMH. This is then called using:

python evaluate.py --config=configs/single_normal.py

We have not released code for writing results to disk. The function train.run_experiments called from main.py returns a NamedDict containing a summary of results that could be caught and recorded if required.

License information

The code is licensed under the Apache 2.0 license, which can be found in full in the LICENSE file.

We have released a pickle model parameters file for the VAE example which is licensed under the Creative Commons Attribution 4.0 International (CC BY 4.0), Full text is found at https://creativecommons.org/licenses/by/4.0/legalcode.

Giving Credit

If you use this code in your work, please cite the corresponding paper i.e Arbel et al. 2021 for AFT and Matthews et al. 2022 for CRAFT.

If you use our baselines such as VI, SMC, SNF, PIMH, please cite the paper of the two below where we first used the method: For SMC and VI this is Arbel et al. 2021, whereas for PIMH and SNF this is Matthews et al. 2022.

@InProceedings{CRAFT2022,
  title={Continual Repeated Annealed Flow Transport Monte Carlo},
  author={Alexander G. D. G. Matthews and Michael Arbel and Danilo J. Rezende and Arnaud Doucet},
  booktitle = {Proceedings of the 39th International Conference on Machine Learning},
  series = {Proceedings of Machine Learning Research},
  year={2022},
  month = {Jul}
}
@InProceedings{AnnealedFlowTransport2021,
  title={Annealed Flow Transport Monte Carlo},
  author={Michael Arbel and Alexander G. D. G. Matthews and Arnaud Doucet},
  booktitle = {Proceedings of the 38th International Conference on Machine Learning},
  series = {Proceedings of Machine Learning Research},
  year={2021},
  month = {18--24 Jul}
}

Disclaimer

This is not an official Google product.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages