diff --git a/nowcasting_dataset/config/gcp.yaml b/nowcasting_dataset/config/gcp.yaml index b813abdb..3fecccfc 100644 --- a/nowcasting_dataset/config/gcp.yaml +++ b/nowcasting_dataset/config/gcp.yaml @@ -29,7 +29,7 @@ input_data: history_minutes: 30 pv_filename: gs://solar-pv-nowcasting-data/PV/Passive/ocf_formatted/v0/passiv.netcdf pv_metadata_filename: gs://solar-pv-nowcasting-data/PV/Passive/ocf_formatted/v0/system_metadata.csv - get_center: False + get_center: false satellite: forecast_minutes: 60 history_minutes: 30 diff --git a/nowcasting_dataset/config/model.py b/nowcasting_dataset/config/model.py index 9b55d19e..46bc2efe 100644 --- a/nowcasting_dataset/config/model.py +++ b/nowcasting_dataset/config/model.py @@ -98,6 +98,13 @@ class PV(DataSourceMixin): ) pv_image_size_pixels: int = IMAGE_SIZE_PIXELS_FIELD pv_meters_per_pixel: int = METERS_PER_PIXEL_FIELD + get_center: bool = Field( + False, + description="If the batches are centered on one PV system (or not). " + "The other options is to have one GSP at the center of a batch. " + "Typically, get_center would be set to true if and only if " + "PVDataSource is used to define the geospatial positions of each example.", + ) class Satellite(DataSourceMixin): diff --git a/nowcasting_dataset/config/on_premises.yaml b/nowcasting_dataset/config/on_premises.yaml index 416ae3a7..82874ac0 100644 --- a/nowcasting_dataset/config/on_premises.yaml +++ b/nowcasting_dataset/config/on_premises.yaml @@ -30,7 +30,7 @@ input_data: pv: pv_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/Passiv/ocf_formatted/v0/passiv.netcdf pv_metadata_filename: /mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/PV/Passiv/ocf_formatted/v0/system_metadata_OCF_ONLY.csv - get_center: False + get_center: false #---------------------- Satellite ------------- satellite: diff --git a/nowcasting_dataset/data_sources/data_source.py b/nowcasting_dataset/data_sources/data_source.py index 83c4acb9..762bf911 100644 --- a/nowcasting_dataset/data_sources/data_source.py +++ b/nowcasting_dataset/data_sources/data_source.py @@ -81,6 +81,7 @@ def __post_init__(self): def _get_start_dt( self, t0_dt: Union[pd.Timestamp, pd.DatetimeIndex] ) -> Union[pd.Timestamp, pd.DatetimeIndex]: + return t0_dt - self.history_duration def _get_end_dt( diff --git a/nowcasting_dataset/data_sources/nwp/nwp_data_source.py b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py index cd18a957..b77ae8da 100644 --- a/nowcasting_dataset/data_sources/nwp/nwp_data_source.py +++ b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py @@ -128,6 +128,11 @@ def _post_process_example(self, selected_data: xr.Dataset, t0_dt: pd.Timestamp) start_dt = self._get_start_dt(t0_dt) end_dt = self._get_end_dt(t0_dt) + # if t0_dt is not on the hour, e.g. 13.05. + # Then if the history_minutes is 1 hours, + # so start_dt will be 12.05, but we want to the 12.00 time step too + start_dt = start_dt.floor("H") + selected_data = selected_data.sel(target_time=slice(start_dt, end_dt)) selected_data = selected_data.rename({"target_time": "time", "variable": "channels"}) selected_data.data = selected_data.data.astype(np.float16) @@ -140,6 +145,7 @@ def datetime_index(self) -> pd.DatetimeIndex: nwp = self._open_data() else: nwp = self._data + # We need to return the `target_times` (the times the NWPs are _about_). # The `target_time` is the `init_time` plus the forecast horizon `step`. # `step` is an array of timedeltas, so we can just add `init_time` to `step`. @@ -148,6 +154,7 @@ def datetime_index(self) -> pd.DatetimeIndex: target_times = np.unique(target_times) target_times = np.sort(target_times) target_times = pd.DatetimeIndex(target_times) + return target_times @property diff --git a/nowcasting_dataset/data_sources/sun/sun_data_source.py b/nowcasting_dataset/data_sources/sun/sun_data_source.py index 790de1a7..400bc40c 100644 --- a/nowcasting_dataset/data_sources/sun/sun_data_source.py +++ b/nowcasting_dataset/data_sources/sun/sun_data_source.py @@ -85,7 +85,7 @@ def get_example( sun = azimuth.to_dataset(name="azimuth") sun["elevation"] = elevation - return Sun(sun) + return sun def _load(self): diff --git a/nowcasting_dataset/data_sources/topographic/topographic_data_source.py b/nowcasting_dataset/data_sources/topographic/topographic_data_source.py index 07c7c514..3b8af08a 100644 --- a/nowcasting_dataset/data_sources/topographic/topographic_data_source.py +++ b/nowcasting_dataset/data_sources/topographic/topographic_data_source.py @@ -103,7 +103,7 @@ def get_example( # change to dataset topo_xd = selected_data.to_dataset(name="data") - return Topographic(topo_xd) + return topo_xd def _post_process_example( self, selected_data: xr.DataArray, t0_dt: pd.Timestamp diff --git a/nowcasting_dataset/dataset/xr_utils.py b/nowcasting_dataset/dataset/xr_utils.py index 12e37129..79c7c75b 100644 --- a/nowcasting_dataset/dataset/xr_utils.py +++ b/nowcasting_dataset/dataset/xr_utils.py @@ -37,7 +37,7 @@ def convert_coordinates_to_indexes(dataset: xr.Dataset) -> xr.Dataset: This is useful to align multiple examples into a single batch. """ - assert type(dataset) == xr.Dataset + assert type(dataset) == xr.Dataset, f" Should be xr.Dataset but found {type(dataset)}" original_dim_names = dataset.dims diff --git a/tests/config/test.yaml b/tests/config/test.yaml index 37f846cc..feffb673 100644 --- a/tests/config/test.yaml +++ b/tests/config/test.yaml @@ -11,9 +11,11 @@ input_data: nwp_image_size_pixels: 2 nwp_zarr_path: tests/data/nwp_data/test.zarr history_minutes: 60 + forecast_minutes: 60 pv: pv_filename: tests/data/pv_data/test.nc pv_metadata_filename: tests/data/pv_metadata/UK_PV_metadata.csv + get_center: false satellite: satellite_channels: - HRV diff --git a/tests/data_sources/test_nwp_data_source.py b/tests/data_sources/test_nwp_data_source.py index 4103c5cb..0d69442e 100644 --- a/tests/data_sources/test_nwp_data_source.py +++ b/tests/data_sources/test_nwp_data_source.py @@ -41,7 +41,7 @@ def test_nwp_data_source_batch(): # noqa: D103 nwp.open() - t0_datetimes = nwp._data.init_time[2:6].values + t0_datetimes = [pd.Timestamp(t) for t in nwp._data.init_time[2:6].values] x = nwp._data.x[0:4].values y = nwp._data.y[0:4].values @@ -54,6 +54,29 @@ def test_nwp_data_source_batch(): # noqa: D103 assert batch.data.shape == (4, 1, 3, 2, 2) +def test_nwp_data_source_batch_not_on_hour(): # noqa: D103 + nwp = NWPDataSource( + zarr_path=NWP_ZARR_PATH, + history_minutes=60, + forecast_minutes=60, + channels=["t"], + ) + + nwp.open() + + t0_datetimes = [pd.Timestamp("2019-01-01 12:05:00")] + x = nwp._data.x[0:1].values + y = nwp._data.y[0:1].values + + batch = nwp.get_batch(t0_datetimes=t0_datetimes, x_locations=x, y_locations=y) + + # batch size 1 + # channel 1 + # time series, 1 int he past, 1 now, 1 in the future + # x,y of size 2 + assert batch.data.shape == (1, 1, 3, 2, 2) + + def test_nwp_get_contiguous_time_periods(): # noqa: D103 nwp = NWPDataSource( zarr_path=NWP_ZARR_PATH, diff --git a/tests/test_manager.py b/tests/test_manager.py index 836b5cc2..0cc91b1d 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -163,3 +163,21 @@ def test_save_config(): manager.save_yaml_configuration() assert os.path.exists(f"{dst_path}/configuration.yaml") + + +def test_run(): + """Test to initialize data sources and get batches""" + + manager = Manager() + local_path = Path(nowcasting_dataset.__file__).parent.parent + filename = local_path / "tests" / "config" / "test.yaml" + manager.load_yaml_configuration(filename=filename) + manager.initialise_data_sources() + + with tempfile.TemporaryDirectory() as local_temp_path, tempfile.TemporaryDirectory() as dst_path: # noqa 101 + + manager.config.output_data.filepath = Path(dst_path) + manager.local_temp_path = Path(local_temp_path) + + manager.create_files_specifying_spatial_and_temporal_locations_of_each_example_if_necessary() # noqa 101 + manager.create_batches(overwrite_batches=True)