diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 28a84fb0..0c1812e0 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -33,7 +33,6 @@ jobs: with: repository: data-apis/array-api-tests path: array-api-tests - ref: 2022.05.18 submodules: "true" - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 @@ -83,8 +82,9 @@ jobs: # don't run special cases yet array_api_tests/test_special_cases.py - # don't test signatures yet as some are not implemented + # don't test signatures or names yet as some are not implemented array_api_tests/test_signatures.py + array_api_tests/test_has_names.py # very slow array_api_tests/test_creation_functions.py::test_eye diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index eb1af8d8..4dfdc500 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -1,6 +1,6 @@ __all__ = [] -__array_api_version__ = "2021.12" +__array_api_version__ = "2022.12" __all__ += ["__array_api_version__"] @@ -48,12 +48,14 @@ "zeros_like", ] -from .data_type_functions import astype, can_cast, finfo, iinfo, result_type +from .data_type_functions import astype, can_cast, finfo, iinfo, isdtype, result_type -__all__ += ["astype", "can_cast", "finfo", "iinfo", "result_type"] +__all__ += ["astype", "can_cast", "finfo", "iinfo", "isdtype", "result_type"] from .dtypes import ( bool, + complex64, + complex128, float32, float64, int8, @@ -68,6 +70,8 @@ __all__ += [ "bool", + "complex64", + "complex128", "float32", "float64", "int8", @@ -97,6 +101,7 @@ bitwise_right_shift, bitwise_xor, ceil, + conj, cos, cosh, divide, @@ -107,6 +112,7 @@ floor_divide, greater, greater_equal, + imag, isfinite, isinf, isnan, @@ -126,6 +132,7 @@ not_equal, positive, pow, + real, remainder, round, sign, @@ -156,6 +163,7 @@ "bitwise_right_shift", "bitwise_xor", "ceil", + "conj", "cos", "cosh", "divide", @@ -166,6 +174,7 @@ "floor_divide", "greater", "greater_equal", + "imag", "isfinite", "isinf", "isnan", @@ -185,6 +194,7 @@ "not_equal", "positive", "pow", + "real", "remainder", "round", "sign", diff --git a/cubed/array_api/array_object.py b/cubed/array_api/array_object.py index c5491b53..3d449fdb 100644 --- a/cubed/array_api/array_object.py +++ b/cubed/array_api/array_object.py @@ -7,8 +7,10 @@ from cubed.array_api.data_type_functions import result_type from cubed.array_api.dtypes import ( _boolean_dtypes, + _complex_floating_dtypes, _dtype_categories, _floating_dtypes, + _integer_dtypes, _integer_or_boolean_dtypes, _numeric_dtypes, ) @@ -144,13 +146,13 @@ def __truediv__(self, other, /): return elemwise(np.divide, self, other, dtype=result_type(self, other)) def __floordiv__(self, other, /): - other = self._check_allowed_dtypes(other, "numeric", "__floordiv__") + other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__") if other is NotImplemented: return other return elemwise(np.floor_divide, self, other, dtype=result_type(self, other)) def __mod__(self, other, /): - other = self._check_allowed_dtypes(other, "numeric", "__mod__") + other = self._check_allowed_dtypes(other, "real numeric", "__mod__") if other is NotImplemented: return other return elemwise(np.remainder, self, other, dtype=result_type(self, other)) @@ -349,9 +351,16 @@ def __bool__(self, /): raise TypeError("bool is only allowed on arrays with 0 dimensions") return bool(self.compute()) + def __complex__(self, /): + if self.ndim != 0: + raise TypeError("complex is only allowed on arrays with 0 dimensions") + return complex(self.compute()) + def __float__(self, /): if self.ndim != 0: raise TypeError("float is only allowed on arrays with 0 dimensions") + if self.dtype in _complex_floating_dtypes: + raise TypeError("float is not allowed on complex floating-point arrays") return float(self.compute()) def __index__(self, /): @@ -362,6 +371,8 @@ def __index__(self, /): def __int__(self, /): if self.ndim != 0: raise TypeError("int is only allowed on arrays with 0 dimensions") + if self.dtype in _complex_floating_dtypes: + raise TypeError("int is not allowed on complex floating-point arrays") return int(self.compute()) # Utility methods @@ -369,7 +380,7 @@ def __int__(self, /): def _check_allowed_dtypes(self, other, dtype_category, op): if self.dtype not in _dtype_categories[dtype_category]: raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") - if isinstance(other, (int, float, bool)): + if isinstance(other, (int, complex, float, bool)): other = self._promote_scalar(other) elif isinstance(other, CoreArray): if other.dtype not in _dtype_categories[dtype_category]: @@ -392,11 +403,23 @@ def _promote_scalar(self, scalar): raise TypeError( "Python int scalars cannot be promoted with bool arrays" ) + if self.dtype in _integer_dtypes: + info = np.iinfo(self.dtype) + if not (info.min <= scalar <= info.max): + raise OverflowError( + "Python int scalars must be within the bounds of the dtype for integer arrays" + ) + # int + array(floating) is allowed elif isinstance(scalar, float): if self.dtype not in _floating_dtypes: raise TypeError( "Python float scalars can only be promoted with floating-point arrays." ) + elif isinstance(scalar, complex): + if self.dtype not in _complex_floating_dtypes: + raise TypeError( + "Python complex scalars can only be promoted with complex floating-point arrays." + ) else: raise TypeError("'scalar' must be a Python scalar") return asarray(scalar, dtype=self.dtype, spec=self.spec) diff --git a/cubed/array_api/creation_functions.py b/cubed/array_api/creation_functions.py index 356152ce..1a717211 100644 --- a/cubed/array_api/creation_functions.py +++ b/cubed/array_api/creation_functions.py @@ -144,12 +144,13 @@ def full( # note that write_empty_chunks=False means no chunks are written to disk, so it is very efficient to create large arrays shape = normalize_shape(shape) if dtype is None: - if isinstance(fill_value, int): + # check bool first since True/False are instances of int and float + if isinstance(fill_value, bool): + dtype = np.bool_ + elif isinstance(fill_value, int): dtype = np.int64 elif isinstance(fill_value, float): dtype = np.float64 - elif isinstance(fill_value, bool): - dtype = np.bool_ else: raise TypeError("Invalid input to full") chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype)) diff --git a/cubed/array_api/data_type_functions.py b/cubed/array_api/data_type_functions.py index d0774568..1a7dd3a8 100644 --- a/cubed/array_api/data_type_functions.py +++ b/cubed/array_api/data_type_functions.py @@ -1,7 +1,22 @@ -import numpy.array_api as nxp +from dataclasses import dataclass + +import numpy as np +from numpy.array_api._typing import Dtype from cubed.core import CoreArray, map_blocks +from .dtypes import ( + _all_dtypes, + _boolean_dtypes, + _complex_floating_dtypes, + _integer_dtypes, + _numeric_dtypes, + _real_floating_dtypes, + _result_type, + _signed_integer_dtypes, + _unsigned_integer_dtypes, +) + def astype(x, dtype, /, *, copy=True): if not copy and dtype == x.dtype: @@ -14,21 +29,118 @@ def _astype(a, astype_dtype): def can_cast(from_, to, /): + # Copied from numpy.array_api + # TODO: replace with `nxp.can_cast` when NumPy 1.25 is widely used (e.g. in Xarray) + if isinstance(from_, CoreArray): from_ = from_.dtype - return nxp.can_cast(from_, to) + elif from_ not in _all_dtypes: + raise TypeError(f"{from_=}, but should be an array_api array or dtype") + if to not in _all_dtypes: + raise TypeError(f"{to=}, but should be a dtype") + try: + # We promote `from_` and `to` together. We then check if the promoted + # dtype is `to`, which indicates if `from_` can (up)cast to `to`. + dtype = _result_type(from_, to) + return to == dtype + except TypeError: + # _result_type() raises if the dtypes don't promote together + return False + + +@dataclass +class finfo_object: + bits: int + eps: float + max: float + min: float + smallest_normal: float + dtype: Dtype + + +@dataclass +class iinfo_object: + bits: int + max: int + min: int + dtype: Dtype def finfo(type, /): - return nxp.finfo(type) + # Copied from numpy.array_api + # TODO: replace with `nxp.finfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray) + + fi = np.finfo(type) + return finfo_object( + fi.bits, + float(fi.eps), + float(fi.max), + float(fi.min), + float(fi.smallest_normal), + fi.dtype, + ) def iinfo(type, /): - return nxp.iinfo(type) + # Copied from numpy.array_api + # TODO: replace with `nxp.iinfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray) + + ii = np.iinfo(type) + return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype) + + +def isdtype(dtype, kind): + # Copied from numpy.array_api + # TODO: replace with `nxp.isdtype(dtype, kind)` when NumPy 1.25 is widely used (e.g. in Xarray) + + if isinstance(kind, tuple): + # Disallow nested tuples + if any(isinstance(k, tuple) for k in kind): + raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs") + return any(isdtype(dtype, k) for k in kind) + elif isinstance(kind, str): + if kind == "bool": + return dtype in _boolean_dtypes + elif kind == "signed integer": + return dtype in _signed_integer_dtypes + elif kind == "unsigned integer": + return dtype in _unsigned_integer_dtypes + elif kind == "integral": + return dtype in _integer_dtypes + elif kind == "real floating": + return dtype in _real_floating_dtypes + elif kind == "complex floating": + return dtype in _complex_floating_dtypes + elif kind == "numeric": + return dtype in _numeric_dtypes + else: + raise ValueError(f"Unrecognized data type kind: {kind!r}") + elif kind in _all_dtypes: + return dtype == kind + else: + raise TypeError( + f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}" + ) def result_type(*arrays_and_dtypes): - # Use numpy.array_api promotion rules (stricter than numpy) - return nxp.result_type( - *(a.dtype if isinstance(a, CoreArray) else a for a in arrays_and_dtypes) - ) + # Copied from numpy.array_api + # TODO: replace with `nxp.result_type` when NumPy 1.25 is widely used (e.g. in Xarray) + + A = [] + for a in arrays_and_dtypes: + if isinstance(a, CoreArray): + a = a.dtype + elif isinstance(a, np.ndarray) or a not in _all_dtypes: + raise TypeError("result_type() inputs must be array_api arrays or dtypes") + A.append(a) + + if len(A) == 0: + raise ValueError("at least one array or dtype is required") + elif len(A) == 1: + return A[0] + else: + t = A[0] + for t2 in A[1:]: + t = _result_type(t, t2) + return t diff --git a/cubed/array_api/dtypes.py b/cubed/array_api/dtypes.py index 099ed402..b3263791 100644 --- a/cubed/array_api/dtypes.py +++ b/cubed/array_api/dtypes.py @@ -1,14 +1,73 @@ -# Use type code from numpy.array_api -from numpy.array_api._dtypes import ( # noqa: F401 - _boolean_dtypes, - _dtype_categories, - _floating_dtypes, - _integer_dtypes, - _integer_or_boolean_dtypes, - _numeric_dtypes, +# Copied from numpy.array_api +import numpy as np + +# Note: we use dtype objects instead of dtype classes. The spec does not +# require any behavior on dtypes other than equality. +int8 = np.dtype("int8") +int16 = np.dtype("int16") +int32 = np.dtype("int32") +int64 = np.dtype("int64") +uint8 = np.dtype("uint8") +uint16 = np.dtype("uint16") +uint32 = np.dtype("uint32") +uint64 = np.dtype("uint64") +float32 = np.dtype("float32") +float64 = np.dtype("float64") +complex64 = np.dtype("complex64") +complex128 = np.dtype("complex128") +# Note: This name is changed +bool = np.dtype("bool") + +_all_dtypes = ( + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, + bool, +) +_boolean_dtypes = (bool,) +_real_floating_dtypes = (float32, float64) +_floating_dtypes = (float32, float64, complex64, complex128) +_complex_floating_dtypes = (complex64, complex128) +_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_signed_integer_dtypes = (int8, int16, int32, int64) +_unsigned_integer_dtypes = (uint8, uint16, uint32, uint64) +_integer_or_boolean_dtypes = ( bool, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +_real_numeric_dtypes = ( + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +_numeric_dtypes = ( float32, float64, + complex64, + complex128, int8, int16, int32, @@ -19,6 +78,96 @@ uint64, ) -_signed_integer_dtypes = (int8, int16, int32, int64) +_dtype_categories = { + "all": _all_dtypes, + "real numeric": _real_numeric_dtypes, + "numeric": _numeric_dtypes, + "integer": _integer_dtypes, + "integer or boolean": _integer_or_boolean_dtypes, + "boolean": _boolean_dtypes, + "real floating-point": _floating_dtypes, + "complex floating-point": _complex_floating_dtypes, + "floating-point": _floating_dtypes, +} -_unsigned_integer_dtypes = (uint8, uint16, uint32, uint64) +_promotion_table = { + (int8, int8): int8, + (int8, int16): int16, + (int8, int32): int32, + (int8, int64): int64, + (int16, int8): int16, + (int16, int16): int16, + (int16, int32): int32, + (int16, int64): int64, + (int32, int8): int32, + (int32, int16): int32, + (int32, int32): int32, + (int32, int64): int64, + (int64, int8): int64, + (int64, int16): int64, + (int64, int32): int64, + (int64, int64): int64, + (uint8, uint8): uint8, + (uint8, uint16): uint16, + (uint8, uint32): uint32, + (uint8, uint64): uint64, + (uint16, uint8): uint16, + (uint16, uint16): uint16, + (uint16, uint32): uint32, + (uint16, uint64): uint64, + (uint32, uint8): uint32, + (uint32, uint16): uint32, + (uint32, uint32): uint32, + (uint32, uint64): uint64, + (uint64, uint8): uint64, + (uint64, uint16): uint64, + (uint64, uint32): uint64, + (uint64, uint64): uint64, + (int8, uint8): int16, + (int8, uint16): int32, + (int8, uint32): int64, + (int16, uint8): int16, + (int16, uint16): int32, + (int16, uint32): int64, + (int32, uint8): int32, + (int32, uint16): int32, + (int32, uint32): int64, + (int64, uint8): int64, + (int64, uint16): int64, + (int64, uint32): int64, + (uint8, int8): int16, + (uint16, int8): int32, + (uint32, int8): int64, + (uint8, int16): int16, + (uint16, int16): int32, + (uint32, int16): int64, + (uint8, int32): int32, + (uint16, int32): int32, + (uint32, int32): int64, + (uint8, int64): int64, + (uint16, int64): int64, + (uint32, int64): int64, + (float32, float32): float32, + (float32, float64): float64, + (float64, float32): float64, + (float64, float64): float64, + (complex64, complex64): complex64, + (complex64, complex128): complex128, + (complex128, complex64): complex128, + (complex128, complex128): complex128, + (float32, complex64): complex64, + (float32, complex128): complex128, + (float64, complex64): complex128, + (float64, complex128): complex128, + (complex64, float32): complex64, + (complex64, float64): complex128, + (complex128, float32): complex128, + (complex128, float64): complex128, + (bool, bool): bool, +} + + +def _result_type(type1, type2): + if (type1, type2) in _promotion_table: + return _promotion_table[type1, type2] + raise TypeError(f"{type1} and {type2} cannot be type promoted together") diff --git a/cubed/array_api/elementwise_functions.py b/cubed/array_api/elementwise_functions.py index 62d5e8c5..b1d0c2b7 100644 --- a/cubed/array_api/elementwise_functions.py +++ b/cubed/array_api/elementwise_functions.py @@ -3,10 +3,17 @@ from cubed.array_api.data_type_functions import result_type from cubed.array_api.dtypes import ( _boolean_dtypes, + _complex_floating_dtypes, _floating_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, _numeric_dtypes, + _real_floating_dtypes, + _real_numeric_dtypes, + complex64, + complex128, + float32, + float64, ) from cubed.core import elemwise @@ -14,7 +21,13 @@ def abs(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in abs") - return elemwise(np.abs, x, dtype=x.dtype) + if x.dtype == complex64: + dtype = float32 + elif x.dtype == complex128: + dtype = float64 + else: + dtype = x.dtype + return elemwise(np.abs, x, dtype=dtype) def acos(x, /): @@ -54,8 +67,8 @@ def atan(x, /): def atan2(x1, x2, /): - if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in atan2") + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in atan2") return elemwise(np.arctan2, x1, x2, dtype=result_type(x1, x2)) @@ -111,14 +124,20 @@ def bitwise_xor(x1, x2, /): def ceil(x, /): - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in ceil") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in ceil") if x.dtype in _integer_dtypes: # Note: The return dtype of ceil is the same as the input return x return elemwise(np.ceil, x, dtype=x.dtype) +def conj(x, /): + if x.dtype not in _complex_floating_dtypes: + raise TypeError("Only complex floating-point dtypes are allowed in conj") + return elemwise(np.conj, x, dtype=x.dtype) + + def cos(x, /): if x.dtype not in _floating_dtypes: raise TypeError("Only floating-point dtypes are allowed in cos") @@ -154,8 +173,8 @@ def equal(x1, x2, /): def floor(x, /): - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in floor") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in floor") if x.dtype in _integer_dtypes: # Note: The return dtype of floor is the same as the input return x @@ -163,8 +182,8 @@ def floor(x, /): def floor_divide(x1, x2, /): - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in floor_divide") + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in floor_divide") return elemwise(np.floor_divide, x1, x2, dtype=result_type(x1, x2)) @@ -176,6 +195,16 @@ def greater_equal(x1, x2, /): return elemwise(np.greater_equal, x1, x2, dtype=np.bool_) +def imag(x, /): + if x.dtype == complex64: + dtype = float32 + elif x.dtype == complex128: + dtype = float64 + else: + raise TypeError("Only complex floating-point dtypes are allowed in imag") + return elemwise(np.imag, x, dtype=dtype) + + def isfinite(x, /): if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in isfinite") @@ -227,8 +256,8 @@ def log10(x, /): def logaddexp(x1, x2, /): - if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in logaddexp") + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in logaddexp") return elemwise(np.logaddexp, x1, x2, dtype=result_type(x1, x2)) @@ -284,9 +313,19 @@ def pow(x1, x2, /): return elemwise(np.power, x1, x2, dtype=result_type(x1, x2)) +def real(x, /): + if x.dtype == complex64: + dtype = float32 + elif x.dtype == complex128: + dtype = float64 + else: + raise TypeError("Only complex floating-point dtypes are allowed in real") + return elemwise(np.real, x, dtype=dtype) + + def remainder(x1, x2, /): - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in remainder") + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in remainder") return elemwise(np.remainder, x1, x2, dtype=result_type(x1, x2)) @@ -345,8 +384,8 @@ def tanh(x, /): def trunc(x, /): - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in trunc") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in trunc") if x.dtype in _integer_dtypes: # Note: The return dtype of trunc is the same as the input return x diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index 0395e54c..3f6dea47 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -181,7 +181,7 @@ def permute_dims(x, /, axes): ) -def reshape(x, /, shape): +def reshape(x, /, shape, *, copy=None): # based on dask reshape known_sizes = [s for s in shape if s != -1] diff --git a/cubed/array_api/searching_functions.py b/cubed/array_api/searching_functions.py index 8f9c4674..5f13879b 100644 --- a/cubed/array_api/searching_functions.py +++ b/cubed/array_api/searching_functions.py @@ -1,14 +1,14 @@ import numpy as np from cubed.array_api.data_type_functions import result_type -from cubed.array_api.dtypes import _numeric_dtypes +from cubed.array_api.dtypes import _real_numeric_dtypes from cubed.array_api.manipulation_functions import reshape from cubed.core.ops import arg_reduction, elemwise def argmax(x, /, *, axis=None, keepdims=False): - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in argmax") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in argmax") if axis is None: x = reshape(x, (-1,)) axis = 0 @@ -17,8 +17,8 @@ def argmax(x, /, *, axis=None, keepdims=False): def argmin(x, /, *, axis=None, keepdims=False): - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in argmin") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in argmin") if axis is None: x = reshape(x, (-1,)) axis = 0 diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index 153d5a6e..9a448ac9 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -4,19 +4,29 @@ from cubed.array_api.dtypes import ( _numeric_dtypes, + _real_floating_dtypes, + _real_numeric_dtypes, _signed_integer_dtypes, _unsigned_integer_dtypes, + complex64, + complex128, + float32, + float64, + int64, + uint64, ) from cubed.core import reduction def max(x, /, *, axis=None, keepdims=False): - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in max") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in max") return reduction(x, np.max, axis=axis, dtype=x.dtype, keepdims=keepdims) def mean(x, /, *, axis=None, keepdims=False): + if x.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in mean") # This implementation uses NumPy and Zarr's structured arrays to store a # pair of fields needed to keep per-chunk counts and totals for computing # the mean. Structured arrays are row-based, so are less efficient than @@ -85,8 +95,8 @@ def _numel(x, **kwargs): def min(x, /, *, axis=None, keepdims=False): - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in min") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in min") return reduction(x, np.min, axis=axis, dtype=x.dtype, keepdims=keepdims) @@ -95,11 +105,13 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False): raise TypeError("Only numeric dtypes are allowed in prod") if dtype is None: if x.dtype in _signed_integer_dtypes: - dtype = np.int64 + dtype = int64 elif x.dtype in _unsigned_integer_dtypes: - dtype = np.uint64 - elif x.dtype == np.float32: - dtype = np.float64 + dtype = uint64 + elif x.dtype == float32: + dtype = float64 + elif x.dtype == complex64: + dtype = complex128 else: dtype = x.dtype return reduction(x, np.prod, axis=axis, dtype=dtype, keepdims=keepdims) @@ -110,11 +122,13 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False): raise TypeError("Only numeric dtypes are allowed in sum") if dtype is None: if x.dtype in _signed_integer_dtypes: - dtype = np.int64 + dtype = int64 elif x.dtype in _unsigned_integer_dtypes: - dtype = np.uint64 - elif x.dtype == np.float32: - dtype = np.float64 + dtype = uint64 + elif x.dtype == float32: + dtype = float64 + elif x.dtype == complex64: + dtype = complex128 else: dtype = x.dtype return reduction(x, np.sum, axis=axis, dtype=dtype, keepdims=keepdims)