-
Notifications
You must be signed in to change notification settings - Fork 90
Closed
Labels
contribution welcomeWe welcome code contributions for thisWe welcome code contributions for thismodule: torchlibRelated to the torch/aten function lib in developmentRelated to the torch/aten function lib in development
Description
onnxscript/onnxscript/function_libs/torch_lib/ops/core.py
Lines 7739 to 7749 in 81f8444
| @torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True) | |
| def aten_scatter( | |
| self: TReal, | |
| dim: int, # we have to use int here because ScatterElements() will use this attribute | |
| index: TInt, | |
| src: TReal, | |
| ) -> TReal: | |
| """scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" | |
| update = op.Expand(src, op.Shape(index)) | |
| return op.ScatterElements(self, index, update, axis=dim) |
The type compatibility between self and value in aten::scatter.value is not required, as implicit casting occurs. Therefore, type compatibility must also be enforced by CastLike in the ops implementation.
>>> data = torch.zeros((3, 3)).float()
>>> indices = torch.tensor([[1, 0, 2], [0, 2, 1]])
>>> torch.ops.aten.scatter.value(data, 0, indices, 1)
tensor([[1., 1., 0.],
[1., 0., 1.],
[0., 1., 1.]])
>>> torch.ops.aten.scatter.value(data, 0, indices, int(1))
tensor([[1., 1., 0.],
[1., 0., 1.],
[0., 1., 1.]])
>>> torch.ops.aten.scatter.value(data, 0, indices, float(1))
tensor([[1., 1., 0.],
[1., 0., 1.],
[0., 1., 1.]])
>>> torch.ops.aten.scatter.value(data, 0, indices, True)
tensor([[1., 1., 0.],
[1., 0., 1.],
[0., 1., 1.]])Metadata
Metadata
Assignees
Labels
contribution welcomeWe welcome code contributions for thisWe welcome code contributions for thismodule: torchlibRelated to the torch/aten function lib in developmentRelated to the torch/aten function lib in development