From 33b0efe13fbfecef055274175338280ad51c364b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Nov 2025 12:08:33 +0100 Subject: [PATCH 1/4] ENH: test side=left,right in searchsorted --- array_api_tests/test_searching_functions.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 8df475d8..1283d9fa 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -243,7 +243,6 @@ def test_where(shapes, dtypes, data): @pytest.mark.min_version("2023.12") @given(data=st.data()) def test_searchsorted(data): - # TODO: test side="right" # TODO: Allow different dtypes for x1 and x2 _x1 = data.draw( st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True), @@ -262,10 +261,13 @@ def test_searchsorted(data): ), label="x2", ) + kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"]))) - repro_snippet = ph.format_snippet(f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r})") + repro_snippet = ph.format_snippet( + f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r}, **kw) with {kw=}" + ) try: - out = xp.searchsorted(x1, x2, sorter=sorter) + out = xp.searchsorted(x1, x2, sorter=sorter, **kw) ph.assert_dtype( "searchsorted", @@ -273,7 +275,8 @@ def test_searchsorted(data): out_dtype=out.dtype, expected=xp.__array_namespace_info__().default_dtypes()["indexing"], ) - # TODO: shapes and values testing + # TODO: x2.ndim > 1, values testing + ph.assert_shape("searchsorted", out_shape=out.shape, expected=x2.shape) except Exception as exc: exc.add_note(repro_snippet) raise From 4bf7e34c30b2dd5738ff84338684aa961b49b267 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Nov 2025 13:29:41 +0100 Subject: [PATCH 2/4] ENH: test searchsorted with x2.ndim > 1 --- array_api_tests/test_searching_functions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 1283d9fa..af079591 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -255,12 +255,18 @@ def test_searchsorted(data): sorter = None x1 = xp.sort(x1) note(f"{x1=}") + x2 = data.draw( st.lists(st.sampled_from(_x1), unique=True, min_size=1).map( lambda o: xp.asarray(o, dtype=dh.default_float) ), label="x2", ) + # make x2.ndim > 1, if it makes sense + factors = hh._factorize(x2.shape[0]) + if len(factors) > 1: + x2 = xp.reshape(x2, tuple(factors)) + kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"]))) repro_snippet = ph.format_snippet( @@ -275,7 +281,7 @@ def test_searchsorted(data): out_dtype=out.dtype, expected=xp.__array_namespace_info__().default_dtypes()["indexing"], ) - # TODO: x2.ndim > 1, values testing + # TODO: values testing ph.assert_shape("searchsorted", out_shape=out.shape, expected=x2.shape) except Exception as exc: exc.add_note(repro_snippet) From 78c969f9fd48cf4e35270c66dabaaa6d8ade6874 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Nov 2025 14:04:04 +0100 Subject: [PATCH 3/4] ENH: searchsorted: draw x1.dtype --- array_api_tests/test_searching_functions.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index af079591..61c6f436 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -244,11 +244,12 @@ def test_where(shapes, dtypes, data): @given(data=st.data()) def test_searchsorted(data): # TODO: Allow different dtypes for x1 and x2 + x1_dtype = data.draw(st.sampled_from(dh.real_dtypes)) _x1 = data.draw( - st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True), + st.lists(xps.from_dtype(x1_dtype), min_size=1, unique=True), label="_x1", ) - x1 = xp.asarray(_x1, dtype=dh.default_float) + x1 = xp.asarray(_x1, dtype=x1_dtype) if data.draw(st.booleans(), label="use sorter?"): sorter = xp.argsort(x1) else: @@ -258,7 +259,7 @@ def test_searchsorted(data): x2 = data.draw( st.lists(st.sampled_from(_x1), unique=True, min_size=1).map( - lambda o: xp.asarray(o, dtype=dh.default_float) + lambda o: xp.asarray(o, dtype=x1_dtype) ), label="x2", ) From 5bd524be003bdab3224e6827dc8f9716cc3444ac Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Nov 2025 18:42:53 +0100 Subject: [PATCH 4/4] MAINT: searchsorted: restrict inputs to be finite real values --- array_api_tests/test_searching_functions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 61c6f436..fdebd84b 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -246,7 +246,11 @@ def test_searchsorted(data): # TODO: Allow different dtypes for x1 and x2 x1_dtype = data.draw(st.sampled_from(dh.real_dtypes)) _x1 = data.draw( - st.lists(xps.from_dtype(x1_dtype), min_size=1, unique=True), + st.lists( + xps.from_dtype(x1_dtype, allow_nan=False, allow_infinity=False), + min_size=1, + unique=True + ), label="_x1", ) x1 = xp.asarray(_x1, dtype=x1_dtype)