diff --git a/nowcasting_dataset/data_sources/datasource_output.py b/nowcasting_dataset/data_sources/datasource_output.py index 709aa89a..dc1a5901 100644 --- a/nowcasting_dataset/data_sources/datasource_output.py +++ b/nowcasting_dataset/data_sources/datasource_output.py @@ -58,9 +58,14 @@ def check_nan_and_inf(self, data: xr.Dataset, variable_name: str = None): """Check that all values are non NaNs and not infinite""" if np.isnan(data).any(): - message = f"Some {self.__class__.__name__} data values are NaNs" + message = f"Some {self.__class__.__name__} data values are NaNs. " if variable_name is not None: message += f" ({variable_name})" + + # find out which example has nans in it + for i in range(data.shape[0]): + if np.isnan(data[i]).any(): + message += f" Nans in example {i}." logger.error(message) raise Exception(message) diff --git a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py index 93e1e33c..4916e48a 100644 --- a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py +++ b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py @@ -225,8 +225,8 @@ def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLoc if total_gsp_nan_count == 0: # get random GSP metadata - indexes = list( - self.rng.integers(low=0, high=len(self.metadata), size=len(t0_datetimes_utc)) + indexes = sorted( + list(self.rng.integers(low=0, high=len(self.metadata), size=len(t0_datetimes_utc))) ) metadata = self.metadata.iloc[indexes] diff --git a/nowcasting_dataset/data_sources/satellite/satellite_data_source.py b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py index e329ce3d..195e13a9 100644 --- a/nowcasting_dataset/data_sources/satellite/satellite_data_source.py +++ b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py @@ -54,7 +54,9 @@ def __post_init__( if self.is_live: # This is to account for the delay in satellite data self.total_seq_length = ( - self.history_length - (self.live_delay_minutes / self.time_resolution_minutes) + 1 + self.history_length + - int(self.live_delay_minutes / self.time_resolution_minutes) + + 1 ) self._shape_of_example = ( diff --git a/nowcasting_dataset/manager/manager_live.py b/nowcasting_dataset/manager/manager_live.py index cd11e9e0..84eee3f8 100644 --- a/nowcasting_dataset/manager/manager_live.py +++ b/nowcasting_dataset/manager/manager_live.py @@ -148,7 +148,7 @@ def sample_spatial_and_temporal_locations_for_examples( return locations - def create_batches(self) -> None: + def create_batches(self, use_async: Optional[bool] = True) -> None: """Create batches (if necessary). Make dirs: ` / / `. @@ -216,33 +216,38 @@ def create_batches(self) -> None: f"About to submit create_batches task for {data_source_name}, {split_name}" ) - # Sometimes when debuggin it is easy to use non async - # data_source.create_batches(**kwargs_for_create_batches) - - async_result = pool.apply_async( - data_source.create_batches, - kwds=kwargs_for_create_batches, - callback=partial( - callback, data_source_name=data_source_name, split_name=split_name - ), - error_callback=partial( - error_callback, - data_source_name=data_source_name, - split_name=split_name, - an_error_has_occured=an_error_has_occured, - ), - ) - async_results_from_create_batches.append(async_result) - - # Wait for all async_results to finish: - for async_result in async_results_from_create_batches: - async_result.wait() - if an_error_has_occured.is_set(): - # An error has occurred but, at this point in the code, we don't know which - # worker process raised the exception. But, with luck, the worker process - # will have logged an informative exception via the _error_callback func. - raise RuntimeError( - f"A worker process raised an exception whilst working on {split_name}!" + if ~use_async: + # Sometimes when debuggin it is easy to use non async + data_source.create_batches(**kwargs_for_create_batches) + else: + + async_result = pool.apply_async( + data_source.create_batches, + kwds=kwargs_for_create_batches, + callback=partial( + callback, data_source_name=data_source_name, split_name=split_name + ), + error_callback=partial( + error_callback, + data_source_name=data_source_name, + split_name=split_name, + an_error_has_occured=an_error_has_occured, + ), ) + async_results_from_create_batches.append(async_result) + + # Wait for all async_results to finish: + for async_result in async_results_from_create_batches: + async_result.wait() + if an_error_has_occured.is_set(): + # An error has occurred but, at this point in the code, + # we don't know which worker process raised the exception. + # But, with luck, the worker process + # will have logged an informative exception via the + # _error_callback func. + raise RuntimeError( + f"A worker process raised an exception whilst " + f"working on {split_name}!" + ) logger.info(f"Finished creating batches for {split_name}!") diff --git a/tests/manager/test_manager_live.py b/tests/manager/test_manager_live.py index 59189ced..00f86752 100644 --- a/tests/manager/test_manager_live.py +++ b/tests/manager/test_manager_live.py @@ -129,6 +129,36 @@ def test_batches(test_configuration_filename, sat, gsp): assert os.path.exists(f"{dst_path}/live/satellite/000000.nc") +def test_batches_not_async(test_configuration_filename, sat, gsp): + """Test that batches can be made""" + + manager = ManagerLive() + manager.load_yaml_configuration(filename=test_configuration_filename) + + with tempfile.TemporaryDirectory() as local_temp_path, tempfile.TemporaryDirectory() as dst_path: # noqa 101 + + # set local temp path, and dst path + manager.config.output_data.filepath = Path(dst_path) + manager.local_temp_path = Path(local_temp_path) + + # Set data sources + manager.data_sources = {"gsp": gsp, "satellite": sat} + manager.data_source_which_defines_geospatial_locations = gsp + + # make file for locations + manager.create_files_specifying_spatial_and_temporal_locations_of_each_example( + t0_datetime=datetime(2020, 4, 1, 13) + ) # noqa 101 + + # make batches + manager.create_batches(use_async=False) + + assert os.path.exists(f"{dst_path}/live") + assert os.path.exists(f"{dst_path}/live/gsp") + assert os.path.exists(f"{dst_path}/live/gsp/000000.nc") + assert os.path.exists(f"{dst_path}/live/satellite/000000.nc") + + def test_run_error(test_configuration_filename): """Test to initialize data sources and get batches"""