Skip to content

Commit

Permalink
Generate oneway promotable dtypes for elwise/op tests
Browse files Browse the repository at this point in the history
  • Loading branch information
honno committed Feb 3, 2022
1 parent 3364e48 commit 263b764
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 33 deletions.
9 changes: 8 additions & 1 deletion array_api_tests/meta/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from hypothesis import strategies as st

from .. import _array_module as xp
from .. import dtype_helpers as dh
from .. import shape_helpers as sh
from .. import xps
from ..test_creation_functions import frange
from ..test_manipulation_functions import roll_ndindex
from ..test_operators_and_elementwise_functions import (
mock_int_dtype,
oneway_broadcastable_shapes,
oneway_promotable_dtypes,
)
from ..test_signatures import extension_module

Expand Down Expand Up @@ -120,6 +122,11 @@ def test_int_to_dtype(x, dtype):
assert mock_int_dtype(x, dtype) == d


@given(oneway_promotable_dtypes(dh.all_dtypes))
def test_oneway_promotable_dtypes(D):
assert D.result_dtype == dh.result_type(*D)


@given(oneway_broadcastable_shapes())
def test_oneway_broadcastable_shapes(S):
assert sh.broadcast_shapes(*S) == S.result_shape
assert S.result_shape == sh.broadcast_shapes(*S)
83 changes: 51 additions & 32 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
return xps.boolean_dtypes() | all_integer_dtypes()


class OnewayPromotableDtypes(NamedTuple):
input_dtype: DataType
result_dtype: DataType


@st.composite
def oneway_promotable_dtypes(
draw, dtypes: List[DataType]
) -> st.SearchStrategy[OnewayPromotableDtypes]:
"""Return a strategy for input dtypes that promote to result dtypes."""
d1, d2 = draw(hh.mutually_promotable_dtypes(dtypes=dtypes))
result_dtype = dh.result_type(d1, d2)
if d1 == result_dtype:
return OnewayPromotableDtypes(d2, d1)
elif d2 == result_dtype:
return OnewayPromotableDtypes(d1, d2)
else:
reject()


class OnewayBroadcastableShapes(NamedTuple):
input_shape: Shape
result_shape: Shape
Expand Down Expand Up @@ -326,8 +346,14 @@ def __repr__(self):


def make_binary_params(
elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType]
elwise_func_name: str, dtypes: List[DataType]
) -> List[Param[BinaryParamContext]]:
if hh.FILTER_UNDEFINED_DTYPES:
dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)]
shared_oneway_dtypes = st.shared(oneway_promotable_dtypes(dtypes))
left_dtypes = shared_oneway_dtypes.map(lambda D: D.result_dtype)
right_dtypes = shared_oneway_dtypes.map(lambda D: D.input_dtype)

def make_param(
func_name: str, func_type: FuncType, right_is_scalar: bool
) -> Param[BinaryParamContext]:
Expand All @@ -338,32 +364,29 @@ def make_param(
left_sym = "x1"
right_sym = "x2"

shared_dtypes = st.shared(dtypes_strat)
if right_is_scalar:
left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes(**shapes_kw))
right_strat = shared_dtypes.flatmap(
lambda d: xps.from_dtype(d, **finite_kw)
)
left_strat = xps.arrays(dtype=left_dtypes, shape=hh.shapes(**shapes_kw))
right_strat = right_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw))
else:
if func_type is FuncType.IOP:
shared_oneway_shapes = st.shared(oneway_broadcastable_shapes())
left_strat = xps.arrays(
dtype=shared_dtypes,
dtype=left_dtypes,
shape=shared_oneway_shapes.map(lambda S: S.result_shape),
)
right_strat = xps.arrays(
dtype=shared_dtypes,
dtype=right_dtypes,
shape=shared_oneway_shapes.map(lambda S: S.input_shape),
)
else:
mutual_shapes = st.shared(
hh.mutually_broadcastable_shapes(2, **shapes_kw)
)
left_strat = xps.arrays(
dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[0])
dtype=left_dtypes, shape=mutual_shapes.map(lambda pair: pair[0])
)
right_strat = xps.arrays(
dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[1])
dtype=right_dtypes, shape=mutual_shapes.map(lambda pair: pair[1])
)

if func_type is FuncType.FUNC:
Expand Down Expand Up @@ -540,7 +563,7 @@ def test_acosh(x):
)


@pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes))
@given(data=st.data())
def test_add(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -605,7 +628,7 @@ def test_atanh(x):


@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_and", boolean_and_all_integer_dtypes())
"ctx", make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes)
)
@given(data=st.data())
def test_bitwise_and(ctx, data):
Expand All @@ -624,7 +647,7 @@ def test_bitwise_and(ctx, data):


@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_left_shift", all_integer_dtypes())
"ctx", make_binary_params("bitwise_left_shift", dh.all_int_dtypes)
)
@given(data=st.data())
def test_bitwise_left_shift(ctx, data):
Expand Down Expand Up @@ -664,7 +687,7 @@ def test_bitwise_invert(ctx, data):


@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_or", boolean_and_all_integer_dtypes())
"ctx", make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes)
)
@given(data=st.data())
def test_bitwise_or(ctx, data):
Expand All @@ -683,7 +706,7 @@ def test_bitwise_or(ctx, data):


@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_right_shift", all_integer_dtypes())
"ctx", make_binary_params("bitwise_right_shift", dh.all_int_dtypes)
)
@given(data=st.data())
def test_bitwise_right_shift(ctx, data):
Expand All @@ -704,7 +727,7 @@ def test_bitwise_right_shift(ctx, data):


@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_xor", boolean_and_all_integer_dtypes())
"ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes)
)
@given(data=st.data())
def test_bitwise_xor(ctx, data):
Expand Down Expand Up @@ -746,7 +769,7 @@ def test_cosh(x):
unary_assert_against_refimpl("cosh", x, out, math.cosh)


@pytest.mark.parametrize("ctx", make_binary_params("divide", xps.floating_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.float_dtypes))
@given(data=st.data())
def test_divide(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand All @@ -769,7 +792,7 @@ def test_divide(ctx, data):
)


@pytest.mark.parametrize("ctx", make_binary_params("equal", xps.scalar_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes))
@given(data=st.data())
def test_equal(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -821,9 +844,7 @@ def test_floor(x):
unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True)


@pytest.mark.parametrize(
"ctx", make_binary_params("floor_divide", xps.numeric_dtypes())
)
@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.numeric_dtypes))
@given(data=st.data())
def test_floor_divide(ctx, data):
left = data.draw(
Expand All @@ -842,7 +863,7 @@ def test_floor_divide(ctx, data):
binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv)


@pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.numeric_dtypes))
@given(data=st.data())
def test_greater(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand All @@ -862,9 +883,7 @@ def test_greater(ctx, data):
)


@pytest.mark.parametrize(
"ctx", make_binary_params("greater_equal", xps.numeric_dtypes())
)
@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.numeric_dtypes))
@given(data=st.data())
def test_greater_equal(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -908,7 +927,7 @@ def test_isnan(x):
unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool)


@pytest.mark.parametrize("ctx", make_binary_params("less", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("less", dh.numeric_dtypes))
@given(data=st.data())
def test_less(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand All @@ -928,7 +947,7 @@ def test_less(ctx, data):
)


@pytest.mark.parametrize("ctx", make_binary_params("less_equal", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.numeric_dtypes))
@given(data=st.data())
def test_less_equal(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -1040,7 +1059,7 @@ def test_logical_xor(x1, x2):
)


@pytest.mark.parametrize("ctx", make_binary_params("multiply", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
@given(data=st.data())
def test_multiply(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -1073,7 +1092,7 @@ def test_negative(ctx, data):
)


@pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes))
@given(data=st.data())
def test_not_equal(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -1105,7 +1124,7 @@ def test_positive(ctx, data):
ph.assert_array(ctx.func_name, out, x)


@pytest.mark.parametrize("ctx", make_binary_params("pow", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes))
@given(data=st.data())
def test_pow(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand All @@ -1129,7 +1148,7 @@ def test_pow(ctx, data):
)


@pytest.mark.parametrize("ctx", make_binary_params("remainder", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.numeric_dtypes))
@given(data=st.data())
def test_remainder(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -1200,7 +1219,7 @@ def test_sqrt(x):
)


@pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes))
@given(data=st.data())
def test_subtract(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down

0 comments on commit 263b764

Please sign in to comment.