-
Notifications
You must be signed in to change notification settings - Fork 39
Simplify and standardize flex attention interface #90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Renames parameter names for consistency and adds tensor transpositions to match expected input format. Removes attention weights/LSE computation and return value to streamline the interface, keeping only the attention output. Disables compilation for debugging purposes and updates parameter name from scaling to scale.
Renames function parameters to follow more conventional naming patterns: - `causal` becomes `is_causal` for boolean clarity - `q`, `k`, `v` become `query`, `key`, `value` for readability - `mask`, `bias` become `attn_mask`, `attn_bias` for specificity Updates function signatures, internal usage, and wrapper function to maintain consistency throughout the codebase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR streamlines the flex attention interface by standardizing parameter names and removing unnecessary return values. The changes focus on improving clarity and consistency across the attention implementation.
Key changes:
- Rename parameters from
causaltois_causaland standardize attention parameter names (q/k/vtoquery/key/value,mask/biastoattn_mask/attn_bias) - Remove unnecessary return values from flex attention function (attention weights/LSE)
- Add tensor transpositions to match expected input formats and disable compilation for debugging
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| flash_dmattn/flash_dmattn_triton.py | Standardizes parameter names throughout the triton implementation and creates wrapper function |
| flash_dmattn/flash_dmattn_flex.py | Adds tensor transpositions, removes attention weights return value, and disables compilation |
| attn_mask = attn_mask[:, :, :, : key.shape[-2]] | ||
| attn_bias = attn_bias[:, :, :, : key.shape[-2]] |
Copilot
AI
Aug 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tensor slicing uses key.shape[-2] but after transposition on line 18, the key tensor shape has changed. This should use the sequence length dimension from the transposed tensor, which would be key.shape[2] instead of key.shape[-2].
| attn_mask = attn_mask[:, :, :, : key.shape[-2]] | |
| attn_bias = attn_bias[:, :, :, : key.shape[-2]] | |
| attn_mask = attn_mask[:, :, :, : key.shape[2]] | |
| attn_bias = attn_bias[:, :, :, : key.shape[2]] |
| key = key.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] | ||
| value = value.transpose(1, 2).contiguous() # [B, H, KV_LEN, D] | ||
| attn_mask = attn_mask[:, :, :, : key.shape[-2]] | ||
| attn_bias = attn_bias[:, :, :, : key.shape[-2]] |
Copilot
AI
Aug 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the attn_mask slicing, this uses key.shape[-2] but should use key.shape[2] after the tensor transposition performed on line 18.
| attn_bias = attn_bias[:, :, :, : key.shape[-2]] | |
| attn_bias = attn_bias[:, :, :, : key.shape[2]] |
Activates the compile flag to improve performance through kernel optimization during flex attention computation.
Streamline the flex attention interface by removing unnecessary return values and standardizing parameter names for clarity and consistency. Adjustments include renaming parameters and ensuring tensor transpositions match expected input formats. Disabling compilation aids in debugging.