diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index fbe986a1..22da8259 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -212,10 +212,10 @@ def asarrays( } kind = same_dtype[type(cast(complex, b))] if xp.isdtype(a.dtype, kind): - xb = xp.asarray(b, dtype=a.dtype) + xb = xp.asarray(b, dtype=a.dtype, device=_compat.device(a)) else: # Undefined behaviour. Let the function deal with it, if it can. - xb = xp.asarray(b) + xb = xp.asarray(b, device=_compat.device(a)) else: # Neither a nor b are Array API objects. diff --git a/tests/test_funcs.py b/tests/test_funcs.py index ac7f4d8c..a80dd3e0 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -888,6 +888,19 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool): b = xp.asarray([1e-9, 1e-4, xp.nan], device=device) res = isclose(a, b, equal_nan=equal_nan) assert get_device(res) == device + + def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device): + a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device) + b = 1 + res = isclose(a, b) + assert get_device(res) == device + xp_assert_equal(res, xp.asarray([False, False, False, False, True])) + + a = 0.1 + b = xp.asarray([0.01, 0.5, 0.8, 0.9, 0.100001], device=device) + res = isclose(a, b) + assert get_device(res) == device + xp_assert_equal(res, xp.asarray([False, False, False, False, True])) class TestKron: