Skip to content

behavior discrepancy when aten::gather receives a scalar index #2564

@linshokaku

Description

@linshokaku

@torch_op("aten::gather", trace_only=True)
def aten_gather(
self: TReal,
dim: int,
index: TInt,
sparse_grad: bool = False,
) -> TReal:
"""gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor"""
if len(self.shape) == 0:
if len(index.shape) == 0:
return op.Identity(self)
else:
return op.Expand(self, op.Shape(index))
if len(index.shape) == 0:
return op.Identity(self)
index = op.Cast(index, to=INT64.dtype)
result = op.GatherElements(self, index, axis=dim)
return result

When a scalar index is provided to aten::gather, the following behavior occurs:

>>> import torch
>>> x = torch.arange(3)
>>> index = torch.tensor(1)
>>> out = torch.ops.aten.gather.default(x, 0, index)
>>> out.shape
torch.Size([])
>>> out
tensor(1)

Therefore, when a 0-dimensional tensor is passed, we'll need to implement branching logic for squeeze/unsqueeze operations.

        is_scalar_index = len(index.shape) == 0
        if is_scalar_index:
            index = op.Unsqueeze(index, [0])

        index = op.Cast(index, to=INT64.dtype)
        result = op.GatherElements(self, index, axis=dim)

        if is_scalar_index:
            result = op.Squeeze(result, [0])

Metadata

Metadata

Assignees

Labels

good first issueGood for newcomersmodule: torchlibRelated to the torch/aten function lib in development

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions