Skip to content

Commit

Permalink
Consolidation of hipsolver/cusolver APIs.
Browse files Browse the repository at this point in the history
  • Loading branch information
rsanthanam-amd committed Apr 7, 2022
1 parent 832d9aa commit 6c560b1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 81 deletions.
101 changes: 20 additions & 81 deletions jax/_src/lax/linalg.py
Expand Up @@ -38,10 +38,9 @@
from jax._src.lib import lapack

from jax._src.lib import cuda_linalg
from jax._src.lib import cusolver
from jax._src.lib import hip_linalg
from jax._src.lib import hipsolver
from jax._src.lib import sparse_apis
from jax._src.lib import solver_apis

from jax._src.lib import xla_client

Expand Down Expand Up @@ -365,16 +364,10 @@ def _cholesky_cpu_gpu_translation_rule(potrf_impl, ctx, avals_in, avals_out,
partial(_cholesky_cpu_gpu_translation_rule, lapack.potrf),
platform='cpu')

if cusolver is not None:
if solver_apis is not None:
xla.register_translation(
cholesky_p,
partial(_cholesky_cpu_gpu_translation_rule, cusolver.potrf),
platform='gpu')

if hipsolver is not None:
xla.register_translation(
cholesky_p,
partial(_cholesky_cpu_gpu_translation_rule, hipsolver.potrf),
partial(_cholesky_cpu_gpu_translation_rule, solver_apis.potrf),
platform='gpu')

def _cholesky_cpu_gpu_lowering(potrf_impl, ctx, operand):
Expand All @@ -398,16 +391,10 @@ def _cholesky_cpu_gpu_lowering(potrf_impl, ctx, operand):
partial(_cholesky_cpu_gpu_lowering, lapack.potrf_mhlo),
platform='cpu')

if cusolver is not None:
if solver_apis is not None:
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, cusolver.potrf_mhlo),
platform='gpu')

if hipsolver is not None:
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, hipsolver.potrf_mhlo),
partial(_cholesky_cpu_gpu_lowering, solver_apis.potrf_mhlo),
platform='gpu')

# Asymmetric eigendecomposition
Expand Down Expand Up @@ -668,22 +655,13 @@ def eigh_batching_rule(batched_args, batch_dims, lower):
platform='cpu')


if cusolver is not None:
if solver_apis is not None:
xla.register_translation(
eigh_p, partial(_eigh_cpu_gpu_translation_rule, cusolver.syevd),
eigh_p, partial(_eigh_cpu_gpu_translation_rule, solver_apis.syevd),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, cusolver.syevd_mhlo),
platform='gpu')

if hipsolver is not None:
xla.register_translation(
eigh_p, partial(_eigh_cpu_gpu_translation_rule, hipsolver.syevd),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, hipsolver.syevd_mhlo),
eigh_p, partial(_eigh_cpu_gpu_lowering, solver_apis.syevd_mhlo),
platform='gpu')


Expand Down Expand Up @@ -914,27 +892,17 @@ def _triangular_solve_gpu_lower(
mhlo.TransposeAttr.get(transpose)).results


if cusolver is not None:
if solver_apis is not None:
xla.register_translation(
triangular_solve_p,
partial(_triangular_solve_gpu_translation_rule, cusolver.trsm),
partial(_triangular_solve_gpu_translation_rule, solver_apis.trsm),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
triangular_solve_p,
partial(_triangular_solve_gpu_lower, cusolver.trsm_mhlo),
partial(_triangular_solve_gpu_lower, solver_apis.trsm_mhlo),
platform='gpu')

if hipsolver is not None:
xla.register_translation(
triangular_solve_p,
partial(_triangular_solve_gpu_translation_rule, hipsolver.trsm),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
triangular_solve_p,
partial(_triangular_solve_gpu_lower, hipsolver.trsm_mhlo),
platform='gpu')

# Support operation for LU decomposition: Transformation of the pivots returned
# by LU decomposition into permutations.
Expand Down Expand Up @@ -1276,22 +1244,13 @@ def _lu_tpu_translation_rule(ctx, avals_in, avals_out, operand):
partial(_lu_cpu_gpu_lowering, lapack.getrf_mhlo),
platform='cpu')

if cusolver is not None:
if solver_apis is not None:
xla.register_translation(
lu_p, partial(_lu_cpu_gpu_translation_rule, cusolver.getrf),
lu_p, partial(_lu_cpu_gpu_translation_rule, solver_apis.getrf),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
lu_p, partial(_lu_cpu_gpu_lowering, cusolver.getrf_mhlo),
platform='gpu')

if hipsolver is not None:
xla.register_translation(
lu_p, partial(_lu_cpu_gpu_translation_rule, hipsolver.getrf),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
lu_p, partial(_lu_cpu_gpu_lowering, hipsolver.getrf_mhlo),
lu_p, partial(_lu_cpu_gpu_lowering, solver_apis.getrf_mhlo),
platform='gpu')

xla.register_translation(lu_p, _lu_tpu_translation_rule, platform='tpu')
Expand Down Expand Up @@ -1501,26 +1460,15 @@ def _qr_cpu_gpu_lowering(geqrf_impl, orgqr_impl, ctx, operand, *,
qr_p, partial(_qr_cpu_gpu_lowering, lapack.geqrf_mhlo, lapack.orgqr_mhlo),
platform='cpu')

if cusolver is not None:
xla.register_translation(
qr_p,
partial(_qr_cpu_gpu_translation_rule, cusolver.geqrf, cusolver.orgqr),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
qr_p,
partial(_qr_cpu_gpu_lowering, cusolver.geqrf_mhlo, cusolver.orgqr_mhlo),
platform='gpu')

if hipsolver is not None:
if solver_apis is not None:
xla.register_translation(
qr_p,
partial(_qr_cpu_gpu_translation_rule, hipsolver.geqrf, hipsolver.orgqr),
partial(_qr_cpu_gpu_translation_rule, solver_apis.geqrf, solver_apis.orgqr),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
qr_p,
partial(_qr_cpu_gpu_lowering, hipsolver.geqrf_mhlo, hipsolver.orgqr_mhlo),
partial(_qr_cpu_gpu_lowering, solver_apis.geqrf_mhlo, solver_apis.orgqr_mhlo),
platform='gpu')

# Singular value decomposition
Expand Down Expand Up @@ -1730,22 +1678,13 @@ def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
platform='cpu')

if cusolver is not None:
xla.register_translation(
svd_p, partial(_svd_cpu_gpu_translation_rule, cusolver.gesvd),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, cusolver.gesvd_mhlo),
platform='gpu')

if hipsolver is not None:
if solver_apis is not None:
xla.register_translation(
svd_p, partial(_svd_cpu_gpu_translation_rule, hipsolver.gesvd),
svd_p, partial(_svd_cpu_gpu_translation_rule, solver_apis.gesvd),
platform='gpu')
if jax._src.lib.version >= (0, 3, 3):
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, hipsolver.gesvd_mhlo),
svd_p, partial(_svd_cpu_gpu_lowering, solver_apis.gesvd_mhlo),
platform='gpu')

def _tridiagonal_solve_gpu_translation_rule(ctx, avals_in, avals_out, dl, d, du,
Expand Down
1 change: 1 addition & 0 deletions jax/_src/lib/__init__.py
Expand Up @@ -127,6 +127,7 @@ def _parse_version(v: str) -> Tuple[int, ...]:
hipsparse = None

sparse_apis = cusparse or hipsparse or None
solver_apis = cusolver or hipsolver or None

try:
import jaxlib.cuda_prng as cuda_prng # pytype: disable=import-error
Expand Down

0 comments on commit 6c560b1

Please sign in to comment.