Skip to content

Commit

Permalink
[array-api] add simple smoketest target for standard CI testing
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 27, 2023
1 parent 5274ca9 commit 0d073a4
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/BUILD
Expand Up @@ -48,6 +48,15 @@ jax_test(
srcs = ["api_util_test.py"],
)

py_test(
name = "array_api_test",
srcs = ["array_api_test.py"],
deps = [
"//jax",
"//jax:experimental_array_api",
] + py_deps("absl/testing"),
)

jax_test(
name = "array_interoperability_test",
srcs = ["array_interoperability_test.py"],
Expand Down
238 changes: 238 additions & 0 deletions tests/array_api_test.py
@@ -0,0 +1,238 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Smoketest for jax.experimental.array_api
The full test suite for the array API is run via the array-api-tests CI;
this is just a minimal smoke test to catch issues early.
"""
from __future__ import annotations

from types import ModuleType

from absl.testing import absltest
import jax
from jax import config
from jax.experimental import array_api

config.parse_flags_with_absl()

MAIN_NAMESPACE = {
'abs',
'acos',
'acosh',
'add',
'all',
'annotations',
'any',
'arange',
'argmax',
'argmin',
'argsort',
'asarray',
'asin',
'asinh',
'astype',
'atan',
'atan2',
'atanh',
'bitwise_and',
'bitwise_invert',
'bitwise_left_shift',
'bitwise_or',
'bitwise_right_shift',
'bitwise_xor',
'bool',
'broadcast_arrays',
'broadcast_to',
'can_cast',
'ceil',
'complex128',
'complex64',
'concat',
'conj',
'cos',
'cosh',
'divide',
'e',
'empty',
'empty_like',
'equal',
'exp',
'expand_dims',
'expm1',
'eye',
'fft',
'finfo',
'flip',
'float32',
'float64',
'floor',
'floor_divide',
'from_dlpack',
'full',
'full_like',
'greater',
'greater_equal',
'iinfo',
'imag',
'inf',
'int16',
'int32',
'int64',
'int8',
'isdtype',
'isfinite',
'isinf',
'isnan',
'less',
'less_equal',
'linalg',
'linspace',
'log',
'log10',
'log1p',
'log2',
'logaddexp',
'logical_and',
'logical_not',
'logical_or',
'logical_xor',
'matmul',
'matrix_transpose',
'max',
'mean',
'meshgrid',
'min',
'multiply',
'nan',
'negative',
'newaxis',
'nonzero',
'not_equal',
'ones',
'ones_like',
'permute_dims',
'pi',
'positive',
'pow',
'prod',
'real',
'remainder',
'reshape',
'result_type',
'roll',
'round',
'sign',
'sin',
'sinh',
'sort',
'sqrt',
'square',
'squeeze',
'stack',
'std',
'subtract',
'sum',
'take',
'tan',
'tanh',
'tensordot',
'tril',
'triu',
'trunc',
'uint16',
'uint32',
'uint64',
'uint8',
'unique_all',
'unique_counts',
'unique_inverse',
'unique_values',
'var',
'vecdot',
'where',
'zeros',
'zeros_like',
}

LINALG_NAMESPACE = {
'cholesky',
'cross',
'det',
'diagonal',
'eigh',
'eigvalsh',
'inv',
'jax',
'matmul',
'matrix_norm',
'matrix_power',
'matrix_rank',
'matrix_transpose',
'outer',
'pinv',
'qr',
'slogdet',
'solve',
'svd',
'svdvals',
'tensordot',
'trace',
'vecdot',
'vector_norm',
}

FFT_NAMESPACE = {
'fft',
'fftfreq',
'fftn',
'fftshift',
'hfft',
'ifft',
'ifftn',
'ifftshift',
'ihfft',
'irfft',
'irfftn',
'rfft',
'rfftfreq',
'rfftn',
}


def names(module: ModuleType) -> set[str]:
return {name for name in dir(module) if not name.startswith('_')}


class ArrayAPISmokeTest(absltest.TestCase):
"""Smoke test for the array API."""

def test_main_namespace(self):
self.assertSetEqual(names(array_api), MAIN_NAMESPACE)

def test_linalg_namespace(self):
self.assertSetEqual(names(array_api.linalg), LINALG_NAMESPACE)

def test_fft_namespace(self):
self.assertSetEqual(names(array_api.fft), FFT_NAMESPACE)

def test_array_namespace_method(self):
x = array_api.arange(20)
self.assertIsInstance(x, jax.Array)
self.assertIs(x.__array_namespace__(), array_api)


if __name__ == '__main__':
absltest.main()

0 comments on commit 0d073a4

Please sign in to comment.