Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Jingze Shi and the HuggingFace Inc. team. All rights reserved.
# Copyright 2025 Jingze Shi and Liangdong Wang and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,7 +17,6 @@
from .import_utils import is_flash_dmattn_available

from transformers.utils import logging
from transformers.integrations import flash_attention


Copy link

Copilot AI Oct 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Removed import statement leaves an empty line. Consider removing the blank line to maintain consistent spacing.

Suggested change

Copilot uses AI. Check for mistakes.
logger = logging.get_logger(__name__)
Expand All @@ -26,7 +25,10 @@
def fdma_peft_integration_check(q, k, v, bias, target_dtype: Optional[torch.dtype] = None):
if target_dtype and q.dtype == torch.float32:
logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-dmattn compatibility.")
q, k, v, bias = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype), bias.to(target_dtype)
q = q.to(target_dtype) if q is not None else None
k = k.to(target_dtype) if k is not None else None
v = v.to(target_dtype) if v is not None else None
bias = bias.to(target_dtype) if bias is not None else None
Comment on lines 26 to +31
Copy link

Copilot AI Oct 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The individual None checks can be simplified. Consider using a helper function or list comprehension to reduce code duplication and improve readability.

Copilot uses AI. Check for mistakes.
return q, k, v, bias


Expand Down Expand Up @@ -66,7 +68,6 @@ def _flash_dynamic_mask_attention_forward(
):
dtype = query_states.dtype
min_dtype = torch.finfo(dtype).min
batch_size, _, num_kv_heads, _ = key_states.shape

if not all(k in globals() for k in ("_flash_fn")):
flash_fn = _lazy_imports(implementation)
Expand All @@ -85,22 +86,34 @@ def _flash_dynamic_mask_attention_forward(
query_states, key_states, value_states, attention_bias, target_dtype
)

if attention_mask is not None and attention_mask.dim() == 4:
if attention_bias.dim() == 3:
attention_bias = attention_bias.unsqueeze(-2)
attention_bias = attention_bias.masked_fill(
~attention_mask,
min_dtype
)

if keep_window_size is not None and key_length > keep_window_size:
topk_values, topk_indices = torch.topk(
attention_bias, keep_window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device)
attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype)
else:
attention_mask = None
if (
attention_bias is not None
and keep_window_size is not None
and key_length > keep_window_size
):
if attention_mask is not None:
if attention_mask.dim() == 4 and attention_bias.dim() == 3:
attention_bias_for_topk = attention_bias.unsqueeze(-2).expand_as(attention_mask)
else:
attention_bias_for_topk = attention_bias

topk_indices = torch.topk(
attention_bias_for_topk.masked_fill(~attention_mask, min_dtype).detach(),
keep_window_size,
dim=-1, largest=True, sorted=False,
).indices
attention_mask = torch.zeros_like(attention_bias_for_topk, dtype=torch.bool).scatter_(
-1, topk_indices, True
) & attention_mask
else:
topk_indices = torch.topk(
attention_bias.detach(),
keep_window_size,
dim=-1, largest=True, sorted=False,
).indices
attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool).scatter_(
-1, topk_indices, True
)

out = flash_fn(
query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal
Expand Down
Loading