Skip to content

Commit

Permalink
PROD-353 / RD-195: ACTGAN memory improvements
Browse files Browse the repository at this point in the history
GitOrigin-RevId: d5690b17fb0baa4fc9ab83a90d6b8faeb6316146
  • Loading branch information
misberner committed Mar 29, 2023
1 parent 17766ab commit 059c669
Show file tree
Hide file tree
Showing 12 changed files with 1,299 additions and 493 deletions.
175 changes: 90 additions & 85 deletions src/gretel_synthetics/actgan/actgan.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import logging

from typing import Callable, Optional, Sequence
from typing import Callable, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
import torch

from gretel_synthetics.actgan.base import BaseSynthesizer, random_state
from gretel_synthetics.actgan.column_encodings import (
BinaryColumnEncoding,
FloatColumnEncoding,
OneHotColumnEncoding,
)
from gretel_synthetics.actgan.data_sampler import DataSampler
from gretel_synthetics.actgan.data_transformer import DataTransformer
from gretel_synthetics.actgan.structures import ActivationFn, EpochInfo
from gretel_synthetics.actgan.structures import ColumnType, EpochInfo
from gretel_synthetics.actgan.train_data import TrainData
from gretel_synthetics.typing import DFLike
from packaging import version
from rdt.transformers.base import BaseTransformer
from torch import optim
from torch.nn import (
BatchNorm1d,
Expand Down Expand Up @@ -113,6 +120,32 @@ def forward(self, input_):
return data


def _gumbel_softmax_stabilized(logits, tau=1, hard=False, eps=1e-10, dim=-1):
"""Deals with the instability of the gumbel_softmax for older versions of torch.
For more details about the issue:
https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing
Args:
logits […, num_features]:
Unnormalized log probabilities
tau:
Non-negative scalar temperature
hard (bool):
If True, the returned samples will be discretized as one-hot vectors,
but will be differentiated as if it is the soft sample in autograd
dim (int):
A dimension along which softmax will be computed. Default: -1.
Returns:
Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
"""
for i in range(10):
transformed = functional.gumbel_softmax(
logits, tau=tau, hard=hard, eps=eps, dim=dim
)
if not torch.isnan(transformed).any():
return transformed
raise ValueError("gumbel_softmax returning NaN.")


class ACTGANSynthesizer(BaseSynthesizer):
"""Anyway Conditional Table GAN Synthesizer.
Expand Down Expand Up @@ -229,91 +262,39 @@ def __init__(
self._data_sampler = None
self._generator = None

self._activation_fns: List[
Tuple[int, int, Callable[[torch.Tensor], torch.Tensor]]
] = []
self._cond_loss_col_ranges: List[Tuple[int, int, int, int]] = []

if self._epoch_callback is not None and not callable(self._epoch_callback):
raise ValueError("`epoch_callback` must be a callable or `None`")

@staticmethod
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
"""Deals with the instability of the gumbel_softmax for older versions of torch.
For more details about the issue:
https://drive.google.com/file/d/1AA5wPfZ1kquaRtVruCd6BiYZGcDeNxyP/view?usp=sharing
Args:
logits […, num_features]:
Unnormalized log probabilities
tau:
Non-negative scalar temperature
hard (bool):
If True, the returned samples will be discretized as one-hot vectors,
but will be differentiated as if it is the soft sample in autograd
dim (int):
A dimension along which softmax will be computed. Default: -1.
Returns:
Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
"""
if version.parse(torch.__version__) < version.parse("1.2.0"):
for i in range(10):
transformed = functional.gumbel_softmax(
logits, tau=tau, hard=hard, eps=eps, dim=dim
)
if not torch.isnan(transformed).any():
return transformed
raise ValueError("gumbel_softmax returning NaN.")

return functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim)
_gumbel_softmax = staticmethod(
functional.gumbel_softmax
if version.parse(torch.__version__) >= version.parse("1.2.0")
else _gumbel_softmax_stabilized
)

def _apply_activate(self, data):
"""Apply proper activation function to the output of the generator."""
data_t = []
st = 0
for column_info in self._transformer.output_info_list:
for span_info in column_info:
if span_info.activation_fn == ActivationFn.TANH:
ed = st + span_info.dim
data_t.append(torch.tanh(data[:, st:ed]))
st = ed
elif span_info.activation_fn == ActivationFn.SIGMOID:
ed = st + span_info.dim
data_t.append(torch.sigmoid(data[:, st:ed]))
st = ed
elif span_info.activation_fn == ActivationFn.SOFTMAX:
ed = st + span_info.dim
transformed = self._gumbel_softmax(data[:, st:ed], tau=0.2)
data_t.append(transformed)
st = ed
else:
raise ValueError(
f"Unexpected activation function {span_info.activation_fn}."
)
data_t = [
activation_fn(data[:, st:ed])
for st, ed, activation_fn in self._activation_fns
]

return torch.cat(data_t, dim=1)

def _cond_loss(self, data, c, m):
"""Compute the cross entropy loss on the fixed discrete column."""
loss = []
st = 0
st_c = 0
for column_info in self._transformer.output_info_list:
for span_info in column_info:
if (
len(column_info) != 1
or span_info.activation_fn != ActivationFn.SOFTMAX
):
# not discrete column
st += span_info.dim
else:
ed = st + span_info.dim
ed_c = st_c + span_info.dim
tmp = functional.cross_entropy(
data[:, st:ed],
torch.argmax(c[:, st_c:ed_c], dim=1),
reduction="none",
)
loss.append(tmp)
st = ed
st_c = ed_c
loss = [
functional.cross_entropy(
data[:, st:ed],
torch.argmax(c[:, st_c:ed_c], dim=1),
reduction="none",
)
for st, ed, st_c, ed_c in self._cond_loss_col_ranges
]

loss = torch.stack(loss, dim=1) # noqa: PD013

Expand Down Expand Up @@ -356,7 +337,7 @@ def fit(

def _pre_fit_transform(
self, train_data: DFLike, discrete_columns: Optional[Sequence[str]] = None
) -> np.ndarray:
) -> TrainData:
if discrete_columns is None:
discrete_columns = ()

Expand All @@ -370,11 +351,36 @@ def _pre_fit_transform(
)
self._transformer.fit(train_data, discrete_columns)

train_data = self._transformer.transform(train_data)
train_data_dec = self._transformer.transform_decoded(train_data)

return train_data
self._activation_fns = []
self._cond_loss_col_ranges = []

def _actual_fit(self, train_data: DFLike) -> None:
st = 0
st_c = 0
for column_info in train_data_dec.column_infos:
for enc in column_info.encodings:
ed = st + enc.encoded_dim
if isinstance(enc, FloatColumnEncoding):
self._activation_fns.append((st, ed, torch.tanh))
elif isinstance(enc, BinaryColumnEncoding):
self._activation_fns.append((st, ed, torch.sigmoid))
elif isinstance(enc, OneHotColumnEncoding):
self._activation_fns.append(
(st, ed, lambda data: self._gumbel_softmax(data, tau=0.2))
)
if column_info.column_type == ColumnType.DISCRETE:
ed_c = st_c + enc.encoded_dim
self._cond_loss_col_ranges.append((st, ed, st_c, ed_c))
st_c = ed_c
else:
raise ValueError(f"Unexpected column encoding {type(enc)}")

st = ed

return train_data_dec

def _actual_fit(self, train_data: TrainData) -> None:
"""Fit the ACTGAN Synthesizer models to the training data.
Args:
Expand All @@ -384,10 +390,11 @@ def _actual_fit(self, train_data: DFLike) -> None:
epochs = self._epochs

self._data_sampler = DataSampler(
train_data, self._transformer.output_info_list, self._log_frequency
train_data,
self._log_frequency,
)

data_dim = self._transformer.output_dimensions
data_dim = train_data.encoded_dim

self._generator = Generator(
self._embedding_dim + self._data_sampler.dim_cond_vec(),
Expand Down Expand Up @@ -557,9 +564,7 @@ def sample(
else:
condvec = self._data_sampler.sample_original_condvec(self._batch_size)

if condvec is None:
pass
else:
if condvec is not None:
c1 = condvec
c1 = torch.from_numpy(c1).to(self._device)
fakez = torch.cat([fakez, c1], dim=1)
Expand Down
69 changes: 64 additions & 5 deletions src/gretel_synthetics/actgan/actgan_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@

import logging

from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import numpy as np
import pandas as pd

from gretel_synthetics.actgan.actgan import ACTGANSynthesizer
from gretel_synthetics.actgan.columnar_df import ColumnarDF
from gretel_synthetics.detectors.sdv import SDVTableMetadata
from gretel_synthetics.utils import rdt_patches, torch_utils
from rdt.transformers import BaseTransformer
from sdv.tabular.base import BaseTabularModel

if TYPE_CHECKING:
from gretel_synthetics.actgan.structures import EpochInfo
from numpy.random import RandomState
from rdt.transformers import BaseTransformer
from sdv.constraints import Constraint
from sdv.metadata import Metadata
from torch import Generator
Expand Down Expand Up @@ -76,14 +78,36 @@ def fit(self, data: Union[pd.DataFrame, str]) -> None:
self._metadata._field_types = detector.field_types
self._metadata._field_transformers = detector.field_transformers

super().fit(data)

def _fit(self, table_data: pd.DataFrame) -> None:
# Metadata fitting will process the data column by column, dropping original columns
# after adding transformed ones. This is very expensive in a pd.DataFrame. Therefore,
# transform it to our internal, columnar representation. This is a bit dicey, because
# we don't implement the full pd.DataFrame interface, but as long as we don't change
# the RDT dependency, this shouldn't be an issue.
# Note that there is one internal RDT method that doesn't work with a ColumnarDF, which
# we'll have to monkeypatch to a compatible version.
data.reset_index(drop=True, inplace=True)
with _patch_rdt_add_columns_to_data():
super().fit(ColumnarDF.from_df(data))

def _fit(self, table_data: Union[pd.DataFrame, ColumnarDF]) -> None:
"""Fit the model to the table.
Args:
table_data: Data to be learned.
"""
if isinstance(table_data, ColumnarDF):
# The ColumnarDF representation is only relevant for the metadata transform, that
# happens between fit() and _fit(). Therefore, to minimize any required downstream
# code changes, revert to a standard pandas dataframe now (the memory requirements are
# the same when no changes to the structure are made).
columnar_df = table_data
table_data = columnar_df.to_df()
# Python doesn't have move semantics for arguments, and since fit() is an ancestor in
# the call tree, there is no way to let garbage collection take care of the ColumnarDF,
# as the calling context in fit() retains a reference. Therefore, clear the contents
# in-place to reduce memory pressure.
columnar_df.drop(columns=columnar_df.columns, inplace=True)

self._model: ACTGANSynthesizer = self._build_model()

categoricals = []
Expand All @@ -96,7 +120,8 @@ def _fit(self, table_data: pd.DataFrame) -> None:

else:
field_data = table_data[field].dropna()
if set(field_data.unique()) == {0.0, 1.0}:
field_data_arr = field_data.to_numpy()
if np.all((field_data_arr == 0.0) | (field_data_arr == 1.0)):
# booleans encoded as float values must be modeled as bool
field_data = field_data.astype(bool)

Expand Down Expand Up @@ -366,3 +391,37 @@ def sample(self, *args, **kwargs):
def sample_remaining_columns(self, *args, **kwargs):
with rdt_patches.patch_float_formatter_rounding_bug():
return super().sample_remaining_columns(*args, **kwargs)


@contextmanager
def _patch_rdt_add_columns_to_data():
prev_value = BaseTransformer._add_columns_to_data
try:
BaseTransformer._add_columns_to_data = staticmethod(
_add_columns_to_data_patched
)
yield None
finally:
BaseTransformer._add_columns_to_data = staticmethod(prev_value)


def _add_columns_to_data_patched(data, columns, column_names):
"""Add new columns to a ``pandas.DataFrame`` or ``ColumnarDF``.
This is a patched version of the ``BaseTransformer._add_columns_to_data``
static method, which avoids ``pd.concat`` and thus also works with
a ``ColumnarDF``, as opposed to just a ``pd.DataFrame``, and also preserves
the type of ``data``. It is fully compatible when used with a regular
``pd.DataFrame``, and would even be more efficient when used with that.
"""
if columns is not None:
if isinstance(columns, (pd.DataFrame, pd.Series)):
columns.index = data.index

if len(columns.shape) == 1:
data[column_names[0]] = columns
else:
new_data = pd.DataFrame(columns, columns=column_names)
data[column_names] = new_data

return data

0 comments on commit 059c669

Please sign in to comment.