Relax GQA seqlens_k shape validation for backward compat with older models#28259
Relax GQA seqlens_k shape validation for backward compat with older models#28259
Conversation
| const seqlLenSize = seqlLens.dims.reduce((a, b) => a * b, 1); | ||
| if (seqlLenSize !== batchSize) { | ||
| throw new Error( | ||
| `Input "seqlens" must have batch_size (${batchSize}) elements, got ${seqlLenSize}`, |
There was a problem hiding this comment.
should we use the same error message here and in onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h?
There was a problem hiding this comment.
JS now uses same messages as C++
| const seqlLens = inputs.length > 4 ? inputs[5] : undefined; | ||
| if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) { | ||
| throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size'); | ||
| // Accept any shape whose total element count equals batchSize (e.g. [B], [B,1], [1,1]). |
There was a problem hiding this comment.
nit: [1, 1] as an example shape will only work if batch size is 1. since we already have example shape [B, 1], maybe we can omit it.
There was a problem hiding this comment.
Removed from both C++ and JS comments
| tester.AddOutput<float>("present_value", {1, kv_num_heads, 1, head_size}, | ||
| std::vector<float>(kv_num_heads * head_size, 0.0f)); | ||
|
|
||
| tester.SetOutputTolerance(1e6f); |
There was a problem hiding this comment.
this is a large output tolerance. is it feasible to test with actual expected values? as the existing tests also do this, perhaps it can be done in another PR. at least, I think it would be worth a comment.
| // Backward compat: seqlens_k with shape {1, 1} (2D) must be accepted when batch_size=1, | ||
| // since older model builders emit this shape. Total element count (1) matches batch_size. | ||
| TEST(GroupQueryAttentionTest, SeqlensKLegacy2DShape) { | ||
| constexpr int num_heads = 1; |
There was a problem hiding this comment.
consider reusing/creating another test helper like RunGQASeqlensKTest() to reduce code duplication.
There was a problem hiding this comment.
Extended RunGQASeqlensKTest with seqlens_k_shape param.
…odels PR #28031 tightened seqlens_k shape validation (&&->||), correctly rejecting non-1D tensors per spec. However, older model builders emit seqlens_k with shape [1,1] instead of [1], breaking HuggingFace LLMs (qwen3-0.6b, qwen3-1.7b). Relax shape check to allow unit dimensions around the batch axis: each dim must be 1 or batch_size (accepts [B], [B,1], [1,1] but rejects [2,2] for B=4). Also fixes the same latent && bug in JS/WebGPU EP. Value bounds checks in Compute() are unchanged. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
ba7d3a2 to
c0b4397
Compare
|
Sorry about the force-push — Copilot CLI rewrote the branch and lost the incremental diff history. Addressed all 5 comments:
|
Add JS/WebGPU test for [1,1] seqlens_k shape (the exact qwen3 regression
case) and C++ test for trailing batch dim shape {1,B}.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Relaxes seqlens_k shape validation for GroupQueryAttention to restore backward compatibility with older model exporters that emit extra unit dimensions (e.g., [B,1]), while keeping the value-range checks that prevent OOB access.
Changes:
- Update C++
CheckInputs()validation to acceptseqlens_kshapes withbatch_sizetotal elements (with additional per-dimension constraints). - Apply equivalent validation updates in the JS/WebGPU
validateInputs()path. - Extend CPU and JS test coverage with legacy-shape acceptance and wrong-shape rejection cases.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h | Updates seqlens_k shape validation and error messages in shared helper. |
| js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts | Aligns WebGPU input validation with the relaxed seqlens_k shape rules. |
| onnxruntime/test/contrib_ops/group_query_attention_op_test.cc | Adds regression tests for legacy 2D shapes and invalid element-count/shape cases. |
| js/web/test/data/ops/group-query-attention.jsonc | Adds a Web test case covering legacy [1,1] seqlens_k shape acceptance. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Spec requires 1D shape (batch_size), but older model builders may add unit | ||
| // dimensions (e.g. [B, 1] instead of [B]). Allow shapes where each dim is 1 or batch_size. | ||
| const auto& seqlens_k_dim = seqlens_k->Shape().GetDims(); | ||
| if (seqlens_k_dim.size() != 1 || seqlens_k_dim[0] != batch_size) { | ||
| if (seqlens_k->Shape().Size() != static_cast<int64_t>(batch_size)) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "seqlens_k must be shape (batch_size)."); | ||
| "seqlens_k must have batch_size (", batch_size, ") elements, got ", | ||
| seqlens_k->Shape().Size(), "."); | ||
| } | ||
| for (size_t i = 0; i < seqlens_k_dim.size(); ++i) { | ||
| if (seqlens_k_dim[i] != 1 && seqlens_k_dim[i] != static_cast<int64_t>(batch_size)) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "seqlens_k has unexpected shape. Each dimension must be 1 or batch_size (", | ||
| batch_size, "), got dim[", i, "] = ", seqlens_k_dim[i], "."); | ||
| } |
There was a problem hiding this comment.
The PR description says the seqlens_k shape validation is relaxed to only check total element count (Size() == batch_size). This implementation still adds an additional per-dimension constraint (each dim must be 1 or batch_size), which will reject shapes like [2,2] for batch_size=4 even though the element count matches. Please either (a) loosen the validation to match the stated behavior, or (b) update the PR description to reflect the intended restriction and why it’s needed.
| } | ||
|
|
||
| // Spec requires 1D shape (batch_size), but older model builders may add unit | ||
| // dimensions (e.g. [B, 1] instead of [B]). Allow shapes where each dim is 1 or batch_size. |
There was a problem hiding this comment.
This new check allows rank-0 (scalar) seqlens_k when batch_size==1 because TensorShape::Size() returns 1 for empty shapes and the dim loop is skipped. If the goal is only to accept 1D with optional unit dimensions (e.g., [B,1], [1,B]), consider explicitly rejecting scalar shapes (NumDimensions()==0) to avoid widening accepted inputs beyond what’s described by the spec/comment.
| // dimensions (e.g. [B, 1] instead of [B]). Allow shapes where each dim is 1 or batch_size. | |
| // dimensions (e.g. [B, 1] instead of [B]). Allow shapes where each dim is 1 or batch_size. | |
| if (seqlens_k->Shape().NumDimensions() == 0) { | |
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | |
| "seqlens_k must not be a scalar; expected a 1D tensor of shape (batch_size) " | |
| "or a tensor with only unit dimensions and one batch_size dimension."); | |
| } |
There was a problem hiding this comment.
perhaps we could also verify that NumDimensions() > 0. sounds like we want at least 1D shape.
There was a problem hiding this comment.
I highly doubt it will break any models but just want to be on side of caution if we should add this check @guschmue
| // Spec requires 1D shape (batch_size), but older model builders may add unit | ||
| // dimensions (e.g. [B, 1] instead of [B]). Allow shapes where each dim is 1 or batchSize. | ||
| const seqlLens = inputs.length > 4 ? inputs[5] : undefined; | ||
| if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) { | ||
| throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size'); | ||
| if (seqlLens) { | ||
| const seqlLenSize = seqlLens.dims.reduce((a, b) => a * b, 1); | ||
| if (seqlLenSize !== batchSize) { | ||
| throw new Error( | ||
| `seqlens_k must have batch_size (${batchSize}) elements, got ${seqlLenSize}.`, | ||
| ); | ||
| } | ||
| for (let i = 0; i < seqlLens.dims.length; i++) { | ||
| if (seqlLens.dims[i] !== 1 && seqlLens.dims[i] !== batchSize) { | ||
| throw new Error( | ||
| `seqlens_k has unexpected shape. Each dimension must be 1 or batch_size (${batchSize}), got dims[${i}] = ${seqlLens.dims[i]}.`, | ||
| ); | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
As written, a scalar seqlens_k (dims=[]) will be accepted when batchSize===1 because reduce() returns 1 and the per-dim loop is skipped. If you only intend to allow older builders that add unit dimensions (e.g., [B,1]/[1,B]) and still require at least 1 dimension, add an explicit dims.length>=1 check to match the spec and the comment above this block.
| } | ||
|
|
||
| // Spec requires 1D shape (batch_size), but older model builders may add unit | ||
| // dimensions (e.g. [B, 1] instead of [B]). Allow shapes where each dim is 1 or batch_size. |
There was a problem hiding this comment.
perhaps we could also verify that NumDimensions() > 0. sounds like we want at least 1D shape.
Address review comments: - Reject rank-0 (scalar) seqlens_k in both C++ and JS validation - Use std::optional<vector> for test helper seqlens_k_shape param - Add SeqlensKScalarRejected test case Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
|
Addressed remaining comments:
|
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
|
Validated with https://huggingface.co/schmuell/Qwen3-1.7B |
|
@vraspar your PR #28031 broke the functionality & I have tested with open source models too. FYI intel#1067 |
Problem
PR #28031 fixed a security OOB GEMM bug via crafted seqlens_k by changing
&&to||in the shape validation in group_query_attention_helper.h. This correctly enforces the spec (1D Tensor of shape (batch_size)) but breaks models (e.g. qwen3-0.6b, qwen3-1.7b) whose builder.py emits seqlens_k with shape [1,1] instead of [1].Fix
Relax the shape check to accept shapes with unit dimensions around the batch axis. The validation rule is:
Also fixes the same latent &&/|| bug in the JS/WebGPU EP (group-query-attention.ts).
Security: The per-element value bounds checks in Compute() are unchanged -- the OOB fix from #28031 is fully preserved.
Changes