From 2caa090ebdfb99d347803ae4a4f3b68e4aac151e Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Wed, 17 Nov 2021 11:58:18 +0000 Subject: [PATCH 1/2] update to split on training in one year, test after a date --- nowcasting_dataset/config/model.py | 2 +- nowcasting_dataset/dataset/split/split.py | 2 +- nowcasting_dataset/manager.py | 8 +++++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/nowcasting_dataset/config/model.py b/nowcasting_dataset/config/model.py index 46bc2efe..0f2970ac 100644 --- a/nowcasting_dataset/config/model.py +++ b/nowcasting_dataset/config/model.py @@ -292,7 +292,7 @@ class Process(BaseModel): ), ) split_method: split.SplitMethod = Field( - split.SplitMethod.DAY, + split.SplitMethod.DAY_RANDOM_TEST_DATE, description=( "The method used to split the t0 datetimes into train, validation and test sets." ), diff --git a/nowcasting_dataset/dataset/split/split.py b/nowcasting_dataset/dataset/split/split.py index f873728e..65aa1b96 100644 --- a/nowcasting_dataset/dataset/split/split.py +++ b/nowcasting_dataset/dataset/split/split.py @@ -49,7 +49,7 @@ class SplitName(Enum): def split_data( datetimes: Union[List[pd.Timestamp], pd.DatetimeIndex], method: SplitMethod, - train_test_validation_split: Tuple[int] = (3, 1, 1), + train_test_validation_split: Tuple[int, int, int] = (3, 1, 1), train_test_validation_specific: TrainValidationTestSpecific = ( default_train_test_validation_specific ), diff --git a/nowcasting_dataset/manager.py b/nowcasting_dataset/manager.py index b68e80b1..1e2d1912 100644 --- a/nowcasting_dataset/manager.py +++ b/nowcasting_dataset/manager.py @@ -167,7 +167,13 @@ def create_files_specifying_spatial_and_temporal_locations_of_each_example_if_ne freq=self.config.process.t0_datetime_frequency ) split_t0_datetimes = split.split_data( - datetimes=t0_datetimes, method=self.config.process.split_method + datetimes=t0_datetimes, + method=self.config.process.split_method, + train_test_validation_split=(3, 0, 1), + train_validation_test_datetime_split=[ + pd.Timestamp("2020-01-01"), + pd.Timestamp("2021-01-01"), + ], ) for split_name, datetimes_for_split in split_t0_datetimes._asdict().items(): n_batches = self._get_n_batches_for_split_name(split_name) From 88b70887db261d4508e774253f5a1f17c85decbd Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Wed, 17 Nov 2021 12:53:27 +0000 Subject: [PATCH 2/2] add TODO --- nowcasting_dataset/manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nowcasting_dataset/manager.py b/nowcasting_dataset/manager.py index 1e2d1912..f1115d28 100644 --- a/nowcasting_dataset/manager.py +++ b/nowcasting_dataset/manager.py @@ -166,6 +166,7 @@ def create_files_specifying_spatial_and_temporal_locations_of_each_example_if_ne t0_datetimes = self.get_t0_datetimes_across_all_data_sources( freq=self.config.process.t0_datetime_frequency ) + # TODO: move hard code values to config file #426 split_t0_datetimes = split.split_data( datetimes=t0_datetimes, method=self.config.process.split_method,