Skip to content

Commit

Permalink
fix: broadcasting issue in compositional set_item + added example to …
Browse files Browse the repository at this point in the history
…set_item test (#28181)
  • Loading branch information
mattbarrett98 committed Feb 7, 2024
1 parent b4defbc commit dc7be7e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
10 changes: 0 additions & 10 deletions ivy/functional/ivy/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2931,16 +2931,6 @@ def _broadcast_to(input, target_shape):
return ivy.reshape(input, target_shape)
else:
input = input if len(input.shape) else ivy.expand_dims(input, axis=0)
new_dims = ()
i_i = len(input.shape) - 1
for i_t in range(len(target_shape) - 1, -1, -1):
if len(input.shape) + len(new_dims) >= len(target_shape):
break
if i_i < 0 or target_shape[i_t] != input.shape[i_i]:
new_dims += (i_t,)
else:
i_i -= 1
input = ivy.expand_dims(input, axis=new_dims)
return ivy.broadcast_to(input, target_shape)


Expand Down
33 changes: 31 additions & 2 deletions ivy_tests/test_ivy/test_functional/test_core/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from types import SimpleNamespace

import pytest
from hypothesis import given, assume, strategies as st
from hypothesis import given, assume, example, strategies as st
import numpy as np
from collections.abc import Sequence

Expand All @@ -15,7 +15,11 @@
import ivy

import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_test, BackendHandler
from ivy_tests.test_ivy.helpers import (
handle_test,
BackendHandler,
test_parameter_flags as pf,
)
from ivy_tests.test_ivy.helpers.assertions import assert_all_close
from ivy_tests.test_ivy.test_functional.test_core.test_elementwise import pow_helper

Expand Down Expand Up @@ -1643,6 +1647,31 @@ def test_set_inplace_mode(mode):
container_flags=st.just([False]),
test_with_copy=st.just(True),
)
@example(
dtypes_x_query_val=(
["int32", "int32"],
np.ones((1, 3, 3, 3)),
(slice(None, None, None), slice(None, None, None), slice(None, None, None), 1),
np.zeros((3, 1)),
),
copy=False,
fn_name="set_item",
test_flags=pf.FunctionTestFlags(
ground_truth_backend="numpy",
num_positional_args=3,
instance_method=False,
with_out=False,
with_copy=False,
test_gradients=False,
test_trace=False,
transpile=False,
as_variable=[False],
native_arrays=[False],
container=[False],
precision_mode=False,
test_cython_wrapper=False,
),
)
def test_set_item(
dtypes_x_query_val,
copy,
Expand Down

0 comments on commit dc7be7e

Please sign in to comment.