Skip to content
144 changes: 139 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4223,6 +4223,23 @@ def aten_index_put(
See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""
if (
len(indices) > 1
and any(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a pointer to the real use-cases that show up? I thought each index in indices is supposed to be a 1D tensor? What are the ways in which SymbolicTensors are created?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is a SymbolicTensor a 1-element tensor version of a SymbolicInt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SymbolicTensor is just ir.Value with magic methods defined to support python operators.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the symbolic tensors are created from the symints.

isinstance(index, torch.onnx._internal.exporter._tensors.SymbolicTensor) # pylint: disable=protected-access
for index in indices
)
and len(values.shape) == 1
):
return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate)

n_none = [i for i, ind in enumerate(indices) if ind is not None]
if (
len(n_none) == 1
and len(indices[n_none[0]].shape) == 1
and len(self.shape) == len(values.shape)
):
return _aten_index_put_scatter_nd(self, indices, values, accumulate)

def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Remove ones until the rank of reshape_list matches values_shape.
Expand Down Expand Up @@ -4292,14 +4309,131 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Flatten values to match the indices
flat_values = op.Reshape(values, [-1])

if accumulate:
result = op.ScatterND(self, new_index, flat_values, reduction="add")
else:
result = op.ScatterND(self, new_index, flat_values)

scatter_kwargs = dict(reduction="add") if accumulate else {}
result = op.ScatterND(self, new_index, flat_values, **scatter_kwargs)
return result


def _aten_index_put_scatter_nd(
x: TReal,
indices: Sequence[INT64],
values: TReal,
accumulate: bool = False,
) -> TReal:
def _1dint(i: int):
return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i]))

n_none = [i for i, ind in enumerate(indices) if ind is not None]
assert len(n_none) == 1, f"Unable to handle that case: n_none={n_none}"
unsq = op.Unsqueeze(indices[n_none[0]], _1dint(1))
if n_none[0] == 0:
return op.ScatterND(x, unsq, values)

perm = list(range(len(x.shape)))
perm[n_none[0]], perm[0] = perm[0], perm[n_none[0]]
return op.Transpose(
op.ScatterND(
op.Transpose(x, perm=perm),
unsq,
op.Transpose(values, perm=perm),
reduction="add" if accumulate else "none",
),
perm=perm,
)


def _aten_index_put_dynamic(
x: TReal,
indices: Sequence[INT64],
values: TReal,
accumulate: bool = False,
) -> TReal:
def _1dint(i: int):
return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i]))

def _0dint(i: int):
return op.Constant(value_int=ir.AttrInt64("value_int", i))

def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int):
if ind is not None:
return op.Cast(ind, to=INT64.dtype), False
return (
op.Cast(
op.Range( # Range does not return a typed result
_0dint(0),
op.Squeeze(op.Shape(x, start=dim, end=dim + 1)),
_0dint(1),
),
to=INT64.dtype,
),
True,
)

shape_x = op.Shape(x)
exped = []
fixed = []
reshape_value_shape2 = []
expand_value_shape = []
for i, ind in enumerate(indices):
if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): # pylint: disable=protected-access
ind.dtype = ir.DataType.INT64
ind, expanded = _make_range_or_cast(ind, shape_x, False, i)
if expanded:
exped.append((i, ind))
expand_value_shape.append(op.Shape(x, start=i, end=i + 1))
reshape_value_shape2.append(_1dint(1))
else:
expand_value_shape.append(_1dint(1))
reshape_value_shape2.append(op.Shape(ind))
fixed.append((i, ind))

reshape_value_shape1 = [_1dint(1)] * len(indices)
if len(fixed) <= 1:
reshape_value_shape1 = None
elif fixed:
reshape_value_shape1[fixed[-1][0]] = _1dint(-1)

def _mkstride(x, i):
if i >= len(x.shape) - 1:
return _1dint(1)
if i == len(x.shape) - 2:
return op.Shape(x, start=i + 1)
return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1)

shape = [1] * (len(x.shape) + 1)
r_fixed = []
if fixed:
new_shape = shape.copy()
new_shape[-1] = -1
r_fixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed]

r_exped = []
for i, e in exped:
new_shape = shape.copy()
new_shape[i] = -1
r_exped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape))

# final sum
unflat = None
for a in [*r_fixed, *r_exped]:
if unflat is None:
unflat = a
continue
unflat = op.Add(unflat, a)

# value_shape
expanded_values = values
if reshape_value_shape1 is not None:
expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0))
expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0))
flat_ind = op.Reshape(unflat, _1dint(-1))
expanded_values = op.Reshape(expanded_values, _1dint(-1))
flat_x = op.Reshape(x, _1dint(-1))
scat_kwargs = {"reduction": "add"} if accumulate else {}
flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs)
return op.Reshape(flat_up_x, op.Shape(x))


@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
self: TReal,
Expand Down
68 changes: 68 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import unittest

import numpy as np
import torch
from torch.onnx._internal.exporter import _testing

Expand Down Expand Up @@ -225,6 +226,73 @@ def forward(self, q, k, v):
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_dynamic(self):
for dimension in [3, 4, 2]:
with self.subTest(dimension=dimension):

class Model(torch.nn.Module):
def __init__(self, dimension):
super().__init__()
self.params = torch.zeros(
(4, 5)
if dimension == 2
else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5))
)
self.dimension = dimension

def forward(self, update, index1, index2):
copy = self.params.clone()
if self.dimension == 2:
copy[index1, index2] = update
elif self.dimension == 3:
copy[:, index1, index2] = update
else:
copy[:, :, index1, index2] = update
return copy

update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32)
index1 = torch.tensor([1, 2], dtype=torch.int64)
index2 = torch.tensor([3, 4], dtype=torch.int64)
feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2)))
onnx_program = torch.onnx.export(
Model(dimension),
tuple(feeds.values()),
input_names=["update", "index1", "index2"],
output_names=["output"],
opset_version=18,
dynamo=True,
dynamic_shapes={
"update": {0: "dn"},
"index1": {0: "dn"},
"index2": {0: "dn"},
},
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_scatter_nd(self):
class Model(torch.nn.Module):
def forward(self, x, index, update):
x = x.clone()
return torch.ops.aten.index_put(x, [None, index, None], update)

shape = (2, 3, 2)
N = int(np.prod(shape))
x = torch.arange(N, dtype=torch.float32).reshape(shape)
update = (torch.arange(N, dtype=torch.float32).reshape(shape) + 1) * 100
index = ((torch.arange(shape[-2])).to(torch.int64) + 1) % shape[-2]

feeds = dict(zip(["x", "index", "update"], (x, index, update)))
onnx_program = torch.onnx.export(
Model(),
tuple(feeds.values()),
input_names=["x", "index", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
dynamic_shapes=({0: "a", 1: "b", 2: "c"}, {0: "d"}, {0: "e", 1: "f", 2: "g"}),
)
_testing.assert_onnx_program(onnx_program)

def test_bitwise_and_scalar(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down
Loading