Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Streamline the flex attention interface by removing unnecessary return values and standardizing parameter names for clarity and consistency. Adjustments include renaming parameters and ensuring tensor transpositions match expected input formats. Disabling compilation aids in debugging.

Renames parameter names for consistency and adds tensor transpositions to match expected input format.

Removes attention weights/LSE computation and return value to streamline the interface, keeping only the attention output.

Disables compilation for debugging purposes and updates parameter name from scaling to scale.
Renames function parameters to follow more conventional naming patterns:
- `causal` becomes `is_causal` for boolean clarity
- `q`, `k`, `v` become `query`, `key`, `value` for readability
- `mask`, `bias` become `attn_mask`, `attn_bias` for specificity

Updates function signatures, internal usage, and wrapper function to maintain consistency throughout the codebase.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR streamlines the flex attention interface by standardizing parameter names and removing unnecessary return values. The changes focus on improving clarity and consistency across the attention implementation.

Key changes:

  • Rename parameters from causal to is_causal and standardize attention parameter names (q/k/v to query/key/value, mask/bias to attn_mask/attn_bias)
  • Remove unnecessary return values from flex attention function (attention weights/LSE)
  • Add tensor transpositions to match expected input formats and disable compilation for debugging

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
flash_dmattn/flash_dmattn_triton.py Standardizes parameter names throughout the triton implementation and creates wrapper function
flash_dmattn/flash_dmattn_flex.py Adds tensor transpositions, removes attention weights return value, and disables compilation

Comment on lines +20 to +21
attn_mask = attn_mask[:, :, :, : key.shape[-2]]
attn_bias = attn_bias[:, :, :, : key.shape[-2]]
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensor slicing uses key.shape[-2] but after transposition on line 18, the key tensor shape has changed. This should use the sequence length dimension from the transposed tensor, which would be key.shape[2] instead of key.shape[-2].

Suggested change
attn_mask = attn_mask[:, :, :, : key.shape[-2]]
attn_bias = attn_bias[:, :, :, : key.shape[-2]]
attn_mask = attn_mask[:, :, :, : key.shape[2]]
attn_bias = attn_bias[:, :, :, : key.shape[2]]

Copilot uses AI. Check for mistakes.
key = key.transpose(1, 2).contiguous() # [B, H, KV_LEN, D]
value = value.transpose(1, 2).contiguous() # [B, H, KV_LEN, D]
attn_mask = attn_mask[:, :, :, : key.shape[-2]]
attn_bias = attn_bias[:, :, :, : key.shape[-2]]
Copy link

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the attn_mask slicing, this uses key.shape[-2] but should use key.shape[2] after the tensor transposition performed on line 18.

Suggested change
attn_bias = attn_bias[:, :, :, : key.shape[-2]]
attn_bias = attn_bias[:, :, :, : key.shape[2]]

Copilot uses AI. Check for mistakes.
Activates the compile flag to improve performance through kernel optimization during flex attention computation.
@LoserCheems LoserCheems merged commit 6403ad1 into main Aug 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants