Skip to content

Commit

Permalink
RDS-594: Resolve apply_along_axis bug in DGAN
Browse files Browse the repository at this point in the history
GitOrigin-RevId: a1384922727b1369b30d04e03225a51c766f1efb
  • Loading branch information
kboyd committed Mar 31, 2023
1 parent 5d387d9 commit 616c5ad
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 11 deletions.
37 changes: 26 additions & 11 deletions src/gretel_synthetics/timeseries_dgan/dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,18 +282,28 @@ def train_numpy(
if "ContinuousOutput" in str(val.__class__)
]

valid_examples = validation_check(
features[:, :, continuous_features_ind].astype("float")
)
# Only using valid examples for the entire dataset.
features = features[valid_examples]
# Apply linear interpolations for continuous features:
features[:, :, continuous_features_ind] = nan_linear_interpolation(
features[:, :, continuous_features_ind].astype("float")
)
if continuous_features_ind:
# DGAN does not handle nans in continuous features (though in
# categorical features, the encoding will treat nans as just another
# category). To ensure we have none of these problematic nans, we
# will interpolate to replace nans with actual float values, but if
# we have too many nans in an example interpolation is unreliable.

# Find valid examples based on minimal number of nans.
valid_examples = validation_check(
features[:, :, continuous_features_ind].astype("float")
)

if attributes is not None:
attributes = attributes[valid_examples]
# Only use valid examples for the entire dataset.
features = features[valid_examples]
if attributes is not None:
attributes = attributes[valid_examples]

# Apply linear interpolations to replace nans for continuous
# features:
features[:, :, continuous_features_ind] = nan_linear_interpolation(
features[:, :, continuous_features_ind].astype("float")
)

if self.additional_attribute_outputs:
(
Expand Down Expand Up @@ -424,6 +434,11 @@ def train_dataframe(
logging.warning(
f"The `example_id_column` was not provided, DGAN will autosplit dataset into sequences of size {self.config.max_sequence_len}!" # noqa
)
if len(df) < self.config.max_sequence_len:
raise ValueError(
f"Received {len(df)} rows in long data format, but DGAN requires max_sequence_len={self.config.max_sequence_len} rows to make a training example. Note training will require at least 2 examples." # noqa
)

df = df[
: math.floor(len(df) / self.config.max_sequence_len)
* self.config.max_sequence_len
Expand Down
92 changes: 92 additions & 0 deletions tests/timeseries_dgan/test_dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,3 +1779,95 @@ def test_save_and_load_dataframe_no_attributes(config: DGANConfig, tmp_path):

assert type(loaded_dg) == DGAN
assert list(synthetic_df.columns) == list(df.columns)


def test_dataframe_long_no_continuous_features(config: DGANConfig):
# Model should train with only discrete/categorical features
df = pd.DataFrame(
{
"a": np.random.choice(["foo", "bar", "baz"], size=9),
"b": np.random.choice(["yes", "no"], size=9),
}
)

config.max_sequence_len = 3
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

dg.train_dataframe(
df=df,
df_style=DfStyle.LONG,
discrete_columns=["a", "b"],
)


def test_dataframe_wide_no_continuous_features(config: DGANConfig):
# Model should train with only discrete/categorical features
df = pd.DataFrame(
{
"2023-01-01": np.random.choice(["yes", "no"], size=6),
"2023-01-02": np.random.choice(["yes", "no"], size=6),
"2023-01-03": np.random.choice(["yes", "no"], size=6),
}
)

config.max_sequence_len = 3
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

dg.train_dataframe(
df=df,
df_style=DfStyle.WIDE,
discrete_columns=["2023-01-01", "2023-01-02", "2023-01-03"],
)


def test_dataframe_long_partial_example(config: DGANConfig):
# Not enough rows to create a single example.
df = pd.DataFrame(
{
"a": np.random.choice(["foo", "bar", "baz"], size=9),
"b": np.random.random(size=9),
}
)

config.max_sequence_len = 10
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

with pytest.raises(ValueError, match="requires max_sequence_len"):
dg.train_dataframe(
df=df,
df_style=DfStyle.LONG,
discrete_columns=["a"],
)


def test_dataframe_long_one_and_partial_example(config: DGANConfig):
# Using auto split with more than max_sequence_len rows, but not enough to
# make 2 examples, which are required for training.
df = pd.DataFrame(
{
"a": np.random.choice(["foo", "bar", "baz"], size=9),
"b": np.random.random(size=9),
}
)

config.max_sequence_len = 5
config.sample_len = 1
config.epochs = 1

dg = DGAN(config=config)

with pytest.raises(ValueError, match="multiple examples to train"):
dg.train_dataframe(
df=df,
df_style=DfStyle.LONG,
discrete_columns=["a"],
)

0 comments on commit 616c5ad

Please sign in to comment.