Skip to content

Commit

Permalink
Add tests and fix bug with dgan and float example id columns
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 9a998db3026a7e6ea8f20ee719531711163fed54
  • Loading branch information
kboyd committed Jan 27, 2023
1 parent b945989 commit 0e72ab1
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/gretel_synthetics/timeseries_dgan/dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,9 +1338,14 @@ def create(
# Assume all examples are for the same time points, e.g., always
# from 2020 even if df has examples from different years.
df_time_example = df[[time_column, example_id_column]]
time_values = df_time_example.groupby(example_id_column).apply(
pd.DataFrame.to_numpy
)[0][:, 0]
# Use first example grouping (iloc[0]), then grab the time
# column values used by that example from the numpy array
# ([:,0]).
time_values = (
df_time_example.groupby(example_id_column)
.apply(pd.DataFrame.to_numpy)
.iloc[0][:, 0]
)

time_column_values = list(sorted(time_values))
else:
Expand Down
52 changes: 52 additions & 0 deletions tests/timeseries_dgan/test_dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,29 @@ def test_train_dataframe_long_attribute_mismatch_nans(config: DGANConfig):
)


def test_train_dataframe_long_float_example_id(config: DGANConfig):
# Reproduce error from production where example_id_column is float and
# there's no 0.0 value. Should train with no errors.
n = 50
df = pd.DataFrame(
{
"example_id": np.repeat(np.arange(10.0, 15.0, 0.5), repeats=5),
"time": [str(x) for x in pd.date_range("2022-01-01", periods=n)],
"f": np.random.rand(n),
}
)

config.max_sequence_len = 5
dg = DGAN(config=config)

dg.train_dataframe(
df,
example_id_column="example_id",
time_column="time",
df_style=DfStyle.LONG,
)


def test_train_numpy_with_strings(config: DGANConfig):
n = 50
features = np.stack(
Expand Down Expand Up @@ -1460,6 +1483,35 @@ def test_long_data_frame_converter_example_id_object(df_long):
assert features.dtype == "float64"


def test_long_data_frame_converter_example_id_float():
# Check converter creation with a float example id column that has no values
# of 0.0.

df_long = pd.DataFrame(
{
"example_id": [1.0, 1.0, 2.0, 2.0],
"time": [
"2022-01-01",
"2022-01-02",
"2022-01-03",
"2022-01-04",
],
"f": [2.0, 3.0, 4.0, 5.0],
}
)

converter = _LongDataFrameConverter.create(
df_long,
example_id_column="example_id",
time_column="time",
)

attributes, features = converter.convert(df_long)
assert attributes is None
assert features.dtype == "float64"
assert features.shape == (2, 2, 1)


def test_long_data_frame_converter_save_and_load(df_long):
converter = _LongDataFrameConverter.create(
df_long,
Expand Down

0 comments on commit 0e72ab1

Please sign in to comment.