diff --git a/nowcasting_dataset/data_sources/README.md b/nowcasting_dataset/data_sources/README.md index 818aa626..1ed8e11c 100644 --- a/nowcasting_dataset/data_sources/README.md +++ b/nowcasting_dataset/data_sources/README.md @@ -15,9 +15,11 @@ and the geospatial shape of each GSP region). # data_source.py General class used for making a data source. It has the following functions -- get_batch: gets a whole batch of data for that data source +- get_batch: gets a whole batch of data for that data source. The list of 'xr.Dataset' examples are converted to +one xr.Dataset by changing the coordinates to indexes, and then joining the examples along an extra dimension. - datetime_index: gets the all available datatimes of the source - get_example: gets one "example" (a single consecutive sequence). Each batch is made up of multiple examples. + Each example is a 'xr.Dataset' - get_locations_for_batch: Samples the geospatial x,y location for each example in a batch. This is useful because, typically, we want a single DataSource to dictate the geospatial locations of the examples (for example, we want each example to be centered on the centroid of the grid supply point region). All the other @@ -27,20 +29,19 @@ General class used for making a data source. It has the following functions # datasource_output.py General pydantic model of output of the data source. Contains the following methods -- to_numpy: changes all data points to numpy objects -- split: converts a batch to a list of items -- join: joins list of items to one -- to_xr_dataset: changes data items to xarrays and returns a dataset -- from_xr_dataset: loads from an xarray dataset -- select_time_period: subselect data, depending on a time period +- save_netcdf: save to netcdf file +- check_nan_and_inf: check if any values are nans or infinite +- check_dataset_greater_than_or_equal_to: check values are >= a value +- check_dataset_less_than_or_equal_to: check values are <= a value +- check_dataset_not_equal: check values are !>= a value +- check_data_var_dim: check the dimensions of a data variable # Data Source folder Roughly each of the data source folders follows this pattern - A class which defines how to load the data source, how to select for batches etc. This inherits from 'data_source.DataSource', -- A class which contains the output model of the data source, built from an xarray Dataset. This is the information used in the batches. +- A class which contains the output model of the data source, built from a xarray Dataset. This is the information used in the batches. This inherits from 'datasource_output.DataSourceOutput'. -- A second class (pydantic) which moves the xarray Dataset to tensor fields. This will be used for training in ML models # fake diff --git a/nowcasting_dataset/data_sources/data_source.py b/nowcasting_dataset/data_sources/data_source.py index 4d1e336e..1ec5bb85 100644 --- a/nowcasting_dataset/data_sources/data_source.py +++ b/nowcasting_dataset/data_sources/data_source.py @@ -16,7 +16,10 @@ from nowcasting_dataset import square from nowcasting_dataset.consts import SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput -from nowcasting_dataset.dataset.xr_utils import join_list_dataset_to_batch_dataset, make_dim_index +from nowcasting_dataset.dataset.xr_utils import ( + convert_coordinates_to_indexes_for_list_datasets, + join_list_dataset_to_batch_dataset, +) logger = logging.getLogger(__name__) @@ -257,10 +260,10 @@ def get_batch( examples = [future_example.result() for future_example in future_examples] # Get the DataSource class, this could be one of the data sources like Sun - cls = examples[0].__class__ + cls = self.get_data_model_for_batch() # Set the coords to be indices before joining into a batch - examples = [make_dim_index(example) for example in examples] + examples = convert_coordinates_to_indexes_for_list_datasets(examples) # join the examples together, and cast them to the cls, so that validation can occur return cls(join_list_dataset_to_batch_dataset(examples)) @@ -271,6 +274,10 @@ def datetime_index(self) -> pd.DatetimeIndex: # of a list of datetimes (e.g. for DatetimeDataSource). raise NotImplementedError() + def get_data_model_for_batch(self): + """Get the model that is used in the batch""" + raise NotImplementedError() + def get_contiguous_time_periods(self) -> pd.DataFrame: """Get all the time periods for which this DataSource has contiguous data. @@ -378,7 +385,7 @@ def data(self): def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> DataSourceOutput: + ) -> xr.Dataset: """ Get Example data @@ -419,7 +426,7 @@ def get_example( f"actual shape {selected_data.shape}" ) - return selected_data.load() + return selected_data.load().to_dataset(name="data") def geospatial_border(self) -> List[Tuple[Number, Number]]: """ diff --git a/nowcasting_dataset/data_sources/datasource_output.py b/nowcasting_dataset/data_sources/datasource_output.py index 689507d6..09a1db84 100644 --- a/nowcasting_dataset/data_sources/datasource_output.py +++ b/nowcasting_dataset/data_sources/datasource_output.py @@ -58,13 +58,15 @@ def check_nan_and_inf(self, data: xr.Dataset, variable_name: str = None): if isnan(data).any(): message = f"Some {self.__class__.__name__} data values are NaNs" - message += f" ({variable_name})" if variable_name is not None else None + if variable_name is not None: + message += f" ({variable_name})" logger.error(message) raise Exception(message) if isinf(data).any(): message = f"Some {self.__class__.__name__} data values are Infinite" - message += f" ({variable_name})" if variable_name is not None else None + if variable_name is not None: + message += f" ({variable_name})" logger.error(message) raise Exception(message) diff --git a/nowcasting_dataset/data_sources/fake.py b/nowcasting_dataset/data_sources/fake.py index 309ea1bf..5d2c1cfd 100644 --- a/nowcasting_dataset/data_sources/fake.py +++ b/nowcasting_dataset/data_sources/fake.py @@ -2,6 +2,8 @@ Wanted to keep this out of the testing frame works, as other repos, might want to use this """ +from typing import List + import numpy as np import pandas as pd import xarray as xr @@ -15,8 +17,8 @@ from nowcasting_dataset.data_sources.sun.sun_model import Sun from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic from nowcasting_dataset.dataset.xr_utils import ( - convert_data_array_to_dataset, - join_list_data_array_to_batch_dataset, + convert_coordinates_to_indexes, + convert_coordinates_to_indexes_for_list_datasets, join_list_dataset_to_batch_dataset, ) @@ -28,7 +30,7 @@ def gsp_fake( ): """Create fake data""" # make batch of arrays - xr_arrays = [ + xr_datasets = [ create_gsp_pv_dataset( seq_length=seq_length_30, freq="30T", @@ -37,8 +39,11 @@ def gsp_fake( for _ in range(batch_size) ] + # change dimensions to dimension indexes + xr_datasets = convert_coordinates_to_indexes_for_list_datasets(xr_datasets) + # make dataset - xr_dataset = join_list_dataset_to_batch_dataset(xr_arrays) + xr_dataset = join_list_dataset_to_batch_dataset(xr_datasets) return GSP(xr_dataset) @@ -47,6 +52,9 @@ def metadata_fake(batch_size): """Make a xr dataset""" xr_arrays = [create_metadata_dataset() for _ in range(batch_size)] + # change to indexes + xr_arrays = [convert_coordinates_to_indexes(xr_array) for xr_array in xr_arrays] + # make dataset xr_dataset = join_list_dataset_to_batch_dataset(xr_arrays) @@ -81,7 +89,7 @@ def nwp_fake( def pv_fake(batch_size, seq_length_5, n_pv_systems_per_batch): """Create fake data""" # make batch of arrays - xr_arrays = [ + xr_datasets = [ create_gsp_pv_dataset( seq_length=seq_length_5, freq="5T", @@ -90,8 +98,11 @@ def pv_fake(batch_size, seq_length_5, n_pv_systems_per_batch): for _ in range(batch_size) ] + # change dimensions to dimension indexes + xr_datasets = convert_coordinates_to_indexes_for_list_datasets(xr_datasets) + # make dataset - xr_dataset = join_list_dataset_to_batch_dataset(xr_arrays) + xr_dataset = join_list_dataset_to_batch_dataset(xr_datasets) return PV(xr_dataset) @@ -150,6 +161,7 @@ def topographic_fake(batch_size, image_size_pixels): x=np.sort(np.random.randn(image_size_pixels)), y=np.sort(np.random.randn(image_size_pixels))[::-1].copy(), ), + name="data", ) for _ in range(batch_size) ] @@ -184,6 +196,7 @@ def create_image_array( ) ), coords=coords, + name="data", ) # Fake data for testing! return image_data_array @@ -197,7 +210,7 @@ def create_gsp_pv_dataset( """Create gsp or pv fake dataset""" ALL_COORDS = { "time": pd.date_range("2021-01-01", freq=freq, periods=seq_length), - "id": np.random.randint(low=0, high=1000, size=number_of_systems), + "id": np.random.choice(range(1000), number_of_systems, replace=False), } coords = [(dim, ALL_COORDS[dim]) for dim in dims] data_array = xr.DataArray( @@ -208,22 +221,20 @@ def create_gsp_pv_dataset( coords=coords, ) # Fake data for testing! - data = convert_data_array_to_dataset(data_array) + data = data_array.to_dataset(name="data") x_coords = xr.DataArray( - data=np.sort(np.random.randn(number_of_systems)), - dims=["id_index"], - coords=dict( - id_index=range(number_of_systems), + data=np.sort( + np.random.choice(range(2 * number_of_systems), number_of_systems, replace=False) ), + dims=["id"], ) y_coords = xr.DataArray( - data=np.sort(np.random.randn(number_of_systems)), - dims=["id_index"], - coords=dict( - id_index=range(number_of_systems), + data=np.sort( + np.random.choice(range(2 * number_of_systems), number_of_systems, replace=False) ), + dims=["id"], ) data["x_coords"] = x_coords @@ -265,13 +276,14 @@ def create_sun_dataset( coords=coords, ) # Fake data for testing! - data = convert_data_array_to_dataset(data_array) - sun = data.rename({"data": "elevation"}) - sun["azimuth"] = data.data + sun = data_array.to_dataset(name="elevation") + sun["azimuth"] = sun.elevation sun.__setitem__("azimuth", sun.azimuth.clip(min=0, max=360)) sun.__setitem__("elevation", sun.elevation.clip(min=-90, max=90)) + sun = convert_coordinates_to_indexes(sun) + return sun @@ -282,11 +294,11 @@ def create_metadata_dataset() -> xr.Dataset: "data": pd.date_range("2021-01-01", freq="5T", periods=1) + pd.Timedelta("30T"), } - data = convert_data_array_to_dataset(xr.DataArray.from_dict(d)) + data = (xr.DataArray.from_dict(d)).to_dataset(name="data") for v in ["x_meters_center", "y_meters_center", "object_at_center_label"]: d: dict = {"dims": ("t0_dt",), "data": [np.random.randint(0, 1000)]} - d: xr.Dataset = convert_data_array_to_dataset(xr.DataArray.from_dict(d)).rename({"data": v}) + d: xr.Dataset = (xr.DataArray.from_dict(d)).to_dataset(name=v) data[v] = getattr(d, v) return data @@ -307,7 +319,7 @@ def create_datetime_dataset( coords=coords, ) # Fake data - data = convert_data_array_to_dataset(data_array) + data = data_array.to_dataset() ds = data.rename({"data": "day_of_year_cos"}) ds["day_of_year_sin"] = data.rename({"data": "day_of_year_sin"}).day_of_year_sin @@ -315,3 +327,12 @@ def create_datetime_dataset( ds["hour_of_day_sin"] = data.rename({"data": "hour_of_day_sin"}).hour_of_day_sin return data + + +def join_list_data_array_to_batch_dataset(data_arrays: List[xr.DataArray]) -> xr.Dataset: + """Join a list of xr.DataArrays into an xr.Dataset by concatenating on the example dim.""" + datasets = [ + convert_coordinates_to_indexes(data_arrays[i].to_dataset()) for i in range(len(data_arrays)) + ] + + return join_list_dataset_to_batch_dataset(datasets) diff --git a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py index e92e1d2a..d0a052c3 100644 --- a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py +++ b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py @@ -18,7 +18,6 @@ from nowcasting_dataset.data_sources.data_source import ImageDataSource from nowcasting_dataset.data_sources.gsp.eso import get_gsp_metadata_from_eso from nowcasting_dataset.data_sources.gsp.gsp_model import GSP -from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset from nowcasting_dataset.geospatial import lat_lon_to_osgb from nowcasting_dataset.square import get_bounding_box_mask from nowcasting_dataset.utils import scale_to_0_to_1 @@ -73,6 +72,10 @@ def sample_period_minutes(self) -> int: """Override the default sample minutes""" return 30 + def get_data_model_for_batch(self): + """Get the model that is used in the batch""" + return GSP + def load(self): """ Load the meta data and load the GSP power data @@ -153,7 +156,7 @@ def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], L def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> GSP: + ) -> xr.Dataset: """ Get data example from one time point (t0_dt) and for x and y coords. @@ -201,41 +204,31 @@ def get_example( da = xr.DataArray( data=selected_gsp_power.values, dims=["time", "id"], - coords=dict( - id=all_gsp_ids.values.astype(int), - time=selected_gsp_power.index.values, - ), ) # convert to dataset - gsp = convert_data_array_to_dataset(da) + gsp = da.to_dataset(name="data") # add gsp x coords gsp_x_coords = xr.DataArray( data=gsp_x_coords.values, - dims=["id_index"], - coords=dict( - id_index=range(len(all_gsp_ids.values)), - ), + dims=["id"], ) gsp_y_coords = xr.DataArray( data=gsp_y_coords.values, - dims=["id_index"], - coords=dict( - id_index=range(len(all_gsp_ids.values)), - ), + dims=["id"], ) gsp["x_coords"] = gsp_x_coords gsp["y_coords"] = gsp_y_coords # pad out so that there are always 32 gsp, fill with 0 - pad_n = self.n_gsp_per_example - len(gsp.id_index) - gsp = gsp.pad(id_index=(0, pad_n), data=((0, 0), (0, pad_n)), constant_values=0) + pad_n = self.n_gsp_per_example - len(gsp.id) + gsp = gsp.pad(id=(0, pad_n), data=((0, 0), (0, pad_n)), constant_values=0) - gsp.__setitem__("id_index", range(self.n_gsp_per_example)) + gsp.__setitem__("id", range(self.n_gsp_per_example)) - return GSP(gsp) + return gsp def _get_central_gsp_id( self, diff --git a/nowcasting_dataset/data_sources/metadata/metadata_data_source.py b/nowcasting_dataset/data_sources/metadata/metadata_data_source.py index de4bdbf1..ed9d1f63 100644 --- a/nowcasting_dataset/data_sources/metadata/metadata_data_source.py +++ b/nowcasting_dataset/data_sources/metadata/metadata_data_source.py @@ -8,7 +8,6 @@ from nowcasting_dataset.data_sources.data_source import DataSource from nowcasting_dataset.data_sources.metadata.metadata_model import Metadata -from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset from nowcasting_dataset.utils import to_numpy @@ -18,9 +17,13 @@ class MetadataDataSource(DataSource): object_at_center: str = "GSP" + def get_data_model_for_batch(self): + """Get the model that is used in the batch""" + return Metadata + def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> Metadata: + ) -> xr.Dataset: """ Get example data @@ -55,13 +58,11 @@ def get_example( "object_at_center_label": {"dims": ("t0_dt_index"), "data": [object_at_center_label]}, } - data = convert_data_array_to_dataset(xr.DataArray.from_dict(d_all["t0_dt"])) + data = (xr.DataArray.from_dict(d_all["t0_dt"])).to_dataset() for v in ["x_meters_center", "y_meters_center", "object_at_center_label"]: d: dict = d_all[v] - d: xr.Dataset = convert_data_array_to_dataset(xr.DataArray.from_dict(d)).rename( - {"data": v} - ) + d: xr.Dataset = (xr.DataArray.from_dict(d)).to_dataset().rename({"data": v}) data[v] = getattr(d, v) - return Metadata(data) + return data diff --git a/nowcasting_dataset/data_sources/nwp/nwp_data_source.py b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py index a5f922e5..f1bd980d 100644 --- a/nowcasting_dataset/data_sources/nwp/nwp_data_source.py +++ b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py @@ -10,6 +10,7 @@ from nowcasting_dataset import utils from nowcasting_dataset.consts import NWP_VARIABLE_NAMES from nowcasting_dataset.data_sources.data_source import ZarrDataSource +from nowcasting_dataset.data_sources.nwp.nwp_model import NWP _LOG = logging.getLogger(__name__) @@ -76,6 +77,10 @@ def open(self) -> None: def _open_data(self) -> xr.DataArray: return open_nwp(self.zarr_path, consolidated=self.consolidated) + def get_data_model_for_batch(self): + """Get the model that is used in the batch""" + return NWP + def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray: """ Select the numerical weather predictions for a single time slice. @@ -108,10 +113,9 @@ def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray: selected["target_time"] = init_time + selected.step return selected - def _post_process_example( - self, selected_data: xr.DataArray, t0_dt: pd.Timestamp - ) -> xr.DataArray: + def _post_process_example(self, selected_data: xr.Dataset, t0_dt: pd.Timestamp) -> xr.Dataset: """Resamples to 5 minutely.""" + start_dt = self._get_start_dt(t0_dt) end_dt = self._get_end_dt(t0_dt) diff --git a/nowcasting_dataset/data_sources/pv/pv_data_source.py b/nowcasting_dataset/data_sources/pv/pv_data_source.py index c76c8399..0ee23094 100644 --- a/nowcasting_dataset/data_sources/pv/pv_data_source.py +++ b/nowcasting_dataset/data_sources/pv/pv_data_source.py @@ -19,7 +19,6 @@ from nowcasting_dataset.consts import DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE from nowcasting_dataset.data_sources.data_source import ImageDataSource from nowcasting_dataset.data_sources.pv.pv_model import PV -from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset from nowcasting_dataset.square import get_bounding_box_mask logger = logging.getLogger(__name__) @@ -64,6 +63,10 @@ def load(self): self._load_pv_power() self.pv_metadata, self.pv_power = align_pv_system_ids(self.pv_metadata, self.pv_power) + def get_data_model_for_batch(self): + """Get the model that is used in the batch""" + return PV + def _load_metadata(self): logger.debug(f"Loading PV metadata from {self.metadata_filename}") @@ -204,7 +207,7 @@ def _get_all_pv_system_ids_in_roi( def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> PV: + ) -> xr.Dataset: """ Get Example data for PV data @@ -254,42 +257,33 @@ def get_example( ) # convert to dataset - pv = convert_data_array_to_dataset(da) + pv = da.to_dataset(name="data") # add pv x coords x_coords = xr.DataArray( data=pv_system_x_coords.values, - dims=["id_index"], - coords=dict( - id_index=range(len(all_pv_system_ids.values)), - ), + dims=["id"], ) y_coords = xr.DataArray( data=pv_system_y_coords.values, - dims=["id_index"], - coords=dict( - id_index=range(len(all_pv_system_ids.values)), - ), + dims=["id"], ) pv_system_row_number = xr.DataArray( data=pv_system_row_number, - dims=["id_index"], - coords=dict( - id_index=range(len(all_pv_system_ids.values)), - ), + dims=["id"], ) pv["x_coords"] = x_coords pv["y_coords"] = y_coords pv["pv_system_row_number"] = pv_system_row_number # pad out so that there are always n_pv_systems_per_example, pad with zeros - pad_n = self.n_pv_systems_per_example - len(pv.id_index) - pv = pv.pad(id_index=(0, pad_n), data=((0, 0), (0, pad_n)), constant_values=0) + pad_n = self.n_pv_systems_per_example - len(pv.id) + pv = pv.pad(id=(0, pad_n), data=((0, 0), (0, pad_n)), constant_values=0) - pv.__setitem__("id_index", range(self.n_pv_systems_per_example)) + pv.__setitem__("id", range(self.n_pv_systems_per_example)) - return PV(pv) + return pv def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]: """Find a valid geographical location for each t0_datetime. diff --git a/nowcasting_dataset/data_sources/satellite/satellite_data_source.py b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py index 4dfe64b6..95554b2f 100644 --- a/nowcasting_dataset/data_sources/satellite/satellite_data_source.py +++ b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py @@ -48,14 +48,17 @@ def open(self) -> None: def _open_data(self) -> xr.DataArray: return open_sat_data(zarr_path=self.zarr_path, consolidated=self.consolidated) - def _dataset_to_data_source_output(output: xr.Dataset) -> Satellite: - return Satellite(output) + def get_data_model_for_batch(self): + """Get the model that is used in the batch""" + return Satellite def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray: start_dt = self._get_start_dt(t0_dt) end_dt = self._get_end_dt(t0_dt) data = self.data.sel(time=slice(start_dt, end_dt)) + assert type(data) == xr.DataArray + return data def datetime_index(self, remove_night: bool = True) -> pd.DatetimeIndex: diff --git a/nowcasting_dataset/data_sources/sun/sun_data_source.py b/nowcasting_dataset/data_sources/sun/sun_data_source.py index 66a13429..a5160e3b 100644 --- a/nowcasting_dataset/data_sources/sun/sun_data_source.py +++ b/nowcasting_dataset/data_sources/sun/sun_data_source.py @@ -7,12 +7,12 @@ import numpy as np import pandas as pd +import xarray as xr import nowcasting_dataset.filesystem.utils as nd_fs_utils from nowcasting_dataset.data_sources.data_source import DataSource from nowcasting_dataset.data_sources.sun.raw_data_load_save import load_from_zarr, x_y_to_name from nowcasting_dataset.data_sources.sun.sun_model import Sun -from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset logger = logging.getLogger(__name__) @@ -28,13 +28,17 @@ def __post_init__(self): super().__post_init__() self._load() + def get_data_model_for_batch(self): + """Get the model that is used in the batch""" + return Sun + def check_input_paths_exist(self) -> None: """Check input paths exist. If not, raise a FileNotFoundError.""" nd_fs_utils.check_path_exists(self.zarr_path) def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> Sun: + ) -> xr.Dataset: """ Get example data from t0_dt and x and y xoordinates @@ -78,9 +82,8 @@ def get_example( azimuth = azimuth.to_xarray().rename({"index": "time"}) elevation = elevation.to_xarray().rename({"index": "time"}) - sun = convert_data_array_to_dataset(azimuth).rename({"data": "azimuth"}) - elevation = convert_data_array_to_dataset(elevation) - sun["elevation"] = elevation.data + sun = azimuth.to_dataset(name="azimuth") + sun["elevation"] = elevation return Sun(sun) diff --git a/nowcasting_dataset/data_sources/topographic/topographic_data_source.py b/nowcasting_dataset/data_sources/topographic/topographic_data_source.py index 9d0517d0..07c7c514 100644 --- a/nowcasting_dataset/data_sources/topographic/topographic_data_source.py +++ b/nowcasting_dataset/data_sources/topographic/topographic_data_source.py @@ -11,7 +11,6 @@ import nowcasting_dataset.filesystem.utils as nd_fs_utils from nowcasting_dataset.data_sources.data_source import ImageDataSource from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic -from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset from nowcasting_dataset.geospatial import OSGB from nowcasting_dataset.utils import OpenData @@ -46,13 +45,17 @@ def __post_init__(self, image_size_pixels: int, meters_per_pixel: int): self._stored_pixel_size_meters = abs(self._data.coords["x"][1] - self._data.coords["x"][0]) self._meters_per_pixel = meters_per_pixel + def get_data_model_for_batch(self): + """Get the model that is used in the batch""" + return Topographic + def check_input_paths_exist(self) -> None: """Check input paths exist. If not, raise a FileNotFoundError.""" nd_fs_utils.check_path_exists(self.filename) def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> Topographic: + ) -> xr.Dataset: """ Get a single example @@ -97,8 +100,8 @@ def get_example( f"actual shape {selected_data.shape}" ) - # TODO: Issue #318: Coordinates should be changed just before creating a batch. - topo_xd = convert_data_array_to_dataset(selected_data) + # change to dataset + topo_xd = selected_data.to_dataset(name="data") return Topographic(topo_xd) diff --git a/nowcasting_dataset/dataset/xr_utils.py b/nowcasting_dataset/dataset/xr_utils.py index c8276bf6..12e37129 100644 --- a/nowcasting_dataset/dataset/xr_utils.py +++ b/nowcasting_dataset/dataset/xr_utils.py @@ -9,15 +9,6 @@ import xarray as xr -# TODO: This function is only used in fake.py for testing. -# Maybe we should move this function to fake.py? -def join_list_data_array_to_batch_dataset(data_arrays: List[xr.DataArray]) -> xr.Dataset: - """Join a list of xr.DataArrays into an xr.Dataset by concatenating on the example dim.""" - datasets = [convert_data_array_to_dataset(data_arrays[i]) for i in range(len(data_arrays))] - - return join_list_dataset_to_batch_dataset(datasets) - - def join_list_dataset_to_batch_dataset(datasets: list[xr.Dataset]) -> xr.Dataset: """Join a list of data sets to a dataset by expanding dims""" @@ -29,18 +20,14 @@ def join_list_dataset_to_batch_dataset(datasets: list[xr.Dataset]) -> xr.Dataset return xr.concat(new_datasets, dim="example") -# TODO: Issue #318: Maybe remove this function and, in calling code, do data_array.to_dataset() -# followed by make_dim_index, to make it more explicit what's happening? At the moment, -# in the calling code, it's not clear that the coordinates are being changed. -def convert_data_array_to_dataset(data_xarray: xr.DataArray) -> xr.Dataset: - """Convert data array to dataset. Reindex dim so that it can be merged with batch""" - data = xr.Dataset({"data": data_xarray}) - return make_dim_index(dataset=data) +def convert_coordinates_to_indexes_for_list_datasets( + examples: List[xr.Dataset], +) -> List[xr.Dataset]: + """Set the coords to be indices before joining into a batch""" + return [convert_coordinates_to_indexes(example) for example in examples] -# TODO: Issue #318: Maybe rename this function... maybe to coord_to_range()? -# Not sure what's best right now! :) -def make_dim_index(dataset: xr.Dataset) -> xr.Dataset: +def convert_coordinates_to_indexes(dataset: xr.Dataset) -> xr.Dataset: """Reindex dims so that it can be merged with batch. For each dimension in dataset, change the coords to 0.. len(original_coords), @@ -50,6 +37,8 @@ def make_dim_index(dataset: xr.Dataset) -> xr.Dataset: This is useful to align multiple examples into a single batch. """ + assert type(dataset) == xr.Dataset + original_dim_names = dataset.dims for original_dim_name in original_dim_names: