webgpu: Support QKV bias in FlashAttention for MultiHeadAttention#28380
Merged
Conversation
Previously FlashAttention required bias == nullptr, forcing all MHA nodes with QKV bias (e.g. whisper decoder) to fall back to the slower unfused attention path. Apply TransferBSDToBNSH to add bias and transpose Q/K/V to BNSH before calling FlashAttention. Handles both self-attention (all Q/K/V need bias+transpose) and cross-attention (only Q needs it, K/V already BNSH from encoder). Verified: whisper-tiny-int4 produces correct transcription at ~134 tps (up from ~92 tps baseline, ~45% speedup).
The bias parameter is no longer checked inside CanApplyFlashAttention since bias handling was moved to the caller (multihead_attention.cc). Remove the dead parameter from the declaration, definition, and all three call sites (attention.cc, multihead_attention.cc, group_query_attention.cc). Verified: All 16 MultiHeadAttentionTest cases pass, including tests with bias data that exercise the FlashAttention-with-bias path.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR enables the WebGPU FlashAttention fast path for com.microsoft.MultiHeadAttention nodes that include QKV bias by pre-applying the bias and converting inputs to the format expected by FlashAttention, instead of forcing the unfused attention path.
Changes:
- Remove the bias-related gating from
CanApplyFlashAttentionby dropping thebiasparameter and thebias == nullptrcheck. - In WebGPU MultiHeadAttention, when bias is present, run
TransferBSDToBNSHto apply bias + transpose into BNSH prior to callingApplyFlashAttention, with separate handling for cross-attention vs self-attention. - Update WebGPU Attention and GroupQueryAttention call sites to match the new
CanApplyFlashAttentionsignature.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc | Adds bias-aware pre-processing (bias + transpose to BNSH) so FlashAttention can be used for MHA with QKV bias. |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | Updates CanApplyFlashAttention call to new signature. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | Updates CanApplyFlashAttention declaration to remove the bias parameter. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | Updates CanApplyFlashAttention definition to remove bias gating. |
| onnxruntime/contrib_ops/webgpu/bert/attention.cc | Updates CanApplyFlashAttention call to new signature. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
guschmue
approved these changes
May 6, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
bias == nullptrrequirement fromCanApplyFlashAttention, enabling FlashAttention for MultiHeadAttention nodes with QKV bias (e.g., whisper decoder).TransferBSDToBNSHto add bias and transpose Q/K/V to BNSH format before calling FlashAttention.Motivation
Whisper decoder's MultiHeadAttention nodes all have QKV bias, which previously forced them into the slower unfused attention path. Enabling FlashAttention for these nodes yields ~45% speedup on whisper-tiny-int4 (~92 → ~134 tokens/s).
Test plan