Skip to content

Commit

Permalink
RDS-494: Fix nan bug with object array comparison.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: e6df88e5773493eab5b8ecfc6ff5b00444806f70
  • Loading branch information
kboyd committed Dec 1, 2022
1 parent 91f2214 commit ffec93b
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions tests/timeseries_dgan/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pytest
import torch

from gretel_synthetics.timeseries_dgan.config import Normalization
from gretel_synthetics.timeseries_dgan.transformations import (
Expand Down Expand Up @@ -38,24 +39,16 @@ def assert_array_equal(a: np.array, b: np.array):
# work on object arrays.
test_a = a.flatten()
test_b = b.flatten()
test_a_nan_mask = [x is np.nan for x in test_a]
test_b_nan_mask = [x is np.nan for x in test_b]
test_a_nan_mask = [isinstance(x, float) and np.isnan(x) for x in test_a]
test_b_nan_mask = [isinstance(x, float) and np.isnan(x) for x in test_b]

test_a[test_a_nan_mask] = replace_value
test_b[test_b_nan_mask] = replace_value

# Now compare 2 arrays that should not have any nans
np.testing.assert_array_equal(
test_a, test_b
), f"original arrays:\n{a}\nand\n{b}"

# Alternative using direct list comprehension for comparision
# assert np.all(
# [
# y is np.NaN if x is np.NaN else x == y
# for x, y in zip(a.flatten(), b.flatten())
# ]
# ), f"{a} == {b}"
test_a, test_b, err_msg=f"original arrays:\n{a}\nand\n{b}"
)
else:
np.testing.assert_array_equal(a, b)

Expand All @@ -69,6 +62,18 @@ def test_custom_assert_array_equal():
np.array(["a", "b", np.NaN], dtype="O"), np.array(["a", "b", np.NaN], dtype="O")
)

# Check for different ways of creating nan
assert_array_equal(
np.array(["a", 1.0, np.NaN, float("nan"), np.Inf / np.Inf], dtype="O"),
np.array(["a", 1.0, np.NaN, np.NaN, np.NaN], dtype="O"),
)

# Check for nans coming from torch versus numpy nans
a = torch.Tensor([1.0, np.NaN, float("nan"), np.Inf / np.Inf]).numpy().astype("O")
a[0] = "a"
b = np.array(["a", np.NaN, np.NaN, np.NaN], dtype="O")
assert_array_equal(a, b)

with pytest.raises(AssertionError):
assert_array_equal(np.array([0, 1, 3]), np.array([0, 1, 2]))

Expand Down Expand Up @@ -142,10 +147,6 @@ def test_one_hot_encoded_output_string():
assert_array_equal(expected2, output.inverse_transform(transformed2))


@pytest.mark.skipif(
sys.version_info < (3, 9),
reason="unknown bug makes nan comparisons for this test not work on python3.8",
)
def test_one_hot_encoded_output_nans():
output = OneHotEncodedOutput(name="foo")

Expand Down

0 comments on commit ffec93b

Please sign in to comment.