Skip to content

Commit

Permalink
Expose existing functions in array API namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Apr 15, 2024
1 parent 2c85ca6 commit 8b93da1
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 1 deletion.
7 changes: 7 additions & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
ceil as ceil,
clip as clip,
conj as conj,
copysign as copysign,
cos as cos,
cosh as cosh,
divide as divide,
Expand All @@ -139,6 +140,8 @@
logical_not as logical_not,
logical_or as logical_or,
logical_xor as logical_xor,
maximum as maximum,
minimum as minimum,
multiply as multiply,
negative as negative,
not_equal as not_equal,
Expand All @@ -148,6 +151,7 @@
remainder as remainder,
round as round,
sign as sign,
signbit as signbit,
sin as sin,
sinh as sinh,
sqrt as sqrt,
Expand All @@ -168,7 +172,9 @@
concat as concat,
expand_dims as expand_dims,
flip as flip,
moveaxis as moveaxis,
permute_dims as permute_dims,
repeat as repeat,
reshape as reshape,
roll as roll,
squeeze as squeeze,
Expand All @@ -179,6 +185,7 @@
argmax as argmax,
argmin as argmin,
nonzero as nonzero,
searchsorted as searchsorted,
where as where,
)

Expand Down
23 changes: 22 additions & 1 deletion jax/experimental/array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
result_type as _result_type,
isdtype as _isdtype,
)
import numpy as np


def _promote_dtypes(name, *args):
Expand Down Expand Up @@ -148,6 +147,11 @@ def conj(x, /):
return jax.numpy.conj(x)


def copysign(x1, x2, /):
"""Composes a floating-point value with the magnitude of x1_i and the sign of x2_i for each element of the input array x1."""
return jax.numpy.copysign(x1, x2)


def cos(x, /):
"""Calculates an implementation-dependent approximation to the cosine for each element x_i of the input array x."""
x, = _promote_dtypes("cos", x)
Expand Down Expand Up @@ -300,6 +304,18 @@ def logical_xor(x1, x2, /):
return jax.numpy.logical_xor(x1, x2)


def maximum(x1, x2, /):
"""Computes the maximum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("maximum", x1, x2)
return jax.numpy.maximum(x1, x2)


def minimum(x1, x2, /):
"""Computes the minimum value for each element x1_i of the input array x1 relative to the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("minimum", x1, x2)
return jax.numpy.minimum(x1, x2)


def multiply(x1, x2, /):
"""Calculates the product for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1, x2 = _promote_dtypes("multiply", x1, x2)
Expand Down Expand Up @@ -356,6 +372,11 @@ def sign(x, /):
return jax.numpy.sign(x)


def signbit(x, /):
"""Determines whether the sign bit is set for each element x_i of the input array x."""
return jax.numpy.signbit(x)


def sin(x, /):
"""Calculates an implementation-dependent approximation to the sine for each element x_i of the input array x."""
x, = _promote_dtypes("sin", x)
Expand Down
10 changes: 10 additions & 0 deletions jax/experimental/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,21 @@ def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None) -> Array:
return jax.numpy.flip(x, axis=axis)


def moveaxis(x: Array, source: int | tuple[int, ...], destination: int | tuple[int, ...], /) -> Array:
"""Moves array axes (dimensions) to new positions, while leaving other axes in their original positions."""
return jax.numpy.moveaxis(x, source, destination)


def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array:
"""Permutes the axes (dimensions) of an array x."""
return jax.numpy.permute_dims(x, axes=axes)


def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array:
"""Repeats each element of an array a specified number of times on a per-element basis."""
return jax.numpy.repeat(x, repeats=repeats, axis=axis)


def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
"""Reshapes an array without changing its data."""
del copy # unused
Expand Down
9 changes: 9 additions & 0 deletions jax/experimental/array_api/_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ def nonzero(x, /):
return jax.numpy.nonzero(x)


def searchsorted(x1, x2, /, *, side='left', sorter=None):
"""
Finds the indices into x1 such that, if the corresponding elements in x2
were inserted before the indices, the order of x1, when sorted in ascending
order, would be preserved.
"""
return jax.numpy.searchsorted(x1, x2, side=side, sorter=sorter)


def where(condition, x1, x2, /):
"""Returns elements chosen from x1 or x2 depending on condition."""
dtype = _result_type(x1, x2)
Expand Down
7 changes: 7 additions & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
'complex64',
'concat',
'conj',
'copysign',
'cos',
'cosh',
'divide',
Expand Down Expand Up @@ -115,9 +116,12 @@
'matmul',
'matrix_transpose',
'max',
'maximum',
'mean',
'meshgrid',
'min',
'minimum',
'moveaxis',
'multiply',
'nan',
'negative',
Expand All @@ -133,11 +137,14 @@
'prod',
'real',
'remainder',
'repeat',
'reshape',
'result_type',
'roll',
'round',
'searchsorted',
'sign',
'signbit',
'sin',
'sinh',
'sort',
Expand Down

0 comments on commit 8b93da1

Please sign in to comment.