JAXMg provides a C++ interface between JAX and cuSolverMg, NVIDIA’s multi-GPU linear solver. We provide a jittable API for the following routines.
-
cusolverMgPotrs: Solves the system of linear equations:
$Ax=b$ where$A$ is an$N\times N$ symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition -
cusolverMgPotrs: Computes the inverse of an
$N\times N$ symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition. -
cusolverMgPotrs: Computes eigenvalues and eigenvectors of an
$N\times N$ symmetric (Hermitian) matrix.
For more details, see the API.
The provided binary is compiled with:
| Component | Version |
|---|---|
| GCC | 11.5.0 |
| CUDA | 12.8.0 |
| cuDNN | 9.2.0.82-12 |
NOTE: We require JAX>=0.6.0, since it ships with CUDA 12.x binaries, which this package relies on. No local version of CUDA is required.
Clone the repository and install with:
pip install jaxmgThis will install a GPU compatible version of JAX.
To verify the installation (requires at least one GPU) run
pytest There are two types of tests:
- SPMD tests: Single Process Multiple GPU tests.
- MPMD: Multiple Processes Multiple GPU tests.
As of CUDA 13, there is a new distributed linear algebra library called cuSolverMp with similar capabilities as cuSolverMg, that does support multi-node computations as well as >16 devices. Given the similarities in syntax, it should be straightforward to eventually switch to this API. This will require sharding data into a cyclic 2D form and handling the solver orchestration with MPI.
(Citation details will be available soon.)
I acknowledge support from the Flatiron Institute. The Flatiron Institute is a division of the Simons Foundation.
