Skip to content

[WebGPU] Fuse Q/K RMSNorm into GroupQueryAttention for Qwen3-style models#28484

Open
hariharans29 wants to merge 15 commits into
mainfrom
hari/webgpu_perf_2
Open

[WebGPU] Fuse Q/K RMSNorm into GroupQueryAttention for Qwen3-style models#28484
hariharans29 wants to merge 15 commits into
mainfrom
hari/webgpu_perf_2

Conversation

@hariharans29
Copy link
Copy Markdown
Member

@hariharans29 hariharans29 commented May 12, 2026

Description

Adds a per-head Q/K RMS normalization prologue to the WebGPU GroupQueryAttention
contrib op so that Qwen3-style models can fold the standalone
SimplifiedLayerNormalization dispatches on Q and K into the attention kernel.

The MS-domain GroupQueryAttention schema gains two optional inputs and one
attribute:

  • q_norm_weight (input 14, optional, shape (head_size,))
  • k_norm_weight (input 15, optional, shape (head_size,))
  • qk_norm_epsilon (FLOAT attribute, default 1e-6)

A new Level 2 optimizer pass GroupQueryAttentionPreNormFusion rewrites the
Reshape -> SimplifiedLayerNormalization -> Reshape -> GroupQueryAttention
pattern produced by the Qwen3 ONNX export into a single GroupQueryAttention
node that carries the norm weights directly. The pass is scoped to the WebGPU
EP only.

WebGPU runtime paths

sequence_length Behavior
1 (decode) RMSNorm is folded into FusedQKRotaryEmbedding.
> 1 (prefill) Falls back to two standalone SimplifiedLayerNorm dispatches into scratch qNorm/kNorm tensors, then runs the unfused FusedQKRotaryEmbedding. Matches the pre-fusion graph timing exactly so prefill cannot regress.

Gating

The optimizer only rewrites nodes that the WebGPU kernel can handle:

  • do_rotary == 1 (the fused decode path runs inside the rotary kernel; the
    prefill fallback also runs only on the do_rotary=1 branch)
  • Non-packed QKV (K at slot 1, V at slot 2 wired)
  • Already-fused nodes (q/k_norm_weight present) are skipped
  • The pattern must use a single norm weight per side with matching epsilon and
    axis=-1, and the surrounding reshapes must collapse to (... head_size) /
    expand back to (... hidden_size).

Defense-in-depth on other EPs

The CPU, CUDA, and JSEP GroupQueryAttention kernels reject hand-authored
models that wire inputs 14/15 with EP-specific error messages, since none of
those runtimes implement the prologue:

  • CPU: GroupQueryAttention (CPU): q_norm_weight / k_norm_weight inputs are not supported...
  • CUDA: GroupQueryAttention (CUDA): q_norm_weight / k_norm_weight inputs are not supported...
  • JSEP: GroupQueryAttention (JSEP): q_norm_weight / k_norm_weight inputs are not supported...

Motivation and Context

About ~6-8% decode TPS throughput improvement on WebGPU + D3D backend on Windows. GPU used: RTX 5060Ti for Qwen3-1.7B.

Adds GroupQueryAttentionPreNormFusion optimizer that folds the pre-rotary Reshape -> SimplifiedLayerNormalization -> Reshape pattern on Q and K into kMSDomain GroupQueryAttention via new optional q_norm_weight/k_norm_weight inputs and an epsilon attribute. The WebGPU GQA decode fast path applies the RMSNorm inline in FusedQKRotaryEmbedding. Includes optimizer unit test.
Move the LayerNormProgram configuration/dispatch logic out of LayerNorm<>::ComputeInternal into a free RunLayerNormProgram helper exposed by layer_norm.h, so callers other than the LayerNorm kernel (e.g. the GroupQueryAttention prefill fallback that runs SimplifiedLayerNorm on Q and K when fused per-head RMSNorm is requested) do not need to duplicate the program-setup code.

No behavior change. Emitted WGSL and dispatch are byte-identical.
Scope the fusion to the WebGPU EP only (drop JSEP from the compatible-EP set) and add explicit gates so the rewrite cannot land on configurations the WebGPU kernel does not implement:

* do_rotary must be 1 -- the fused decode prologue is folded into the rotary kernel and the prefill fallback also runs only on the do_rotary=1 path.
* Slots 1 (key) and 2 (value) must be wired -- excludes the packed-QKV form, which the WebGPU fused prologue does not support.
* Hardened the already-fused early-skip to cover both inputs 14 and 15.

Add unit tests covering each new gate (do_rotary=0 skip, packed-QKV skip, already-fused skip, JSEP EP skip) and update the existing fixture to register the optimizer with WebGPU only and to set do_rotary=1 on the GQA node.
Add defense-in-depth rejections in the CPU and CUDA contrib GroupQueryAttention kernels and in the JSEP TS dispatch so a hand-authored model that wires inputs 14/15 on those EPs fails fast instead of silently dropping the per-head Q/K RMS normalization.

Each EP uses an EP-specific error message (CPU / CUDA / JSEP) so the diagnostic identifies which kernel rejected the node. The WebGPU EP is the only runtime that implements the prologue; the GroupQueryAttentionPreNormFusion optimizer pass is scoped accordingly.
CheckUnfusedGraph asserts that slot 14 on the GQA node is empty, but the SkipsAlreadyFusedNode test deliberately pre-wires slot 14 to exercise the already-fused early-skip path. Replace the shared checker with a local one that only verifies the surrounding SimplifiedLayerNormalization and Reshape ops were not removed.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the MS-domain GroupQueryAttention contrib op to optionally fuse a Qwen3-style per-head Q/K RMSNorm prologue (via new optional inputs q_norm_weight/k_norm_weight and attribute qk_norm_epsilon), and adds a Level 2 graph transformer (GroupQueryAttentionPreNormFusion) to fold the exported Reshape -> SimplifiedLayerNormalization -> Reshape pattern into a single GroupQueryAttention node for the WebGPU EP.

Changes:

  • Add GroupQueryAttentionPreNormFusion optimizer pass (WebGPU-only) plus a dedicated unit test.
  • Extend GroupQueryAttention schema with optional norm inputs and epsilon attribute; add EP-side defensive rejections where unsupported.
  • Implement WebGPU runtime support: decode path fuses norm into FusedQKRotaryEmbedding, and prefill path falls back to standalone simplified layer norm dispatches using shared RunLayerNormProgram.

Reviewed changes

Copilot reviewed 14 out of 15 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc New unit tests covering positive fusion and key rejection/gating scenarios
onnxruntime/core/providers/webgpu/nn/layer_norm.h Exposes RunLayerNormProgram() helper for shared LayerNorm dispatch setup
onnxruntime/core/providers/webgpu/nn/layer_norm.cc Refactors LayerNorm compute to reuse RunLayerNormProgram()
onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h Namespace closing indentation fix
onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h Declares new Level 2 graph transformer for Q/K pre-norm fusion
onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc Implements the fusion matcher and graph rewrite
onnxruntime/core/optimizer/graph_transformer_utils.cc Registers the transformer (WebGPU-only compatibility set)
onnxruntime/core/graph/contrib_ops/bert_defs.cc Extends GroupQueryAttention schema with new inputs/attribute
onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h Adds has_qk_norm and new uniforms for fused norm+RoPE shader
onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc Implements fused RMSNorm+RoPE shader logic and updated IO/uniform wiring
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h Adds kernel attribute storage for qk_norm_epsilon
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Implements decode fused path and prefill fallback using LayerNorm helper
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Rejects models that provide unsupported norm inputs
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Rejects models that provide unsupported norm inputs
js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts Rejects models that provide unsupported norm inputs in JSEP implementation
Comments suppressed due to low confidence (1)

onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc:50

  • The pattern comment has the Reshape order reversed. The code matches: projection -> inner Reshape -> SLN -> outer Reshape -> consumer. Please fix the comment so it matches the actual matcher, otherwise future modifications are likely to break the intended semantics.
// Walks back from `consumer` via input slot `consumer_input_index` and matches:
//     producer_proj -> Reshape(reshape_outer) -> SimplifiedLayerNormalization(sln) -> Reshape(reshape_inner) -> consumer
// On success returns true and fills the out-pointers. Each intermediate node must have a single
// consumer (the next op in the chain) and must not be a graph output.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc Outdated
Comment thread onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h Outdated
Comment thread onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
@hariharans29 hariharans29 added the ep:WebGPU ort-web webgpu provider label May 13, 2026
@tianleiwu tianleiwu requested a review from Copilot May 13, 2026 17:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 14 out of 15 changed files in this pull request and generated 3 comments.

Comment on lines +221 to +223
const Tensor* q_norm_weight = context.InputCount() > 14 ? context.Input<Tensor>(14) : nullptr;
const Tensor* k_norm_weight = context.InputCount() > 15 ? context.Input<Tensor>(15) : nullptr;
const bool has_qk_norm = (q_norm_weight != nullptr) && (k_norm_weight != nullptr);
Comment on lines +127 to +135
// Norm weight must be an initializer of shape [head_size].
NodeArg* norm_weight_arg = sln->MutableInputDefs()[1];
const ONNX_NAMESPACE::TensorProto* norm_weight_tensor =
graph_utils::GetConstantInitializer(graph, norm_weight_arg->Name());
if (norm_weight_tensor == nullptr) {
return false;
}
if (norm_weight_tensor->dims_size() != 1 || norm_weight_tensor->dims(0) != expected_head_size) {
return false;
Comment on lines +335 to +338
if (
(context.inputs.length > 14 && context.inputs[14] && context.inputs[14].dims.length > 0) ||
(context.inputs.length > 15 && context.inputs[15] && context.inputs[15].dims.length > 0)
) {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants