Skip to content

Commit

Permalink
better take on sparse tensors, put layout on the correct device
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Apr 17, 2022
1 parent 05b3821 commit 6c97cca
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tests/test_sparse_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def _create_blocksparse_tensor(
device, block_size=32, Z=8, C=2, H=64, W=64, dtype=torch.float32
):
layout = torch.randint(2, (C, H // block_size, W // block_size))
layout = torch.randint(2, (C, H // block_size, W // block_size), device=device)
layout[:, :, 0] = 1
layout[:, 0, :] = 1
values = torch.randn(Z, layout.sum(), block_size, block_size, device=device).to(
Expand Down
15 changes: 8 additions & 7 deletions xformers/sparse/blocksparse_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
blocksparse_softmax = None


def _can_use_triton(a, b):
if a.device.type == "cpu" or b.device.type == "cpu":
def _can_use_triton(a):
if a.device.type == "cpu":
return False

if blocksparse_matmul is None:
Expand Down Expand Up @@ -107,6 +107,9 @@ def __new__(cls, values, layout):

def __init__(self, values, layout):
assert values.shape[-2] == values.shape[-1]
assert (
values.device == layout.device
), "Both values and layout need to reside on the same device"
block_size = values.shape[-1]
# TODO: make this check conditioned on the use of Triton
assert block_size >= 16, "Minimum block size is 16, for now at least"
Expand Down Expand Up @@ -175,9 +178,7 @@ def _wrap(cls, values, bmat):
def _bmm(cls, arg0, arg1):
if not (isinstance(arg0, cls) and type(arg1) == torch.Tensor):
return NotImplemented
if _can_use_triton(arg1, arg0.__sparse_dot_dsd.layout):
# Triton requires all the tensors to be on GPU,
# which may not be the case depending on what layout was passed
if _can_use_triton(arg1):
res = arg0.__sparse_dot_dsd(arg0.__values, arg1)
else:
res = _spmm(arg1, arg0.__layout, arg0.__values)
Expand All @@ -189,7 +190,7 @@ def _masked_matmul(cls, a, b, mask):
return NotImplemented
b = b.transpose(-2, -1)
assert b.is_contiguous()
if _can_use_triton(a, mask.__sparse_dot_sdd.layout):
if _can_use_triton(a):
res = mask.__sparse_dot_sdd(a, b)
else:
res = _sddmm(a, b, mask.__layout)
Expand All @@ -199,7 +200,7 @@ def _masked_matmul(cls, a, b, mask):
def _softmax(cls, arg0, dim):
if not (dim == -1 or dim == 2):
return NotImplemented
if _can_use_triton(arg0, arg0.__sparse_softmax.layout):
if _can_use_triton(arg0):
# TODO triton softmax performs an in-place operation
# res = arg0.__sparse_softmax(arg0.__values)
res = arg0.__sparse_softmax(arg0.__values.clone())
Expand Down

0 comments on commit 6c97cca

Please sign in to comment.