From 42e118d49cce554f278560b28597138421f737b5 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Mon, 13 Oct 2025 15:58:44 +0800 Subject: [PATCH 1/2] Fix attention_mask and attention_bias shape descriptions in docstring --- flash_dmattn/integrations/flash_dynamic_mask_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_dmattn/integrations/flash_dynamic_mask_attention.py b/flash_dmattn/integrations/flash_dynamic_mask_attention.py index b583a29..f842ae6 100644 --- a/flash_dmattn/integrations/flash_dynamic_mask_attention.py +++ b/flash_dmattn/integrations/flash_dynamic_mask_attention.py @@ -30,10 +30,10 @@ def flash_dynamic_mask_attention_forward( query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim). key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim). value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim). - attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape - (batch_size, seq_len) or (batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape - (batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len). + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). scaling (Optional[float]): The scaling factor for the attention scores. window_size (Optional[int]): The size of the window to keep. softcap (Optional[float]): The softcap value for the attention scores. From 623b75d096af916cc8515906efbc33e672310987 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Mon, 13 Oct 2025 16:11:13 +0800 Subject: [PATCH 2/2] Remove redundant dimension checks for attention_mask and attention_bias in _flash_dynamic_mask_attention_forward --- .../modeling_flash_dynamic_mask_attention_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index 538acd0..b21a69d 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -626,11 +626,6 @@ def _flash_dynamic_mask_attention_forward( ): min_dtype = torch.finfo(query_states.dtype).min if attention_mask is not None: - if attention_mask.dim() == 4 and attention_bias.dim() == 3: - attention_bias = attention_bias.unsqueeze(-2).expand(-1, -1, query_length, -1) - if attention_mask.dim() == 3 and attention_bias.dim() == 4: - attention_mask = attention_mask.unsqueeze(-2).expand(-1, -1, query_length, -1) - topk_values, topk_indices = torch.topk( attention_bias.masked_fill(~attention_mask, min_dtype).detach(), window_size, dim=-1, largest=True, sorted=False