-
Notifications
You must be signed in to change notification settings - Fork 45
Replace attention_mask with cache_position for improved efficiency #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Simplifies causal masking implementation by using cache_position directly instead of constructing full attention masks. Removes complex attention mask handling logic that converted boolean masks and applied slicing operations. Introduces more efficient causal masking using torch.arange comparison against cache_position values. Updates function signatures and documentation across all attention implementations to reflect the parameter change. Improves memory efficiency by eliminating large intermediate mask tensors in favor of position-based calculations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR replaces attention mask handling with cache position for improved efficiency in dynamic mask attention implementations. The change simplifies causal masking by directly using cache position tensors instead of complex attention mask logic, enhancing memory efficiency and streamlining the codebase.
Key changes:
- Replace
attention_maskparameter withcache_positionin function signatures - Simplify causal masking logic using cache position directly in
prepare_dynamic_mask - Update benchmark result naming conventions to use standardized abbreviations
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| benchmarks/forward_performance.py | Updates function signatures and benchmark result keys to use cache_position and standardized naming |
| benchmarks/forward_equivalence.py | Replaces attention mask with cache position and removes complex mask creation logic |
| benchmarks/backward_performance.py | Updates backward pass benchmarks to use cache_position and consistent naming |
| benchmarks/backward_equivalence.py | Simplifies backward equivalence tests by using cache_position instead of attention masks |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| zoh_states: torch.Tensor, | ||
| keep_window_size: int = 2048, | ||
| attention_mask: torch.Tensor | None = None, | ||
| cache_position: torch.Tensor = None, |
Copilot
AI
Aug 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter should use Optional[torch.Tensor] type annotation instead of defaulting to None without proper typing. This makes the API more explicit about accepting None values.
| attn_bias = attn_bias.masked_fill( | ||
| attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype | ||
| torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1), | ||
| min_dtype | ||
| ) |
Copilot
AI
Aug 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential issue with tensor broadcasting when cache_position has more than one element. The reshape(-1, 1) assumes cache_position can be reshaped to a column vector, but this could fail if cache_position has incompatible dimensions with the comparison operation.
benchmarks/forward_performance.py
Outdated
| # # Vary sequence length | ||
| # (1, 2, 1, 256, 256, 128, 1024, True), | ||
| # (1, 2, 1, 512, 512, 128, 1024, True), | ||
| # (1, 2, 1, 1024, 1024, 128, 1024, True), | ||
| # (1, 2, 1, 2048, 2048, 128, 1024, True), | ||
| # (1, 2, 1, 4096, 4096, 128, 1024, True), | ||
| # (1, 2, 1, 8192, 8192, 128, 1024, True), | ||
| # (1, 2, 1, 16384, 16384, 128, 1024, True), | ||
| # (1, 2, 1, 32768, 32768, 128, 1024, True), | ||
|
|
||
| # # Inference | ||
| (1, 2, 1, 1, 256, 128, 2048, True), | ||
| (1, 2, 1, 1, 512, 128, 2048, True), | ||
| (1, 2, 1, 1, 1024, 128, 2048, True), | ||
| (1, 2, 1, 1, 2048, 128, 2048, True), | ||
| (1, 2, 1, 1, 4096, 128, 2048, True), | ||
| (1, 2, 1, 1, 8192, 128, 2048, True), | ||
| (1, 2, 1, 1, 16384, 128, 2048, True), | ||
| (1, 2, 1, 1, 32768, 128, 2048, True), | ||
| (1, 2, 1, 1, 65536, 128, 2048, True), | ||
| (1, 2, 1, 1, 131072, 128, 2048, True), | ||
| (1, 2, 1, 1, 262144, 128, 2048, True), | ||
| (1, 2, 1, 1, 524288, 128, 2048, True), | ||
|
|
||
| # Vary batch size | ||
| (1, 2, 1, 4096, 4096, 32, 2048, True), | ||
| (2, 2, 1, 4096, 4096, 32, 2048, True), | ||
| (4, 2, 1, 4096, 4096, 32, 2048, True), | ||
| (8, 2, 1, 4096, 4096, 32, 2048, True), | ||
|
|
||
| # Vary head count | ||
| (1, 1, 1, 4096, 4096, 32, 2048, True), | ||
| (1, 2, 1, 4096, 4096, 32, 2048, True), | ||
| (1, 4, 1, 4096, 4096, 32, 2048, True), | ||
| (1, 8, 2, 4096, 4096, 32, 2048, True), | ||
|
|
||
| # Vary head dimension | ||
| (1, 2, 1, 4096, 4096, 32, 2048, True), | ||
| (1, 2, 1, 4096, 4096, 64, 2048, True), | ||
| (1, 2, 1, 4096, 4096, 96, 2048, True), | ||
| (1, 2, 1, 4096, 4096, 128, 2048, True), | ||
| (1, 2, 1, 4096, 4096, 192, 2048, True), | ||
| (1, 2, 1, 4096, 4096, 256, 2048, True), | ||
| # (1, 2, 1, 1, 256, 128, 1024, True), | ||
| # (1, 2, 1, 1, 512, 128, 1024, True), | ||
| # (1, 2, 1, 1, 1024, 128, 1024, True), | ||
| # (1, 2, 1, 1, 2048, 128, 1024, True), | ||
| # (1, 2, 1, 1, 4096, 128, 1024, True), | ||
| # (1, 2, 1, 1, 8192, 128, 1024, True), | ||
| # (1, 2, 1, 1, 16384, 128, 1024, True), |
Copilot
AI
Aug 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Large blocks of commented-out test configurations should be removed rather than left as comments. This reduces code clutter and improves maintainability. If these configurations are needed for future testing, they should be documented separately or conditionally enabled.
|
Using |
Uncomments all test configuration sets to run the full suite of performance benchmarks covering sequence length variations, inference scenarios, batch size scaling, head count variations, and head dimension testing. Provides complete performance profiling across different model architectures and use cases.
Simplify causal masking by directly using cache_position, enhancing memory efficiency and streamlining the implementation by removing complex attention mask handling logic. Update function signatures and documentation accordingly.