Skip to content

need for casting in aten::scatter.value #2602

@linshokaku

Description

@linshokaku

@torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True)
def aten_scatter(
self: TReal,
dim: int, # we have to use int here because ScatterElements() will use this attribute
index: TInt,
src: TReal,
) -> TReal:
"""scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
update = op.Expand(src, op.Shape(index))
return op.ScatterElements(self, index, update, axis=dim)

The type compatibility between self and value in aten::scatter.value is not required, as implicit casting occurs. Therefore, type compatibility must also be enforced by CastLike in the ops implementation.

>>> data = torch.zeros((3, 3)).float()
>>> indices = torch.tensor([[1, 0, 2], [0, 2, 1]])
>>> torch.ops.aten.scatter.value(data, 0, indices, 1)
tensor([[1., 1., 0.],
        [1., 0., 1.],
        [0., 1., 1.]])
>>> torch.ops.aten.scatter.value(data, 0, indices, int(1))
tensor([[1., 1., 0.],
        [1., 0., 1.],
        [0., 1., 1.]])
>>> torch.ops.aten.scatter.value(data, 0, indices, float(1))
tensor([[1., 1., 0.],
        [1., 0., 1.],
        [0., 1., 1.]])
>>> torch.ops.aten.scatter.value(data, 0, indices, True)
tensor([[1., 1., 0.],
        [1., 0., 1.],
        [0., 1., 1.]])

Metadata

Metadata

Assignees

Labels

contribution welcomeWe welcome code contributions for thismodule: torchlibRelated to the torch/aten function lib in development

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions