diff --git a/nowcasting_dataset/config/model.py b/nowcasting_dataset/config/model.py index 9fab7e0c..9d4d3777 100644 --- a/nowcasting_dataset/config/model.py +++ b/nowcasting_dataset/config/model.py @@ -329,11 +329,31 @@ class Process(BaseModel): split.SplitMethod.DAY_RANDOM_TEST_DATE, description=( "The method used to split the t0 datetimes into train, validation and test sets." + " If the split method produces no t0 datetimes for any split_name, then" + " n__batches must also be set to 0." + ), + ) + n_train_batches: int = Field( + 250, + description=( + "Number of train batches. Must be 0 if split_method produces no t0 datetimes for" + " the train split" + ), + ) + n_validation_batches: int = Field( + 0, # Currently not using any validation batches! + description=( + "Number of validation batches. Must be 0 if split_method produces no t0 datetimes for" + " the validation split" + ), + ) + n_test_batches: int = Field( + 10, + description=( + "Number of test batches. Must be 0 if split_method produces no t0 datetimes for" + " the test split." ), ) - n_train_batches: int = 250 - n_validation_batches: int = 10 - n_test_batches: int = 10 upload_every_n_batches: int = Field( 16, description=( diff --git a/nowcasting_dataset/config/on_premises.yaml b/nowcasting_dataset/config/on_premises.yaml index 6f9d8199..0baa6c00 100644 --- a/nowcasting_dataset/config/on_premises.yaml +++ b/nowcasting_dataset/config/on_premises.yaml @@ -63,7 +63,7 @@ input_data: topographic_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/Topographic/europe_dem_1km_osgb.tif output_data: - filepath: /mnt/storage_ssd_4tb/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/prepared_ML_training_data/v11/ + filepath: /mnt/storage_ssd_4tb/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/prepared_ML_training_data/v13_testing process: batch_size: 32 seed: 1234 diff --git a/nowcasting_dataset/manager.py b/nowcasting_dataset/manager.py index 88a8f78a..ac809a16 100644 --- a/nowcasting_dataset/manager.py +++ b/nowcasting_dataset/manager.py @@ -180,23 +180,39 @@ def create_files_specifying_spatial_and_temporal_locations_of_each_example_if_ne ], ) for split_name, datetimes_for_split in split_t0_datetimes._asdict().items(): - n_batches = self._get_n_batches_for_split_name(split_name) - n_examples = n_batches * self.config.process.batch_size + path_for_csv = self.config.output_data.filepath / split_name + n_batches_requested = self._get_n_batches_requested_for_split_name(split_name) + if (n_batches_requested == 0 and len(datetimes_for_split) != 0) or ( + len(datetimes_for_split) == 0 and n_batches_requested != 0 + ): + # TODO: Issue #450: Test this scenario! + msg = ( + f"For split {split_name}: n_{split_name}_batches={n_batches_requested} and" + f" {len(datetimes_for_split)=}! This is an error!" + f" If n_{split_name}_batches==0 then len(datetimes_for_split) must also" + f" equal 0, and visa-versa! Please check `n_{split_name}_batches` and" + " `split_method` in the config YAML!" + ) + logger.error(msg) + raise RuntimeError(msg) + if n_batches_requested == 0: + logger.info(f"0 batches requested for {split_name} so won't create {path_for_csv}") + continue + n_examples = n_batches_requested * self.config.process.batch_size logger.debug( - f"Creating {n_batches:,d} batches x {self.config.process.batch_size:,d} examples" - f" per batch = {n_examples:,d} examples for {split_name}." + f"Creating {n_batches_requested:,d} batches x {self.config.process.batch_size:,d}" + f" examples per batch = {n_examples:,d} examples for {split_name}." ) df_of_locations = self.sample_spatial_and_temporal_locations_for_examples( t0_datetimes=datetimes_for_split, n_examples=n_examples ) output_filename = self._filename_of_locations_csv_file(split_name) - path_for_csv = self.config.output_data.filepath / split_name logger.info(f"Making {path_for_csv} if it does not exist.") nd_fs_utils.makedirs(path_for_csv, exist_ok=True) logger.debug(f"Writing {output_filename}") df_of_locations.to_csv(output_filename) - def _get_n_batches_for_split_name(self, split_name: str) -> int: + def _get_n_batches_requested_for_split_name(self, split_name: str) -> int: return getattr(self.config.process, f"n_{split_name}_batches") def _filename_of_locations_csv_file(self, split_name: str) -> Path: @@ -287,6 +303,7 @@ def sample_spatial_and_temporal_locations_for_examples( Each row of each the DataFrame specifies the position of each example, using columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'. """ + assert len(t0_datetimes) > 0 shuffled_t0_datetimes = np.random.choice(t0_datetimes, size=n_examples) # TODO: Issue #304. Speed this up by splitting the shuffled_t0_datetimes across # multiple processors. Currently takes about half an hour for 25,000 batches. @@ -341,7 +358,7 @@ def _check_if_more_batches_are_required_for_split( first_batches_to_create: dict[split.SplitName, dict[str, int]], ) -> bool: """Returns True if batches still need to be created for any DataSource.""" - n_batches_requested = self._get_n_batches_for_split_name(split_name.value) + n_batches_requested = self._get_n_batches_requested_for_split_name(split_name.value) for data_source_name in self.data_sources: if first_batches_to_create[split_name][data_source_name] < n_batches_requested: return True @@ -375,7 +392,11 @@ def create_batches(self, overwrite_batches: bool) -> None: # Check if there's any work to do. if overwrite_batches: - splits_which_need_more_batches = [split_name for split_name in split.SplitName] + splits_which_need_more_batches = [ + split_name + for split_name in split.SplitName + if self._get_n_batches_requested_for_split_name(split_name.value) > 0 + ] else: splits_which_need_more_batches = self._find_splits_which_need_more_batches( first_batches_to_create