Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Introduce an optional block_size parameter to enhance the attention mask generation by aggregating top-k selections, reducing fragmentation, and improving locality. Validate block_size as a positive integer and ensure previous behavior remains intact when unset. Update documentation accordingly.

Introduces an optional block_size to aggregate top-k selections along the key dimension using a majority vote, reducing fragmentation and encouraging locality in the dynamic mask.

Validates block_size as a positive integer, handles remainder tails, and forwards the parameter through mask creation. Updates docs accordingly.

Preserves previous behavior when unset and uses in-place ops for efficiency.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds an optional block_size parameter to enable block-wise aggregation and smoothing of attention masks after top-k selection. The feature applies a majority voting mechanism where blocks are kept if more than half of their elements are selected.

  • Added block_size parameter to dynamic_mask() and create_mask() functions
  • Implemented block-wise smoothing logic with majority voting for both full blocks and tail elements
  • Added input validation for block_size to ensure it's a positive integer

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

block_shape = (*block_view.shape[:-1], full_len // block_size, block_size)
blocks = block_view.view(*block_shape)
block_counts = blocks.sum(dim=-1).to(torch.int32)
block_keep = (block_counts * 2) > block_size
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition (block_counts * 2) > block_size is a magic number pattern that obscures the majority voting logic. Consider extracting this as a named constant or adding an inline comment to clarify this represents 'keep block if more than 50% of elements are True'.

Copilot uses AI. Check for mistakes.
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
tail_keep = (tail_counts * 2) > tail_len
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as line 67, this condition (tail_counts * 2) > tail_len uses a magic number pattern. Consider adding a comment or using a named constant to clarify this represents the majority voting threshold (>50%).

Copilot uses AI. Check for mistakes.
Comment on lines +66 to +73
block_counts = blocks.sum(dim=-1).to(torch.int32)
block_keep = (block_counts * 2) > block_size
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))

if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converting to torch.int32 may cause unnecessary precision loss or overflow for large block sizes. Consider using torch.int64 for safer integer arithmetic, or document why int32 is sufficient for the expected use cases.

Suggested change
block_counts = blocks.sum(dim=-1).to(torch.int32)
block_keep = (block_counts * 2) > block_size
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
block_counts = blocks.sum(dim=-1).to(torch.int64)
block_keep = (block_counts * 2) > block_size
blocks.copy_(block_keep.unsqueeze(-1).expand_as(blocks))
if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64)

Copilot uses AI. Check for mistakes.
if key_len > full_len:
tail_slice = attention_mask[..., full_len:]
tail_len = tail_slice.shape[-1]
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as line 66, converting to torch.int32 may cause precision loss. Consider using torch.int64 for consistency and to handle edge cases with large tensors safely.

Suggested change
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int32)
tail_counts = tail_slice.sum(dim=-1, keepdim=True).to(torch.int64)

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit d883adc into main Oct 30, 2025
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants