Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add array API standard v2022.12 support to numpy.array_api #23881

Merged
merged 15 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/release/upcoming_changes/23789.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Array API v2022.12 support in ``numpy.array_api``
-------------------------------------------------

- ``numpy.array_api`` now full supports the `v2022.12 version
<https://data-apis.org/array-api/2022.12>`__ of the array API standard. Note
that this does not yet include the optional ``fft`` extension in the
standard.
8 changes: 7 additions & 1 deletion numpy/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
"The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2
)

__array_api_version__ = "2021.12"
__array_api_version__ = "2022.12"

__all__ = ["__array_api_version__"]

Expand Down Expand Up @@ -173,6 +173,7 @@
broadcast_to,
can_cast,
finfo,
isdtype,
iinfo,
result_type,
)
Expand All @@ -198,6 +199,8 @@
uint64,
float32,
float64,
complex64,
complex128,
bool,
)

Expand Down Expand Up @@ -232,6 +235,7 @@
bitwise_right_shift,
bitwise_xor,
ceil,
conj,
cos,
cosh,
divide,
Expand All @@ -242,6 +246,7 @@
floor_divide,
greater,
greater_equal,
imag,
isfinite,
isinf,
isnan,
Expand All @@ -261,6 +266,7 @@
not_equal,
positive,
pow,
real,
remainder,
round,
sign,
Expand Down
55 changes: 38 additions & 17 deletions numpy/array_api/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_integer_dtypes,
_integer_or_boolean_dtypes,
_floating_dtypes,
_complex_floating_dtypes,
_numeric_dtypes,
_result_type,
_dtype_categories,
Expand Down Expand Up @@ -139,7 +140,7 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor

if self.dtype not in _dtype_categories[dtype_category]:
raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
if isinstance(other, (int, float, bool)):
if isinstance(other, (int, complex, float, bool)):
other = self._promote_scalar(other)
elif isinstance(other, Array):
if other.dtype not in _dtype_categories[dtype_category]:
Expand Down Expand Up @@ -189,11 +190,23 @@ def _promote_scalar(self, scalar):
raise TypeError(
"Python int scalars cannot be promoted with bool arrays"
)
if self.dtype in _integer_dtypes:
info = np.iinfo(self.dtype)
if not (info.min <= scalar <= info.max):
raise OverflowError(
"Python int scalars must be within the bounds of the dtype for integer arrays"
)
# int + array(floating) is allowed
elif isinstance(scalar, float):
if self.dtype not in _floating_dtypes:
raise TypeError(
"Python float scalars can only be promoted with floating-point arrays."
)
elif isinstance(scalar, complex):
if self.dtype not in _complex_floating_dtypes:
raise TypeError(
"Python complex scalars can only be promoted with complex floating-point arrays."
)
else:
raise TypeError("'scalar' must be a Python scalar")

Expand Down Expand Up @@ -454,11 +467,19 @@ def __bool__(self: Array, /) -> bool:
# Note: This is an error here.
if self._array.ndim != 0:
raise TypeError("bool is only allowed on arrays with 0 dimensions")
if self.dtype not in _boolean_dtypes:
raise ValueError("bool is only allowed on boolean arrays")
res = self._array.__bool__()
return res

def __complex__(self: Array, /) -> float:
asmeurer marked this conversation as resolved.
Show resolved Hide resolved
"""
Performs the operation __complex__.
"""
# Note: This is an error here.
if self._array.ndim != 0:
raise TypeError("complex is only allowed on arrays with 0 dimensions")
res = self._array.__complex__()
return res

def __dlpack__(self: Array, /, *, stream: None = None) -> PyCapsule:
"""
Performs the operation __dlpack__.
Expand Down Expand Up @@ -492,16 +513,16 @@ def __float__(self: Array, /) -> float:
# Note: This is an error here.
if self._array.ndim != 0:
raise TypeError("float is only allowed on arrays with 0 dimensions")
if self.dtype not in _floating_dtypes:
raise ValueError("float is only allowed on floating-point arrays")
if self.dtype in _complex_floating_dtypes:
raise TypeError("float is not allowed on complex floating-point arrays")
res = self._array.__float__()
return res

def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __floordiv__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__floordiv__")
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand All @@ -512,7 +533,7 @@ def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ge__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__ge__")
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -542,7 +563,7 @@ def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __gt__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__gt__")
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand All @@ -556,8 +577,8 @@ def __int__(self: Array, /) -> int:
# Note: This is an error here.
if self._array.ndim != 0:
raise TypeError("int is only allowed on arrays with 0 dimensions")
if self.dtype not in _integer_dtypes:
raise ValueError("int is only allowed on integer arrays")
if self.dtype in _complex_floating_dtypes:
raise TypeError("int is not allowed on complex floating-point arrays")
res = self._array.__int__()
return res

Expand All @@ -581,7 +602,7 @@ def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __le__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__le__")
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand All @@ -603,7 +624,7 @@ def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __lt__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__lt__")
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand All @@ -626,7 +647,7 @@ def __mod__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __mod__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__mod__")
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -808,7 +829,7 @@ def __ifloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ifloordiv__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__ifloordiv__")
other = self._check_allowed_dtypes(other, "real numeric", "__ifloordiv__")
if other is NotImplemented:
return other
self._array.__ifloordiv__(other._array)
Expand All @@ -818,7 +839,7 @@ def __rfloordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __rfloordiv__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__rfloordiv__")
other = self._check_allowed_dtypes(other, "real numeric", "__rfloordiv__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down Expand Up @@ -874,7 +895,7 @@ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __imod__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__imod__")
other = self._check_allowed_dtypes(other, "real numeric", "__imod__")
if other is NotImplemented:
return other
self._array.__imod__(other._array)
Expand All @@ -884,7 +905,7 @@ def __rmod__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __rmod__.
"""
other = self._check_allowed_dtypes(other, "numeric", "__rmod__")
other = self._check_allowed_dtypes(other, "real numeric", "__rmod__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
Expand Down
55 changes: 53 additions & 2 deletions numpy/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from __future__ import annotations

from ._array_object import Array
from ._dtypes import _all_dtypes, _result_type
from ._dtypes import (
_all_dtypes,
_boolean_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
_integer_dtypes,
_real_floating_dtypes,
_complex_floating_dtypes,
_numeric_dtypes,
_result_type,
)

from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Tuple, Union
Expand Down Expand Up @@ -80,13 +90,15 @@ class finfo_object:
max: float
min: float
smallest_normal: float
dtype: Dtype


@dataclass
class iinfo_object:
bits: int
max: int
min: int
dtype: Dtype


def finfo(type: Union[Dtype, Array], /) -> finfo_object:
Expand All @@ -104,6 +116,7 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object:
float(fi.max),
float(fi.min),
float(fi.smallest_normal),
fi.dtype,
)


Expand All @@ -114,9 +127,47 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
See its docstring for more information.
"""
ii = np.iinfo(type)
return iinfo_object(ii.bits, ii.max, ii.min)
return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype)


# Note: isdtype is a new function from the 2022.12 array API specification.
def isdtype(
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]]
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.

See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
for more details
"""
if isinstance(kind, tuple):
# Disallow nested tuples
if any(isinstance(k, tuple) for k in kind):
raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs")
return any(isdtype(dtype, k) for k in kind)
elif isinstance(kind, str):
if kind == 'bool':
return dtype in _boolean_dtypes
elif kind == 'signed integer':
return dtype in _signed_integer_dtypes
elif kind == 'unsigned integer':
return dtype in _unsigned_integer_dtypes
elif kind == 'integral':
return dtype in _integer_dtypes
elif kind == 'real floating':
return dtype in _real_floating_dtypes
elif kind == 'complex floating':
return dtype in _complex_floating_dtypes
elif kind == 'numeric':
return dtype in _numeric_dtypes
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
elif kind in _all_dtypes:
return dtype == kind
else:
raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}")

def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
"""
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
Expand Down
Loading
Loading