Skip to content

Commit

Permalink
RDS-564: Add cbn_sample_size param for ACTGAN model
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 73708653989fb6e62ce04ca8b2daff7f49255115
  • Loading branch information
misberner committed Mar 22, 2023
1 parent 5255284 commit 17766ab
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/gretel_synthetics/actgan/actgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ class ACTGANSynthesizer(BaseSynthesizer):
Binary encoding currently may produce errant NaN values during reverse transformation. By default
these NaN's will be left in place, however if this value is set to "mode" then those NaN's will
be replaced by a random value that is a known mode for a given column.
cbn_sample_size:
Number of rows to sample from each column for identifying clusters for the cluster-based normalizer.
This only applies to float columns. By default, no sampling is done and all values are considered,
which may be very slow.
log_frequency:
Whether to use log frequency of categorical levels in conditional
sampling. Defaults to ``True``.
Expand Down Expand Up @@ -180,6 +184,7 @@ def __init__(
discriminator_steps: int = 1,
binary_encoder_cutoff: int = 500,
binary_encoder_nan_handler: Optional[str] = None,
cbn_sample_size: Optional[int] = None,
log_frequency: bool = True,
verbose: bool = False,
epochs: int = 300,
Expand All @@ -205,6 +210,7 @@ def __init__(
self._binary_encoder_cutoff = binary_encoder_cutoff
self._binary_encoder_nan_handler = binary_encoder_nan_handler
self._log_frequency = log_frequency
self._cbn_sample_size = cbn_sample_size
self._verbose = verbose
self._epochs = epochs
self._epoch_callback = epoch_callback
Expand Down Expand Up @@ -359,6 +365,7 @@ def _pre_fit_transform(
self._transformer = DataTransformer(
binary_encoder_cutoff=self._binary_encoder_cutoff,
binary_encoder_nan_handler=self._binary_encoder_nan_handler,
cbn_sample_size=self._cbn_sample_size,
verbose=self._verbose,
)
self._transformer.fit(train_data, discrete_columns)
Expand Down
6 changes: 6 additions & 0 deletions src/gretel_synthetics/actgan/actgan_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ class ACTGAN(_ACTGANModel):
Binary encoding currently may produce errant NaN values during reverse transformation. By default
these NaN's will be left in place, however if this value is set to "mode" then those NaN' will
be replaced by a random value that is a known mode for a given column.
cbn_sample_size:
Number of rows to sample from each column for identifying clusters for the cluster-based normalizer.
This only applies to float columns. By default, no sampling is done and all values are considered,
which may be very slow.
log_frequency:
Whether to use log frequency of categorical levels in conditional
sampling. Defaults to ``True``.
Expand Down Expand Up @@ -302,6 +306,7 @@ def __init__(
discriminator_steps: int = 1,
binary_encoder_cutoff: int = 500,
binary_encoder_nan_handler: Optional[str] = None,
cbn_sample_size: Optional[int] = None,
log_frequency: bool = True,
verbose: bool = False,
epochs: int = 300,
Expand Down Expand Up @@ -338,6 +343,7 @@ def __init__(
"discriminator_steps": discriminator_steps,
"binary_encoder_cutoff": binary_encoder_cutoff,
"binary_encoder_nan_handler": binary_encoder_nan_handler,
"cbn_sample_size": cbn_sample_size,
"log_frequency": log_frequency,
"verbose": verbose,
"epochs": epochs,
Expand Down
8 changes: 8 additions & 0 deletions src/gretel_synthetics/actgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class DataTransformer:
_column_transform_info_list: List[ColumnTransformInfo]
_binary_encoder_cutoff: int
_binary_encoder_han_handler: Optional[str]
_cbn_sample_size: Optional[int]
_verbose: bool
dataframe: bool
output_dimensions: int
Expand All @@ -277,6 +278,7 @@ def __init__(
weight_threshold: float = 0.005,
binary_encoder_cutoff: int = OHE_CUTOFF,
binary_encoder_nan_handler: Optional[str] = None,
cbn_sample_size: Optional[int] = None,
verbose: bool = False,
):
"""Create a data transformer.
Expand All @@ -291,16 +293,22 @@ def __init__(
binary_encoder_nan_handler:
If NaN's are produced from the binary encoding reverse transform, this drives how to replace those
NaN's with actual values
cbn_sample_size:
How many rows to sample for identifying clusters in float columns. None means no sampling.
verbose: Provide detailed logging on data transformation details.
"""
self._max_clusters = max_clusters
self._weight_threshold = weight_threshold
self._binary_encoder_cutoff = binary_encoder_cutoff
self._binary_encoder_han_handler = binary_encoder_nan_handler
self._cbn_sample_size = cbn_sample_size
self._verbose = verbose

def _fit_continuous(self, data: pd.DataFrame) -> ColumnTransformInfo:
"""Train Bayesian GMM for continuous columns."""
if self._cbn_sample_size is not None and self._cbn_sample_size < len(data):
# Train on only a sample of the data, if requested.
data = data.sample(n=self._cbn_sample_size)
column_name = data.columns[0]
gm = ClusterBasedNormalizer(
model_missing_values=True, max_clusters=min(len(data), 10)
Expand Down

0 comments on commit 17766ab

Please sign in to comment.