-
Notifications
You must be signed in to change notification settings - Fork 40
[FEATURE SUPPORT] Flexible head dims for mask/bias with in-kernel conversion path #167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces h_mask and h_bias fields to track the number of heads in attention mask and bias structures respectively. Enables better head dimension management and validation in flash attention operations.
Introduces dynamic head index calculation for mask and bias tensors to support different head configurations. Previously used fixed head ratio calculations, now supports three scenarios: - Single head broadcasting (h_mask/h_bias == 1) - Multi-head with ratio-based indexing (h_mask/h_bias == h_k) - Direct head indexing (fallback case) Enables more flexible attention masking and bias application across different multi-head attention configurations.
Introduces conditional head index calculation for mask and bias operations based on tensor dimensions. Supports scenarios where mask/bias tensors can have single head (h=1), match key heads (h=h_k), or match query heads (h=h_q). Replaces hardcoded head index division with dynamic selection logic that adapts to different tensor head configurations in flash attention backward kernel.
Adds support for mask and bias tensors with 1, num_heads_k, or num_heads dimensions instead of only num_heads_k. Enables more flexible attention patterns by allowing masks and biases to be broadcast across different head configurations. Updates parameter passing to track separate head counts for masks and biases, and adds appropriate validation checks. Temporarily disables variable-length attention variants to focus on core functionality improvements.
Clarifies that attention mask and bias parameters support multiple tensor shapes to accommodate Multi-Query Attention (MQA) and Grouped Query Attention (GQA) patterns, in addition to the standard multi-head attention format. Adds explicit documentation for supported shapes including broadcast-compatible dimensions for flexible attention implementations.
Clarifies that attention mask and bias tensors support multiple shape formats to accommodate Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) patterns in addition to the standard multi-head attention format. Adds explicit documentation for supported shapes: standard num_heads format, num_kv_heads format, and broadcast-compatible single head format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for flexible head dimensions in attention masks and biases, allowing for shapes (B, H, Q, K), (B, H_k, Q, K), and (B, 1, Q, K) instead of being restricted to only (B, H_k, Q, K).
- Introduces kernel-side head indexing logic to handle different head dimension configurations
- Updates API and parameter structures to track mask and bias head counts separately
- Implements proper broadcast semantics and conversion paths for different head arrangements
Reviewed Changes
Copilot reviewed 6 out of 7 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| flash_dmattn/integrations/flash_dynamic_mask_attention.py | Updates documentation to reflect flexible mask/bias shape support |
| flash_dmattn/flash_dmattn_interface.py | Updates API documentation for flexible head dimensions |
| csrc/flash_dmattn/src/flash_fwd_kernel.h | Implements head indexing logic for mask/bias in forward kernel |
| csrc/flash_dmattn/src/flash_bwd_kernel.h | Implements head indexing logic for mask/bias in backward kernel |
| csrc/flash_dmattn/src/flash.h | Adds h_mask and h_bias fields to parameter structures |
| csrc/flash_dmattn/flash_api.cpp | Updates function signatures and implements flexible head dimension handling |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k | ||
| const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k | ||
| at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k | ||
| at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k, or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k |
Copilot
AI
Sep 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment format is inconsistent between mask and bias parameters. The bias comment has an extra comma after the first shape specification that should be removed to match the mask comment format.
| at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k, or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k | |
| at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k |
| const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k | ||
| const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k | ||
| const at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k | ||
| const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x num_heads x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k |
Copilot
AI
Sep 13, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order of shape specifications in the bias comment differs from the mask comment above it. For consistency, the bias comment should list shapes in the same order: num_heads first, then num_heads_k, then 1.
| const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x num_heads x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k | |
| const at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k |
Summary
#163
Design
h_mask/h_biasto select:Changes
h_maskandh_biasto drive kernel-side head selection.set_params_fprop/set_params_dgrad.Implementation Notes
Tests
Docs
Checklist