Skip to content

Commit

Permalink
Update to support version 2022.12 of the Array API
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Aug 21, 2023
1 parent f5ef6eb commit 553601e
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 58 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ jobs:
with:
repository: data-apis/array-api-tests
path: array-api-tests
ref: 2022.05.18
submodules: "true"
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
Expand Down Expand Up @@ -83,8 +82,9 @@ jobs:
# don't run special cases yet
array_api_tests/test_special_cases.py
# don't test signatures yet as some are not implemented
# don't test signatures or names yet as some are not implemented
array_api_tests/test_signatures.py
array_api_tests/test_has_names.py
# very slow
array_api_tests/test_creation_functions.py::test_eye
Expand Down
16 changes: 13 additions & 3 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = []

__array_api_version__ = "2021.12"
__array_api_version__ = "2022.12"

__all__ += ["__array_api_version__"]

Expand Down Expand Up @@ -48,12 +48,14 @@
"zeros_like",
]

from .data_type_functions import astype, can_cast, finfo, iinfo, result_type
from .data_type_functions import astype, can_cast, finfo, iinfo, isdtype, result_type

__all__ += ["astype", "can_cast", "finfo", "iinfo", "result_type"]
__all__ += ["astype", "can_cast", "finfo", "iinfo", "isdtype", "result_type"]

from .dtypes import (
bool,
complex64,
complex128,
float32,
float64,
int8,
Expand All @@ -68,6 +70,8 @@

__all__ += [
"bool",
"complex64",
"complex128",
"float32",
"float64",
"int8",
Expand Down Expand Up @@ -97,6 +101,7 @@
bitwise_right_shift,
bitwise_xor,
ceil,
conj,
cos,
cosh,
divide,
Expand All @@ -107,6 +112,7 @@
floor_divide,
greater,
greater_equal,
imag,
isfinite,
isinf,
isnan,
Expand All @@ -126,6 +132,7 @@
not_equal,
positive,
pow,
real,
remainder,
round,
sign,
Expand Down Expand Up @@ -156,6 +163,7 @@
"bitwise_right_shift",
"bitwise_xor",
"ceil",
"conj",
"cos",
"cosh",
"divide",
Expand All @@ -166,6 +174,7 @@
"floor_divide",
"greater",
"greater_equal",
"imag",
"isfinite",
"isinf",
"isnan",
Expand All @@ -185,6 +194,7 @@
"not_equal",
"positive",
"pow",
"real",
"remainder",
"round",
"sign",
Expand Down
29 changes: 26 additions & 3 deletions cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from cubed.array_api.data_type_functions import result_type
from cubed.array_api.dtypes import (
_boolean_dtypes,
_complex_floating_dtypes,
_dtype_categories,
_floating_dtypes,
_integer_dtypes,
_integer_or_boolean_dtypes,
_numeric_dtypes,
)
Expand Down Expand Up @@ -144,13 +146,13 @@ def __truediv__(self, other, /):
return elemwise(np.divide, self, other, dtype=result_type(self, other))

def __floordiv__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__floordiv__")
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
if other is NotImplemented:
return other
return elemwise(np.floor_divide, self, other, dtype=result_type(self, other))

def __mod__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__mod__")
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
if other is NotImplemented:
return other
return elemwise(np.remainder, self, other, dtype=result_type(self, other))
Expand Down Expand Up @@ -349,9 +351,16 @@ def __bool__(self, /):
raise TypeError("bool is only allowed on arrays with 0 dimensions")
return bool(self.compute())

def __complex__(self, /):
if self.ndim != 0:
raise TypeError("complex is only allowed on arrays with 0 dimensions")
return complex(self.compute())

def __float__(self, /):
if self.ndim != 0:
raise TypeError("float is only allowed on arrays with 0 dimensions")
if self.dtype in _complex_floating_dtypes:
raise TypeError("float is not allowed on complex floating-point arrays")
return float(self.compute())

def __index__(self, /):
Expand All @@ -362,14 +371,16 @@ def __index__(self, /):
def __int__(self, /):
if self.ndim != 0:
raise TypeError("int is only allowed on arrays with 0 dimensions")
if self.dtype in _complex_floating_dtypes:
raise TypeError("int is not allowed on complex floating-point arrays")
return int(self.compute())

# Utility methods

def _check_allowed_dtypes(self, other, dtype_category, op):
if self.dtype not in _dtype_categories[dtype_category]:
raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
if isinstance(other, (int, float, bool)):
if isinstance(other, (int, complex, float, bool)):
other = self._promote_scalar(other)
elif isinstance(other, CoreArray):
if other.dtype not in _dtype_categories[dtype_category]:
Expand All @@ -392,11 +403,23 @@ def _promote_scalar(self, scalar):
raise TypeError(
"Python int scalars cannot be promoted with bool arrays"
)
if self.dtype in _integer_dtypes:
info = np.iinfo(self.dtype)
if not (info.min <= scalar <= info.max):
raise OverflowError(
"Python int scalars must be within the bounds of the dtype for integer arrays"
)
# int + array(floating) is allowed
elif isinstance(scalar, float):
if self.dtype not in _floating_dtypes:
raise TypeError(
"Python float scalars can only be promoted with floating-point arrays."
)
elif isinstance(scalar, complex):
if self.dtype not in _complex_floating_dtypes:
raise TypeError(
"Python complex scalars can only be promoted with complex floating-point arrays."
)
else:
raise TypeError("'scalar' must be a Python scalar")
return asarray(scalar, dtype=self.dtype, spec=self.spec)
128 changes: 120 additions & 8 deletions cubed/array_api/data_type_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
import numpy.array_api as nxp
from dataclasses import dataclass

import numpy as np
from numpy.array_api._typing import Dtype

from cubed.core import CoreArray, map_blocks

from .dtypes import (
_all_dtypes,
_boolean_dtypes,
_complex_floating_dtypes,
_integer_dtypes,
_numeric_dtypes,
_real_floating_dtypes,
_result_type,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
)


def astype(x, dtype, /, *, copy=True):
if not copy and dtype == x.dtype:
Expand All @@ -14,21 +29,118 @@ def _astype(a, astype_dtype):


def can_cast(from_, to, /):
# Copied from numpy.array_api
# TODO: replace with `nxp.can_cast` when NumPy 1.25 is widely used (e.g. in Xarray)

if isinstance(from_, CoreArray):
from_ = from_.dtype
return nxp.can_cast(from_, to)
elif from_ not in _all_dtypes:
raise TypeError(f"{from_=}, but should be an array_api array or dtype")
if to not in _all_dtypes:
raise TypeError(f"{to=}, but should be a dtype")
try:
# We promote `from_` and `to` together. We then check if the promoted
# dtype is `to`, which indicates if `from_` can (up)cast to `to`.
dtype = _result_type(from_, to)
return to == dtype
except TypeError:
# _result_type() raises if the dtypes don't promote together
return False


@dataclass
class finfo_object:
bits: int
eps: float
max: float
min: float
smallest_normal: float
dtype: Dtype


@dataclass
class iinfo_object:
bits: int
max: int
min: int
dtype: Dtype


def finfo(type, /):
return nxp.finfo(type)
# Copied from numpy.array_api
# TODO: replace with `nxp.finfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray)

fi = np.finfo(type)
return finfo_object(
fi.bits,
float(fi.eps),
float(fi.max),
float(fi.min),
float(fi.smallest_normal),
fi.dtype,
)


def iinfo(type, /):
return nxp.iinfo(type)
# Copied from numpy.array_api
# TODO: replace with `nxp.iinfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray)

ii = np.iinfo(type)
return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype)


def isdtype(dtype, kind):
# Copied from numpy.array_api
# TODO: replace with `nxp.isdtype(dtype, kind)` when NumPy 1.25 is widely used (e.g. in Xarray)

if isinstance(kind, tuple):
# Disallow nested tuples
if any(isinstance(k, tuple) for k in kind):
raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs")
return any(isdtype(dtype, k) for k in kind)
elif isinstance(kind, str):
if kind == "bool":
return dtype in _boolean_dtypes
elif kind == "signed integer":
return dtype in _signed_integer_dtypes
elif kind == "unsigned integer":
return dtype in _unsigned_integer_dtypes
elif kind == "integral":
return dtype in _integer_dtypes
elif kind == "real floating":
return dtype in _real_floating_dtypes
elif kind == "complex floating":
return dtype in _complex_floating_dtypes
elif kind == "numeric":
return dtype in _numeric_dtypes
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
elif kind in _all_dtypes:
return dtype == kind
else:
raise TypeError(
f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}"
)


def result_type(*arrays_and_dtypes):
# Use numpy.array_api promotion rules (stricter than numpy)
return nxp.result_type(
*(a.dtype if isinstance(a, CoreArray) else a for a in arrays_and_dtypes)
)
# Copied from numpy.array_api
# TODO: replace with `nxp.result_type` when NumPy 1.25 is widely used (e.g. in Xarray)

A = []
for a in arrays_and_dtypes:
if isinstance(a, CoreArray):
a = a.dtype
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
A.append(a)

if len(A) == 0:
raise ValueError("at least one array or dtype is required")
elif len(A) == 1:
return A[0]
else:
t = A[0]
for t2 in A[1:]:
t = _result_type(t, t2)
return t
Loading

0 comments on commit 553601e

Please sign in to comment.