Skip to content

Commit

Permalink
[sparse] Lower batch-mode bcoo_dot_genernal to cusparseSpMM.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 473777597
  • Loading branch information
tlu7 authored and jax authors committed Sep 12, 2022
1 parent 5c266ba commit 3243e23
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 41 deletions.
157 changes: 123 additions & 34 deletions jax/experimental/sparse/bcoo.py
Expand Up @@ -42,6 +42,8 @@
_const, ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir.dialects import mhlo
from jax._src.numpy.setops import _unique

Expand Down Expand Up @@ -736,9 +738,8 @@ def _bcoo_dot_general_cuda_lowering(

# Checks the shapes of lhs and rhs.
assert props.n_dense == 0
assert props.n_batch == 0
assert props.n_sparse in [1, 2]
assert rhs_ndim in [1, 2]
assert (props.n_batch, props.n_sparse, rhs_ndim) in [
(0, 1, 1), (0, 1, 2), (0, 2, 1), (0, 2, 2), (1, 2, 2)]

# Checks the operation dimensions.
assert len(lhs_batch) == 0
Expand All @@ -761,54 +762,123 @@ def _bcoo_dot_general_cuda_lowering(
else:
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_ndim}d.")

lhs_transpose = False
if props.n_sparse == 1:
# Converts lhs to a row vector.
col = _collapse_mhlo(lhs_indices, start=0, end=1)
row = mlir.full_like_aval(
0, core.ShapedArray(ir.RankedTensorType(col.type).shape,
np.dtype(np.int32)))
lhs_shape = (1, lhs_spinfo.shape[0])
dot_product = bcoo_dot_general_fn(
lhs_data, row, col, rhs, shape=lhs_shape, transpose=lhs_transpose,
data_dtype=lhs_data_aval.dtype, index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)

if rhs_ndim == 1:
# Transforms a single-element array to a scalar.
return [mhlo.ReshapeOp(
ir.RankedTensorType.get(
[], ir.RankedTensorType(dot_product.type).element_type),
dot_product).result]
if props.n_batch == 0:
# non-batch mode.
lhs_transpose = False
if props.n_sparse == 1:
# Converts lhs to a row vector.
col = _collapse_mhlo(lhs_indices, start=0, end=1)
row = mlir.full_like_aval(
0, core.ShapedArray(ir.RankedTensorType(col.type).shape,
np.dtype(np.int32)))
lhs_shape = (1, lhs_spinfo.shape[0])
dot_product = bcoo_dot_general_fn(
lhs_data, row, col, rhs, shape=lhs_shape, transpose=lhs_transpose,
data_dtype=lhs_data_aval.dtype, index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)

if rhs_ndim == 1:
# Transforms a single-element array to a scalar.
return [mhlo.ReshapeOp(
ir.RankedTensorType.get(
[], ir.RankedTensorType(dot_product.type).element_type),
dot_product).result]
else:
return [_collapse_mhlo(dot_product, start=0, end=1)]
elif props.n_sparse == 2:
lhs_indices_shape = ir.RankedTensorType(lhs_indices.type).shape
row = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices,
start_indices=mlir.dense_int_elements([0, 0]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 1]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)
col = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices,
start_indices=mlir.dense_int_elements([0, 1]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 2]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)

if lhs_contract[0] == 0:
lhs_transpose = True

return [bcoo_dot_general_fn(
lhs_data, row, col, rhs, shape=lhs_spinfo.shape,
transpose=lhs_transpose, data_dtype=lhs_data_aval.dtype,
index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)]
else:
return [_collapse_mhlo(dot_product, start=0, end=1)]
elif props.n_sparse == 2:
raise ValueError(f"lhs has to be 1d or 2d; get {props.n_sparse}d.")
elif props.n_batch == 1:
# batch mode.
lhs_indices_shape = ir.RankedTensorType(lhs_indices.type).shape
lhs_data_shape = ir.RankedTensorType(lhs_data.type).shape
batch_count, _, _ = lhs_indices_shape
rhs_shape = ir.RankedTensorType(rhs.type).shape

# Squeeze the batch dimension for both indices and data.
lhs_indices_2d_shape = (np.prod(np.array(lhs_indices_shape)[:-1]),
lhs_indices_shape[-1])
lhs_data_1d_shape = (np.prod(np.array(lhs_data_shape)), )

lhs_indices_2d = mhlo.ReshapeOp(
ir.RankedTensorType.get(
lhs_indices_2d_shape,
ir.RankedTensorType(lhs_indices.type).element_type),
lhs_indices).result

lhs_data_1d = mhlo.ReshapeOp(
ir.RankedTensorType.get(
lhs_data_1d_shape,
ir.RankedTensorType(lhs_data.type).element_type),
lhs_data).result

row = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices,
lhs_indices_2d,
start_indices=mlir.dense_int_elements([0, 0]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 1]),
limit_indices=mlir.dense_int_elements([lhs_indices_2d_shape[0], 1]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)

col = _collapse_mhlo(
mhlo.SliceOp(
lhs_indices,
lhs_indices_2d,
start_indices=mlir.dense_int_elements([0, 1]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 2]),
limit_indices=mlir.dense_int_elements([lhs_indices_2d_shape[0], 2]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)

if lhs_contract[0] == 0:
lhs_transpose = True
# Broadcast rhs to have the same batch size as lhs.
# TODO(tianjianlu): remove broadcasting.
# Use batch_stride = 0 for non-batch.
# The issue (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643)
# in cusparse library does not allow batch_stride = 0 for a non-batched rhs.
batched_rhs_shape = (batch_count,) + tuple(rhs_shape)
batched_rhs = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get(batched_rhs_shape,
ir.RankedTensorType(rhs.type).element_type),
rhs,
broadcast_dimensions=mlir.dense_int_elements([1, 2])).result
batched_rhs_2d_shape = (np.prod(np.array(batched_rhs_shape)[:-1]), batched_rhs_shape[-1])
batched_rhs_2d = mhlo.ReshapeOp(
ir.RankedTensorType.get(
batched_rhs_2d_shape,
ir.RankedTensorType(batched_rhs.type).element_type),
batched_rhs).result

lhs_transpose = True if lhs_contract[0] == props.n_batch else False

return [bcoo_dot_general_fn(
lhs_data, row, col, rhs, shape=lhs_spinfo.shape,
lhs_data_1d, row, col, batched_rhs_2d, shape=lhs_spinfo.shape,
transpose=lhs_transpose, data_dtype=lhs_data_aval.dtype,
index_dtype=lhs_indices_aval.dtype,
x_dtype=rhs_aval.dtype)]
else:
raise ValueError(f"lhs has to be 1d or 2d; get {props.n_sparse}d.")
raise ValueError(f"n_batch has to be 0 or 1; get {props.n_batch}.")

def _bcoo_dot_general_gpu_lowering(
coo_matvec_lowering, coo_matmat_lowering,
Expand All @@ -820,7 +890,7 @@ def _bcoo_dot_general_gpu_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)

(lhs_contract, _), (lhs_batch, rhs_batch) = dimension_numbers
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_data_aval, lhs_indices_aval, rhs_aval, = ctx.avals_in
n_batch, n_sparse, n_dense, _ = _validate_bcoo(
lhs_data_aval, lhs_indices_aval, lhs_spinfo.shape)
Expand All @@ -834,7 +904,7 @@ def _bcoo_dot_general_gpu_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)

if (n_batch or n_dense or
if (n_batch > 1 or n_dense or
n_sparse not in [1, 2] or rhs_aval.ndim not in [1, 2] or
lhs_batch or rhs_batch or len(lhs_contract) != 1):
return _bcoo_dot_general_default_lowering(
Expand All @@ -850,6 +920,25 @@ def _bcoo_dot_general_gpu_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)

if n_batch == 1:
# The support for batched computation in cusparseSpMM COO was added in
# 11.6.1: https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cusparse-11.6.1
cuda_version = int(xla_bridge.get_backend().platform_version.split()[-1])

# TODO(tianjianlu): enable the batch mode of cusparseSpMv.
cuda_supported_batch_mode = (
n_sparse == 2 and rhs_aval.ndim == 2 and
len(lhs_contract) == 1 and lhs_contract[0] in [1, 2] and
len(rhs_contract) == 1 and rhs_contract[0] in [0, 1] and
cuda_version >= 11061 and jaxlib_version >= (0, 3, 18))
if not cuda_supported_batch_mode:
warnings.warn("bcoo_dot_general GPU lowering currently does not "
"support this batch-mode computation. Falling back to "
"the default implementation.", CuSparseEfficiencyWarning)
return _bcoo_dot_general_default_lowering(
ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)

return _bcoo_dot_general_cuda_lowering(
coo_matvec_lowering, coo_matmat_lowering, ctx, lhs_data, lhs_indices, rhs,
dimension_numbers=dimension_numbers, lhs_spinfo=lhs_spinfo)
Expand Down
27 changes: 20 additions & 7 deletions jaxlib/gpu_sparse.py
Expand Up @@ -283,22 +283,28 @@ def _coo_matmat_mhlo(platform, gpu_sparse, data, row, col, B, *, shape,
x_dtype, data_dtype, index_dtype):
"""COO from dense matrix."""
data_type, _, nnz = _validate_coo_mhlo(data, row, col)
rows, cols = shape
is_batched_matmat = False
batch_count = 1
if len(shape) == 2:
rows, cols = shape
elif len(shape) == 3:
is_batched_matmat = True
batch_count, rows, cols = shape
# Redefine nnz as nnz per batch.
nnz = nnz // batch_count

B_shape = ir.RankedTensorType(B.type).shape
_, Ccols = B_shape

if compute_dtype is None:
compute_dtype = data_dtype
compute_type = data_type

# TODO(tianjianlu): use user-defined batch count after enabling batch mode.
batch_count = 1

# TODO(tianjianlu): use batch stride to trigger different mode of batch
# computation. Currently batch_stride = 0 is not allowed because of the issue
# in cusparse https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643
# Set batch stride to be the matrix size for now.
lhs_batch_stride = rows * cols
lhs_batch_stride = nnz
B_rows = rows if transpose else cols
rhs_batch_stride = B_rows * Ccols

Expand All @@ -308,17 +314,24 @@ def _coo_matmat_mhlo(platform, gpu_sparse, data, row, col, B, *, shape,
rhs_batch_stride)
out_size = cols if transpose else rows

if is_batched_matmat:
out_shape = [batch_count, out_size, Ccols]
out_layout = [2, 1, 0]
else:
out_shape = [out_size, Ccols]
out_layout = [1, 0]

out = custom_call(
f"{platform}sparse_coo_matmat",
[
ir.RankedTensorType.get([out_size, Ccols], compute_type),
ir.RankedTensorType.get(out_shape, compute_type),
ir.RankedTensorType.get([buffer_size],
ir.IntegerType.get_signless(8)),
],
[data, row, col, B],
backend_config=opaque,
operand_layouts=[[0], [0], [0], [1, 0]],
result_layouts=[[1, 0], [0]])
result_layouts=[out_layout, [0]])
return out[0]

cuda_coo_matmat = partial(_coo_matmat_mhlo, "cu", _cusparse)
Expand Down

0 comments on commit 3243e23

Please sign in to comment.