Skip to content

Commit

Permalink
better take on sparse tensors, put layout on the correct device
Browse files Browse the repository at this point in the history
update the pip packages, minor cleanup
  • Loading branch information
blefaudeux committed Apr 17, 2022
1 parent 05b3821 commit 4ff051b
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 13 deletions.
3 changes: 1 addition & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
author = "Facebook AI Research"

# The full version, including alpha/beta/rc tags
release = "0.0.10"

release = "0.0.11.dev"

# -- General configuration ---------------------------------------------------

Expand Down
1 change: 0 additions & 1 deletion requirements-benchmark.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@ scikit-learn == 0.24.1
tqdm == 4.59.0
pandas == 1.2.4
seaborn == 0.11.1
triton == 2.0.0.dev20220323
pytorch-lightning >= 1.3
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ hydra-core >= 1.1
fairscale >= 0.4.5

# Dependency for fused layers, optional
triton == 2.0.0.dev20220323
triton == 2.0.0.dev20220403
2 changes: 1 addition & 1 deletion tests/test_sparse_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def _create_blocksparse_tensor(
device, block_size=32, Z=8, C=2, H=64, W=64, dtype=torch.float32
):
layout = torch.randint(2, (C, H // block_size, W // block_size))
layout = torch.randint(2, (C, H // block_size, W // block_size), device=device)
layout[:, :, 0] = 1
layout[:, 0, :] = 1
values = torch.randn(Z, layout.sum(), block_size, block_size, device=device).to(
Expand Down
2 changes: 1 addition & 1 deletion xformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

# Please update the doc version in docs/source/conf.py as well.
__version__ = "0.0.10"
__version__ = "0.0.11.dev"

_is_sparse_available = True
_is_triton_available = torch.cuda.is_available()
Expand Down
15 changes: 8 additions & 7 deletions xformers/sparse/blocksparse_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
blocksparse_softmax = None


def _can_use_triton(a, b):
if a.device.type == "cpu" or b.device.type == "cpu":
def _can_use_triton(a):
if a.device.type == "cpu":
return False

if blocksparse_matmul is None:
Expand Down Expand Up @@ -107,6 +107,9 @@ def __new__(cls, values, layout):

def __init__(self, values, layout):
assert values.shape[-2] == values.shape[-1]
assert (
values.device == layout.device
), "Both values and layout need to reside on the same device"
block_size = values.shape[-1]
# TODO: make this check conditioned on the use of Triton
assert block_size >= 16, "Minimum block size is 16, for now at least"
Expand Down Expand Up @@ -175,9 +178,7 @@ def _wrap(cls, values, bmat):
def _bmm(cls, arg0, arg1):
if not (isinstance(arg0, cls) and type(arg1) == torch.Tensor):
return NotImplemented
if _can_use_triton(arg1, arg0.__sparse_dot_dsd.layout):
# Triton requires all the tensors to be on GPU,
# which may not be the case depending on what layout was passed
if _can_use_triton(arg1):
res = arg0.__sparse_dot_dsd(arg0.__values, arg1)
else:
res = _spmm(arg1, arg0.__layout, arg0.__values)
Expand All @@ -189,7 +190,7 @@ def _masked_matmul(cls, a, b, mask):
return NotImplemented
b = b.transpose(-2, -1)
assert b.is_contiguous()
if _can_use_triton(a, mask.__sparse_dot_sdd.layout):
if _can_use_triton(a):
res = mask.__sparse_dot_sdd(a, b)
else:
res = _sddmm(a, b, mask.__layout)
Expand All @@ -199,7 +200,7 @@ def _masked_matmul(cls, a, b, mask):
def _softmax(cls, arg0, dim):
if not (dim == -1 or dim == 2):
return NotImplemented
if _can_use_triton(arg0, arg0.__sparse_softmax.layout):
if _can_use_triton(arg0):
# TODO triton softmax performs an in-place operation
# res = arg0.__sparse_softmax(arg0.__values)
res = arg0.__sparse_softmax(arg0.__values.clone())
Expand Down

0 comments on commit 4ff051b

Please sign in to comment.