Skip to content

Commit

Permalink
Add test and fix for shape mismatch with binary encoded columns in AC…
Browse files Browse the repository at this point in the history
…TGAN anyway

GitOrigin-RevId: 06a104561b10046eeb1abd734925441ccf5cf3af
  • Loading branch information
kboyd committed Sep 19, 2023
1 parent faeeeae commit 58d894c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
8 changes: 7 additions & 1 deletion src/gretel_synthetics/actgan/actgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,17 @@ def _column_loss(
# Assumes tanh is only used for 1 column at a time, ie data.shape[1]==1
return functional.mse_loss(data_act, cond_vec, reduction="none").flatten()
elif activation_fn == torch.sigmoid:
return functional.binary_cross_entropy_with_logits(
bce = functional.binary_cross_entropy_with_logits(
data,
cond_vec,
reduction="none",
)
# bce is computed for each representation column, so shape is
# [batch_size, # of bits needed to represent unique values]. All
# other losses in this function return a 1-d tensor of shape
# [batch_size], so we take the mean loss across the representions
# columns to convert from [batch_size, k] to [batch_size] shape.
return bce.mean(dim=1)
else:
return functional.cross_entropy(
data, torch.argmax(cond_vec, dim=1), reduction="none"
Expand Down
9 changes: 7 additions & 2 deletions tests/actgan/test_actgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,22 @@ def test_binary_encoder_cutoff(test_df):


@pytest.mark.parametrize(
"log_frequency,conditional_vector_type,force_conditioning",
"log_frequency,conditional_vector_type,force_conditioning,binary_encoder_cutoff",
itertools.product(
[False, True],
[
ConditionalVectorType.SINGLE_DISCRETE,
ConditionalVectorType.ANYWAY,
],
[False, True],
[1, 500],
),
)
def test_actgan_implementation(
log_frequency, conditional_vector_type, force_conditioning
log_frequency,
conditional_vector_type,
force_conditioning,
binary_encoder_cutoff,
):
# Test basic actgan setup with various parameters and to confirm training
# and synthesize does not crash, i.e., all the tensor shapes match. Use a
Expand Down Expand Up @@ -88,6 +92,7 @@ def test_actgan_implementation(
conditional_vector_type=conditional_vector_type,
force_conditioning=force_conditioning,
conditional_select_mean_columns=conditional_select_mean_columns,
binary_encoder_cutoff=binary_encoder_cutoff,
)

# Check training
Expand Down

0 comments on commit 58d894c

Please sign in to comment.