From 0397298bb78f8c3e54aa5ec45fd600dc33d555bb Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Tue, 14 Oct 2025 17:01:23 +0500 Subject: [PATCH] ENH allow Python scalars for functions with multiple arrays --- src/array_api_extra/_lib/_utils/_helpers.py | 4 ++-- tests/test_funcs.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index fbe986a1..22da8259 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -212,10 +212,10 @@ def asarrays( } kind = same_dtype[type(cast(complex, b))] if xp.isdtype(a.dtype, kind): - xb = xp.asarray(b, dtype=a.dtype) + xb = xp.asarray(b, dtype=a.dtype, device=_compat.device(a)) else: # Undefined behaviour. Let the function deal with it, if it can. - xb = xp.asarray(b) + xb = xp.asarray(b, device=_compat.device(a)) else: # Neither a nor b are Array API objects. diff --git a/tests/test_funcs.py b/tests/test_funcs.py index ac7f4d8c..a80dd3e0 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -888,6 +888,19 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool): b = xp.asarray([1e-9, 1e-4, xp.nan], device=device) res = isclose(a, b, equal_nan=equal_nan) assert get_device(res) == device + + def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device): + a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device) + b = 1 + res = isclose(a, b) + assert get_device(res) == device + xp_assert_equal(res, xp.asarray([False, False, False, False, True])) + + a = 0.1 + b = xp.asarray([0.01, 0.5, 0.8, 0.9, 0.100001], device=device) + res = isclose(a, b) + assert get_device(res) == device + xp_assert_equal(res, xp.asarray([False, False, False, False, True])) class TestKron: