Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6081,10 +6081,12 @@ def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:
def aten_masked_scatter(self: TTensor, mask: TTensor, source: TTensor) -> TTensor:
"""masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor"""

if len(mask.shape) < len(self.shape):
mask = op.Expand(mask, op.Shape(self))
else:
self = op.Expand(self, op.Shape(mask))
# Broadcast self and mask to their common shape so NonZero enumerates every
# masked element. The previous rank-only check missed same-rank broadcasting
# (e.g. mask (1, S, 1) vs self (1, S, D)): it left mask un-expanded, so only a
# subset of masked positions were scattered (pytorch/pytorch#186146).
self = op.Expand(self, op.Shape(mask))
mask = op.Expand(mask, op.Shape(self))
Comment thread
titaiwangms marked this conversation as resolved.
index = op.Transpose(op.NonZero(mask), perm=[1, 0])

# NOTE: source can have more elements than needed.
Expand Down
35 changes: 35 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,33 @@ def __init__(self):
# in ops_test_data.py and opinfo_core.OpInfo("unique_name", ...)
# To avoid name duplication, it is possible to rename the OpInfo and specify
# the `op` field explicitly.
def sample_inputs_masked_scatter(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
# (self_shape, mask_shape) with mask broadcastable to self. The same-rank
# broadcasting cases (e.g. (1, 5, 4) / (1, 5, 1)) are the regression target for
# pytorch/pytorch#186146 — the dynamo exporter previously left mask un-expanded.
cases = (
((1, 5, 4), (1, 5, 4)), # no broadcast
((1, 5, 4), (1, 5, 1)), # same-rank broadcast over last dim
((1, 5, 4), (5, 4)), # lower-rank mask
((2, 3), (2, 1)), # same-rank broadcast
((3, 1), (3, 4)), # self broadcast up to mask
)
for self_shape, mask_shape in cases:
self_tensor = make_arg(self_shape)
mask = torch.zeros(mask_shape, dtype=torch.bool, device=device)
mask.view(-1)[::2] = True
broadcast_shape = torch.broadcast_shapes(self_shape, mask_shape)
num_selected = int(mask.expand(broadcast_shape).sum())
source = make_arg((max(num_selected, 1),))
yield opinfo_core.SampleInput(self_tensor, args=(mask, source))


OP_DB: List[opinfo_core.OpInfo] = [
opinfo_core.OpInfo(
"bilinear",
Expand Down Expand Up @@ -3101,4 +3128,12 @@ def __init__(self):
sample_inputs_func=sample_inputs_roi_pool,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.masked_scatter",
aten_name="masked_scatter",
op=torch.ops.aten.masked_scatter.default,
dtypes=common_dtype.all_types(),
sample_inputs_func=sample_inputs_masked_scatter,
supports_out=False,
),
]
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,7 @@ def _where_input_wrangler(
reason="fixme: ORT does not have an implementation for Where with bool inputs.",
),
TorchLibOpInfo("masked_scatter", core_ops.aten_masked_scatter),
TorchLibOpInfo("ops.aten.masked_scatter", core_ops.aten_masked_scatter),
TorchLibOpInfo(
"matmul",
core_ops.aten_matmul,
Expand Down
Loading