From 6c560b14a7a1b9c2bb0d1bcf4db4efd30a7856df Mon Sep 17 00:00:00 2001 From: Rohit Santhanam Date: Wed, 6 Apr 2022 14:45:47 +0000 Subject: [PATCH] Consolidation of hipsolver/cusolver APIs. --- jax/_src/lax/linalg.py | 101 ++++++++------------------------------- jax/_src/lib/__init__.py | 1 + 2 files changed, 21 insertions(+), 81 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index fc44c0656430..51f201568ed3 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -38,10 +38,9 @@ from jax._src.lib import lapack from jax._src.lib import cuda_linalg -from jax._src.lib import cusolver from jax._src.lib import hip_linalg -from jax._src.lib import hipsolver from jax._src.lib import sparse_apis +from jax._src.lib import solver_apis from jax._src.lib import xla_client @@ -365,16 +364,10 @@ def _cholesky_cpu_gpu_translation_rule(potrf_impl, ctx, avals_in, avals_out, partial(_cholesky_cpu_gpu_translation_rule, lapack.potrf), platform='cpu') -if cusolver is not None: +if solver_apis is not None: xla.register_translation( cholesky_p, - partial(_cholesky_cpu_gpu_translation_rule, cusolver.potrf), - platform='gpu') - -if hipsolver is not None: - xla.register_translation( - cholesky_p, - partial(_cholesky_cpu_gpu_translation_rule, hipsolver.potrf), + partial(_cholesky_cpu_gpu_translation_rule, solver_apis.potrf), platform='gpu') def _cholesky_cpu_gpu_lowering(potrf_impl, ctx, operand): @@ -398,16 +391,10 @@ def _cholesky_cpu_gpu_lowering(potrf_impl, ctx, operand): partial(_cholesky_cpu_gpu_lowering, lapack.potrf_mhlo), platform='cpu') - if cusolver is not None: + if solver_apis is not None: mlir.register_lowering( cholesky_p, - partial(_cholesky_cpu_gpu_lowering, cusolver.potrf_mhlo), - platform='gpu') - - if hipsolver is not None: - mlir.register_lowering( - cholesky_p, - partial(_cholesky_cpu_gpu_lowering, hipsolver.potrf_mhlo), + partial(_cholesky_cpu_gpu_lowering, solver_apis.potrf_mhlo), platform='gpu') # Asymmetric eigendecomposition @@ -668,22 +655,13 @@ def eigh_batching_rule(batched_args, batch_dims, lower): platform='cpu') -if cusolver is not None: +if solver_apis is not None: xla.register_translation( - eigh_p, partial(_eigh_cpu_gpu_translation_rule, cusolver.syevd), + eigh_p, partial(_eigh_cpu_gpu_translation_rule, solver_apis.syevd), platform='gpu') if jax._src.lib.version >= (0, 3, 3): mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, cusolver.syevd_mhlo), - platform='gpu') - -if hipsolver is not None: - xla.register_translation( - eigh_p, partial(_eigh_cpu_gpu_translation_rule, hipsolver.syevd), - platform='gpu') - if jax._src.lib.version >= (0, 3, 3): - mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, hipsolver.syevd_mhlo), + eigh_p, partial(_eigh_cpu_gpu_lowering, solver_apis.syevd_mhlo), platform='gpu') @@ -914,27 +892,17 @@ def _triangular_solve_gpu_lower( mhlo.TransposeAttr.get(transpose)).results -if cusolver is not None: +if solver_apis is not None: xla.register_translation( triangular_solve_p, - partial(_triangular_solve_gpu_translation_rule, cusolver.trsm), + partial(_triangular_solve_gpu_translation_rule, solver_apis.trsm), platform='gpu') if jax._src.lib.version >= (0, 3, 3): mlir.register_lowering( triangular_solve_p, - partial(_triangular_solve_gpu_lower, cusolver.trsm_mhlo), + partial(_triangular_solve_gpu_lower, solver_apis.trsm_mhlo), platform='gpu') -if hipsolver is not None: - xla.register_translation( - triangular_solve_p, - partial(_triangular_solve_gpu_translation_rule, hipsolver.trsm), - platform='gpu') - if jax._src.lib.version >= (0, 3, 3): - mlir.register_lowering( - triangular_solve_p, - partial(_triangular_solve_gpu_lower, hipsolver.trsm_mhlo), - platform='gpu') # Support operation for LU decomposition: Transformation of the pivots returned # by LU decomposition into permutations. @@ -1276,22 +1244,13 @@ def _lu_tpu_translation_rule(ctx, avals_in, avals_out, operand): partial(_lu_cpu_gpu_lowering, lapack.getrf_mhlo), platform='cpu') -if cusolver is not None: +if solver_apis is not None: xla.register_translation( - lu_p, partial(_lu_cpu_gpu_translation_rule, cusolver.getrf), + lu_p, partial(_lu_cpu_gpu_translation_rule, solver_apis.getrf), platform='gpu') if jax._src.lib.version >= (0, 3, 3): mlir.register_lowering( - lu_p, partial(_lu_cpu_gpu_lowering, cusolver.getrf_mhlo), - platform='gpu') - -if hipsolver is not None: - xla.register_translation( - lu_p, partial(_lu_cpu_gpu_translation_rule, hipsolver.getrf), - platform='gpu') - if jax._src.lib.version >= (0, 3, 3): - mlir.register_lowering( - lu_p, partial(_lu_cpu_gpu_lowering, hipsolver.getrf_mhlo), + lu_p, partial(_lu_cpu_gpu_lowering, solver_apis.getrf_mhlo), platform='gpu') xla.register_translation(lu_p, _lu_tpu_translation_rule, platform='tpu') @@ -1501,26 +1460,15 @@ def _qr_cpu_gpu_lowering(geqrf_impl, orgqr_impl, ctx, operand, *, qr_p, partial(_qr_cpu_gpu_lowering, lapack.geqrf_mhlo, lapack.orgqr_mhlo), platform='cpu') -if cusolver is not None: - xla.register_translation( - qr_p, - partial(_qr_cpu_gpu_translation_rule, cusolver.geqrf, cusolver.orgqr), - platform='gpu') - if jax._src.lib.version >= (0, 3, 3): - mlir.register_lowering( - qr_p, - partial(_qr_cpu_gpu_lowering, cusolver.geqrf_mhlo, cusolver.orgqr_mhlo), - platform='gpu') - -if hipsolver is not None: +if solver_apis is not None: xla.register_translation( qr_p, - partial(_qr_cpu_gpu_translation_rule, hipsolver.geqrf, hipsolver.orgqr), + partial(_qr_cpu_gpu_translation_rule, solver_apis.geqrf, solver_apis.orgqr), platform='gpu') if jax._src.lib.version >= (0, 3, 3): mlir.register_lowering( qr_p, - partial(_qr_cpu_gpu_lowering, hipsolver.geqrf_mhlo, hipsolver.orgqr_mhlo), + partial(_qr_cpu_gpu_lowering, solver_apis.geqrf_mhlo, solver_apis.orgqr_mhlo), platform='gpu') # Singular value decomposition @@ -1730,22 +1678,13 @@ def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv): svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo), platform='cpu') -if cusolver is not None: - xla.register_translation( - svd_p, partial(_svd_cpu_gpu_translation_rule, cusolver.gesvd), - platform='gpu') - if jax._src.lib.version >= (0, 3, 3): - mlir.register_lowering( - svd_p, partial(_svd_cpu_gpu_lowering, cusolver.gesvd_mhlo), - platform='gpu') - -if hipsolver is not None: +if solver_apis is not None: xla.register_translation( - svd_p, partial(_svd_cpu_gpu_translation_rule, hipsolver.gesvd), + svd_p, partial(_svd_cpu_gpu_translation_rule, solver_apis.gesvd), platform='gpu') if jax._src.lib.version >= (0, 3, 3): mlir.register_lowering( - svd_p, partial(_svd_cpu_gpu_lowering, hipsolver.gesvd_mhlo), + svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo), platform='gpu') def _tridiagonal_solve_gpu_translation_rule(ctx, avals_in, avals_out, dl, d, du, diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 5d3a3ae3a752..85bc3d481472 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -127,6 +127,7 @@ def _parse_version(v: str) -> Tuple[int, ...]: hipsparse = None sparse_apis = cusparse or hipsparse or None +solver_apis = cusolver or hipsolver or None try: import jaxlib.cuda_prng as cuda_prng # pytype: disable=import-error