Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flash_dmattn/integrations/flash_dynamic_mask_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def flash_dynamic_mask_attention_forward(
query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
(batch_size, seq_len) or (batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape
(batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len).
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
scaling (Optional[float]): The scaling factor for the attention scores.
window_size (Optional[int]): The size of the window to keep.
softcap (Optional[float]): The softcap value for the attention scores.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -626,11 +626,6 @@ def _flash_dynamic_mask_attention_forward(
):
min_dtype = torch.finfo(query_states.dtype).min
if attention_mask is not None:
Copy link

Copilot AI Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the dimension compatibility checks, this code will fail if attention_mask and attention_bias have incompatible shapes for broadcasting in the masked_fill operation. The removed checks were handling dimension mismatches between 3D and 4D tensors.

Suggested change
if attention_mask is not None:
if attention_mask is not None:
# Ensure attention_mask is broadcastable to attention_bias
if attention_mask.shape != attention_bias.shape:
# Try to unsqueeze or expand attention_mask to match attention_bias
# Common case: attention_mask is (batch, seq_len) and attention_bias is (batch, num_heads, seq_len, seq_len)
# or attention_mask is (batch, 1, seq_len, seq_len)
# We'll try to expand attention_mask to match attention_bias
try:
attention_mask = attention_mask.expand_as(attention_bias)
except RuntimeError:
# Try unsqueeze and expand for common 3D->4D case
while attention_mask.dim() < attention_bias.dim():
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.expand_as(attention_bias)

Copilot uses AI. Check for mistakes.
if attention_mask.dim() == 4 and attention_bias.dim() == 3:
attention_bias = attention_bias.unsqueeze(-2).expand(-1, -1, query_length, -1)
if attention_mask.dim() == 3 and attention_bias.dim() == 4:
attention_mask = attention_mask.unsqueeze(-2).expand(-1, -1, query_length, -1)

topk_values, topk_indices = torch.topk(
attention_bias.masked_fill(~attention_mask, min_dtype).detach(),
window_size, dim=-1, largest=True, sorted=False
Expand Down