From 80e64de6ab7e3bf7819c08b867c2e279d62e68cb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 14:33:12 -0600 Subject: [PATCH 01/15] Bump array API version in numpy.array_api --- numpy/array_api/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index e154b995298c..7ccb9a89b3de 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -121,7 +121,7 @@ "The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2 ) -__array_api_version__ = "2021.12" +__array_api_version__ = "2022.12" __all__ = ["__array_api_version__"] From 33064b882c53811e1952f1b29d748d644db783e6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 14:33:25 -0600 Subject: [PATCH 02/15] Add complex dtypes to numpy.array_api result_type --- numpy/array_api/_dtypes.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py index 476d619fee63..06acab8e80df 100644 --- a/numpy/array_api/_dtypes.py +++ b/numpy/array_api/_dtypes.py @@ -12,6 +12,8 @@ uint64 = np.dtype("uint64") float32 = np.dtype("float32") float64 = np.dtype("float64") +complex64 = np.dtype("complex64") +complex128 = np.dtype("complex128") # Note: This name is changed bool = np.dtype("bool") @@ -26,10 +28,13 @@ uint64, float32, float64, + complex64, + complex128, bool, ) _boolean_dtypes = (bool,) -_floating_dtypes = (float32, float64) +_real_floating_dtypes = (float32, float64) +_floating_dtypes = (float32, float64, complex64, complex128) _integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) _integer_or_boolean_dtypes = ( bool, @@ -45,6 +50,8 @@ _numeric_dtypes = ( float32, float64, + complex64, + complex128, int8, int16, int32, @@ -61,6 +68,7 @@ "integer": _integer_dtypes, "integer or boolean": _integer_or_boolean_dtypes, "boolean": _boolean_dtypes, + "real floating-point": _floating_dtypes, "floating-point": _floating_dtypes, } @@ -133,6 +141,18 @@ (float32, float64): float64, (float64, float32): float64, (float64, float64): float64, + (complex64, complex64): complex64, + (complex64, complex128): complex128, + (complex128, complex64): complex128, + (complex128, complex64): complex128, + (float32, complex64): complex64, + (float32, complex128): complex128, + (float64, complex64): complex128, + (float64, complex128): complex128, + (complex64, float32): complex64, + (complex64, float64): complex128, + (complex128, float32): complex128, + (complex128, float64): complex128, (bool, bool): bool, } From bd86d17d60f3d9fa93c35665f034ac09c1193f98 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 16:21:30 -0600 Subject: [PATCH 03/15] Add complex dtypes to the array_api top-level namespace --- numpy/array_api/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index 7ccb9a89b3de..eb1af3b96eca 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -198,6 +198,8 @@ uint64, float32, float64, + complex64, + complex128, bool, ) From 8b63fc295bafea3efd3d115964200dbaac7be8a5 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 16:25:42 -0600 Subject: [PATCH 04/15] Update numpy.array_api magic methods for complex numbers Updates from the v2022.12 version of the spec: - Add __complex__. - __float__, __int__, and __bool__ are now more lenient in what dtypes they can operate on. - Support complex scalars and dtypes in all operators (except those that should not operate on complex numbers). - Disallow integer scalars that are out of the bounds of the array dtype. - Update the tests accordingly. --- numpy/array_api/_array_object.py | 55 ++++++++++++----- numpy/array_api/_dtypes.py | 16 ++++- numpy/array_api/tests/test_array_object.py | 72 ++++++++++++++-------- 3 files changed, 99 insertions(+), 44 deletions(-) diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index a949b5977c25..c7eb2a0a361f 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -24,6 +24,7 @@ _integer_dtypes, _integer_or_boolean_dtypes, _floating_dtypes, + _complex_floating_dtypes, _numeric_dtypes, _result_type, _dtype_categories, @@ -139,7 +140,7 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor if self.dtype not in _dtype_categories[dtype_category]: raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}") - if isinstance(other, (int, float, bool)): + if isinstance(other, (int, complex, float, bool)): other = self._promote_scalar(other) elif isinstance(other, Array): if other.dtype not in _dtype_categories[dtype_category]: @@ -189,11 +190,23 @@ def _promote_scalar(self, scalar): raise TypeError( "Python int scalars cannot be promoted with bool arrays" ) + if self.dtype in _integer_dtypes: + info = np.iinfo(self.dtype) + if not (info.min <= scalar <= info.max): + raise OverflowError( + "Python int scalars must be within the bounds of the dtype for integer arrays" + ) + # int + array(floating) is allowed elif isinstance(scalar, float): if self.dtype not in _floating_dtypes: raise TypeError( "Python float scalars can only be promoted with floating-point arrays." ) + elif isinstance(scalar, complex): + if self.dtype not in _complex_floating_dtypes: + raise TypeError( + "Python complex scalars can only be promoted with complex floating-point arrays." + ) else: raise TypeError("'scalar' must be a Python scalar") @@ -454,11 +467,19 @@ def __bool__(self: Array, /) -> bool: # Note: This is an error here. if self._array.ndim != 0: raise TypeError("bool is only allowed on arrays with 0 dimensions") - if self.dtype not in _boolean_dtypes: - raise ValueError("bool is only allowed on boolean arrays") res = self._array.__bool__() return res + def __complex__(self: Array, /) -> float: + """ + Performs the operation __complex__. + """ + # Note: This is an error here. + if self._array.ndim != 0: + raise TypeError("complex is only allowed on arrays with 0 dimensions") + res = self._array.__complex__() + return res + def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule: """ Performs the operation __dlpack__. @@ -492,8 +513,8 @@ def __float__(self: Array, /) -> float: # Note: This is an error here. if self._array.ndim != 0: raise TypeError("float is only allowed on arrays with 0 dimensions") - if self.dtype not in _floating_dtypes: - raise ValueError("float is only allowed on floating-point arrays") + if self.dtype in _complex_floating_dtypes: + raise TypeError("float is not allowed on complex floating-point arrays") res = self._array.__float__() return res @@ -501,7 +522,7 @@ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __floordiv__. """ - other = self._check_allowed_dtypes(other, "numeric", "__floordiv__") + other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -512,7 +533,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ge__. """ - other = self._check_allowed_dtypes(other, "numeric", "__ge__") + other = self._check_allowed_dtypes(other, "real numeric", "__ge__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -542,7 +563,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __gt__. """ - other = self._check_allowed_dtypes(other, "numeric", "__gt__") + other = self._check_allowed_dtypes(other, "real numeric", "__gt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -556,8 +577,8 @@ def __int__(self: Array, /) -> int: # Note: This is an error here. if self._array.ndim != 0: raise TypeError("int is only allowed on arrays with 0 dimensions") - if self.dtype not in _integer_dtypes: - raise ValueError("int is only allowed on integer arrays") + if self.dtype in _complex_floating_dtypes: + raise TypeError("int is not allowed on complex floating-point arrays") res = self._array.__int__() return res @@ -581,7 +602,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __le__. """ - other = self._check_allowed_dtypes(other, "numeric", "__le__") + other = self._check_allowed_dtypes(other, "real numeric", "__le__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -603,7 +624,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __lt__. """ - other = self._check_allowed_dtypes(other, "numeric", "__lt__") + other = self._check_allowed_dtypes(other, "real numeric", "__lt__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -626,7 +647,7 @@ def __mod__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __mod__. """ - other = self._check_allowed_dtypes(other, "numeric", "__mod__") + other = self._check_allowed_dtypes(other, "real numeric", "__mod__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -808,7 +829,7 @@ def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __ifloordiv__. """ - other = self._check_allowed_dtypes(other, "numeric", "__ifloordiv__") + other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__") if other is NotImplemented: return other self._array.__ifloordiv__(other._array) @@ -818,7 +839,7 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __rfloordiv__. """ - other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__") + other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) @@ -874,7 +895,7 @@ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __imod__. """ - other = self._check_allowed_dtypes(other, "numeric", "__imod__") + other = self._check_allowed_dtypes(other, "real numeric", "__imod__") if other is NotImplemented: return other self._array.__imod__(other._array) @@ -884,7 +905,7 @@ def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array: """ Performs the operation __rmod__. """ - other = self._check_allowed_dtypes(other, "numeric", "__rmod__") + other = self._check_allowed_dtypes(other, "real numeric", "__rmod__") if other is NotImplemented: return other self, other = self._normalize_two_args(self, other) diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py index 06acab8e80df..54268116c708 100644 --- a/numpy/array_api/_dtypes.py +++ b/numpy/array_api/_dtypes.py @@ -35,6 +35,7 @@ _boolean_dtypes = (bool,) _real_floating_dtypes = (float32, float64) _floating_dtypes = (float32, float64, complex64, complex128) +_complex_floating_dtypes = (complex64, complex128) _integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) _integer_or_boolean_dtypes = ( bool, @@ -47,6 +48,18 @@ uint32, uint64, ) +_real_numeric_dtypes = ( + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) _numeric_dtypes = ( float32, float64, @@ -64,6 +77,7 @@ _dtype_categories = { "all": _all_dtypes, + "real numeric": _real_numeric_dtypes, "numeric": _numeric_dtypes, "integer": _integer_dtypes, "integer or boolean": _integer_or_boolean_dtypes, @@ -144,7 +158,7 @@ (complex64, complex64): complex64, (complex64, complex128): complex128, (complex128, complex64): complex128, - (complex128, complex64): complex128, + (complex128, complex128): complex128, (float32, complex64): complex64, (float32, complex128): complex128, (float64, complex64): complex128, diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index f6efacefaee1..0feb72c4ea33 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -1,6 +1,6 @@ import operator -from numpy.testing import assert_raises +from numpy.testing import assert_raises, suppress_warnings import numpy as np import pytest @@ -9,9 +9,12 @@ from .._dtypes import ( _all_dtypes, _boolean_dtypes, + _real_floating_dtypes, _floating_dtypes, + _complex_floating_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, + _real_numeric_dtypes, _numeric_dtypes, int8, int16, @@ -85,13 +88,13 @@ def test_operators(): "__add__": "numeric", "__and__": "integer_or_boolean", "__eq__": "all", - "__floordiv__": "numeric", - "__ge__": "numeric", - "__gt__": "numeric", - "__le__": "numeric", + "__floordiv__": "real numeric", + "__ge__": "real numeric", + "__gt__": "real numeric", + "__le__": "real numeric", "__lshift__": "integer", - "__lt__": "numeric", - "__mod__": "numeric", + "__lt__": "real numeric", + "__mod__": "real numeric", "__mul__": "numeric", "__ne__": "all", "__or__": "integer_or_boolean", @@ -101,7 +104,6 @@ def test_operators(): "__truediv__": "floating", "__xor__": "integer_or_boolean", } - # Recompute each time because of in-place ops def _array_vals(): for d in _integer_dtypes: @@ -111,13 +113,15 @@ def _array_vals(): for d in _floating_dtypes: yield asarray(1.0, dtype=d) + + BIG_INT = int(1e30) for op, dtypes in binary_op_dtypes.items(): ops = [op] if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]: rop = "__r" + op[2:] iop = "__i" + op[2:] ops += [rop, iop] - for s in [1, 1.0, False]: + for s in [1, 1.0, 1j, BIG_INT, False]: for _op in ops: for a in _array_vals(): # Test array op scalar. From the spec, the following combinations @@ -125,13 +129,12 @@ def _array_vals(): # - Python bool for a bool array dtype, # - a Python int within the bounds of the given dtype for integer array dtypes, - # - a Python int or float for floating-point array dtypes - - # We do not do bounds checking for int scalars, but rather use the default - # NumPy behavior for casting in that case. + # - a Python int or float for real floating-point array dtypes + # - a Python int, float, or complex for complex floating-point array dtypes if ((dtypes == "all" or dtypes == "numeric" and a.dtype in _numeric_dtypes + or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes or dtypes == "integer" and a.dtype in _integer_dtypes or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes or dtypes == "boolean" and a.dtype in _boolean_dtypes @@ -141,10 +144,18 @@ def _array_vals(): # isinstance here. and (a.dtype in _boolean_dtypes and type(s) == bool or a.dtype in _integer_dtypes and type(s) == int - or a.dtype in _floating_dtypes and type(s) in [float, int] + or a.dtype in _real_floating_dtypes and type(s) in [float, int] + or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int] )): - # Only test for no error - getattr(a, _op)(s) + if a.dtype in _integer_dtypes and s == BIG_INT: + assert_raises(OverflowError, lambda: getattr(a, _op)(s)) + else: + # Only test for no error + with suppress_warnings() as sup: + # ignore warnings from pow(BIG_INT) + sup.filter(RuntimeWarning, + "invalid value encountered in power") + getattr(a, _op)(s) else: assert_raises(TypeError, lambda: getattr(a, _op)(s)) @@ -174,8 +185,9 @@ def _array_vals(): # Ensure only those dtypes that are required for every operator are allowed. elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) + or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes) or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes) - or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _numeric_dtypes + or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes) or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes @@ -263,31 +275,39 @@ def test_python_scalar_construtors(): b = asarray(False) i = asarray(0) f = asarray(0.0) + c = asarray(0j) assert bool(b) == False assert int(i) == 0 assert float(f) == 0.0 assert operator.index(i) == 0 - # bool/int/float should only be allowed on 0-D arrays. + # bool/int/float/complex should only be allowed on 0-D arrays. assert_raises(TypeError, lambda: bool(asarray([False]))) assert_raises(TypeError, lambda: int(asarray([0]))) assert_raises(TypeError, lambda: float(asarray([0.0]))) + assert_raises(TypeError, lambda: complex(asarray([0j]))) assert_raises(TypeError, lambda: operator.index(asarray([0]))) - # bool/int/float should only be allowed on arrays of the corresponding - # dtype - assert_raises(ValueError, lambda: bool(i)) - assert_raises(ValueError, lambda: bool(f)) + # bool should work on all types of arrays + assert bool(b) is bool(i) is bool(f) is bool(c) is False + + # int should fail on complex arrays + assert int(b) == int(i) == int(f) == 0 + assert_raises(TypeError, lambda: int(c)) - assert_raises(ValueError, lambda: int(b)) - assert_raises(ValueError, lambda: int(f)) + # float should fail on complex arrays + assert float(b) == float(i) == float(f) == 0.0 + assert_raises(TypeError, lambda: float(c)) - assert_raises(ValueError, lambda: float(b)) - assert_raises(ValueError, lambda: float(i)) + # complex should work on all types of arrays + assert complex(b) == complex(i) == complex(f) == complex(c) == 0j + # index should only work on integer arrays + assert operator.index(i) == 0 assert_raises(TypeError, lambda: operator.index(b)) assert_raises(TypeError, lambda: operator.index(f)) + assert_raises(TypeError, lambda: operator.index(c)) def test_device_property(): From 103bca57407ba69632c005609c58538c3a765123 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 16:49:22 -0600 Subject: [PATCH 05/15] Add conj, imag, and real functions to numpy.array_api --- numpy/array_api/__init__.py | 3 ++ numpy/array_api/_dtypes.py | 1 + numpy/array_api/_elementwise_functions.py | 34 +++++++++++++++++++ .../tests/test_elementwise_functions.py | 3 ++ 4 files changed, 41 insertions(+) diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index eb1af3b96eca..dcfff33e18d2 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -234,6 +234,7 @@ bitwise_right_shift, bitwise_xor, ceil, + conj, cos, cosh, divide, @@ -244,6 +245,7 @@ floor_divide, greater, greater_equal, + imag, isfinite, isinf, isnan, @@ -263,6 +265,7 @@ not_equal, positive, pow, + real, remainder, round, sign, diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py index 54268116c708..e19473c891ec 100644 --- a/numpy/array_api/_dtypes.py +++ b/numpy/array_api/_dtypes.py @@ -83,6 +83,7 @@ "integer or boolean": _integer_or_boolean_dtypes, "boolean": _boolean_dtypes, "real floating-point": _floating_dtypes, + "complex floating-point": _complex_floating_dtypes, "floating-point": _floating_dtypes, } diff --git a/numpy/array_api/_elementwise_functions.py b/numpy/array_api/_elementwise_functions.py index c758a09447a0..5ea48528a8d5 100644 --- a/numpy/array_api/_elementwise_functions.py +++ b/numpy/array_api/_elementwise_functions.py @@ -3,6 +3,7 @@ from ._dtypes import ( _boolean_dtypes, _floating_dtypes, + _complex_floating_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, _numeric_dtypes, @@ -238,6 +239,17 @@ def ceil(x: Array, /) -> Array: return Array._new(np.ceil(x._array)) +def conj(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.conj `. + + See its docstring for more information. + """ + if x.dtype not in _complex_floating_dtypes: + raise TypeError("Only complex floating-point dtypes are allowed in conj") + return Array._new(np.conj(x)) + + def cos(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.cos `. @@ -364,6 +376,17 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: return Array._new(np.greater_equal(x1._array, x2._array)) +def imag(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.imag `. + + See its docstring for more information. + """ + if x.dtype not in _complex_floating_dtypes: + raise TypeError("Only complex floating-point dtypes are allowed in imag") + return Array._new(np.imag(x)) + + def isfinite(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.isfinite `. @@ -599,6 +622,17 @@ def pow(x1: Array, x2: Array, /) -> Array: return Array._new(np.power(x1._array, x2._array)) +def real(x: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.real `. + + See its docstring for more information. + """ + if x.dtype not in _complex_floating_dtypes: + raise TypeError("Only complex floating-point dtypes are allowed in real") + return Array._new(np.real(x)) + + def remainder(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.remainder `. diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index b2fb44e766f8..9ab73e0e3f91 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -38,6 +38,7 @@ def test_function_types(): "bitwise_right_shift": "integer", "bitwise_xor": "integer or boolean", "ceil": "numeric", + "conj": "complex floating-point", "cos": "floating-point", "cosh": "floating-point", "divide": "floating-point", @@ -48,6 +49,7 @@ def test_function_types(): "floor_divide": "numeric", "greater": "numeric", "greater_equal": "numeric", + "imag": "complex floating-point", "isfinite": "numeric", "isinf": "numeric", "isnan": "numeric", @@ -67,6 +69,7 @@ def test_function_types(): "not_equal": "all", "positive": "numeric", "pow": "numeric", + "real": "complex floating-point", "remainder": "numeric", "round": "numeric", "sign": "numeric", From e023bc611661bbed26292b098945170728e67d48 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 18:10:20 -0600 Subject: [PATCH 06/15] Update numpy.array_api sum() and prod() to handle complex dtypes --- numpy/array_api/_statistical_functions.py | 27 ++++++++++++++--------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py index 5bc831ac2965..da42076adc95 100644 --- a/numpy/array_api/_statistical_functions.py +++ b/numpy/array_api/_statistical_functions.py @@ -6,7 +6,7 @@ ) from ._array_object import Array from ._creation_functions import asarray -from ._dtypes import float32, float64 +from ._dtypes import float32, float64, complex64, complex128 from typing import TYPE_CHECKING, Optional, Tuple, Union @@ -62,10 +62,14 @@ def prod( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in prod") - # Note: sum() and prod() always upcast float32 to float64 for dtype=None - # We need to do so here before computing the product to avoid overflow - if dtype is None and x.dtype == float32: - dtype = float64 + # Note: sum() and prod() always upcast for dtype=None. `np.prod` does that + # for integers, but not for float32 or complex64, so we need to + # special-case it here + if dtype is None: + if x.dtype == float32: + dtype = float64 + elif x.dtype == complex64: + dtype = complex128 return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims)) @@ -93,11 +97,14 @@ def sum( ) -> Array: if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sum") - # Note: sum() and prod() always upcast integers to (u)int64 and float32 to - # float64 for dtype=None. `np.sum` does that too for integers, but not for - # float32, so we need to special-case it here - if dtype is None and x.dtype == float32: - dtype = float64 + # Note: sum() and prod() always upcast for dtype=None. `np.sum` does that + # for integers, but not for float32 or complex64, so we need to + # special-case it here + if dtype is None: + if x.dtype == float32: + dtype = float64 + elif x.dtype == complex64: + dtype = complex128 return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims)) From 173fbc7009719ce802aa70634fb93031a0c00cfb Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 18:10:36 -0600 Subject: [PATCH 07/15] Add isdtype() to numpy.array_api This is a new function in the v2022.12 version of the array API standard which is used for determining if a given dtype is part of a set of given dtype categories. This will also eventually be added to the main NumPy namespace, but for now only exists in numpy.array_api as a purely strict version. --- numpy/array_api/__init__.py | 1 + numpy/array_api/_data_type_functions.py | 50 ++++++++++++++++++- numpy/array_api/_dtypes.py | 2 + .../tests/test_data_type_functions.py | 14 +++++- 4 files changed, 65 insertions(+), 2 deletions(-) diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index dcfff33e18d2..964873faab20 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -173,6 +173,7 @@ broadcast_to, can_cast, finfo, + isdtype, iinfo, result_type, ) diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index 7026bd489563..ede7f85d38d1 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -1,7 +1,17 @@ from __future__ import annotations from ._array_object import Array -from ._dtypes import _all_dtypes, _result_type +from ._dtypes import ( + _all_dtypes, + _boolean_dtypes, + _signed_integer_dtypes, + _unsigned_integer_dtypes, + _integer_dtypes, + _real_floating_dtypes, + _complex_floating_dtypes, + _numeric_dtypes, + _result_type, +) from dataclasses import dataclass from typing import TYPE_CHECKING, List, Tuple, Union @@ -117,6 +127,44 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: return iinfo_object(ii.bits, ii.max, ii.min) +# Note: isdtype is a new function from the 2022.12 array API specification. +def isdtype( + dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]] +) -> bool: + """ + Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + for more details + """ + if isinstance(kind, tuple): + # Disallow nested tuples + if any(isinstance(k, tuple) for k in kind): + raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs") + return any(isdtype(dtype, k) for k in kind) + elif isinstance(kind, str): + if kind == 'bool': + return dtype in _boolean_dtypes + elif kind == 'signed integer': + return dtype in _signed_integer_dtypes + elif kind == 'unsigned integer': + return dtype in _unsigned_integer_dtypes + elif kind == 'integral': + return dtype in _integer_dtypes + elif kind == 'real floating': + return dtype in _real_floating_dtypes + elif kind == 'complex floating': + return dtype in _complex_floating_dtypes + elif kind == 'numeric': + return dtype in _numeric_dtypes + else: + raise ValueError(f"Unrecognized data type kind: {kind!r}") + elif kind in _all_dtypes: + return dtype == kind + else: + raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}") + def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype: """ Array API compatible wrapper for :py:func:`np.result_type `. diff --git a/numpy/array_api/_dtypes.py b/numpy/array_api/_dtypes.py index e19473c891ec..0e8f666eeedd 100644 --- a/numpy/array_api/_dtypes.py +++ b/numpy/array_api/_dtypes.py @@ -37,6 +37,8 @@ _floating_dtypes = (float32, float64, complex64, complex128) _complex_floating_dtypes = (complex64, complex128) _integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_signed_integer_dtypes = (int8, int16, int32, int64) +_unsigned_integer_dtypes = (uint8, uint16, uint32, uint64) _integer_or_boolean_dtypes = ( bool, int8, diff --git a/numpy/array_api/tests/test_data_type_functions.py b/numpy/array_api/tests/test_data_type_functions.py index efe3d0abd4be..61d56ca45b1e 100644 --- a/numpy/array_api/tests/test_data_type_functions.py +++ b/numpy/array_api/tests/test_data_type_functions.py @@ -1,7 +1,8 @@ import pytest +from numpy.testing import assert_raises from numpy import array_api as xp - +import numpy as np @pytest.mark.parametrize( "from_, to, expected", @@ -17,3 +18,14 @@ def test_can_cast(from_, to, expected): can_cast() returns correct result """ assert xp.can_cast(from_, to) == expected + +def test_isdtype_strictness(): + assert_raises(TypeError, lambda: xp.isdtype(xp.float64, 64)) + assert_raises(ValueError, lambda: xp.isdtype(xp.float64, 'f8')) + + assert_raises(TypeError, lambda: xp.isdtype(xp.float64, (('integral',),))) + assert_raises(TypeError, lambda: xp.isdtype(xp.float64, np.object_)) + + # TODO: These will require https://github.com/numpy/numpy/issues/23883 + # assert_raises(TypeError, lambda: xp.isdtype(xp.float64, None)) + # assert_raises(TypeError, lambda: xp.isdtype(xp.float64, np.float64)) From ca1ef2ba12f22beec334c964d2d54fecff4f0772 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 18:23:22 -0600 Subject: [PATCH 08/15] Add a release notes entry --- doc/release/upcoming_changes/23789.new_feature.rst | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 doc/release/upcoming_changes/23789.new_feature.rst diff --git a/doc/release/upcoming_changes/23789.new_feature.rst b/doc/release/upcoming_changes/23789.new_feature.rst new file mode 100644 index 000000000000..417b944f92e9 --- /dev/null +++ b/doc/release/upcoming_changes/23789.new_feature.rst @@ -0,0 +1,5 @@ +Array API v2022.12 support in ``numpy.array_api`` +------------------------------------------------- + +- ``numpy.array_api`` now full supports the `v2022.12 version + `__ of the array API standard. From c866ef19c71b2d0269340ce984be42fd8de45e28 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 18:39:18 -0600 Subject: [PATCH 09/15] Update dtype strictness for complex numbers in array_api elementwise functions --- numpy/array_api/_elementwise_functions.py | 46 ++++++++++--------- .../tests/test_elementwise_functions.py | 22 ++++----- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/numpy/array_api/_elementwise_functions.py b/numpy/array_api/_elementwise_functions.py index 5ea48528a8d5..8b696772f6dd 100644 --- a/numpy/array_api/_elementwise_functions.py +++ b/numpy/array_api/_elementwise_functions.py @@ -3,9 +3,11 @@ from ._dtypes import ( _boolean_dtypes, _floating_dtypes, + _real_floating_dtypes, _complex_floating_dtypes, _integer_dtypes, _integer_or_boolean_dtypes, + _real_numeric_dtypes, _numeric_dtypes, _result_type, ) @@ -106,8 +108,8 @@ def atan2(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in atan2") + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in atan2") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) @@ -231,8 +233,8 @@ def ceil(x: Array, /) -> Array: See its docstring for more information. """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in ceil") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in ceil") if x.dtype in _integer_dtypes: # Note: The return dtype of ceil is the same as the input return x @@ -326,8 +328,8 @@ def floor(x: Array, /) -> Array: See its docstring for more information. """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in floor") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in floor") if x.dtype in _integer_dtypes: # Note: The return dtype of floor is the same as the input return x @@ -340,8 +342,8 @@ def floor_divide(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in floor_divide") + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in floor_divide") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) @@ -354,8 +356,8 @@ def greater(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in greater") + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in greater") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) @@ -368,8 +370,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in greater_equal") + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in greater_equal") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) @@ -426,8 +428,8 @@ def less(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in less") + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in less") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) @@ -440,8 +442,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in less_equal") + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in less_equal") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) @@ -498,8 +500,8 @@ def logaddexp(x1: Array, x2: Array) -> Array: See its docstring for more information. """ - if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in logaddexp") + if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in logaddexp") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) @@ -639,8 +641,8 @@ def remainder(x1: Array, x2: Array, /) -> Array: See its docstring for more information. """ - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in remainder") + if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in remainder") # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) x1, x2 = Array._normalize_two_args(x1, x2) @@ -755,8 +757,8 @@ def trunc(x: Array, /) -> Array: See its docstring for more information. """ - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in trunc") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in trunc") if x.dtype in _integer_dtypes: # Note: The return dtype of trunc is the same as the input return x diff --git a/numpy/array_api/tests/test_elementwise_functions.py b/numpy/array_api/tests/test_elementwise_functions.py index 9ab73e0e3f91..1228d0af2e6a 100644 --- a/numpy/array_api/tests/test_elementwise_functions.py +++ b/numpy/array_api/tests/test_elementwise_functions.py @@ -29,7 +29,7 @@ def test_function_types(): "asin": "floating-point", "asinh": "floating-point", "atan": "floating-point", - "atan2": "floating-point", + "atan2": "real floating-point", "atanh": "floating-point", "bitwise_and": "integer or boolean", "bitwise_invert": "integer or boolean", @@ -37,7 +37,7 @@ def test_function_types(): "bitwise_or": "integer or boolean", "bitwise_right_shift": "integer", "bitwise_xor": "integer or boolean", - "ceil": "numeric", + "ceil": "real numeric", "conj": "complex floating-point", "cos": "floating-point", "cosh": "floating-point", @@ -45,18 +45,18 @@ def test_function_types(): "equal": "all", "exp": "floating-point", "expm1": "floating-point", - "floor": "numeric", - "floor_divide": "numeric", - "greater": "numeric", - "greater_equal": "numeric", + "floor": "real numeric", + "floor_divide": "real numeric", + "greater": "real numeric", + "greater_equal": "real numeric", "imag": "complex floating-point", "isfinite": "numeric", "isinf": "numeric", "isnan": "numeric", - "less": "numeric", - "less_equal": "numeric", + "less": "real numeric", + "less_equal": "real numeric", "log": "floating-point", - "logaddexp": "floating-point", + "logaddexp": "real floating-point", "log10": "floating-point", "log1p": "floating-point", "log2": "floating-point", @@ -70,7 +70,7 @@ def test_function_types(): "positive": "numeric", "pow": "numeric", "real": "complex floating-point", - "remainder": "numeric", + "remainder": "real numeric", "round": "numeric", "sign": "numeric", "sin": "floating-point", @@ -80,7 +80,7 @@ def test_function_types(): "subtract": "numeric", "tan": "floating-point", "tanh": "floating-point", - "trunc": "numeric", + "trunc": "real numeric", } def _array_vals(): From 315b0d0db60977be164a251f55a25b64497d3db9 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 18:45:33 -0600 Subject: [PATCH 10/15] Update dtype strictness for array_api statistical functions --- numpy/array_api/_statistical_functions.py | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py index da42076adc95..98e31b51c115 100644 --- a/numpy/array_api/_statistical_functions.py +++ b/numpy/array_api/_statistical_functions.py @@ -1,11 +1,11 @@ from __future__ import annotations from ._dtypes import ( - _floating_dtypes, + _real_floating_dtypes, + _real_numeric_dtypes, _numeric_dtypes, ) from ._array_object import Array -from ._creation_functions import asarray from ._dtypes import float32, float64, complex64, complex128 from typing import TYPE_CHECKING, Optional, Tuple, Union @@ -23,8 +23,8 @@ def max( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> Array: - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in max") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in max") return Array._new(np.max(x._array, axis=axis, keepdims=keepdims)) @@ -35,8 +35,8 @@ def mean( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> Array: - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in mean") + if x.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in mean") return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims)) @@ -47,8 +47,8 @@ def min( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ) -> Array: - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in min") + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in min") return Array._new(np.min(x._array, axis=axis, keepdims=keepdims)) @@ -82,8 +82,8 @@ def std( keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in std") + if x.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in std") return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims)) @@ -117,6 +117,6 @@ def var( keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here - if x.dtype not in _floating_dtypes: - raise TypeError("Only floating-point dtypes are allowed in var") + if x.dtype not in _real_floating_dtypes: + raise TypeError("Only real floating-point dtypes are allowed in var") return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims)) From 837b1af70ecea4877c8b1fee327d73d6dace517a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 18:50:20 -0600 Subject: [PATCH 11/15] Update dtype strictness in array_api searching and sorting functions --- numpy/array_api/_searching_functions.py | 6 +++++- numpy/array_api/_sorting_functions.py | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/numpy/array_api/_searching_functions.py b/numpy/array_api/_searching_functions.py index 40f5a4d2e8fc..a1f4b0c904c1 100644 --- a/numpy/array_api/_searching_functions.py +++ b/numpy/array_api/_searching_functions.py @@ -1,7 +1,7 @@ from __future__ import annotations from ._array_object import Array -from ._dtypes import _result_type +from ._dtypes import _result_type, _real_numeric_dtypes from typing import Optional, Tuple @@ -14,6 +14,8 @@ def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - See its docstring for more information. """ + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in argmax") return Array._new(np.asarray(np.argmax(x._array, axis=axis, keepdims=keepdims))) @@ -23,6 +25,8 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - See its docstring for more information. """ + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in argmin") return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) diff --git a/numpy/array_api/_sorting_functions.py b/numpy/array_api/_sorting_functions.py index afbb412f7f5e..9b8cb044d88a 100644 --- a/numpy/array_api/_sorting_functions.py +++ b/numpy/array_api/_sorting_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._array_object import Array +from ._dtypes import _real_numeric_dtypes import numpy as np @@ -14,6 +15,8 @@ def argsort( See its docstring for more information. """ + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in argsort") # Note: this keyword argument is different, and the default is different. kind = "stable" if stable else "quicksort" if not descending: @@ -41,6 +44,8 @@ def sort( See its docstring for more information. """ + if x.dtype not in _real_numeric_dtypes: + raise TypeError("Only real numeric dtypes are allowed in sort") # Note: this keyword argument is different, and the default is different. kind = "stable" if stable else "quicksort" res = np.sort(x._array, axis=axis, kind=kind) From 4e2a03ab936ac5035640df75e71965074c7d84c6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 18:55:54 -0600 Subject: [PATCH 12/15] Add the dtype argument to numpy.array_api.linalg.trace --- numpy/array_api/linalg.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index d214046effd3..58320db55ceb 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -1,6 +1,13 @@ from __future__ import annotations -from ._dtypes import _floating_dtypes, _numeric_dtypes +from ._dtypes import ( + _floating_dtypes, + _numeric_dtypes, + float32, + float64, + complex64, + complex128 +) from ._manipulation_functions import reshape from ._array_object import Array @@ -8,7 +15,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ._typing import Literal, Optional, Sequence, Tuple, Union + from ._typing import Literal, Optional, Sequence, Tuple, Union, Dtype from typing import NamedTuple @@ -363,7 +370,7 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) # Note: trace is the numpy top-level namespace, not np.linalg -def trace(x: Array, /, *, offset: int = 0) -> Array: +def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.trace `. @@ -371,9 +378,17 @@ def trace(x: Array, /, *, offset: int = 0) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in trace') + + # Note: trace() works the same as sum() and prod() (see + # _statistical_functions.py) + if dtype is None: + if x.dtype == float32: + dtype = float64 + elif x.dtype == complex64: + dtype = complex128 # Note: trace always operates on the last two axes, whereas np.trace # operates on the first two axes by default - return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1))) + return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype))) # Note: vecdot is not in NumPy def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: From 00ca84141e708d382c5be2c3888c48a96a54244c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 5 Jun 2023 19:28:24 -0600 Subject: [PATCH 13/15] Add dtype to the output of array_api finfo and iinfo --- numpy/array_api/_data_type_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/numpy/array_api/_data_type_functions.py b/numpy/array_api/_data_type_functions.py index ede7f85d38d1..6f972c3b5424 100644 --- a/numpy/array_api/_data_type_functions.py +++ b/numpy/array_api/_data_type_functions.py @@ -90,6 +90,7 @@ class finfo_object: max: float min: float smallest_normal: float + dtype: Dtype @dataclass @@ -97,6 +98,7 @@ class iinfo_object: bits: int max: int min: int + dtype: Dtype def finfo(type: Union[Dtype, Array], /) -> finfo_object: @@ -114,6 +116,7 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object: float(fi.max), float(fi.min), float(fi.smallest_normal), + fi.dtype, ) @@ -124,7 +127,7 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object: See its docstring for more information. """ ii = np.iinfo(type) - return iinfo_object(ii.bits, ii.max, ii.min) + return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype) # Note: isdtype is a new function from the 2022.12 array API specification. From ae528a014bf80c006d85ba1330bc1c0a155cc90e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Jun 2023 01:32:51 -0600 Subject: [PATCH 14/15] Add a note that we are skipping numpy.array_api.fft for now --- doc/release/upcoming_changes/23789.new_feature.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/release/upcoming_changes/23789.new_feature.rst b/doc/release/upcoming_changes/23789.new_feature.rst index 417b944f92e9..58158486c9f1 100644 --- a/doc/release/upcoming_changes/23789.new_feature.rst +++ b/doc/release/upcoming_changes/23789.new_feature.rst @@ -2,4 +2,6 @@ Array API v2022.12 support in ``numpy.array_api`` ------------------------------------------------- - ``numpy.array_api`` now full supports the `v2022.12 version - `__ of the array API standard. + `__ of the array API standard. Note + that this does not yet include the optional ``fft`` extension in the + standard. From 91153af22ffc07d2b7653d406d93e99d136e89d6 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 6 Jun 2023 13:20:09 -0600 Subject: [PATCH 15/15] Fix __complex__ type annotation Co-authored-by: Bas van Beek <43369155+BvB93@users.noreply.github.com> --- numpy/array_api/_array_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py index c7eb2a0a361f..ec465208e8b2 100644 --- a/numpy/array_api/_array_object.py +++ b/numpy/array_api/_array_object.py @@ -470,7 +470,7 @@ def __bool__(self: Array, /) -> bool: res = self._array.__bool__() return res - def __complex__(self: Array, /) -> float: + def __complex__(self: Array, /) -> complex: """ Performs the operation __complex__. """