Jax samplers: 3 ways
blog post: link
Code to reproduce the results in the "Jax 3 ways" blog post.
Installation:
virtualenv venv; source venv/bin/activate
pip install -r requirements.txt
Scripts
To run all the examples, simply run the bash script: run_all_samplers.sh
ULA
run_ULA_increase_data.py
run_ULA_increase_dimension.py
MALA
run_MALA_increase_data.py
run_MALA_increase_dimension.py
SGLD
run_SGLD_increase_data.py
run_SGLD_increase_dimension.py