Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Clarify the shape descriptions for attention_mask and attention_bias in the documentation. Eliminate unnecessary dimension checks in the attention forward function to streamline the code.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR improves the documentation and code for flash dynamic mask attention by clarifying tensor shape descriptions and removing redundant dimension handling logic.

  • Clarifies shape documentation for attention_mask and attention_bias parameters to be more explicit about supported dimensions
  • Removes unnecessary dimension compatibility checks in the attention forward function

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
flash_dynamic_mask_attention.py Updates parameter documentation to clarify supported tensor shapes
modeling_flash_dynamic_mask_attention_utils.py Removes redundant dimension checking and reshaping logic

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

and key_length > window_size
):
min_dtype = torch.finfo(query_states.dtype).min
if attention_mask is not None:
Copy link

Copilot AI Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the dimension compatibility checks, this code will fail if attention_mask and attention_bias have incompatible shapes for broadcasting in the masked_fill operation. The removed checks were handling dimension mismatches between 3D and 4D tensors.

Suggested change
if attention_mask is not None:
if attention_mask is not None:
# Ensure attention_mask is broadcastable to attention_bias
if attention_mask.shape != attention_bias.shape:
# Try to unsqueeze or expand attention_mask to match attention_bias
# Common case: attention_mask is (batch, seq_len) and attention_bias is (batch, num_heads, seq_len, seq_len)
# or attention_mask is (batch, 1, seq_len, seq_len)
# We'll try to expand attention_mask to match attention_bias
try:
attention_mask = attention_mask.expand_as(attention_bias)
except RuntimeError:
# Try unsqueeze and expand for common 3D->4D case
while attention_mask.dim() < attention_bias.dim():
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.expand_as(attention_bias)

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 3d91162 into main Oct 13, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants