Skip to content

Commit

Permalink
RDS 204 - supporting data sets smaller than batch size or batch size …
Browse files Browse the repository at this point in the history
…is not divisible by dataset length

Batch Size changes. Making DGAN work when len(dataset) % batch size != 0 or batch size > len(dataset)

GitOrigin-RevId: be3be4bfe853ce81e1731ac33d9ce4e3e9bcb425
  • Loading branch information
santhosh97 committed Jun 23, 2022
1 parent 2deea52 commit 57c4cec
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/gretel_synthetics/timeseries_dgan/dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def _build(
self.config.feature_num_units,
self.config.feature_num_layers,
)

self.generator.to(self.device, non_blocking=True)

if self.attribute_outputs is None:
Expand Down Expand Up @@ -533,7 +534,7 @@ def _train(
dataset,
self.config.batch_size,
shuffle=True,
drop_last=True,
drop_last=False,
num_workers=2,
prefetch_factor=4,
persistent_workers=True,
Expand Down Expand Up @@ -571,15 +572,12 @@ def _train(

for real_batch in loader:
global_step += 1

with torch.cuda.amp.autocast(
enabled=self.config.mixed_precision_training
):
attribute_noise = self.attribute_noise_func(
real_batch[0].shape[0]
).to(self.device, non_blocking=True)
feature_noise = self.feature_noise_func(real_batch[0].shape[0]).to(
self.device, non_blocking=True
)
attribute_noise = self.attribute_noise_func(real_batch[0].shape[0])
feature_noise = self.feature_noise_func(real_batch[0].shape[0])

# Both real and generated batch are always three element tuple of
# tensors. The tuple is structured as follows: (attribute_output,
Expand Down
59 changes: 59 additions & 0 deletions tests/timeseries_dgan/test_dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,65 @@ def test_train_dataframe(config: DGANConfig):
assert list(synthetic_df.columns) == list(df.columns)


def test_train_dataframe_batch_size_larger_than_dataset(config: DGANConfig):
n = 50
df = pd.DataFrame(
{
"a1": np.random.randint(0, 3, size=n),
"a2": np.random.rand(n),
"2022-01-01": np.random.rand(n),
"2022-02-01": np.random.rand(n),
"2022-03-01": np.random.rand(n),
"2022-04-01": np.random.rand(n),
}
)

config.max_sequence_len = 4
config.sample_len = 1
config.batch_size = 1000

dg = DGAN(config=config)

dg.train_dataframe(
df=df,
df_attribute_columns=["a1", "a2"],
attribute_types=[OutputType.DISCRETE, OutputType.CONTINUOUS],
)

synthetic_df = dg.generate_dataframe(5)
assert synthetic_df.shape == (5, 6)
assert list(synthetic_df.columns) == list(df.columns)


def test_train_dataframe_batch_size_not_divisible_by_dataset_length(config: DGANConfig):
n = 1000
df = pd.DataFrame(
{
"a1": np.random.randint(0, 3, size=n),
"a2": np.random.rand(n),
"2022-01-01": np.random.rand(n),
"2022-02-01": np.random.rand(n),
"2022-03-01": np.random.rand(n),
"2022-04-01": np.random.rand(n),
}
)

config.max_sequence_len = 4
config.sample_len = 2
config.batch_size = 300
dg = DGAN(config=config)

dg.train_dataframe(
df=df,
df_attribute_columns=["a1", "a2"],
attribute_types=[OutputType.DISCRETE, OutputType.CONTINUOUS],
)

synthetic_df = dg.generate_dataframe(5)
assert synthetic_df.shape == (5, 6)
assert list(synthetic_df.columns) == list(df.columns)


def test_train_dataframe_no_attributes(config: DGANConfig):
n = 50
df = pd.DataFrame(
Expand Down

0 comments on commit 57c4cec

Please sign in to comment.