diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index cae5b3e..a64c71d 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -1,4 +1,4 @@ -# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved. +# Copyright 2025 Jingze Shi and Liangdong Wang and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ from .import_utils import is_flash_dmattn_available from transformers.utils import logging -from transformers.integrations import flash_attention logger = logging.get_logger(__name__) @@ -26,7 +25,10 @@ def fdma_peft_integration_check(q, k, v, bias, target_dtype: Optional[torch.dtype] = None): if target_dtype and q.dtype == torch.float32: logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-dmattn compatibility.") - q, k, v, bias = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype), bias.to(target_dtype) + q = q.to(target_dtype) if q is not None else None + k = k.to(target_dtype) if k is not None else None + v = v.to(target_dtype) if v is not None else None + bias = bias.to(target_dtype) if bias is not None else None return q, k, v, bias @@ -66,7 +68,6 @@ def _flash_dynamic_mask_attention_forward( ): dtype = query_states.dtype min_dtype = torch.finfo(dtype).min - batch_size, _, num_kv_heads, _ = key_states.shape if not all(k in globals() for k in ("_flash_fn")): flash_fn = _lazy_imports(implementation) @@ -85,22 +86,34 @@ def _flash_dynamic_mask_attention_forward( query_states, key_states, value_states, attention_bias, target_dtype ) - if attention_mask is not None and attention_mask.dim() == 4: - if attention_bias.dim() == 3: - attention_bias = attention_bias.unsqueeze(-2) - attention_bias = attention_bias.masked_fill( - ~attention_mask, - min_dtype - ) - - if keep_window_size is not None and key_length > keep_window_size: - topk_values, topk_indices = torch.topk( - attention_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device) - attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype) - else: - attention_mask = None + if ( + attention_bias is not None + and keep_window_size is not None + and key_length > keep_window_size + ): + if attention_mask is not None: + if attention_mask.dim() == 4 and attention_bias.dim() == 3: + attention_bias_for_topk = attention_bias.unsqueeze(-2).expand_as(attention_mask) + else: + attention_bias_for_topk = attention_bias + + topk_indices = torch.topk( + attention_bias_for_topk.masked_fill(~attention_mask, min_dtype).detach(), + keep_window_size, + dim=-1, largest=True, sorted=False, + ).indices + attention_mask = torch.zeros_like(attention_bias_for_topk, dtype=torch.bool).scatter_( + -1, topk_indices, True + ) & attention_mask + else: + topk_indices = torch.topk( + attention_bias.detach(), + keep_window_size, + dim=-1, largest=True, sorted=False, + ).indices + attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool).scatter_( + -1, topk_indices, True + ) out = flash_fn( query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal