Skip to content

Commit

Permalink
Fix test_take to make axis optional when ndim == 1
Browse files Browse the repository at this point in the history
I didn't explicitly test axis=None because it's not clear to me that should
actually be supported, given that that's the same as axis=0.
  • Loading branch information
asmeurer authored and honno committed Feb 28, 2024
1 parent 6a1f943 commit ca44b2a
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions array_api_tests/test_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,22 @@ def test_take(x, data):
# * negative axis
# * negative indices
# * different dtypes for indices
axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis")

# axis is optional but only if x.ndim == 1
_axis_st = st.integers(0, max(x.ndim - 1, 0))
if x.ndim == 1:
kw = data.draw(hh.kwargs(axis=_axis_st))
else:
kw = {"axis": data.draw(_axis_st)}
axis = kw.get("axis", 0)
_indices = data.draw(
st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
label="_indices",
)
indices = xp.asarray(_indices, dtype=dh.default_int)
note(f"{indices=}")

out = xp.take(x, indices, axis=axis)
out = xp.take(x, indices, **kw)

ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape(
Expand Down

0 comments on commit ca44b2a

Please sign in to comment.