-
Notifications
You must be signed in to change notification settings - Fork 90
Labels
good first issueGood for newcomersGood for newcomersmodule: torchlibRelated to the torch/aten function lib in developmentRelated to the torch/aten function lib in development
Description
onnxscript/onnxscript/function_libs/torch_lib/ops/core.py
Lines 3802 to 3822 in ea79022
| @torch_op("aten::gather", trace_only=True) | |
| def aten_gather( | |
| self: TReal, | |
| dim: int, | |
| index: TInt, | |
| sparse_grad: bool = False, | |
| ) -> TReal: | |
| """gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor""" | |
| if len(self.shape) == 0: | |
| if len(index.shape) == 0: | |
| return op.Identity(self) | |
| else: | |
| return op.Expand(self, op.Shape(index)) | |
| if len(index.shape) == 0: | |
| return op.Identity(self) | |
| index = op.Cast(index, to=INT64.dtype) | |
| result = op.GatherElements(self, index, axis=dim) | |
| return result |
When a scalar index is provided to aten::gather, the following behavior occurs:
>>> import torch
>>> x = torch.arange(3)
>>> index = torch.tensor(1)
>>> out = torch.ops.aten.gather.default(x, 0, index)
>>> out.shape
torch.Size([])
>>> out
tensor(1)Therefore, when a 0-dimensional tensor is passed, we'll need to implement branching logic for squeeze/unsqueeze operations.
is_scalar_index = len(index.shape) == 0
if is_scalar_index:
index = op.Unsqueeze(index, [0])
index = op.Cast(index, to=INT64.dtype)
result = op.GatherElements(self, index, axis=dim)
if is_scalar_index:
result = op.Squeeze(result, [0])Metadata
Metadata
Assignees
Labels
good first issueGood for newcomersGood for newcomersmodule: torchlibRelated to the torch/aten function lib in developmentRelated to the torch/aten function lib in development