Skip to content

Fix aten_masked_scatter for broadcasting masks#2929

Merged
titaiwangms merged 1 commit into
microsoft:mainfrom
gonultasbu:fix-masked-scatter-broadcast
Jun 4, 2026
Merged

Fix aten_masked_scatter for broadcasting masks#2929
titaiwangms merged 1 commit into
microsoft:mainfrom
gonultasbu:fix-masked-scatter-broadcast

Conversation

@gonultasbu
Copy link
Copy Markdown
Contributor

aten_masked_scatter returns wrong results when mask is broadcastable to self but has the same rank (e.g. mask (1, S, 1) vs self (1, S, D)): the rank-only check expands self and leaves mask un-expanded, so NonZero(mask) enumerates only a subset of masked positions and ScatterND writes the wrong source elements (e.g. output [1, 2, 3, 4] instead of [1, 9, 17, 25]).

Fix — expand both to their common shape:

self = op.Expand(self, op.Shape(mask))
mask = op.Expand(mask, op.Shape(self))

The original branch (#2112, the initial add for Gemma3) only handled rank differences and has been unchanged since; the Gemma3 path pre-expands the mask (expand_as), so it was unaffected. Surfaced via torch.onnx.export(dynamo=True) on a model that injects features with masked_scatterpytorch/pytorch#186146.

Test: adds an ops.aten.masked_scatter OpInfo with broadcastable-mask samples (the existing torch OpInfo didn't cover same-rank broadcasting) — fails on main, passes here.

The rank-only check left `mask` un-expanded when `mask` and `self` have equal
rank but `mask` still needs broadcasting (e.g. mask `(1, S, 1)` vs self
`(1, S, D)`): it took the `else` branch and expanded `self` (a no-op) while
leaving `mask` at its original shape. `NonZero(mask)` then enumerated only a
subset of the masked positions, so `ScatterND` wrote the wrong source elements
(numerically wrong output). Expand both `self` and `mask` to their common shape.

Add an OpInfo sample (`ops.aten.masked_scatter`) covering broadcastable masks,
which the existing torch OpInfo samples did not exercise.

Context / repro: pytorch/pytorch#186146

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@gonultasbu
Copy link
Copy Markdown
Contributor Author

@microsoft-github-policy-service agree company="Noble Machines"

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to fix incorrect aten_masked_scatter results when mask is broadcastable to self (including same-rank broadcasting), and adds OpInfo-based test coverage to prevent regressions in the torchlib exporter path.

Changes:

  • Update aten_masked_scatter to broadcast self/mask before NonZero(mask) so indices cover all masked elements.
  • Add a custom ops.aten.masked_scatter OpInfo with broadcastable-mask sample inputs (including same-rank broadcasting).
  • Register the new OpInfo in the torchlib ops test matrix.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
onnxscript/function_libs/torch_lib/ops/core.py Adjusts aten_masked_scatter broadcasting logic prior to NonZero/ScatterND.
tests/function_libs/torch_lib/ops_test_data.py Adds TorchLibOpInfo entry so the new OpInfo is exercised by the torchlib test suite.
tests/function_libs/torch_lib/extra_opinfo.py Defines ops.aten.masked_scatter OpInfo and sample inputs covering broadcastable masks.

Comment thread onnxscript/function_libs/torch_lib/ops/core.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 4, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 72.63%. Comparing base (e0d3edc) to head (e1a7d01).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2929      +/-   ##
==========================================
- Coverage   72.63%   72.63%   -0.01%     
==========================================
  Files         259      259              
  Lines       31666    31665       -1     
  Branches     2982     2981       -1     
==========================================
- Hits        23000    22999       -1     
  Misses       7656     7656              
  Partials     1010     1010              

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@titaiwangms titaiwangms enabled auto-merge (squash) June 4, 2026 18:02
@titaiwangms titaiwangms merged commit 50a758a into microsoft:main Jun 4, 2026
31 of 32 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in ONNX Script Review Board Jun 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

3 participants