From 9bd0fd623bc732e9faf4d9695c98c891e2694650 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 30 Oct 2025 20:30:15 +0800 Subject: [PATCH] Adds block-wise smoothing to attention mask Introduces an optional block_size to aggregate top-k selections along the key dimension using a majority vote, reducing fragmentation and encouraging locality in the dynamic mask. Validates block_size as a positive integer, handles remainder tails, and forwards the parameter through mask creation. Updates docs accordingly. Preserves previous behavior when unset and uses in-place ops for efficiency. --- flash_dmattn/utils/mask.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/flash_dmattn/utils/mask.py b/flash_dmattn/utils/mask.py index 69f6eb7..a5b3315 100644 --- a/flash_dmattn/utils/mask.py +++ b/flash_dmattn/utils/mask.py @@ -22,6 +22,7 @@ def dynamic_mask( attention_mask: Optional[torch.Tensor], window_size: int, min_dtype: float, + block_size: Optional[int] = None, ): r""" This function generates a dynamic mask based on the top-k attention bias. @@ -33,11 +34,18 @@ def dynamic_mask( ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). window_size (int): The number of top elements to consider for the mask. min_dtype (float): The minimum value to use for masking. + block_size (Optional[int]): Optional size of aggregation blocks to smooth the + resulting mask along the key dimension. Returns: attention_mask (Tensor): The attention mask tensor of shape ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). """ + if block_size is not None: + if int(block_size) != block_size or block_size <= 0: + raise ValueError(f"block_size must be a positive integer, got {block_size}.") + block_size = int(block_size) + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias topk_values, topk_indices = torch.topk( attention_bias.detach(), @@ -46,6 +54,26 @@ def dynamic_mask( attention_mask = torch.zeros_like( attention_bias, dtype=torch.bool, device=attention_bias.device ).scatter_(-1, topk_indices, topk_values != min_dtype) + + if block_size is not None and block_size > 1: + key_len = attention_mask.shape[-1] + full_len = (key_len // block_size) * block_size + + if full_len: + block_view = attention_mask[..., :full_len] + block_shape = (*block_view.shape[:-1], full_len // block_size, block_size) + blocks = block_view.view(*block_shape) + block_counts = blocks.sum(dim=-1).to(torch.int32) + block_keep = (block_counts * 2) > block_size + blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) + + if key_len > full_len: + tail_slice = attention_mask[..., full_len:] + tail_len = tail_slice.shape[-1] + tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) + tail_keep = (tail_counts * 2) > tail_len + tail_slice.copy_(tail_keep.expand_as(tail_slice)) + return attention_mask @@ -57,6 +85,7 @@ def create_mask( key_len: int, window_size: int, min_dtype: float, + block_size: Optional[int] = None, ) -> torch.Tensor: r""" This function creates a mask tensor for Flash Dynamic Mask Attention. @@ -73,6 +102,7 @@ def create_mask( key_len (int): The sequence length of the key. window_size (int): The number of top elements to consider for the attention mask. min_dtype (float): The minimum value to use for masking. + block_size (Optional[int]): Optional size of aggregation blocks after top-k masking. Returns: attention (Tensor): The attention mask tensor of shape @@ -103,6 +133,12 @@ def create_mask( ) # Generate dynamic mask based on attention_bias and attention_mask - attention_mask = dynamic_mask(attention_bias, attention_mask, window_size, min_dtype) + attention_mask = dynamic_mask( + attention_bias, + attention_mask, + window_size, + min_dtype, + block_size=block_size, + ) return attention_mask