diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 0bef421..f273451 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -132,10 +132,10 @@ def _fwd_kernel( ) # Check if any element in mask is non-zero - any_active = tl.sum(mask > 0) > 0 + any_active = tl.sum(mask) > 0 # compute acc_s - acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) if any_active: # Load k if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition @@ -166,7 +166,7 @@ def _fwd_kernel( if IS_CAUSAL: acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) # Apply dynamic mask - acc_s += tl.where(mask > 0.0, 0, float("-inf")) + acc_s += tl.where(mask > 0, 0, float("-inf")) # Load bias if EVEN_M & EVEN_N: @@ -184,6 +184,7 @@ def _fwd_kernel( # can then fuse the mult and add into an fma instruction. But if we have bias we need to # to multiply with softmax_scale here. acc_s = acc_s * softmax_scale + bias + acc_s = tl.where(acc_s != float("-inf"), acc_s * softmax_scale + bias, acc_s) m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) p = tl.exp(acc_s - m_ij[:, None]) l_ij = tl.sum(p, 1) @@ -197,8 +198,8 @@ def _fwd_kernel( acc_o_scale = tl.load(t_ptrs) acc_o = acc_o * acc_o_scale[:, None] - # update acc_o if any_active: + # load v if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition if EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn)