Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of the Python Array API standard #16099

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: JAX Array API

on:
workflow_dispatch: # allows triggering the workflow run manually
pull_request: # Automatically trigger on pull requests affecting particular files
branches:
- main
paths:
- '**workflows/jax-array-api.yml'
- '**experimental/array_api/**'

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.11]

steps:
- name: Checkout jax
uses: actions/checkout@v3
- name: Checkout array-api-tests
uses: actions/checkout@v3
with:
repository: data-apis/array-api-tests
ref: '83f0bcdcc5286250dbb26be5d37511702970b4dc' # Latest commit as of 2023-11-15
submodules: 'true'
path: 'array-api-tests'
- name: Fix array-apis bug
# Temporary workaround for https://github.com/data-apis/array-api/issues/631
run: |
sed -i -e 's/\\/\\\\/g' array-api-tests/array-api/spec/API_specification/signatures/*.py
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install .[cpu]
python -m pip install "hypothesis<6.88.4" # 6.88.4 breaks with a Return-type annotation warning
python -m pip install -r array-api-tests/requirements.txt
- name: Run the test suite
env:
ARRAY_API_TESTS_MODULE: jax.experimental.array_api
JAX_ENABLE_X64: 'true'
run: |
cd ${GITHUB_WORKSPACE}/array-api-tests
pytest --ci array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/jax/experimental/array_api/skips.txt
4 changes: 4 additions & 0 deletions docs/jax.experimental.array_api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
``jax.experimental.array_api`` module
=====================================

.. automodule:: jax.experimental.array_api
1 change: 1 addition & 0 deletions docs/jax.experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Experimental Modules
.. toctree::
:maxdepth: 1

jax.experimental.array_api
jax.experimental.checkify
jax.experimental.host_callback
jax.experimental.maps
Expand Down
13 changes: 12 additions & 1 deletion jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,17 @@ pytype_library(
deps = [":jax"],
)

pytype_library(
name = "experimental_array_api",
srcs = glob(
[
"experimental/array_api/*.py",
],
),
visibility = [":internal"],
deps = [":jax"],
)

pytype_library(
name = "experimental_sparse",
srcs = glob(
Expand Down Expand Up @@ -873,7 +884,7 @@ pytype_library(
"example_libraries/optimizers.py",
],
visibility = ["//visibility:public"],
deps = [":jax"],
deps = [":jax"] + py_deps("numpy"),
)

pytype_library(
Expand Down
220 changes: 220 additions & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module includes experimental JAX support for the `Python array API standard`_.
Support for this is currently experimental and not fully complete.

Example Usage::

>>> from jax.experimental import array_api as xp

>>> xp.__array_api_version__
'2022.12'

>>> arr = xp.arange(1000)

>>> arr.sum()
Array(499500, dtype=int32)

The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`,
and implements most of the API listed in the standard.

.. _Python array API standard: https://data-apis.org/array-api/latest/
"""

from __future__ import annotations

from jax.experimental.array_api._version import __array_api_version__ as __array_api_version__

from jax.experimental.array_api import (
fft as fft,
linalg as linalg,
)

from jax.experimental.array_api._constants import (
e as e,
inf as inf,
nan as nan,
newaxis as newaxis,
pi as pi,
)

from jax.experimental.array_api._creation_functions import (
arange as arange,
asarray as asarray,
empty as empty,
empty_like as empty_like,
eye as eye,
from_dlpack as from_dlpack,
full as full,
full_like as full_like,
linspace as linspace,
meshgrid as meshgrid,
ones as ones,
ones_like as ones_like,
tril as tril,
triu as triu,
zeros as zeros,
zeros_like as zeros_like,
)

from jax.experimental.array_api._data_type_functions import (
astype as astype,
can_cast as can_cast,
finfo as finfo,
iinfo as iinfo,
isdtype as isdtype,
result_type as result_type,
)

from jax.experimental.array_api._dtypes import (
bool as bool,
int8 as int8,
int16 as int16,
int32 as int32,
int64 as int64,
uint8 as uint8,
uint16 as uint16,
uint32 as uint32,
uint64 as uint64,
float32 as float32,
float64 as float64,
complex64 as complex64,
complex128 as complex128,
)

from jax.experimental.array_api._elementwise_functions import (
abs as abs,
acos as acos,
acosh as acosh,
add as add,
asin as asin,
asinh as asinh,
atan as atan,
atan2 as atan2,
atanh as atanh,
bitwise_and as bitwise_and,
bitwise_invert as bitwise_invert,
bitwise_left_shift as bitwise_left_shift,
bitwise_or as bitwise_or,
bitwise_right_shift as bitwise_right_shift,
bitwise_xor as bitwise_xor,
ceil as ceil,
conj as conj,
cos as cos,
cosh as cosh,
divide as divide,
equal as equal,
exp as exp,
expm1 as expm1,
floor as floor,
floor_divide as floor_divide,
greater as greater,
greater_equal as greater_equal,
imag as imag,
isfinite as isfinite,
isinf as isinf,
isnan as isnan,
less as less,
less_equal as less_equal,
log as log,
log10 as log10,
log1p as log1p,
log2 as log2,
logaddexp as logaddexp,
logical_and as logical_and,
logical_not as logical_not,
logical_or as logical_or,
logical_xor as logical_xor,
multiply as multiply,
negative as negative,
not_equal as not_equal,
positive as positive,
pow as pow,
real as real,
remainder as remainder,
round as round,
sign as sign,
sin as sin,
sinh as sinh,
sqrt as sqrt,
square as square,
subtract as subtract,
tan as tan,
tanh as tanh,
trunc as trunc,
)

from jax.experimental.array_api._indexing_functions import (
take as take,
)

from jax.experimental.array_api._manipulation_functions import (
broadcast_arrays as broadcast_arrays,
broadcast_to as broadcast_to,
concat as concat,
expand_dims as expand_dims,
flip as flip,
permute_dims as permute_dims,
reshape as reshape,
roll as roll,
squeeze as squeeze,
stack as stack,
)

from jax.experimental.array_api._searching_functions import (
argmax as argmax,
argmin as argmin,
nonzero as nonzero,
where as where,
)

from jax.experimental.array_api._set_functions import (
unique_all as unique_all,
unique_counts as unique_counts,
unique_inverse as unique_inverse,
unique_values as unique_values,
)

from jax.experimental.array_api._sorting_functions import (
argsort as argsort,
sort as sort,
)

from jax.experimental.array_api._statistical_functions import (
max as max,
mean as mean,
min as min,
prod as prod,
std as std,
sum as sum,
var as var
)

from jax.experimental.array_api._utility_functions import (
all as all,
any as any,
)

from jax.experimental.array_api._linear_algebra_functions import (
matmul as matmul,
matrix_transpose as matrix_transpose,
tensordot as tensordot,
vecdot as vecdot,
)

from jax.experimental.array_api import _array_methods
_array_methods.add_array_object_methods()
del _array_methods
45 changes: 45 additions & 0 deletions jax/experimental/array_api/_array_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any, Callable, Optional, Union

import jax
from jax._src.array import ArrayImpl
from jax.experimental.array_api._version import __array_api_version__

from jax._src.lib import xla_extension as xe


def _array_namespace(self, /, *, api_version: None | str = None):
if api_version is not None and api_version != __array_api_version__:
raise ValueError(f"{api_version=!r} is not available; "
f"available versions are: {[__array_api_version__]}")
return jax.experimental.array_api


def _to_device(self, device: xe.Device | Callable[[], xe.Device], /, *,
stream: Optional[Union[int, Any]] = None):
if stream is not None:
raise NotImplementedError("stream argument of array.to_device()")
# The type of device is defined by Array.device. In JAX, this is a callable that
# returns a device, so we must handle this case to satisfy the API spec.
return jax.device_put(self, device() if callable(device) else device)


def add_array_object_methods():
# TODO(jakevdp): set on tracers as well?
setattr(ArrayImpl, "__array_namespace__", _array_namespace)
setattr(ArrayImpl, "to_device", _to_device)
21 changes: 21 additions & 0 deletions jax/experimental/array_api/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

e = np.e
inf = np.inf
nan = np.nan
newaxis = np.newaxis
pi = np.pi
Loading
Loading