Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Simplify causal masking by directly using cache_position, enhancing memory efficiency and streamlining the implementation by removing complex attention mask handling logic. Update function signatures and documentation accordingly.

Simplifies causal masking implementation by using cache_position directly instead of constructing full attention masks.

Removes complex attention mask handling logic that converted boolean masks and applied slicing operations.

Introduces more efficient causal masking using torch.arange comparison against cache_position values.

Updates function signatures and documentation across all attention implementations to reflect the parameter change.

Improves memory efficiency by eliminating large intermediate mask tensors in favor of position-based calculations.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR replaces attention mask handling with cache position for improved efficiency in dynamic mask attention implementations. The change simplifies causal masking by directly using cache position tensors instead of complex attention mask logic, enhancing memory efficiency and streamlining the codebase.

Key changes:

  • Replace attention_mask parameter with cache_position in function signatures
  • Simplify causal masking logic using cache position directly in prepare_dynamic_mask
  • Update benchmark result naming conventions to use standardized abbreviations

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
benchmarks/forward_performance.py Updates function signatures and benchmark result keys to use cache_position and standardized naming
benchmarks/forward_equivalence.py Replaces attention mask with cache position and removes complex mask creation logic
benchmarks/backward_performance.py Updates backward pass benchmarks to use cache_position and consistent naming
benchmarks/backward_equivalence.py Simplifies backward equivalence tests by using cache_position instead of attention masks

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

zoh_states: torch.Tensor,
keep_window_size: int = 2048,
attention_mask: torch.Tensor | None = None,
cache_position: torch.Tensor = None,
Copy link

Copilot AI Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter should use Optional[torch.Tensor] type annotation instead of defaulting to None without proper typing. This makes the API more explicit about accepting None values.

Copilot uses AI. Check for mistakes.
Comment on lines 102 to 105
attn_bias = attn_bias.masked_fill(
attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype
torch.arange(attn_bias.shape[-1], device=attn_bias.device) > cache_position.reshape(-1, 1),
min_dtype
)
Copy link

Copilot AI Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential issue with tensor broadcasting when cache_position has more than one element. The reshape(-1, 1) assumes cache_position can be reshaped to a column vector, but this could fail if cache_position has incompatible dimensions with the comparison operation.

Copilot uses AI. Check for mistakes.
Comment on lines 723 to 740
# # Vary sequence length
# (1, 2, 1, 256, 256, 128, 1024, True),
# (1, 2, 1, 512, 512, 128, 1024, True),
# (1, 2, 1, 1024, 1024, 128, 1024, True),
# (1, 2, 1, 2048, 2048, 128, 1024, True),
# (1, 2, 1, 4096, 4096, 128, 1024, True),
# (1, 2, 1, 8192, 8192, 128, 1024, True),
# (1, 2, 1, 16384, 16384, 128, 1024, True),
# (1, 2, 1, 32768, 32768, 128, 1024, True),

# # Inference
(1, 2, 1, 1, 256, 128, 2048, True),
(1, 2, 1, 1, 512, 128, 2048, True),
(1, 2, 1, 1, 1024, 128, 2048, True),
(1, 2, 1, 1, 2048, 128, 2048, True),
(1, 2, 1, 1, 4096, 128, 2048, True),
(1, 2, 1, 1, 8192, 128, 2048, True),
(1, 2, 1, 1, 16384, 128, 2048, True),
(1, 2, 1, 1, 32768, 128, 2048, True),
(1, 2, 1, 1, 65536, 128, 2048, True),
(1, 2, 1, 1, 131072, 128, 2048, True),
(1, 2, 1, 1, 262144, 128, 2048, True),
(1, 2, 1, 1, 524288, 128, 2048, True),

# Vary batch size
(1, 2, 1, 4096, 4096, 32, 2048, True),
(2, 2, 1, 4096, 4096, 32, 2048, True),
(4, 2, 1, 4096, 4096, 32, 2048, True),
(8, 2, 1, 4096, 4096, 32, 2048, True),

# Vary head count
(1, 1, 1, 4096, 4096, 32, 2048, True),
(1, 2, 1, 4096, 4096, 32, 2048, True),
(1, 4, 1, 4096, 4096, 32, 2048, True),
(1, 8, 2, 4096, 4096, 32, 2048, True),

# Vary head dimension
(1, 2, 1, 4096, 4096, 32, 2048, True),
(1, 2, 1, 4096, 4096, 64, 2048, True),
(1, 2, 1, 4096, 4096, 96, 2048, True),
(1, 2, 1, 4096, 4096, 128, 2048, True),
(1, 2, 1, 4096, 4096, 192, 2048, True),
(1, 2, 1, 4096, 4096, 256, 2048, True),
# (1, 2, 1, 1, 256, 128, 1024, True),
# (1, 2, 1, 1, 512, 128, 1024, True),
# (1, 2, 1, 1, 1024, 128, 1024, True),
# (1, 2, 1, 1, 2048, 128, 1024, True),
# (1, 2, 1, 1, 4096, 128, 1024, True),
# (1, 2, 1, 1, 8192, 128, 1024, True),
# (1, 2, 1, 1, 16384, 128, 1024, True),
Copy link

Copilot AI Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Large blocks of commented-out test configurations should be removed rather than left as comments. This reduces code clutter and improves maintainability. If these configurations are needed for future testing, they should be documented separately or conditionally enabled.

Copilot uses AI. Check for mistakes.
@LoserCheems
Copy link
Collaborator Author

Using cache_position for decoding instead of passing in a 4D attention mask is a prerequisite for integrating with the Transformers framework.

Uncomments all test configuration sets to run the full suite of performance benchmarks covering sequence length variations, inference scenarios, batch size scaling, head count variations, and head dimension testing.

Provides complete performance profiling across different model architectures and use cases.
@LoserCheems LoserCheems merged commit a629142 into main Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants