diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index b21a69d..c2638b8 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -27,10 +27,11 @@ # `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves -_flash_fn = None -_flash_varlen_fn = None +_fdma_fn = None +_fdma_varlen_fn = None _pad_fn = None _unpad_fn = None +_create_mask_fn = None # function that processes kwargs, generalized to handle any supported kwarg within the function _process_flash_kwargs_fn = None @@ -53,13 +54,12 @@ def _lazy_imports(implementation: Optional[str]): """ is_fdma = is_flash_dmattn_available() - pad_input, unpad_input = _pad_input, _unpad_input - if (implementation == "flash_dmattn" and is_fdma) or (implementation is None and is_fdma): from flash_dmattn import flash_dmattn_func, flash_dmattn_varlen_func from flash_dmattn.utils.padding import pad_input, unpad_input + from flash_dmattn.utils.mask import create_mask - return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input + return flash_dmattn_func, flash_dmattn_varlen_func, pad_input, unpad_input, create_mask def _lazy_define_process_function(flash_function): @@ -90,15 +90,15 @@ def lazy_import_flash_dynamic_mask_attention(implementation: Optional[str], forc NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`. """ - global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn - if force_import or any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]): - _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation) - + global _fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn + if force_import or any(k is None for k in [_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn]): + _fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn = _lazy_imports(implementation) + global _process_flash_kwargs_fn if force_import or _process_flash_kwargs_fn is None: - _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn) + _process_flash_kwargs_fn = _lazy_define_process_function(_fdma_varlen_fn) - return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn + return (_fdma_fn, _fdma_varlen_fn, _pad_fn, _unpad_fn, _create_mask_fn), _process_flash_kwargs_fn def _index_first_axis(tensor, indices): @@ -113,57 +113,6 @@ def _index_first_axis(tensor, indices): return reshaped_tensor[indices] -def _unpad_input(hidden_states, attention_mask, unused_mask=None): - """ - unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. - - Arguments: - hidden_states: (batch, seqlen, ...) - attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. - unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. - - Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. - indices: (total_nnz), the indices of masked tokens from the flattened input sequence. - cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. - max_seqlen_in_batch: int - seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. - """ - all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask - seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) - used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - - return ( - _index_first_axis(hidden_states, indices), - indices, - cu_seqlens, - max_seqlen_in_batch, - used_seqlens_in_batch, - ) - - -def _pad_input(hidden_states, indices, batch, seqlen): - """ - pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3. - - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. - batch: int, batch size for the padded sequence. - seqlen: int, maximum sequence length for the padded sequence. - - Return: - hidden_states: (batch, seqlen, ...) - """ - dim = hidden_states.shape[1:] - output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) - output[indices] = hidden_states - return output.view(batch, seqlen, *dim) - - def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: """ Retrieves indexing data required to repad unpadded (ragged) tensors. @@ -527,7 +476,7 @@ def _flash_dynamic_mask_attention_forward( "If shape of attention_mask is (batch_size, seq_len), attention_bias has to be None." ) - (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation) + (fdma_fn, fdma_varlen_fn, pad_fn, unpad_fn, create_mask_fn), process_flash_kwargs_fn = lazy_import_flash_dynamic_mask_attention(implementation) # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op query_states, key_states, value_states, attention_bias = fdma_peft_integration_check( @@ -546,10 +495,10 @@ def _flash_dynamic_mask_attention_forward( **kwargs, ) - # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: + # We will use `fdma_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases: # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`. # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to - # use `flash_varlen_fn` knowing we already have all necessary the kwargs. + # use `fdma_varlen_fn` knowing we already have all necessary the kwargs. # # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model. # See #39121 for more information. @@ -569,7 +518,7 @@ def _flash_dynamic_mask_attention_forward( if "mps" in str(q.device): cu_seq_lens_k = cu_seq_lens_k.clone() - out_unpad = flash_varlen_fn( + out_unpad = fdma_varlen_fn( q, k, v, @@ -600,7 +549,7 @@ def _flash_dynamic_mask_attention_forward( if "mps" in str(q.device): cu_seq_lens_k = cu_seq_lens_k.clone() - out = flash_varlen_fn( + out = fdma_varlen_fn( q, k, v, @@ -624,24 +573,17 @@ def _flash_dynamic_mask_attention_forward( and window_size is not None and key_length > window_size ): - min_dtype = torch.finfo(query_states.dtype).min - if attention_mask is not None: - topk_values, topk_indices = torch.topk( - attention_bias.masked_fill(~attention_mask, min_dtype).detach(), - window_size, dim=-1, largest=True, sorted=False - ) - attention_mask = torch.zeros_like( - attention_bias, dtype=torch.bool, device=attention_bias.device - ).scatter_(-1, topk_indices, topk_values != min_dtype) - else: - topk_values, topk_indices = torch.topk( - attention_bias.detach(), window_size, dim=-1, largest=True, sorted=False - ) - attention_mask = torch.zeros_like( - attention_bias, dtype=torch.bool, device=attention_bias.device - ).scatter_(-1, topk_indices, topk_values != min_dtype) + attention_mask = create_mask_fn( + attention_bias, + attention_mask, + batch_size=query_states.size(0), + query_len=query_length, + key_len=key_length, + window_size=window_size, + min_dtype=torch.finfo(attention_bias.dtype).min, + ) - out = flash_fn( + out = fdma_fn( query_states, key_states, value_states, diff --git a/flash_dmattn/utils/mask.py b/flash_dmattn/utils/mask.py new file mode 100644 index 0000000..69f6eb7 --- /dev/null +++ b/flash_dmattn/utils/mask.py @@ -0,0 +1,108 @@ +# Copyright 2025 Jingze Shi and Liangdong Wang. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + + +def dynamic_mask( + attention_bias: torch.Tensor, + attention_mask: Optional[torch.Tensor], + window_size: int, + min_dtype: float, +): + r""" + This function generates a dynamic mask based on the top-k attention bias. + + Args: + attention_bias (torch.Tensor): The attention bias tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + window_size (int): The number of top elements to consider for the mask. + min_dtype (float): The minimum value to use for masking. + + Returns: + attention_mask (Tensor): The attention mask tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + """ + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias + topk_values, topk_indices = torch.topk( + attention_bias.detach(), + window_size, dim=-1, largest=True, sorted=False + ) + attention_mask = torch.zeros_like( + attention_bias, dtype=torch.bool, device=attention_bias.device + ).scatter_(-1, topk_indices, topk_values != min_dtype) + return attention_mask + + +def create_mask( + attention_bias: torch.Tensor, + attention_mask: Optional[torch.Tensor], + batch_size: int, + query_len: int, + key_len: int, + window_size: int, + min_dtype: float, +) -> torch.Tensor: + r""" + This function creates a mask tensor for Flash Dynamic Mask Attention. + + If attention_mask is not of shape (batch_size, seq_len), it needs to match the shape of attention_bias. + + Args: + attention_bias (torch.Tensor): The attention bias tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape + (batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + batch_size (int): The batch size. + query_len (int): The sequence length of the query. + key_len (int): The sequence length of the key. + window_size (int): The number of top elements to consider for the attention mask. + min_dtype (float): The minimum value to use for masking. + + Returns: + attention (Tensor): The attention mask tensor of shape + ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). + """ + + # If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len) + if attention_mask is not None and attention_mask.dim() == 2: + if attention_mask.shape[-1] == key_len: + attention_mask = attention_mask.view(batch_size, 1, 1, key_len) + elif attention_mask.shape[-1] == query_len: + pad_len = key_len - query_len + if pad_len > 0: + pad_mask = torch.ones( + (batch_size, 1, 1, pad_len), + dtype=torch.bool, + device=attention_mask.device, + ) + attention_mask = torch.cat( + [pad_mask, attention_mask.view(batch_size, 1, 1, query_len)], + dim=-1, + ) + else: + attention_mask = attention_mask.view(batch_size, 1, 1, query_len) + else: + raise ValueError( + f"attention_mask shape {attention_mask.shape} is not compatible with key_len {key_len} or query_len {query_len}." + ) + + # Generate dynamic mask based on attention_bias and attention_mask + attention_mask = dynamic_mask(attention_bias, attention_mask, window_size, min_dtype) + + return attention_mask diff --git a/flash_dmattn/utils/padding.py b/flash_dmattn/utils/padding.py index 27350f4..b675af7 100644 --- a/flash_dmattn/utils/padding.py +++ b/flash_dmattn/utils/padding.py @@ -167,4 +167,4 @@ def upad_input( indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) \ No newline at end of file + )