Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support version 2022.12 of the Array API #293

Merged
merged 3 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ jobs:
with:
repository: data-apis/array-api-tests
path: array-api-tests
ref: 2022.05.18
submodules: "true"
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
Expand Down Expand Up @@ -83,8 +82,9 @@ jobs:
# don't run special cases yet
array_api_tests/test_special_cases.py

# don't test signatures yet as some are not implemented
# don't test signatures or names yet as some are not implemented
array_api_tests/test_signatures.py
array_api_tests/test_has_names.py

# very slow
array_api_tests/test_creation_functions.py::test_eye
Expand Down
16 changes: 13 additions & 3 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = []

__array_api_version__ = "2021.12"
__array_api_version__ = "2022.12"

__all__ += ["__array_api_version__"]

Expand Down Expand Up @@ -48,12 +48,14 @@
"zeros_like",
]

from .data_type_functions import astype, can_cast, finfo, iinfo, result_type
from .data_type_functions import astype, can_cast, finfo, iinfo, isdtype, result_type

__all__ += ["astype", "can_cast", "finfo", "iinfo", "result_type"]
__all__ += ["astype", "can_cast", "finfo", "iinfo", "isdtype", "result_type"]

from .dtypes import (
bool,
complex64,
complex128,
float32,
float64,
int8,
Expand All @@ -68,6 +70,8 @@

__all__ += [
"bool",
"complex64",
"complex128",
"float32",
"float64",
"int8",
Expand Down Expand Up @@ -97,6 +101,7 @@
bitwise_right_shift,
bitwise_xor,
ceil,
conj,
cos,
cosh,
divide,
Expand All @@ -107,6 +112,7 @@
floor_divide,
greater,
greater_equal,
imag,
isfinite,
isinf,
isnan,
Expand All @@ -126,6 +132,7 @@
not_equal,
positive,
pow,
real,
remainder,
round,
sign,
Expand Down Expand Up @@ -156,6 +163,7 @@
"bitwise_right_shift",
"bitwise_xor",
"ceil",
"conj",
"cos",
"cosh",
"divide",
Expand All @@ -166,6 +174,7 @@
"floor_divide",
"greater",
"greater_equal",
"imag",
"isfinite",
"isinf",
"isnan",
Expand All @@ -185,6 +194,7 @@
"not_equal",
"positive",
"pow",
"real",
"remainder",
"round",
"sign",
Expand Down
29 changes: 26 additions & 3 deletions cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from cubed.array_api.data_type_functions import result_type
from cubed.array_api.dtypes import (
_boolean_dtypes,
_complex_floating_dtypes,
_dtype_categories,
_floating_dtypes,
_integer_dtypes,
_integer_or_boolean_dtypes,
_numeric_dtypes,
)
Expand Down Expand Up @@ -144,13 +146,13 @@ def __truediv__(self, other, /):
return elemwise(np.divide, self, other, dtype=result_type(self, other))

def __floordiv__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__floordiv__")
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
if other is NotImplemented:
return other
return elemwise(np.floor_divide, self, other, dtype=result_type(self, other))

def __mod__(self, other, /):
other = self._check_allowed_dtypes(other, "numeric", "__mod__")
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
if other is NotImplemented:
return other
return elemwise(np.remainder, self, other, dtype=result_type(self, other))
Expand Down Expand Up @@ -349,9 +351,16 @@ def __bool__(self, /):
raise TypeError("bool is only allowed on arrays with 0 dimensions")
return bool(self.compute())

def __complex__(self, /):
if self.ndim != 0:
raise TypeError("complex is only allowed on arrays with 0 dimensions")
return complex(self.compute())

def __float__(self, /):
if self.ndim != 0:
raise TypeError("float is only allowed on arrays with 0 dimensions")
if self.dtype in _complex_floating_dtypes:
raise TypeError("float is not allowed on complex floating-point arrays")
return float(self.compute())

def __index__(self, /):
Expand All @@ -362,14 +371,16 @@ def __index__(self, /):
def __int__(self, /):
if self.ndim != 0:
raise TypeError("int is only allowed on arrays with 0 dimensions")
if self.dtype in _complex_floating_dtypes:
raise TypeError("int is not allowed on complex floating-point arrays")
return int(self.compute())

# Utility methods

def _check_allowed_dtypes(self, other, dtype_category, op):
if self.dtype not in _dtype_categories[dtype_category]:
raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
if isinstance(other, (int, float, bool)):
if isinstance(other, (int, complex, float, bool)):
other = self._promote_scalar(other)
elif isinstance(other, CoreArray):
if other.dtype not in _dtype_categories[dtype_category]:
Expand All @@ -392,11 +403,23 @@ def _promote_scalar(self, scalar):
raise TypeError(
"Python int scalars cannot be promoted with bool arrays"
)
if self.dtype in _integer_dtypes:
info = np.iinfo(self.dtype)
if not (info.min <= scalar <= info.max):
raise OverflowError(
"Python int scalars must be within the bounds of the dtype for integer arrays"
)
# int + array(floating) is allowed
elif isinstance(scalar, float):
if self.dtype not in _floating_dtypes:
raise TypeError(
"Python float scalars can only be promoted with floating-point arrays."
)
elif isinstance(scalar, complex):
if self.dtype not in _complex_floating_dtypes:
raise TypeError(
"Python complex scalars can only be promoted with complex floating-point arrays."
)
else:
raise TypeError("'scalar' must be a Python scalar")
return asarray(scalar, dtype=self.dtype, spec=self.spec)
7 changes: 4 additions & 3 deletions cubed/array_api/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,13 @@ def full(
# note that write_empty_chunks=False means no chunks are written to disk, so it is very efficient to create large arrays
shape = normalize_shape(shape)
if dtype is None:
if isinstance(fill_value, int):
# check bool first since True/False are instances of int and float
if isinstance(fill_value, bool):
dtype = np.bool_
elif isinstance(fill_value, int):
dtype = np.int64
elif isinstance(fill_value, float):
dtype = np.float64
elif isinstance(fill_value, bool):
dtype = np.bool_
else:
raise TypeError("Invalid input to full")
chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
Expand Down
128 changes: 120 additions & 8 deletions cubed/array_api/data_type_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
import numpy.array_api as nxp
from dataclasses import dataclass

import numpy as np
from numpy.array_api._typing import Dtype

from cubed.core import CoreArray, map_blocks

from .dtypes import (
_all_dtypes,
_boolean_dtypes,
_complex_floating_dtypes,
_integer_dtypes,
_numeric_dtypes,
_real_floating_dtypes,
_result_type,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
)


def astype(x, dtype, /, *, copy=True):
if not copy and dtype == x.dtype:
Expand All @@ -14,21 +29,118 @@ def _astype(a, astype_dtype):


def can_cast(from_, to, /):
# Copied from numpy.array_api
# TODO: replace with `nxp.can_cast` when NumPy 1.25 is widely used (e.g. in Xarray)

if isinstance(from_, CoreArray):
from_ = from_.dtype
return nxp.can_cast(from_, to)
elif from_ not in _all_dtypes:
raise TypeError(f"{from_=}, but should be an array_api array or dtype")
if to not in _all_dtypes:
raise TypeError(f"{to=}, but should be a dtype")
try:
# We promote `from_` and `to` together. We then check if the promoted
# dtype is `to`, which indicates if `from_` can (up)cast to `to`.
dtype = _result_type(from_, to)
return to == dtype
except TypeError:
# _result_type() raises if the dtypes don't promote together
return False


@dataclass
class finfo_object:
bits: int
eps: float
max: float
min: float
smallest_normal: float
dtype: Dtype


@dataclass
class iinfo_object:
bits: int
max: int
min: int
dtype: Dtype


def finfo(type, /):
return nxp.finfo(type)
# Copied from numpy.array_api
# TODO: replace with `nxp.finfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray)

fi = np.finfo(type)
return finfo_object(
fi.bits,
float(fi.eps),
float(fi.max),
float(fi.min),
float(fi.smallest_normal),
fi.dtype,
)


def iinfo(type, /):
return nxp.iinfo(type)
# Copied from numpy.array_api
# TODO: replace with `nxp.iinfo(type)` when NumPy 1.25 is widely used (e.g. in Xarray)

ii = np.iinfo(type)
return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype)


def isdtype(dtype, kind):
# Copied from numpy.array_api
# TODO: replace with `nxp.isdtype(dtype, kind)` when NumPy 1.25 is widely used (e.g. in Xarray)

if isinstance(kind, tuple):
# Disallow nested tuples
if any(isinstance(k, tuple) for k in kind):
raise TypeError("'kind' must be a dtype, str, or tuple of dtypes and strs")
return any(isdtype(dtype, k) for k in kind)
elif isinstance(kind, str):
if kind == "bool":
return dtype in _boolean_dtypes
elif kind == "signed integer":
return dtype in _signed_integer_dtypes
elif kind == "unsigned integer":
return dtype in _unsigned_integer_dtypes
elif kind == "integral":
return dtype in _integer_dtypes
elif kind == "real floating":
return dtype in _real_floating_dtypes
elif kind == "complex floating":
return dtype in _complex_floating_dtypes
elif kind == "numeric":
return dtype in _numeric_dtypes
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
elif kind in _all_dtypes:
return dtype == kind
else:
raise TypeError(
f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}"
)


def result_type(*arrays_and_dtypes):
# Use numpy.array_api promotion rules (stricter than numpy)
return nxp.result_type(
*(a.dtype if isinstance(a, CoreArray) else a for a in arrays_and_dtypes)
)
# Copied from numpy.array_api
# TODO: replace with `nxp.result_type` when NumPy 1.25 is widely used (e.g. in Xarray)

A = []
for a in arrays_and_dtypes:
if isinstance(a, CoreArray):
a = a.dtype
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
A.append(a)

if len(A) == 0:
raise ValueError("at least one array or dtype is required")
elif len(A) == 1:
return A[0]
else:
t = A[0]
for t2 in A[1:]:
t = _result_type(t, t2)
return t
Loading