Fix Potential Integer Truncation Leading to Heap Out-of-Bounds Read/Write#27544
Fix Potential Integer Truncation Leading to Heap Out-of-Bounds Read/Write#27544
Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors several CPU tensor kernels to improve type safety in ThreadPool::TryParallelFor usage by replacing truncated 32-bit loop indices with ptrdiff_t, reducing the risk of overflow-driven heap OOB when iterating very large workloads.
Changes:
- Update
GatherNDparallel loops and per-slice lambdas to useptrdiff_tindices and adjust related pointer arithmetic casts. - Update
ScatterNDdispatch lambda and parallel loop to useptrdiff_tindices consistently. - Update
GatherGradparallel loop to useptrdiff_tindices (avoidinginttruncation).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| orttraining/orttraining/training_ops/cpu/tensor/gather_grad.cc | Uses ptrdiff_t for TryParallelFor loop indices in GatherGrad. |
| onnxruntime/core/providers/cpu/tensor/scatter_nd.cc | Uses ptrdiff_t for TryParallelFor loop indices and dispatch lambda in ScatterND. |
| onnxruntime/core/providers/cpu/tensor/gather_nd.cc | Uses ptrdiff_t for TryParallelFor loop indices and slice lambdas in GatherND. |
Comments suppressed due to low confidence (4)
onnxruntime/core/providers/cpu/tensor/gather_nd.cc:122
- This change is intended to prevent integer truncation in parallel iteration (e.g., when the total work item count exceeds 32-bit). There isn’t a regression test exercising very large
num_slices/ output sizes for GatherND similar toGather_overflow_checkfor Gather; adding one would help prevent reintroducing truncation issues in the future (with appropriate 32-bit skips and memory considerations).
concurrency::ThreadPool::TryParallelFor(
tp, onnxruntime::narrow<size_t>(num_slices), static_cast<double>(num_slice_dims),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (ptrdiff_t slice_idx = first, end = last; slice_idx < end; ++slice_idx) {
lambda(slice_idx);
onnxruntime/core/providers/cpu/tensor/scatter_nd.cc:405
- This change is intended to prevent integer truncation in the parallel loop iteration. There isn’t a stress/regression test for ScatterND that forces
prepare.element_offsets.size()to exceed 32-bit and exercises theTryParallelForpath; adding one (skipping 32-bit platforms and being mindful of memory) would help ensure this doesn’t regress.
concurrency::ThreadPool::TryParallelFor(
tp, prepare.element_offsets.size(), static_cast<double>(prepare.element_to_copy),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (ptrdiff_t i = first, end = last; i < end; ++i) {
lambda(i);
}
});
orttraining/orttraining/training_ops/cpu/tensor/gather_grad.cc:98
- This change avoids truncating
first/last(ptrdiff_t) tointin the parallel loop. There is currently no regression test that exercisesgrad_sizevalues beyond 32-bit to validate this fix; consider adding a stress test (with 32-bit skips/memory constraints) to prevent future reintroductions of the truncation pattern.
concurrency::ThreadPool::TryParallelFor(tp, grad_size, static_cast<double>(block_size),
[&lambda](ptrdiff_t first, ptrdiff_t last) {
for (ptrdiff_t index = first, end = last; index < end; ++index) {
lambda(index);
}
});
onnxruntime/core/providers/cpu/tensor/gather_nd.cc:107
err_indexis written from inside theTryParallelForworker lambda without any synchronization. If multiple threads encounter an invalid index, this is a data race (undefined behavior) and the final error value is nondeterministic. Consider using an atomic (e.g., store first failure), or computing validation sequentially / with thread-local errors and combining after the parallel loop.
int64_t index = static_cast<int64_t>(slice_indices[dim_idx]);
const auto upper_limit = input_shape[SafeInt<size_t>(batch_dims_) + dim_idx];
const auto lower_limit = -upper_limit;
if (index < lower_limit || index >= upper_limit) {
err_index = index;
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
AI SummaryThis PR fixes integer truncation bugs in The fix is straightforward and consistent across all three files:
Detailed Analysisgather_nd.cc — 3 functions fixed
scatter_nd.cc — 1 function fixed
gather_grad.cc — 1 loop fixed
Issues1.
|
| Priority | Issue | Action |
|---|---|---|
| P2 | No regression tests for overflow scenarios | Add tests similar to Gather_overflow_check |
| P2 | Pre-existing err_index data race (not from this PR) |
Track separately |
| P3 | gather_grad.cc lambda parameter still int64_t |
Change to ptrdiff_t for consistency |
| P3 | One remaining truncation site in uni_directional_lstm.cc |
Follow-up PR |
| P3 | Commit message "update" not descriptive | Cosmetic; PR title is good |
Recommendation: Approve. This is a clean, low-risk security fix that addresses a real integer truncation vulnerability. The changes are mechanically correct and consistent with the established fix pattern from PR #27444. The missing regression tests (Issue 3) are the most notable gap but are not blocking given the straightforward nature of the type changes.
Description
This pull request refactors several tensor operation kernels (
GatherND,ScatterND, andGatherGrad) to improve type safety and consistency in parallelized code execution. The main change is replacingintloop indices withptrdiff_tto avoid overflow.Parallelization and Type Safety Improvements
gather_nd.cc(GatherNDBase::PrepareForCompute,GatherND::GatherNumber, andGatherND::GatherString) to useptrdiff_tinstead ofint64_t, and replaced index arithmetic with explicit casts to maintain correctness. [1] [2] [3]scatter_nd.cc(ScatterNDDispatchTarget) to useptrdiff_tfor loop indices and index arithmetic in all reduction cases, ensuring consistent type usage in parallel execution.gather_grad.cc(GatherGrad::ComputeImpl) to useptrdiff_tfor parallel loop indices, aligning with the changes in other tensor kernels.Motivation and Context
Another same issue was fixed in #27444