From ca44b2a76ebb8bc449a789439563168acfc5c619 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Sat, 15 Jul 2023 14:21:35 -0500 Subject: [PATCH 1/2] Fix test_take to make axis optional when ndim == 1 I didn't explicitly test axis=None because it's not clear to me that should actually be supported, given that that's the same as axis=0. --- array_api_tests/test_indexing_functions.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index 9f2cf319..4f6db45c 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -22,7 +22,14 @@ def test_take(x, data): # * negative axis # * negative indices # * different dtypes for indices - axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis") + + # axis is optional but only if x.ndim == 1 + _axis_st = st.integers(0, max(x.ndim - 1, 0)) + if x.ndim == 1: + kw = data.draw(hh.kwargs(axis=_axis_st)) + else: + kw = {"axis": data.draw(_axis_st)} + axis = kw.get("axis", 0) _indices = data.draw( st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True), label="_indices", @@ -30,7 +37,7 @@ def test_take(x, data): indices = xp.asarray(_indices, dtype=dh.default_int) note(f"{indices=}") - out = xp.take(x, indices, axis=axis) + out = xp.take(x, indices, **kw) ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape( From f60471a713806228da5c986d1d7cd2fd197485ca Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 28 Feb 2024 13:48:34 +0000 Subject: [PATCH 2/2] Skip flaky `test_bitwise_xor` --- array_api_tests/test_operators_and_elementwise_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 5e7f717c..4c5e8d79 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -905,6 +905,7 @@ def test_bitwise_right_shift(ctx, data): ) +@pytest.mark.skip("sometimes triggers hypothesis.errors.DeadlineExceeded") # TODO: fix! @pytest.mark.parametrize( "ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes) )