Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions array_api_strict/_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ._array_object import Array
from ._dtypes import _real_numeric_dtypes, _result_type
from ._dtypes import bool as _bool
from ._flags import requires_api_version, requires_data_dependent_shapes
from ._flags import requires_api_version, requires_data_dependent_shapes, get_array_api_strict_flags
from ._helpers import _maybe_normalize_py_scalars


Expand Down Expand Up @@ -64,7 +64,7 @@ def count_nonzero(
@requires_api_version('2023.12')
def searchsorted(
x1: Array,
x2: Array,
x2: Array | int | float,
/,
*,
side: Literal["left", "right"] = "left",
Expand All @@ -75,6 +75,12 @@ def searchsorted(

See its docstring for more information.
"""
flags = get_array_api_strict_flags()
if flags["api_version"] >= "2025.12":

if isinstance(x2, bool | int | float | complex):
x2 = x1._promote_scalar(x2)

if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in searchsorted")

Expand Down
Loading