-
Notifications
You must be signed in to change notification settings - Fork 39
Add block-wise smoothing to attention mask #201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -22,6 +22,7 @@ def dynamic_mask( | |||||||||||||||||||||||||||||||||||||
| attention_mask: Optional[torch.Tensor], | ||||||||||||||||||||||||||||||||||||||
| window_size: int, | ||||||||||||||||||||||||||||||||||||||
| min_dtype: float, | ||||||||||||||||||||||||||||||||||||||
| block_size: Optional[int] = None, | ||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
| r""" | ||||||||||||||||||||||||||||||||||||||
| This function generates a dynamic mask based on the top-k attention bias. | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -33,11 +34,18 @@ def dynamic_mask( | |||||||||||||||||||||||||||||||||||||
| ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). | ||||||||||||||||||||||||||||||||||||||
| window_size (int): The number of top elements to consider for the mask. | ||||||||||||||||||||||||||||||||||||||
| min_dtype (float): The minimum value to use for masking. | ||||||||||||||||||||||||||||||||||||||
| block_size (Optional[int]): Optional size of aggregation blocks to smooth the | ||||||||||||||||||||||||||||||||||||||
| resulting mask along the key dimension. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||
| attention_mask (Tensor): The attention mask tensor of shape | ||||||||||||||||||||||||||||||||||||||
| ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len). | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| if block_size is not None: | ||||||||||||||||||||||||||||||||||||||
| if int(block_size) != block_size or block_size <= 0: | ||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"block_size must be a positive integer, got {block_size}.") | ||||||||||||||||||||||||||||||||||||||
| block_size = int(block_size) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) if attention_mask is not None else attention_bias | ||||||||||||||||||||||||||||||||||||||
| topk_values, topk_indices = torch.topk( | ||||||||||||||||||||||||||||||||||||||
| attention_bias.detach(), | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -46,6 +54,26 @@ def dynamic_mask( | |||||||||||||||||||||||||||||||||||||
| attention_mask = torch.zeros_like( | ||||||||||||||||||||||||||||||||||||||
| attention_bias, dtype=torch.bool, device=attention_bias.device | ||||||||||||||||||||||||||||||||||||||
| ).scatter_(-1, topk_indices, topk_values != min_dtype) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if block_size is not None and block_size > 1: | ||||||||||||||||||||||||||||||||||||||
| key_len = attention_mask.shape[-1] | ||||||||||||||||||||||||||||||||||||||
| full_len = (key_len // block_size) * block_size | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if full_len: | ||||||||||||||||||||||||||||||||||||||
| block_view = attention_mask[..., :full_len] | ||||||||||||||||||||||||||||||||||||||
| block_shape = (*block_view.shape[:-1], full_len // block_size, block_size) | ||||||||||||||||||||||||||||||||||||||
| blocks = block_view.view(*block_shape) | ||||||||||||||||||||||||||||||||||||||
| block_counts = blocks.sum(dim=-1).to(torch.int32) | ||||||||||||||||||||||||||||||||||||||
| block_keep = (block_counts * 2) > block_size | ||||||||||||||||||||||||||||||||||||||
| blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if key_len > full_len: | ||||||||||||||||||||||||||||||||||||||
| tail_slice = attention_mask[..., full_len:] | ||||||||||||||||||||||||||||||||||||||
| tail_len = tail_slice.shape[-1] | ||||||||||||||||||||||||||||||||||||||
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+66
to
+73
|
||||||||||||||||||||||||||||||||||||||
| block_counts = blocks.sum(dim=-1).to(torch.int32) | |
| block_keep = (block_counts * 2) > block_size | |
| blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) | |
| if key_len > full_len: | |
| tail_slice = attention_mask[..., full_len:] | |
| tail_len = tail_slice.shape[-1] | |
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) | |
| block_counts = blocks.sum(dim=-1).to(torch.int64) | |
| block_keep = (block_counts * 2) > block_size | |
| blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks)) | |
| if key_len > full_len: | |
| tail_slice = attention_mask[..., full_len:] | |
| tail_len = tail_slice.shape[-1] | |
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64) |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as line 66, converting to torch.int32 may cause precision loss. Consider using torch.int64 for consistency and to handle edge cases with large tensors safely.
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32) | |
| tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64) |
Copilot
AI
Oct 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as line 67, this condition (tail_counts * 2) > tail_len uses a magic number pattern. Consider adding a comment or using a named constant to clarify this represents the majority voting threshold (>50%).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
(block_counts * 2) > block_sizeis a magic number pattern that obscures the majority voting logic. Consider extracting this as a named constant or adding an inline comment to clarify this represents 'keep block if more than 50% of elements are True'.