Skip to content

Commit

Permalink
Minor blocksparse refactoring, update block size restrictions, relax …
Browse files Browse the repository at this point in the history
…power of two constraint (#277)

* Relax device size restrictions

* Refactor device creation and run all tests

* linting

Co-authored-by: Cole Hawkins <colehawk@amazon.com>
  • Loading branch information
2 people authored and blefaudeux committed Apr 20, 2022
1 parent c4c7b5f commit 72fb5c7
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 38 deletions.
2 changes: 1 addition & 1 deletion examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class GPT(pl.LightningModule):
""" the full GPT language model, with a context size of block_size """
"""the full GPT language model, with a context size of block_size"""

def __init__(
self,
Expand Down
12 changes: 7 additions & 5 deletions tests/test_triton_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=32, H=2, M=512, N=384, K


@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
@pytest.mark.parametrize("BLOCK", [32])
@pytest.mark.parametrize("BLOCK", [32, 128])
@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
def test_softmax(BLOCK, WIDTH, DTYPE):
Expand Down Expand Up @@ -127,12 +127,12 @@ def test_softmax(BLOCK, WIDTH, DTYPE):


@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
@pytest.mark.parametrize("block", [32, 43]) # 16, 32,
@pytest.mark.parametrize("block", [32, 43, 128]) # 16, 32,
def test_attention_fwd_bwd(
block,
input_scale=1.0,
scale=1 / 8.0,
n_ctx=256,
n_ctx=384,
dtype=torch.float16,
batch_size=2,
n_heads=2,
Expand All @@ -152,12 +152,14 @@ def loss_fn(x):

# Triton:
n_blocks = n_ctx // block
layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long))
layout = torch.tril(
torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long), diagonal=-1
)
query, key, value = [x.clone() for x in qkvs]
query.retain_grad()
key.retain_grad()
value.retain_grad()
if block not in [16, 32, 64]:
if block not in [16, 32, 64, 128]:
# Check that unsupported dimensions are caught
with pytest.raises(AssertionError):
_ = BlockSparseAttention(layout, block)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_triton_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"dtype", [torch.float32]
) # Triton use tensor cores, which return slightly different results to pytorch mm
def test_fused_matmul(shape, dtype):
""" Check that the matrix multiply kernel and Pytorch's give the same results"""
"""Check that the matrix multiply kernel and Pytorch's give the same results"""
torch.random.manual_seed(0)

# Raw fused matrix multiply first, to catch gross errors
Expand Down
2 changes: 1 addition & 1 deletion xformers/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


def pretty_print(results, title, units):
""" Printout the contents of a dict as a human-readable and Markdown compatible array"""
"""Printout the contents of a dict as a human-readable and Markdown compatible array"""
print(title)
header = " Units: {:<45}".format(units)
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
Expand Down
63 changes: 34 additions & 29 deletions xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def __init__(
16,
32,
64,
), "Only block sizes in [16, 32, 64] are supported"
128,
), "Only block sizes in [16, 32, 64, 128] are supported"

super().__init__()

Expand Down Expand Up @@ -112,6 +113,32 @@ def update_mask_type(self, mask: torch.Tensor):
)
mask = bool_mask_to_additive(mask)

def create_triton_kernels(self, device):
# blocksparse operators
self.sparse_dot_sdd = blocksparse_matmul(
self.layout,
self.block_size,
"sdd",
trans_a=False,
trans_b=True,
device=device,
)

self.sparse_dot_dsd = blocksparse_matmul(
self.layout,
self.block_size,
"dsd",
trans_a=False,
trans_b=False,
device=device,
)

self.sparse_softmax = blocksparse_softmax(
self.layout,
self.block_size,
device=device,
)

def forward(
self,
q: torch.Tensor,
Expand All @@ -132,31 +159,9 @@ def forward(
"""

# Delayed triton init, to make sure that we get the right device
# Infer device from query
if not hasattr(self, "sparse_dot_sdd"):
# blocksparse operators
self.sparse_dot_sdd = blocksparse_matmul(
self.layout,
self.block_size,
"sdd",
trans_a=False,
trans_b=True,
device=q.device,
)

self.sparse_dot_dsd = blocksparse_matmul(
self.layout,
self.block_size,
"dsd",
trans_a=False,
trans_b=False,
device=q.device,
)

self.sparse_softmax = blocksparse_softmax(
self.layout,
self.block_size,
device=q.device,
)
self.create_triton_kernels(q.device)

assert (
q.shape[-2] == k.shape[-2]
Expand All @@ -169,10 +174,10 @@ def forward(
k.shape[-2] == self.layout.shape[-2] * self.block_size
), "Actual sequence size and layout are inconsistent"

assert math.log(
q.shape[-2], 2
).is_integer(), (
"For now blocksparse only works on power-of-two sequence lengths"
assert (
q.shape[-2] % self.block_size
) == 0, "Sequence length {} must be a multiple of block size {}".format(
q.shape[-2], self.block_size
)

# Blocksparse only works on fp16
Expand Down
2 changes: 1 addition & 1 deletion xformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def rmf(filename: str) -> None:

@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
""" A context to get tempfiles and ensure they are cleaned up. """
"""A context to get tempfiles and ensure they are cleaned up."""
files = [tempfile.mkstemp()[1] for _ in range(num)]

yield tuple(files)
Expand Down

0 comments on commit 72fb5c7

Please sign in to comment.