Skip to content

Commit

Permalink
Merge pull request #23881 from asmeurer/array_api-2022
Browse files Browse the repository at this point in the history
ENH: Add array API standard v2022.12 support to numpy.array_api
  • Loading branch information
mattip committed Jun 14, 2023
2 parents 26edc98 + 91153af commit c178bac
Show file tree
Hide file tree
Showing 13 changed files with 332 additions and 108 deletions.
7 changes: 7 additions & 0 deletions doc/release/upcoming_changes/23789.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Array API v2022.12 support in ``numpy.array_api``
-------------------------------------------------

- ``numpy.array_api`` now full supports the `v2022.12 version
<https://data-apis.org/array-api/2022.12>`__ of the array API standard. Note
that this does not yet include the optional ``fft`` extension in the
standard.
8 changes: 7 additions & 1 deletion numpy/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
"The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2
)

__array_api_version__ = "2021.12"
__array_api_version__ = "2022.12"

__all__ = ["__array_api_version__"]

Expand Down Expand Up @@ -173,6 +173,7 @@
broadcast_to,
can_cast,
finfo,
isdtype,
iinfo,
result_type,
)
Expand All @@ -198,6 +199,8 @@
uint64,
float32,
float64,
complex64,
complex128,
bool,
)

Expand Down Expand Up @@ -232,6 +235,7 @@
bitwise_right_shift,
bitwise_xor,
ceil,
conj,
cos,
cosh,
divide,
Expand All @@ -242,6 +246,7 @@
floor_divide,
greater,
greater_equal,
imag,
isfinite,
isinf,
isnan,
Expand All @@ -261,6 +266,7 @@
not_equal,
positive,
pow,
real,
remainder,
round,
sign,
Expand Down
55 changes: 38 additions & 17 deletions numpy/array_api/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_integer_dtypes,
_integer_or_boolean_dtypes,
_floating_dtypes,
_complex_floating_dtypes,
_numeric_dtypes,
_result_type,
_dtype_categories,
Expand Down Expand Up @@ -139,7 +140,7 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor

if self.dtype not in _dtype_categories[dtype_category]:
raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
if isinstance(other, (int, float, bool)):
if isinstance(other, (int, complex, float, bool)):
other = self._promote_scalar(other)
elif isinstance(other, Array):
if other.dtype not in _dtype_categories[dtype_category]:
Expand Down Expand Up @@ -189,11 +190,23 @@ def _promote_scalar(self, scalar):
raise TypeError(
"Python int scalars cannot be promoted with bool arrays"
)
if self.dtype in _integer_dtypes:
info = np.iinfo(self.dtype)
if not (info.min <= scalar <= info.max):
raise OverflowError(
"Python int scalars must be within the bounds of the dtype for integer arrays"
)
# int + array(floating) is allowed
elif isinstance(scalar, float):
if self.dtype not in _floating_dtypes:
raise TypeError(
"Python float scalars can only be promoted with floating-point arrays."
)
elif isinstance(scalar, complex):
if self.dtype not in _complex_floating_dtypes:
raise TypeError(
"Python complex scalars can only be promoted with complex floating-point arrays."
)
else:
raise TypeError("'scalar' must be a Python scalar")

Expand Down Expand Up @@ -454,11 +467,19 @@ def __bool__(self: Array, /) -> bool:
# Note: This is an error here.
if self._array.ndim != 0:
raise TypeError("bool is only allowed on arrays with 0 dimensions")
if self.dtype not in _boolean_dtypes:
raise ValueError("bool is only allowed on boolean arrays")
res = self._array.__bool__()
return res

def __complex__(self: Array, /) -> complex:
"""
Performs the operation __complex__.
"""
# Note: This is an error here.
if self._array.ndim != 0:
raise TypeError("complex is only allowed on arrays with 0 dimensions")
res = self._array.__complex__()
return res

def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
"""
Performs the operation __dlpack__.
Expand Down Expand Up @@ -492,16 +513,16 @@ def __float__(self: Array, /) -> float:
# Note: This is an error here.
if self._array.ndim != 0:
raise TypeError("float is only allowed on arrays with 0 dimensions")
if self.dtype not in _floating_dtypes:
raise ValueError("float is only allowed on floating-point arrays")
if self.dtype in _complex_floating_dtypes:
raise TypeError("float is not allowed on complex floating-point arrays")
res = self._array.__float__()
return res

def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __floordiv__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__floordiv__")
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand All @@ -512,7 +533,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ge__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__ge__")
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -542,7 +563,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __gt__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__gt__")
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand All @@ -556,8 +577,8 @@ def __int__(self: Array, /) -> int:
# Note: This is an error here.
if self._array.ndim != 0:
raise TypeError("int is only allowed on arrays with 0 dimensions")
if self.dtype not in _integer_dtypes:
raise ValueError("int is only allowed on integer arrays")
if self.dtype in _complex_floating_dtypes:
raise TypeError("int is not allowed on complex floating-point arrays")
res = self._array.__int__()
return res

Expand All @@ -581,7 +602,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __le__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__le__")
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand All @@ -603,7 +624,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __lt__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__lt__")
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand All @@ -626,7 +647,7 @@ def __mod__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __mod__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__mod__")
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -808,7 +829,7 @@ def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ifloordiv__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__ifloordiv__")
other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__")
if other is NotImplemented:
return other
self._array.__ifloordiv__(other._array)
Expand All @@ -818,7 +839,7 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __rfloordiv__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__")
other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -874,7 +895,7 @@ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __imod__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__imod__")
other = self._check_allowed_dtypes(other, "real numeric", "__imod__")
if other is NotImplemented:
return other
self._array.__imod__(other._array)
Expand All @@ -884,7 +905,7 @@ def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __rmod__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__rmod__")
other = self._check_allowed_dtypes(other, "real numeric", "__rmod__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down
55 changes: 53 additions & 2 deletions numpy/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from __future__ import annotations

from ._array_object import Array
from ._dtypes import _all_dtypes, _result_type
from ._dtypes import (
_all_dtypes,
_boolean_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
_integer_dtypes,
_real_floating_dtypes,
_complex_floating_dtypes,
_numeric_dtypes,
_result_type,
)

from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Tuple, Union
Expand Down Expand Up @@ -80,13 +90,15 @@ class finfo_object:
max: float
min: float
smallest_normal: float
dtype: Dtype


@dataclass
class iinfo_object:
bits: int
max: int
min: int
dtype: Dtype


def finfo(type: Union[Dtype, Array], /) -> finfo_object:
Expand All @@ -104,6 +116,7 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object:
float(fi.max),
float(fi.min),
float(fi.smallest_normal),
fi.dtype,
)


Expand All @@ -114,9 +127,47 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
See its docstring for more information.
"""
ii = np.iinfo(type)
return iinfo_object(ii.bits, ii.max, ii.min)
return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype)


# Note: isdtype is a new function from the 2022.12 array API specification.
def isdtype(
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]]
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
for more details
"""
if isinstance(kind, tuple):
# Disallow nested tuples
if any(isinstance(k, tuple) for k in kind):
raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs")
return any(isdtype(dtype, k) for k in kind)
elif isinstance(kind, str):
if kind == 'bool':
return dtype in _boolean_dtypes
elif kind == 'signed integer':
return dtype in _signed_integer_dtypes
elif kind == 'unsigned integer':
return dtype in _unsigned_integer_dtypes
elif kind == 'integral':
return dtype in _integer_dtypes
elif kind == 'real floating':
return dtype in _real_floating_dtypes
elif kind == 'complex floating':
return dtype in _complex_floating_dtypes
elif kind == 'numeric':
return dtype in _numeric_dtypes
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
elif kind in _all_dtypes:
return dtype == kind
else:
raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}")

def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
"""
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
Expand Down

0 comments on commit c178bac

Please sign in to comment.