From 4417b164d967415adc84e65dbd42c2ac5ea89e3e Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 23 Oct 2025 21:08:52 +0800 Subject: [PATCH 1/6] Add missing trailing newline to flash_dmattn/utils/padding.py --- flash_dmattn/utils/padding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_dmattn/utils/padding.py b/flash_dmattn/utils/padding.py index 27350f4..b675af7 100644 --- a/flash_dmattn/utils/padding.py +++ b/flash_dmattn/utils/padding.py @@ -167,4 +167,4 @@ def upad_input( indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) \ No newline at end of file + ) From a06bff1c1d6140b4db1cc762cfaab660ef444c71 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 23 Oct 2025 21:11:32 +0800 Subject: [PATCH 2/6] =?UTF-8?q?Adds=20dynamic=20top=E2=80=91k=20attention?= =?UTF-8?q?=20mask=20utilities?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces utilities to build boolean masks for Flash Dynamic Mask Attention by selecting top‑k positions from an attention bias, improving sparsity and compute efficiency. Handles 2D mask reshaping and padding to align query/key lengths, respects existing masks, and excludes invalid positions via a configurable minimum value. --- flash_dmattn/utils/mask.py | 108 +++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 flash_dmattn/utils/mask.py diff --git a/flash_dmattn/utils/mask.py b/flash_dmattn/utils/mask.py new file mode 100644 index 0000000..4978c59 --- /dev/null +++ b/flash_dmattn/utils/mask.py @@ -0,0 +1,108 @@ +# Copyright 2025 Jingze Shi and Liangdong Wang. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + + +def dynamic_mask( + attention_bias: torch.Tensor, + attention_mask: Optional[torch.Tensor], + window_size: int, + min_dtype: float, +): + r""" + This function generates a dynamic mask based on the top-k attention bias. + + Args: + attention_bias (torch.Tensor): The attention bias tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + window_size (int): The number of top elements to consider for the mask. + min_dtype (float): The minimum value to use for masking. + + Returns: + attention_mask (Tensor): The attention mask tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + """ + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias + topk_values, topk_indices = torch.topk( + attention_bias.detach(), + window_size, dim=-1, largest=True, sorted=False + ) + attention_mask = torch.zeros_like( + attention_bias, dtype=torch.bool, device=attention_bias.device + ).scatter_(-1, topk_indices, topk_values != min_dtype) + return attention_mask + + +def create_mask( + attention_bias: torch.Tensor, + attention_mask: Optional[torch.Tensor], + batch_size: int, + query_len: int, + key_len: int, + window_size: int, + min_dtype: float, +) -> torch.Tensor: + r""" + This function creates a mask tensor for Flash Dynamic Mask Attention. + + If attention_mask is not of shape (batch_size, seq_len), it needs to match the shape of attention_bias. + + Args: + Args: + attention_bias (torch.Tensor): The attention bias tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). + batch_size (int): The batch size. + query_len (int): The sequence length of the query. + key_len (int): The sequence length of the key. + window_size (int): The number of top elements to consider for the attention mask. + min_dtype (float): The minimum value to use for masking. + + Returns: + attention (Tensor): The attention mask tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). + """ + + # If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len) + if attention_mask is not None and attention_mask.dim() == 2: + if attention_mask.shape[-1] == key_len: + attention_mask = attention_mask.view(batch_size, 1, 1, key_len) + elif attention_mask.shape[-1] == query_len: + pad_len = key_len - query_len + if pad_len > 0: + pad_mask = torch.ones( + (batch_size, 1, 1, pad_len), + dtype=torch.bool, + device=attention_mask.device, + ) + attention_mask = torch.cat( + [pad_mask, attention_mask.view(batch_size, 1, 1, query_len)], + dim=-1, + ) + else: + attention_mask = attention_mask.view(batch_size, 1, 1, query_len) + else: + raise ValueError( + f"attention_mask shape {attention_mask.shape} is not compatible with key_len {key_len} or query_len {query_len}." + ) + + attention_mask = dynamic_mask(attention_bias, attention_mask, window_size, min_dtype) + + return attention_mask From 510ef4d217459892fe3bc27f87171fd2dcfe8cc8 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 23 Oct 2025 21:13:15 +0800 Subject: [PATCH 3/6] Refactors FDMA utils and centralizes mask creation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Renames internal attention callsites to FDMA-prefixed names for clarity and consistency. Adds lazy import and wiring for a mask creation utility and uses it to build sliding‑window masks instead of ad‑hoc top‑k logic, improving readability and numerical correctness by using attention bias dtype for min. Removes local pad/unpad fallbacks in favor of package implementations. Updates lazy loader return signature and processing hook accordingly. --- ...ling_flash_dynamic_mask_attention_utils.py | 110 +++++------------- 1 file changed, 26 insertions(+), 84 deletions(-) diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index b21a69d..c2638b8 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -27,10 +27,11 @@ # `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves -_flash_fn = None -_flash_varlen_fn = None +_fdma_fn = None +_fdma_varlen_fn = None _pad_fn = None _unpad_fn = None +_create_mask_fn = None # function that processes kwargs, generalized to handle any supported kwarg within the function _process_flash_kwargs_fn = None @@ -53,13 +54,12 @@ def _lazy_imports(implementation: Optional[str]): """ is_fdma = is_flash_dmattn_available() - pad_input, unpad_input = _pad_input, _unpad_input - if (implementation == "flash_dmattn" and is_fdma) or (implementation is None and is_fdma): from flash_dmattn import flash_dmattn_func, flash_dmattn_varlen_func from flash_dmattn.utils.padding import pad_input, unpad_input + from flash_dmattn.utils.mask import create_mask - return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input + return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input, create_mask def _lazy_define_process_function(flash_function): @@ -90,15 +90,15 @@ def lazy_import_flash_dynamic_mask_attention(implementation: Optional[str], forc NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`. """ - global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn - if force_import or any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]): - _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation) - + global _fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn + if force_import or any(k is None for k in [_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn]): + _fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn = _lazy_imports(implementation) + global _process_flash_kwargs_fn if force_import or _process_flash_kwargs_fn is None: - _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn) + _process_flash_kwargs_fn = _lazy_define_process_function(_fdma_varlen_fn) - return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn + return (_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn), _process_flash_kwargs_fn def _index_first_axis(tensor, indices): @@ -113,57 +113,6 @@ def _index_first_axis(tensor, indices): return reshaped_tensor[indices] -def _unpad_input(hidden_states, attention_mask, unused_mask=None): - """ - unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. - - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. - unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. - - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. - indices: (total_nnz), the indices of masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. - """ - all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask - seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) - used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - - return ( - _index_first_axis(hidden_states, indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - used_seqlens_in_batch, - ) - - -def _pad_input(hidden_states, indices, batch, seqlen): - """ - pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. - - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - - Return: - hidden_states: (batch, seqlen, ...) - """ - dim = hidden_states.shape[1:] - output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) - output[indices] = hidden_states - return output.view(batch, seqlen, *dim) - - def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: """ Retrieves indexing data required to repad unpadded (ragged) tensors. @@ -527,7 +476,7 @@ def _flash_dynamic_mask_attention_forward( "If shape of attention_mask is (batch_size, seq_len), attention_bias has to be None." ) - (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation) + (fdma_fn, fdma_varlen_fn, pad_fn, unpad_fn, create_mask_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation) # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op query_states, key_states, value_states, attention_bias = fdma_peft_integration_check( @@ -546,10 +495,10 @@ def _flash_dynamic_mask_attention_forward( **kwargs, ) - # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: + # We will use `fdma_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`. # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to - # use `flash_varlen_fn` knowing we already have all necessary the kwargs. + # use `fdma_varlen_fn` knowing we already have all necessary the kwargs. # # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model. # See #39121 for more information. @@ -569,7 +518,7 @@ def _flash_dynamic_mask_attention_forward( if "mps" in str(q.device): cu_seq_lens_k = cu_seq_lens_k.clone() - out_unpad = flash_varlen_fn( + out_unpad = fdma_varlen_fn( q, k, v, @@ -600,7 +549,7 @@ def _flash_dynamic_mask_attention_forward( if "mps" in str(q.device): cu_seq_lens_k = cu_seq_lens_k.clone() - out = flash_varlen_fn( + out = fdma_varlen_fn( q, k, v, @@ -624,24 +573,17 @@ def _flash_dynamic_mask_attention_forward( and window_size is not None and key_length > window_size ): - min_dtype = torch.finfo(query_states.dtype).min - if attention_mask is not None: - topk_values, topk_indices = torch.topk( - attention_bias.masked_fill(~attention_mask, min_dtype).detach(), - window_size, dim=-1, largest=True, sorted=False - ) - attention_mask = torch.zeros_like( - attention_bias, dtype=torch.bool, device=attention_bias.device - ).scatter_(-1, topk_indices, topk_values != min_dtype) - else: - topk_values, topk_indices = torch.topk( - attention_bias.detach(), window_size, dim=-1, largest=True, sorted=False - ) - attention_mask = torch.zeros_like( - attention_bias, dtype=torch.bool, device=attention_bias.device - ).scatter_(-1, topk_indices, topk_values != min_dtype) + attention_mask = create_mask_fn( + attention_bias, + attention_mask, + batch_size=query_states.size(0), + query_len=query_length, + key_len=key_length, + window_size=window_size, + min_dtype=torch.finfo(attention_bias.dtype).min, + ) - out = flash_fn( + out = fdma_fn( query_states, key_states, value_states, From 7d4cf23bddaf0569e83d2a25d37f42874d6ee6c3 Mon Sep 17 00:00:00 2001 From: Jingze Date: Thu, 23 Oct 2025 21:20:31 +0800 Subject: [PATCH 4/6] Update flash_dmattn/utils/mask.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- flash_dmattn/utils/mask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flash_dmattn/utils/mask.py b/flash_dmattn/utils/mask.py index 4978c59..c826b75 100644 --- a/flash_dmattn/utils/mask.py +++ b/flash_dmattn/utils/mask.py @@ -64,7 +64,6 @@ def create_mask( If attention_mask is not of shape (batch_size, seq_len), it needs to match the shape of attention_bias. Args: - Args: attention_bias (torch.Tensor): The attention bias tensor of shape ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape From a2b5309eaf6643fec1061ab58357c25336b2a9bc Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Thu, 23 Oct 2025 21:22:45 +0800 Subject: [PATCH 5/6] Normalize shape notation in create_mask docstrings (use key_len instead of {key_len|1}) --- flash_dmattn/utils/mask.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_dmattn/utils/mask.py b/flash_dmattn/utils/mask.py index c826b75..b14e129 100644 --- a/flash_dmattn/utils/mask.py +++ b/flash_dmattn/utils/mask.py @@ -65,9 +65,9 @@ def create_mask( Args: attention_bias (torch.Tensor): The attention bias tensor of shape - ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). - attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape - (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). batch_size (int): The batch size. query_len (int): The sequence length of the query. key_len (int): The sequence length of the key. @@ -76,7 +76,7 @@ def create_mask( Returns: attention (Tensor): The attention mask tensor of shape - ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}). + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). """ # If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len) From 0dbd67397deb0615d33906fcb547db0d09d392ce Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 23 Oct 2025 21:38:12 +0800 Subject: [PATCH 6/6] Add blank line for readability before dynamic mask generation in create_mask --- flash_dmattn/utils/mask.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash_dmattn/utils/mask.py b/flash_dmattn/utils/mask.py index b14e129..69f6eb7 100644 --- a/flash_dmattn/utils/mask.py +++ b/flash_dmattn/utils/mask.py @@ -102,6 +102,7 @@ def create_mask( f"attention_mask shape {attention_mask.shape} is not compatible with key_len {key_len} or query_len {query_len}." ) + # Generate dynamic mask based on attention_bias and attention_mask attention_mask = dynamic_mask(attention_bias, attention_mask, window_size, min_dtype) return attention_mask