Skip to content

Commit

Permalink
PROD-409: Fix a newly introduced rounding bug in RDT FloatFormatter
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 700abe81276cdd93c5774c8b690428ca810dee2b
  • Loading branch information
misberner committed Apr 14, 2023
1 parent 90b1490 commit 459a329
Show file tree
Hide file tree
Showing 3 changed files with 879 additions and 1 deletion.
23 changes: 22 additions & 1 deletion src/gretel_synthetics/utils/rdt_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd

from rdt.transformers.numerical import FloatFormatter, INTEGER_BOUNDS
from rdt.transformers.numerical import FloatFormatter, INTEGER_BOUNDS, MAX_DECIMALS


@contextmanager
Expand All @@ -18,10 +18,15 @@ def patch_float_formatter_rounding_bug():
used.
"""
orig_reverse_transform = FloatFormatter._reverse_transform
orig_learn_rounding_digits = FloatFormatter._learn_rounding_digits
try:
FloatFormatter._reverse_transform = _patched_float_formatter_reverse_transform
FloatFormatter._learn_rounding_digits = staticmethod(
_patched_float_formatter_learn_rounding_digits
)
yield
finally:
FloatFormatter._learn_rounding_digits = staticmethod(orig_learn_rounding_digits)
FloatFormatter._reverse_transform = orig_reverse_transform


Expand Down Expand Up @@ -71,3 +76,19 @@ def _patched_float_formatter_reverse_transform(self, data):
return data

return data.astype(self._dtype)


def _patched_float_formatter_learn_rounding_digits(data):
# check if data has any decimals
data = np.array(data)
roundable_data = data[~(np.isinf(data) | pd.isna(data))]
if not ((roundable_data % 1) != 0).any():
# BUGFIX: if the above evaluates to true, that means none of the
# non-NaN input values have any non-zero decimals.
return 0
if (roundable_data == roundable_data.round(MAX_DECIMALS)).all():
for decimal in range(MAX_DECIMALS + 1):
if (roundable_data == roundable_data.round(decimal)).all():
return decimal

return None

0 comments on commit 459a329

Please sign in to comment.