Skip to content

Commit

Permalink
Fix for hipsparse in ROCm.
Browse files Browse the repository at this point in the history
  • Loading branch information
reza-amd committed Mar 25, 2022
1 parent 58efa00 commit 8cd0294
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 91 deletions.
14 changes: 3 additions & 11 deletions jax/_src/lax/linalg.py
Expand Up @@ -37,10 +37,9 @@

from jax._src.lib import cuda_linalg
from jax._src.lib import cusolver
from jax._src.lib import cusparse
from jax._src.lib import hip_linalg
from jax._src.lib import hipsolver
from jax._src.lib import hipsparse
from jax._src.lib import sparse_apis

from jax._src.lib import xla_client

Expand Down Expand Up @@ -1406,22 +1405,15 @@ def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):

def _tridiagonal_solve_gpu_translation_rule(ctx, avals_in, avals_out, dl, d, du,
b, *, m, n, ldb, t):
if cusparse:
return [cusparse.gtsv2(ctx.builder, dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]
if hipsparse:
return [hipsparse.gtsv2(ctx.builder, dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]
return [sparse_apis.gtsv2(ctx.builder, dl, d, du, b, m=m, n=n, ldb=ldb, t=t)]

tridiagonal_solve_p = Primitive('tridiagonal_solve')
tridiagonal_solve_p.multiple_results = False
tridiagonal_solve_p.def_impl(
functools.partial(xla.apply_primitive, tridiagonal_solve_p))
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
if cusparse is not None and hasattr(cusparse, "gtsv2"):
xla.register_translation(tridiagonal_solve_p,
_tridiagonal_solve_gpu_translation_rule,
platform='gpu')
if hipsparse is not None and hasattr(hipsparse, "gtsv2"):
if sparse_apis and hasattr(sparse_apis, "gtsv2"):
xla.register_translation(tridiagonal_solve_p,
_tridiagonal_solve_gpu_translation_rule,
platform='gpu')
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/lib/__init__.py
Expand Up @@ -126,6 +126,8 @@ def _parse_version(v: str) -> Tuple[int, ...]:
except ImportError:
hipsparse = None

sparse_apis = cusparse or hipsparse or None

try:
import jaxlib.cuda_prng as cuda_prng # pytype: disable=import-error
except ImportError:
Expand Down
12 changes: 7 additions & 5 deletions jax/experimental/sparse/bcoo.py
Expand Up @@ -38,10 +38,11 @@
from jax._src.lax.lax import (
ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)
from jax._src.lib import cusparse
from jax._src.lib import xla_client as xc
from jax._src.numpy.setops import _unique

from jax._src.lib import sparse_apis

xops = xc._xla.ops

Dtype = Any
Expand Down Expand Up @@ -687,11 +688,12 @@ def _bcoo_dot_general_cuda_translation_rule(
assert lhs_data_aval.dtype in [np.float32, np.float64, np.complex64, np.complex128]
assert lhs_data_aval.dtype == rhs_aval.dtype
assert lhs_indices_aval.dtype == np.int32
assert sparse_apis is not None

if rhs_ndim == 1:
bcoo_dot_general_fn = cusparse.coo_matvec
bcoo_dot_general_fn = sparse_apis.coo_matvec
elif rhs_ndim == 2:
bcoo_dot_general_fn = cusparse.coo_matmat
bcoo_dot_general_fn = sparse_apis.coo_matmat
if rhs_contract[0] == 1:
rhs = xops.Transpose(rhs, permutation=[1, 0])
else:
Expand Down Expand Up @@ -751,7 +753,7 @@ def _bcoo_dot_general_gpu_translation_rule(

dtype = lhs_data_aval.dtype
if dtype not in [np.float32, np.float64, np.complex64, np.complex128]:
warnings.warn(f'bcoo_dot_general cusparse lowering not available for '
warnings.warn(f'bcoo_dot_general cusparse/hipsparse lowering not available for '
f'dtype={dtype}. Falling back to default implementation.',
CuSparseEfficiencyWarning)
return _bcoo_dot_general_default_translation_rule(
Expand Down Expand Up @@ -842,7 +844,7 @@ def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,

xla.register_translation(
bcoo_dot_general_p, _bcoo_dot_general_default_translation_rule)
if cusparse and cusparse.is_supported:
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(bcoo_dot_general_p,
_bcoo_dot_general_gpu_translation_rule,
platform='gpu')
Expand Down
44 changes: 11 additions & 33 deletions jax/experimental/sparse/coo.py
Expand Up @@ -27,19 +27,11 @@
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _coo_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax import tree_util
from jax._src.lib import sparse_apis
from jax._src.numpy.lax_numpy import _promote_dtypes
from jax._src.lib import xla_client
import jax.numpy as jnp

try:
from jax._src.lib import cusparse
except ImportError:
cusparse = None

try:
from jax._src.lib import hipsparse
except ImportError:
hipsparse = None

xops = xla_client.ops

Expand Down Expand Up @@ -186,10 +178,8 @@ def _coo_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
return _coo_todense_translation_rule(ctx, avals_in, avals_out, data, row, col,
spinfo=spinfo)

if cusparse is not None:
result = cusparse.coo_todense(ctx.builder, data, row, col, shape=shape)
else:
result = hipsparse.coo_todense(ctx.builder, data, row, col, shape=spinfo.shape)
result = sparse_apis.coo_todense(ctx.builder, data, row, col, shape=shape)


return [xops.Transpose(result, (1, 0))] if transpose else [result]

Expand All @@ -210,7 +200,7 @@ def _coo_todense_transpose(ct, data, row, col, *, spinfo):
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.register_translation(coo_todense_p, _coo_todense_translation_rule)
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_todense_p, _coo_todense_gpu_translation_rule,
platform='gpu')

Expand Down Expand Up @@ -284,12 +274,8 @@ def _coo_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_fromdense_translation_rule(ctx, avals_in, avals_out, mat,
nse=nse, index_dtype=index_dtype)
if cusparse is not None:
data, row, col = cusparse.coo_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
else:
data, row, col = hipsparse.coo_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
data, row, col = sparse_apis.coo_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return [data, row, col]

def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype):
Expand Down Expand Up @@ -321,7 +307,7 @@ def _coo_fromdense_transpose(ct, M, *, nse, index_dtype):
ad.primitive_transposes[coo_fromdense_p] = _coo_fromdense_transpose

xla.register_translation(coo_fromdense_p, _coo_fromdense_translation_rule)
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_fromdense_p,
_coo_fromdense_gpu_translation_rule,
platform='gpu')
Expand Down Expand Up @@ -411,11 +397,7 @@ def _coo_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
return _coo_matvec_translation_rule(ctx, avals_in, avals_out, data, row, col, v,
spinfo=spinfo, transpose=transpose)

if cusparse is not None:
return [cusparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
transpose=transpose)]
else:
return [hipsparse.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
return [sparse_apis.coo_matvec(ctx.builder, data, row, col, v, shape=shape,
transpose=transpose)]

def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, spinfo, transpose):
Expand All @@ -439,7 +421,7 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, spinfo, transpose):
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
xla.register_translation(coo_matvec_p, _coo_matvec_translation_rule)
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_matvec_p, _coo_matvec_gpu_translation_rule,
platform='gpu')

Expand Down Expand Up @@ -526,11 +508,7 @@ def _coo_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, row, col,
return _coo_matmat_translation_rule(ctx, avals_in, avals_out, data, row, col, B,
spinfo=spinfo, transpose=transpose)

if cusparse is not None:
return [cusparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
transpose=transpose)]
else:
return [hipsparse.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
return [sparse_apis.coo_matmat(ctx.builder, data, row, col, B, shape=shape,
transpose=transpose)]

def _coo_matmat_jvp_left(data_dot, data, row, col, B, *, spinfo, transpose):
Expand All @@ -551,6 +529,6 @@ def _coo_matmat_transpose(ct, data, row, col, B, *, spinfo, transpose):
ad.defjvp(coo_matmat_p, _coo_matmat_jvp_left, None, None, _coo_matmat_jvp_right)
ad.primitive_transposes[coo_matmat_p] = _coo_matmat_transpose
xla.register_translation(coo_matmat_p, _coo_matmat_translation_rule)
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(coo_matmat_p, _coo_matmat_gpu_translation_rule,
platform='gpu')
47 changes: 12 additions & 35 deletions jax/experimental/sparse/csr.py
Expand Up @@ -27,18 +27,10 @@
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, _safe_asarray, CuSparseEfficiencyWarning
from jax import tree_util
from jax._src.lib import sparse_apis
from jax._src.numpy.lax_numpy import _promote_dtypes
import jax.numpy as jnp

try:
from jax._src.lib import cusparse
except ImportError:
cusparse = None

try:
from jax._src.lib import hipsparse
except ImportError:
hipsparse = None

@tree_util.register_pytree_node_class
class CSR(JAXSparse):
Expand Down Expand Up @@ -190,10 +182,7 @@ def _csr_todense_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_todense_translation_rule(ctx, avals_in, avals_out, data, indices,
indptr, shape=shape)
if cusparse:
return [cusparse.csr_todense(ctx.builder, data, indices, indptr, shape=shape)]
else:
return [hipsparse.csr_todense(ctx.builder, data, indices, indptr, shape=shape)]
return [sparse_apis.csr_todense(ctx.builder, data, indices, indptr, shape=shape)]

def _csr_todense_jvp(data_dot, data, indices, indptr, *, shape):
return csr_todense(data_dot, indices, indptr, shape=shape)
Expand All @@ -212,7 +201,7 @@ def _csr_todense_transpose(ct, data, indices, indptr, *, shape):
ad.defjvp(csr_todense_p, _csr_todense_jvp, None, None)
ad.primitive_transposes[csr_todense_p] = _csr_todense_transpose
xla.register_translation(csr_todense_p, _csr_todense_translation_rule)
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_todense_p, _csr_todense_gpu_translation_rule,
platform='gpu')

Expand Down Expand Up @@ -274,12 +263,8 @@ def _csr_fromdense_gpu_translation_rule(ctx, avals_in, avals_out, mat, *, nse,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_fromdense_translation_rule(ctx, avals_in, avals_out, mat,
nse=nse, index_dtype=index_dtype)
if cusparse:
data, indices, indptr = cusparse.csr_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
else:
data, indices, indptr = hipsparse.csr_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
data, indices, indptr = sparse_apis.csr_fromdense(
ctx.builder, mat, nnz=nse, index_dtype=np.dtype(index_dtype))
return [data, indices, indptr]

def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype):
Expand Down Expand Up @@ -310,7 +295,7 @@ def _csr_fromdense_transpose(ct, M, *, nse, index_dtype):
ad.primitive_jvps[csr_fromdense_p] = _csr_fromdense_jvp
ad.primitive_transposes[csr_fromdense_p] = _csr_fromdense_transpose
xla.register_translation(csr_fromdense_p, _csr_fromdense_translation_rule)
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_fromdense_p,
_csr_fromdense_gpu_translation_rule,
platform='gpu')
Expand Down Expand Up @@ -366,12 +351,8 @@ def _csr_matvec_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_matvec_translation_rule(ctx, avals_in, avals_out, data, indices, indptr, v,
shape=shape, transpose=transpose)
if cusparse:
return [cusparse.csr_matvec(ctx.builder, data, indices, indptr, v,
shape=shape, transpose=transpose)]
else:
return [hipsparse.csr_matvec(ctx.builder, data, indices, indptr, v,
shape=shape, transpose=transpose)]
return [sparse_apis.csr_matvec(ctx.builder, data, indices, indptr, v,
shape=shape, transpose=transpose)]

def _csr_matvec_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose):
return csr_matvec(data_dot, indices, indptr, v, shape=shape, transpose=transpose)
Expand All @@ -395,7 +376,7 @@ def _csr_matvec_transpose(ct, data, indices, indptr, v, *, shape, transpose):
ad.defjvp(csr_matvec_p, _csr_matvec_jvp_mat, None, None, _csr_matvec_jvp_vec)
ad.primitive_transposes[csr_matvec_p] = _csr_matvec_transpose
xla.register_translation(csr_matvec_p, _csr_matvec_translation_rule)
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_matvec_p, _csr_matvec_gpu_translation_rule,
platform='gpu')

Expand Down Expand Up @@ -452,12 +433,8 @@ def _csr_matmat_gpu_translation_rule(ctx, avals_in, avals_out, data, indices,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_matmat_translation_rule(ctx, avals_in, avals_out, data, indices, indptr, B,
shape=shape, transpose=transpose)
if cusparse is not None:
return [cusparse.csr_matmat(ctx.builder, data, indices, indptr, B,
shape=shape, transpose=transpose)]
else:
return [hipsparse.csr_matmat(ctx.builder, data, indices, indptr, B,
shape=shape, transpose=transpose)]
return [sparse_apis.csr_matmat(ctx.builder, data, indices, indptr, B,
shape=shape, transpose=transpose)]

def _csr_matmat_jvp_left(data_dot, data, indices, indptr, B, *, shape, transpose):
return csr_matmat(data_dot, indices, indptr, B, shape=shape, transpose=transpose)
Expand All @@ -479,6 +456,6 @@ def _csr_matmat_transpose(ct, data, indices, indptr, B, *, shape, transpose):
ad.defjvp(csr_matmat_p, _csr_matmat_jvp_left, None, None, _csr_matmat_jvp_right)
ad.primitive_transposes[csr_matmat_p] = _csr_matmat_transpose
xla.register_translation(csr_matmat_p, _csr_matmat_translation_rule)
if (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported):
if sparse_apis and sparse_apis.is_supported:
xla.register_translation(csr_matmat_p, _csr_matmat_gpu_translation_rule,
platform='gpu')
14 changes: 7 additions & 7 deletions tests/sparse_test.py
Expand Up @@ -32,8 +32,7 @@
from jax.experimental.sparse import coo as sparse_coo
from jax.experimental.sparse.bcoo import BCOOInfo
from jax import lax
from jax._src.lib import cusparse
from jax._src.lib import hipsparse
from jax._src.lib import sparse_apis
from jax._src.lib import xla_bridge
from jax import jit
from jax import tree_util
Expand All @@ -56,7 +55,7 @@
np.complex128: 1E-10,
}

GPU_LOWERING_ENABLED = (cusparse and cusparse.is_supported) or (hipsparse and hipsparse.is_supported)
GPU_LOWERING_ENABLED = (sparse_apis and sparse_apis.is_supported)

class BcooDotGeneralProperties(NamedTuple):
lhs_shape: Tuple[int]
Expand Down Expand Up @@ -441,7 +440,8 @@ def test_coo_sorted_indices(self):
mat_resorted = mat_unsorted._sort_rows()
self.assertArraysEqual(mat.todense(), mat_resorted.todense())

@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse")
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
@jtu.skip_on_devices("rocm") # TODO(rocm): see SWDEV-328107
def test_coo_sorted_indices_gpu_lowerings(self):
dtype = jnp.float32

Expand Down Expand Up @@ -510,15 +510,15 @@ def test_gpu_translation_rule(self):
cuda_version = None if version == "<unknown>" else int(
version.split()[-1])
if cuda_version is None or cuda_version < 11000:
self.assertFalse(cusparse and cusparse.is_supported)
self.assertFalse(sparse_apis and sparse_apis.is_supported)
self.assertNotIn(sparse.csr_todense_p,
xla._backend_specific_translations["gpu"])
else:
self.assertTrue(cusparse and cusparse.is_supported)
self.assertTrue(sparse_apis and sparse_apis.is_supported)
self.assertIn(sparse.csr_todense_p,
xla._backend_specific_translations["gpu"])
else:
self.assertTrue(hipsparse and hipsparse.is_supported)
self.assertTrue(sparse_apis and sparse_apis.is_supported)
self.assertIn(sparse.csr_todense_p,
xla._backend_specific_translations["gpu"])

Expand Down

0 comments on commit 8cd0294

Please sign in to comment.