Skip to content

Commit

Permalink
RDS-293: Fix bugs and add tests around batch and data set size.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: ec48898f6e4f17aa9088e10936939f6f44070fb3
  • Loading branch information
kboyd committed Jul 8, 2022
1 parent e29b6c7 commit 19af35c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/gretel_synthetics/timeseries_dgan/dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,12 +540,23 @@ def _train(
Args:
dataset: torch Dataset containing tuple of (attributes, additional_attributes, features)
"""
if len(dataset) <= 1:
raise ValueError(
f"DGAN requires multiple examples to train, received {len(dataset)} example."
+ "Consider splitting a single long sequence into many subsequences to obtain "
+ "multiple examples for training."
)

# Our optimization setup does not work on batches of size 1. So if
# drop_last=False would produce a last batch of size of 1, we use
# drop_last=True instead.
drop_last = len(dataset) % self.config.batch_size == 1

loader = DataLoader(
dataset,
self.config.batch_size,
shuffle=True,
drop_last=False,
drop_last=drop_last,
num_workers=2,
prefetch_factor=4,
persistent_workers=True,
Expand Down
57 changes: 57 additions & 0 deletions tests/timeseries_dgan/test_dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,29 @@ def test_train_numpy_no_attributes_2(config: DGANConfig):
assert synthetic_features.shape == (n_samples, features.shape[1], features.shape[2])


def test_train_numpy_batch_size_of_1(config: DGANConfig):
# Check model trains when (# of examples) % batch_size == 1.

config.batch_size = 10
config.epochs = 1

features = np.random.rand(91, 20, 2)
attributes = np.random.randint(0, 3, (91, 1))

model = DGAN(config=config)
model.train_numpy(
features=features,
attributes=attributes,
feature_types=[OutputType.CONTINUOUS] * 2,
attribute_types=[OutputType.DISCRETE],
)

synthetic_attributes, synthetic_features = model.generate_numpy(11)
assert synthetic_attributes is not None
assert synthetic_attributes.shape == (11, 1)
assert synthetic_features.shape == (11, 20, 2)


def test_train_dataframe_wide(config: DGANConfig):
n = 50
df = pd.DataFrame(
Expand Down Expand Up @@ -311,11 +334,45 @@ def test_train_dataframe_batch_size_larger_than_dataset(config: DGANConfig):
df_style=DfStyle.WIDE,
)

# We want to confirm the training does update the model params, so we create
# some fixed noise inputs and check if they produce different outputs before
# and after some more training.
attribute_noise = dg.attribute_noise_func(50)
feature_noise = dg.feature_noise_func(50)
before_attributes, before_features = dg.generate_numpy(
attribute_noise=attribute_noise, feature_noise=feature_noise
)

dg.train_dataframe(
df=df,
attribute_columns=["a1", "a2"],
discrete_columns=["a1"],
df_style=DfStyle.WIDE,
)

after_attributes, after_features = dg.generate_numpy(
attribute_noise=attribute_noise, feature_noise=feature_noise
)
# Generated data should be different.
assert np.any(np.not_equal(before_attributes, after_attributes))
assert np.any(np.not_equal(before_features, after_features))

synthetic_df = dg.generate_dataframe(5)
assert synthetic_df.shape == (5, 6)
assert list(synthetic_df.columns) == list(df.columns)


def test_train_1_example(config: DGANConfig, feature_data):
features, feature_types = feature_data
# Keep 1 example
features = features[0:1, :]

dg = DGAN(config=config)

with pytest.raises(ValueError, match="multiple examples to train"):
dg.train_numpy(features=features, feature_types=feature_types)


def test_train_dataframe_batch_size_not_divisible_by_dataset_length(config: DGANConfig):
n = 1000
df = pd.DataFrame(
Expand Down

0 comments on commit 19af35c

Please sign in to comment.