Add validation of position_ids in RotaryEmbedding operators#27597
Add validation of position_ids in RotaryEmbedding operators#27597
Conversation
There was a problem hiding this comment.
Pull request overview
This PR hardens the ONNX-domain RotaryEmbedding operator against out-of-bounds reads when user-provided position_ids contain invalid indices relative to the cos/sin cache (max_sequence_length), addressing a potential correctness and security issue.
Changes:
- CPU: Validate
position_idsvalues upfront (when explicitly provided) and returnINVALID_ARGUMENTon out-of-range values. - CUDA: Plumb
max_sequence_lengthinto the kernel and add a device-side bounds check (pass-through on OOB since kernels can’t surface errors). - Tests: Add CPU unit tests that assert invalid
position_idsare rejected with an appropriate error substring.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
onnxruntime/core/providers/cpu/llm/rotary_embedding.cc |
Adds upfront position_ids range validation to prevent OOB cache access on CPU. |
onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu |
Passes max_sequence_length to the CUDA kernel and guards cache indexing for explicit position_ids. |
onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc |
Adds negative / exceeds-max / in-batch OOB test cases for CPU failure behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
|
@tianleiwu I remember ONNX RotaryEmbedding is copied from Contrib op. Did you also fix on that side? |
4. Consolidated Findings4.1 Must-Fix Issues
Both CPU and CUDA cast int pos = static_cast<int>(position_ids[i]); // truncation happens here
if (pos < 0 || pos >= max_sequence_length) { ... } // check is on truncated valueA value like One-line fix (CPU): int64_t pos64 = position_ids[i];
if (pos64 < 0 || pos64 >= static_cast<int64_t>(max_sequence_length)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"position_ids value ", pos64, " at index ", i,
" is out of range [0, ", max_sequence_length, ")");
}One-line fix (CUDA): int64_t raw_pos = position_ids[b_s_index];
if (raw_pos < 0 || raw_pos >= static_cast<int64_t>(max_sequence_length)) {
output_data[i] = input_data[i];
return;
}
b_s_index = static_cast<int>(raw_pos);4.2 Tracked Separately (Separate PR Recommended)
The contrib_ops implementations at
Recommendation: File a tracked security issue and address in a follow-up PR. The scope of this PR is well-defined for the ONNX domain. 4.3 Should-Fix (In This PR)
4.4 Nice-to-Have (Non-Blocking)
5. What's Correct and Well-DoneAll four reviewers agreed these aspects are positive:
6. Verdict: Approve with ChangesRequired for merge:
Strongly recommended for this PR:
Separate follow-up:
Nice-to-have (author's discretion):
|
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Comments suppressed due to low confidence (1)
onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu:91
- The new bounds checks cover position_ids formats 0 and 1, but format 2 (
past_sequence_length + s) can still produce negative or >= max_sequence_lengthposition_idvalues, leading to out-of-bounds reads fromcos_cache/sin_cache. Add the same range validation for format 2 (and handle negativepast_sequence_lengths[b]), falling back to pass-through on OOB like the other formats.
position_id = static_cast<int>(pos);
} else if (position_ids_format == 2) {
// format 2: past_sequence_length + s
// used for Decoding (past_sequence_length = seqlens_k[b]) or First Prompt (past=0 if nullptr)
int past = (past_sequence_lengths == nullptr) ? 0 : past_sequence_lengths[b];
position_id = past + s;
}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Comments suppressed due to low confidence (1)
onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu:93
- In the CUDA contrib rotary embedding kernel, the new bounds checks cover formats 0 and 1, but the format-2 path (position_id = past_sequence_lengths[b] + s) still computes cache_offset without validating that the resulting position_id is within [0, max_sequence_length). If position_ids_format=2 is ever used with a large/negative past_sequence_lengths value, this can still read out of bounds from cos_cache/sin_cache. Please add an equivalent bounds check for format 2 (e.g., validate past in range and that past + sequence_length <= max_sequence_length) and apply the same pass-through behavior on failure.
} else if (position_ids_format == 2) {
// format 2: past_sequence_length + s
// used for Decoding (past_sequence_length = seqlens_k[b]) or First Prompt (past=0 if nullptr)
int past = (past_sequence_lengths == nullptr) ? 0 : past_sequence_lengths[b];
position_id = past + s;
}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Description
Fix out-of-bounds read in the RotaryEmbedding operator when user-provided
position_idsvalues exceed the cos/sin cache bounds (max_sequence_length).Problem
When
position_idscontains values that are negative or >=max_sequence_length, the kernel computescache_offset = position_id * half_rotary_embedding_dimand reads out-of-bounds fromcos_cache/sin_cache. This can cause undefined behavior (incorrect results, crashes, or memory corruption).Fix
CPU (
rotary_embedding.cc):position_idsvalues before the parallel computation loop. Returns anINVALID_ARGUMENTerror if any value is out of range[0, max_sequence_length).position_ids_format != 0(i.e., when position_ids are explicitly provided). Whenposition_idsis not provided (format 0), the cache is shaped(B, S, H/2)and the indexb * S + sis always in-bounds by construction.CUDA (
rotary_embedding_impl.cu):max_sequence_lengthparameter through to the kernel.position_ids_format != 0branch. Out-of-bounds position IDs cause the kernel to pass through the input unchanged (errors cannot be propagated from GPU kernels).position_ids_format != 0branch only. When format is 0 (no position_ids), the cache is(B*S, H/2)andb_s_index = b * S + sis deterministically valid — applying the check unconditionally would incorrectly reject all batches beyond the first sincemax_sequence_length == sequence_lengthin that case.Tests
Added three CPU test cases for the ONNX domain
RotaryEmbeddingop:RotaryEmbedding_PositionIds_ExceedsMaxSeqLen— position_id far exceeding cache sizeRotaryEmbedding_PositionIds_Negative— negative position_idRotaryEmbedding_PositionIds_OOB_InBatch— OOB position_id in a multi-batch, multi-sequence scenarioMotivation and Context
Security hardening — prevent out-of-bounds memory access from untrusted model inputs.