Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Eliminate the keep_window_size parameter from the Mask constructor and related functions to streamline the interface. Adjust cache position calculations for attention mechanisms to ensure proper causal masking behavior. This reduces code complexity and enhances performance.

Simplifies the mask initialization by removing the unused keep_window_size parameter from both attention computation functions.

This streamlines the interface and reduces unnecessary parameter passing without affecting functionality.
Removes the keep_window_size parameter from the Mask struct and eliminates
the conditional branching logic that determined whether masking was needed.

Consolidates the masking logic into a single code path that always applies
the mask check, reducing code complexity and potential branching overhead.

The previous optimization that skipped masking when no window size limit
was needed has been removed in favor of a more straightforward approach.
Corrects the cache position tensor generation to properly handle
cases where key length differs from query length by calculating
the starting position as key_len - query_len instead of starting
from zero.

This ensures proper causal masking behavior when dealing with
cached key-value pairs in attention computations.
Copilot AI review requested due to automatic review settings July 28, 2025 06:38

This comment was marked as outdated.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR removes the unused keep_window_size parameter from the Mask constructor and related functions to simplify the codebase. It also corrects cache position calculations in attention mechanisms to ensure proper causal masking behavior.

  • Removes keep_window_size parameter from Mask struct constructor and function calls
  • Fixes cache position calculation to use key_len - query_len instead of 0
  • Comments out alternative masking logic implementation

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
csrc/src/mask.h Removes keep_window_size parameter from Mask struct and comments out alternative masking implementation
csrc/src/flash_fwd_kernel.h Updates Mask constructor calls to remove keep_window_size parameter
benchmarks/benchmark_forward_performance.py Fixes cache position calculation for proper causal masking
benchmarks/benchmark_forward_equivalence.py Fixes cache position calculation for proper causal masking
Comments suppressed due to low confidence (2)


# Create custom causal mask with cache position
cache_position = torch.arange(0, query_len + 0, device=device)
cache_position = torch.arange(key_len - query_len, key_len, device=device)
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache position calculation assumes key_len >= query_len, but this assumption may not always hold. Consider adding validation or handling the case where key_len < query_len to prevent potential negative start values in torch.arange.

Copilot uses AI. Check for mistakes.
)
A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16)

# Create custom causal mask with cache position
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache position calculation assumes key_len >= query_len, but this assumption may not always hold. Consider adding validation or handling the case where key_len < query_len to prevent potential negative start values in torch.arange.

Suggested change
# Create custom causal mask with cache position
# Create custom causal mask with cache position
if key_len < query_len:
raise ValueError(f"Invalid configuration: key_len ({key_len}) must be >= query_len ({query_len}).")

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit aa0cbcc into main Jul 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants