diff --git a/nowcasting_dataset/config/model.py b/nowcasting_dataset/config/model.py index 46bc2efe..0f2970ac 100644 --- a/nowcasting_dataset/config/model.py +++ b/nowcasting_dataset/config/model.py @@ -292,7 +292,7 @@ class Process(BaseModel): ), ) split_method: split.SplitMethod = Field( - split.SplitMethod.DAY, + split.SplitMethod.DAY_RANDOM_TEST_DATE, description=( "The method used to split the t0 datetimes into train, validation and test sets." ), diff --git a/nowcasting_dataset/dataset/split/split.py b/nowcasting_dataset/dataset/split/split.py index f873728e..65aa1b96 100644 --- a/nowcasting_dataset/dataset/split/split.py +++ b/nowcasting_dataset/dataset/split/split.py @@ -49,7 +49,7 @@ class SplitName(Enum): def split_data( datetimes: Union[List[pd.Timestamp], pd.DatetimeIndex], method: SplitMethod, - train_test_validation_split: Tuple[int] = (3, 1, 1), + train_test_validation_split: Tuple[int, int, int] = (3, 1, 1), train_test_validation_specific: TrainValidationTestSpecific = ( default_train_test_validation_specific ), diff --git a/nowcasting_dataset/manager.py b/nowcasting_dataset/manager.py index b68e80b1..f1115d28 100644 --- a/nowcasting_dataset/manager.py +++ b/nowcasting_dataset/manager.py @@ -166,8 +166,15 @@ def create_files_specifying_spatial_and_temporal_locations_of_each_example_if_ne t0_datetimes = self.get_t0_datetimes_across_all_data_sources( freq=self.config.process.t0_datetime_frequency ) + # TODO: move hard code values to config file #426 split_t0_datetimes = split.split_data( - datetimes=t0_datetimes, method=self.config.process.split_method + datetimes=t0_datetimes, + method=self.config.process.split_method, + train_test_validation_split=(3, 0, 1), + train_validation_test_datetime_split=[ + pd.Timestamp("2020-01-01"), + pd.Timestamp("2021-01-01"), + ], ) for split_name, datetimes_for_split in split_t0_datetimes._asdict().items(): n_batches = self._get_n_batches_for_split_name(split_name)