Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Ensure the attention mask correctly processes topk values by validating them before applying the mask, preventing incorrect masking of valid attention scores.

Ensures that attention mask correctly handles cases where topk values
equal the minimum dtype value by checking validity before scattering.

Previously scattered a constant 1.0 for all topk indices regardless
of their validity, which could incorrectly mask valid attention scores.
Copilot AI review requested due to automatic review settings August 29, 2025 08:23
@LoserCheems LoserCheems merged commit 61c2924 into main Aug 29, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Fixes attention mask handling to properly validate topk values before applying them to the mask, preventing incorrect masking of valid attention scores when topk values equal the minimum dtype value.

  • Extracts both topk values and indices from torch.topk instead of just indices
  • Validates topk values against min_dtype to identify valid attention scores
  • Uses the validation mask when scattering values to prevent masking valid scores

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines 360 to +361
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, 1.0)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
Copy link

Copilot AI Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scatter operation may produce incorrect results when valid_topk contains False values. Scattering 0.0 values (from valid_topk.to(dtype) where valid_topk is False) will explicitly set those positions to 0.0, but the attn_mask is already initialized with zeros. This could mask positions that should remain unmasked if there are fewer than keep_window_size valid values. Consider only scattering indices where valid_topk is True.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants