-
Notifications
You must be signed in to change notification settings - Fork 90
Closed
Labels
contribution welcomeWe welcome code contributions for thisWe welcome code contributions for thismodule: torchlibRelated to the torch/aten function lib in developmentRelated to the torch/aten function lib in development
Description
onnxscript/onnxscript/function_libs/torch_lib/ops/core.py
Lines 7739 to 7749 in 81f8444
| @torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True) | |
| def aten_scatter( | |
| self: TReal, | |
| dim: int, # we have to use int here because ScatterElements() will use this attribute | |
| index: TInt, | |
| src: TReal, | |
| ) -> TReal: | |
| """scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" | |
| update = op.Expand(src, op.Shape(index)) | |
| return op.ScatterElements(self, index, update, axis=dim) |
#2564 reports a similar issue, but I've also confirmed that aten::scatter performs implicit unsqueezing on 0-dimensional tensors, which requires correction.
When a 0-dimensional tensor is provided, ScatterElement will not function correctly.
>>> data = torch.zeros(3).float()
>>> torch.ops.aten.scatter.src(data, 0, torch.tensor(0).long(), torch.tensor(1).float())
tensor([1., 0., 0.])
>>> torch.ops.aten.scatter.src(data, 0, torch.tensor([0]).long(), torch.tensor(1).float())
tensor([1., 0., 0.])
>>> torch.ops.aten.scatter.src(data, 0, torch.tensor(0).long(), torch.tensor([1]).float())
tensor([1., 0., 0.])
>>> torch.ops.aten.scatter.src(data, 0, torch.tensor([0]).long(), torch.tensor([1]).float())
tensor([1., 0., 0.])
>>> torch.ops.aten.scatter.value(data, 0, torch.tensor(0).long(), 1.)
tensor([1., 0., 0.])
>>> torch.ops.aten.scatter.value(data, 0, torch.tensor([0]).long(), 1.)
tensor([1., 0., 0.])Metadata
Metadata
Assignees
Labels
contribution welcomeWe welcome code contributions for thisWe welcome code contributions for thismodule: torchlibRelated to the torch/aten function lib in developmentRelated to the torch/aten function lib in development