Skip to content

behavior discrepancy when aten::scatter.src receives a scalar index or scalar source #2600

@linshokaku

Description

@linshokaku

@torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True)
def aten_scatter(
self: TReal,
dim: int, # we have to use int here because ScatterElements() will use this attribute
index: TInt,
src: TReal,
) -> TReal:
"""scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
update = op.Expand(src, op.Shape(index))
return op.ScatterElements(self, index, update, axis=dim)

#2564 reports a similar issue, but I've also confirmed that aten::scatter performs implicit unsqueezing on 0-dimensional tensors, which requires correction.
When a 0-dimensional tensor is provided, ScatterElement will not function correctly.

>>> data = torch.zeros(3).float()
>>> torch.ops.aten.scatter.src(data, 0, torch.tensor(0).long(), torch.tensor(1).float())
tensor([1., 0., 0.])
>>> torch.ops.aten.scatter.src(data, 0, torch.tensor([0]).long(), torch.tensor(1).float())
tensor([1., 0., 0.])
>>> torch.ops.aten.scatter.src(data, 0, torch.tensor(0).long(), torch.tensor([1]).float())
tensor([1., 0., 0.])
>>> torch.ops.aten.scatter.src(data, 0, torch.tensor([0]).long(), torch.tensor([1]).float())
tensor([1., 0., 0.])
>>> torch.ops.aten.scatter.value(data, 0, torch.tensor(0).long(), 1.)
tensor([1., 0., 0.])
>>> torch.ops.aten.scatter.value(data, 0, torch.tensor([0]).long(), 1.)
tensor([1., 0., 0.])

Metadata

Metadata

Assignees

No one assigned

    Labels

    contribution welcomeWe welcome code contributions for thismodule: torchlibRelated to the torch/aten function lib in development

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions