Skip to content

Commit

Permalink
Fix datatype issue with sparse attention softmax (microsoft#363)
Browse files Browse the repository at this point in the history
Fixes a dataype issue with softmax where the number of blocks being sent to the Triton kernel source was a torch.Tensor but should have been a python integer. On some environments (e.g., conda) this resulted in triton not knowing how to serialize the input (and crashing in our tests). Once switching to the correct datatype that triton expects this seems to have solved the issue.

Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
  • Loading branch information
jeffra and Shaden Smith committed Sep 10, 2020
1 parent 093f09f commit dca0b78
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepspeed/ops/sparse_attention/softmax.py
Expand Up @@ -234,7 +234,7 @@ def __init__(self, layout, block, bench=False):
bench: optional: set if you want to do benchmarking
"""

self.num_blocks = layout.sum()
self.num_blocks = layout.sum().item()
self.spdims = layout.shape
self.layout = layout
self.block = block
Expand Down

0 comments on commit dca0b78

Please sign in to comment.