Skip to content

Commit

Permalink
PROD-390: Fix RDT FloatFormatter bug
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 10788f00884bc0bf2f25fcd52ccd9885602a8fed
  • Loading branch information
misberner committed Mar 21, 2023
1 parent be1e267 commit faab8d1
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/gretel_synthetics/actgan/actgan_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from gretel_synthetics.actgan.actgan import ACTGANSynthesizer
from gretel_synthetics.detectors.sdv import SDVTableMetadata
from gretel_synthetics.utils import torch_utils
from gretel_synthetics.utils import rdt_patches, torch_utils
from sdv.tabular.base import BaseTabularModel

if TYPE_CHECKING:
Expand Down Expand Up @@ -345,3 +345,18 @@ def __init__(
"pac": pac,
"cuda": cuda,
}

def fit(self, *args, **kwargs):
# Float formatter should not affect anything during model fitting, but it's
# hard to know with certainty what exactly is going on under the hood. Therefore,
# take a conservative approach.
with rdt_patches.patch_float_formatter_rounding_bug():
return super().fit(*args, **kwargs)

def sample(self, *args, **kwargs):
with rdt_patches.patch_float_formatter_rounding_bug():
return super().sample(*args, **kwargs)

def sample_remaining_columns(self, *args, **kwargs):
with rdt_patches.patch_float_formatter_rounding_bug():
return super().sample_remaining_columns(*args, **kwargs)
73 changes: 73 additions & 0 deletions src/gretel_synthetics/utils/rdt_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

from contextlib import contextmanager

import numpy as np
import pandas as pd

from rdt.transformers.numerical import FloatFormatter, INTEGER_BOUNDS


@contextmanager
def patch_float_formatter_rounding_bug():
"""Returns a contextmanager object that temporarily patches FloatFormatter.
A bug in RDT's FloatFormatter in versions <=1.2.1 might result in floating-point
numbers being rounded to integers. Enclose code using FloatFormatter in a `with`
block with this object to ensure a patched version not suffering from the bug is
used.
"""
orig_reverse_transform = FloatFormatter._reverse_transform
try:
FloatFormatter._reverse_transform = _patched_float_formatter_reverse_transform
yield
finally:
FloatFormatter._reverse_transform = orig_reverse_transform


# The below function is mostly copied from
# https://github.com/sdv-dev/RDT/blob/v1.2.1/rdt/transformers/numerical.py#L188
# which is MIT-licensed, fixing a bug as detailed below.
def _patched_float_formatter_reverse_transform(self, data):
"""Convert data back into the original format.
Args:
data (pd.Series or numpy.ndarray):
Data to transform.
Returns:
numpy.ndarray
"""

if not isinstance(data, np.ndarray):
data = data.to_numpy()

if self.missing_value_replacement is not None:
data = self.null_transformer.reverse_transform(data)

if self.enforce_min_max_values:
data = data.clip(self._min_value, self._max_value)
elif self.computer_representation != "Float":
min_bound, max_bound = INTEGER_BOUNDS[self.computer_representation]
data = data.clip(min_bound, max_bound)

is_integer = np.dtype(self._dtype).kind == "i"
# BUGFIX: Instead of checking for self._learn_rounding_scheme, check if
# self._rounding_digits is not None. This implies self._learn_rounding_scheme,
# but self._rounding_digits MAY actually be None if the data cannot be rounded
# to any number of decimal digits (consider, e.g., that 0.9... and 0.1.... use
# a different exponent in the IEEE754 representation and thus have different
# numbers of bits available for decimal places). The idea that there may be
# a "maximum" number of decimal digits that suffices is a pure heuristic that
# only works for some types of input data (basically, when all values are in the
# range [1.0, 2.0) ).
if self._rounding_digits is not None:
data = data.round(self._rounding_digits)
elif is_integer:
data = data.round(0)
# END BUGFIX

if pd.isna(data).any() and is_integer:
return data

return data.astype(self._dtype)

0 comments on commit faab8d1

Please sign in to comment.