-
Notifications
You must be signed in to change notification settings - Fork 39
Add block-wise smoothing to attention mask #201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces an optional block_size to aggregate top-k selections along the key dimension using a majority vote, reducing fragmentation and encouraging locality in the dynamic mask. Validates block_size as a positive integer, handles remainder tails, and forwards the parameter through mask creation. Updates docs accordingly. Preserves previous behavior when unset and uses in-place ops for efficiency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds an optional block_size parameter to enable block-wise aggregation and smoothing of attention masks after top-k selection. The feature applies a majority voting mechanism where blocks are kept if more than half of their elements are selected.
- Added
block_sizeparameter todynamic_mask()andcreate_mask()functions - Implemented block-wise smoothing logic with majority voting for both full blocks and tail elements
- Added input validation for
block_sizeto ensure it's a positive integer
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| block_shape = (*block_view.shape[:-1], full_len // block_size, block_size) | ||
| blocks = block_view.view(*block_shape) | ||
| block_counts = blocks.sum(dim=-1).to(torch.int32) | ||
| block_keep = (block_counts * 2) > block_size |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition (block_counts * 2) > block_size is a magic number pattern that obscures the majority voting logic. Consider extracting this as a named constant or adding an inline comment to clarify this represents 'keep block if more than 50% of elements are True'.
| tail_slice = attention_mask[..., full_len:] | ||
| tail_len = tail_slice.shape[-1] | ||
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) | ||
| tail_keep = (tail_counts * 2) > tail_len |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as line 67, this condition (tail_counts * 2) > tail_len uses a magic number pattern. Consider adding a comment or using a named constant to clarify this represents the majority voting threshold (>50%).
| block_counts = blocks.sum(dim=-1).to(torch.int32) | ||
| block_keep = (block_counts * 2) > block_size | ||
| blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) | ||
|
|
||
| if key_len > full_len: | ||
| tail_slice = attention_mask[..., full_len:] | ||
| tail_len = tail_slice.shape[-1] | ||
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Converting to torch.int32 may cause unnecessary precision loss or overflow for large block sizes. Consider using torch.int64 for safer integer arithmetic, or document why int32 is sufficient for the expected use cases.
| block_counts = blocks.sum(dim=-1).to(torch.int32) | |
| block_keep = (block_counts * 2) > block_size | |
| blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) | |
| if key_len > full_len: | |
| tail_slice = attention_mask[..., full_len:] | |
| tail_len = tail_slice.shape[-1] | |
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) | |
| block_counts = blocks.sum(dim=-1).to(torch.int64) | |
| block_keep = (block_counts * 2) > block_size | |
| blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) | |
| if key_len > full_len: | |
| tail_slice = attention_mask[..., full_len:] | |
| tail_len = tail_slice.shape[-1] | |
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64) |
| if key_len > full_len: | ||
| tail_slice = attention_mask[..., full_len:] | ||
| tail_len = tail_slice.shape[-1] | ||
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as line 66, converting to torch.int32 may cause precision loss. Consider using torch.int64 for consistency and to handle edge cases with large tensors safely.
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) | |
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64) |
Introduce an optional
block_sizeparameter to enhance the attention mask generation by aggregating top-k selections, reducing fragmentation, and improving locality. Validateblock_sizeas a positive integer and ensure previous behavior remains intact when unset. Update documentation accordingly.