Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

index_add #13559

Merged
merged 3 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,45 @@ def where(condition, input=None, other=None):
@to_ivy_arrays_and_back
def conj(input):
return ivy.conj(input)


@to_ivy_arrays_and_back
def index_add(input, dim, index, source, *, alpha=1, out=None):
# Potential Bug:
# There is an issue with the torch backend (not caused by ivy)
# where half precision (float16) values get ignored in summation:
#
# >>> a = torch.tensor(-14., dtype=torch.float16)
# >>> b = torch.tensor(1.014, dtype=torch.float16)
# >>> a+b
# tensor(-12.9844, dtype=torch.float16)
# >>> a = torch.tensor(-24., dtype=torch.float16)
# >>> a+b
# tensor(-22.9844, dtype=torch.float16)
# >>> a = torch.tensor(-34., dtype=torch.float16)
# >>> a+b
# tensor(-33., dtype=torch.float16)
# >>>
input = ivy.swapaxes(input, dim, 0)
source = ivy.swapaxes(source, dim, 0)
_to_adds = []
index = sorted(zip(ivy.to_list(index), range(len(index))), key=(lambda x: x[0]))
while index:
_curr_idx = index[0][0]
while len(_to_adds) < _curr_idx:
_to_adds.append(ivy.zeros_like(source[0]))
_to_add_cum = ivy.get_item(source, index[0][1])
while (1 < len(index)) and (index[0][0] == index[1][0]):
_to_add_cum = ivy.add(_to_add_cum, ivy.get_item(source, index.pop(1)[1]))
index.pop(0)
_to_adds.append(_to_add_cum)
while len(_to_adds) < input.shape[0]:
_to_adds.append(ivy.zeros_like(source[0]))
_to_adds = ivy.stack(_to_adds)
if len(input.shape) < 2:
# Added this line due to the paddle backend treating scalars as 1-d arrays
_to_adds = ivy.flatten(_to_adds)

ret = ivy.add(input, _to_adds, alpha=alpha)
ret = ivy.swapaxes(ret, 0, dim, out=out)
return ret
Original file line number Diff line number Diff line change
Expand Up @@ -1217,3 +1217,94 @@ def test_torch_conj(
on_device=on_device,
input=x[0],
)


@st.composite
def _arrays_dim_idx_n_dtypes(draw):
num_dims = draw(st.shared(helpers.ints(min_value=1, max_value=4), key="num_dims"))
num_arrays = 2
common_shape = draw(
helpers.lists(
x=helpers.ints(min_value=2, max_value=3),
min_size=num_dims - 1,
max_size=num_dims - 1,
)
)
_dim = draw(helpers.ints(min_value=0, max_value=num_dims - 1))
unique_dims = draw(
helpers.lists(
x=helpers.ints(min_value=2, max_value=3),
min_size=num_arrays,
max_size=num_arrays,
)
)

min_dim = min(unique_dims)
max_dim = max(unique_dims)
_idx = draw(
helpers.array_values(
shape=min_dim,
dtype="int64",
min_value=0,
max_value=max_dim,
exclude_min=False,
)
)

xs = list()
available_input_types = draw(helpers.get_dtypes("numeric"))
available_input_types.remove("float16") # half summation unstable in backends
input_dtypes = draw(
helpers.array_dtypes(
available_dtypes=available_input_types,
num_arrays=num_arrays,
shared_dtype=True,
)
)
for ud, dt in zip(unique_dims, input_dtypes):
x = draw(
helpers.array_values(
shape=common_shape[:_dim] + [ud] + common_shape[_dim:],
dtype=dt,
large_abs_safety_factor=2.5,
small_abs_safety_factor=2.5,
safety_factor_scale="log",
)
)
xs.append(x)
return xs, input_dtypes, _dim, _idx


# index_add
@handle_frontend_test(
fn_tree="torch.index_add",
xs_dtypes_dim_idx=_arrays_dim_idx_n_dtypes(),
alpha=st.integers(min_value=1, max_value=2),
)
def test_torch_index_add(
*,
xs_dtypes_dim_idx,
alpha,
on_device,
fn_tree,
frontend,
test_flags,
):
xs, input_dtypes, axis, indices = xs_dtypes_dim_idx
if xs[0].shape[axis] < xs[1].shape[axis]:
source, input = xs
else:
input, source = xs
helpers.test_frontend_function(
input_dtypes=[input_dtypes[0], "int64", input_dtypes[1]],
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
rtol=1e-03,
input=input,
dim=axis,
index=indices,
source=source,
alpha=alpha,
)