-
Notifications
You must be signed in to change notification settings - Fork 39
Remove unused parameters and simplify mask logic #77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Simplifies the mask initialization by removing the unused keep_window_size parameter from both attention computation functions. This streamlines the interface and reduces unnecessary parameter passing without affecting functionality.
Removes the keep_window_size parameter from the Mask struct and eliminates the conditional branching logic that determined whether masking was needed. Consolidates the masking logic into a single code path that always applies the mask check, reducing code complexity and potential branching overhead. The previous optimization that skipped masking when no window size limit was needed has been removed in favor of a more straightforward approach.
Corrects the cache position tensor generation to properly handle cases where key length differs from query length by calculating the starting position as key_len - query_len instead of starting from zero. This ensures proper causal masking behavior when dealing with cached key-value pairs in attention computations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR removes the unused keep_window_size parameter from the Mask constructor and related functions to simplify the codebase. It also corrects cache position calculations in attention mechanisms to ensure proper causal masking behavior.
- Removes
keep_window_sizeparameter from Mask struct constructor and function calls - Fixes cache position calculation to use
key_len - query_leninstead of0 - Comments out alternative masking logic implementation
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| csrc/src/mask.h | Removes keep_window_size parameter from Mask struct and comments out alternative masking implementation |
| csrc/src/flash_fwd_kernel.h | Updates Mask constructor calls to remove keep_window_size parameter |
| benchmarks/benchmark_forward_performance.py | Fixes cache position calculation for proper causal masking |
| benchmarks/benchmark_forward_equivalence.py | Fixes cache position calculation for proper causal masking |
Comments suppressed due to low confidence (2)
|
|
||
| # Create custom causal mask with cache position | ||
| cache_position = torch.arange(0, query_len + 0, device=device) | ||
| cache_position = torch.arange(key_len - query_len, key_len, device=device) |
Copilot
AI
Jul 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cache position calculation assumes key_len >= query_len, but this assumption may not always hold. Consider adding validation or handling the case where key_len < query_len to prevent potential negative start values in torch.arange.
| ) | ||
| A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) | ||
|
|
||
| # Create custom causal mask with cache position |
Copilot
AI
Jul 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cache position calculation assumes key_len >= query_len, but this assumption may not always hold. Consider adding validation or handling the case where key_len < query_len to prevent potential negative start values in torch.arange.
| # Create custom causal mask with cache position | |
| # Create custom causal mask with cache position | |
| if key_len < query_len: | |
| raise ValueError(f"Invalid configuration: key_len ({key_len}) must be >= query_len ({query_len}).") |
Eliminate the
keep_window_sizeparameter from the Mask constructor and related functions to streamline the interface. Adjust cache position calculations for attention mechanisms to ensure proper causal masking behavior. This reduces code complexity and enhances performance.