Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions flash_dmattn/flash_dmattn_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ def _fwd_kernel(
)

# Check if any element in mask is non-zero
any_active = tl.sum(mask > 0) > 0
any_active = tl.sum(mask) > 0
Copy link

Copilot AI Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using tl.any(mask) (or tl.any(mask != 0)) instead of tl.sum(mask) > 0 to more directly express 'any active elements' and potentially improve efficiency.

Suggested change
any_active = tl.sum(mask) > 0
any_active = tl.any(mask)

Copilot uses AI. Check for mistakes.

# compute acc_s
acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
if any_active:
# Load k
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
Expand Down Expand Up @@ -166,7 +166,7 @@ def _fwd_kernel(
if IS_CAUSAL:
acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
# Apply dynamic mask
acc_s += tl.where(mask > 0.0, 0, float("-inf"))
acc_s += tl.where(mask > 0, 0, float("-inf"))

# Load bias
if EVEN_M & EVEN_N:
Expand All @@ -184,6 +184,7 @@ def _fwd_kernel(
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
acc_s = acc_s * softmax_scale + bias
Copy link

Copilot AI Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unconditional scaling duplicates the conditional scaling on the next line, resulting in double application of softmax_scale. Remove this line so that the tl.where version is the single scaling step.

Suggested change
acc_s = acc_s * softmax_scale + bias
# Removed duplicate scaling to avoid double application of softmax_scale.

Copilot uses AI. Check for mistakes.
acc_s = tl.where(acc_s != float("-inf"), acc_s * softmax_scale + bias, acc_s)
m_ij = tl.maximum(tl.max(acc_s, 1), lse_i)
p = tl.exp(acc_s - m_ij[:, None])
l_ij = tl.sum(p, 1)
Expand All @@ -197,8 +198,8 @@ def _fwd_kernel(
acc_o_scale = tl.load(t_ptrs)
acc_o = acc_o * acc_o_scale[:, None]

# update acc_o
if any_active:
# load v
Copy link

Copilot AI Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The comment # load v is vague; consider expanding it to clarify what data is being loaded and why, e.g., # load V tensor block for value projection.

Suggested change
# load v
# Load the V tensor block for value projection in the attention mechanism.

Copilot uses AI. Check for mistakes.
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
Expand Down