Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,12 @@ def scalars(draw, dtypes, finite=False, **kwds):
"""
Strategy to generate a scalar that matches a dtype strategy

dtypes should be one of the shared_* dtypes strategies.
dtypes should be one of the shared_* dtypes strategies or a sequence of dtypes.
"""
dtype = draw(dtypes)
if isinstance(dtypes, Sequence):
dtype = draw(sampled_from(dtypes))
else:
dtype = draw(dtypes)
mM = kwds.pop('mM', None)
if dh.is_int_dtype(dtype):
if mM is None:
Expand Down
73 changes: 66 additions & 7 deletions array_api_tests/test_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,37 +243,96 @@ def test_where(shapes, dtypes, data):
@pytest.mark.min_version("2023.12")
@given(data=st.data())
def test_searchsorted(data):
# TODO: test side="right"
# TODO: Allow different dtypes for x1 and x2
x1_dtype = data.draw(st.sampled_from(dh.real_dtypes))
_x1 = data.draw(
st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True),
st.lists(
xps.from_dtype(x1_dtype, allow_nan=False, allow_infinity=False),
min_size=1,
unique=True
),
label="_x1",
)
x1 = xp.asarray(_x1, dtype=dh.default_float)
x1 = xp.asarray(_x1, dtype=x1_dtype)
if data.draw(st.booleans(), label="use sorter?"):
sorter = xp.argsort(x1)
else:
sorter = None
x1 = xp.sort(x1)
note(f"{x1=}")

x2 = data.draw(
st.lists(st.sampled_from(_x1), unique=True, min_size=1).map(
lambda o: xp.asarray(o, dtype=dh.default_float)
lambda o: xp.asarray(o, dtype=x1_dtype)
),
label="x2",
)
# make x2.ndim > 1, if it makes sense
factors = hh._factorize(x2.shape[0])
if len(factors) > 1:
x2 = xp.reshape(x2, tuple(factors))

repro_snippet = ph.format_snippet(f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r})")
kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"])))

repro_snippet = ph.format_snippet(
f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r}, **kw) with {kw=}"
)
try:
out = xp.searchsorted(x1, x2, sorter=sorter)
out = xp.searchsorted(x1, x2, sorter=sorter, **kw)

ph.assert_dtype(
"searchsorted",
in_dtype=[x1.dtype, x2.dtype],
out_dtype=out.dtype,
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
)
# TODO: shapes and values testing
# TODO: values testing
ph.assert_shape("searchsorted", out_shape=out.shape, expected=x2.shape)
except Exception as exc:
exc.add_note(repro_snippet)
raise


### @pytest.mark.min_version("2025.12")
@given(data=st.data())
def test_searchsorted_with_scalars(data):
# 1. draw x1, sorter and side exactly the same as in test_searchsorted
x1_dtype = data.draw(st.sampled_from(dh.real_dtypes))
_x1 = data.draw(
st.lists(
xps.from_dtype(x1_dtype, allow_nan=False, allow_infinity=False),
min_size=1,
unique=True
),
label="_x1",
)
x1 = xp.asarray(_x1, dtype=x1_dtype)
if data.draw(st.booleans(), label="use sorter?"):
sorter = xp.argsort(x1)
else:
sorter = None
x1 = xp.sort(x1)

kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"])))

# 2. draw x2, a real-valued scalar
x2 = data.draw(hh.scalars(st.just(x1.dtype), finite=True))

# 3. testing: similar to test_searchsorted, modulo `out.shape == ()`
repro_snippet = ph.format_snippet(
f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r}, **kw) with {kw = }"
)
try:
out = xp.searchsorted(x1, x2, sorter=sorter, **kw)

ph.assert_dtype(
"searchsorted",
in_dtype=[x1.dtype], #, x2.dtype
out_dtype=out.dtype,
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
)
# TODO: values testing
ph.assert_shape("searchsorted", out_shape=out.shape, expected=())
except Exception as exc:
exc.add_note(repro_snippet)
raise
Loading