Skip to content

Commit

Permalink
Top level array namespace (#473)
Browse files Browse the repository at this point in the history
* Add array API functions to top-level cubed namespace

* Update to 2022 version of array API

* Add test for array functions in top-level cubed namespace
  • Loading branch information
tomwhite committed Jun 6, 2024
1 parent 6040d35 commit 560ceb9
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 3 deletions.
271 changes: 269 additions & 2 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
defaults=[{"spec": {"allowed_mem": 200_000_000, "reserved_mem": 100_000_000}}],
)

from .array_api import Array
from .core.array import compute, measure_reserved_mem, visualize
from .core.gufunc import apply_gufunc
from .core.ops import from_array, from_zarr, map_blocks, store, to_zarr
Expand All @@ -26,7 +25,6 @@
__all__ = [
"__version__",
"Callback",
"Array",
"Spec",
"TaskEndEvent",
"apply_gufunc",
Expand All @@ -44,3 +42,272 @@
"to_zarr",
"visualize",
]

# Array API

__array_api_version__ = "2022.12"

__all__ += ["__array_api_version__"]

from .array_api.array_object import Array

__all__ += ["Array"]

from .array_api.constants import e, inf, nan, newaxis, pi

__all__ += ["e", "inf", "nan", "newaxis", "pi"]

from .array_api.creation_functions import (
arange,
asarray,
empty,
empty_like,
eye,
full,
full_like,
linspace,
meshgrid,
ones,
ones_like,
tril,
triu,
zeros,
zeros_like,
)

__all__ += [
"arange",
"asarray",
"empty",
"empty_like",
"eye",
"full",
"full_like",
"linspace",
"meshgrid",
"ones",
"ones_like",
"tril",
"triu",
"zeros",
"zeros_like",
]

from .array_api.data_type_functions import (
astype,
can_cast,
finfo,
iinfo,
isdtype,
result_type,
)

__all__ += ["astype", "can_cast", "finfo", "iinfo", "isdtype", "result_type"]

from .array_api.dtypes import (
bool,
complex64,
complex128,
float32,
float64,
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
)

__all__ += [
"bool",
"complex64",
"complex128",
"float32",
"float64",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
]

from .array_api.elementwise_functions import (
abs,
acos,
acosh,
add,
asin,
asinh,
atan,
atan2,
atanh,
bitwise_and,
bitwise_invert,
bitwise_left_shift,
bitwise_or,
bitwise_right_shift,
bitwise_xor,
ceil,
conj,
cos,
cosh,
divide,
equal,
exp,
expm1,
floor,
floor_divide,
greater,
greater_equal,
imag,
isfinite,
isinf,
isnan,
less,
less_equal,
log,
log1p,
log2,
log10,
logaddexp,
logical_and,
logical_not,
logical_or,
logical_xor,
multiply,
negative,
not_equal,
positive,
pow,
real,
remainder,
round,
sign,
sin,
sinh,
sqrt,
square,
subtract,
tan,
tanh,
trunc,
)

__all__ += [
"abs",
"acos",
"acosh",
"add",
"asin",
"asinh",
"atan",
"atan2",
"atanh",
"bitwise_and",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_or",
"bitwise_right_shift",
"bitwise_xor",
"ceil",
"conj",
"cos",
"cosh",
"divide",
"equal",
"exp",
"expm1",
"floor",
"floor_divide",
"greater",
"greater_equal",
"imag",
"isfinite",
"isinf",
"isnan",
"less",
"less_equal",
"log",
"log1p",
"log2",
"log10",
"logaddexp",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"multiply",
"negative",
"not_equal",
"positive",
"pow",
"real",
"remainder",
"round",
"sign",
"sin",
"sinh",
"sqrt",
"square",
"subtract",
"tan",
"tanh",
"trunc",
]

from .array_api.indexing_functions import take

__all__ += ["take"]

from .array_api.linear_algebra_functions import (
matmul,
matrix_transpose,
outer,
tensordot,
vecdot,
)

__all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"]

from .array_api.manipulation_functions import (
broadcast_arrays,
broadcast_to,
concat,
expand_dims,
moveaxis,
permute_dims,
reshape,
roll,
squeeze,
stack,
)

__all__ += [
"broadcast_arrays",
"broadcast_to",
"concat",
"expand_dims",
"moveaxis",
"permute_dims",
"reshape",
"roll",
"squeeze",
"stack",
]

from .array_api.searching_functions import argmax, argmin, where

__all__ += ["argmax", "argmin", "where"]

from .array_api.statistical_functions import max, mean, min, prod, sum

__all__ += ["max", "mean", "min", "prod", "sum"]

from .array_api.utility_functions import all, any

__all__ += ["all", "any"]
2 changes: 1 addition & 1 deletion cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def __abs__(self, /):
return elemwise(nxp.abs, self, dtype=dtype)

def __array_namespace__(self, /, *, api_version=None):
if api_version is not None and not api_version.startswith("2021."):
if api_version is not None and not api_version.startswith("2022."):
raise ValueError(f"Unrecognized array API version: {api_version!r}")
import cubed.array_api as array_api

Expand Down
9 changes: 9 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,15 @@ def test_add(spec, any_executor):
)


def test_add_top_level_namespace(spec, executor):
a = cubed.ones((10, 10), chunks=(10, 2), spec=spec)
b = cubed.ones((10, 10), chunks=(2, 10), spec=spec)
c = cubed.add(a, b)
assert_array_equal(
c.compute(executor=executor), np.ones((10, 10)) + np.ones((10, 10))
)


def test_add_different_chunks(spec, executor):
a = xp.ones((10, 10), chunks=(10, 2), spec=spec)
b = xp.ones((10, 10), chunks=(2, 10), spec=spec)
Expand Down

0 comments on commit 560ceb9

Please sign in to comment.