[WebGPU] Fuse Q/K RMSNorm into GroupQueryAttention for Qwen3-style models#28484
Open
hariharans29 wants to merge 15 commits into
Open
[WebGPU] Fuse Q/K RMSNorm into GroupQueryAttention for Qwen3-style models#28484hariharans29 wants to merge 15 commits into
hariharans29 wants to merge 15 commits into
Conversation
Adds GroupQueryAttentionPreNormFusion optimizer that folds the pre-rotary Reshape -> SimplifiedLayerNormalization -> Reshape pattern on Q and K into kMSDomain GroupQueryAttention via new optional q_norm_weight/k_norm_weight inputs and an epsilon attribute. The WebGPU GQA decode fast path applies the RMSNorm inline in FusedQKRotaryEmbedding. Includes optimizer unit test.
Move the LayerNormProgram configuration/dispatch logic out of LayerNorm<>::ComputeInternal into a free RunLayerNormProgram helper exposed by layer_norm.h, so callers other than the LayerNorm kernel (e.g. the GroupQueryAttention prefill fallback that runs SimplifiedLayerNorm on Q and K when fused per-head RMSNorm is requested) do not need to duplicate the program-setup code. No behavior change. Emitted WGSL and dispatch are byte-identical.
Scope the fusion to the WebGPU EP only (drop JSEP from the compatible-EP set) and add explicit gates so the rewrite cannot land on configurations the WebGPU kernel does not implement: * do_rotary must be 1 -- the fused decode prologue is folded into the rotary kernel and the prefill fallback also runs only on the do_rotary=1 path. * Slots 1 (key) and 2 (value) must be wired -- excludes the packed-QKV form, which the WebGPU fused prologue does not support. * Hardened the already-fused early-skip to cover both inputs 14 and 15. Add unit tests covering each new gate (do_rotary=0 skip, packed-QKV skip, already-fused skip, JSEP EP skip) and update the existing fixture to register the optimizer with WebGPU only and to set do_rotary=1 on the GQA node.
Add defense-in-depth rejections in the CPU and CUDA contrib GroupQueryAttention kernels and in the JSEP TS dispatch so a hand-authored model that wires inputs 14/15 on those EPs fails fast instead of silently dropping the per-head Q/K RMS normalization. Each EP uses an EP-specific error message (CPU / CUDA / JSEP) so the diagnostic identifies which kernel rejected the node. The WebGPU EP is the only runtime that implements the prologue; the GroupQueryAttentionPreNormFusion optimizer pass is scoped accordingly.
CheckUnfusedGraph asserts that slot 14 on the GQA node is empty, but the SkipsAlreadyFusedNode test deliberately pre-wires slot 14 to exercise the already-fused early-skip path. Replace the shared checker with a local one that only verifies the surrounding SimplifiedLayerNormalization and Reshape ops were not removed.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR extends the MS-domain GroupQueryAttention contrib op to optionally fuse a Qwen3-style per-head Q/K RMSNorm prologue (via new optional inputs q_norm_weight/k_norm_weight and attribute qk_norm_epsilon), and adds a Level 2 graph transformer (GroupQueryAttentionPreNormFusion) to fold the exported Reshape -> SimplifiedLayerNormalization -> Reshape pattern into a single GroupQueryAttention node for the WebGPU EP.
Changes:
- Add
GroupQueryAttentionPreNormFusionoptimizer pass (WebGPU-only) plus a dedicated unit test. - Extend
GroupQueryAttentionschema with optional norm inputs and epsilon attribute; add EP-side defensive rejections where unsupported. - Implement WebGPU runtime support: decode path fuses norm into
FusedQKRotaryEmbedding, and prefill path falls back to standalone simplified layer norm dispatches using sharedRunLayerNormProgram.
Reviewed changes
Copilot reviewed 14 out of 15 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc | New unit tests covering positive fusion and key rejection/gating scenarios |
| onnxruntime/core/providers/webgpu/nn/layer_norm.h | Exposes RunLayerNormProgram() helper for shared LayerNorm dispatch setup |
| onnxruntime/core/providers/webgpu/nn/layer_norm.cc | Refactors LayerNorm compute to reuse RunLayerNormProgram() |
| onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h | Namespace closing indentation fix |
| onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h | Declares new Level 2 graph transformer for Q/K pre-norm fusion |
| onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc | Implements the fusion matcher and graph rewrite |
| onnxruntime/core/optimizer/graph_transformer_utils.cc | Registers the transformer (WebGPU-only compatibility set) |
| onnxruntime/core/graph/contrib_ops/bert_defs.cc | Extends GroupQueryAttention schema with new inputs/attribute |
| onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h | Adds has_qk_norm and new uniforms for fused norm+RoPE shader |
| onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc | Implements fused RMSNorm+RoPE shader logic and updated IO/uniform wiring |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h | Adds kernel attribute storage for qk_norm_epsilon |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | Implements decode fused path and prefill fallback using LayerNorm helper |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc | Rejects models that provide unsupported norm inputs |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc | Rejects models that provide unsupported norm inputs |
| js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts | Rejects models that provide unsupported norm inputs in JSEP implementation |
Comments suppressed due to low confidence (1)
onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc:50
- The pattern comment has the Reshape order reversed. The code matches: projection -> inner Reshape -> SLN -> outer Reshape -> consumer. Please fix the comment so it matches the actual matcher, otherwise future modifications are likely to break the intended semantics.
// Walks back from `consumer` via input slot `consumer_input_index` and matches:
// producer_proj -> Reshape(reshape_outer) -> SimplifiedLayerNormalization(sln) -> Reshape(reshape_inner) -> consumer
// On success returns true and fills the out-pointers. Each intermediate node must have a single
// consumer (the next op in the chain) and must not be a graph output.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…ect WebGPU-only scope
…test MakeInput<int>(shape, min, max) calls Uniform(min, max - 1) which asserts when min == max. Switch the seqlens_k and total_seq_len inputs to the explicit-data overload so the (0,0) and (1,1) ranges are no longer treated as integer distributions.
Comment on lines
+221
to
+223
| const Tensor* q_norm_weight = context.InputCount() > 14 ? context.Input<Tensor>(14) : nullptr; | ||
| const Tensor* k_norm_weight = context.InputCount() > 15 ? context.Input<Tensor>(15) : nullptr; | ||
| const bool has_qk_norm = (q_norm_weight != nullptr) && (k_norm_weight != nullptr); |
Comment on lines
+127
to
+135
| // Norm weight must be an initializer of shape [head_size]. | ||
| NodeArg* norm_weight_arg = sln->MutableInputDefs()[1]; | ||
| const ONNX_NAMESPACE::TensorProto* norm_weight_tensor = | ||
| graph_utils::GetConstantInitializer(graph, norm_weight_arg->Name()); | ||
| if (norm_weight_tensor == nullptr) { | ||
| return false; | ||
| } | ||
| if (norm_weight_tensor->dims_size() != 1 || norm_weight_tensor->dims(0) != expected_head_size) { | ||
| return false; |
Comment on lines
+335
to
+338
| if ( | ||
| (context.inputs.length > 14 && context.inputs[14] && context.inputs[14].dims.length > 0) || | ||
| (context.inputs.length > 15 && context.inputs[15] && context.inputs[15].dims.length > 0) | ||
| ) { |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds a per-head Q/K RMS normalization prologue to the WebGPU
GroupQueryAttentioncontrib op so that Qwen3-style models can fold the standalone
SimplifiedLayerNormalizationdispatches on Q and K into the attention kernel.The MS-domain
GroupQueryAttentionschema gains two optional inputs and oneattribute:
q_norm_weight(input 14, optional, shape(head_size,))k_norm_weight(input 15, optional, shape(head_size,))qk_norm_epsilon(FLOAT attribute, default1e-6)A new Level 2 optimizer pass
GroupQueryAttentionPreNormFusionrewrites theReshape -> SimplifiedLayerNormalization -> Reshape -> GroupQueryAttentionpattern produced by the Qwen3 ONNX export into a single
GroupQueryAttentionnode that carries the norm weights directly. The pass is scoped to the WebGPU
EP only.
WebGPU runtime paths
FusedQKRotaryEmbedding.SimplifiedLayerNormdispatches into scratchqNorm/kNormtensors, then runs the unfusedFusedQKRotaryEmbedding. Matches the pre-fusion graph timing exactly so prefill cannot regress.Gating
The optimizer only rewrites nodes that the WebGPU kernel can handle:
do_rotary == 1(the fused decode path runs inside the rotary kernel; theprefill fallback also runs only on the
do_rotary=1branch)axis=-1, and the surrounding reshapes must collapse to(... head_size)/expand back to
(... hidden_size).Defense-in-depth on other EPs
The CPU, CUDA, and JSEP
GroupQueryAttentionkernels reject hand-authoredmodels that wire inputs 14/15 with EP-specific error messages, since none of
those runtimes implement the prologue:
GroupQueryAttention (CPU): q_norm_weight / k_norm_weight inputs are not supported...GroupQueryAttention (CUDA): q_norm_weight / k_norm_weight inputs are not supported...GroupQueryAttention (JSEP): q_norm_weight / k_norm_weight inputs are not supported...Motivation and Context
About ~6-8% decode TPS throughput improvement on WebGPU + D3D backend on Windows. GPU used: RTX 5060Ti for Qwen3-1.7B.