Fix aten_masked_scatter for broadcasting masks#2929
Merged
titaiwangms merged 1 commit intoJun 4, 2026
Conversation
The rank-only check left `mask` un-expanded when `mask` and `self` have equal rank but `mask` still needs broadcasting (e.g. mask `(1, S, 1)` vs self `(1, S, D)`): it took the `else` branch and expanded `self` (a no-op) while leaving `mask` at its original shape. `NonZero(mask)` then enumerated only a subset of the masked positions, so `ScatterND` wrote the wrong source elements (numerically wrong output). Expand both `self` and `mask` to their common shape. Add an OpInfo sample (`ops.aten.masked_scatter`) covering broadcastable masks, which the existing torch OpInfo samples did not exercise. Context / repro: pytorch/pytorch#186146 Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Contributor
Author
|
@microsoft-github-policy-service agree company="Noble Machines" |
Contributor
There was a problem hiding this comment.
Pull request overview
This PR aims to fix incorrect aten_masked_scatter results when mask is broadcastable to self (including same-rank broadcasting), and adds OpInfo-based test coverage to prevent regressions in the torchlib exporter path.
Changes:
- Update
aten_masked_scatterto broadcastself/maskbeforeNonZero(mask)so indices cover all masked elements. - Add a custom
ops.aten.masked_scatterOpInfo with broadcastable-mask sample inputs (including same-rank broadcasting). - Register the new OpInfo in the torchlib ops test matrix.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| onnxscript/function_libs/torch_lib/ops/core.py | Adjusts aten_masked_scatter broadcasting logic prior to NonZero/ScatterND. |
| tests/function_libs/torch_lib/ops_test_data.py | Adds TorchLibOpInfo entry so the new OpInfo is exercised by the torchlib test suite. |
| tests/function_libs/torch_lib/extra_opinfo.py | Defines ops.aten.masked_scatter OpInfo and sample inputs covering broadcastable masks. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2929 +/- ##
==========================================
- Coverage 72.63% 72.63% -0.01%
==========================================
Files 259 259
Lines 31666 31665 -1
Branches 2982 2981 -1
==========================================
- Hits 23000 22999 -1
Misses 7656 7656
Partials 1010 1010 ☔ View full report in Codecov by Harness. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
aten_masked_scatterreturns wrong results whenmaskis broadcastable toselfbut has the same rank (e.g.mask (1, S, 1)vsself (1, S, D)): the rank-only check expandsselfand leavesmaskun-expanded, soNonZero(mask)enumerates only a subset of masked positions andScatterNDwrites the wrongsourceelements (e.g. output[1, 2, 3, 4]instead of[1, 9, 17, 25]).Fix — expand both to their common shape:
The original branch (#2112, the initial add for Gemma3) only handled rank differences and has been unchanged since; the Gemma3 path pre-expands the mask (
expand_as), so it was unaffected. Surfaced viatorch.onnx.export(dynamo=True)on a model that injects features withmasked_scatter— pytorch/pytorch#186146.Test: adds an
ops.aten.masked_scatterOpInfo with broadcastable-mask samples (the existing torch OpInfo didn't cover same-rank broadcasting) — fails onmain, passes here.