Skip to content

webgpu: Support QKV bias in FlashAttention for MultiHeadAttention#28380

Merged
guschmue merged 2 commits into
mainfrom
webgpu-flash-attention-bias
May 6, 2026
Merged

webgpu: Support QKV bias in FlashAttention for MultiHeadAttention#28380
guschmue merged 2 commits into
mainfrom
webgpu-flash-attention-bias

Conversation

@qjia7
Copy link
Copy Markdown
Contributor

@qjia7 qjia7 commented May 6, 2026

Summary

  • Remove the bias == nullptr requirement from CanApplyFlashAttention, enabling FlashAttention for MultiHeadAttention nodes with QKV bias (e.g., whisper decoder).
  • Apply TransferBSDToBNSH to add bias and transpose Q/K/V to BNSH format before calling FlashAttention.
  • Handle cross-attention (only Q needs bias+transpose, K/V already BNSH from encoder) and self-attention (all Q/K/V need bias+transpose) separately.

Motivation

Whisper decoder's MultiHeadAttention nodes all have QKV bias, which previously forced them into the slower unfused attention path. Enabling FlashAttention for these nodes yields ~45% speedup on whisper-tiny-int4 (~92 → ~134 tokens/s).

Test plan

  • Existing MHA unit tests with bias data now exercise the FlashAttention path on WebGPU with Subgroups support
  • whisper-tiny-int4 end-to-end: correct transcription at ~134 tps (vs ~92 tps baseline)
  • clang-format passes
  • D3D12 build succeeds

qjia7 added 2 commits May 6, 2026 13:32
Previously FlashAttention required bias == nullptr, forcing all MHA nodes
with QKV bias (e.g. whisper decoder) to fall back to the slower unfused
attention path.

Apply TransferBSDToBNSH to add bias and transpose Q/K/V to BNSH before
calling FlashAttention. Handles both self-attention (all Q/K/V need
bias+transpose) and cross-attention (only Q needs it, K/V already BNSH
from encoder).

Verified: whisper-tiny-int4 produces correct transcription at ~134 tps
(up from ~92 tps baseline, ~45% speedup).
The bias parameter is no longer checked inside CanApplyFlashAttention
since bias handling was moved to the caller (multihead_attention.cc).
Remove the dead parameter from the declaration, definition, and all
three call sites (attention.cc, multihead_attention.cc,
group_query_attention.cc).

Verified: All 16 MultiHeadAttentionTest cases pass, including tests
with bias data that exercise the FlashAttention-with-bias path.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables the WebGPU FlashAttention fast path for com.microsoft.MultiHeadAttention nodes that include QKV bias by pre-applying the bias and converting inputs to the format expected by FlashAttention, instead of forcing the unfused attention path.

Changes:

  • Remove the bias-related gating from CanApplyFlashAttention by dropping the bias parameter and the bias == nullptr check.
  • In WebGPU MultiHeadAttention, when bias is present, run TransferBSDToBNSH to apply bias + transpose into BNSH prior to calling ApplyFlashAttention, with separate handling for cross-attention vs self-attention.
  • Update WebGPU Attention and GroupQueryAttention call sites to match the new CanApplyFlashAttention signature.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc Adds bias-aware pre-processing (bias + transpose to BNSH) so FlashAttention can be used for MHA with QKV bias.
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Updates CanApplyFlashAttention call to new signature.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Updates CanApplyFlashAttention declaration to remove the bias parameter.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Updates CanApplyFlashAttention definition to remove bias gating.
onnxruntime/contrib_ops/webgpu/bert/attention.cc Updates CanApplyFlashAttention call to new signature.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@qjia7 qjia7 requested review from guschmue and hariharans29 May 6, 2026 08:49
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label May 6, 2026
@guschmue guschmue merged commit 3b007a6 into main May 6, 2026
91 of 93 checks passed
@guschmue guschmue deleted the webgpu-flash-attention-bias branch May 6, 2026 16:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants