This package provides a framework to implement Brownian motion on Riemannian manifolds. The theory is developed in this paper. Check out the documentation page.
- Simulation of SDEs on manifolds.
- Simulation of Riemannian Brownian motions and Riemannian uniform distributions.
- Simulation of Riemannian Langevin processes and sampling distributions on manifolds with given densities.
- Several example manifolds with flexible choices of metrics.
- Several simulation schemes.
- Extendable to new manifolds/new SDEs.
Requirement: JAX (pip install jax). If we have access to GPU then install jax cuda following JAX's installation note (for example pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html).
To install from git do (assume you have build, otherwise do pip install build).
pip install git+https://github.com/dnguyend/jax-rb
Alternatively, you can clone the project to your local directory then add the directory to your PYTHONPATH. View an example using sys.append, you can use PYTHONPATH similarly.
If you want to build the project manually from a cloned directory, go to the folder jax-rb then run
python -m build
assuming you have JAX installed.
To build the document, you need to install sphinx (pip install sphinx, pip install sphinx-rtd-theme) then go to the jax-rb/docs folder and run
make html.
After that, go to the index.html file under jax-rb/docs/_build/html, and open that file in your browser, - which will allow to navigate the documentation. Eventually, we will upload this to pypi and the users can install the library from there.