Skip to content

Commit

Permalink
fix(python): address several edge-cases found when asserting NaN equa…
Browse files Browse the repository at this point in the history
…lity (pola-rs#5732)
  • Loading branch information
alexander-beedie authored and chitralverma committed Dec 10, 2022
1 parent eac3345 commit f462a20
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 12 deletions.
27 changes: 22 additions & 5 deletions py-polars/polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,14 @@ def _assert_series_inner(
if left.dtype != right.dtype:
raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)

# confirm that we can call 'is_nan' on both sides
left_is_float = left.dtype in (Float32, Float64)
right_is_float = right.dtype in (Float32, Float64)
comparing_float_dtypes = left_is_float and right_is_float

# create mask of which (if any) values are unequal
unequal = left != right
if unequal.any() and nans_compare_equal and left.dtype in (Float32, Float64):
if unequal.any() and nans_compare_equal and comparing_float_dtypes:
# handle NaN values (which compare unequal to themselves)
unequal = unequal & ~(
(left.is_nan() & right.is_nan()).fill_null(pli.lit(False))
Expand All @@ -182,13 +187,25 @@ def _assert_series_inner(
obj, "Exact value mismatch", left=list(left), right=list(right)
)
else:
# apply check with tolerance, but only to the known-unequal matches
# apply check with tolerance (to the known-unequal matches).
left, right = left.filter(unequal), right.filter(unequal)
mismatch, nan_info = False, ""
if (((left - right).abs() > (atol + rtol * right.abs())).sum() != 0) or (
(left.is_null() != right.is_null()).any()
):
left.is_null() != right.is_null()
).any():
mismatch = True
elif comparing_float_dtypes:
# note: take special care with NaN values.
if not nans_compare_equal and (left.is_nan() == right.is_nan()).any():
nan_info = " (nans_compare_equal=False)"
mismatch = True
elif (left.is_nan() != right.is_nan()).any():
nan_info = f" (nans_compare_equal={nans_compare_equal})"
mismatch = True

if mismatch:
raise_assert_detail(
obj, "Value mismatch", left=list(left), right=list(right)
obj, f"Value mismatch{nan_info}", left=list(left), right=list(right)
)


Expand Down
46 changes: 39 additions & 7 deletions py-polars/tests/unit/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,49 @@ def test_compare_series_empty_equal() -> None:


def test_compare_series_nans_assert_equal() -> None:
# NaN values do not _compare_ equal, but should _assert_ as equal here
# note: NaN values do not _compare_ equal, but should _assert_ equal (by default)
nan = float("NaN")

srs1 = pl.Series([1.0, 2.0, nan])
srs2 = pl.Series([1.0, 2.0, nan])
assert_series_equal(srs1, srs2)
srs1 = pl.Series([1.0, 2.0, nan, 4.0, None, 6.0])
srs2 = pl.Series([1.0, nan, 3.0, 4.0, None, 6.0])
srs3 = pl.Series([1.0, 2.0, 3.0, 4.0, None, 6.0])

for srs in (srs1, srs2, srs3):
assert_series_equal(srs, srs)
assert_series_equal(srs, srs, check_exact=True)

srs1 = pl.Series([1.0, 2.0, nan])
srs2 = pl.Series([1.0, nan, 3.0])
with pytest.raises(AssertionError):
assert_series_equal(srs1, srs2, check_exact=True)
assert_series_equal(srs1, srs1, nans_compare_equal=False)
with pytest.raises(AssertionError):
assert_series_equal(srs1, srs1, nans_compare_equal=False, check_exact=True)

for check_exact, nans_equal in (
(False, False),
(False, True),
(True, False),
(True, True),
):
if check_exact:
check_msg = "Exact value mismatch"
else:
check_msg = f"Value mismatch.*nans_compare_equal={nans_equal}"

with pytest.raises(AssertionError, match=check_msg):
assert_series_equal(
srs1, srs2, check_exact=check_exact, nans_compare_equal=nans_equal
)
with pytest.raises(AssertionError, match=check_msg):
assert_series_equal(
srs1, srs3, check_exact=check_exact, nans_compare_equal=nans_equal
)

srs4 = pl.Series([1.0, 2.0, 3.0, 4.0, None, 6.0])
srs5 = pl.Series([1.0, 2.0, 3.0, 4.0, nan, 6.0])
srs6 = pl.Series([1, 2, 3, 4, None, 6])

assert_series_equal(srs4, srs6, check_dtype=False)
with pytest.raises(AssertionError):
assert_series_equal(srs5, srs6, check_dtype=False)


def test_compare_series_nulls() -> None:
Expand Down

0 comments on commit f462a20

Please sign in to comment.