Skip to content

Commit

Permalink
fix: reset broadcasting issues in jax backend setitem (#27909)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ishticode authored Jan 14, 2024
1 parent f3dd7eb commit fec52dc
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions ivy/functional/backends/jax/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def _mask_to_index(query, x):
raise ivy.exceptions.IvyException("too many indices")
elif not len(query.shape):
query = jnp.tile(query, x.shape[0])
expected_shape = x[query].shape
return jnp.where(query), expected_shape
return jnp.where(query)


def get_item(
Expand All @@ -75,7 +74,7 @@ def get_item(
return jnp.array([], dtype=x.dtype)
else:
return jnp.expand_dims(x, 0)
query, _ = _mask_to_index(query, x)
query = _mask_to_index(query, x)
elif isinstance(query, list):
query = (query,)
return x.__getitem__(query)
Expand All @@ -90,9 +89,10 @@ def set_item(
copy: Optional[bool] = False,
) -> JaxArray:
if ivy.is_array(query) and ivy.is_bool_dtype(query):
query, expected_shape = _mask_to_index(query, x)
if ivy.is_array(val):
val = _broadcast_to(val, expected_shape)._data
query = _mask_to_index(query, x)
expected_shape = x[query].shape
if ivy.is_array(val):
val = _broadcast_to(val, expected_shape)._data
ret = x.at[query].set(val)
if copy:
return ret
Expand Down

0 comments on commit fec52dc

Please sign in to comment.