Skip to content

Commit

Permalink
RDS-705: Fix bugs in actgan when there are no discrete columns
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 2e3c783bc3c620c434f8cadb3a3df5f57dfc363c
  • Loading branch information
kboyd committed Aug 25, 2023
1 parent dc03187 commit c8cfc19
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/gretel_synthetics/actgan/actgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,9 @@ def _actual_fit(self, train_data: TrainData) -> None:
== ConditionalVectorType.SINGLE_DISCRETE
):
if fake_cond_vec is None:
loss_reconstruction = 0.0
loss_reconstruction = torch.tensor(
0.0, dtype=torch.float32, device=self._device
)
else:
loss_reconstruction = self._cond_loss(
fake, fake_cond_vec, fake_column_mask
Expand Down Expand Up @@ -895,9 +897,13 @@ def sample(
if fixed_cond_vec_torch is None:
# In SINGLE_DISCRETE mode, so we generate a different cond vec
# for every batch to match expected discrete distributions.
cond_vec = torch.from_numpy(
self._condvec_sampler.sample_original_condvec(self._batch_size)
).to(self._device)
cond_vec_numpy = self._condvec_sampler.sample_original_condvec(
self._batch_size
)
if cond_vec_numpy is not None:
cond_vec = torch.from_numpy(cond_vec_numpy).to(self._device)
else:
cond_vec = None
else:
cond_vec = fixed_cond_vec_torch

Expand Down
84 changes: 84 additions & 0 deletions tests/actgan/test_actgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,93 @@ def test_actgan_implementation(
conditional_select_mean_columns=conditional_select_mean_columns,
)

# Check training
model.fit(df)

# Check unconditional generation
df_synth = model.sample(num_rows=100)

assert df_synth.shape == (100, len(df.columns))
assert list(df.columns) == list(df_synth.columns)

if conditional_vector_type != ConditionalVectorType.SINGLE_DISCRETE:
# Check conditional generation from numeric column
df_synth = model.sample_remaining_columns(
pd.DataFrame(
{
"int_column": [10] * 10,
}
)
)
assert list(df.columns) == list(df_synth.columns)

# Check conditional generation from discrete column
df_synth = model.sample_remaining_columns(
pd.DataFrame(
{
"categorical_column": ["b"] * 10,
}
)
)
assert list(df.columns) == list(df_synth.columns)


@pytest.mark.parametrize(
"log_frequency,conditional_vector_type,force_conditioning",
itertools.product(
[False, True],
[
ConditionalVectorType.SINGLE_DISCRETE,
ConditionalVectorType.ANYWAY,
],
[False, True],
),
)
def test_actgan_implementation_all_numeric(
log_frequency, conditional_vector_type, force_conditioning
):
# Test basic actgan setup with various parameters and to confirm training
# and synthesize does not crash, i.e., all the tensor shapes match. Use a
# small model and small dataset to keep tests quick.
n = 100
df = pd.DataFrame(
{
"int_column": np.random.randint(0, 200, size=n),
"float_column": np.random.random(size=n),
}
)

conditional_select_mean_columns = None
if conditional_vector_type != ConditionalVectorType.SINGLE_DISCRETE:
conditional_select_mean_columns = 2

model = ACTGAN(
epochs=1,
batch_size=20,
generator_dim=[32, 32],
discriminator_dim=[32, 32],
log_frequency=log_frequency,
conditional_vector_type=conditional_vector_type,
force_conditioning=force_conditioning,
conditional_select_mean_columns=conditional_select_mean_columns,
)

# Check training
model.fit(df)

# Check unconditional generation
df_synth = model.sample(num_rows=100)

assert df_synth.shape == (100, len(df.columns))
assert list(df.columns) == list(df_synth.columns)

if conditional_vector_type != ConditionalVectorType.SINGLE_DISCRETE:
# Check conditional generation from numeric column
df_synth = model.sample_remaining_columns(
pd.DataFrame(
{
"int_column": [10] * 10,
}
)
)
assert list(df.columns) == list(df_synth.columns)

0 comments on commit c8cfc19

Please sign in to comment.