Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix integer / float scalar promotion #23148

Merged
merged 3 commits into from Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 9 additions & 16 deletions numpy/core/src/umath/scalarmath.c.src
Expand Up @@ -1179,6 +1179,11 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
* (Half, Float, Double, LongDouble,
* CFloat, CDouble, CLongDouble)*4,
* (Half, Float, Double, LongDouble)*3#
* #NAME = (BYTE, UBYTE, SHORT, USHORT, INT, UINT,
* LONG, ULONG, LONGLONG, ULONGLONG)*12,
* (HALF, FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE)*4,
* (HALF, FLOAT, DOUBLE, LONGDOUBLE)*3#
* #type = (npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint,
* npy_long, npy_ulong, npy_longlong, npy_ulonglong)*12,
* (npy_half, npy_float, npy_double, npy_longdouble,
Expand All @@ -1202,24 +1207,12 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
* (npy_half, npy_float, npy_double, npy_longdouble,
* npy_cfloat, npy_cdouble, npy_clongdouble)*4,
* (npy_half, npy_float, npy_double, npy_longdouble)*3#
* #oname = (byte, ubyte, short, ushort, int, uint,
* long, ulong, longlong, ulonglong)*11,
* double*10,
* (half, float, double, longdouble,
* cfloat, cdouble, clongdouble)*4,
* (half, float, double, longdouble)*3#
* #OName = (Byte, UByte, Short, UShort, Int, UInt,
* Long, ULong, LongLong, ULongLong)*11,
* Double*10,
* (Half, Float, Double, LongDouble,
* CFloat, CDouble, CLongDouble)*4,
* (Half, Float, Double, LongDouble)*3#
* #ONAME = (BYTE, UBYTE, SHORT, USHORT, INT, UINT,
* LONG, ULONG, LONGLONG, ULONGLONG)*11,
* DOUBLE*10,
* (HALF, FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE)*4,
* (HALF, FLOAT, DOUBLE, LONGDOUBLE)*3#
*/
#define IS_@name@
/* drop the "true_" from "true_divide" for floating point warnings: */
Expand All @@ -1234,7 +1227,7 @@ static PyObject *
@name@_@oper@(PyObject *a, PyObject *b)
{
PyObject *ret;
@otype@ arg1, arg2, other_val;
@type@ arg1, arg2, other_val;

/*
* Check if this operation may be considered forward. Note `is_forward`
Expand Down Expand Up @@ -1263,7 +1256,7 @@ static PyObject *
PyObject *other = is_forward ? b : a;

npy_bool may_need_deferring;
conversion_result res = convert_to_@oname@(
conversion_result res = convert_to_@name@(
other, &other_val, &may_need_deferring);
if (res == CONVERSION_ERROR) {
return NULL; /* an error occurred (should never happen) */
Expand Down Expand Up @@ -1305,7 +1298,7 @@ static PyObject *
*/
return PyGenericArrType_Type.tp_as_number->nb_@oper@(a,b);
case CONVERT_PYSCALAR:
if (@ONAME@_setitem(other, (char *)&other_val, NULL) < 0) {
if (@NAME@_setitem(other, (char *)&other_val, NULL) < 0) {
return NULL;
}
break;
Expand Down Expand Up @@ -1345,7 +1338,7 @@ static PyObject *
#if @twoout@
int retstatus = @name@_ctype_@oper@(arg1, arg2, &out, &out2);
#else
int retstatus = @oname@_ctype_@oper@(arg1, arg2, &out);
int retstatus = @name@_ctype_@oper@(arg1, arg2, &out);
#endif

#if @fperr@
Expand Down
59 changes: 47 additions & 12 deletions numpy/core/tests/test_scalarmath.py
Expand Up @@ -75,17 +75,7 @@ def test_leak(self):
np.add(1, 1)


@pytest.mark.slow
@settings(max_examples=10000, deadline=2000)
@given(sampled_from(reasonable_operators_for_scalars),
hynp.arrays(dtype=hynp.scalar_dtypes(), shape=()),
hynp.arrays(dtype=hynp.scalar_dtypes(), shape=()))
def test_array_scalar_ufunc_equivalence(op, arr1, arr2):
"""
This is a thorough test attempting to cover important promotion paths
and ensuring that arrays and scalars stay as aligned as possible.
However, if it creates troubles, it should maybe just be removed.
"""
def check_ufunc_scalar_equivalence(op, arr1, arr2):
scalar1 = arr1[()]
scalar2 = arr2[()]
assert isinstance(scalar1, np.generic)
Expand All @@ -95,6 +85,11 @@ def test_array_scalar_ufunc_equivalence(op, arr1, arr2):
comp_ops = {operator.ge, operator.gt, operator.le, operator.lt}
if op in comp_ops and (np.isnan(scalar1) or np.isnan(scalar2)):
pytest.xfail("complex comp ufuncs use sort-order, scalars do not.")
if op == operator.pow and arr2.item() in [-1, 0, 0.5, 1, 2]:
# array**scalar special case can have different result dtype
# (Other powers may have issues also, but are not hit here.)
# TODO: It would be nice to resolve this issue.
pytest.skip("array**2 can have incorrect/weird result dtype")

# ignore fpe's since they may just mismatch for integers anyway.
with warnings.catch_warnings(), np.errstate(all="ignore"):
Expand All @@ -107,7 +102,47 @@ def test_array_scalar_ufunc_equivalence(op, arr1, arr2):
op(scalar1, scalar2)
else:
scalar_res = op(scalar1, scalar2)
assert_array_equal(scalar_res, res)
assert_array_equal(scalar_res, res, strict=True)


@pytest.mark.slow
@settings(max_examples=10000, deadline=2000)
@given(sampled_from(reasonable_operators_for_scalars),
hynp.arrays(dtype=hynp.scalar_dtypes(), shape=()),
hynp.arrays(dtype=hynp.scalar_dtypes(), shape=()))
def test_array_scalar_ufunc_equivalence(op, arr1, arr2):
"""
This is a thorough test attempting to cover important promotion paths
and ensuring that arrays and scalars stay as aligned as possible.
However, if it creates troubles, it should maybe just be removed.
"""
check_ufunc_scalar_equivalence(op, arr1, arr2)


@pytest.mark.slow
@given(sampled_from(reasonable_operators_for_scalars),
hynp.scalar_dtypes(), hynp.scalar_dtypes())
def test_array_scalar_ufunc_dtypes(op, dt1, dt2):
# Same as above, but don't worry about sampling weird values so that we
# do not have to sample as much
arr1 = np.array(2, dtype=dt1)
arr2 = np.array(3, dtype=dt2) # some power do weird things.

check_ufunc_scalar_equivalence(op, arr1, arr2)


@pytest.mark.parametrize("fscalar", [np.float16, np.float32])
def test_int_float_promotion_truediv(fscalar):
# Promotion for mixed int and float32/float16 must not go to float64
i = np.int8(1)
f = fscalar(1)
expected = np.result_type(i, f)
assert (i / f).dtype == expected
assert (f / i).dtype == expected
# But normal int / int true division goes to float64:
assert (i / i).dtype == np.dtype("float64")
# For int16, result has to be ast least float32 (takes ufunc path):
assert (np.int16(1) / f).dtype == np.dtype("float32")


class TestBaseMath:
Expand Down