Skip to content

Commit

Permalink
[sparse] Enable batch mode of COO matmat from cusparse kernels.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 465405490
  • Loading branch information
tlu7 authored and jax authors committed Aug 4, 2022
1 parent c5d4eb5 commit 07da502
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
20 changes: 20 additions & 0 deletions jaxlib/cuda/cusparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,18 +487,38 @@ std::pair<size_t, py::bytes> BuildCooMatmatDescriptor(
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;

// All three matrices A, B, and C must have the same batch_count.
// TODO(tianjianlu): use batch_count from matrix descriptor.
int batch_count = 1;

// Three batch modes are supported, C_i = A_i B, C_i = A B_i, and
// Ci = A_i B_i, where `i` denotes the batch dimension. Use `batch_stride` to
// trigger individual mode, e.g., using `batch_stride_B = 0` in C_i = A_i B.
int batch_stride_A = A.rows * A.cols;
int batch_stride_B = B.rows * B.cols;
int batch_stride_C = C.rows * C.cols;

// bufferSize does not reference these pointers, but does error on NULL.
int val = 0;
void* empty = &val;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseCreateCoo(&mat_a, A.rows, A.cols, A.nnz, empty, empty, empty,
A.index_type, CUSPARSE_INDEX_BASE_ZERO, A.value_type)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseCooSetStridedBatch(
mat_a, /*batchCount=*/batch_count, /*batchStride=*/batch_stride_A)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnMat(&mat_b, B.rows, B.cols, /*ld=*/B.cols,
empty, B.type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseDnMatSetStridedBatch(
mat_b, /*batchCount=*/batch_count, /*batchStride=*/batch_stride_B)));
JAX_THROW_IF_ERROR(
JAX_AS_STATUS(cusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols,
empty, C.type, CUSPARSE_ORDER_ROW)));
JAX_THROW_IF_ERROR(JAX_AS_STATUS(
cusparseDnMatSetStridedBatch(
mat_c, /*batchCount=*/batch_count, /*batchStride=*/batch_stride_C)));
size_t buffer_size;
CudaConst alpha = CudaOne(C.type);
CudaConst beta = CudaZero(C.type);
Expand Down
20 changes: 20 additions & 0 deletions jaxlib/cuda/cusparse_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -509,15 +509,35 @@ static absl::Status CooMatmat_(cudaStream_t stream, void** buffers,
cusparseDnMatDescr_t mat_b = 0;
cusparseDnMatDescr_t mat_c = 0;

// All three matrices A, B, and C must have the same batch_count.
// TODO(tianjianlu): use batch_count from matrix descriptor.
int batch_count = 1;

// Three batch modes are supported, C_i = A_i B, C_i = A B_i, and
// Ci = A_i B_i, where `i` denotes the batch dimension. Use `batch_stride` to
// trigger individual mode, e.g., using `batch_stride_B = 0` in C_i = A_i B.
int batch_stride_A = d.A.rows * d.A.cols;
int batch_stride_B = d.B.rows * d.B.cols;
int batch_stride_C = d.C.rows * d.C.cols;

JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateCoo(
&mat_a, d.A.rows, d.A.cols, d.A.nnz, coo_row_ind, coo_col_ind, coo_values,
d.A.index_type, CUSPARSE_INDEX_BASE_ZERO, d.A.value_type)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseCooSetStridedBatch(
mat_a, /*batchCount=*/batch_count, /*batchStride=*/batch_stride_A)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_b, d.B.rows, d.B.cols,
/*ld=*/d.B.cols, Bbuf, d.B.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseDnMatSetStridedBatch(
mat_b, /*batchCount=*/batch_count, /*batchStride=*/batch_stride_B)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseCreateDnMat(
&mat_c, d.C.rows, d.C.cols,
/*ld=*/d.C.cols, Cbuf, d.C.type, CUSPARSE_ORDER_ROW)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cusparseDnMatSetStridedBatch(
mat_c, /*batchCount=*/batch_count, /*batchStride=*/batch_stride_C)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusparseSpMM(
handle.get(), d.op_A, /*opB=*/CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
mat_a, mat_b, &beta, mat_c, d.C.type, CUSPARSE_SPMM_ALG_DEFAULT, buf)));
Expand Down

0 comments on commit 07da502

Please sign in to comment.