Skip to content

Commit

Permalink
Sparse CSR CUDA: Add torch.sparse.sampled_addmm (pytorch#68007)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#68007

This PR adds a new function to the sparse module.
`sampled_addmm` computes α*(A @ B) * spy(C) + β*C, where C is a sparse CSR matrix and A, B are dense (strided) matrices.
This function is currently restricted to single 2D matrices, it doesn't support batched input.

cc nikitaved pearu cpuhrsch IvanYashchuk

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D32435799

Pulled By: cpuhrsch

fbshipit-source-id: b1ffac795080aef3fa05eaeeded03402bc097392
  • Loading branch information
IvanYashchuk authored and pull[bot] committed Feb 13, 2023
1 parent 52a55e3 commit 60ac95f
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 2 deletions.
9 changes: 8 additions & 1 deletion aten/src/ATen/cuda/CUDASparse.h
Expand Up @@ -18,9 +18,16 @@
#define AT_USE_CUSPARSE_GENERIC_SPSV() 0
#endif

// cuSparse Generic API spsv function was added in CUDA 11.3.1
// cuSparse Generic API spsm function was added in CUDA 11.3.1
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600)
#define AT_USE_CUSPARSE_GENERIC_SPSM() 1
#else
#define AT_USE_CUSPARSE_GENERIC_SPSM() 0
#endif

// cuSparse Generic API sddmm function was added in CUDA 11.2.1 (cuSparse version 11400)
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400)
#define AT_USE_CUSPARSE_GENERIC_SDDMM() 1
#else
#define AT_USE_CUSPARSE_GENERIC_SDDMM() 0
#endif
10 changes: 10 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -5004,6 +5004,16 @@
dispatch:
CompositeExplicitAutograd: _sparse_addmm

- func: sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
python_module: sparse
dispatch:
SparseCsrCUDA: sparse_sampled_addmm_out_sparse_csr_cuda

- func: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
python_module: sparse
dispatch:
SparseCsrCUDA: sparse_sampled_addmm_sparse_csr_cuda

- func: addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
Expand Down
79 changes: 79 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseBlas.cpp
Expand Up @@ -9,6 +9,85 @@
namespace at {
namespace native {

/*
Computes `result` <- α*(A @ B) * spy(C) + β*C, where spy(C) is the sparsity pattern matrix of C.
Args:
* `mat1` - [in] dense Tensor A of size m × k.
* `mat2` - [in] dense Tensor B of size k × n.
* `self` - [in] sparse Tensor C of size m × n.
* `result` - [out] sparse Tensor of size m × n.
*/
Tensor& sparse_sampled_addmm_out_sparse_csr_cuda(
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.is_sparse_csr());

TORCH_CHECK(mat1.layout() == kStrided, "sampled_addmm: Expected mat1 to have strided layout, but got ", mat1.layout());
TORCH_CHECK(mat2.layout() == kStrided, "sampled_addmm: Expected mat2 to have strided layout, but got ", mat2.layout());

TORCH_CHECK(result.layout() == kSparseCsr, "sampled_addmm: Expected result to have sparse csr layout, but got ", result.layout());

TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "sampled_addmm: Expected mat1 and mat2 to have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
TORCH_CHECK(mat1.scalar_type() == self.scalar_type(), "sampled_addmm: Expected mat1 and self to have the same dtype, but got ", mat1.scalar_type(), " and ", self.scalar_type());
TORCH_CHECK(result.scalar_type() == self.scalar_type(), "sampled_addmm: Expected result and self to have the same dtype, but got ", result.scalar_type(), " and ", self.scalar_type());

TORCH_CHECK(
mat1.dim() == 2, "sampled_addmm: Expected mat1 to be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(
mat2.dim() == 2, "sampled_addmm: Expected mat2 to be a matrix, got ", mat2.dim(), "-D tensor");
TORCH_CHECK(
result.dim() == 2, "sampled_addmm: Expected result to be a matrix, got ", result.dim(), "-D tensor");

IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
TORCH_CHECK(
mat1_sizes[1] == mat2_sizes[0],
"sampled_addmm: mat1 and mat2 shapes cannot be multiplied (",
mat1_sizes[0],
"x",
mat1_sizes[1],
" and ",
mat2_sizes[0],
"x",
mat2_sizes[1],
")");

IntArrayRef self_sizes = self.sizes();
TORCH_CHECK(
self_sizes[0] == mat1_sizes[0], "sampled_addmm: self dim 0 must match mat1 dim 0");
TORCH_CHECK(
self_sizes[1] == mat2_sizes[1], "sampled_addmm: self dim 1 must match mat2 dim 1");

if (&result != &self) {
at::native::resize_as_sparse_csr_(result, self);
result.copy_(self);
}

// there's a segfault when calling cuSPARSE on 0-sized matrices
if (mat1.numel() == 0 || mat2.numel() == 0) {
return result;
}

sparse::impl::cuda::sampled_addmm_out_sparse_csr(mat1, mat2, beta, alpha, result);
return result;
}

Tensor sparse_sampled_addmm_sparse_csr_cuda(
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha) {
auto result = at::empty({0, 0}, self.options());
at::native::sparse_sampled_addmm_out_sparse_csr_cuda(self, mat1, mat2, beta, alpha, result);
return result;
}

Tensor& addmm_out_sparse_csr_cuda(
const Tensor& self,
const Tensor& mat1,
Expand Down
79 changes: 79 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
Expand Up @@ -893,6 +893,85 @@ void triangular_solve_out_sparse_csr(
#endif // !AT_USE_CUSPARSE_GENERIC_SPSV()
}

void sampled_addmm_out_sparse_csr(
const Tensor& A,
const Tensor& B,
const Scalar& beta,
const Scalar& alpha,
const at::sparse_csr::SparseCsrTensor& C) {
#if !AT_USE_CUSPARSE_GENERIC_SDDMM()
TORCH_CHECK(
false,
"Calling sampled_addmm with sparse GPU tensors requires compiling ",
"PyTorch with CUDA 11.2.1+. ",
"Please use PyTorch built with newer CUDA version.");
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.layout() == Layout::Strided);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(B.layout() == Layout::Strided);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(C.is_sparse_csr());

auto descA = at::cuda::sparse::CuSparseDnMatDescriptor(A);
auto descB = at::cuda::sparse::CuSparseDnMatDescriptor(B);
auto descC = at::cuda::sparse::CuSparseSpMatCsrDescriptor(C);

cusparseOperation_t opA = CUSPARSE_OPERATION_NON_TRANSPOSE;
cusparseOperation_t opB = CUSPARSE_OPERATION_NON_TRANSPOSE;

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
C.scalar_type(),
"sampled_addmm_out_sparse_csr",
[&] {
auto beta_ = beta.to<scalar_t>();
auto alpha_ = alpha.to<scalar_t>();
auto compute_type = at::cuda::getCudaDataType<scalar_t>();
auto handle = at::cuda::getCurrentCUDASparseHandle();
size_t buffer_size = 0;
TORCH_CUDASPARSE_CHECK(cusparseSDDMM_bufferSize(
handle,
opA,
opB,
&alpha_,
descA.descriptor(),
descB.descriptor(),
&beta_,
descC.descriptor(),
compute_type,
CUSPARSE_SDDMM_ALG_DEFAULT,
&buffer_size // output
));

auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto buffer = allocator.allocate(buffer_size);

TORCH_CUDASPARSE_CHECK(cusparseSDDMM_preprocess(
handle,
opA,
opB,
&alpha_,
descA.descriptor(),
descB.descriptor(),
&beta_,
descC.descriptor(),
compute_type,
CUSPARSE_SDDMM_ALG_DEFAULT,
buffer.get()));

TORCH_CUDASPARSE_CHECK(cusparseSDDMM(
handle,
opA,
opB,
&alpha_,
descA.descriptor(),
descB.descriptor(),
&beta_,
descC.descriptor(),
compute_type,
CUSPARSE_SDDMM_ALG_DEFAULT,
buffer.get()));
});
#endif
}

} // namespace cuda
} // namespace impl
} // namespace sparse
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/native/sparse/cuda/SparseBlasImpl.h
Expand Up @@ -39,6 +39,13 @@ void triangular_solve_out_sparse_csr(
bool transpose,
bool unitriangular);

void sampled_addmm_out_sparse_csr(
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
const at::sparse_csr::SparseCsrTensor& result);

} // namespace cuda
} // namespace impl
} // namespace sparse
Expand Down
118 changes: 117 additions & 1 deletion test/test_sparse_csr.py
Expand Up @@ -10,7 +10,7 @@
(TEST_WITH_ROCM, TEST_SCIPY, TestCase, run_tests, load_tests, coalescedonoff)
from torch.testing._internal.common_device_type import \
(ops, instantiate_device_type_tests, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoCusparseGeneric,
precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse)
precisionOverride, skipMeta, skipCUDAIf, skipCUDAIfRocm, skipCPUIfNoMklSparse)
from torch.testing._internal.common_methods_invocations import (sparse_csr_unary_ufuncs, )
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing._internal.common_dtype import floating_types, get_all_dtypes
Expand All @@ -35,6 +35,12 @@ def _check_cusparse_spgemm_available():
min_supported_version = (11, 0)
return version >= min_supported_version

def _check_cusparse_sddmm_available():
version = _get_torch_cuda_version()
# cusparseSDDMM was added in 11.2.1 but we don't have access to patch version
min_supported_version = (11, 3)
return version >= min_supported_version

# This should be just an import from test_linalg instead of code duplication
# but https://github.com/pytorch/pytorch/pull/63511#discussion_r733989701
def _test_addmm_addmv(test_case, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, layout=torch.strided, all_sparse=False):
Expand Down Expand Up @@ -978,6 +984,116 @@ def run_test(n, k, upper, unitriangular, transpose):
itertools.product([True, False], repeat=3)):
run_test(n, k, upper, unitriangular, transpose)

@skipCUDAIfRocm
@onlyCUDA
@skipCUDAIf(
not _check_cusparse_sddmm_available(),
"cuSparse Generic API SDDMM is not available"
)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_sampled_addmm(self, device, dtype):
def run_test(c, a, b, op_a, op_b, *, alpha=None, beta=None):
if dtype.is_complex:
alpha = random.random() + 0.3j if alpha is None else alpha
beta = random.random() + 0.6j if beta is None else beta
else:
alpha = random.random() if alpha is None else alpha
beta = random.random() if beta is None else beta

if op_a and a.shape == b.shape:
a = a.mH
if op_b and a.shape == b.shape:
b = b.mH

actual = torch.sparse.sampled_addmm(c, a, b, alpha=alpha, beta=beta)

out = torch.sparse_csr_tensor(
*map(torch.clone, (actual.crow_indices(), actual.col_indices())),
torch.empty_like(actual.values()),
size=c.shape
)
torch.sparse.sampled_addmm(c, a, b, alpha=alpha, beta=beta, out=out)

spy_c = torch.sparse_csr_tensor(c.crow_indices(), c.col_indices(), torch.ones_like(c.values()), size=c.shape)
expected = alpha * (a @ b) * spy_c.to_dense() + beta * c.to_dense()
self.assertEqual(actual.to_dense(), out.to_dense())
self.assertEqual(actual.to_dense(), expected)

for index_dtype in [torch.int32, torch.int64]:
for (m, n, k), noncontiguous in zip(itertools.product([1, 5], repeat=3), [True, False]):
nnz = random.randint(0, m * n)
c = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
a = make_tensor((m, k), dtype=dtype, device=device, noncontiguous=noncontiguous)
b = make_tensor((k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
for op_a, op_b in itertools.product([True, False], repeat=2):
run_test(c, a, b, op_a, op_b)

@skipCUDAIfRocm
@onlyCUDA
@skipCUDAIf(
not _check_cusparse_sddmm_available(),
"cuSparse Generic API SDDMM is not available"
)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
torch.float64: 1e-8, torch.complex128: 1e-8})
def test_sampled_addmm_zero_sized(self, device, dtype):
def run_test(c, a, b):
actual = torch.sparse.sampled_addmm(c, a, b)
self.assertEqual(actual.shape, c.shape)

for m, n, k in itertools.product([0, 5], repeat=3):
c = torch.empty(m, n, dtype=dtype, device=device, layout=torch.sparse_csr)
a = make_tensor((m, k), dtype=dtype, device=device)
b = make_tensor((k, n), dtype=dtype, device=device)
run_test(c, a, b)

@skipCUDAIfRocm
@onlyCUDA
@skipCUDAIf(
not _check_cusparse_sddmm_available(),
"cuSparse Generic API SDDMM is not available"
)
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_sampled_addmm_errors(self, device, dtype):
# test that the errors are the same for dense and sparse sampled versions
# import re

# shapes must be compatible for matrix multiplication
a = make_tensor((2, 3), dtype=dtype, device=device)
a_sparse = a.to_sparse_csr()
with self.assertRaisesRegex(RuntimeError, r"cannot be multiplied"):
torch.sparse.sampled_addmm(a_sparse, a, a)

# mat1 must be a matrix
with self.assertRaisesRegex(RuntimeError, r"Expected mat1 to be a matrix"):
torch.sparse.sampled_addmm(a_sparse, a.unsqueeze(0), a)

# mat2 must be a matrix
with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to be a matrix"):
torch.sparse.sampled_addmm(a_sparse, a, a.unsqueeze(0))

a = make_tensor((2, 2), dtype=dtype, device=device)
b = make_tensor((3, 3), dtype=dtype, device=device)
b_sparse = b.to_sparse_csr()
with self.assertRaisesRegex(RuntimeError, r"self dim 0 must match mat1 dim 0"):
torch.sparse.sampled_addmm(b_sparse, a, a)

b = make_tensor((2, 3), dtype=dtype, device=device)
b_sparse = b.to_sparse_csr()
with self.assertRaisesRegex(RuntimeError, r"self dim 1 must match mat2 dim 1"):
torch.sparse.sampled_addmm(b_sparse, a, a)

a = make_tensor((2, 2), dtype=dtype, device=device)
a_sparse = a.to_sparse_csr()
with self.assertRaisesRegex(RuntimeError, r"Expected mat1 to have strided layout"):
torch.sparse.sampled_addmm(a_sparse, a_sparse, a_sparse)

with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to have strided layout"):
torch.sparse.sampled_addmm(a_sparse, a, a_sparse)

@dtypes(*get_all_dtypes())
def test_coo_csr_conversion(self, device, dtype):
for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
Expand Down

0 comments on commit 60ac95f

Please sign in to comment.