Skip to content

Commit

Permalink
[sparse] Updates bcoo_dot_general cuSparse lowering rule by adding …
Browse files Browse the repository at this point in the history
…sorted indices.

PiperOrigin-RevId: 428621454
  • Loading branch information
tlu7 authored and jax authors committed Feb 14, 2022
1 parent 7204ac3 commit 273ea62
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 23 deletions.
19 changes: 14 additions & 5 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def f(data, indices):
f = broadcasting_vmap(f)
return f(data, indices)

_bcoo_sort_indices_rule = xla.lower_fun(
_bcoo_sort_indices, multiple_results=True, new_style=True)

def _unbatch_bcoo(data, indices, shape):
n_batch = _validate_bcoo(data, indices, shape).n_batch
if n_batch == 0:
Expand Down Expand Up @@ -736,6 +739,11 @@ def _bcoo_dot_general_gpu_translation_rule(
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs, *, dimension_numbers,
lhs_spinfo: BCOOInfo):

if not config.jax_bcoo_cusparse_lowering:
return _bcoo_dot_general_default_translation_rule(
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)

(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, rhs_aval, = avals_in
n_batch, n_sparse, n_dense, nse = _validate_bcoo(
Expand All @@ -757,10 +765,11 @@ def _bcoo_dot_general_gpu_translation_rule(
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
else:
# The lhs indices are row-wise sorted.
lhs_indices_row, lhs_indices_col, lhs_data = lax.sort(
[lhs_indices[:, 0], lhs_indices[:, 1], lhs_data])
lhs_indices = jnp.hstack((lhs_indices_row, lhs_indices_col))
# Sorts lhs by row indices.
lhs_data, lhs_indices = _bcoo_sort_indices_rule(
ctx, avals_in[:2], avals_in[:2], lhs_data, lhs_indices,
shape=lhs_spinfo.shape)

return _bcoo_dot_general_cuda_translation_rule(
ctx, avals_in, avals_out, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
Expand Down Expand Up @@ -833,7 +842,7 @@ def _bcoo_dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,

xla.register_translation(
bcoo_dot_general_p, _bcoo_dot_general_default_translation_rule)
if config.jax_bcoo_cusparse_lowering and cusparse and cusparse.is_supported:
if cusparse and cusparse.is_supported:
xla.register_translation(bcoo_dot_general_p,
_bcoo_dot_general_gpu_translation_rule,
platform='gpu')
Expand Down
47 changes: 29 additions & 18 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from jax.util import split_list
import numpy as np
import scipy.sparse

config.parse_flags_with_absl()
FLAGS = config.FLAGS

Expand Down Expand Up @@ -915,36 +916,46 @@ def create_unsorted_indices(data, indices):
lhs_2d_sparse, lhs_sparse_2d_indicdes = create_unsorted_indices(
lhs_2d_sparse, lhs_sparse_2d_indicdes)

dimension_numbers = (([1], [0]), ([], []))
expected_2d = lax.dot_general(
lhs_2d_dense, rhs, dimension_numbers=dimension_numbers)
actual_2d = sparse.bcoo_dot_general(
lhs_2d_sparse, lhs_sparse_2d_indicdes, rhs,
dimension_numbers=dimension_numbers,
lhs_spinfo=BCOOInfo(lhs_2d_dense.shape))
dimension_numbers_2d = (([1], [0]), ([], []))

def args_maker_2d():
return lhs_2d_sparse, lhs_sparse_2d_indicdes, lhs_2d_dense, rhs

def f_dense_2d(data, indices, lhs, rhs):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers_2d)

def f_sparse_2d(data, indices, lhs, rhs):
return sparse.bcoo_dot_general(data, indices, rhs,
dimension_numbers=dimension_numbers_2d,
lhs_spinfo=BCOOInfo(lhs.shape))
with self.subTest(msg="2D"):
self.assertAllClose(expected_2d, actual_2d)
self._CompileAndCheck(f_sparse_2d, args_maker_2d)
self._CheckAgainstNumpy(f_dense_2d, f_sparse_2d, args_maker_2d)

# It creates out-of-bound indices when nse > nnz.
lhs_1d_dense = jnp.array([0, 1, 0, 2, 0], dtype=jnp.float32)
lhs_1d_sparse, lhs_sparse_1d_indicdes = sparse.bcoo_fromdense(
lhs_1d_dense, nse=7)
lhs_1d_dense, nse=5)

# Random permutate the indices to make them unsorted.
lhs_1d_sparse, lhs_sparse_1d_indicdes = create_unsorted_indices(
lhs_1d_sparse, lhs_sparse_1d_indicdes)

dimension_numbers = (([0], [0]), ([], []))
expected_1d = lax.dot_general(
lhs_1d_dense, rhs, dimension_numbers=dimension_numbers)
actual_1d = sparse.bcoo_dot_general(
lhs_1d_sparse, lhs_sparse_1d_indicdes, rhs,
dimension_numbers=dimension_numbers,
lhs_spinfo=BCOOInfo(lhs_1d_dense.shape))
dimension_numbers_1d = (([0], [0]), ([], []))

def args_maker_1d():
return lhs_1d_sparse, lhs_sparse_1d_indicdes, lhs_1d_dense, rhs

def f_dense_1d(data, indices, lhs, rhs):
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers_1d)

def f_sparse_1d(data, indices, lhs, rhs):
return sparse.bcoo_dot_general(data, indices, rhs,
dimension_numbers=dimension_numbers_1d,
lhs_spinfo=BCOOInfo(lhs.shape))

with self.subTest(msg="1D"):
self.assertAllClose(expected_1d, actual_1d)
self._CompileAndCheck(f_sparse_1d, args_maker_1d)
self._CheckAgainstNumpy(f_dense_1d, f_sparse_1d, args_maker_1d)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": props.testcase_name(), "props": props}
Expand Down

0 comments on commit 273ea62

Please sign in to comment.