Skip to content

Commit

Permalink
Fix reference implementation for ScatterND with 4D tensors (#6174)
Browse files Browse the repository at this point in the history
### Description
Implementation for ScatterND is wrong for 4D tensors.

### Motivation and Context
Fixes a bug.

---------

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
  • Loading branch information
xadupre committed Jun 17, 2024
1 parent 6f7ff97 commit 6ec8c87
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 48 deletions.
65 changes: 17 additions & 48 deletions onnx/reference/ops/op_scatter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,55 +88,24 @@ def f(x, y): # noqa: ARG001
)
return scattered

idx_xsection_shape = indices.shape[:axis] + indices.shape[axis + 1 :]

def make_slice(arr, axis, i): # type: ignore
slc = [slice(None)] * arr.ndim
slc[axis] = i
return slc

def unpack(packed): # type: ignore
unpacked = packed[0]
for i in range(1, len(packed)):
unpacked = unpacked, packed[i]
return unpacked

# We use indices and axis parameters to create idx
# idx is in a form that can be used as a NumPy advanced
# indices for scattering of updates param. in data
idx = [
[
unpack(np.indices(idx_xsection_shape).reshape(indices.ndim - 1, -1)),
indices[tuple(make_slice(indices, axis, i))].reshape(1, -1)[0],
]
for i in range(indices.shape[axis])
]
idx = list(np.concatenate(idx, axis=1))
idx.insert(axis, idx.pop())

# updates_idx is a NumPy advanced indices for indexing
# of elements in the updates
updates_idx = list(idx)
updates_idx.pop(axis)
updates_idx.insert( # type: ignore
axis,
np.repeat(np.arange(indices.shape[axis]), np.prod(idx_xsection_shape)), # type: ignore
)
if len(indices.shape) == 4:
scattered = np.copy(data)
for a in range(indices.shape[0]):
for i in range(indices.shape[1]):
for j in range(indices.shape[2]):
for k in range(indices.shape[3]):
index = [a, i, j, k]
index[axis] = indices[a, i, j, k]
tuple_index = tuple(index)
scattered[tuple_index] = f(
scattered[tuple_index],
updates[a, i, j, k],
)
return scattered

scattered = np.copy(data)
if reduction == "min":
scattered[tuple(idx)] = np.minimum(
scattered[tuple(idx)], updates[tuple(updates_idx)]
)
elif reduction == "max":
scattered[tuple(idx)] = np.maximum(
scattered[tuple(idx)], updates[tuple(updates_idx)]
)
elif reduction == "add":
scattered[tuple(idx)] += updates[tuple(updates_idx)]
else:
scattered[tuple(idx)] = updates[tuple(updates_idx)]
return scattered
raise NotImplementedError(
f"ScatterND is not implement for indices.shape={indices.shape} and axis={axis}."
)


class ScatterElements(OpRun):
Expand Down
33 changes: 33 additions & 0 deletions onnx/test/reference_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
python onnx/test/reference_evaluator_test.py TestReferenceEvaluator.test_function_attribute_nested_graph
"""

from __future__ import annotations

import itertools
Expand Down Expand Up @@ -5915,6 +5916,38 @@ class MyReferenceEvaluator(ReferenceEvaluator):
for v in oinf.functions_.values():
self.assertIsInstance(v, MyReferenceEvaluator)

def test_scatter_elements_4d(self):
model = make_model(
make_graph(
[
make_node(
"ScatterElements",
["data", "indices", "updates"],
["Z"],
axis=3,
reduction="add",
)
],
"name",
[
make_tensor_value_info("data", TensorProto.FLOAT, None),
make_tensor_value_info("indices", TensorProto.INT64, None),
make_tensor_value_info("updates", TensorProto.FLOAT, None),
],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18)],
)
data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2))
indices = np.array([[[[0]]]], dtype=np.int64)
updates = np.array([[[[1]]]], dtype=np.float32)
y = np.array(
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32
).reshape((2, 2, 2, 2))
ref = ReferenceEvaluator(model)
got = ref.run(None, {"data": data, "indices": indices, "updates": updates})
assert_allclose(y, got[0])


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 6ec8c87

Please sign in to comment.