diff --git a/ocf_datapipes/convert/numpy_batch/__init__.py b/ocf_datapipes/convert/numpy_batch/__init__.py index 13dd7dfac..e576e4860 100644 --- a/ocf_datapipes/convert/numpy_batch/__init__.py +++ b/ocf_datapipes/convert/numpy_batch/__init__.py @@ -1 +1,7 @@ """Conversion from Xarray to NumpyBatch""" +from .gsp import convert_gsp_to_numpy_batch +from .nwp import convert_nwp_to_numpy_batch +from .pv import convert_pv_to_numpy_batch +from .satellite import convert_satellite_to_numpy_batch +from .sensor import convert_sensor_to_numpy_batch +from .wind import convert_wind_to_numpy_batch diff --git a/ocf_datapipes/convert/numpy_batch/gsp.py b/ocf_datapipes/convert/numpy_batch/gsp.py index f860939e5..0b6909aad 100644 --- a/ocf_datapipes/convert/numpy_batch/gsp.py +++ b/ocf_datapipes/convert/numpy_batch/gsp.py @@ -9,6 +9,31 @@ logger = logging.getLogger(__name__) +def convert_gsp_to_numpy_batch(xr_data): + """Convert from Xarray to NumpyBatch""" + + example: NumpyBatch = { + BatchKey.gsp: xr_data.values, + BatchKey.gsp_t0_idx: xr_data.attrs["t0_idx"], + BatchKey.gsp_id: xr_data.gsp_id.values, + BatchKey.gsp_nominal_capacity_mwp: xr_data.isel(time_utc=0)["nominal_capacity_mwp"].values, + BatchKey.gsp_effective_capacity_mwp: ( + xr_data.isel(time_utc=0)["effective_capacity_mwp"].values + ), + BatchKey.gsp_time_utc: datetime64_to_float(xr_data["time_utc"].values), + } + + # Coordinates + for batch_key, dataset_key in ( + (BatchKey.gsp_y_osgb, "y_osgb"), + (BatchKey.gsp_x_osgb, "x_osgb"), + ): + if dataset_key in xr_data.coords.keys(): + example[batch_key] = xr_data[dataset_key].values + + return example + + @functional_datapipe("convert_gsp_to_numpy_batch") class ConvertGSPToNumpyBatchIterDataPipe(IterDataPipe): """Convert GSP Xarray to NumpyBatch""" @@ -25,29 +50,5 @@ def __init__(self, source_datapipe: IterDataPipe): def __iter__(self) -> NumpyBatch: """Convert from Xarray to NumpyBatch""" - logger.debug("Converting GSP to numpy to batch") for xr_data in self.source_datapipe: - example: NumpyBatch = { - BatchKey.gsp: xr_data.values, - BatchKey.gsp_t0_idx: xr_data.attrs["t0_idx"], - BatchKey.gsp_id: xr_data.gsp_id.values, - BatchKey.gsp_nominal_capacity_mwp: xr_data.isel(time_utc=0)[ - "nominal_capacity_mwp" - ].values, - BatchKey.gsp_effective_capacity_mwp: ( - xr_data.isel(time_utc=0)["effective_capacity_mwp"].values - ), - BatchKey.gsp_time_utc: datetime64_to_float(xr_data["time_utc"].values), - } - - # Coordinates - for batch_key, dataset_key in ( - (BatchKey.gsp_y_osgb, "y_osgb"), - (BatchKey.gsp_x_osgb, "x_osgb"), - ): - if dataset_key in xr_data.coords.keys(): - values = xr_data[dataset_key].values - # Expand dims so AddFourierSpaceTime works! - example[batch_key] = values # np.expand_dims(values, axis=1) - - yield example + yield convert_gsp_to_numpy_batch(xr_data) diff --git a/ocf_datapipes/convert/numpy_batch/nwp.py b/ocf_datapipes/convert/numpy_batch/nwp.py index 6dc3d953f..f17240e27 100644 --- a/ocf_datapipes/convert/numpy_batch/nwp.py +++ b/ocf_datapipes/convert/numpy_batch/nwp.py @@ -6,6 +6,31 @@ from ocf_datapipes.utils.utils import datetime64_to_float +def convert_nwp_to_numpy_batch(xr_data): + """Convert from Xarray to NWPBatchKey""" + + example: NWPNumpyBatch = { + NWPBatchKey.nwp: xr_data.values, + NWPBatchKey.nwp_t0_idx: xr_data.attrs["t0_idx"], + NWPBatchKey.nwp_channel_names: xr_data.channel.values, + NWPBatchKey.nwp_init_time_utc: datetime64_to_float(xr_data.init_time_utc.values), + NWPBatchKey.nwp_step: (xr_data.step.values / np.timedelta64(1, "h")).astype(np.int64), + } + + if "target_time_utc" in xr_data.coords: + target_time = xr_data.target_time_utc.values + example[NWPBatchKey.nwp_target_time_utc] = datetime64_to_float(target_time) + + for batch_key, dataset_key in ( + (NWPBatchKey.nwp_y_osgb, "y_osgb"), + (NWPBatchKey.nwp_x_osgb, "x_osgb"), + ): + if dataset_key in xr_data.coords: + example[batch_key] = xr_data[dataset_key].values + + return example + + @functional_datapipe("convert_nwp_to_numpy_batch") class ConvertNWPToNumpyBatchIterDataPipe(IterDataPipe): """Convert NWP Xarray objects to NWPNumpyBatch""" @@ -23,26 +48,4 @@ def __init__(self, source_datapipe: IterDataPipe): def __iter__(self) -> NWPNumpyBatch: """Convert from Xarray to NWPBatchKey""" for xr_data in self.source_datapipe: - example: NWPNumpyBatch = { - NWPBatchKey.nwp: xr_data.values, - NWPBatchKey.nwp_t0_idx: xr_data.attrs["t0_idx"], - } - if "target_time_utc" in xr_data.coords: - target_time = xr_data.target_time_utc.values - example[NWPBatchKey.nwp_target_time_utc] = datetime64_to_float(target_time) - example[NWPBatchKey.nwp_channel_names] = xr_data.channel.values - example[NWPBatchKey.nwp_step] = (xr_data.step.values / np.timedelta64(1, "h")).astype( - np.int64 - ) - example[NWPBatchKey.nwp_init_time_utc] = datetime64_to_float( - xr_data.init_time_utc.values - ) - - for batch_key, dataset_key in ( - (NWPBatchKey.nwp_y_osgb, "y_osgb"), - (NWPBatchKey.nwp_x_osgb, "x_osgb"), - ): - if dataset_key in xr_data.coords.keys(): - example[batch_key] = xr_data[dataset_key].values - - yield example + yield convert_nwp_to_numpy_batch(xr_data) diff --git a/ocf_datapipes/convert/numpy_batch/pv.py b/ocf_datapipes/convert/numpy_batch/pv.py index fd9a478cc..1476c7140 100644 --- a/ocf_datapipes/convert/numpy_batch/pv.py +++ b/ocf_datapipes/convert/numpy_batch/pv.py @@ -10,6 +10,23 @@ logger = logging.getLogger(__name__) +def convert_pv_to_numpy_batch(xr_data): + """Convert PV Xarray to NumpyBatch""" + example: NumpyBatch = { + BatchKey.pv: xr_data.values, + BatchKey.pv_t0_idx: xr_data.attrs["t0_idx"], + BatchKey.pv_ml_id: xr_data["ml_id"].values, + BatchKey.pv_id: xr_data["pv_system_id"].values.astype(np.float32), + BatchKey.pv_observed_capacity_wp: (xr_data["observed_capacity_wp"].values), + BatchKey.pv_nominal_capacity_wp: (xr_data["nominal_capacity_wp"].values), + BatchKey.pv_time_utc: datetime64_to_float(xr_data["time_utc"].values), + BatchKey.pv_latitude: xr_data["latitude"].values, + BatchKey.pv_longitude: xr_data["longitude"].values, + } + + return example + + @functional_datapipe("convert_pv_to_numpy_batch") class ConvertPVToNumpyBatchIterDataPipe(IterDataPipe): """Convert PV Xarray to NumpyBatch""" @@ -25,20 +42,6 @@ def __init__(self, source_datapipe: IterDataPipe): self.source_datapipe = source_datapipe def __iter__(self) -> NumpyBatch: - """Iterate and convert PV Xarray to NumpyBatch""" + """Convert PV Xarray to NumpyBatch""" for xr_data in self.source_datapipe: - logger.debug("Converting PV xarray to numpy example") - - example: NumpyBatch = { - BatchKey.pv: xr_data.values, - BatchKey.pv_t0_idx: xr_data.attrs["t0_idx"], - BatchKey.pv_ml_id: xr_data["ml_id"].values, - BatchKey.pv_id: xr_data["pv_system_id"].values.astype(np.float32), - BatchKey.pv_observed_capacity_wp: (xr_data["observed_capacity_wp"].values), - BatchKey.pv_nominal_capacity_wp: (xr_data["nominal_capacity_wp"].values), - BatchKey.pv_time_utc: datetime64_to_float(xr_data["time_utc"].values), - BatchKey.pv_latitude: xr_data["latitude"].values, - BatchKey.pv_longitude: xr_data["longitude"].values, - } - - yield example + yield convert_pv_to_numpy_batch(xr_data) diff --git a/ocf_datapipes/convert/numpy_batch/satellite.py b/ocf_datapipes/convert/numpy_batch/satellite.py index ec650894d..ae7afaf97 100644 --- a/ocf_datapipes/convert/numpy_batch/satellite.py +++ b/ocf_datapipes/convert/numpy_batch/satellite.py @@ -5,6 +5,49 @@ from ocf_datapipes.utils.utils import datetime64_to_float +def _convert_satellite_to_numpy_batch(xr_data): + example: NumpyBatch = { + BatchKey.satellite_actual: xr_data.values, + BatchKey.satellite_t0_idx: xr_data.attrs["t0_idx"], + BatchKey.satellite_time_utc: datetime64_to_float(xr_data["time_utc"].values), + } + + for batch_key, dataset_key in ( + (BatchKey.satellite_y_geostationary, "y_geostationary"), + (BatchKey.satellite_x_geostationary, "x_geostationary"), + ): + # HRVSatellite coords are already float32. + example[batch_key] = xr_data[dataset_key].values + + return example + + +def _convert_hrvsatellite_to_numpy_batch(xr_data): + example: NumpyBatch = { + BatchKey.hrvsatellite_actual: xr_data.values, + BatchKey.hrvsatellite_t0_idx: xr_data.attrs["t0_idx"], + BatchKey.hrvsatellite_time_utc: datetime64_to_float(xr_data["time_utc"].values), + } + + for batch_key, dataset_key in ( + (BatchKey.hrvsatellite_y_geostationary, "y_geostationary"), + (BatchKey.hrvsatellite_x_geostationary, "x_geostationary"), + ): + # Satellite coords are already float32. + example[batch_key] = xr_data[dataset_key].values + + return example + + +def convert_satellite_to_numpy_batch(xr_data, is_hrv=False): + """Converts Xarray Satellite to NumpyBatch object""" + if is_hrv: + example = _convert_hrvsatellite_to_numpy_batch(xr_data) + else: + example = _convert_satellite_to_numpy_batch(xr_data) + return example + + @functional_datapipe("convert_satellite_to_numpy_batch") class ConvertSatelliteToNumpyBatchIterDataPipe(IterDataPipe): """Converts Xarray Satellite to NumpyBatch object""" @@ -24,31 +67,4 @@ def __init__(self, source_datapipe: IterDataPipe, is_hrv: bool = False): def __iter__(self) -> NumpyBatch: """Convert each example to a NumpyBatch object""" for xr_data in self.source_datapipe: - if self.is_hrv: - example: NumpyBatch = { - BatchKey.hrvsatellite_actual: xr_data.values, - BatchKey.hrvsatellite_t0_idx: xr_data.attrs["t0_idx"], - BatchKey.hrvsatellite_time_utc: datetime64_to_float(xr_data["time_utc"].values), - } - - for batch_key, dataset_key in ( - (BatchKey.hrvsatellite_y_geostationary, "y_geostationary"), - (BatchKey.hrvsatellite_x_geostationary, "x_geostationary"), - ): - # HRVSatellite coords are already float32. - example[batch_key] = xr_data[dataset_key].values - else: - example: NumpyBatch = { - BatchKey.satellite_actual: xr_data.values, - BatchKey.satellite_t0_idx: xr_data.attrs["t0_idx"], - BatchKey.satellite_time_utc: datetime64_to_float(xr_data["time_utc"].values), - } - - for batch_key, dataset_key in ( - (BatchKey.satellite_y_geostationary, "y_geostationary"), - (BatchKey.satellite_x_geostationary, "x_geostationary"), - ): - # HRVSatellite coords are already float32. - example[batch_key] = xr_data[dataset_key].values - - yield example + yield convert_satellite_to_numpy_batch(xr_data, self.is_hrv) diff --git a/ocf_datapipes/convert/numpy_batch/sensor.py b/ocf_datapipes/convert/numpy_batch/sensor.py index 754aec447..3e8874106 100644 --- a/ocf_datapipes/convert/numpy_batch/sensor.py +++ b/ocf_datapipes/convert/numpy_batch/sensor.py @@ -10,34 +10,37 @@ logger = logging.getLogger(__name__) +def convert_sensor_to_numpy_batch(xr_data): + """Convert Sensor Xarray to NumpyBatch""" + + example: NumpyBatch = { + BatchKey.sensor: xr_data.values, + BatchKey.sensor_t0_idx: xr_data.attrs["t0_idx"], + BatchKey.sensor_id: xr_data["station_id"].values.astype(np.float32), + # BatchKey.sensor_observed_capacity_wp: (xr_data["observed_capacity_wp"].values), + # BatchKey.sensor_nominal_capacity_wp: (xr_data["nominal_capacity_wp"].values), + BatchKey.sensor_time_utc: datetime64_to_float(xr_data["time_utc"].values), + BatchKey.sensor_latitude: xr_data["latitude"].values, + BatchKey.sensor_longitude: xr_data["longitude"].values, + } + return example + + @functional_datapipe("convert_sensor_to_numpy_batch") class ConvertSensorToNumpyBatchIterDataPipe(IterDataPipe): """Convert Sensor Xarray to NumpyBatch""" def __init__(self, source_datapipe: IterDataPipe): """ - Convert PV Xarray objects to NumpyBatch objects + Convert sensor Xarray objects to NumpyBatch objects Args: - source_datapipe: Datapipe emitting PV Xarray objects + source_datapipe: Datapipe emitting sensor Xarray objects """ super().__init__() self.source_datapipe = source_datapipe def __iter__(self) -> NumpyBatch: - """Iterate and convert PV Xarray to NumpyBatch""" + """Iterate and convert sensor Xarray to NumpyBatch""" for xr_data in self.source_datapipe: - logger.debug("Converting Sensor xarray to numpy example") - - example: NumpyBatch = { - BatchKey.sensor: xr_data.values, - BatchKey.sensor_t0_idx: xr_data.attrs["t0_idx"], - BatchKey.sensor_id: xr_data["station_id"].values.astype(np.float32), - # BatchKey.sensor_observed_capacity_wp: (xr_data["observed_capacity_wp"].values), - # BatchKey.sensor_nominal_capacity_wp: (xr_data["nominal_capacity_wp"].values), - BatchKey.sensor_time_utc: datetime64_to_float(xr_data["time_utc"].values), - BatchKey.sensor_latitude: xr_data["latitude"].values, - BatchKey.sensor_longitude: xr_data["longitude"].values, - } - - yield example + yield convert_sensor_to_numpy_batch(xr_data) diff --git a/ocf_datapipes/convert/numpy_batch/wind.py b/ocf_datapipes/convert/numpy_batch/wind.py index 8495a6584..3f04a9a0d 100644 --- a/ocf_datapipes/convert/numpy_batch/wind.py +++ b/ocf_datapipes/convert/numpy_batch/wind.py @@ -10,6 +10,24 @@ logger = logging.getLogger(__name__) +def convert_wind_to_numpy_batch(xr_data): + """Convert Wind Xarray to NumpyBatch""" + + example: NumpyBatch = { + BatchKey.wind: xr_data.values, + BatchKey.wind_t0_idx: xr_data.attrs["t0_idx"], + BatchKey.wind_ml_id: xr_data["ml_id"].values, + BatchKey.wind_id: xr_data["wind_system_id"].values.astype(np.float32), + BatchKey.wind_observed_capacity_mwp: (xr_data["observed_capacity_mwp"].values), + BatchKey.wind_nominal_capacity_mwp: (xr_data["nominal_capacity_mwp"].values), + BatchKey.wind_time_utc: datetime64_to_float(xr_data["time_utc"].values), + BatchKey.wind_latitude: xr_data["latitude"].values, + BatchKey.wind_longitude: xr_data["longitude"].values, + } + + return example + + @functional_datapipe("convert_wind_to_numpy_batch") class ConvertWindToNumpyBatchIterDataPipe(IterDataPipe): """Convert Wind Xarray to NumpyBatch""" @@ -27,18 +45,4 @@ def __init__(self, source_datapipe: IterDataPipe): def __iter__(self) -> NumpyBatch: """Iterate and convert PV Xarray to NumpyBatch""" for xr_data in self.source_datapipe: - logger.debug("Converting Wind xarray to numpy example") - - example: NumpyBatch = { - BatchKey.wind: xr_data.values, - BatchKey.wind_t0_idx: xr_data.attrs["t0_idx"], - BatchKey.wind_ml_id: xr_data["ml_id"].values, - BatchKey.wind_id: xr_data["wind_system_id"].values.astype(np.float32), - BatchKey.wind_observed_capacity_mwp: (xr_data["observed_capacity_mwp"].values), - BatchKey.wind_nominal_capacity_mwp: (xr_data["nominal_capacity_mwp"].values), - BatchKey.wind_time_utc: datetime64_to_float(xr_data["time_utc"].values), - BatchKey.wind_latitude: xr_data["latitude"].values, - BatchKey.wind_longitude: xr_data["longitude"].values, - } - - yield example + yield convert_wind_to_numpy_batch(xr_data) diff --git a/ocf_datapipes/load/gsp/gsp.py b/ocf_datapipes/load/gsp/gsp.py index 57800032a..318a86753 100644 --- a/ocf_datapipes/load/gsp/gsp.py +++ b/ocf_datapipes/load/gsp/gsp.py @@ -12,14 +12,6 @@ logger = logging.getLogger(__name__) -try: - from ocf_datapipes.utils.eso import get_gsp_metadata_from_eso, get_gsp_shape_from_eso - - _has_pvlive = True -except ImportError: - print("Unable to import PVLive utils, please provide filenames with OpenGSP") - _has_pvlive = False - @functional_datapipe("open_gsp") class OpenGSPIterDataPipe(IterDataPipe): @@ -44,40 +36,34 @@ def __init__( sample_period_duration: Sample period of the GSP data """ self.gsp_pv_power_zarr_path = gsp_pv_power_zarr_path - if ( - gsp_id_to_region_id_filename is None - or sheffield_solar_region_path is None - and _has_pvlive - ): - self.gsp_id_to_region_id_filename = get_gsp_metadata_from_eso() - self.sheffield_solar_region_path = get_gsp_shape_from_eso() - else: - self.gsp_id_to_region_id_filename = gsp_id_to_region_id_filename - self.sheffield_solar_region_path = sheffield_solar_region_path + + self.gsp_id_to_region_id_filename = gsp_id_to_region_id_filename + self.sheffield_solar_region_path = sheffield_solar_region_path self.threshold_mw = threshold_mw self.sample_period_duration = sample_period_duration def __iter__(self) -> xr.DataArray: """Get and return GSP data""" gsp_id_to_shape = get_gsp_id_to_shape( - self.gsp_id_to_region_id_filename, self.sheffield_solar_region_path + self.gsp_id_to_region_id_filename, + self.sheffield_solar_region_path, ) - self._gsp_id_to_shape = gsp_id_to_shape # Save, mostly for plotting to check all is fine! logger.debug(f"Getting GSP data from {self.gsp_pv_power_zarr_path}") - # Load GSP generation xr.Dataset: + # Load GSP generation xr.Dataset gsp_pv_power_mw_ds = xr.open_dataset(self.gsp_pv_power_zarr_path, engine="zarr") - # Ensure the centroids have the same GSP ID index as the GSP PV power: + # Ensure the centroids have the same GSP ID index as the GSP PV power gsp_id_to_shape = gsp_id_to_shape.loc[gsp_pv_power_mw_ds.gsp_id] + data_array = put_gsp_data_into_an_xr_dataarray( gsp_pv_power_mw=gsp_pv_power_mw_ds.generation_mw.data.astype(np.float32), time_utc=gsp_pv_power_mw_ds.datetime_gmt.data, gsp_id=gsp_pv_power_mw_ds.gsp_id.data, # TODO: Try using `gsp_id_to_shape.geometry.envelope.centroid`. See issue #76. - x_osgb=gsp_id_to_shape.geometry.centroid.x.astype(np.float32), - y_osgb=gsp_id_to_shape.geometry.centroid.y.astype(np.float32), + x_osgb=gsp_id_to_shape.x_osgb.astype(np.float32), + y_osgb=gsp_id_to_shape.y_osgb.astype(np.float32), nominal_capacity_mwp=gsp_pv_power_mw_ds.installedcapacity_mwp.data.astype(np.float32), effective_capacity_mwp=gsp_pv_power_mw_ds.capacity_mwp.data.astype(np.float32), ) diff --git a/ocf_datapipes/load/gsp/utils.py b/ocf_datapipes/load/gsp/utils.py index 3ec228935..faf74243d 100644 --- a/ocf_datapipes/load/gsp/utils.py +++ b/ocf_datapipes/load/gsp/utils.py @@ -1,9 +1,21 @@ """ Utils for GSP loading""" +from typing import Optional + import geopandas as gpd import numpy as np import pandas as pd import xarray as xr +from ocf_datapipes.utils.location import Location + +try: + from ocf_datapipes.utils.eso import get_gsp_metadata_from_eso, get_gsp_shape_from_eso + + _has_pvlive = True +except ImportError: + print("Unable to import PVLive utils, please provide filenames with OpenGSP") + _has_pvlive = False + def put_gsp_data_into_an_xr_dataarray( gsp_pv_power_mw: np.ndarray, @@ -48,7 +60,8 @@ def put_gsp_data_into_an_xr_dataarray( def get_gsp_id_to_shape( - gsp_id_to_region_id_filename: str, sheffield_solar_region_path: str + gsp_id_to_region_id_filename: Optional[str] = None, + sheffield_solar_region_path: Optional[str] = None, ) -> gpd.GeoDataFrame: """ Get the GSP ID to the shape @@ -60,6 +73,16 @@ def get_gsp_id_to_shape( Returns: GeoDataFrame containing the mapping from ID to shape """ + + did_provide_filepaths = None not in [gsp_id_to_region_id_filename, sheffield_solar_region_path] + assert _has_pvlive or did_provide_filepaths + + if not did_provide_filepaths: + if gsp_id_to_region_id_filename is None: + gsp_id_to_region_id_filename = get_gsp_metadata_from_eso() + if sheffield_solar_region_path is None: + sheffield_solar_region_path = get_gsp_shape_from_eso() + # Load mapping from GSP ID to Sheffield Solar GSP ID to GSP name: gsp_id_to_region_id = pd.read_csv( gsp_id_to_region_id_filename, @@ -94,4 +117,42 @@ def get_gsp_id_to_shape( # For the national forecast, GSP ID 0, we want the shape to be the # union of all the other shapes gsp_id_to_shape = pd.concat([gsp_id_to_shape, gsp_0]).sort_index() + + # Add central coordinates + gsp_id_to_shape["x_osgb"] = gsp_id_to_shape.geometry.centroid.x.astype(np.float32) + gsp_id_to_shape["y_osgb"] = gsp_id_to_shape.geometry.centroid.y.astype(np.float32) + return gsp_id_to_shape + + +class GSPLocationLookup: + """Query object for GSP location from GSP ID""" + + def __init__( + self, + gsp_id_to_region_id_filename: Optional[str] = None, + sheffield_solar_region_path: Optional[str] = None, + ): + """Query object for GSP location from GSP ID + + Args: + gsp_id_to_region_id_filename: Filename of the mapping file + sheffield_solar_region_path: Path to the region shapes + + """ + self.gsp_id_to_shape = get_gsp_id_to_shape( + gsp_id_to_region_id_filename, + sheffield_solar_region_path, + ) + + def __call__(self, gsp_id: int) -> Location: + """Returns the locations for the input GSP IDs. + + Args: + gsp_id: Integer ID of the GSP + """ + return Location( + x=self.gsp_id_to_shape.loc[gsp_id].x_osgb.astype(np.float32), + y=self.gsp_id_to_shape.loc[gsp_id].y_osgb.astype(np.float32), + id=gsp_id, + ) diff --git a/ocf_datapipes/select/pick_locations.py b/ocf_datapipes/select/pick_locations.py index 591da8254..81b913956 100644 --- a/ocf_datapipes/select/pick_locations.py +++ b/ocf_datapipes/select/pick_locations.py @@ -17,51 +17,78 @@ class PickLocationsIterDataPipe(IterDataPipe): def __init__( self, source_datapipe: IterDataPipe, - return_all_locations: bool = False, + return_all: bool = False, + shuffle: bool = False, ): """ - Picks random locations from a dataset + Datapipe to yield locations from the input data source. Args: source_datapipe: Datapipe emitting Xarray Dataset - return_all_locations: Whether to return all locations, - if True, also returns them in order + return_all: Whether to return all t0-location pairs, + if True, also returns them in structured order + shuffle: If `return_all` is True this sets whether the pairs are + shuffled before being returned. """ super().__init__() self.source_datapipe = source_datapipe - self.return_all_locations = return_all_locations + self.return_all = return_all + self.shuffle = shuffle + + def _yield_all_iter(self, xr_dataset): + """Samples without replacement from possible locations""" + # Get the spatial coords + xr_coord_system, xr_x_dim, xr_y_dim = spatial_coord_type(xr_dataset) + + loc_indices = np.arange(len(xr_dataset[xr_x_dim])) + + if self.shuffle: + loc_indices = np.random.permutation(loc_indices) + + # Iterate through all locations in dataset + for loc_index in loc_indices: + # Get the location ID + loc_id = None + for id_dim_name in ["pv_system_id", "gsp_id", "station_id"]: + if id_dim_name in xr_dataset.coords.keys(): + loc_id = int(xr_dataset[id_dim_name][loc_index].values) + + location = Location( + coordinate_system=xr_coord_system, + x=xr_dataset[xr_x_dim][loc_index].values, + y=xr_dataset[xr_y_dim][loc_index].values, + id=loc_id, + ) + + yield location + + def _yield_random_iter(self, xr_dataset): + """Samples with replacement from possible locations""" + # Get the spatial coords + xr_coord_system, xr_x_dim, xr_y_dim = spatial_coord_type(xr_dataset) + + while True: + loc_index = np.random.randint(0, len(xr_dataset[xr_x_dim])) + + # Get the location ID + loc_id = None + for id_dim_name in ["pv_system_id", "gsp_id", "station_id"]: + if id_dim_name in xr_dataset.coords.keys(): + loc_id = int(xr_dataset[id_dim_name][loc_index].values) + + location = Location( + coordinate_system=xr_coord_system, + x=xr_dataset[xr_x_dim][loc_index].values, + y=xr_dataset[xr_y_dim][loc_index].values, + id=loc_id, + ) + + yield location def __iter__(self) -> Location: - """Returns locations from the inputs datapipe""" - for xr_dataset in self.source_datapipe: - loc_type, xr_x_dim, xr_y_dim = spatial_coord_type(xr_dataset) - - if self.return_all_locations: - logger.debug("Going to return all locations") - - # Iterate through all locations in dataset - for location_idx in range(len(xr_dataset[xr_x_dim])): - location = Location( - x=xr_dataset[xr_x_dim][location_idx].values, - y=xr_dataset[xr_y_dim][location_idx].values, - coordinate_system=loc_type, - ) - if "pv_system_id" in xr_dataset.coords.keys(): - location.id = int(xr_dataset["pv_system_id"][location_idx].values) - logger.debug(f"Got all location {location}") - yield location - else: - # Pick 1 random location from the input dataset - logger.debug("Selecting random idx") - location_idx = np.random.randint(0, len(xr_dataset[xr_x_dim])) - logger.debug(f"{location_idx=}") - location = Location( - x=xr_dataset[xr_x_dim][location_idx].values, - y=xr_dataset[xr_y_dim][location_idx].values, - coordinate_system=loc_type, - ) - if "pv_system_id" in xr_dataset.coords.keys(): - location.id = int(xr_dataset["pv_system_id"][location_idx].values) - logger.debug(f"Have selected location.id {location.id}") - logger.debug(f"{location=}") - yield location + xr_dataset = next(iter(self.source_datapipe)) + + if self.return_all: + return self._yield_all_iter(xr_dataset) + else: + return self._yield_random_iter(xr_dataset) diff --git a/ocf_datapipes/select/pick_locations_and_t0_times.py b/ocf_datapipes/select/pick_locations_and_t0_times.py index 06cec1677..5e72af414 100644 --- a/ocf_datapipes/select/pick_locations_and_t0_times.py +++ b/ocf_datapipes/select/pick_locations_and_t0_times.py @@ -32,7 +32,7 @@ def __init__( source_datapipe: Datapipe emitting Xarray Dataset return_all: Whether to return all t0-location pairs, if True, also returns them in structured order - shuffle: If `return_all` sets whether the pairs are + shuffle: If `return_all` is True this sets whether the pairs are shuffled before being returned. time_dim_name: time dimension name, defaulted to 'time_utc' """ @@ -43,7 +43,9 @@ def __init__( self.time_dim_name = time_dim_name def _yield_all_iter(self, xr_dataset): + # Get the spatial coords xr_coord_system, xr_x_dim, xr_y_dim = spatial_coord_type(xr_dataset) + t_index, x_index = np.meshgrid( np.arange(len(xr_dataset[self.time_dim_name])), np.arange(len(xr_dataset[xr_x_dim])), @@ -56,47 +58,41 @@ def _yield_all_iter(self, xr_dataset): # Iterate through all locations in dataset for t_index, loc_index in index_pairs: + # Get the location ID + loc_id = None + for id_dim_name in ["pv_system_id", "gsp_id", "station_id"]: + if id_dim_name in xr_dataset.coords.keys(): + loc_id = int(xr_dataset[id_dim_name][loc_index].values) + t0 = xr_dataset[self.time_dim_name][t_index].values location = Location( coordinate_system=xr_coord_system, x=xr_dataset[xr_x_dim][loc_index].values, y=xr_dataset[xr_y_dim][loc_index].values, + id=loc_id, ) - # for pv - if "pv_system_id" in xr_dataset.coords.keys(): - location.id = int(xr_dataset["pv_system_id"][loc_index].values) - - # for gsp - if "gsp_id" in xr_dataset.coords.keys(): - location.id = int(xr_dataset["gsp_id"][loc_index].values) - - # for sensor - if "station_id" in xr_dataset.coords.keys(): - location.id = int(xr_dataset["station_id"][loc_index].values) - yield location, t0 def _yield_random_iter(self, xr_dataset): + # Get the spatial coords xr_coord_system, xr_x_dim, xr_y_dim = spatial_coord_type(xr_dataset) + while True: - location_idx = np.random.randint(0, len(xr_dataset[xr_x_dim])) + loc_index = np.random.randint(0, len(xr_dataset[xr_x_dim])) + + # Get the location ID + loc_id = None + for id_dim_name in ["pv_system_id", "gsp_id", "station_id"]: + if id_dim_name in xr_dataset.coords.keys(): + loc_id = int(xr_dataset[id_dim_name][loc_index].values) location = Location( coordinate_system=xr_coord_system, - x=xr_dataset[xr_x_dim][location_idx].values, - y=xr_dataset[xr_y_dim][location_idx].values, + x=xr_dataset[xr_x_dim][loc_index].values, + y=xr_dataset[xr_y_dim][loc_index].values, + id=loc_id, ) - if "pv_system_id" in xr_dataset.coords.keys(): - location.id = int(xr_dataset["pv_system_id"][location_idx].values) - - # for gsp - if "gsp_id" in xr_dataset.coords.keys(): - location.id = int(xr_dataset["gsp_id"][location_idx].values) - - # for sensor - if "station_id" in xr_dataset.coords.keys(): - location.id = int(xr_dataset["station_id"][location_idx].values) t0 = np.random.choice(xr_dataset[self.time_dim_name].values) diff --git a/ocf_datapipes/select/pick_t0_times.py b/ocf_datapipes/select/pick_t0_times.py index 88afdccf9..1f677cd8d 100644 --- a/ocf_datapipes/select/pick_t0_times.py +++ b/ocf_datapipes/select/pick_t0_times.py @@ -10,31 +10,52 @@ @functional_datapipe("pick_t0_times") class PickT0TimesIterDataPipe(IterDataPipe): - """Picks random t0 times from a dataset""" + """Picks (random) t0 times from a dataset""" def __init__( self, source_datapipe: IterDataPipe, + return_all: bool = False, + shuffle: bool = False, dim_name: str = "time_utc", ): """ - Picks random t0 times from a dataset + Datapipe to yield t0 times from the input data source. Args: - source_datapipe: Datapipe emitting Xarray objects - dim_name: The time dimension name to use + source_datapipe: Datapipe emitting Xarray objects. + return_all: Whether to return all t0 values, else sample with replacement. If True, the + default behaviour to return t0 values in order - see `shuffle` parameter. + shuffle: If `return_all` is True this sets whether the pairs are + shuffled before being returned. + dim_name: The time dimension name to use. """ self.source_datapipe = source_datapipe + self.return_all = return_all + self.shuffle = shuffle self.dim_name = dim_name + def _yield_random_iter(self, xr_dataset): + """Sample t0 with replacement""" + while True: + t0 = np.random.choice(xr_dataset[self.dim_name].values) + yield t0 + + def _yield_all_iter(self, xr_dataset): + """Yield all the t0s in order, and maybe with a shuffle""" + all_t0s = np.copy(xr_dataset[self.dim_name].values) + if self.shuffle: + all_t0s = np.random.permutation(all_t0s) + for t0 in all_t0s: + yield t0 + def __iter__(self) -> pd.Timestamp: - """Get the latest timestamp and return it""" - for xr_data in self.source_datapipe: - logger.debug(f"Selecting t0 from {len(xr_data[self.dim_name])} datetimes") + xr_dataset = next(iter(self.source_datapipe)) - if len(xr_data[self.dim_name].values) == 0: - raise Exception("There are no values to get t0 from") - t0 = np.random.choice(xr_data[self.dim_name].values) - logger.debug(f"t0 will be {t0}") + if len(xr_dataset[self.dim_name].values) == 0: + raise Exception("There are no values to get t0 from") - yield t0 + if self.return_all: + return self._yield_all_iter(xr_dataset) + else: + return self._yield_random_iter(xr_dataset) diff --git a/ocf_datapipes/select/select_spatial_slice.py b/ocf_datapipes/select/select_spatial_slice.py index 36e4c8787..11a5dad95 100644 --- a/ocf_datapipes/select/select_spatial_slice.py +++ b/ocf_datapipes/select/select_spatial_slice.py @@ -21,7 +21,249 @@ logger = logging.getLogger(__name__) -def select_spatial_slice_pixels( +# -------------------------------- utility functions -------------------------------- + + +def convert_coords_to_match_xarray(x, y, from_coords, xr_data): + """Convert x and y coords to cooridnate system matching xarray data + + Args: + x: Float or array-like + y: Float or array-like + from_coords: String describing coordinate system of x and y + xr_data: xarray data object to which coordinates should be matched + """ + + xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) + + assert from_coords in ["osgb", "lon_lat"] + assert xr_coords in ["geostationary", "osgb", "lon_lat"] + + if xr_coords == "geostationary": + if from_coords == "osgb": + x, y = osgb_to_geostationary_area_coords(x, y, xr_data) + + elif from_coords == "lon_lat": + x, y = lon_lat_to_geostationary_area_coords(x, y, xr_data) + + elif xr_coords == "lon_lat": + if from_coords == "osgb": + x, y = osgb_to_lon_lat(x, y) + + # else the from_coords=="lon_lat" and we don't need to convert + + elif xr_coords == "osgb": + if from_coords == "lon_lat": + x, y = lon_lat_to_osgb(x, y) + + # else the from_coords=="osgb" and we don't need to convert + + return x, y + + +def _get_idx_of_pixel_closest_to_poi( + xr_data: xr.DataArray, + location: Location, +) -> Location: + """ + Return x and y index location of pixel at center of region of interest. + + Args: + xr_data: Xarray dataset + location: Center + Returns: + The Location for the center pixel + """ + xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) + + if xr_coords not in ["osgb", "lon_lat"]: + raise NotImplementedError(f"Only 'osgb' and 'lon_lat' are supported - not '{xr_coords}'") + + # Convert location coords to match xarray data + x, y = convert_coords_to_match_xarray( + location.x, + location.y, + from_coords=location.coordinate_system, + xr_data=xr_data, + ) + + # Check that the requested point lies within the data + assert xr_data[xr_x_dim].min() < x < xr_data[xr_x_dim].max() + assert xr_data[xr_y_dim].min() < y < xr_data[xr_y_dim].max() + + x_index = xr_data.get_index(xr_x_dim) + y_index = xr_data.get_index(xr_y_dim) + + closest_x = x_index.get_indexer([x], method="nearest")[0] + closest_y = y_index.get_indexer([y], method="nearest")[0] + + return Location(x=closest_x, y=closest_y, coordinate_system="idx") + + +def _get_idx_of_pixel_closest_to_poi_geostationary( + xr_data: xr.DataArray, + center_osgb: Location, +) -> Location: + """ + Return x and y index location of pixel at center of region of interest. + + Args: + xr_data: Xarray dataset + center_osgb: Center in OSGB coordinates + + Returns: + Location for the center pixel in geostationary coordinates + """ + + xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) + + x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=xr_data) + center_geostationary = Location(x=x, y=y, coordinate_system="geostationary") + + # Check that the requested point lies within the data + assert xr_data[xr_x_dim].min() < x < xr_data[xr_x_dim].max() + assert xr_data[xr_y_dim].min() < y < xr_data[xr_y_dim].max() + + # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary: + x_index_at_center = searchsorted( + xr_data[xr_x_dim].values, center_geostationary.x, assume_ascending=True + ) + + # y_geostationary is in descending order: + y_index_at_center = searchsorted( + xr_data[xr_y_dim].values, center_geostationary.y, assume_ascending=False + ) + + return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx") + + +def _get_points_from_unstructured_grids( + xr_data: xr.DataArray, + location: Location, + location_idx_name: str = "values", + num_points: int = 1, +): + """ + Get the closest points from an unstructured grid (i.e. Icosahedral grid) + + This is primarily used for the Icosahedral grid, which is not a regular grid, + and so is not an image + + Args: + xr_data: Xarray dataset + location: Location of center point + location_idx_name: Name of the index values dimension + (i.e. where we index into to get the lat/lon for that point) + num_points: Number of points to return (should be width * height) + + Returns: + The closest points from the grid + """ + xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) + assert xr_coords == "lon_lat" + + # Check if need to convert from different coordinate system to lat/lon + if location.coordinate_system == "osgb": + longitude, latitude = osgb_to_lon_lat(x=location.x, y=location.y) + location = Location( + x=longitude, + y=latitude, + coordinate_system="lon_lat", + ) + elif location.coordinate_system == "geostationary": + raise NotImplementedError( + "Does not currently support geostationary coordinates when using unstructured grids" + ) + + # Extract lat, lon, and locidx data + lat = xr_data.longitude.values + lon = xr_data.latitude.values + locidx = xr_data[location_idx_name].values + + # Create a KDTree + tree = KDTree(list(zip(lat, lon))) + + # Query with the [longitude, latitude] of your point + _, idx = tree.query([location.x, location.y], k=num_points) + + # Retrieve the location_idxs for these grid points + location_idxs = locidx[idx] + + data = xr_data.sel({location_idx_name: location_idxs}) + return data + + +# ---------------------------- sub-functions for slicing ---------------------------- + + +def _slice_patial_spatial_pixel_window_from_xarray( + xr_data, + left_idx, + right_idx, + top_idx, + bottom_idx, + left_pad_pixels, + right_pad_pixels, + top_pad_pixels, + bottom_pad_pixels, + xr_x_dim, + xr_y_dim, +): + """Return spatial window of given pixel size when window partially overlaps input data""" + + dx = np.median(np.diff(xr_data[xr_x_dim].values)) + dy = np.median(np.diff(xr_data[xr_y_dim].values)) + + if left_pad_pixels > 0: + assert right_pad_pixels == 0 + x_sel = np.concatenate( + [ + xr_data[xr_x_dim].values[0] - np.arange(left_pad_pixels, 0, -1) * dx, + xr_data[xr_x_dim].values[0:right_idx], + ] + ) + xr_data = xr_data.isel({xr_x_dim: slice(0, right_idx)}).reindex({xr_x_dim: x_sel}) + + elif right_pad_pixels > 0: + assert left_pad_pixels == 0 + x_sel = np.concatenate( + [ + xr_data[xr_x_dim].values[left_idx:], + xr_data[xr_x_dim].values[-1] + np.arange(1, right_pad_pixels + 1) * dx, + ] + ) + xr_data = xr_data.isel({xr_x_dim: slice(left_idx, None)}).reindex({xr_x_dim: x_sel}) + + else: + xr_data = xr_data.isel({xr_x_dim: slice(left_idx, right_idx)}) + + if top_pad_pixels > 0: + assert bottom_pad_pixels == 0 + y_sel = np.concatenate( + [ + xr_data[xr_y_dim].values[0] - np.arange(top_pad_pixels, 0, -1) * dy, + xr_data[xr_y_dim].values[0:bottom_idx], + ] + ) + xr_data = xr_data.isel({xr_y_dim: slice(0, bottom_idx)}).reindex({xr_y_dim: y_sel}) + + elif bottom_pad_pixels > 0: + assert top_pad_pixels == 0 + y_sel = np.concatenate( + [ + xr_data[xr_y_dim].values[top_idx:], + xr_data[xr_y_dim].values[-1] + np.arange(1, bottom_pad_pixels + 1) * dy, + ] + ) + xr_data = xr_data.isel({xr_y_dim: slice(top_idx, None)}).reindex({xr_x_dim: y_sel}) + + else: + xr_data = xr_data.isel({xr_y_dim: slice(top_idx, bottom_idx)}) + + return xr_data + + +def slice_spatial_pixel_window_from_xarray( xr_data, center_idx, width_pixels, height_pixels, xr_x_dim, xr_y_dim, allow_partial_slice ): """Select a spatial slice from an xarray object @@ -64,7 +306,7 @@ def select_spatial_slice_pixels( (bottom_idx - (data_height_pixels - 1)) if bottom_pad_required else 0 ) - xr_data = select_partial_spatial_slice_pixels( + xr_data = _slice_patial_spatial_pixel_window_from_xarray( xr_data, left_idx, right_idx, @@ -105,71 +347,164 @@ def select_spatial_slice_pixels( return xr_data -def select_partial_spatial_slice_pixels( - xr_data, - left_idx, - right_idx, - top_idx, - bottom_idx, - left_pad_pixels, - right_pad_pixels, - top_pad_pixels, - bottom_pad_pixels, - xr_x_dim, - xr_y_dim, +# ---------------------------- main functions for slicing --------------------------- + + +def select_spatial_slice_pixels( + xr_data: Union[xr.Dataset, xr.DataArray], + location: Location, + roi_width_pixels: int, + roi_height_pixels: int, + allow_partial_slice: bool = False, + location_idx_name: Optional[str] = None, ): - """Return spatial window of given pixel size when window partially overlaps input data""" + """ + Select spatial slice based off pixels from location point of interest - dx = np.median(np.diff(xr_data[xr_x_dim].values)) - dy = np.median(np.diff(xr_data[xr_y_dim].values)) + If `allow_partial_slice` is set to True, then slices may be made which intersect the border + of the input data. The additional x and y cordinates that would be required for this slice + are extrapolated based on the average spacing of these coordinates in the input data. + However, currently slices cannot be made where the centre of the window is outside of the + input data. - if left_pad_pixels > 0: - assert right_pad_pixels == 0 - x_sel = np.concatenate( - [ - xr_data[xr_x_dim].values[0] - np.arange(left_pad_pixels, 0, -1) * dx, - xr_data[xr_x_dim].values[0:right_idx], - ] - ) - xr_data = xr_data.isel({xr_x_dim: slice(0, right_idx)}).reindex({xr_x_dim: x_sel}) + Args: + xr_data: Xarray DataArray or Dataset to slice from + location: Location of interest + roi_height_pixels: ROI height in pixels + roi_width_pixels: ROI width in pixels + allow_partial_slice: Whether to allow a partial slice. + location_idx_name: Name for location index of unstructured grid data, + None if not relevant + """ - elif right_pad_pixels > 0: - assert left_pad_pixels == 0 - x_sel = np.concatenate( - [ - xr_data[xr_x_dim].values[left_idx:], - xr_data[xr_x_dim].values[-1] + np.arange(1, right_pad_pixels + 1) * dx, - ] + xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) + if location_idx_name is not None: + selected = _get_points_from_unstructured_grids( + xr_data=xr_data, + location=location, + location_idx_name=location_idx_name, + num_points=roi_width_pixels * roi_height_pixels, ) - xr_data = xr_data.isel({xr_x_dim: slice(left_idx, None)}).reindex({xr_x_dim: x_sel}) - else: - xr_data = xr_data.isel({xr_x_dim: slice(left_idx, right_idx)}) + if xr_coords == "geostationary": + center_idx: Location = _get_idx_of_pixel_closest_to_poi_geostationary( + xr_data=xr_data, + center_osgb=location, + ) + else: + center_idx: Location = _get_idx_of_pixel_closest_to_poi( + xr_data=xr_data, + location=location, + ) - if top_pad_pixels > 0: - assert bottom_pad_pixels == 0 - y_sel = np.concatenate( - [ - xr_data[xr_y_dim].values[0] - np.arange(top_pad_pixels, 0, -1) * dy, - xr_data[xr_y_dim].values[0:bottom_idx], - ] + selected = slice_spatial_pixel_window_from_xarray( + xr_data, + center_idx, + roi_width_pixels, + roi_height_pixels, + xr_x_dim, + xr_y_dim, + allow_partial_slice=allow_partial_slice, ) - xr_data = xr_data.isel({xr_y_dim: slice(0, bottom_idx)}).reindex({xr_y_dim: y_sel}) - elif bottom_pad_pixels > 0: - assert top_pad_pixels == 0 - y_sel = np.concatenate( - [ - xr_data[xr_y_dim].values[top_idx:], - xr_data[xr_y_dim].values[-1] + np.arange(1, bottom_pad_pixels + 1) * dy, - ] + return selected + + +def select_spatial_slice_meters( + xr_data: Union[xr.Dataset, xr.DataArray], + location: Location, + roi_width_meters: int, + roi_height_meters: int, + dim_name: Optional[str] = None, +): + """ + Select spatial slice based off pixels from point of interest + + Args: + xr_data: Xarray DataArray or Dataset to slice from + location: Location of interest + roi_height_meters: ROI height in meters + roi_width_meters: ROI width in meters + dim_name: Dimension name to select for ID, None for coordinates + + Notes: + Using spatial slicing based on distance rather than number of pixels will often yield + slices which can vary by 1 pixel in height and/or width. + + E.g. Suppose the Xarray data has x-coords = [1,2,3,4,5]. We want to slice a spatial + window with a size which equates to 2.2 along the x-axis. If we choose to slice around + the point x=3 this will slice out the x-coords [2,3,4]. If we choose to slice around the + point x=2.5 this will slice out the x-coords [2,3]. Hence the returned slice can have + size either 2 or 3 in the x-axis depending on the spatial location selected. + + Also, if selecting over a large span of latitudes, this may also causes pixel sizes of + the yielded outputs to change. For example, if the Xarray data is on a regularly spaced + longitude-latitude grid, then the structure of the grid means that the longitudes near + to the poles are spaced closer together (measured in meters) than at the equator. So + slices near the equator will have less pixels in the x-axis than slices taken near the + poles. + """ + # Get the spatial coords of the xarray data + xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) + + half_width = roi_width_meters // 2 + half_height = roi_height_meters // 2 + + # Find the bounding box values for the location in either lat-lon or OSGB coord systems + if location.coordinate_system == "lon_lat": + right, top = move_lon_lat_by_meters( + location.x, + location.y, + half_width, + half_height, ) - xr_data = xr_data.isel({xr_y_dim: slice(top_idx, None)}).reindex({xr_x_dim: y_sel}) + left, bottom = move_lon_lat_by_meters( + location.x, + location.y, + -half_width, + -half_height, + ) + + elif location.coordinate_system == "osgb": + left = location.x - half_width + right = location.x + half_width + bottom = location.y - half_height + top = location.y + half_height else: - xr_data = xr_data.isel({xr_y_dim: slice(top_idx, bottom_idx)}) + raise ValueError(f"Location coord system not recognized: {location.coordinate_system}") + + # Change the bounding coordinates [left, right, bottom, top] to the same + # coordinate system as the xarray data + (left, right), (bottom, top) = convert_coords_to_match_xarray( + x=np.array([left, right], dtype=np.float32), + y=np.array([bottom, top], dtype=np.float32), + from_coords=location.coordinate_system, + xr_data=xr_data, + ) + + # Do it off coordinates, not ID + if dim_name is None: + # Select a patch from the xarray data + x_mask = (left <= xr_data[xr_x_dim]) & (xr_data[xr_x_dim] <= right) + y_mask = (bottom <= xr_data[xr_y_dim]) & (xr_data[xr_y_dim] <= top) + selected = xr_data.isel({xr_x_dim: x_mask, xr_y_dim: y_mask}) + + else: + # Select data in the region of interest and ID: + # This also works for unstructured grids + + id_mask = ( + (left <= xr_data[xr_x_dim]) + & (xr_data[xr_x_dim] <= right) + & (bottom <= xr_data[xr_y_dim]) + & (xr_data[xr_y_dim] <= top) + ) + selected = xr_data.isel({dim_name: id_mask}) + return selected - return xr_data + +# ------------------------------ datapipes for slicing ------------------------------ @functional_datapipe("select_spatial_slice_pixels") @@ -180,8 +515,8 @@ def __init__( self, source_datapipe: IterDataPipe, location_datapipe: IterDataPipe, - roi_height_pixels: int, roi_width_pixels: int, + roi_height_pixels: int, allow_partial_slice: bool = False, location_idx_name: Optional[str] = None, ): @@ -213,35 +548,14 @@ def __init__( def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: for xr_data, location in self.source_datapipe.zip_ocf(self.location_datapipe): logger.debug("Selecting spatial slice with pixels") - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) - if self.location_idx_name is not None: - selected = _get_points_from_unstructured_grids( - xr_data=xr_data, - location=location, - location_idx_name=self.location_idx_name, - num_points=self.roi_width_pixels * self.roi_height_pixels, - ) - yield selected - - if xr_coords == "geostationary": - center_idx: Location = _get_idx_of_pixel_closest_to_poi_geostationary( - xr_data=xr_data, - center_osgb=location, - ) - else: - center_idx: Location = _get_idx_of_pixel_closest_to_poi( - xr_data=xr_data, - location=location, - ) selected = select_spatial_slice_pixels( - xr_data, - center_idx, - self.roi_width_pixels, - self.roi_height_pixels, - xr_x_dim, - xr_y_dim, + xr_data=xr_data, + location=location, + roi_width_pixels=self.roi_width_pixels, + roi_height_pixels=self.roi_height_pixels, allow_partial_slice=self.allow_partial_slice, + location_idx_name=self.location_idx_name, ) yield selected @@ -255,9 +569,9 @@ def __init__( self, source_datapipe: IterDataPipe, location_datapipe: IterDataPipe, - roi_height_meters: int, roi_width_meters: int, - dim_name: Optional[str] = None, # "pv_system_id", + roi_height_meters: int, + dim_name: Optional[str] = None, ): """ Select spatial slice based off pixels from point of interest @@ -265,8 +579,8 @@ def __init__( Args: source_datapipe: Datapipe of Xarray data location_datapipe: Location datapipe - roi_height_meters: ROI height in meters roi_width_meters: ROI width in meters + roi_height_meters: ROI height in meters dim_name: Dimension name to select for ID, None for coordinates Notes: @@ -297,233 +611,12 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: # Compute the index for left and right: logger.debug("Getting Spatial Slice Meters") - # Get the spatial coords of the xarray data - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) - - half_height = self.roi_height_meters // 2 - half_width = self.roi_width_meters // 2 - - # Find the bounding box values for the location in either lat-lon or OSGB coord systems - if location.coordinate_system == "lon_lat": - right, top = move_lon_lat_by_meters( - location.x, - location.y, - half_width, - half_height, - ) - left, bottom = move_lon_lat_by_meters( - location.x, - location.y, - -half_width, - -half_height, - ) - - elif location.coordinate_system == "osgb": - left = location.x - half_width - right = location.x + half_width - bottom = location.y - half_height - top = location.y + half_height - - else: - raise ValueError( - f"Location coord system not recognized: {location.coordinate_system}" - ) - - # Change the bounding coordinates [left, right, bottom, top] to the same - # coordinate system as the xarray data - (left, right), (bottom, top) = convert_coords_to_match_xarray( - x=np.array([left, right], dtype=np.float32), - y=np.array([bottom, top], dtype=np.float32), - from_coords=location.coordinate_system, + selected = select_spatial_slice_meters( xr_data=xr_data, + location=location, + roi_width_meters=self.roi_width_meters, + roi_height_meters=self.roi_height_meters, + dim_name=self.dim_name, ) - # Do it off coordinates, not ID - if self.dim_name is None: - # Select a patch from the xarray data - x_mask = (left <= xr_data[xr_x_dim]) & (xr_data[xr_x_dim] <= right) - y_mask = (bottom <= xr_data[xr_y_dim]) & (xr_data[xr_y_dim] <= top) - selected = xr_data.isel({xr_x_dim: x_mask, xr_y_dim: y_mask}) - - else: - # Select data in the region of interest and ID: - # This also works for unstructured grids - - id_mask = ( - (left <= xr_data[xr_x_dim]) - & (xr_data[xr_x_dim] <= right) - & (bottom <= xr_data[xr_y_dim]) - & (xr_data[xr_y_dim] <= top) - ) - selected = xr_data.isel({self.dim_name: id_mask}) - yield selected - - -def convert_coords_to_match_xarray(x, y, from_coords, xr_data): - """Convert x and y coords to cooridnate system matching xarray data - - Args: - x: Float or array-like - y: Float or array-like - from_coords: String describing coordinate system of x and y - xr_data: xarray data object to which coordinates should be matched - """ - - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) - - assert from_coords in ["osgb", "lon_lat"] - assert xr_coords in ["geostationary", "osgb", "lon_lat"] - - if xr_coords == "geostationary": - if from_coords == "osgb": - x, y = osgb_to_geostationary_area_coords(x, y, xr_data) - - elif from_coords == "lon_lat": - x, y = lon_lat_to_geostationary_area_coords(x, y, xr_data) - - elif xr_coords == "lon_lat": - if from_coords == "osgb": - x, y = osgb_to_lon_lat(x, y) - - # else the from_coords=="lon_lat" and we don't need to convert - - elif xr_coords == "osgb": - if from_coords == "lon_lat": - x, y = lon_lat_to_osgb(x, y) - - # else the from_coords=="osgb" and we don't need to convert - - return x, y - - -def _get_idx_of_pixel_closest_to_poi( - xr_data: xr.DataArray, - location: Location, -) -> Location: - """ - Return x and y index location of pixel at center of region of interest. - - Args: - xr_data: Xarray dataset - location: Center - Returns: - The Location for the center pixel - """ - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) - - if xr_coords not in ["osgb", "lon_lat"]: - raise NotImplementedError(f"Only 'osgb' and 'lon_lat' are supported - not '{xr_coords}'") - - # Convert location coords to match xarray data - x, y = convert_coords_to_match_xarray( - location.x, - location.y, - from_coords=location.coordinate_system, - xr_data=xr_data, - ) - - # Check that the requested point lies within the data - assert xr_data[xr_x_dim].min() < x < xr_data[xr_x_dim].max() - assert xr_data[xr_y_dim].min() < y < xr_data[xr_y_dim].max() - - x_index = xr_data.get_index(xr_x_dim) - y_index = xr_data.get_index(xr_y_dim) - - closest_x = x_index.get_indexer([x], method="nearest")[0] - closest_y = y_index.get_indexer([y], method="nearest")[0] - - return Location(x=closest_x, y=closest_y, coordinate_system="idx") - - -def _get_idx_of_pixel_closest_to_poi_geostationary( - xr_data: xr.DataArray, - center_osgb: Location, -) -> Location: - """ - Return x and y index location of pixel at center of region of interest. - - Args: - xr_data: Xarray dataset - center_osgb: Center in OSGB coordinates - - Returns: - Location for the center pixel in geostationary coordinates - """ - - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) - - x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=xr_data) - center_geostationary = Location(x=x, y=y, coordinate_system="geostationary") - - # Check that the requested point lies within the data - assert xr_data[xr_x_dim].min() < x < xr_data[xr_x_dim].max() - assert xr_data[xr_y_dim].min() < y < xr_data[xr_y_dim].max() - - # Get the index into x and y nearest to x_center_geostationary and y_center_geostationary: - x_index_at_center = searchsorted( - xr_data[xr_x_dim].values, center_geostationary.x, assume_ascending=True - ) - - # y_geostationary is in descending order: - y_index_at_center = searchsorted( - xr_data[xr_y_dim].values, center_geostationary.y, assume_ascending=False - ) - - return Location(x=x_index_at_center, y=y_index_at_center, coordinate_system="idx") - - -def _get_points_from_unstructured_grids( - xr_data: xr.DataArray, - location: Location, - location_idx_name: str = "values", - num_points: int = 1, -): - """ - Get the closest points from an unstructured grid (i.e. Icosahedral grid) - - This is primarily used for the Icosahedral grid, which is not a regular grid, - and so is not an image - - Args: - xr_data: Xarray dataset - location: Location of center point - location_idx_name: Name of the index values dimension - (i.e. where we index into to get the lat/lon for that point) - num_points: Number of points to return (should be width * height) - - Returns: - The closest points from the grid - """ - xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data) - assert xr_coords == "lon_lat" - - # Check if need to convert from different coordinate system to lat/lon - if location.coordinate_system == "osgb": - longitude, latitude = osgb_to_lon_lat(x=location.x, y=location.y) - location = Location( - x=longitude, - y=latitude, - coordinate_system="lon_lat", - ) - elif location.coordinate_system == "geostationary": - raise NotImplementedError( - "Does not currently support geostationary coordinates when using unstructured grids" - ) - - # Extract lat, lon, and locidx data - lat = xr_data.longitude.values - lon = xr_data.latitude.values - locidx = xr_data[location_idx_name].values - - # Create a KDTree - tree = KDTree(list(zip(lat, lon))) - - # Query with the [longitude, latitude] of your point - _, idx = tree.query([location.x, location.y], k=num_points) - - # Retrieve the location_idxs for these grid points - location_idxs = locidx[idx] - - data = xr_data.sel({location_idx_name: location_idxs}) - return data diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 683b9c8e3..01007c116 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -937,20 +937,20 @@ def check_nans_in_satellite_data(batch: NumpyBatch) -> NumpyBatch: """ if np.any(np.isnan(batch[BatchKey.satellite_actual])): logger.error("Found nans values in satellite data") - logger.error(batch[BatchKey.satellite_actual].shape) # loop over time and channels for dim in [0, 1]: for t in range(batch[BatchKey.satellite_actual].shape[dim]): if dim == 0: - sate_data_one_step = batch[BatchKey.satellite_actual][t] + sat_data_one_step = batch[BatchKey.satellite_actual][t] else: - sate_data_one_step = batch[BatchKey.satellite_actual][:, t] - nans = np.isnan(sate_data_one_step) + sat_data_one_step = batch[BatchKey.satellite_actual][:, t] + + nans = np.isnan(sat_data_one_step) if np.any(nans): - percent_nans = np.sum(nans) / np.prod(sate_data_one_step.shape) * 100 + percent_nans = np.mean(nans) * 100 logger.error( f"Found nans values in satellite data at index {t} ({dim=}). " @@ -1093,28 +1093,18 @@ def add_selected_time_slices_from_datapipes(used_datapipes: dict): return datapipes_to_return -def create_t0_and_loc_datapipes( +def create_valid_t0_periods_datapipe( datapipes_dict: dict, configuration: Configuration, key_for_t0: str = "gsp", - shuffle: bool = True, ): - """ - Takes source datapipes and returns datapipes of appropriate sample pairs of locations and times. - - The (location, t0) pairs are sampled without replacement. + """Create datapipe yielding t0 periods which are valid for the input data sources. Args: datapipes_dict: Dictionary of datapipes of input sources for which we want to select appropriate location and times. configuration: Configuration object for inputs. key_for_t0: Key to use for the t0 datapipe. Must be "gsp" or "pv". - shuffle: Whether to use the internal shuffle function when yielding location times. Else - location times will be heavily ordered. - - Returns: - location datapipe, t0 datapipe - """ assert key_for_t0 in datapipes_dict assert key_for_t0 in [ @@ -1233,9 +1223,41 @@ def create_t0_and_loc_datapipes( overlapping_datapipe = contiguous_time_datapipes[0] # Select time periods and set length - key_datapipe = key_datapipe.filter_time_periods(time_periods=overlapping_datapipe) + valid_t0_periods_datapipe = key_datapipe.filter_time_periods(time_periods=overlapping_datapipe) + + return valid_t0_periods_datapipe + + +def create_t0_and_loc_datapipes( + datapipes_dict: dict, + configuration: Configuration, + key_for_t0: str = "gsp", + shuffle: bool = True, +): + """ + Takes source datapipes and returns datapipes of appropriate sample pairs of locations and times. + + The (location, t0) pairs are sampled without replacement. + + Args: + datapipes_dict: Dictionary of datapipes of input sources for which we want to select + appropriate location and times. + configuration: Configuration object for inputs. + key_for_t0: Key to use for the t0 datapipe. Must be "gsp" or "pv". + shuffle: Whether to use the internal shuffle function when yielding location times. Else + location times will be heavily ordered. + + Returns: + location datapipe, t0 datapipe + """ + + valid_t0_periods_datapipe = create_valid_t0_periods_datapipe( + datapipes_dict, + configuration, + key_for_t0, + ) - t0_loc_datapipe = key_datapipe.pick_locs_and_t0s(return_all=True, shuffle=shuffle) + t0_loc_datapipe = valid_t0_periods_datapipe.pick_locs_and_t0s(return_all=True, shuffle=shuffle) location_pipe, t0_datapipe = t0_loc_datapipe.unzip(sequence_length=2) diff --git a/ocf_datapipes/training/pvnet.py b/ocf_datapipes/training/pvnet.py index 914ab2db4..7b82ac85b 100644 --- a/ocf_datapipes/training/pvnet.py +++ b/ocf_datapipes/training/pvnet.py @@ -75,42 +75,12 @@ def slice_datapipes_by_space( return -def construct_sliced_data_pipeline( - config_filename: str, - location_pipe: IterDataPipe, - t0_datapipe: IterDataPipe, - production: bool = False, - check_satellite_no_zeros: bool = False, -) -> IterDataPipe: - """Constructs data pipeline for the input data config file. - - This yields samples from the location and time datapipes. - - Args: - config_filename: Path to config file. - location_pipe: Datapipe yielding locations. - t0_datapipe: Datapipe yielding times. - production: Whether constucting pipeline for production inference. - check_satellite_no_zeros: Whether to check that satellite data has no zeros. - """ - - datapipes_dict = _get_datapipes_dict( - config_filename, - production=production, - ) - - configuration = datapipes_dict.pop("config") +def process_and_combine_datapipes(datapipes_dict, configuration, check_satellite_no_nans=False): + """Normalize and convert data to numpy arrays""" # Unpack for convenience conf_nwp = configuration.input_data.nwp - # Slice all of the datasets by spce - this is an in-place operation - slice_datapipes_by_space(datapipes_dict, location_pipe, configuration) - - # Slice all of the datasets by time - this is an in-place operation - slice_datapipes_by_time(datapipes_dict, t0_datapipe, configuration, production) - - # Spatially slice, normalize, and convert data to numpy arrays numpy_modalities = [] # Normalise the inputs and convert to numpy format @@ -158,8 +128,7 @@ def construct_sliced_data_pipeline( logger.debug("Combine all the data sources") combined_datapipe = MergeNumpyModalities(numpy_modalities).add_sun_position(modality_name="gsp") - logger.info("Filtering out samples with no data") - if check_satellite_no_zeros: + if check_satellite_no_nans: # in production we don't want any nans in the satellite data combined_datapipe = combined_datapipe.map(check_nans_in_satellite_data) @@ -168,6 +137,46 @@ def construct_sliced_data_pipeline( return combined_datapipe +def construct_sliced_data_pipeline( + config_filename: str, + location_pipe: IterDataPipe, + t0_datapipe: IterDataPipe, + production: bool = False, + check_satellite_no_nans: bool = False, +) -> IterDataPipe: + """Constructs data pipeline for the input data config file. + + This yields samples from the location and time datapipes. + + Args: + config_filename: Path to config file. + location_pipe: Datapipe yielding locations. + t0_datapipe: Datapipe yielding times. + production: Whether constucting pipeline for production inference. + check_satellite_no_nans: Whether to check that satellite data has no nans. + """ + + datapipes_dict = _get_datapipes_dict( + config_filename, + production=production, + ) + + configuration = datapipes_dict.pop("config") + + # Slice all of the datasets by space - this is an in-place operation + slice_datapipes_by_space(datapipes_dict, location_pipe, configuration) + + # Slice all of the datasets by time - this is an in-place operation + slice_datapipes_by_time(datapipes_dict, t0_datapipe, configuration, production) + + # Normalise, and combine the data sources into NumpyBatches + combined_datapipe = process_and_combine_datapipes( + datapipes_dict, configuration, check_satellite_no_nans + ) + + return combined_datapipe + + def pvnet_datapipe( config_filename: str, start_time: Optional[datetime] = None, diff --git a/ocf_datapipes/training/pvnet_all_gsp.py b/ocf_datapipes/training/pvnet_all_gsp.py new file mode 100644 index 000000000..aa3da98c9 --- /dev/null +++ b/ocf_datapipes/training/pvnet_all_gsp.py @@ -0,0 +1,511 @@ +"""Create the training/validation datapipe for UK PVNet batches for all GSPs + +The main public functions are: + + [1] `pvnet_all_gsp_datapipe()` + This constructs a datapipe yielding batches with inputs for all 317 UK GSPs for random t0 + times. + + [2] `construct_sliced_data_pipeline()` + Given a datapipe yielding t0 times, this function constructs a datapipe yielding batches + with inputs for all 317 UK GSPs for the yielded t0 times. This function is used inside [1]. + +""" + +import logging +from datetime import datetime +from typing import List, Optional, Tuple, Union + +import xarray as xr +from torch.utils.data.datapipes._decorator import functional_datapipe +from torch.utils.data.datapipes.datapipe import IterDataPipe + +from ocf_datapipes.batch import MergeNumpyModalities, MergeNWPNumpyModalities +from ocf_datapipes.batch.merge_numpy_examples_to_batch import stack_np_examples_into_batch +from ocf_datapipes.config.model import Configuration +from ocf_datapipes.convert.numpy_batch import ( + convert_gsp_to_numpy_batch, + convert_nwp_to_numpy_batch, + convert_pv_to_numpy_batch, + convert_satellite_to_numpy_batch, +) +from ocf_datapipes.load.gsp.utils import GSPLocationLookup +from ocf_datapipes.select.select_spatial_slice import ( + select_spatial_slice_meters, + select_spatial_slice_pixels, +) +from ocf_datapipes.training.common import ( + _get_datapipes_dict, + check_nans_in_satellite_data, + concat_xr_time_utc, + create_valid_t0_periods_datapipe, + fill_nans_in_arrays, + fill_nans_in_pv, + normalize_gsp, + normalize_pv, + slice_datapipes_by_time, +) +from ocf_datapipes.utils.consts import ( + NWP_MEANS, + NWP_STDS, + RSS_MEAN, + RSS_STD, +) +from ocf_datapipes.utils.location import Location + +xr.set_options(keep_attrs=True) +logger = logging.getLogger("pvnet_all_gsp_datapipe") + + +# ---------------------------------- Utility datapipes --------------------------------- + + +def xr_compute(xr_data): + """Compute the xarray object""" + return xr_data.compute() + + +class SampleRepeat: + """Use a single input element to create a list of identical values""" + + def __init__(self, num_repeats): + """Use a single input element to create a list of identical values + + Args: + num_repeats: Length of the returned list of duplicated values + """ + self.num_repeats = num_repeats + + def __call__(self, x): + """Repeat the input a number of times as a list""" + return [x for _ in range(self.num_repeats)] + + +@functional_datapipe("list_map") +class ListMap(IterDataPipe): + """Datapipe used to appky function to each item in yielded list""" + + def __init__(self, source_datapipe: IterDataPipe, func, *args, **kwargs): + """Datapipe used to appky function to each item in yielded list. + + Args: + source_datapipe: The source datapipe yielding lists of samples + func: The function to apply to all items in the list + *args: Args to pass to the function + **kwargs: Keyword arguments to pass to the function + + """ + + self.source_datapipe = source_datapipe + self.func = func + self._args = args + self._kwargs = kwargs + + def __iter__(self): + for element_list in self.source_datapipe: + yield [self.func(x, *self._args, **self._kwargs) for x in element_list] + + +# ------------------------------ Multi-location datapipes ------------------------------ +# These are datapipes rewritten to run on all GSPs + + +@functional_datapipe("select_all_gsp_spatial_slices_pixels") +class SelectAllGSPSpatialSlicePixelsIterDataPipe(IterDataPipe): + """Select all the spatial slices""" + + def __init__( + self, + source_datapipe: IterDataPipe, + locations: List[Location], + roi_height_pixels: int, + roi_width_pixels: int, + allow_partial_slice: bool = False, + location_idx_name: Optional[str] = None, + ): + """ + Select spatial slices for all GSPs + + If `allow_partial_slice` is set to True, then slices may be made which intersect the border + of the input data. The additional x and y cordinates that would be required for this slice + are extrapolated based on the average spacing of these coordinates in the input data. + However, currently slices cannot be made where the centre of the window is outside of the + input data. + + Args: + source_datapipe: Datapipe of Xarray data + locations: List of all locations to create samples for + roi_height_pixels: ROI height in pixels + roi_width_pixels: ROI width in pixels + allow_partial_slice: Whether to allow a partial slice. + location_idx_name: Name for location index of unstructured grid data, + None if not relevant + """ + self.source_datapipe = source_datapipe + self.locations = locations + self.roi_height_pixels = roi_height_pixels + self.roi_width_pixels = roi_width_pixels + self.allow_partial_slice = allow_partial_slice + self.location_idx_name = location_idx_name + + def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: + for xr_data in self.source_datapipe: + loc_slices = [] + + for location in self.locations: + selected = select_spatial_slice_pixels( + xr_data, + location, + self.roi_width_pixels, + self.roi_height_pixels, + self.allow_partial_slice, + self.location_idx_name, + ) + + loc_slices.append(selected) + + yield loc_slices + + +@functional_datapipe("select_all_gsp_spatial_slice_meters") +class SelectAllGSPSpatialSliceMetersIterDataPipe(IterDataPipe): + """Select spatial slice based off meters from point of interest""" + + def __init__( + self, + source_datapipe: IterDataPipe, + locations: List[Location], + roi_height_meters: int, + roi_width_meters: int, + dim_name: Optional[str] = None, + ): + """ + Select spatial slice based off pixels from point of interest + + Args: + source_datapipe: Datapipe of Xarray data + locations: List of all locations to create samples for + roi_width_meters: ROI width in meters + roi_height_meters: ROI height in meters + dim_name: Dimension name to select for ID, None for coordinates + + Notes: + Using spatial slicing based on distance rather than number of pixels will often yield + slices which can vary by 1 pixel in height and/or width. + + E.g. Suppose the Xarray data has x-coords = [1,2,3,4,5]. We want to slice a spatial + window with a size which equates to 2.2 along the x-axis. If we choose to slice around + the point x=3 this will slice out the x-coords [2,3,4]. If we choose to slice around the + point x=2.5 this will slice out the x-coords [2,3]. Hence the returned slice can have + size either 2 or 3 in the x-axis depending on the spatial location selected. + + Also, if selecting over a large span of latitudes, this may also causes pixel sizes of + the yielded outputs to change. For example, if the Xarray data is on a regularly spaced + longitude-latitude grid, then the structure of the grid means that the longitudes near + to the poles are spaced closer together (measured in meters) than at the equator. So + slices near the equator will have less pixels in the x-axis than slices taken near the + poles. + """ + self.source_datapipe = source_datapipe + self.locations = locations + self.roi_width_meters = roi_width_meters + self.roi_height_meters = roi_height_meters + self.dim_name = dim_name + + def __iter__(self) -> Union[xr.DataArray, xr.Dataset]: + for xr_data in self.source_datapipe: + loc_slices = [] + + for location in self.locations: + selected = select_spatial_slice_meters( + xr_data=xr_data, + location=location, + roi_width_meters=self.roi_width_meters, + roi_height_meters=self.roi_height_meters, + dim_name=self.dim_name, + ) + + loc_slices.append(selected) + + yield loc_slices + + +# ------------------------------- Time pipeline functions ------------------------------ + + +def construct_time_pipeline( + config_filename: str, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, +) -> Tuple[IterDataPipe, IterDataPipe]: + """Construct time pipeline for the input data config file. + + Args: + config_filename: Path to config file. + start_time: Minimum time for time datapipe. + end_time: Maximum time for time datapipe. + """ + + datapipes_dict = _get_datapipes_dict(config_filename) + + # Get config + config = datapipes_dict.pop("config") + + if (start_time is not None) or (end_time is not None): + datapipes_dict["gsp"] = datapipes_dict["gsp"].filter_times(start_time, end_time) + + # Get overlapping time periods + t0_datapipe = create_t0_datapipe( + datapipes_dict, + configuration=config, + shuffle=True, + ) + + return t0_datapipe + + +def create_t0_datapipe( + datapipes_dict: dict, + configuration: Configuration, + shuffle: bool = True, +): + """ + Takes source datapipes and returns datapipes of appropriate t0 times. + + The t0 times are sampled without replacement. + + Args: + datapipes_dict: Dictionary of datapipes of input sources for which we want to select + appropriate t0 times. + configuration: Configuration object for inputs. + shuffle: Whether to use the internal shuffle function when yielding times. Else + location times will be heavily ordered. + + Returns: + t0 datapipe + + """ + valid_t0_periods_datapipe = create_valid_t0_periods_datapipe( + datapipes_dict, + configuration, + key_for_t0="gsp", + ) + + t0_datapipe = valid_t0_periods_datapipe.pick_t0_times(return_all=True, shuffle=shuffle) + + return t0_datapipe + + +# ------------------------------- Space pipeline functions ----------------------------- + + +def slice_datapipes_by_space_all_gsps( + datapipes_dict: dict, + locations: list[Location], + configuration: Configuration, +) -> None: + """Slice the dictionary of datapipes by space in-place""" + conf_nwp = configuration.input_data.nwp + conf_sat = configuration.input_data.satellite + + if "nwp" in datapipes_dict: + for nwp_key, nwp_datapipe in datapipes_dict["nwp"].items(): + datapipes_dict["nwp"][nwp_key] = nwp_datapipe.select_all_gsp_spatial_slices_pixels( + locations, + roi_width_pixels=conf_nwp[nwp_key].nwp_image_size_pixels_width, + roi_height_pixels=conf_nwp[nwp_key].nwp_image_size_pixels_height, + ) + + if "sat" in datapipes_dict: + datapipes_dict["sat"] = datapipes_dict["sat"].select_all_gsp_spatial_slices_pixels( + locations, + roi_width_pixels=conf_sat.satellite_image_size_pixels_width, + roi_height_pixels=conf_sat.satellite_image_size_pixels_height, + ) + + if "pv" in datapipes_dict: + # No spatial slice for PV since it is always the same, just repeat for GSPs + datapipes_dict["pv"] = datapipes_dict["pv"].map(SampleRepeat(len(locations))) + + # GSP always assumed to be in data + datapipes_dict["gsp"] = datapipes_dict["gsp"].select_all_gsp_spatial_slice_meters( + locations, + roi_width_meters=1, + roi_height_meters=1, + dim_name="gsp_id", + ) + + +# -------------------------------- Processing functions -------------------------------- + + +def pre_spatial_slice_process(datapipes_dict, configuration): + """Apply pre-processing steps to the dictionary of datapipes in place + + These steps are normalisation and recombining past and future GSP/PV data + """ + conf_nwp = configuration.input_data.nwp + + if "nwp" in datapipes_dict: + for nwp_key, nwp_datapipe in datapipes_dict["nwp"].items(): + datapipes_dict["nwp"][nwp_key] = nwp_datapipe.map(xr_compute).normalize( + mean=NWP_MEANS[conf_nwp[nwp_key].nwp_provider], + std=NWP_STDS[conf_nwp[nwp_key].nwp_provider], + ) + + if "sat" in datapipes_dict: + datapipes_dict["sat"] = ( + datapipes_dict["sat"].map(xr_compute).normalize(mean=RSS_MEAN, std=RSS_STD) + ) + + if "pv" in datapipes_dict: + # Recombine PV arrays - see function doc for further explanation + datapipes_dict["pv"] = ( + datapipes_dict["pv"] + .zip_ocf(datapipes_dict["pv_future"]) + .map(concat_xr_time_utc) + .map(normalize_pv) + .map(fill_nans_in_pv) + ) + + del datapipes_dict["pv_future"] + + # GSP always assumed to be in data + # Recombine GSP arrays - see function doc for further explanation + datapipes_dict["gsp"] = ( + datapipes_dict["gsp"] + .zip_ocf(datapipes_dict["gsp_future"]) + .map(concat_xr_time_utc) + .map(normalize_gsp) + ) + + del datapipes_dict["gsp_future"] + + +def post_spatial_slice_process(datapipes_dict, check_satellite_no_nans=False): + """Convert the dictionary of datapipes to NumpyBatches, combine, and fill nans""" + + numpy_modalities = [] + + if "nwp" in datapipes_dict: + nwp_numpy_modalities = dict() + + for nwp_key, nwp_datapipe in datapipes_dict["nwp"].items(): + nwp_numpy_modalities[nwp_key] = nwp_datapipe.list_map(convert_nwp_to_numpy_batch).map( + stack_np_examples_into_batch + ) + + # Combine the NWPs into NumpyBatch + nwp_numpy_modalities = MergeNWPNumpyModalities(nwp_numpy_modalities) + numpy_modalities.append(nwp_numpy_modalities) + + if "sat" in datapipes_dict: + numpy_modalities.append( + datapipes_dict["sat"] + .list_map(convert_satellite_to_numpy_batch) + .map(stack_np_examples_into_batch) + ) + + if "pv" in datapipes_dict: + numpy_modalities.append( + datapipes_dict["pv"] + .list_map(convert_pv_to_numpy_batch) + .map(stack_np_examples_into_batch) + ) + + # GSP always assumed to be in data + numpy_modalities.append( + datapipes_dict["gsp"].list_map(convert_gsp_to_numpy_batch).map(stack_np_examples_into_batch) + ) + + # Combine all the data sources + combined_datapipe = MergeNumpyModalities(numpy_modalities).add_sun_position(modality_name="gsp") + + if check_satellite_no_nans: + # in production we don't want any nans in the satellite data + combined_datapipe = combined_datapipe.map(check_nans_in_satellite_data) + + combined_datapipe = combined_datapipe.map(fill_nans_in_arrays) + + return combined_datapipe + + +# --------------------------- High level pipeline functions ---------------------------- + + +def construct_sliced_data_pipeline( + config_filename: str, + t0_datapipe: IterDataPipe, + production: bool = False, + check_satellite_no_nans: bool = False, +) -> IterDataPipe: + """Constructs data pipeline for the input data config file. + + This yields samples from the location and time datapipes. + + Args: + config_filename: Path to config file. + t0_datapipe: Datapipe yielding times. + production: Whether constucting pipeline for production inference. + check_satellite_no_nans: Whether to check that satellite data has no nans. + """ + + datapipes_dict = _get_datapipes_dict( + config_filename, + production=production, + ) + + # Get the location objects for all 317 regional GSPs + gsp_id_to_loc = GSPLocationLookup() + locations = [gsp_id_to_loc(gsp_id) for gsp_id in range(1, 318)] + + # Pop config + configuration = datapipes_dict.pop("config") + + # Slice all of the datasets by time - this is an in-place operation + slice_datapipes_by_time(datapipes_dict, t0_datapipe, configuration, production) + + # Run compute and normalise all the data + pre_spatial_slice_process(datapipes_dict, configuration) + + # Slice all of the datasets by space - this is an in-place operation + slice_datapipes_by_space_all_gsps(datapipes_dict, locations, configuration) + + # Convert to NumpyBatch + combined_datapipe = post_spatial_slice_process(datapipes_dict) + + return combined_datapipe + + +def pvnet_all_gsp_datapipe( + config_filename: str, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, +) -> IterDataPipe: + """ + Construct pvnet pipeline for the input data config file. + + Args: + config_filename: Path to config file. + start_time: Minimum time at which a sample can be selected. + end_time: Maximum time at which a sample can be selected. + """ + + # Open datasets from the config and filter to useable times + t0_datapipe = construct_time_pipeline( + config_filename, + start_time, + end_time, + ) + + # Shard after we have the times. These are already shuffled so no need to shuffle again + t0_datapipe = t0_datapipe.sharding_filter() + + # In this function we re-open the datasets to make a clean separation before/after sharding + # This function + datapipe = construct_sliced_data_pipeline( + config_filename, + t0_datapipe, + ) + + return datapipe diff --git a/tests/conftest.py b/tests/conftest.py index 34cc5d4b7..94d03972f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -589,6 +589,11 @@ def nwp_ukv_data_filename(nwp_ukv_data): yield filename +@pytest.fixture() +def pvnet_config_filename(): + return f"{_top_test_directory}/data/configs/pvnet_test_config.yaml" + + @pytest.fixture() def configuration_filename(): return f"{_top_test_directory}/data/configs/test.yaml" diff --git a/tests/data/configs/pvnet_test_config.yaml b/tests/data/configs/pvnet_test_config.yaml new file mode 100644 index 000000000..fb4d6c3ea --- /dev/null +++ b/tests/data/configs/pvnet_test_config.yaml @@ -0,0 +1,43 @@ +general: + description: Test config for PVNet + name: pvnet_test + +input_data: + default_history_minutes: 120 + default_forecast_minutes: 480 + + gsp: + gsp_zarr_path: tests/data/gsp/test.zarr + history_minutes: 120 + forecast_minutes: 480 + time_resolution_minutes: 30 + dropout_timedeltas_minutes: null + dropout_fraction: 0 + + nwp: + ukv: + nwp_provider: ukv + nwp_zarr_path: tests/data/nwp_data/test.zarr + history_minutes: 60 + forecast_minutes: 120 + time_resolution_minutes: 60 + nwp_channels: + - t # 2-metre temperature + nwp_image_size_pixels_height: 2 + nwp_image_size_pixels_width: 2 + dropout_timedeltas_minutes: [-180] + dropout_fraction: 1.0 + max_staleness_minutes: null + + satellite: + satellite_zarr_path: tests/data/sat_data.zarr + history_minutes: 90 + forecast_minutes: 0 + live_delay_minutes: 0 + time_resolution_minutes: 5 + satellite_channels: + - IR_016 + satellite_image_size_pixels_height: 2 + satellite_image_size_pixels_width: 2 + dropout_timedeltas_minutes: null + dropout_fraction: 0 diff --git a/tests/data/gsp/test.zarr/.zmetadata b/tests/data/gsp/test.zarr/.zmetadata index 0aade4476..9fc94d56f 100644 --- a/tests/data/gsp/test.zarr/.zmetadata +++ b/tests/data/gsp/test.zarr/.zmetadata @@ -6,8 +6,8 @@ }, "capacity_mwp/.zarray": { "chunks": [ - 9430, - 20 + 10, + 318 ], "compressor": { "blocksize": 0, @@ -21,8 +21,8 @@ "filters": null, "order": "C", "shape": [ - 49, - 22 + 96, + 318 ], "zarr_format": 2 }, @@ -34,7 +34,7 @@ }, "datetime_gmt/.zarray": { "chunks": [ - 37717 + 96 ], "compressor": { "blocksize": 0, @@ -48,7 +48,7 @@ "filters": null, "order": "C", "shape": [ - 49 + 96 ], "zarr_format": 2 }, @@ -57,17 +57,17 @@ "datetime_gmt" ], "calendar": "proleptic_gregorian", - "units": "minutes since 2014-01-01" + "units": "minutes since 2020-04-01 00:00:00" }, "generation_mw/.zarray": { "chunks": [ - 9430, - 20 + 10, + 318 ], "compressor": { "blocksize": 0, "clevel": 5, - "cname": "zstd", + "cname": "lz4", "id": "blosc", "shuffle": 1 }, @@ -76,8 +76,8 @@ "filters": null, "order": "C", "shape": [ - 49, - 22 + 96, + 318 ], "zarr_format": 2 }, @@ -103,7 +103,7 @@ "filters": null, "order": "C", "shape": [ - 22 + 318 ], "zarr_format": 2 }, @@ -114,8 +114,8 @@ }, "installedcapacity_mwp/.zarray": { "chunks": [ - 9430, - 20 + 10, + 318 ], "compressor": { "blocksize": 0, @@ -129,8 +129,8 @@ "filters": null, "order": "C", "shape": [ - 49, - 22 + 96, + 318 ], "zarr_format": 2 }, diff --git a/tests/data/gsp/test.zarr/capacity_mwp/.zarray b/tests/data/gsp/test.zarr/capacity_mwp/.zarray index 7534fa734..1b355953d 100644 --- a/tests/data/gsp/test.zarr/capacity_mwp/.zarray +++ b/tests/data/gsp/test.zarr/capacity_mwp/.zarray @@ -1,7 +1,7 @@ { "chunks": [ - 9430, - 20 + 10, + 318 ], "compressor": { "blocksize": 0, @@ -15,8 +15,8 @@ "filters": null, "order": "C", "shape": [ - 49, - 22 + 96, + 318 ], "zarr_format": 2 } diff --git a/tests/data/gsp/test.zarr/capacity_mwp/0.0 b/tests/data/gsp/test.zarr/capacity_mwp/0.0 index 8e5b65e7c..3c2036b08 100644 Binary files a/tests/data/gsp/test.zarr/capacity_mwp/0.0 and b/tests/data/gsp/test.zarr/capacity_mwp/0.0 differ diff --git a/tests/data/gsp/test.zarr/capacity_mwp/0.1 b/tests/data/gsp/test.zarr/capacity_mwp/0.1 deleted file mode 100644 index dd08e7c6c..000000000 Binary files a/tests/data/gsp/test.zarr/capacity_mwp/0.1 and /dev/null differ diff --git a/tests/data/gsp/test.zarr/capacity_mwp/1.0 b/tests/data/gsp/test.zarr/capacity_mwp/1.0 new file mode 100644 index 000000000..d4396ada6 Binary files /dev/null and b/tests/data/gsp/test.zarr/capacity_mwp/1.0 differ diff --git a/tests/data/gsp/test.zarr/capacity_mwp/2.0 b/tests/data/gsp/test.zarr/capacity_mwp/2.0 new file mode 100644 index 000000000..d4396ada6 Binary files /dev/null and b/tests/data/gsp/test.zarr/capacity_mwp/2.0 differ diff --git a/tests/data/gsp/test.zarr/capacity_mwp/3.0 b/tests/data/gsp/test.zarr/capacity_mwp/3.0 new file mode 100644 index 000000000..72a61750f Binary files /dev/null and b/tests/data/gsp/test.zarr/capacity_mwp/3.0 differ diff --git a/tests/data/gsp/test.zarr/capacity_mwp/4.0 b/tests/data/gsp/test.zarr/capacity_mwp/4.0 new file mode 100644 index 000000000..c3cf2cedf Binary files /dev/null and b/tests/data/gsp/test.zarr/capacity_mwp/4.0 differ diff --git a/tests/data/gsp/test.zarr/capacity_mwp/5.0 b/tests/data/gsp/test.zarr/capacity_mwp/5.0 new file mode 100644 index 000000000..565dfe7ec Binary files /dev/null and b/tests/data/gsp/test.zarr/capacity_mwp/5.0 differ diff --git a/tests/data/gsp/test.zarr/capacity_mwp/6.0 b/tests/data/gsp/test.zarr/capacity_mwp/6.0 new file mode 100644 index 000000000..565dfe7ec Binary files /dev/null and b/tests/data/gsp/test.zarr/capacity_mwp/6.0 differ diff --git a/tests/data/gsp/test.zarr/capacity_mwp/7.0 b/tests/data/gsp/test.zarr/capacity_mwp/7.0 new file mode 100644 index 000000000..ee65a22f1 Binary files /dev/null and b/tests/data/gsp/test.zarr/capacity_mwp/7.0 differ diff --git a/tests/data/gsp/test.zarr/capacity_mwp/8.0 b/tests/data/gsp/test.zarr/capacity_mwp/8.0 new file mode 100644 index 000000000..56e351216 Binary files /dev/null and b/tests/data/gsp/test.zarr/capacity_mwp/8.0 differ diff --git a/tests/data/gsp/test.zarr/capacity_mwp/9.0 b/tests/data/gsp/test.zarr/capacity_mwp/9.0 new file mode 100644 index 000000000..0b973f644 Binary files /dev/null and b/tests/data/gsp/test.zarr/capacity_mwp/9.0 differ diff --git a/tests/data/gsp/test.zarr/datetime_gmt/.zarray b/tests/data/gsp/test.zarr/datetime_gmt/.zarray index 3fad0aa37..513e45066 100644 --- a/tests/data/gsp/test.zarr/datetime_gmt/.zarray +++ b/tests/data/gsp/test.zarr/datetime_gmt/.zarray @@ -1,6 +1,6 @@ { "chunks": [ - 37717 + 96 ], "compressor": { "blocksize": 0, @@ -14,7 +14,7 @@ "filters": null, "order": "C", "shape": [ - 49 + 96 ], "zarr_format": 2 } diff --git a/tests/data/gsp/test.zarr/datetime_gmt/.zattrs b/tests/data/gsp/test.zarr/datetime_gmt/.zattrs index 83ad40ac4..355ac1574 100644 --- a/tests/data/gsp/test.zarr/datetime_gmt/.zattrs +++ b/tests/data/gsp/test.zarr/datetime_gmt/.zattrs @@ -3,5 +3,5 @@ "datetime_gmt" ], "calendar": "proleptic_gregorian", - "units": "minutes since 2014-01-01" + "units": "minutes since 2020-04-01 00:00:00" } diff --git a/tests/data/gsp/test.zarr/datetime_gmt/0 b/tests/data/gsp/test.zarr/datetime_gmt/0 index e1b4c4836..136f201fb 100644 Binary files a/tests/data/gsp/test.zarr/datetime_gmt/0 and b/tests/data/gsp/test.zarr/datetime_gmt/0 differ diff --git a/tests/data/gsp/test.zarr/generation_mw/.zarray b/tests/data/gsp/test.zarr/generation_mw/.zarray index b6ee6cf8c..1b355953d 100644 --- a/tests/data/gsp/test.zarr/generation_mw/.zarray +++ b/tests/data/gsp/test.zarr/generation_mw/.zarray @@ -1,12 +1,12 @@ { "chunks": [ - 9430, - 20 + 10, + 318 ], "compressor": { "blocksize": 0, "clevel": 5, - "cname": "zstd", + "cname": "lz4", "id": "blosc", "shuffle": 1 }, @@ -15,8 +15,8 @@ "filters": null, "order": "C", "shape": [ - 49, - 22 + 96, + 318 ], "zarr_format": 2 } diff --git a/tests/data/gsp/test.zarr/generation_mw/0.0 b/tests/data/gsp/test.zarr/generation_mw/0.0 index c0bb8bdd8..5a41ef13d 100644 Binary files a/tests/data/gsp/test.zarr/generation_mw/0.0 and b/tests/data/gsp/test.zarr/generation_mw/0.0 differ diff --git a/tests/data/gsp/test.zarr/generation_mw/0.1 b/tests/data/gsp/test.zarr/generation_mw/0.1 deleted file mode 100644 index 567024122..000000000 Binary files a/tests/data/gsp/test.zarr/generation_mw/0.1 and /dev/null differ diff --git a/tests/data/gsp/test.zarr/generation_mw/1.0 b/tests/data/gsp/test.zarr/generation_mw/1.0 new file mode 100644 index 000000000..cf836e406 Binary files /dev/null and b/tests/data/gsp/test.zarr/generation_mw/1.0 differ diff --git a/tests/data/gsp/test.zarr/generation_mw/2.0 b/tests/data/gsp/test.zarr/generation_mw/2.0 new file mode 100644 index 000000000..93d0551d8 Binary files /dev/null and b/tests/data/gsp/test.zarr/generation_mw/2.0 differ diff --git a/tests/data/gsp/test.zarr/generation_mw/3.0 b/tests/data/gsp/test.zarr/generation_mw/3.0 new file mode 100644 index 000000000..f83851351 Binary files /dev/null and b/tests/data/gsp/test.zarr/generation_mw/3.0 differ diff --git a/tests/data/gsp/test.zarr/generation_mw/4.0 b/tests/data/gsp/test.zarr/generation_mw/4.0 new file mode 100644 index 000000000..5a41ef13d Binary files /dev/null and b/tests/data/gsp/test.zarr/generation_mw/4.0 differ diff --git a/tests/data/gsp/test.zarr/generation_mw/5.0 b/tests/data/gsp/test.zarr/generation_mw/5.0 new file mode 100644 index 000000000..c5e6b786f Binary files /dev/null and b/tests/data/gsp/test.zarr/generation_mw/5.0 differ diff --git a/tests/data/gsp/test.zarr/generation_mw/6.0 b/tests/data/gsp/test.zarr/generation_mw/6.0 new file mode 100644 index 000000000..227500da1 Binary files /dev/null and b/tests/data/gsp/test.zarr/generation_mw/6.0 differ diff --git a/tests/data/gsp/test.zarr/generation_mw/7.0 b/tests/data/gsp/test.zarr/generation_mw/7.0 new file mode 100644 index 000000000..7cb60a60a Binary files /dev/null and b/tests/data/gsp/test.zarr/generation_mw/7.0 differ diff --git a/tests/data/gsp/test.zarr/generation_mw/8.0 b/tests/data/gsp/test.zarr/generation_mw/8.0 new file mode 100644 index 000000000..209c416db Binary files /dev/null and b/tests/data/gsp/test.zarr/generation_mw/8.0 differ diff --git a/tests/data/gsp/test.zarr/generation_mw/9.0 b/tests/data/gsp/test.zarr/generation_mw/9.0 new file mode 100644 index 000000000..827e35769 Binary files /dev/null and b/tests/data/gsp/test.zarr/generation_mw/9.0 differ diff --git a/tests/data/gsp/test.zarr/gsp_id/.zarray b/tests/data/gsp/test.zarr/gsp_id/.zarray index 9a1f6f31e..dafc11290 100644 --- a/tests/data/gsp/test.zarr/gsp_id/.zarray +++ b/tests/data/gsp/test.zarr/gsp_id/.zarray @@ -14,7 +14,7 @@ "filters": null, "order": "C", "shape": [ - 22 + 318 ], "zarr_format": 2 } diff --git a/tests/data/gsp/test.zarr/gsp_id/0 b/tests/data/gsp/test.zarr/gsp_id/0 index b3ff4eaaa..840af40af 100644 Binary files a/tests/data/gsp/test.zarr/gsp_id/0 and b/tests/data/gsp/test.zarr/gsp_id/0 differ diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/.zarray b/tests/data/gsp/test.zarr/installedcapacity_mwp/.zarray index 7534fa734..1b355953d 100644 --- a/tests/data/gsp/test.zarr/installedcapacity_mwp/.zarray +++ b/tests/data/gsp/test.zarr/installedcapacity_mwp/.zarray @@ -1,7 +1,7 @@ { "chunks": [ - 9430, - 20 + 10, + 318 ], "compressor": { "blocksize": 0, @@ -15,8 +15,8 @@ "filters": null, "order": "C", "shape": [ - 49, - 22 + 96, + 318 ], "zarr_format": 2 } diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/0.0 b/tests/data/gsp/test.zarr/installedcapacity_mwp/0.0 index ef2d86fc6..31e97017c 100644 Binary files a/tests/data/gsp/test.zarr/installedcapacity_mwp/0.0 and b/tests/data/gsp/test.zarr/installedcapacity_mwp/0.0 differ diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/0.1 b/tests/data/gsp/test.zarr/installedcapacity_mwp/0.1 deleted file mode 100644 index 7824ab3eb..000000000 Binary files a/tests/data/gsp/test.zarr/installedcapacity_mwp/0.1 and /dev/null differ diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/1.0 b/tests/data/gsp/test.zarr/installedcapacity_mwp/1.0 new file mode 100644 index 000000000..d03374c2a Binary files /dev/null and b/tests/data/gsp/test.zarr/installedcapacity_mwp/1.0 differ diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/2.0 b/tests/data/gsp/test.zarr/installedcapacity_mwp/2.0 new file mode 100644 index 000000000..d03374c2a Binary files /dev/null and b/tests/data/gsp/test.zarr/installedcapacity_mwp/2.0 differ diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/3.0 b/tests/data/gsp/test.zarr/installedcapacity_mwp/3.0 new file mode 100644 index 000000000..d03374c2a Binary files /dev/null and b/tests/data/gsp/test.zarr/installedcapacity_mwp/3.0 differ diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/4.0 b/tests/data/gsp/test.zarr/installedcapacity_mwp/4.0 new file mode 100644 index 000000000..8dc697c7c Binary files /dev/null and b/tests/data/gsp/test.zarr/installedcapacity_mwp/4.0 differ diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/5.0 b/tests/data/gsp/test.zarr/installedcapacity_mwp/5.0 new file mode 100644 index 000000000..6e62da389 Binary files /dev/null and b/tests/data/gsp/test.zarr/installedcapacity_mwp/5.0 differ diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/6.0 b/tests/data/gsp/test.zarr/installedcapacity_mwp/6.0 new file mode 100644 index 000000000..6e62da389 Binary files /dev/null and b/tests/data/gsp/test.zarr/installedcapacity_mwp/6.0 differ diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/7.0 b/tests/data/gsp/test.zarr/installedcapacity_mwp/7.0 new file mode 100644 index 000000000..6e62da389 Binary files /dev/null and b/tests/data/gsp/test.zarr/installedcapacity_mwp/7.0 differ diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/8.0 b/tests/data/gsp/test.zarr/installedcapacity_mwp/8.0 new file mode 100644 index 000000000..6e62da389 Binary files /dev/null and b/tests/data/gsp/test.zarr/installedcapacity_mwp/8.0 differ diff --git a/tests/data/gsp/test.zarr/installedcapacity_mwp/9.0 b/tests/data/gsp/test.zarr/installedcapacity_mwp/9.0 new file mode 100644 index 000000000..836867bf6 Binary files /dev/null and b/tests/data/gsp/test.zarr/installedcapacity_mwp/9.0 differ diff --git a/tests/production/test_pvnet_production.py b/tests/production/test_pvnet_production.py index 572d19766..8f2fae68c 100644 --- a/tests/production/test_pvnet_production.py +++ b/tests/production/test_pvnet_production.py @@ -40,7 +40,7 @@ def test_construct_sliced_data_pipeline(configuration_filename, gsp_yields): configuration_filename, location_pipe=loc_pipe, t0_datapipe=t0_pipe, - check_satellite_no_zeros=True, + check_satellite_no_nans=True, production=True, ) @@ -48,7 +48,7 @@ def test_construct_sliced_data_pipeline(configuration_filename, gsp_yields): @freeze_time("2020-04-01 02:30:00") -def test_construct_sliced_data_pipeline_satellite_with_zeros(configuration_filename, gsp_yields): +def test_construct_sliced_data_pipeline_satellite_with_nans(configuration_filename, gsp_yields): # This is randomly chosen, but real, GSP location loc_pipe = IterableWrapper([Location(x=246699.328125, y=849771.9375, id=18)]) @@ -59,8 +59,8 @@ def test_construct_sliced_data_pipeline_satellite_with_zeros(configuration_filen configuration_filename, location_pipe=loc_pipe, t0_datapipe=t0_pipe, - check_satellite_no_zeros=True, + check_satellite_no_nans=True, production=True, ) with pytest.raises(ValueError): - _ = next(iter(dp)) + batch = next(iter(dp)) diff --git a/tests/select/test_pick_locations.py b/tests/select/test_pick_locations.py index b52cec234..753014273 100644 --- a/tests/select/test_pick_locations.py +++ b/tests/select/test_pick_locations.py @@ -28,7 +28,7 @@ def test_pick_locations_all_locations(gsp_datapipe): sample_period_duration=timedelta(minutes=30), history_duration=timedelta(hours=1), ) - location_datapipe = PickLocations(gsp_datapipe, return_all_locations=True) + location_datapipe = PickLocations(gsp_datapipe, return_all=True) loc_iterator = iter(location_datapipe) for i in range(len(dataset["x_osgb"])): loc_data = next(loc_iterator) diff --git a/tests/select/test_select_spatial_slice.py b/tests/select/test_select_spatial_slice.py index 5e089fd38..cdb3a89f2 100644 --- a/tests/select/test_select_spatial_slice.py +++ b/tests/select/test_select_spatial_slice.py @@ -8,10 +8,10 @@ SelectSpatialSlicePixels, ) -from ocf_datapipes.select.select_spatial_slice import select_spatial_slice_pixels +from ocf_datapipes.select.select_spatial_slice import slice_spatial_pixel_window_from_xarray -def test_select_spatial_slice_pixels_function(): +def test_slice_spatial_pixel_window_from_xarray_function(): # Create dummy data x = np.arange(100) y = np.arange(100)[::-1] @@ -29,7 +29,7 @@ def test_select_spatial_slice_pixels_function(): center_idx = Location(x=10, y=10, coordinate_system="idx") # Select window which lies within data - xr_selected = select_spatial_slice_pixels( + xr_selected = slice_spatial_pixel_window_from_xarray( xr_data, center_idx, width_pixels=10, @@ -44,7 +44,7 @@ def test_select_spatial_slice_pixels_function(): assert not xr_selected.data.isnull().any() # Select window where the edge of the window lies at the edge of the data - xr_selected = select_spatial_slice_pixels( + xr_selected = slice_spatial_pixel_window_from_xarray( xr_data, center_idx, width_pixels=20, @@ -59,7 +59,7 @@ def test_select_spatial_slice_pixels_function(): assert not xr_selected.data.isnull().any() # Select window which is partially outside the boundary of the data - xr_selected = select_spatial_slice_pixels( + xr_selected = slice_spatial_pixel_window_from_xarray( xr_data, center_idx, width_pixels=30, @@ -114,7 +114,7 @@ def test_select_spatial_slice_pixel_icon_eu(passiv_datapipe, icon_eu_datapipe): def test_select_spatial_slice_pixel_icon_global(passiv_datapipe, icon_global_datapipe): - loc_datapipe = PickLocations(passiv_datapipe, return_all_locations=True) + loc_datapipe = PickLocations(passiv_datapipe, return_all=True) icon_global_datapipe = SelectSpatialSlicePixels( icon_global_datapipe, location_datapipe=loc_datapipe, @@ -146,7 +146,7 @@ def test_select_spatial_slice_meters_icon_eu(passiv_datapipe, icon_eu_datapipe): def test_select_spatial_slice_meters_icon_global(passiv_datapipe, icon_global_datapipe): - loc_datapipe = PickLocations(passiv_datapipe, return_all_locations=True) + loc_datapipe = PickLocations(passiv_datapipe, return_all=True) icon_global_datapipe = SelectSpatialSliceMeters( icon_global_datapipe, location_datapipe=loc_datapipe, diff --git a/tests/training/test_common.py b/tests/training/test_common.py index 89efef6a2..9237f417b 100644 --- a/tests/training/test_common.py +++ b/tests/training/test_common.py @@ -1,22 +1,21 @@ +from datetime import datetime +import numpy as np import pytest + from torch.utils.data.datapipes.datapipe import IterDataPipe from torch.utils.data.datapipes.iter import Zipper +from torch.utils.data import DataLoader + from ocf_datapipes.config.model import Configuration from ocf_datapipes.utils import Location -from torch.utils.data import DataLoader from ocf_datapipes.training.common import ( add_selected_time_slices_from_datapipes, get_and_return_overlapping_time_periods_and_t0, open_and_return_datapipes, create_t0_and_loc_datapipes, + construct_loctime_pipelines, ) -import fsspec -from pyaml_env import parse_config - -import pandas as pd -import numpy as np - def test_open_and_return_datapipes(configuration_filename): used_datapipes = open_and_return_datapipes(configuration_filename) @@ -111,3 +110,17 @@ def test_create_t0_and_loc_datapipes(configuration_filename): loc0, t0 = next(iter(location_pipe.zip(t0_datapipe))) assert isinstance(loc0, Location) assert isinstance(t0, np.datetime64) + + +def test_construct_loctime_pipelines(configuration_filename): + start_time = datetime(1900, 1, 1) + end_time = datetime(2050, 1, 1) + + loc_pipe, t0_pipe = construct_loctime_pipelines( + configuration_filename, + start_time=start_time, + end_time=end_time, + ) + + next(iter(loc_pipe)) + next(iter(t0_pipe)) diff --git a/tests/training/test_pvnet.py b/tests/training/test_pvnet.py index 16b5f6c0d..d5073875f 100644 --- a/tests/training/test_pvnet.py +++ b/tests/training/test_pvnet.py @@ -6,26 +6,11 @@ construct_sliced_data_pipeline, pvnet_datapipe, ) -from ocf_datapipes.training.common import construct_loctime_pipelines from ocf_datapipes.batch import BatchKey, NWPBatchKey from ocf_datapipes.utils import Location -def test_construct_loctime_pipelines(configuration_filename): - start_time = datetime(1900, 1, 1) - end_time = datetime(2050, 1, 1) - - loc_pipe, t0_pipe = construct_loctime_pipelines( - configuration_filename, - start_time=start_time, - end_time=end_time, - ) - - next(iter(loc_pipe)) - next(iter(t0_pipe)) - - -def test_construct_sliced_data_pipeline(configuration_filename): +def test_construct_sliced_data_pipeline(pvnet_config_filename): # This is randomly chosen, but real, GSP location loc_pipe = IterableWrapper([Location(x=246699.328125, y=849771.9375, id=18)]) @@ -33,7 +18,7 @@ def test_construct_sliced_data_pipeline(configuration_filename): t0_pipe = IterableWrapper([datetime(2020, 4, 1, 13, 30)]) dp = construct_sliced_data_pipeline( - configuration_filename, + pvnet_config_filename, location_pipe=loc_pipe, t0_datapipe=t0_pipe, ) @@ -44,12 +29,12 @@ def test_construct_sliced_data_pipeline(configuration_filename): assert NWPBatchKey.nwp in batch[BatchKey.nwp][nwp_source] -def test_pvnet_datapipe(configuration_filename): +def test_pvnet_datapipe(pvnet_config_filename): start_time = datetime(1900, 1, 1) end_time = datetime(2050, 1, 1) dp = pvnet_datapipe( - configuration_filename, + pvnet_config_filename, start_time=start_time, end_time=end_time, ) diff --git a/tests/training/test_pvnet_all_gsp.py b/tests/training/test_pvnet_all_gsp.py new file mode 100644 index 000000000..dbfc94b1b --- /dev/null +++ b/tests/training/test_pvnet_all_gsp.py @@ -0,0 +1,40 @@ +from datetime import datetime + +from torch.utils.data.datapipes.iter import IterableWrapper + +from ocf_datapipes.training.pvnet_all_gsp import ( + construct_sliced_data_pipeline, + pvnet_all_gsp_datapipe, +) +from ocf_datapipes.batch import BatchKey, NWPBatchKey + + +def test_construct_sliced_data_pipeline(pvnet_config_filename): + # This is a randomly chosen time in the middle of the test data + t0_pipe = IterableWrapper([datetime(2020, 4, 1, 13, 30)]) + + dp = construct_sliced_data_pipeline( + pvnet_config_filename, + t0_datapipe=t0_pipe, + ) + + batch = next(iter(dp)) + assert BatchKey.nwp in batch + for nwp_source in batch[BatchKey.nwp].keys(): + assert NWPBatchKey.nwp in batch[BatchKey.nwp][nwp_source] + + +def test_pvnet_all_gsp_datapipe(pvnet_config_filename): + start_time = datetime(1900, 1, 1) + end_time = datetime(2050, 1, 1) + + dp = pvnet_all_gsp_datapipe( + pvnet_config_filename, + start_time=start_time, + end_time=end_time, + ) + + batch = next(iter(dp)) + assert BatchKey.nwp in batch + for nwp_source in batch[BatchKey.nwp].keys(): + assert NWPBatchKey.nwp in batch[BatchKey.nwp][nwp_source] diff --git a/tests/transform/xarray/test_normalize.py b/tests/transform/xarray/test_normalize.py index fb00ee6f3..37b787bff 100644 --- a/tests/transform/xarray/test_normalize.py +++ b/tests/transform/xarray/test_normalize.py @@ -39,7 +39,9 @@ def test_normalize_gsp(gsp_datapipe): ) data = next(iter(normed_gsp_datapipe)) assert np.min(data) >= 0.0 - assert np.max(data) <= 1.0 + + # Some GSPs are noisey and seem to have values above 1 + assert np.max(data) <= 1.5 def test_normalize_passiv(passiv_datapipe):