Skip to content

lll6924/automatically-marginalized-MCMC

Repository files navigation

Automatic_Marginalization

Implementations of the preprint: Automatically Marginalized MCMC in Probabilistic Programming.

Dependencies

  • JAX 0.4.1
  • NumPyro 0.10.1

Assumptions of model implementations

  • Each model is declared as a subclass of model.Model, where functions model, args, kwargs, name are implemented.
  • model is a function that declares a model with the same grammars as NumPyro, except that numpyro.sample statement is replaced with primitives.my_sample. my_sample is a hack to annotate random variables in Jaxprs and function name register is used for it.
  • args is assumed to be a tuple of parameters that should always send to model (for example, covariates).
  • kwargs is assumed to be a dictionary of observations, without which model becomes the generative model.
  • name is the name of the model.
  • The overall implementations are experimental in nature and can be sensitive to different writing styles. If you find a problem, feel free to open an issue.

Running Commands

To reproduce the results with HMC-M, the following commands could be used

python -m run.main --model HierarchicalPartialPooling --model_parameters "{'dataset':'rat_tumors'}" --rng_key $rng_key
python -m run.main --model HierarchicalPartialPooling --model_parameters "{'dataset':'baseball_large'}" --rng_key $rng_key
python -m run.main --model HierarchicalPartialPooling --model_parameters "{'dataset':'baseball_small'}" --rng_key $rng_key
python -m run.main --model ElectricCompany --rng_key $rng_key --protected "['mua0','mua1','mua2','mua3']"
python -m run.main --model PulmonaryFibrosis --rng_key $rng_key --protected "['m_a','s_a','m_b','s_b']"

To reproduce the results with HMC, the following commands could be used

python -m run.hmc --model HierarchicalPartialPooling --model_parameters "{'dataset':'rat_tumors'}" --rng_key $rng_key
python -m run.hmc --model HierarchicalPartialPooling --model_parameters "{'dataset':'baseball_large'}" --rng_key $rng_key
python -m run.hmc --model HierarchicalPartialPooling --model_parameters "{'dataset':'baseball_small'}" --rng_key $rng_key
python -m run.hmc --model ElectricCompany --rng_key $rng_key
python -m run.hmc --model PulmonaryFibrosisVectorized --rng_key $rng_key

To reproduced the results with HMC-R, please use

python -m run.hmc --model ElectricCompanyReparameterized --rng_key $rng_key
python -m run.hmc --model PulmonaryFibrosisReparameterized --rng_key $rng_key

To reproduce the results in Appendix F, run the following commands with $n set to the number of branches.

python -m run.main --model TMP --model_parameters "{'N':'$n'}" --just_compile
python -m run.main --model TMP --model_parameters "{'N':'$n'}" --just_compile --no_marginalization

All above commands could be found under experiment/.

Additional Notes

Our codes depend on the patterns of Jaxprs when tracing a program, which could be different for different versions of JAX/NumPyro and different system environments, but should work in most cases.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published