From af1ab743a76982297bca063aac93d8e5f4aef88a Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 22 May 2024 09:12:52 +0100 Subject: [PATCH] Implement NumPy's `__array_function__` protocol for array methods that are not in the Array API Standard (#468) * nanmean, nansum, pad --- cubed/array_api/array_object.py | 24 ++++++++++++++++++++++++ cubed/nan_functions.py | 9 +++++++-- cubed/pad.py | 7 +++++++ cubed/tests/test_nan_functions.py | 12 ++++++++---- cubed/tests/test_pad.py | 8 ++++++-- 5 files changed, 52 insertions(+), 8 deletions(-) diff --git a/cubed/array_api/array_object.py b/cubed/array_api/array_object.py index 3a2b23ff..f4ba89e6 100644 --- a/cubed/array_api/array_object.py +++ b/cubed/array_api/array_object.py @@ -29,6 +29,24 @@ 120 # cubed doesn't have a config module like dask does so hard-code this for now ) +_HANDLED_FUNCTIONS = {} + + +def implements(*numpy_functions): + """Register an __array_function__ implementation for cubed.Array + + Note that this is **only** used for functions that are not defined in the + Array API Standard. + """ + + def decorator(cubed_func): + for numpy_function in numpy_functions: + _HANDLED_FUNCTIONS[numpy_function] = cubed_func + + return cubed_func + + return decorator + class Array(CoreArray): """Chunked array backed by Zarr storage that conforms to the Python Array API standard.""" @@ -44,6 +62,12 @@ def __array__(self, dtype=None) -> np.ndarray: x = np.array(x) return x + def __array_function__(self, func, types, args, kwargs): + # Only dispatch to functions that are not defined in the Array API Standard + if func in _HANDLED_FUNCTIONS: + return _HANDLED_FUNCTIONS[func](*args, **kwargs) + return NotImplemented + def __repr__(self): return f"cubed.Array<{self.name}, shape={self.shape}, dtype={self.dtype}, chunks={self.chunks}>" diff --git a/cubed/nan_functions.py b/cubed/nan_functions.py index 2acd308b..2de161df 100644 --- a/cubed/nan_functions.py +++ b/cubed/nan_functions.py @@ -1,5 +1,6 @@ import numpy as np +from cubed.array_api.array_object import implements from cubed.array_api.dtypes import ( _numeric_dtypes, _signed_integer_dtypes, @@ -18,9 +19,12 @@ # https://github.com/data-apis/array-api/issues/621 -def nanmean(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None): +@implements(np.nanmean) +def nanmean( + x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None +): """Compute the arithmetic mean along the specified axis, ignoring NaNs.""" - dtype = x.dtype + dtype = dtype or x.dtype intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)] return reduction( x, @@ -61,6 +65,7 @@ def _nannumel(x, **kwargs): return nxp.sum(~(nxp.isnan(x)), **kwargs) +@implements(np.nansum) def nansum( x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None ): diff --git a/cubed/pad.py b/cubed/pad.py index c292c65e..afb0c4d9 100644 --- a/cubed/pad.py +++ b/cubed/pad.py @@ -1,6 +1,13 @@ +import numpy as np + +from cubed.array_api.array_object import implements from cubed.array_api.manipulation_functions import concat +# TODO: refactor once pad is standardized: +# https://github.com/data-apis/array-api/issues/187 + +@implements(np.pad) def pad(x, pad_width, mode=None, chunks=None): """Pad an array.""" if len(pad_width) != x.ndim: diff --git a/cubed/tests/test_nan_functions.py b/cubed/tests/test_nan_functions.py index 53264e79..f67ce71d 100644 --- a/cubed/tests/test_nan_functions.py +++ b/cubed/tests/test_nan_functions.py @@ -11,9 +11,11 @@ def spec(tmp_path): return cubed.Spec(tmp_path, allowed_mem=100000) -def test_nanmean(spec): +@pytest.mark.parametrize("namespace", [cubed, np]) +def test_nanmean(spec, namespace): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec) - b = cubed.nanmean(a) + b = namespace.nanmean(a) + assert isinstance(b, cubed.Array) assert_array_equal( b.compute(), np.nanmean(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]])) ) @@ -26,9 +28,11 @@ def test_nanmean_allnan(spec): assert_array_equal(b.compute(), np.nanmean(np.array([np.nan]))) -def test_nansum(spec): +@pytest.mark.parametrize("namespace", [cubed, np]) +def test_nansum(spec, namespace): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec) - b = cubed.nansum(a) + b = namespace.nansum(a) + assert isinstance(b, cubed.Array) assert_array_equal( b.compute(), np.nansum(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]])) ) diff --git a/cubed/tests/test_pad.py b/cubed/tests/test_pad.py index 7ba985f4..027348da 100644 --- a/cubed/tests/test_pad.py +++ b/cubed/tests/test_pad.py @@ -11,11 +11,15 @@ def spec(tmp_path): return cubed.Spec(tmp_path, allowed_mem=100000) -def test_pad(spec): +@pytest.mark.parametrize("namespace", [cubed, np]) +def test_pad(spec, namespace): an = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) - b = cubed.pad(a, ((1, 0), (0, 0)), mode="symmetric") + # check that we can dispatch via the numpy namespace (via __array_function__) + # since pad is not yet a part of the Array API Standard + b = namespace.pad(a, ((1, 0), (0, 0)), mode="symmetric") + assert isinstance(b, cubed.Array) assert b.chunks == ((2, 2), (2, 1)) assert_array_equal(b.compute(), np.pad(an, ((1, 0), (0, 0)), mode="symmetric"))