From b32176f6e1faed8bc9781b492222c2f656e2fd02 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 4 Jul 2025 13:39:55 +0800 Subject: [PATCH] Fixes mask comparison and scaling logic in attention kernel Simplifies mask comparison from `> 0` to direct boolean evaluation for consistency. Removes trailing whitespace for code cleanliness. Corrects duplicate softmax scaling application that was causing incorrect attention computations. Improves code readability by moving comment to more appropriate location. --- flash_dmattn/flash_dmattn_triton.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 0bef421..f273451 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -132,10 +132,10 @@ def _fwd_kernel( ) # Check if any element in mask is non-zero - any_active = tl.sum(mask > 0) > 0 + any_active = tl.sum(mask) > 0 # compute acc_s - acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) if any_active: # Load k if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition @@ -166,7 +166,7 @@ def _fwd_kernel( if IS_CAUSAL: acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) # Apply dynamic mask - acc_s += tl.where(mask > 0.0, 0, float("-inf")) + acc_s += tl.where(mask > 0, 0, float("-inf")) # Load bias if EVEN_M & EVEN_N: @@ -184,6 +184,7 @@ def _fwd_kernel( # can then fuse the mult and add into an fma instruction. But if we have bias we need to # to multiply with softmax_scale here. acc_s = acc_s * softmax_scale + bias + acc_s = tl.where(acc_s != float("-inf"), acc_s * softmax_scale + bias, acc_s) m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) p = tl.exp(acc_s - m_ij[:, None]) l_ij = tl.sum(p, 1) @@ -197,8 +198,8 @@ def _fwd_kernel( acc_o_scale = tl.load(t_ptrs) acc_o = acc_o * acc_o_scale[:, None] - # update acc_o if any_active: + # load v if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition if EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn)