Skip to content

Commit

Permalink
Actgan sal prep
Browse files Browse the repository at this point in the history
* Prep for ACTGAN SAL release

GitOrigin-RevId: 18a0b574982197ec87aa45f155e60c1bceb188c6
  • Loading branch information
johntmyers committed Dec 20, 2022
1 parent c678d70 commit 553c1cf
Show file tree
Hide file tree
Showing 10 changed files with 384 additions and 162 deletions.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Modules
api/batch.rst
utils/index.rst
models/timeseries_dgan.rst
models/actgan.rst


Indices and tables
Expand Down
15 changes: 15 additions & 0 deletions docs/models/actgan.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
ACTGAN
======

The ACTGAN sub-package contains an alternate implementation of the SDV CTGAN model. It
provides some improvement and automation around automatic detection of datetime fields
and optional usage of a binary encoder for discrete columns for better memory usage.

Please see the "ACTGAN_Demo" Notebook in the "examples" directory in the repository root.


.. automodule:: gretel_synthetics.actgan.actgan_wrapper
:members:

.. automodule:: gretel_synthetics.actgan.structures
:members:
141 changes: 141 additions & 0 deletions examples/ACTGAN_Demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7e237ea2",
"metadata": {},
"source": [
"# ACTGAN\n",
"\n",
"This Notebook provides an overview of how to use ACTGAN. It is compatable with SDV CTGAN from version 0.17.X of SDV. The notable changes are exposed through additional keyword parameters when creating the `ACTGAN` instance. Specifically:\n",
"\n",
"- Binary encoding usage. CTGAN uses One Hot Encoding for discrete/categorical columns which can lead to memory issues depending on the cardinality of these columns. You may now specify a cardinality cutoff that will trigger the switch to using a binary encoder, which saves significant memory usage.\n",
"\n",
"\n",
"- Auto datetime detection. When enabled, each column will be scanned for potential DateTime values. The strfmt of each column will be determined and the underlying SDV Table Metadata will be automatically configured to use a `UnixTimestampEncoder` for these columns. This will give better variability during data sampling and prevent DateTime\n",
"columns from being treated as categorical.\n",
"\n",
"- Empty field detection. Any columns that are empty (or all NaN) will be transformed for fitting and reverse transformed to being empty during sampling. Empty columns can cause training execptions otherwise.\n",
"\n",
"\n",
"- Epoch callback. Optionally allow the passing of an `EpochInfo` object to any callable when a training epoch completes."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2347c31b",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"from gretel_synthetics.actgan import ACTGAN"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d0676a7",
"metadata": {},
"outputs": [],
"source": [
"train_df = pd.read_csv(\"http://gretel-public-website.s3-website-us-west-2.amazonaws.com/datasets/311_call_center_10k.csv\")\n",
"train_df.head()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05d2208a",
"metadata": {},
"outputs": [],
"source": [
"class EpochTracker:\n",
" \"\"\"\n",
" Simple example that just accumulates ``EpochInfo`` events,\n",
" but demonstrates how you can route epoch information to\n",
" arbitrary callables during model fitting.\n",
" \"\"\"\n",
" \n",
" def __init__(self):\n",
" self.epochs = []\n",
" \n",
" def add(self, epoch_data):\n",
" self.epochs.append(epoch_data)\n",
" \n",
"epoch_tracker = EpochTracker()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f127e44",
"metadata": {},
"outputs": [],
"source": [
"model = ACTGAN(\n",
" verbose=True,\n",
" binary_encoder_cutoff=10, # use a binary encoder for data transforms if the cardinality of a column is below this value\n",
" auto_transform_datetimes=True,\n",
" epochs=100,\n",
" epoch_callback=epoch_tracker.add\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b8ef7f0",
"metadata": {},
"outputs": [],
"source": [
"model.fit(train_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42618395",
"metadata": {},
"outputs": [],
"source": [
"# Tracked and stored epoch information\n",
"\n",
"epoch_tracker.epochs[42]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e5acde6",
"metadata": {},
"outputs": [],
"source": [
"syn_df = model.sample(100)\n",
"syn_df.head()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pandas>=1.1.0
sentencepiece==0.1.97
smart_open>=2.1.0,<6.0
tensorflow==2.8.0
sdv<0.18.0
tensorflow_estimator==2.8
tensorflow_privacy==0.7.3
tensorflow_probability==0.16.0
Expand Down
108 changes: 47 additions & 61 deletions src/gretel_synthetics/actgan/actgan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""ACTGANSynthesizer module."""

import logging
import warnings

from typing import Callable, Optional, Sequence

Expand Down Expand Up @@ -119,54 +116,52 @@ def forward(self, input_):
class ACTGANSynthesizer(BaseSynthesizer):
"""Anyway Conditional Table GAN Synthesizer.
This is the core class of the ACTGAN project
For more details about the process, please check the [Modeling Tabular data using
Conditional GAN](https://arxiv.org/abs/1907.00503) paper and our blogs
This is the core class of the ACTGAN interface.
Args:
embedding_dim (int):
embedding_dim:
Size of the random sample passed to the Generator. Defaults to 128.
generator_dim (tuple or list of ints):
generator_dim:
Size of the output samples for each one of the Residuals. A Residual Layer
will be created for each one of the values provided. Defaults to (256, 256).
discriminator_dim (tuple or list of ints):
discriminator_dim:
Size of the output samples for each one of the Discriminator Layers. A Linear Layer
will be created for each one of the values provided. Defaults to (256, 256).
generator_lr (float):
generator_lr:
Learning rate for the generator. Defaults to 2e-4.
generator_decay (float):
generator_decay:
Generator weight decay for the Adam Optimizer. Defaults to 1e-6.
discriminator_lr (float):
discriminator_lr:
Learning rate for the discriminator. Defaults to 2e-4.
discriminator_decay (float):
discriminator_decay:
Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6.
batch_size (int):
batch_size:
Number of data samples to process in each step.
discriminator_steps (int):
discriminator_steps:
Number of discriminator updates to do for each generator update.
From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper
default is 5. Default used is 1 to match original CTGAN implementation.
binary_encoder_cutoff (int):
binary_encoder_cutoff:
For any given column, the number of unique values that should exist before
switching over to binary encoding instead of OHE. This will help reduce
memory consumption for datasets with a lot of unique values.
binary_encoder_nan_handler: (str):
binary_encoder_nan_handler:
Binary encoding currently may produce errant NaN values during reverse transformation. By default
these NaN's will be left in place, however if this value is set to "mode" then those NaN's will
be replaced by a random value that is a known mode for a given column.
log_frequency (boolean):
log_frequency:
Whether to use log frequency of categorical levels in conditional
sampling. Defaults to ``True``.
verbose (boolean):
verbose:
Whether to have log progress results. Defaults to ``False``.
epochs (int):
epochs:
Number of training epochs. Defaults to 300.
epoch_callback (callable, optional):
epoch_callback:
If set to a callable, call the function with `EpochInfo` as the arg
pac (int):
pac:
Number of samples to group together when applying the discriminator.
Defaults to 10.
cuda (bool):
cuda:
Whether to attempt to use cuda for GPU computation.
If this is False or CUDA is not available, CPU will be used.
Defaults to ``True``.
Expand Down Expand Up @@ -324,9 +319,8 @@ def _validate_discrete_columns(
"""Check whether ``discrete_columns`` exists in ``train_data``.
Args:
train_data (numpy.ndarray or pandas.DataFrame):
Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
discrete_columns (list-like):
train_data: Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
discrete_columns:
List of discrete columns to be used to generate the Conditional
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
Expand All @@ -347,46 +341,41 @@ def _validate_discrete_columns(

@random_state
def fit(
self,
train_data: DFLike,
discrete_columns: Optional[Sequence[str]] = None,
epochs: Optional[int] = None,
self, train_data: DFLike, discrete_columns: Optional[Sequence[str]] = None
) -> None:
"""Fit the ACTGAN Synthesizer models to the training data.
transformed_train_data = self._pre_fit_transform(
train_data, discrete_columns=discrete_columns
)
self._actual_fit(transformed_train_data)

Args:
train_data (numpy.ndarray or pandas.DataFrame):
Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
discrete_columns (list-like):
List of discrete columns to be used to generate the Conditional
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
"""
def _pre_fit_transform(
self, train_data: DFLike, discrete_columns: Optional[Sequence[str]] = None
) -> np.ndarray:
if discrete_columns is None:
discrete_columns = ()

self._validate_discrete_columns(train_data, discrete_columns)

if epochs is None:
epochs = self._epochs
else:
warnings.warn(
(
"`epochs` argument in `fit` method has been deprecated and will be removed "
"in a future version. Please pass `epochs` to the constructor instead"
),
DeprecationWarning,
)

self._transformer = DataTransformer(
binary_encoder_cutoff=self._binary_encoder_cutoff,
binary_encoder_nan_handler=self._binary_encoder_nan_handler,
verbose=self._verbose,
)
self._transformer.fit(train_data, discrete_columns)

train_data = self._transformer.transform(train_data)

return train_data

def _actual_fit(self, train_data: DFLike) -> None:
"""Fit the ACTGAN Synthesizer models to the training data.
Args:
train_data: Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
"""

epochs = self._epochs

self._data_sampler = DataSampler(
train_data, self._transformer.output_info_list, self._log_frequency
)
Expand Down Expand Up @@ -424,7 +413,7 @@ def fit(

steps_per_epoch = max(len(train_data) // self._batch_size, 1)
for i in range(epochs):
for id_ in range(steps_per_epoch):
for _ in range(steps_per_epoch):

for n in range(self._discriminator_steps):
fakez = torch.normal(mean=mean, std=std)
Expand Down Expand Up @@ -522,19 +511,16 @@ def sample(
n: int,
condition_column: Optional[str] = None,
condition_value: Optional[str] = None,
):
) -> pd.DataFrame:
"""Sample data similar to the training data.
Choosing a condition_column and condition_value will increase the probability of the
discrete condition_value happening in the condition_column.
Args:
n (int):
Number of rows to sample.
condition_column (string):
Name of a discrete column.
condition_value (string):
Name of the category in the condition_column which we wish to increase the
n: Number of rows to sample.
condition_column: Name of a discrete column.
condition_value: Name of the category in the condition_column which we wish to increase the
probability of happening.
Returns:
Expand All @@ -554,7 +540,7 @@ def sample(

steps = n // self._batch_size + 1
data = []
for i in range(steps):
for _ in range(steps):
mean = torch.zeros(self._batch_size, self._embedding_dim)
std = mean + 1
fakez = torch.normal(mean=mean, std=std).to(self._device)
Expand All @@ -581,7 +567,7 @@ def sample(
transformed_data = self._transformer.inverse_transform(data)
return transformed_data

def set_device(self, device):
def set_device(self, device: str) -> None:
"""Set the `device` to be used ('GPU' or 'CPU)."""
self._device = device
if self._generator is not None:
Expand Down

0 comments on commit 553c1cf

Please sign in to comment.