Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion flash_dmattn/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def dynamic_mask(
attention_mask: Optional[torch.Tensor],
window_size: int,
min_dtype: float,
block_size: Optional[int] = None,
):
r"""
This function generates a dynamic mask based on the top-k attention bias.
Expand All @@ -33,11 +34,18 @@ def dynamic_mask(
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
window_size (int): The number of top elements to consider for the mask.
min_dtype (float): The minimum value to use for masking.
block_size (Optional[int]): Optional size of aggregation blocks to smooth the
resulting mask along the key dimension.

Returns:
attention_mask (Tensor): The attention mask tensor of shape
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
"""
if block_size is not None:
if int(block_size) != block_size or block_size <= 0:
raise ValueError(f"block_size must be a positive integer, got {block_size}.")
block_size = int(block_size)

attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias
topk_values, topk_indices = torch.topk(
attention_bias.detach(),
Expand All @@ -46,6 +54,26 @@ def dynamic_mask(
attention_mask = torch.zeros_like(
attention_bias, dtype=torch.bool, device=attention_bias.device
).scatter_(-1, topk_indices, topk_values != min_dtype)

if block_size is not None and block_size > 1:
key_len = attention_mask.shape[-1]
full_len = (key_len // block_size) * block_size

if full_len:
block_view = attention_mask[..., :full_len]
block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
blocks = block_view.view(*block_shape)
block_counts = blocks.sum(dim=-1).to(torch.int32)
block_keep = (block_counts * 2) > block_size
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition (block_counts * 2) > block_size is a magic number pattern that obscures the majority voting logic. Consider extracting this as a named constant or adding an inline comment to clarify this represents 'keep block if more than 50% of elements are True'.

Copilot uses AI. Check for mistakes.
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))

if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
Comment on lines +66 to +73
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converting to torch.int32 may cause unnecessary precision loss or overflow for large block sizes. Consider using torch.int64 for safer integer arithmetic, or document why int32 is sufficient for the expected use cases.

Suggested change
block_counts = blocks.sum(dim=-1).to(torch.int32)
block_keep = (block_counts * 2) > block_size
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
block_counts = blocks.sum(dim=-1).to(torch.int64)
block_keep = (block_counts * 2) > block_size
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64)

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as line 66, converting to torch.int32 may cause precision loss. Consider using torch.int64 for consistency and to handle edge cases with large tensors safely.

Suggested change
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64)

Copilot uses AI. Check for mistakes.
tail_keep = (tail_counts * 2) > tail_len
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as line 67, this condition (tail_counts * 2) > tail_len uses a magic number pattern. Consider adding a comment or using a named constant to clarify this represents the majority voting threshold (>50%).

Copilot uses AI. Check for mistakes.
tail_slice.copy_(tail_keep.expand_as(tail_slice))

return attention_mask


Expand All @@ -57,6 +85,7 @@ def create_mask(
key_len: int,
window_size: int,
min_dtype: float,
block_size: Optional[int] = None,
) -> torch.Tensor:
r"""
This function creates a mask tensor for Flash Dynamic Mask Attention.
Expand All @@ -73,6 +102,7 @@ def create_mask(
key_len (int): The sequence length of the key.
window_size (int): The number of top elements to consider for the attention mask.
min_dtype (float): The minimum value to use for masking.
block_size (Optional[int]): Optional size of aggregation blocks after top-k masking.

Returns:
attention (Tensor): The attention mask tensor of shape
Expand Down Expand Up @@ -103,6 +133,12 @@ def create_mask(
)

# Generate dynamic mask based on attention_bias and attention_mask
attention_mask = dynamic_mask(attention_bias, attention_mask, window_size, min_dtype)
attention_mask = dynamic_mask(
attention_bias,
attention_mask,
window_size,
min_dtype,
block_size=block_size,
)

return attention_mask
Loading