Skip to content

flatironinstitute/jaxmg

Repository files navigation

Jaxmg

JAXMg: A distributed linear solver in JAX with cuSolverMg

Docs Releases Continuous integration

JAXMg

JAXMg provides a C++ interface between JAX and cuSolverMg, NVIDIA’s multi-GPU linear solver. We provide a jittable API for the following routines.

  • cusolverMgPotrs: Solves the system of linear equations: $Ax=b$ where $A$ is an $N\times N$ symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition
  • cusolverMgPotrs: Computes the inverse of an $N\times N$ symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition.
  • cusolverMgPotrs: Computes eigenvalues and eigenvectors of an $N\times N$ symmetric (Hermitian) matrix.

For more details, see the API.

The provided binary is compiled with:

Component Version
GCC 11.5.0
CUDA 12.8.0
cuDNN 9.2.0.82-12

NOTE: We require JAX>=0.6.0, since it ships with CUDA 12.x binaries, which this package relies on. No local version of CUDA is required.

Installation

Clone the repository and install with:

pip install jaxmg

This will install a GPU compatible version of JAX.

To verify the installation (requires at least one GPU) run

pytest 

There are two types of tests:

  1. SPMD tests: Single Process Multiple GPU tests.
  2. MPMD: Multiple Processes Multiple GPU tests.

cuSolverMp

As of CUDA 13, there is a new distributed linear algebra library called cuSolverMp with similar capabilities as cuSolverMg, that does support multi-node computations as well as >16 devices. Given the similarities in syntax, it should be straightforward to eventually switch to this API. This will require sharding data into a cyclic 2D form and handling the solver orchestration with MPI.

Citations

(Citation details will be available soon.)

Acknowledgements

I acknowledge support from the Flatiron Institute. The Flatiron Institute is a division of the Simons Foundation.