Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions csrc/flash_dmattn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,13 +979,10 @@ mha_bwd(
dbias_expanded = has_bias
? (
(num_heads_bias != num_heads || batch_size_bias != batch_size || seqlen_q_bias != seqlen_q) // MQA / GQA or dbias has different batch size or seqlen_q
? torch::empty({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
? torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
: dbias
)
: torch::empty({0}, opts);
if (has_bias) {
dbias_expanded.zero_();
}

Flash_bwd_params params;

Expand Down Expand Up @@ -1050,7 +1047,7 @@ mha_bwd(
}
}

return { dq, dk, dv, dbias, softmax_d };
return {dq, dk, dv, dbias, softmax_d};
}

std::vector<at::Tensor>
Expand Down
4 changes: 3 additions & 1 deletion flash_dmattn/flash_dmattn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def forward(
deterministic = False
if return_softmax is None:
return_softmax = False
seqlen_k_bias_og = bias.shape[-1] if bias is not None else 0

# Padding to multiple of 8 for 16-bit memory allocations
head_size_og = q.size(3)
Expand Down Expand Up @@ -446,6 +447,7 @@ def forward(
ctx.is_causal = is_causal
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.seqlen_k_bias_og = seqlen_k_bias_og

out = out_padded[..., :head_size_og]

Expand Down Expand Up @@ -491,7 +493,7 @@ def backward(
dv = dv[..., : dout.shape[-1]]

if dbias is not None:
dbias = dbias[..., : k.shape[1]]
dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : k.shape[1]]
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This line is quite complex and hard to read. Consider breaking it into multiple lines or extracting the conditional logic into a separate variable for better readability.

Suggested change
dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : k.shape[1]]
if ctx.seqlen_k_bias_og == 1:
dbias = dbias[..., :k.shape[1]].sum(dim=-1, keepdim=True)
else:
dbias = dbias[..., :k.shape[1]]

Copilot uses AI. Check for mistakes.

return dq, dk, dv, None, dbias, None, None, None, None, None, None

Expand Down