Skip to content

Commit

Permalink
Triton kernel for bsr @ dense (#94823)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitaved authored and cyyever committed Mar 5, 2023
1 parent 503d15f commit 81f02b0
Show file tree
Hide file tree
Showing 3 changed files with 666 additions and 1 deletion.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ ignore_errors = True
# Third party dependencies that don't have types.
#

[mypy-triton.*]
ignore_missing_imports = True

[mypy-tensorflow.*]
ignore_missing_imports = True

Expand Down
56 changes: 55 additions & 1 deletion test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
from torch.testing._internal.common_utils import \
(TEST_WITH_ROCM, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff, parametrize,
subtest, skipIfTorchDynamo)
subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU)
from torch.testing._internal.common_device_type import \
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
precisionOverride, skipMeta, skipCUDAIf, skipCUDAIfRocm, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan)
Expand Down Expand Up @@ -1462,6 +1462,60 @@ def run_test_block_addmm_addmv(self,
self.assertEqual(actual, out)
self.assertEqual(actual, expected)

@parametrize("block_size", [16, 32, 64])
@parametrize("index_dtype", [torch.int32, torch.int64])
@onlyCUDA
@skipIfRocm
@dtypes(torch.half, torch.bfloat16)
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [])
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
from functools import partial

from torch._inductor.utils import has_triton
from torch.sparse._triton_ops import bsr_dense_mm

if not has_triton():
self.skipTest("Triton is not available.")

# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)

# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
batches = [(), (2,)]
size = [128, 256, 0]

# Whether to make inputs orthogonal so that the product is zero
make_orthogonal = [True, False]

for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
bsr = tensor(bs + (m, k))
# NOTE: do not get confused, it will be transposed
dense = tensor(bd + (n, k))

if is_ortho:
bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)

bsr = bsr.to_sparse_bsr(block_size)

res_tri = bsr_dense_mm(bsr, dense.transpose(-2, -1))
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
self.assertEqual(res_tri, res_dense)

# check whether bsr_dense_mm handles different grid sizes
# None means max possible grid size which is CUDA-dependent.
grid_size = (None, 2, 4)
grid_gen = itertools.product(grid_size, repeat=3)
for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
res_tri = bsr_dense_mm(
bsr,
dense.transpose(-2, -1),
max_grid=grid,
is_sparse_rowspace_mode=is_sparse_rowspace
)
self.assertEqual(res_tri, res_dense)

# TODO: block_size 1 is broken
@parametrize("block_size", [2, 3])
@parametrize("index_dtype", [torch.int32, torch.int64])
Expand Down

0 comments on commit 81f02b0

Please sign in to comment.