-
Notifications
You must be signed in to change notification settings - Fork 39
Refactor attention block smoothing for consistency #205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces a shared block smoothing helper so top-k and ReLU masks keep consistent voting behavior while avoiding duplicated logic. Stops coercing attention scores to float, keeping native dtypes when ranking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors duplicate block smoothing logic in mask generation functions by extracting it into a reusable block_smooth function. The refactoring reduces code duplication while maintaining identical functionality.
Key changes:
- Introduced a new
block_smoothhelper function to encapsulate block-based mask smoothing logic - Replaced duplicate block smoothing code in both
topk_maskandrelu_maskfunctions with calls to the new helper - Changed dtype conversion from
torch.int32totorch.int64for block counts - Removed redundant
block_sizevalidation and casting that existed in both functions
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| attention_mask: torch.Tensor, | ||
| key_len: int, | ||
| block_size: int, | ||
| ): |
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The newly introduced block_smooth function is missing a docstring. Add documentation explaining its purpose, parameters, and return value to maintain consistency with other functions in the module like topk_mask and relu_mask.
| ): | |
| ): | |
| """ | |
| Smooths the attention mask by aggregating values in non-overlapping blocks along the key dimension. | |
| For each block of size `block_size`, the block is set to True if more than half of its elements are True, | |
| otherwise it is set to False. The tail (if key_len is not divisible by block_size) is handled similarly. | |
| Args: | |
| attention_mask (torch.Tensor): Boolean tensor of shape (..., key_len) representing the attention mask. | |
| key_len (int): The length of the key dimension. | |
| block_size (int): The size of each block for smoothing. Must be a positive integer. | |
| Returns: | |
| torch.Tensor: The smoothed attention mask tensor of the same shape as the input. | |
| """ |
| key_len: int, | ||
| block_size: int, | ||
| ): | ||
| if block_size <= 0: |
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation only checks if block_size <= 0, but doesn't verify that block_size is an integer. The previous code included int(block_size) != block_size checks in both topk_mask and relu_mask. Consider adding integer validation here: if not isinstance(block_size, int) or block_size <= 0: to maintain the same input validation behavior.
| if block_size <= 0: | |
| if not isinstance(block_size, int) or block_size <= 0: |
|
|
||
|
|
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is trailing whitespace on line 47. Remove the extra whitespace after the return statement.
| attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias | ||
| topk_values, topk_indices = torch.topk( | ||
| attention_bias.to(torch.float), | ||
| attention_bias, |
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The .to(torch.float) conversion was removed from the torch.topk call. While this may be intentional to preserve the original dtype, it changes the existing behavior. If attention_bias is not already torch.float, this could affect numerical precision in the topk operation. Verify this is the intended behavior or document why the dtype conversion was removed.
| attention_bias, | |
| attention_bias.to(torch.float), |
Summary
Root Cause
Changes
Reproduction
Tests
Compatibility
Checklist