From dca0b7841e4449e9a03d0f603df634b35d2e015a Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 10 Sep 2020 00:07:30 -0700 Subject: [PATCH] Fix datatype issue with sparse attention softmax (#363) Fixes a dataype issue with softmax where the number of blocks being sent to the Triton kernel source was a torch.Tensor but should have been a python integer. On some environments (e.g., conda) this resulted in triton not knowing how to serialize the input (and crashing in our tests). Once switching to the correct datatype that triton expects this seems to have solved the issue. Co-authored-by: Shaden Smith --- deepspeed/ops/sparse_attention/softmax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/ops/sparse_attention/softmax.py b/deepspeed/ops/sparse_attention/softmax.py index 814e9cc50e19..41267298a0a4 100644 --- a/deepspeed/ops/sparse_attention/softmax.py +++ b/deepspeed/ops/sparse_attention/softmax.py @@ -234,7 +234,7 @@ def __init__(self, layout, block, bench=False): bench: optional: set if you want to do benchmarking """ - self.num_blocks = layout.sum() + self.num_blocks = layout.sum().item() self.spdims = layout.shape self.layout = layout self.block = block