Skip to content

Commit

Permalink
Add initial array_api interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 24, 2023
1 parent 1831b3c commit 596164b
Show file tree
Hide file tree
Showing 6 changed files with 828 additions and 0 deletions.
131 changes: 131 additions & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__array_api_version__ = '2022.12'

from jax.experimental.array_api._constants import (
e as e,
inf as inf,
nan as nan,
newaxis as newaxis,
pi as pi,
)

from jax.experimental.array_api._creation_functions import (
arange as arange,
asarray as asarray,
empty as empty,
empty_like as empty_like,
eye as eye,
from_dlpack as from_dlpack,
full as full,
full_like as full_like,
linspace as linspace,
meshgrid as meshgrid,
ones as ones,
ones_like as ones_like,
tril as tril,
triu as triu,
zeros as zeros,
zeros_like as zeros_like,
)

from jax.experimental.array_api._data_type_functions import (
astype as astype,
can_cast as can_cast,
finfo as finfo,
iinfo as iinfo,
isdtype as isdtype,
result_type as result_type,
)

from jax.experimental.array_api._dtypes import (
bool as bool,
int8 as int8,
int16 as int16,
int32 as int32,
int64 as int64,
uint8 as uint8,
uint16 as uint16,
uint32 as uint32,
uint64 as uint64,
float32 as float32,
float64 as float64,
complex64 as complex64,
complex128 as complex128,
)

from jax.experimental.array_api._elementwise_functions import (
abs as abs,
acos as acos,
acosh as acosh,
add as add,
asin as asin,
asinh as asinh,
atan as atan,
atan2 as atan2,
atanh as atanh,
bitwise_and as bitwise_and,
bitwise_invert as bitwise_invert,
bitwise_left_shift as bitwise_left_shift,
bitwise_or as bitwise_or,
bitwise_right_shift as bitwise_right_shift,
bitwise_xor as bitwise_xor,
ceil as ceil,
conj as conj,
cos as cos,
cosh as cosh,
divide as divide,
equal as equal,
exp as exp,
expm1 as expm1,
floor as floor,
floor_divide as floor_divide,
greater as greater,
greater_equal as greater_equal,
imag as imag,
isfinite as isfinite,
isinf as isinf,
isnan as isnan,
jax as jax,
less as less,
less_equal as less_equal,
log as log,
log10 as log10,
log1p as log1p,
log2 as log2,
logaddexp as logaddexp,
logical_and as logical_and,
logical_not as logical_not,
logical_or as logical_or,
logical_xor as logical_xor,
multiply as multiply,
negative as negative,
not_equal as not_equal,
np as np,
positive as positive,
pow as pow,
real as real,
remainder as remainder,
round as round,
sign as sign,
sin as sin,
sinh as sinh,
sqrt as sqrt,
square as square,
subtract as subtract,
tan as tan,
tanh as tanh,
trunc as trunc,
)
7 changes: 7 additions & 0 deletions jax/experimental/array_api/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import numpy as np

e = np.e
inf = np.inf
nan = np.nan
newaxis = np.newaxis
pi = np.pi
65 changes: 65 additions & 0 deletions jax/experimental/array_api/_creation_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import jax.numpy as jnp


def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)

def asarray(obj, /, *, dtype=None, device=None, copy=None):
return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device)

def empty(shape, *, dtype=None, device=None):
return jax.device_put(jnp.empty(shape, dtype=dtype), device=device)

def empty_like(x, /, *, dtype=None, device=None):
return jax.device_put(jnp.empty_like(x, dtype=dtype), device=device)

def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)

def from_dlpack(x, /):
return jnp.from_dlpack(x)

def full(shape, fill_value, *, dtype=None, device=None):
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)

def full_like(x, /, fill_value, *, dtype=None, device=None):
return jax.device_put(jnp.full_like(x, fill_value=fill_value, dtype=dtype), device=device)

def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True):
return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device)

def meshgrid(*arrays, indexing='xy'):
return jnp.meshgrid(*arrays, indexing=indexing)

def ones(shape, *, dtype=None, device=None):
return jax.device_put(jnp.ones(shape, dtype=dtype), device=device)

def ones_like(x, /, *, dtype=None, device=None):
return jax.device_put(jnp.ones_like(x, dtype=dtype), device=device)

def tril(x, /, *, k=0):
return jnp.tril(x, k=k)

def triu(x, /, *, k=0):
return jnp.triu(x, k=k)

def zeros(shape, *, dtype=None, device=None):
return jax.device_put(jnp.zeros(shape, dtype=dtype), device=device)

def zeros_like(x, /, *, dtype=None, device=None):
return jax.device_put(jnp.zeros_like(x, dtype=dtype), device=device)
Loading

0 comments on commit 596164b

Please sign in to comment.