diff --git a/conftest.py b/conftest.py index 213221c3..b47b2c31 100644 --- a/conftest.py +++ b/conftest.py @@ -9,6 +9,7 @@ from nowcasting_dataset.config.load import load_yaml_configuration from nowcasting_dataset.data_sources import SatelliteDataSource from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource +from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource pytest.IMAGE_SIZE_PIXELS = 128 @@ -50,6 +51,14 @@ def sat_data_source(sat_filename: Path): ) +@pytest.fixture +def general_data_source(): + + return MetadataDataSource( + history_minutes=0, forecast_minutes=5, object_at_center="GSP", convert_to_numpy=True + ) + + @pytest.fixture def gsp_data_source(): return GSPDataSource( @@ -65,9 +74,9 @@ def gsp_data_source(): @pytest.fixture def configuration(): filename = os.path.join(os.path.dirname(nowcasting_dataset.__file__), "config", "gcp.yaml") - config = load_yaml_configuration(filename) + configuration = load_yaml_configuration(filename) - return config + return configuration @pytest.fixture diff --git a/notebooks/2021-09/2021-09-07/sat_data.py b/notebooks/2021-09/2021-09-07/sat_data.py index 79d28f52..4e8f5e18 100644 --- a/notebooks/2021-09/2021-09-07/sat_data.py +++ b/notebooks/2021-09/2021-09-07/sat_data.py @@ -1,6 +1,6 @@ from datetime import datetime -from nowcasting_dataset.data_sources.satellite_data_source import SatelliteDataSource +from nowcasting_dataset.data_sources.satellite.satellite_data_source import SatelliteDataSource s = SatelliteDataSource( filename="gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/" diff --git a/notebooks/2021-10/2021-10-01/pydantic.py b/notebooks/2021-10/2021-10-01/pydantic.py new file mode 100644 index 00000000..b6ff73b6 --- /dev/null +++ b/notebooks/2021-10/2021-10-01/pydantic.py @@ -0,0 +1,129 @@ +from pydantic import BaseModel, Field, validator +from typing import Union +import numpy as np +import xarray as xr +import torch +from nowcasting_dataset.config.model import Configuration + + +Array = Union[xr.DataArray, np.ndarray, torch.Tensor] + + +class Satellite(BaseModel): + + # width: int = Field(..., g=0, description="The width of the satellite image") + # height: int = Field(..., g=0, description="The width of the satellite image") + # num_channels: int = Field(..., g=0, description="The width of the satellite image") + + # Shape: [batch_size,] seq_length, width, height, channel + image_data: Array = Field( + ..., + description="Satellites images. Shape: [batch_size,] seq_length, width, height, channel", + ) + x_coords: Array = Field( + ..., + description="The x (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width", + ) + y_coords: Array = Field( + ..., + description="The y (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width", + ) + + # @validator("sat_data") + # def image_shape(cls, v): + # assert v.shape[-1] == cls.num_channels + # assert v.shape[-2] == cls.height + # assert v.shape[-3] == cls.width + # + # @validator("x_coords") + # def x_coords_shape(cls, v): + # assert v.shape[-1] == cls.width + # + # @validator("y_coords") + # def y_coords_shape(cls, v): + # assert v.shape[-1] == cls.height + # + class Config: + arbitrary_types_allowed = True + + +class Batch(BaseModel): + + batch_size: int = Field( + ..., + g=0, + description="The size of this batch. If the batch size is 0, " + "then this item stores one data item", + ) + + satellite: Satellite + + +class FakeDataset(torch.utils.data.Dataset): + """Fake dataset.""" + + def __init__(self, configuration: Configuration = Configuration(), length: int = 10): + """ + Init + + Args: + configuration: configuration object + length: length of dataset + """ + self.batch_size = configuration.process.batch_size + self.seq_length_5 = ( + configuration.process.seq_len_5_minutes + ) # the sequence data in 5 minute steps + self.seq_length_30 = ( + configuration.process.seq_len_30_minutes + ) # the sequence data in 30 minute steps + self.satellite_image_size_pixels = configuration.process.satellite_image_size_pixels + self.nwp_image_size_pixels = configuration.process.nwp_image_size_pixels + self.number_sat_channels = len(configuration.process.sat_channels) + self.number_nwp_channels = len(configuration.process.nwp_channels) + self.length = length + + def __len__(self): + """Number of pieces of data""" + return self.length + + def per_worker_init(self, worker_id: int): + """Not needed""" + pass + + def __getitem__(self, idx): + """ + Get item, use for iter and next method + + Args: + idx: batch index + + Returns: Dictionary of random data + + """ + + sat = Satellite( + image_data=np.random.randn( + self.batch_size, + self.seq_length_5, + self.satellite_image_size_pixels, + self.satellite_image_size_pixels, + self.number_sat_channels, + ), + x_coords=torch.sort(torch.randn(self.batch_size, self.satellite_image_size_pixels))[0], + y_coords=torch.sort( + torch.randn(self.batch_size, self.satellite_image_size_pixels), descending=True + )[0], + ) + + # Note need to return as nested dict + return Batch(satellite=sat, batch_size=self.batch_size).dict() + + +train = torch.utils.data.DataLoader(FakeDataset()) +i = iter(train) +x = next(i) + +x = Batch(**x) +# IT WORKS +assert type(x.satellite.image_data) == torch.Tensor diff --git a/notebooks/2021-10/2021-10-08/no_validation.py b/notebooks/2021-10/2021-10-08/no_validation.py new file mode 100644 index 00000000..e69de29b diff --git a/notebooks/2021-10/2021-10-08/xr_compression.py b/notebooks/2021-10/2021-10-08/xr_compression.py new file mode 100644 index 00000000..54c0c5d1 --- /dev/null +++ b/notebooks/2021-10/2021-10-08/xr_compression.py @@ -0,0 +1,97 @@ +import os + +import numpy as np +import xarray as xr +from nowcasting_dataset.utils import coord_to_range + + +def get_satellite_xrarray_data_array( + batch_size, seq_length_5, satellite_image_size_pixels, number_sat_channels=10 +): + + r = np.random.randn( + # self.batch_size, + seq_length_5, + satellite_image_size_pixels, + satellite_image_size_pixels, + number_sat_channels, + ) + + time = np.sort(np.random.randn(seq_length_5)) + + x_coords = np.sort(np.random.randint(0, 1000, (satellite_image_size_pixels))) + y_coords = np.sort(np.random.randint(0, 1000, (satellite_image_size_pixels)))[::-1].copy() + + sat_xr = xr.DataArray( + data=r, + dims=["time", "x", "y", "channels"], + coords=dict( + # batch=range(0,self.batch_size), + x=list(x_coords), + y=list(y_coords), + time=list(time), + channels=range(0, number_sat_channels), + ), + attrs=dict( + description="Ambient temperature.", + units="degC", + ), + name="sata_data", + ) + + return sat_xr + + +def sat_data_array_to_dataset(sat_xr): + ds = sat_xr.to_dataset(name="sat_data") + # ds["sat_data"] = ds["sat_data"].astype(np.int16) + + for dim in ["time", "x", "y"]: + # This does seem like the right way to do it + # https://ecco-v4-python-tutorial.readthedocs.io/ECCO_v4_Saving_Datasets_and_DataArrays_to_NetCDF.html + ds = coord_to_range(ds, dim, prefix="sat") + ds = ds.rename( + { + "channels": f"sat_channels", + "x": f"sat_x", + "y": f"sat_y", + } + ) + + # ds["sat_x_coords"] = ds["sat_x_coords"].astype(np.int32) + # ds["sat_y_coords"] = ds["sat_y_coords"].astype(np.int32) + + return ds + + +def to_netcdf(batch_xr, local_filename): + encoding = {name: {"compression": "lzf"} for name in batch_xr.data_vars} + batch_xr.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding) + + +# 1. try to save netcdf files not using coord to range function +sat_xrs = [get_satellite_xrarray_data_array(4, 19, 32) for _ in range(0, 10)] + +### error ### +# cant do this step as x/y index has duplicate values +sat_dataset = xr.merge(sat_xrs) +to_netcdf(sat_dataset, "test_no_alignment.nc") +### + +# but can save it as separate files +os.mkdir("test_no_alignment") +[sat_xrs[i].to_netcdf(f"test_no_alignment/{i}.nc", engine="h5netcdf") for i in range(0, 10)] +# 10 files about 1.5MB + +# 2. +sat_xrs = [get_satellite_xrarray_data_array(4, 19, 32) for _ in range(0, 10)] +sat_xrs = [sat_data_array_to_dataset(sat_xr) for sat_xr in sat_xrs] + +sat_dataset = xr.concat(sat_xrs, dim="example") +to_netcdf(sat_dataset, "test_alignment.nc") +# this 15 MB + + +# conclusion +# no major improvement in compression by joining datasets together, buts by joining array together, +# it does make it easier to get array ready ML diff --git a/notebooks/2021-10/2021-10-08/xr_pydantic.py b/notebooks/2021-10/2021-10-08/xr_pydantic.py new file mode 100644 index 00000000..946b3e40 --- /dev/null +++ b/notebooks/2021-10/2021-10-08/xr_pydantic.py @@ -0,0 +1,99 @@ +from pydantic import BaseModel, Field, validator +from typing import Union, List +import numpy as np +import xarray as xr +import torch +from nowcasting_dataset.config.model import Configuration + + +Array = Union[xr.DataArray, np.ndarray, torch.Tensor] + + +class Satellite(BaseModel): + # Shape: [batch_size,] seq_length, width, height, channel + image_data: xr.DataArray = Field( + ..., + description="Satellites images. Shape: [batch_size,] seq_length, width, height, channel", + ) + + class Config: + arbitrary_types_allowed = True + + @validator("image_data") + def v_image_data(cls, v): + print("validating image data") + return v + + +class Batch(BaseModel): + + batch_size: int = 0 + satellite: Satellite + + @validator("batch_size") + def v_image_data(cls, v): + print("validating batch size") + return v + + +s = Satellite(image_data=xr.DataArray()) +s_dict = s.dict() + +x = Satellite(**s_dict) +x = Satellite.construct(Satellite.__fields_set__, **s_dict) + + +batch = Batch(batch_size=5, satellite=s) + +b_dict = batch.dict() + +x = Batch(**b_dict) +x = Batch.construct(Batch.__fields_set__, **b_dict) + + +# class Satellite(BaseModel): +# +# image_data: xr.DataArray +# +# # validate +# +# def to_dataset(self): +# pass +# +# def from_dateset(self): +# pass +# +# def to_numpy(self) -> SatelliteNumpy: +# pass +# +# +# class SatelliteNumpy(BaseModel): +# +# image_data: np.ndarray +# x: np.ndarray +# # more +# +# +# class Example(BaseModel): +# +# satelllite: Satellite +# # more +# +# +# class Batch(BaseModel): +# +# batch_size: int = 0 +# examples: List[Example] +# +# def to/from_netcdf(): +# pass +# +# +# class BatchNumpy(BaseModel): +# +# batch_size: int = 0 +# satellite: SatellliteNumpy +# # more +# +# def from_batch(self) -> BatchNumpy: +# """ change to Batch numpy structure """ diff --git a/nowcasting_dataset/config/gcp.yaml b/nowcasting_dataset/config/gcp.yaml index 4f5ded46..7974cfd8 100644 --- a/nowcasting_dataset/config/gcp.yaml +++ b/nowcasting_dataset/config/gcp.yaml @@ -6,10 +6,10 @@ input_data: satellite_zarr_path: gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr solar_pv_data_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_timeseries_batch.nc solar_pv_metadata_filename: gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv - gsp_zarr_path: gs://solar-pv-nowcasting-data/PV/PVOutput.org/PV/GSP/v1/pv_gsp.zarr + gsp_zarr_path: gs://solar-pv-nowcasting-data/PV/GSP/v1/pv_gsp.zarr topographic_filename: gs://solar-pv-nowcasting-data/Topographic/europe_dem_1km_osgb.tif output_data: - filepath: gs://solar-pv-nowcasting-data/prepared_ML_training_data/v6/ + filepath: gs://solar-pv-nowcasting-data/prepared_ML_training_data/v7/ process: local_temp_path: ~/temp/ seed: 1234 diff --git a/nowcasting_dataset/config/model.py b/nowcasting_dataset/config/model.py index 30b1c0a2..4a1c9b25 100644 --- a/nowcasting_dataset/config/model.py +++ b/nowcasting_dataset/config/model.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from pydantic import validator -from nowcasting_dataset.data_sources.nwp_data_source import NWP_VARIABLE_NAMES -from nowcasting_dataset.data_sources.satellite_data_source import SAT_VARIABLE_NAMES +from nowcasting_dataset.consts import NWP_VARIABLE_NAMES +from nowcasting_dataset.consts import SAT_VARIABLE_NAMES class General(BaseModel): @@ -107,7 +107,7 @@ class Process(BaseModel): ) history_minutes: int = Field(30, ge=0, description="how many historic minutes are used") satellite_image_size_pixels: int = Field(64, description="the size of the satellite images") - nwp_image_size_pixels: int = Field(2, description="the size of the nwp images") + nwp_image_size_pixels: int = Field(64, description="the size of the nwp images") sat_channels: tuple = Field( SAT_VARIABLE_NAMES, description="the satellite channels that are used" diff --git a/nowcasting_dataset/consts.py b/nowcasting_dataset/consts.py index 2b6b0245..d5e1cd54 100644 --- a/nowcasting_dataset/consts.py +++ b/nowcasting_dataset/consts.py @@ -3,6 +3,7 @@ from typing import Union import numpy as np +import torch import xarray as xr # DEFAULT PATHS @@ -23,7 +24,7 @@ NWP_FILENAME = "gs://" + str(BUCKET / "NWP/UK_Met_Office/UKV_zarr") # Typing -Array = Union[xr.DataArray, np.ndarray] +Array = Union[xr.DataArray, np.ndarray, torch.Tensor] PV_SYSTEM_ID: str = "pv_system_id" PV_SYSTEM_ROW_NUMBER = "pv_system_row_number" PV_SYSTEM_X_COORDS = "pv_system_x_coords" diff --git a/nowcasting_dataset/data_sources/README.md b/nowcasting_dataset/data_sources/README.md new file mode 100644 index 00000000..90ff7c7a --- /dev/null +++ b/nowcasting_dataset/data_sources/README.md @@ -0,0 +1,42 @@ +This folder contains the code for the different data sources. + +# Data Sources +- metadata: metadata for the batch like t0_dt, x_meters_center .... +- datetime: datetime information like 'hour_of_day' +- gsp: Grid Supply Point data from Sheffield Solar (e.g. the estimated total solar PV power generation for each +GSP region, and the geospatial shape of each GSP region). +- nwp: Numerical Weather predictions from UK Met Office +- pv: PV output data from pvoutput.org +- satellite: satellite data from ... +- sun: Sun position data (e.g. the estimated total solar PV power generation for each GSP region, +and the geospatial shape of each GSP region). +- topographic: Topographic data e.g. the elevation of the land. + +# data_source.py + +General class used for making a data source. It has the following functions +- get_batch: gets a whole batch of data for that data source +- datetime_index: gets the all available datatimes of the source +- get_example: gets one "example" (a single consecutive sequence). Each batch is made up of multiple examples. +- get_locations_for_batch: Samples the geospatial x,y location for each example in a batch. This is useful because, + typically, we want a single DataSource to dictate the geospatial locations of the examples (for example, + we want each example to be centered on the centroid of the grid supply point region). All the other + `DataSources` will use these same geospatial locations. + + +# datasource_output.py + +General pydantic model of output of the data source. Contains the following methods +- to_numpy: changes all data points to numpy objects +- split: converts a batch to a list of items +- join: joins list of items to one +- to_xr_dataset: changes data items to xarrays and returns a dataset +- from_xr_dataset: loads from an xarray dataset +- select_time_period: subselect data, depending on a time period + +# Data Source folder + +Roughly each of the data source folders follows this pattern +- A class which defines how to load the data source, how to select for batches etc. This inherits from 'data_source.DataSource', +- A class which contains the output model of the data source. This is the information used in the batches. +This inherits from 'datasource_output.DataSourceOutput'. diff --git a/nowcasting_dataset/data_sources/__init__.py b/nowcasting_dataset/data_sources/__init__.py index 4d2f9b84..f4cf7556 100644 --- a/nowcasting_dataset/data_sources/__init__.py +++ b/nowcasting_dataset/data_sources/__init__.py @@ -1,7 +1,9 @@ """ Various DataSources """ from nowcasting_dataset.data_sources.data_source import DataSource -from nowcasting_dataset.data_sources.datetime_data_source import DatetimeDataSource -from nowcasting_dataset.data_sources.nwp_data_source import NWPDataSource -from nowcasting_dataset.data_sources.pv_data_source import PVDataSource -from nowcasting_dataset.data_sources.satellite_data_source import SatelliteDataSource -from nowcasting_dataset.data_sources.topographic_data_source import TopographicDataSource +from nowcasting_dataset.data_sources.datetime.datetime_data_source import DatetimeDataSource +from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWPDataSource +from nowcasting_dataset.data_sources.pv.pv_data_source import PVDataSource +from nowcasting_dataset.data_sources.satellite.satellite_data_source import SatelliteDataSource +from nowcasting_dataset.data_sources.topographic.topographic_data_source import ( + TopographicDataSource, +) diff --git a/nowcasting_dataset/data_sources/data_source.py b/nowcasting_dataset/data_sources/data_source.py index c131e588..5500b777 100644 --- a/nowcasting_dataset/data_sources/data_source.py +++ b/nowcasting_dataset/data_sources/data_source.py @@ -11,7 +11,7 @@ import nowcasting_dataset.time as nd_time from nowcasting_dataset import square -from nowcasting_dataset.dataset.example import Example, to_numpy +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput logger = logging.getLogger(__name__) @@ -99,7 +99,7 @@ def get_batch( t0_datetimes: pd.DatetimeIndex, x_locations: Iterable[Number], y_locations: Iterable[Number], - ) -> List[Example]: + ) -> DataSourceOutput: """ Get Batch Data @@ -115,12 +115,13 @@ def get_batch( examples = [] zipped = zip(t0_datetimes, x_locations, y_locations) for t0_datetime, x_location, y_location in zipped: - example = self.get_example(t0_datetime, x_location, y_location) + output: DataSourceOutput = self.get_example(t0_datetime, x_location, y_location) + if self.convert_to_numpy: - example = to_numpy(example) - examples.append(example) + output.to_numpy() + examples.append(output) - return examples + return DataSourceOutput.create_batch_from_examples(examples) def datetime_index(self) -> pd.DatetimeIndex: """Returns a complete list of all available datetimes.""" @@ -153,7 +154,7 @@ def get_example( t0_dt: pd.Timestamp, #: Datetime of "now": The most recent obs. x_meters_center: Number, #: Centre, in OSGB coordinates. y_meters_center: Number, #: Centre, in OSGB coordinates. - ) -> Example: + ) -> DataSourceOutput: """Must be overridden by child classes.""" raise NotImplementedError() @@ -214,7 +215,7 @@ def data(self): def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> Example: + ) -> DataSourceOutput: """ Get Example data @@ -293,5 +294,5 @@ def open(self) -> None: def _open_data(self) -> xr.DataArray: raise NotImplementedError() - def _put_data_into_example(self, selected_data: xr.DataArray) -> Example: + def _put_data_into_example(self, selected_data: xr.DataArray) -> DataSourceOutput: raise NotImplementedError() diff --git a/nowcasting_dataset/data_sources/datasource_output.py b/nowcasting_dataset/data_sources/datasource_output.py new file mode 100644 index 00000000..3b566360 --- /dev/null +++ b/nowcasting_dataset/data_sources/datasource_output.py @@ -0,0 +1,193 @@ +""" General Data Source output pydantic class. """ +from __future__ import annotations +from pydantic import BaseModel, Field +import pandas as pd +import xarray as xr +import numpy as np +from typing import List, Union +import logging +from datetime import datetime + +from nowcasting_dataset.utils import to_numpy + +logger = logging.getLogger(__name__) + + +class DataSourceOutput(BaseModel): + """General Data Source output pydantic class. + + Data source output classes should inherit from this class + """ + + class Config: + """ Allowed classes e.g. tensor.Tensor""" + + # TODO maybe there is a better way to do this + arbitrary_types_allowed = True + + batch_size: int = Field( + 0, + ge=0, + description="The size of this batch. If the batch size is 0, " + "then this item stores one data item i.e Example", + ) + + def to_numpy(self): + """Change to numpy""" + for k, v in self.dict().items(): + self.__setattr__(k, to_numpy(v)) + + def to_xr_data_array(self): + """ Change to xr DataArray""" + raise NotImplementedError() + + @staticmethod + def create_batch_from_examples(data): + """ + Join a list of data source items to a batch. + + Note that this only works for numpy objects, so objects are changed into numpy + """ + _ = [d.to_numpy() for d in data] + + # use the first item in the list, and then update each item + batch = data[0] + for k in batch.dict().keys(): + + # set batch size to the list of the items + if k == "batch_size": + batch.batch_size = len(data) + else: + + # get list of one variable from the list of data items. + one_variable_list = [d.__getattribute__(k) for d in data] + batch.__setattr__(k, np.stack(one_variable_list, axis=0)) + + return batch + + def split(self) -> List[DataSourceOutput]: + """ + Split the datasource from a batch to a list of items + + Returns: List of single data source items + """ + cls = self.__class__ + + items = [] + for batch_idx in range(self.batch_size): + d = {k: v[batch_idx] for k, v in self.dict().items() if k != "batch_size"} + d["batch_size"] = 0 + items.append(cls(**d)) + + return items + + def to_xr_dataset(self, **kwargs): + """ Make a xr dataset. Each data source needs to define this """ + raise NotImplementedError + + def from_xr_dataset(self): + """ Load from xr dataset. Each data source needs to define this """ + raise NotImplementedError + + def get_datetime_index(self): + """ Datetime index for the data """ + pass + + def select_time_period( + self, + keys: List[str], + history_minutes: int, + forecast_minutes: int, + t0_dt_of_first_example: Union[datetime, pd.Timestamp], + ): + """ + Selects a subset of data between the indicies of [start, end] for each key in keys + + Note that class is edited so nothing is returned. + + Args: + keys: Keys in batch to use + t0_dt_of_first_example: datetime of the current time (t0) in the first example of the batch + history_minutes: How many minutes of history to use + forecast_minutes: How many minutes of future data to use for forecasting + + """ + logger.debug( + f"Taking a sub-selection of the batch data based on a history minutes of {history_minutes} " + f"and forecast minutes of {forecast_minutes}" + ) + + start_time_of_first_batch = t0_dt_of_first_example - pd.to_timedelta( + f"{history_minutes} minute 30 second" + ) + end_time_of_first_example = t0_dt_of_first_example + pd.to_timedelta( + f"{forecast_minutes} minute 30 second" + ) + + logger.debug(f"New start time for first batch is {start_time_of_first_batch}") + logger.debug(f"New end time for first batch is {end_time_of_first_example}") + + start_time_of_first_example = to_numpy(start_time_of_first_batch) + end_time_of_first_example = to_numpy(end_time_of_first_example) + + if self.get_datetime_index() is not None: + + time_of_first_example = to_numpy(pd.to_datetime(self.get_datetime_index()[0])) + + # find the start and end index, that we will then use to slice the data + start_i, end_i = np.searchsorted( + time_of_first_example, [start_time_of_first_example, end_time_of_first_example] + ) + + # slice all the data + for key in keys: + if "time" in self.__getattribute__(key).dims: + self.__setattr__( + key, self.__getattribute__(key).isel(time=slice(start_i, end_i)) + ) + elif "time_30" in self.__getattribute__(key).dims: + self.__setattr__( + key, self.__getattribute__(key).isel(time_30=slice(start_i, end_i)) + ) + + logger.debug(f"{self.__class__.__name__} {key}: {self.__getattribute__(key).shape}") + + +def pad_nans(array, pad_width) -> np.ndarray: + """ Pad nans with nans""" + array = array.astype(np.float32) + return np.pad(array, pad_width, constant_values=np.NaN) + + +def pad_data( + data: DataSourceOutput, + pad_size: int, + one_dimensional_arrays: List[str], + two_dimensional_arrays: List[str], +): + """ + Pad (if necessary) so returned arrays are always of size + + data has two types of arrays in it, one dimensional arrays and two dimensional arrays + the one dimensional arrays are padded in that dimension + the two dimensional arrays are padded in the second dimension + + Note that class is edited so nothing is returned. + + Args: + data: typed dictionary of data objects + pad_size: the maount that should be padded + one_dimensional_arrays: list of data items that should be padded by one dimension + two_dimensional_arrays: list of data tiems that should be padded in the third dimension (and more) + + """ + # Pad (if necessary) so returned arrays are always of size + pad_shape = (0, pad_size) # (before, after) + + for name in one_dimensional_arrays: + data.__setattr__(name, pad_nans(data.__getattribute__(name), pad_width=pad_shape)) + + for variable in two_dimensional_arrays: + data.__setattr__( + variable, pad_nans(data.__getattribute__(variable), pad_width=((0, 0), pad_shape)) + ) # (axis0, axis1) diff --git a/nowcasting_dataset/data_sources/datetime_data_source.py b/nowcasting_dataset/data_sources/datetime/datetime_data_source.py similarity index 94% rename from nowcasting_dataset/data_sources/datetime_data_source.py rename to nowcasting_dataset/data_sources/datetime/datetime_data_source.py index f51d3596..9d1088a2 100644 --- a/nowcasting_dataset/data_sources/datetime_data_source.py +++ b/nowcasting_dataset/data_sources/datetime/datetime_data_source.py @@ -7,7 +7,7 @@ from nowcasting_dataset import time as nd_time from nowcasting_dataset.data_sources.data_source import DataSource -from nowcasting_dataset.dataset.example import Example +from nowcasting_dataset.data_sources.datetime.datetime_model import Datetime @dataclass @@ -20,7 +20,7 @@ def __post_init__(self): def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> Example: + ) -> Datetime: """ Get example data diff --git a/nowcasting_dataset/data_sources/datetime/datetime_model.py b/nowcasting_dataset/data_sources/datetime/datetime_model.py new file mode 100644 index 00000000..9629e9e6 --- /dev/null +++ b/nowcasting_dataset/data_sources/datetime/datetime_model.py @@ -0,0 +1,100 @@ +""" Model for output of datetime data """ +from pydantic import validator +import xarray as xr +import numpy as np +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput +from nowcasting_dataset.consts import Array, DATETIME_FEATURE_NAMES +from nowcasting_dataset.utils import coord_to_range + + +class Datetime(DataSourceOutput): + """ Model for output of datetime data """ + + hour_of_day_sin: Array #: Shape: [batch_size,] seq_length + hour_of_day_cos: Array + day_of_year_sin: Array + day_of_year_cos: Array + datetime_index: Array + + @property + def sequence_length(self): + """The sequence length of the pv data""" + return self.hour_of_day_sin.shape[-1] + + @validator("hour_of_day_cos") + def v_hour_of_day_cos(cls, v, values): + """ Validate 'hour_of_day_cos' """ + assert v.shape[-1] == values["hour_of_day_sin"].shape[-1] + return v + + @validator("day_of_year_sin") + def v_day_of_year_sin(cls, v, values): + """ Validate 'day_of_year_sin' """ + assert v.shape[-1] == values["hour_of_day_sin"].shape[-1] + return v + + @validator("day_of_year_cos") + def v_day_of_year_cos(cls, v, values): + """ Validate 'day_of_year_cos' """ + assert v.shape[-1] == values["hour_of_day_sin"].shape[-1] + return v + + @staticmethod + def fake(batch_size, seq_length_5): + """ Make a fake Datetime object """ + return Datetime( + batch_size=batch_size, + hour_of_day_sin=np.random.randn( + batch_size, + seq_length_5, + ), + hour_of_day_cos=np.random.randn( + batch_size, + seq_length_5, + ), + day_of_year_sin=np.random.randn( + batch_size, + seq_length_5, + ), + day_of_year_cos=np.random.randn( + batch_size, + seq_length_5, + ), + datetime_index=np.sort(np.random.randn(batch_size, seq_length_5))[:, ::-1].copy(), + # copy is needed as torch doesnt not support negative strides + ) + + def to_xr_dataset(self, _): + """ Make a xr dataset """ + individual_datasets = [] + for name in DATETIME_FEATURE_NAMES: + + var = self.__getattribute__(name) + + data = xr.DataArray( + var, + dims=["time"], + coords={"time": self.datetime_index}, + name=name, + ) + + ds = data.to_dataset() + ds = coord_to_range(ds, "time", prefix=None) + individual_datasets.append(ds) + + return xr.merge(individual_datasets) + + @staticmethod + def from_xr_dataset(xr_dataset): + """ Change xr dataset to model. If data does not exist, then return None """ + if "hour_of_day_sin" in xr_dataset.keys(): + return Datetime( + batch_size=xr_dataset["hour_of_day_sin"].shape[0], + hour_of_day_sin=xr_dataset["hour_of_day_sin"], + hour_of_day_cos=xr_dataset["hour_of_day_cos"], + day_of_year_sin=xr_dataset["day_of_year_sin"], + day_of_year_cos=xr_dataset["day_of_year_cos"], + datetime_index=xr_dataset["hour_of_day_sin"].time, + ) + else: + return None diff --git a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py index c834efda..fe9c9212 100644 --- a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py +++ b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py @@ -15,19 +15,16 @@ import xarray as xr from nowcasting_dataset.consts import ( - GSP_ID, - GSP_YIELD, - GSP_X_COORDS, - GSP_Y_COORDS, DEFAULT_N_GSP_PER_EXAMPLE, - OBJECT_AT_CENTER, ) from nowcasting_dataset.data_sources.data_source import ImageDataSource from nowcasting_dataset.data_sources.gsp.eso import get_gsp_metadata_from_eso -from nowcasting_dataset.dataset.example import Example from nowcasting_dataset.geospatial import lat_lon_to_osgb from nowcasting_dataset.square import get_bounding_box_mask -from nowcasting_dataset.utils import scale_to_0_to_1, pad_data + +# from nowcasting_dataset.utils import scale_to_0_to_1, pad_data +from nowcasting_dataset.utils import scale_to_0_to_1 +from nowcasting_dataset.data_sources.gsp.gsp_model import GSP logger = logging.getLogger(__name__) @@ -160,7 +157,7 @@ def get_locations_for_batch( def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> Example: + ) -> GSP: """ Get data example from one time point (t0_dt) and for x and y coords (x_meters_center), (y_meters_center). @@ -206,30 +203,18 @@ def get_example( gsp_y_coords = self.metadata[self.metadata["gsp_id"].isin(all_gsp_ids)].location_y # Save data into the Example dict... - example = Example( - t0_dt=t0_dt, - gsp_id=all_gsp_ids, - gsp_yield=selected_gsp_power, - x_meters_center=x_meters_center, - y_meters_center=y_meters_center, - gsp_x_coords=gsp_x_coords, - gsp_y_coords=gsp_y_coords, - gsp_datetime_index=selected_gsp_power.index, - ) - if self.get_center: - example[OBJECT_AT_CENTER] = "gsp" - - # Pad (if necessary) so returned arrays are always of size n_gsp_per_example. - pad_size = self.n_gsp_per_example - len(all_gsp_ids) - example = pad_data( - data=example, - one_dimensional_arrays=[GSP_ID, GSP_X_COORDS, GSP_Y_COORDS], - two_dimensional_arrays=[GSP_YIELD], - pad_size=pad_size, + gsp = GSP( + gsp_id=all_gsp_ids.values, + gsp_yield=selected_gsp_power.values, + gsp_x_coords=gsp_x_coords.values, + gsp_y_coords=gsp_y_coords.values, + gsp_datetime_index=selected_gsp_power.index.values, ) - return example + gsp.pad() + + return gsp def _get_central_gsp_id( self, diff --git a/nowcasting_dataset/data_sources/gsp/gsp_model.py b/nowcasting_dataset/data_sources/gsp/gsp_model.py new file mode 100644 index 00000000..d990c9d8 --- /dev/null +++ b/nowcasting_dataset/data_sources/gsp/gsp_model.py @@ -0,0 +1,184 @@ +""" Model for output of GSP data """ +from pydantic import Field, validator +import numpy as np +import xarray as xr + +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput, pad_data +from nowcasting_dataset.consts import Array + +from nowcasting_dataset.consts import ( + GSP_ID, + GSP_YIELD, + GSP_X_COORDS, + GSP_Y_COORDS, + GSP_DATETIME_INDEX, + DEFAULT_N_GSP_PER_EXAMPLE, +) +from nowcasting_dataset.time import make_random_time_vectors +import logging + +logger = logging.getLogger(__name__) + + +class GSP(DataSourceOutput): + """ Model for output of GSP data """ + + # Shape: [batch_size,] seq_length, width, height, channel + gsp_yield: Array = Field( + ..., + description=" GSP yield from all GSP in the region of interest (ROI). \ + : Includes central GSP system, which will always be the first entry. \ + : shape = [batch_size, ] seq_length, n_gsp_per_example", + ) + + #: GSP identification. + #: shape = [batch_size, ] n_pv_systems_per_example + gsp_id: Array = Field(..., description="gsp id from NG") + + gsp_datetime_index: Array = Field( + ..., + description="The datetime associated with the gsp data. shape = [batch_size, ] sequence length,", + ) + + gsp_x_coords: Array = Field( + ..., + description="The x (OSGB geo-spatial) coordinates of the gsp. " + "This is in fact the x centroid of the GSP region" + "Shape: [batch_size,] n_gsp_per_example", + ) + gsp_y_coords: Array = Field( + ..., + description="The y (OSGB geo-spatial) coordinates of the gsp. " + "This are in fact the y centroid of the GSP region" + "Shape: [batch_size,] n_gsp_per_example", + ) + + @property + def number_of_gsp(self): + """The number of Grid Supply Points in this example""" + return self.gsp_yield.shape[-1] + + @property + def sequence_length(self): + """The sequence length of the GSP PV power timeseries data""" + return self.gsp_yield.shape[-2] + + @validator("gsp_yield") + def gsp_yield_shape(cls, v, values): + """ Validate 'gsp_yield' """ + if values["batch_size"] > 0: + assert len(v.shape) == 3 + else: + assert len(v.shape) == 2 + return v + + @validator("gsp_x_coords") + def x_coordinates_shape(cls, v, values): + """ Validate 'gsp_x_coords' """ + assert v.shape[-1] == values["gsp_yield"].shape[-1] + return v + + @validator("gsp_y_coords") + def y_coordinates_shape(cls, v, values): + """ Validate 'gsp_y_coords' """ + assert v.shape[-1] == values["gsp_yield"].shape[-1] + return v + + @staticmethod + def fake(batch_size, seq_length_30, n_gsp_per_batch, time_30=None): + """ Make a fake GSP object """ + if time_30 is None: + _, _, time_30 = make_random_time_vectors( + batch_size=batch_size, seq_len_5_minutes=0, seq_len_30_minutes=seq_length_30 + ) + + return GSP( + batch_size=batch_size, + gsp_yield=np.random.randn( + batch_size, + seq_length_30, + n_gsp_per_batch, + ), + gsp_id=np.sort(np.random.randint(0, 340, (batch_size, n_gsp_per_batch))), + gsp_datetime_index=time_30, + gsp_x_coords=np.sort(np.random.randn(batch_size, n_gsp_per_batch)), + gsp_y_coords=np.sort(np.random.randn(batch_size, n_gsp_per_batch))[:, ::-1].copy(), + ) + # copy is needed as torch doesnt not support negative strides + + def pad(self, n_gsp_per_example: int = DEFAULT_N_GSP_PER_EXAMPLE): + """ + Pad out data + + Args: + n_gsp_per_example: The number of gsp's there are per example. + + Note that nothing is returned as the changes are made inplace. + """ + assert self.batch_size == 0, "Padding only works for batch_size=0, i.e one Example" + + pad_size = n_gsp_per_example - self.gsp_yield.shape[-1] + pad_data( + data=self, + one_dimensional_arrays=[GSP_ID, GSP_X_COORDS, GSP_Y_COORDS], + two_dimensional_arrays=[GSP_YIELD], + pad_size=pad_size, + ) + + def get_datetime_index(self) -> Array: + """ Get the datetime index of this data """ + return self.gsp_datetime_index + + def to_xr_dataset(self, i): + """ Make a xr dataset """ + logger.debug(f"Making xr dataset for batch {i}") + assert self.batch_size == 0 + + example_dim = {"example": np.array([i], dtype=np.int32)} + + # GSP + n_gsp = len(self.gsp_id) + + one_dataset = xr.DataArray(self.gsp_yield, dims=["time_30", "gsp"], name="gsp_yield") + one_dataset = one_dataset.to_dataset(name="gsp_yield") + one_dataset[GSP_DATETIME_INDEX] = xr.DataArray( + self.gsp_datetime_index, + dims=["time_30"], + coords=[np.arange(len(self.gsp_datetime_index))], + ) + + # GSP + for name in [GSP_ID, GSP_X_COORDS, GSP_Y_COORDS]: + + var = self.__getattribute__(name) + + one_dataset[name] = xr.DataArray( + var[None, :], + coords={ + **example_dim, + **{"gsp": np.arange(n_gsp, dtype=np.int32)}, + }, + dims=["example", "gsp"], + ) + + one_dataset[GSP_YIELD] = one_dataset[GSP_YIELD].astype(np.float32) + one_dataset[GSP_ID] = one_dataset[GSP_ID].astype(np.float32) + one_dataset[GSP_X_COORDS] = one_dataset[GSP_X_COORDS].astype(np.float32) + one_dataset[GSP_Y_COORDS] = one_dataset[GSP_Y_COORDS].astype(np.float32) + + return one_dataset + + @staticmethod + def from_xr_dataset(xr_dataset): + """ Change xr dataset to model. If data does not exist, then return None """ + if "gsp_yield" in xr_dataset.keys(): + return GSP( + batch_size=xr_dataset["gsp_yield"].shape[0], + gsp_yield=xr_dataset[GSP_YIELD], + gsp_id=xr_dataset[GSP_ID], + gsp_datetime_index=xr_dataset[GSP_DATETIME_INDEX], + gsp_x_coords=xr_dataset[GSP_X_COORDS], + gsp_y_coords=xr_dataset[GSP_Y_COORDS], + ) + else: + return None diff --git a/nowcasting_dataset/data_sources/metadata/metadata_data_source.py b/nowcasting_dataset/data_sources/metadata/metadata_data_source.py new file mode 100644 index 00000000..1af77596 --- /dev/null +++ b/nowcasting_dataset/data_sources/metadata/metadata_data_source.py @@ -0,0 +1,60 @@ +""" Datetime DataSource - add hour and year features """ +from dataclasses import dataclass +from numbers import Number +from typing import List, Tuple + +import pandas as pd +import numpy as np + +from nowcasting_dataset.data_sources.data_source import DataSource +from nowcasting_dataset.data_sources.metadata.metadata_model import Metadata +from nowcasting_dataset.utils import to_numpy + + +@dataclass +class MetadataDataSource(DataSource): + """Add metadata to the batch""" + + object_at_center: str = "GSP" + + def __post_init__(self): + """Post init""" + super().__post_init__() + + def get_example( + self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number + ) -> Metadata: + """ + Get example data + + Args: + t0_dt: list of timestamps + x_meters_center: x center of patches - not needed + y_meters_center: y center of patches - not needed + + Returns: batch data of datetime features + + """ + if self.object_at_center == "GSP": + object_at_center_label = 1 + elif self.object_at_center == "PV": + object_at_center_label = 2 + else: + object_at_center_label = 0 + + return Metadata( + t0_dt=to_numpy(t0_dt), #: Shape: [batch_size,] + x_meters_center=np.array(x_meters_center), + y_meters_center=np.array(y_meters_center), + object_at_center_label=object_at_center_label, + ) + + def get_locations_for_batch( + self, t0_datetimes: pd.DatetimeIndex + ) -> Tuple[List[Number], List[Number]]: + """This method is not needed for MetadataDataSource""" + raise NotImplementedError() + + def datetime_index(self) -> pd.DatetimeIndex: + """This method is not needed for MetadataDataSource""" + raise NotImplementedError() diff --git a/nowcasting_dataset/data_sources/metadata/metadata_model.py b/nowcasting_dataset/data_sources/metadata/metadata_model.py new file mode 100644 index 00000000..2f97026e --- /dev/null +++ b/nowcasting_dataset/data_sources/metadata/metadata_model.py @@ -0,0 +1,74 @@ +""" Model for output of general/metadata data, useful for a batch """ +from typing import Union, List +import numpy as np +import xarray as xr +import torch +from pydantic import validator, Field + +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput +from nowcasting_dataset.time import make_random_time_vectors + +# seems to be a pandas dataseries + + +class Metadata(DataSourceOutput): + """Model for output of general/metadata data""" + + # TODO add descriptions + t0_dt: Union[xr.DataArray, np.ndarray, torch.Tensor, int] #: Shape: [batch_size,] + x_meters_center: Union[xr.DataArray, np.ndarray, torch.Tensor, int] + y_meters_center: Union[xr.DataArray, np.ndarray, torch.Tensor, int] + object_at_center_label: Union[xr.DataArray, np.ndarray, torch.Tensor, int] = Field( + ..., + description="What object is at the center of the batch data " + "0: Nothing at the center, " + "1: GSP system, " + "2: PV system", + ) + + @staticmethod + def fake(batch_size, t0_dt=None): + """Make a xr dataset""" + if t0_dt is None: + t0_dt, _, _ = make_random_time_vectors( + batch_size=batch_size, seq_len_5_minutes=0, seq_len_30_minutes=0 + ) + + return Metadata( + batch_size=batch_size, + t0_dt=t0_dt, + x_meters_center=np.random.randn( + batch_size, + ), + y_meters_center=np.random.randn( + batch_size, + ), + object_at_center_label=np.array([1] * batch_size), + ) + + def to_xr_dataset(self, i): + """Make a xr dataset""" + individual_datasets = [] + for name in ["t0_dt", "x_meters_center", "y_meters_center", "object_at_center_label"]: + + var = self.__getattribute__(name) + + example_dim = {"example": np.array([i], dtype=np.int32)} + + data = xr.DataArray([var], coords=example_dim, dims=["example"], name=name) + + ds = data.to_dataset() + individual_datasets.append(ds) + + return xr.merge(individual_datasets) + + @staticmethod + def from_xr_dataset(xr_dataset): + """Change xr dataset to model. If data does not exist, then return None""" + return Metadata( + batch_size=xr_dataset["t0_dt"].shape[0], + t0_dt=xr_dataset["t0_dt"], + x_meters_center=xr_dataset["x_meters_center"], + y_meters_center=xr_dataset["y_meters_center"], + object_at_center_label=xr_dataset["object_at_center_label"], + ) diff --git a/nowcasting_dataset/data_sources/nwp_data_source.py b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py similarity index 93% rename from nowcasting_dataset/data_sources/nwp_data_source.py rename to nowcasting_dataset/data_sources/nwp/nwp_data_source.py index b8a182c2..cb12dcd6 100644 --- a/nowcasting_dataset/data_sources/nwp_data_source.py +++ b/nowcasting_dataset/data_sources/nwp/nwp_data_source.py @@ -3,7 +3,7 @@ from concurrent import futures from dataclasses import dataclass, InitVar from numbers import Number -from typing import Iterable, Optional, List +from typing import Iterable, Optional import numpy as np import pandas as pd @@ -11,7 +11,8 @@ from nowcasting_dataset import utils from nowcasting_dataset.data_sources.data_source import ZarrDataSource -from nowcasting_dataset.dataset.example import Example, to_numpy +from nowcasting_dataset.data_sources.nwp.nwp_model import NWP +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput _LOG = logging.getLogger(__name__) @@ -125,7 +126,7 @@ def get_batch( t0_datetimes: pd.DatetimeIndex, x_locations: Iterable[Number], y_locations: Iterable[Number], - ) -> List[Example]: + ) -> NWP: """ Get batch data @@ -177,21 +178,25 @@ def get_batch( t0_dt = t0_datetimes[i] selected_data = self._post_process_example(selected_data, t0_dt) - example = self._put_data_into_example(selected_data) + output: DataSourceOutput = self._put_data_into_example(selected_data) if self.convert_to_numpy: - example = to_numpy(example) - examples.append(example) - return examples + output.to_numpy() + examples.append(output) + + return DataSourceOutput.create_batch_from_examples(examples) def _open_data(self) -> xr.DataArray: return open_nwp(self.filename, consolidated=self.consolidated) - def _put_data_into_example(self, selected_data: xr.DataArray) -> Example: - return Example( + def _put_data_into_example(self, selected_data: xr.DataArray) -> NWP: + + return NWP( nwp=selected_data, nwp_x_coords=selected_data.x, nwp_y_coords=selected_data.y, nwp_target_time=selected_data.target_time, + nwp_init_time=np.array(selected_data.init_time.data), + nwp_channel_names=self.channels, # TODO perhaps could get this from selected data instead ) def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray: diff --git a/nowcasting_dataset/data_sources/nwp/nwp_model.py b/nowcasting_dataset/data_sources/nwp/nwp_model.py new file mode 100644 index 00000000..285f4958 --- /dev/null +++ b/nowcasting_dataset/data_sources/nwp/nwp_model.py @@ -0,0 +1,175 @@ +""" Model for output of NWP data """ +from pydantic import Field, validator +from typing import Union, List +import numpy as np +import xarray as xr +import torch + +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput +from nowcasting_dataset.consts import ( + Array, + NWP_VARIABLE_NAMES, + NWP_DATA, +) +from nowcasting_dataset.utils import coord_to_range +from nowcasting_dataset.time import make_random_time_vectors +import logging + +logger = logging.getLogger(__name__) + + +class NWP(DataSourceOutput): + """ Model for output of NWP data """ + + # Shape: [batch_size,] seq_length, width, height, channel + nwp: Array = Field( + ..., + description=" Numerical weather predictions (NWPs) \ + : Shape: [batch_size,] channel, seq_length, width, height", + ) + + nwp_x_coords: Array = Field( + ..., + description="The x (OSGB geo-spatial) coordinates of the NWP data. " + "Shape: [batch_size,] width", + ) + nwp_y_coords: Array = Field( + ..., + description="The y (OSGB geo-spatial) coordinates of the NWP data. " + "Shape: [batch_size,] height", + ) + + nwp_target_time: Array = Field( + ..., + description="Time index of nwp data at 5 minutes past the hour {0, 5, ..., 55}. " + "Datetimes become Unix epochs (UTC) represented as int64 just before being" + "passed into the ML model. The 'target time' is the time the NWP is _about_.", + ) + + nwp_init_time: Union[xr.DataArray, np.ndarray, torch.Tensor, int] = Field( + ..., description="The time when the nwp forecast was made" + ) + + nwp_channel_names: Union[List[List[str]], List[str], np.ndarray] = Field( + ..., description="List of the nwp channels" + ) + + @property + def width(self): + """The width of the nwp data""" + return self.nwp.shape[-2] + + @property + def height(self): + """The width of the nwp data""" + return self.nwp.shape[-1] + + @property + def sequence_length(self): + """The sequence length of the NWP timeseries""" + return self.nwp.shape[-3] + + @validator("nwp_x_coords") + def x_coordinates_shape(cls, v, values): + """ Validate 'nwp_x_coords' """ + assert v.shape[-1] == values["nwp"].shape[-2] + return v + + @validator("nwp_y_coords") + def y_coordinates_shape(cls, v, values): + """ Validate 'nwp_y_coords' """ + assert v.shape[-1] == values["nwp"].shape[-1] + return v + + @staticmethod + def fake(batch_size, seq_length_5, nwp_image_size_pixels, number_nwp_channels, time_5=None): + """ Create fake data """ + if time_5 is None: + _, time_5, _ = make_random_time_vectors( + batch_size=batch_size, seq_len_5_minutes=seq_length_5, seq_len_30_minutes=0 + ) + + return NWP( + batch_size=batch_size, + nwp=np.random.randn( + batch_size, + number_nwp_channels, + seq_length_5, + nwp_image_size_pixels, + nwp_image_size_pixels, + ), + nwp_x_coords=np.sort(np.random.randn(batch_size, nwp_image_size_pixels)), + nwp_y_coords=np.sort(np.random.randn(batch_size, nwp_image_size_pixels))[ + :, ::-1 + ].copy(), + # copy is needed as torch doesnt not support negative strides + nwp_target_time=time_5, + nwp_init_time=np.sort( + np.random.randn( + batch_size, + ) + ), + nwp_channel_names=[ + NWP_VARIABLE_NAMES[0:number_nwp_channels] for _ in range(batch_size) + ], + ) + + def get_datetime_index(self) -> Array: + """ Get the datetime index of this data """ + return self.nwp_target_time + + def to_xr_data_array(self): + """ Change to data_array. Sets the nwp field in-place.""" + self.nwp = xr.DataArray( + self.nwp, + dims=["variable", "target_time", "x", "y"], + coords={ + "variable": self.nwp_channel_names, + "target_time": self.nwp_target_time, + "init_time": self.nwp_init_time, + "x": self.nwp_x_coords, + "y": self.nwp_y_coords, + }, + ) + + def to_xr_dataset(self, i): + """ Make a xr dataset """ + logger.debug(f"Making xr dataset for batch {i}") + if type(self.nwp) != xr.DataArray: + self.to_xr_data_array() + + ds = self.nwp.to_dataset(name="nwp") + ds["nwp"] = ds["nwp"].astype(np.float32) + ds = ds.round(2) + + ds = ds.rename({"target_time": "time"}) + for dim in ["time", "x", "y"]: + ds = coord_to_range(ds, dim, prefix="nwp") + ds = ds.rename( + { + "variable": f"nwp_variable", + "x": "nwp_x", + "y": "nwp_y", + } + ) + + ds["nwp_x_coords"] = ds["nwp_x_coords"].astype(np.float32) + ds["nwp_y_coords"] = ds["nwp_y_coords"].astype(np.float32) + + return ds + + @staticmethod + def from_xr_dataset(xr_dataset): + """ Change xr dataset to model. If data does not exist, then return None """ + if NWP_DATA in xr_dataset.keys(): + return NWP( + batch_size=xr_dataset[NWP_DATA].shape[0], + nwp=xr_dataset[NWP_DATA], + nwp_channel_names=xr_dataset[NWP_DATA].nwp_variable.values, + nwp_init_time=xr_dataset[NWP_DATA].init_time, + nwp_target_time=xr_dataset["nwp_time_coords"], + nwp_x_coords=xr_dataset[NWP_DATA].nwp_x, + nwp_y_coords=xr_dataset[NWP_DATA].nwp_y, + ) + else: + return None diff --git a/nowcasting_dataset/data_sources/pv_data_source.py b/nowcasting_dataset/data_sources/pv/pv_data_source.py similarity index 91% rename from nowcasting_dataset/data_sources/pv_data_source.py rename to nowcasting_dataset/data_sources/pv/pv_data_source.py index 2bf4cfed..8f8bf558 100644 --- a/nowcasting_dataset/data_sources/pv_data_source.py +++ b/nowcasting_dataset/data_sources/pv/pv_data_source.py @@ -17,17 +17,11 @@ from nowcasting_dataset import geospatial, utils from nowcasting_dataset.consts import ( - PV_SYSTEM_ID, - PV_SYSTEM_ROW_NUMBER, - PV_SYSTEM_X_COORDS, - PV_SYSTEM_Y_COORDS, - PV_YIELD, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE, - OBJECT_AT_CENTER, ) from nowcasting_dataset.data_sources.data_source import ImageDataSource -from nowcasting_dataset.dataset.example import Example from nowcasting_dataset.square import get_bounding_box_mask +from nowcasting_dataset.data_sources.pv.pv_model import PV logger = logging.getLogger(__name__) @@ -202,7 +196,7 @@ def _get_all_pv_system_ids_in_roi( def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> Example: + ) -> PV: """ Get Example data for PV data @@ -238,42 +232,21 @@ def get_example( pv_system_row_number = np.flatnonzero(self.pv_metadata.index.isin(all_pv_system_ids)) pv_system_x_coords = self.pv_metadata.location_x[all_pv_system_ids] pv_system_y_coords = self.pv_metadata.location_y[all_pv_system_ids] - # Save data into the Example dict... - example = Example( - t0_dt=t0_dt, - pv_system_id=all_pv_system_ids, - pv_system_row_number=pv_system_row_number, - pv_yield=selected_pv_power, - x_meters_center=x_meters_center, - y_meters_center=y_meters_center, - pv_system_x_coords=pv_system_x_coords, - pv_system_y_coords=pv_system_y_coords, - pv_datetime_index=selected_pv_power.index, - ) - if self.get_center: - example[OBJECT_AT_CENTER] = "pv" + # Save data into the PV object... - # Pad (if necessary) so returned arrays are always of size n_pv_systems_per_example. - pad_size = self.n_pv_systems_per_example - len(all_pv_system_ids) - - one_dimensional_arrays = [ - PV_SYSTEM_ID, - PV_SYSTEM_ROW_NUMBER, - PV_SYSTEM_X_COORDS, - PV_SYSTEM_Y_COORDS, - ] - - pad_nans_variables = [PV_YIELD] - - example = utils.pad_data( - data=example, - one_dimensional_arrays=one_dimensional_arrays, - two_dimensional_arrays=pad_nans_variables, - pad_size=pad_size, + pv = PV( + pv_system_id=all_pv_system_ids.values, + pv_system_row_number=pv_system_row_number, + pv_yield=selected_pv_power.values, + pv_system_x_coords=pv_system_x_coords.values, + pv_system_y_coords=pv_system_y_coords.values, + pv_datetime_index=selected_pv_power.index.values, ) - return example + pv.pad() + + return pv def get_locations_for_batch( self, t0_datetimes: pd.DatetimeIndex diff --git a/nowcasting_dataset/data_sources/pv/pv_model.py b/nowcasting_dataset/data_sources/pv/pv_model.py new file mode 100644 index 00000000..488806a6 --- /dev/null +++ b/nowcasting_dataset/data_sources/pv/pv_model.py @@ -0,0 +1,193 @@ +""" Model for output of PV data """ +from pydantic import Field, validator +import numpy as np +import xarray as xr + +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput, pad_data +from nowcasting_dataset.consts import ( + Array, + PV_YIELD, + PV_DATETIME_INDEX, + PV_SYSTEM_Y_COORDS, + PV_SYSTEM_X_COORDS, + PV_SYSTEM_ROW_NUMBER, + PV_SYSTEM_ID, + DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE, +) +from nowcasting_dataset.time import make_random_time_vectors +import logging + +logger = logging.getLogger(__name__) + + +class PV(DataSourceOutput): + """ Model for output of PV data """ + + # Shape: [batch_size,] seq_length, width, height, channel + pv_yield: Array = Field( + ..., + description=" PV yield from all PV systems in the region of interest (ROI). \ + : Includes central PV system, which will always be the first entry. \ + : shape = [batch_size, ] seq_length, n_pv_systems_per_example", + ) + + #: PV identification. + #: shape = [batch_size, ] n_pv_systems_per_example + pv_system_id: Array = Field(..., description="PV system ID, e.g. from PVoutput.org") + pv_system_row_number: Array = Field(..., description="pv row number, made by OCF TODO") + + pv_datetime_index: Array = Field( + ..., + description="The datetime associated with the pv system data. shape = [batch_size, ] sequence length,", + ) + + pv_system_x_coords: Array = Field( + ..., + description="The x (OSGB geo-spatial) coordinates of the pv systems. " + "Shape: [batch_size,] n_pv_systems_per_example", + ) + pv_system_y_coords: Array = Field( + ..., + description="The y (OSGB geo-spatial) coordinates of the pv systems. " + "Shape: [batch_size,] n_pv_systems_per_example", + ) + + @property + def number_of_pv_systems(self): + """The number of pv systems""" + return self.pv_yield.shape[-1] + + @property + def sequence_length(self): + """The sequence length of the pv data""" + return self.pv_yield.shape[-2] + + @validator("pv_system_x_coords") + def x_coordinates_shape(cls, v, values): + """ Validate 'pv_system_x_coords' """ + assert v.shape[-1] == values["pv_yield"].shape[-1] + return v + + @validator("pv_system_y_coords") + def y_coordinates_shape(cls, v, values): + """ Validate 'pv_system_y_coords' """ + assert v.shape[-1] == values["pv_yield"].shape[-1] + return v + + @staticmethod + def fake(batch_size, seq_length_5, n_pv_systems_per_batch, time_5=None): + """ Create fake data """ + if time_5 is None: + _, time_5, _ = make_random_time_vectors( + batch_size=batch_size, seq_len_5_minutes=seq_length_5, seq_len_30_minutes=0 + ) + + return PV( + batch_size=batch_size, + pv_yield=np.random.randn( + batch_size, + seq_length_5, + n_pv_systems_per_batch, + ), + pv_system_id=np.sort(np.random.randint(0, 10000, (batch_size, n_pv_systems_per_batch))), + pv_system_row_number=np.sort( + np.random.randint(0, 1000, (batch_size, n_pv_systems_per_batch)) + ), + pv_datetime_index=time_5, + pv_system_x_coords=np.sort(np.random.randn(batch_size, n_pv_systems_per_batch)), + pv_system_y_coords=np.sort(np.random.randn(batch_size, n_pv_systems_per_batch))[ + :, ::-1 + ].copy(), # copy is needed as torch doesnt not support negative strides + ) + + def pad(self, n_pv_systems_per_example: int = DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE): + """ + Pad out data + + Args: + n_pv_systems_per_example: The number of pv systems there are per example. + + Note that nothing is returned as the changes are made inplace. + """ + assert self.batch_size == 0, "Padding only works for batch_size=0, i.e one Example" + + pad_size = n_pv_systems_per_example - self.pv_yield.shape[-1] + # Pad (if necessary) so returned arrays are always of size + pad_shape = (0, pad_size) # (before, after) + + one_dimensional_arrays = [ + PV_SYSTEM_ID, + PV_SYSTEM_ROW_NUMBER, + PV_SYSTEM_X_COORDS, + PV_SYSTEM_Y_COORDS, + ] + + pad_data( + data=self, + pad_size=pad_size, + one_dimensional_arrays=one_dimensional_arrays, + two_dimensional_arrays=[PV_YIELD], + ) + + def get_datetime_index(self) -> Array: + """ Get the datetime index of this data """ + return self.pv_datetime_index + + def to_xr_dataset(self, i): + """ Make a xr dataset """ + logger.debug(f"Making xr dataset for batch {i}") + assert self.batch_size == 0 + + example_dim = {"example": np.array([i], dtype=np.int32)} + + # PV + one_dataset = xr.DataArray(self.pv_yield, dims=["time", "pv_system"]) + one_dataset = one_dataset.to_dataset(name="pv_yield") + n_pv_systems = len(self.pv_system_id) + + one_dataset[PV_DATETIME_INDEX] = xr.DataArray( + self.pv_datetime_index, + dims=["time"], + coords=[np.arange(len(self.pv_datetime_index))], + ) + + # 1D + for name in [ + PV_SYSTEM_ID, + PV_SYSTEM_ROW_NUMBER, + PV_SYSTEM_X_COORDS, + PV_SYSTEM_Y_COORDS, + ]: + var = self.__getattribute__(name) + + one_dataset[name] = xr.DataArray( + var[None, :], + coords={ + **example_dim, + **{"pv_system": np.arange(n_pv_systems, dtype=np.int32)}, + }, + dims=["example", "pv_system"], + ) + + one_dataset["pv_system_id"] = one_dataset["pv_system_id"].astype(np.float32) + one_dataset["pv_system_row_number"] = one_dataset["pv_system_row_number"].astype(np.float32) + one_dataset["pv_system_x_coords"] = one_dataset["pv_system_x_coords"].astype(np.float32) + one_dataset["pv_system_y_coords"] = one_dataset["pv_system_y_coords"].astype(np.float32) + + return one_dataset + + @staticmethod + def from_xr_dataset(xr_dataset): + """ Change xr dataset to model. If data does not exist, then return None """ + if PV_YIELD in xr_dataset.keys(): + return PV( + batch_size=xr_dataset[PV_YIELD].shape[0], + pv_yield=xr_dataset[PV_YIELD], + pv_system_id=xr_dataset[PV_SYSTEM_ID], + pv_system_row_number=xr_dataset[PV_SYSTEM_ROW_NUMBER], + pv_datetime_index=xr_dataset[PV_DATETIME_INDEX], + pv_system_x_coords=xr_dataset[PV_SYSTEM_X_COORDS], + pv_system_y_coords=xr_dataset[PV_SYSTEM_Y_COORDS], + ) + else: + return None diff --git a/nowcasting_dataset/data_sources/satellite_data_source.py b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py similarity index 94% rename from nowcasting_dataset/data_sources/satellite_data_source.py rename to nowcasting_dataset/data_sources/satellite/satellite_data_source.py index 19b22a53..5fbfdb21 100644 --- a/nowcasting_dataset/data_sources/satellite_data_source.py +++ b/nowcasting_dataset/data_sources/satellite/satellite_data_source.py @@ -3,14 +3,15 @@ from concurrent import futures from dataclasses import dataclass, InitVar from numbers import Number -from typing import Iterable, Optional, List +from typing import Iterable, Optional import numpy as np import pandas as pd import xarray as xr from nowcasting_dataset.data_sources.data_source import ZarrDataSource -from nowcasting_dataset.dataset.example import Example, to_numpy +from nowcasting_dataset.data_sources.satellite.satellite_model import Satellite +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput _LOG = logging.getLogger("nowcasting_dataset") @@ -107,7 +108,7 @@ def get_batch( t0_datetimes: pd.DatetimeIndex, x_locations: Iterable[Number], y_locations: Iterable[Number], - ) -> List[Example]: + ) -> Satellite: """ Get batch data @@ -150,17 +151,21 @@ def get_batch( example = self.get_example(t0_datetime, x_location, y_location) examples.append(example) + output = DataSourceOutput.create_batch_from_examples(examples) + if self.convert_to_numpy: - examples = [to_numpy(example) for example in examples] + output.to_numpy() self._cache = {} - return examples - def _put_data_into_example(self, selected_data: xr.DataArray) -> Example: - return Example( + return output + + def _put_data_into_example(self, selected_data: xr.DataArray) -> Satellite: + return Satellite( sat_data=selected_data, sat_x_coords=selected_data.x, sat_y_coords=selected_data.y, sat_datetime_index=selected_data.time, + sat_channel_names=self.channels, ) def _get_time_slice(self, t0_dt: pd.Timestamp) -> xr.DataArray: diff --git a/nowcasting_dataset/data_sources/satellite/satellite_model.py b/nowcasting_dataset/data_sources/satellite/satellite_model.py new file mode 100644 index 00000000..059e0359 --- /dev/null +++ b/nowcasting_dataset/data_sources/satellite/satellite_model.py @@ -0,0 +1,144 @@ +""" Model for output of satellite data """ +from pydantic import Field, validator +from typing import Union, List +import numpy as np +import xarray as xr + +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput +from nowcasting_dataset.consts import Array, SAT_VARIABLE_NAMES +from nowcasting_dataset.utils import coord_to_range +from nowcasting_dataset.time import make_random_time_vectors +import logging + +logger = logging.getLogger(__name__) + + +class Satellite(DataSourceOutput): + """ Model for output of satellite data """ + + # Shape: [batch_size,] seq_length, width, height, channel + sat_data: Array = Field( + ..., + description="Satellites images. Shape: [batch_size,] seq_length, width, height, channel", + ) + sat_x_coords: Array = Field( + ..., + description="aThe x (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] width", + ) + sat_y_coords: Array = Field( + ..., + description="The y (OSGB geo-spatial) coordinates of the satellite images. Shape: [batch_size,] height", + ) + + sat_datetime_index: Array = Field( + ..., + description="Time index of satellite data at 5 minutes past the hour {0, 5, ..., 55}. " + "*not* the {4, 9, ..., 59} timings of the satellite imagery. " + "Datetimes become Unix epochs (UTC) represented as int64 just before being" + "passed into the ML model.", + ) + + sat_channel_names: Union[List[List[str]], List[str], np.ndarray] = Field( + ..., description="List of the satellite channels" + ) + + @validator("sat_x_coords") + def x_coordinates_shape(cls, v, values): + """ Validate 'sat_x_coords' """ + assert v.shape[-1] == values["sat_data"].shape[-3] + return v + + @validator("sat_y_coords") + def y_coordinates_shape(cls, v, values): + """ Validate 'sat_y_coords' """ + assert v.shape[-1] == values["sat_data"].shape[-2] + return v + + @staticmethod + def fake( + batch_size=32, + seq_length_5=19, + satellite_image_size_pixels=64, + number_sat_channels=7, + time_5=None, + ): + """ Create fake data """ + if time_5 is None: + _, time_5, _ = make_random_time_vectors( + batch_size=batch_size, seq_len_5_minutes=seq_length_5, seq_len_30_minutes=0 + ) + + s = Satellite( + batch_size=batch_size, + sat_data=np.random.randn( + batch_size, + seq_length_5, + satellite_image_size_pixels, + satellite_image_size_pixels, + number_sat_channels, + ), + sat_x_coords=np.sort(np.random.randn(batch_size, satellite_image_size_pixels)), + sat_y_coords=np.sort(np.random.randn(batch_size, satellite_image_size_pixels))[ + :, ::-1 + ].copy() + # copy is needed as torch doesnt not support negative strides + , + sat_datetime_index=time_5, + sat_channel_names=[ + SAT_VARIABLE_NAMES[0:number_sat_channels] for _ in range(batch_size) + ], + ) + + return s + + def get_datetime_index(self) -> Array: + """ Get the datetime index of this data """ + return self.sat_datetime_index + + def to_xr_dataset(self, i): + """ Make a xr dataset """ + logger.debug(f"Making xr dataset for batch {i}") + if type(self.sat_data) != xr.DataArray: + self.sat_data = xr.DataArray( + self.sat_data, + coords={ + "time": self.sat_datetime_index, + "x": self.sat_x_coords, + "y": self.sat_y_coords, + "variable": self.sat_channel_names, # assume all channels are the same + }, + ) + + ds = self.sat_data.to_dataset(name="sat_data") + ds["sat_data"] = ds["sat_data"].astype(np.int16) + ds = ds.round(2) + + for dim in ["time", "x", "y"]: + ds = coord_to_range(ds, dim, prefix="sat") + ds = ds.rename( + { + "variable": f"sat_variable", + "x": f"sat_x", + "y": f"sat_y", + } + ) + + ds["sat_x_coords"] = ds["sat_x_coords"].astype(np.int32) + ds["sat_y_coords"] = ds["sat_y_coords"].astype(np.int32) + + return ds + + @staticmethod + def from_xr_dataset(xr_dataset): + """ Change xr dataset to model. If data does not exist, then return None """ + if "sat_data" in xr_dataset.keys(): + return Satellite( + batch_size=xr_dataset["sat_data"].shape[0], + sat_data=xr_dataset["sat_data"], + sat_x_coords=xr_dataset["sat_x_coords"], + sat_y_coords=xr_dataset["sat_y_coords"], + sat_datetime_index=xr_dataset["sat_time_coords"], + sat_channel_names=xr_dataset["sat_data"].sat_variable.values, + ) + else: + return None diff --git a/nowcasting_dataset/data_sources/sun/sun_data_source.py b/nowcasting_dataset/data_sources/sun/sun_data_source.py index 1b91650f..755ec736 100644 --- a/nowcasting_dataset/data_sources/sun/sun_data_source.py +++ b/nowcasting_dataset/data_sources/sun/sun_data_source.py @@ -1,22 +1,17 @@ """ Loading Raw data """ from nowcasting_dataset.data_sources.data_source import DataSource -from nowcasting_dataset.dataset.example import Example -from nowcasting_dataset import time as nd_time from dataclasses import dataclass import pandas as pd from numbers import Number from typing import List, Tuple, Union, Optional from pathlib import Path import numpy as np -import io -import gcsfs -import xarray as xr -from nowcasting_dataset.geospatial import osgb_to_lat_lon from datetime import datetime -from nowcasting_dataset.consts import SUN_AZIMUTH_ANGLE, SUN_ELEVATION_ANGLE from nowcasting_dataset.data_sources.sun.raw_data_load_save import load_from_zarr, x_y_to_name +from nowcasting_dataset.data_sources.sun.sun_model import Sun + @dataclass class SunDataSource(DataSource): @@ -33,7 +28,7 @@ def __post_init__(self): def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> Example: + ) -> Sun: """ Get example data from t0_dt and x and y xoordinates @@ -71,11 +66,13 @@ def get_example( azimuth = self.azimuth.loc[start_dt:end_dt][name] elevation = self.elevation.loc[start_dt:end_dt][name] - example = Example() - example[SUN_AZIMUTH_ANGLE] = azimuth.values - example[SUN_ELEVATION_ANGLE] = elevation.values + sun = Sun( + sun_azimuth_angle=azimuth.values, + sun_elevation_angle=elevation.values, + sun_datetime_index=azimuth.index.values, + ) - return example + return sun def _load(self): diff --git a/nowcasting_dataset/data_sources/sun/sun_model.py b/nowcasting_dataset/data_sources/sun/sun_model.py new file mode 100644 index 00000000..11562361 --- /dev/null +++ b/nowcasting_dataset/data_sources/sun/sun_model.py @@ -0,0 +1,115 @@ +""" Model for Sun features """ +from pydantic import Field, validator +import numpy as np +import xarray as xr + +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput +from nowcasting_dataset.consts import Array, SUN_AZIMUTH_ANGLE, SUN_ELEVATION_ANGLE +from nowcasting_dataset.utils import coord_to_range +from nowcasting_dataset.time import make_random_time_vectors +import logging + +logger = logging.getLogger(__name__) + + +class Sun(DataSourceOutput): + """ Model for Sun features """ + + sun_azimuth_angle: Array = Field( + ..., + description="PV azimuth angles i.e where the sun is. " "Shape: [batch_size,] seq_length", + ) + + sun_elevation_angle: Array = Field( + ..., + description="PV elevation angles i.e where the sun is. " "Shape: [batch_size,] seq_length", + ) + sun_datetime_index: Array + + @validator("sun_elevation_angle") + def elevation_shape(cls, v, values): + """ + Validate 'sun_elevation_angle'. + + This is done by change shape is the same as the "sun_azimuth_angle" + """ + assert v.shape == values["sun_azimuth_angle"].shape + return v + + @validator("sun_datetime_index") + def sun_datetime_index_shape(cls, v, values): + """ + Validate 'sun_datetime_index'. + + This is done by checking last dimension is the same as the last dim of 'sun_azimuth_angle' + i.e the time dimension + """ + assert v.shape[-1] == values["sun_azimuth_angle"].shape[-1] + return v + + @staticmethod + def fake(batch_size, seq_length_5, time_5=None): + """ Create fake data """ + if time_5 is None: + _, time_5, _ = make_random_time_vectors( + batch_size=batch_size, seq_len_5_minutes=seq_length_5, seq_len_30_minutes=0 + ) + + return Sun( + batch_size=batch_size, + sun_azimuth_angle=np.random.randn( + batch_size, + seq_length_5, + ), + sun_elevation_angle=np.random.randn( + batch_size, + seq_length_5, + ), + sun_datetime_index=time_5, + ) + + def get_datetime_index(self): + """ Get the datetime index of this data """ + return self.sun_datetime_index + + def to_xr_dataset(self, i): + """ Make a xr dataset """ + logger.debug(f"Making xr dataset for batch {i}") + individual_datasets = [] + for name in [SUN_AZIMUTH_ANGLE, SUN_ELEVATION_ANGLE]: + + var = self.__getattribute__(name) + + data = xr.DataArray( + var, + dims=["time"], + coords={"time": self.sun_datetime_index}, + name=name, + ) + + ds = data.to_dataset() + ds = coord_to_range(ds, "time", prefix=None) + individual_datasets.append(ds) + + data = xr.DataArray( + self.sun_datetime_index, + dims=["time"], + coords=[np.arange(len(self.sun_datetime_index))], + ) + ds = data.to_dataset(name="sun_datetime_index") + individual_datasets.append(ds) + + return xr.merge(individual_datasets) + + @staticmethod + def from_xr_dataset(xr_dataset): + """ Change xr dataset to model. If data does not exist, then return None """ + if SUN_AZIMUTH_ANGLE in xr_dataset.keys(): + return Sun( + batch_size=xr_dataset[SUN_AZIMUTH_ANGLE].shape[0], + sun_azimuth_angle=xr_dataset[SUN_AZIMUTH_ANGLE], + sun_elevation_angle=xr_dataset[SUN_ELEVATION_ANGLE], + sun_datetime_index=xr_dataset["sun_datetime_index"], + ) + else: + return None diff --git a/nowcasting_dataset/data_sources/topographic_data_source.py b/nowcasting_dataset/data_sources/topographic/topographic_data_source.py similarity index 97% rename from nowcasting_dataset/data_sources/topographic_data_source.py rename to nowcasting_dataset/data_sources/topographic/topographic_data_source.py index 94df5819..e396dd26 100644 --- a/nowcasting_dataset/data_sources/topographic_data_source.py +++ b/nowcasting_dataset/data_sources/topographic/topographic_data_source.py @@ -10,8 +10,9 @@ from nowcasting_dataset.consts import TOPOGRAPHIC_DATA from nowcasting_dataset.data_sources.data_source import ImageDataSource -from nowcasting_dataset.dataset.example import Example from nowcasting_dataset.geospatial import OSGB + +from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic from nowcasting_dataset.utils import OpenData # Means computed with @@ -66,7 +67,7 @@ def __post_init__(self, image_size_pixels: int, meters_per_pixel: int): def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number - ) -> Example: + ) -> Topographic: """ Get a single example @@ -113,7 +114,7 @@ def get_example( return self._put_data_into_example(selected_data) - def _put_data_into_example(self, selected_data: xr.DataArray) -> Example: + def _put_data_into_example(self, selected_data: xr.DataArray) -> Topographic: """ Insert the data and coordinates into an Example @@ -123,7 +124,7 @@ def _put_data_into_example(self, selected_data: xr.DataArray) -> Example: Returns: Example containing the Topographic data """ - return Example( + return Topographic( topo_data=selected_data, topo_x_coords=selected_data.x, topo_y_coords=selected_data.y, diff --git a/nowcasting_dataset/data_sources/topographic/topographic_model.py b/nowcasting_dataset/data_sources/topographic/topographic_model.py new file mode 100644 index 00000000..06f4350f --- /dev/null +++ b/nowcasting_dataset/data_sources/topographic/topographic_model.py @@ -0,0 +1,112 @@ +""" Model for Topogrpahic features """ +from pydantic import Field, validator +import xarray as xr +import numpy as np +import logging +from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput +from nowcasting_dataset.consts import Array + +from nowcasting_dataset.consts import TOPOGRAPHIC_DATA, TOPOGRAPHIC_X_COORDS, TOPOGRAPHIC_Y_COORDS +from nowcasting_dataset.utils import coord_to_range + +logger = logging.getLogger(__name__) + + +class Topographic(DataSourceOutput): + """ + Topographic/elevation map features. + """ + + # Shape: [batch_size,] width, height + topo_data: Array = Field( + ..., + description="Elevation map of the area covered by the satellite data. " + "Shape: [batch_size], width, height", + ) + topo_x_coords: Array = Field( + ..., + description="The x (OSGB geo-spatial) coordinates of the topographic images. Shape: [batch_size,] width", + ) + topo_y_coords: Array = Field( + ..., + description="The y (OSGB geo-spatial) coordinates of the topographic images. Shape: [batch_size,] height", + ) + + @property + def height(self): + """ The height of the topographic image """ + return self.topo_data.shape[-1] + + @property + def width(self): + """ The width of the topographic image """ + return self.topo_data.shape[-2] + + @validator("topo_x_coords") + def x_coordinates_shape(cls, v, values): + """ Validate 'topo_x_coords' """ + assert v.shape[-1] == values["topo_data"].shape[-2] + return v + + @validator("topo_y_coords") + def y_coordinates_shape(cls, v, values): + """ Validate 'topo_y_coords' """ + assert v.shape[-1] == values["topo_data"].shape[-1] + return v + + @staticmethod + def fake(batch_size, satellite_image_size_pixels): + """ Create fake data """ + return Topographic( + batch_size=batch_size, + topo_data=np.random.randn( + batch_size, + satellite_image_size_pixels, + satellite_image_size_pixels, + ), + topo_x_coords=np.sort(np.random.randn(batch_size, satellite_image_size_pixels)), + topo_y_coords=np.sort(np.random.randn(batch_size, satellite_image_size_pixels))[ + :, ::-1 + ].copy(), + # copy is needed as torch doesnt not support negative strides + ) + + def to_xr_dataset(self, i): + """ Make a xr dataset """ + logger.debug(f"Making xr dataset for batch {i}") + data = xr.DataArray( + self.topo_data, + coords={ + "x": self.topo_x_coords, + "y": self.topo_y_coords, + }, + ) + + ds = data.to_dataset(name=TOPOGRAPHIC_DATA) + for dim in ["x", "y"]: + ds = coord_to_range(ds, dim, prefix="topo") + ds = ds.rename( + { + "x": f"topo_x", + "y": f"topo_y", + } + ) + + ds[TOPOGRAPHIC_DATA] = ds[TOPOGRAPHIC_DATA].astype(np.float32) + ds[TOPOGRAPHIC_X_COORDS] = ds[TOPOGRAPHIC_X_COORDS].astype(np.float32) + ds[TOPOGRAPHIC_Y_COORDS] = ds[TOPOGRAPHIC_Y_COORDS].astype(np.float32) + + return ds + + @staticmethod + def from_xr_dataset(xr_dataset): + """ Change xr dataset to model. If data does not exist, then return None """ + if TOPOGRAPHIC_DATA in xr_dataset.keys(): + return Topographic( + batch_size=xr_dataset[TOPOGRAPHIC_DATA].shape[0], + topo_data=xr_dataset[TOPOGRAPHIC_DATA], + topo_x_coords=xr_dataset[TOPOGRAPHIC_DATA].topo_x, + topo_y_coords=xr_dataset[TOPOGRAPHIC_DATA].topo_y, + ) + else: + return None diff --git a/nowcasting_dataset/dataset/README.md b/nowcasting_dataset/dataset/README.md index f472084e..d7f59549 100644 --- a/nowcasting_dataset/dataset/README.md +++ b/nowcasting_dataset/dataset/README.md @@ -4,7 +4,7 @@ This folder contains the following files ## batch.py -Functions used to 'play with' batch data, where "batch data" is a List of Example objects; i.e. `List[Example]`. +'Batch' pydantic class, to hold batch data in. An 'Example' is one item in the batch. ## datamodule.py @@ -23,12 +23,6 @@ NetCDFDataset - torch.utils.data.Dataset: Use for loading pre-made batches NowcastingDataset - torch.utils.data.IterableDataset: Dataset for making batches -## example.py - -Main thing in here is a Typed Dictionary. This is used to store one element of data use for one step in the ML models. -There is also a validation function. See this file for documentation about exactly what data is available in each ML -training Example. - -## validatey.py +## validate.py Contains a class that can validate the prepare ml dataset diff --git a/nowcasting_dataset/dataset/batch.py b/nowcasting_dataset/dataset/batch.py index e684515e..7b6f8388 100644 --- a/nowcasting_dataset/dataset/batch.py +++ b/nowcasting_dataset/dataset/batch.py @@ -1,211 +1,226 @@ """ batch functions """ import logging from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Union -import numpy as np import xarray as xr - -from nowcasting_dataset.consts import ( - GSP_ID, - GSP_YIELD, - GSP_X_COORDS, - GSP_Y_COORDS, - GSP_DATETIME_INDEX, - DATETIME_FEATURE_NAMES, - T0_DT, - TOPOGRAPHIC_DATA, - TOPOGRAPHIC_X_COORDS, - TOPOGRAPHIC_Y_COORDS, -) -from nowcasting_dataset.dataset.example import Example +from pydantic import BaseModel, Field + +from nowcasting_dataset.config.model import Configuration + +from nowcasting_dataset.data_sources.datetime.datetime_model import Datetime +from nowcasting_dataset.data_sources.metadata.metadata_model import Metadata +from nowcasting_dataset.data_sources.gsp.gsp_model import GSP +from nowcasting_dataset.data_sources.nwp.nwp_model import NWP +from nowcasting_dataset.data_sources.pv.pv_model import PV +from nowcasting_dataset.data_sources.satellite.satellite_model import Satellite +from nowcasting_dataset.data_sources.sun.sun_model import Sun +from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic +from nowcasting_dataset.time import make_random_time_vectors from nowcasting_dataset.utils import get_netcdf_filename _LOG = logging.getLogger(__name__) -def write_batch_locally(batch: List[Example], batch_i: int, path: Path): +class Example(BaseModel): + """Single Data item""" + + metadata: Metadata + satellite: Optional[Satellite] + topographic: Optional[Topographic] + pv: Optional[PV] + sun: Optional[Sun] + gsp: Optional[GSP] + nwp: Optional[NWP] + datetime: Optional[Datetime] + + def change_type_to_numpy(self): + """Change data to numpy""" + for data_source in self.data_sources: + if data_source is not None: + data_source.to_numpy() + + @property + def data_sources(self): + """ The different data sources """ + return [ + self.satellite, + self.topographic, + self.pv, + self.sun, + self.gsp, + self.nwp, + self.datetime, + self.metadata, + ] + + +class Batch(Example): """ - Write a batch to a locally file - - Args: - batch: A batch of data - batch_i: The number of the batch - path: The directory to write the batch into. - """ - dataset = batch_to_dataset(batch) - dataset = fix_dtypes(dataset) - encoding = {name: {"compression": "lzf"} for name in dataset.data_vars} - filename = get_netcdf_filename(batch_i) - local_filename = path / filename - dataset.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding) + Batch data object. + Contains the following data sources + - gsp, satellite, topogrpahic, sun, pv, nwp and datetime. + Also contains metadata of the class -def fix_dtypes(concat_ds): """ - TODO - """ - ds_dtypes = { - "example": np.int32, - "sat_x_coords": np.int32, - "sat_y_coords": np.int32, - "nwp": np.float32, - "nwp_x_coords": np.float32, - "nwp_y_coords": np.float32, - "pv_system_id": np.float32, - "pv_system_row_number": np.float32, - "pv_system_x_coords": np.float32, - "pv_system_y_coords": np.float32, - GSP_YIELD: np.float32, - GSP_ID: np.float32, - GSP_X_COORDS: np.float32, - GSP_Y_COORDS: np.float32, - TOPOGRAPHIC_X_COORDS: np.float32, - TOPOGRAPHIC_Y_COORDS: np.float32, - TOPOGRAPHIC_DATA: np.float32, - } - - for name, dtype in ds_dtypes.items(): - concat_ds[name] = concat_ds[name].astype(dtype) - - assert concat_ds["sat_data"].dtype == np.int16 - return concat_ds - - -def batch_to_dataset(batch: List[Example]) -> xr.Dataset: + + batch_size: int = Field( + ..., + g=0, + description="The size of this batch. If the batch size is 0, " + "then this item stores one data item", + ) + + def batch_to_dataset(self) -> xr.Dataset: + """Change batch to xr.Dataset so it can be saved and compressed""" + return batch_to_dataset(batch=self) + + @staticmethod + def load_batch_from_dataset(xr_dataset: xr.Dataset): + """Change xr.Datatset to Batch object""" + # get a list of data sources + data_sources_names = Example.__fields__.keys() + + # collect data sources + data_sources_dict = {} + for data_source_name in data_sources_names: + cls = Example.__fields__[data_source_name].type_ + data_sources_dict[data_source_name] = cls.from_xr_dataset(xr_dataset=xr_dataset) + + data_sources_dict["batch_size"] = data_sources_dict["metadata"].batch_size + + return Batch(**data_sources_dict) + + def split(self) -> List[Example]: + """Split batch into list of data items""" + # collect split data + split_data_dict = {} + for data_source in self.data_sources: + if data_source is not None: + cls = data_source.__class__.__name__.lower() + split_data_dict[cls] = data_source.split() + + # make in to Example objects + data_items = [] + for batch_idx in range(self.batch_size): + split_data_one_example_dict = {k: v[batch_idx] for k, v in split_data_dict.items()} + data_items.append(Example(**split_data_one_example_dict)) + + return data_items + + @staticmethod + def fake(configuration: Configuration = Configuration()): + """Create fake batch""" + process = configuration.process + + t0_dt, time_5, time_30 = make_random_time_vectors( + batch_size=process.batch_size, + seq_len_5_minutes=process.seq_len_5_minutes, + seq_len_30_minutes=process.seq_len_30_minutes, + ) + + return Batch( + batch_size=process.batch_size, + metadata=Metadata.fake(batch_size=process.batch_size, t0_dt=t0_dt), + satellite=Satellite.fake( + process.batch_size, + process.seq_len_5_minutes, + process.satellite_image_size_pixels, + len(process.nwp_channels), + time_5=time_5, + ), + topographic=Topographic.fake( + batch_size=process.batch_size, + satellite_image_size_pixels=process.satellite_image_size_pixels, + ), + pv=PV.fake( + batch_size=process.batch_size, + seq_length_5=process.seq_len_5_minutes, + n_pv_systems_per_batch=128, + time_5=time_5, + ), + sun=Sun.fake(batch_size=process.batch_size, seq_length_5=process.seq_len_5_minutes), + gsp=GSP.fake( + batch_size=process.batch_size, + seq_length_30=process.seq_len_30_minutes, + n_gsp_per_batch=32, + time_30=time_30, + ), + nwp=NWP.fake( + batch_size=process.batch_size, + seq_length_5=process.seq_len_5_minutes, + nwp_image_size_pixels=process.nwp_image_size_pixels, + number_nwp_channels=len(process.nwp_channels), + time_5=time_5, + ), + datetime=Datetime.fake( + batch_size=process.batch_size, seq_length_5=process.seq_len_5_minutes + ), + ) + + def save_netcdf(self, batch_i: int, path: Path): + """ + Save batch to netcdf file + + Args: + batch_i: the batch id, used to make the filename + path: the path where it will be saved. This can be local or in the cloud. + + """ + batch_xr = self.batch_to_dataset() + + encoding = {name: {"compression": "lzf"} for name in batch_xr.data_vars} + filename = get_netcdf_filename(batch_i) + local_filename = path / filename + batch_xr.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding) + + @staticmethod + def load_netcdf(local_netcdf_filename: Path): + """Load batch from netcdf file""" + netcdf_batch = xr.load_dataset(local_netcdf_filename) + + return Batch.load_batch_from_dataset(netcdf_batch) + + +def batch_to_dataset(batch: Batch) -> xr.Dataset: """Concat all the individual fields in an Example into a single Dataset. Args: batch: List of Example objects, which together constitute a single batch. """ datasets = [] - for i, example in enumerate(batch): - try: - individual_datasets = [] - example_dim = {"example": np.array([i], dtype=np.int32)} - for name in ["sat_data", "nwp"]: - ds = example[name].to_dataset(name=name) - short_name = name.replace("_data", "") - if name == "nwp": - ds = ds.rename({"target_time": "time"}) - for dim in ["time", "x", "y"]: - ds = coord_to_range(ds, dim, prefix=short_name) - ds = ds.rename( - { - "variable": f"{short_name}_variable", - "x": f"{short_name}_x", - "y": f"{short_name}_y", - } - ) - individual_datasets.append(ds) - - # Datetime features - for name in DATETIME_FEATURE_NAMES: - ds = example[name].rename(name).to_xarray().to_dataset().rename({"index": "time"}) - ds = coord_to_range(ds, "time", prefix=None) - individual_datasets.append(ds) - - # PV - one_dataset = xr.DataArray(example["pv_yield"], dims=["time", "pv_system"]) - one_dataset = one_dataset.to_dataset(name="pv_yield") - n_pv_systems = len(example["pv_system_id"]) - - # GSP - n_gsp = len(example[GSP_ID]) - one_dataset[GSP_YIELD] = xr.DataArray(example[GSP_YIELD], dims=["time_30", "gsp"]) - one_dataset[GSP_DATETIME_INDEX] = xr.DataArray( - example[GSP_DATETIME_INDEX], - dims=["time_30"], - coords=[np.arange(len(example[GSP_DATETIME_INDEX]))], - ) - - # Topographic - ds = example[TOPOGRAPHIC_DATA].to_dataset(name=TOPOGRAPHIC_DATA) - topo_name = "topo" - for dim in ["x", "y"]: - ds = coord_to_range(ds, dim, prefix=topo_name) - ds = ds.rename( - { - "x": f"{topo_name}_x", - "y": f"{topo_name}_y", - } - ) - individual_datasets.append(ds) - - # This will expand all dataarrays to have an 'example' dim. - # 0D - for name in ["x_meters_center", "y_meters_center", T0_DT]: - try: - one_dataset[name] = xr.DataArray( - [example[name]], coords=example_dim, dims=["example"] - ) - except Exception as e: - _LOG.error( - f"Could not make pv_yield data for {name} with example_dim={example_dim}" - ) - if name not in example.keys(): - _LOG.error(f"{name} not in data keys: {example.keys()}") - _LOG.error(e) - raise Exception - - # 1D - for name in [ - "pv_system_id", - "pv_system_row_number", - "pv_system_x_coords", - "pv_system_y_coords", - ]: - one_dataset[name] = xr.DataArray( - example[name][None, :], - coords={ - **example_dim, - **{"pv_system": np.arange(n_pv_systems, dtype=np.int32)}, - }, - dims=["example", "pv_system"], - ) - - # GSP - for name in [GSP_ID, GSP_X_COORDS, GSP_Y_COORDS]: - try: - one_dataset[name] = xr.DataArray( - example[name][None, :], - coords={ - **example_dim, - **{"gsp": np.arange(n_gsp, dtype=np.int32)}, - }, - dims=["example", "gsp"], - ) - except Exception as e: - _LOG.debug(f"Could not add {name} to dataset. {example[name].shape}") - _LOG.error(e) - raise e - - individual_datasets.append(one_dataset) - - # Merge - merged_ds = xr.merge(individual_datasets) - datasets.append(merged_ds) - - except Exception as e: - print(e) - _LOG.error(e) - raise Exception + + # loop over each item in the batch + for i, example in enumerate(batch.split()): + + individual_datasets = [] + + for data_source in example.data_sources: + if data_source is not None: + individual_datasets.append(data_source.to_xr_dataset(i)) + + # Merge + merged_ds = xr.merge(individual_datasets) + datasets.append(merged_ds) return xr.concat(datasets, dim="example") -def coord_to_range( - da: xr.DataArray, dim: str, prefix: Optional[str], dtype=np.int32 -) -> xr.DataArray: +def write_batch_locally(batch: Union[Batch, dict], batch_i: int, path: Path): """ - TODO - - TODO: Actually, I think this is over-complicated? I think we can - just strip off the 'coord' from the dimension. + Write a batch to a locally file + Args: + batch: A batch of data + batch_i: The number of the batch + path: The directory to write the batch into. """ - coord = da[dim] - da[dim] = np.arange(len(coord), dtype=dtype) - if prefix is not None: - da[f"{prefix}_{dim}_coords"] = xr.DataArray(coord, coords=[da[dim]], dims=[dim]) - return da + if type(batch): + batch = Batch(**batch) + + dataset = batch.batch_to_dataset() + encoding = {name: {"compression": "lzf"} for name in dataset.data_vars} + filename = get_netcdf_filename(batch_i) + local_filename = path / filename + dataset.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding) diff --git a/nowcasting_dataset/dataset/datamodule.py b/nowcasting_dataset/dataset/datamodule.py index 12ff765a..2d089e7b 100644 --- a/nowcasting_dataset/dataset/datamodule.py +++ b/nowcasting_dataset/dataset/datamodule.py @@ -15,6 +15,7 @@ from nowcasting_dataset import utils from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource +from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource from nowcasting_dataset.dataset import datasets from nowcasting_dataset.dataset.split.split import split_data, SplitMethod @@ -205,6 +206,15 @@ def prepare_data(self) -> None: ) self.data_sources.append(self.datetime_data_source) + self.data_sources.append( + MetadataDataSource( + history_minutes=self.history_minutes, + forecast_minutes=self.forecast_minutes, + convert_to_numpy=self.convert_to_numpy, + object_at_center="GSP", + ) + ) + def setup(self, stage="fit"): """Split data, etc. diff --git a/nowcasting_dataset/dataset/datasets.py b/nowcasting_dataset/dataset/datasets.py index 6739b356..9f160350 100644 --- a/nowcasting_dataset/dataset/datasets.py +++ b/nowcasting_dataset/dataset/datasets.py @@ -29,13 +29,12 @@ SATELLITE_DATETIME_INDEX, NWP_TARGET_TIME, PV_DATETIME_INDEX, - DATETIME_FEATURE_NAMES, DEFAULT_REQUIRED_KEYS, - T0_DT, ) -from nowcasting_dataset.data_sources.satellite_data_source import SAT_VARIABLE_NAMES -from nowcasting_dataset.dataset import example -from nowcasting_dataset.utils import set_fsspec_for_multiprocess +from nowcasting_dataset.data_sources.satellite.satellite_data_source import SAT_VARIABLE_NAMES + +from nowcasting_dataset.utils import set_fsspec_for_multiprocess, to_numpy +from nowcasting_dataset.dataset.batch import Batch logger = logging.getLogger(__name__) @@ -170,7 +169,7 @@ def __len__(self): """ Length of dataset """ return self.n_batches - def __getitem__(self, batch_idx: int) -> example.Example: + def __getitem__(self, batch_idx: int) -> Batch: """Returns a whole batch at once. Args: @@ -205,19 +204,21 @@ def __getitem__(self, batch_idx: int) -> example.Example: else: local_netcdf_filename = remote_netcdf_filename - netcdf_batch = xr.load_dataset(local_netcdf_filename) + batch = Batch.load_netcdf(local_netcdf_filename) + # netcdf_batch = xr.load_dataset(local_netcdf_filename) if self.cloud != "local": os.remove(local_netcdf_filename) - batch = example.xr_to_example(batch_xr=netcdf_batch, required_keys=self.required_keys) + # batch = example.xr_to_example(batch_xr=netcdf_batch, required_keys=self.required_keys) + # Todo this may should be done when the data is created if SATELLITE_DATA in self.required_keys: - sat_data = batch[SATELLITE_DATA] + sat_data = batch.satellite.sat_data if sat_data.dtype == np.int16: sat_data = sat_data.astype(np.float32) sat_data = sat_data - SAT_MEAN - sat_data /= SAT_STD - batch[SATELLITE_DATA] = sat_data + sat_data = sat_data / SAT_STD + batch.satellite.sat_data = sat_data if self.select_subset_data: batch = subselect_data( @@ -228,7 +229,7 @@ def __getitem__(self, batch_idx: int) -> example.Example: current_timestep_index=self.current_timestep_5_index, ) - batch = example.to_numpy(batch) + batch.change_type_to_numpy() return batch @@ -312,12 +313,13 @@ def _get_batch(self) -> torch.Tensor: t0_datetimes = self._get_t0_datetimes_for_batch() x_locations, y_locations = self._get_locations_for_batch(t0_datetimes) - examples = None + examples = {} n_threads = len(self.data_sources) with futures.ThreadPoolExecutor(max_workers=n_threads) as executor: # Submit tasks to the executor. future_examples_per_source = [] for data_source in self.data_sources: + future_examples = executor.submit( data_source.get_batch, t0_datetimes=t0_datetimes, @@ -329,13 +331,17 @@ def _get_batch(self) -> torch.Tensor: # Collect results from each thread. for future_examples in future_examples_per_source: examples_from_source = future_examples.result() - if examples is None: - examples = examples_from_source - else: - for i in range(self.batch_size): - examples[i].update(examples_from_source[i]) - return self.collate_fn(examples) + # print(type(examples_from_source)) + name = type(examples_from_source).__name__.lower() + examples[name] = examples_from_source.dict() + + examples["batch_size"] = len(t0_datetimes) + + b = Batch(**examples) + + # return as dictionary because .... # TODO + return b.dict() def _get_t0_datetimes_for_batch(self) -> pd.DatetimeIndex: # Pick random datetimes. @@ -371,40 +377,13 @@ def worker_init_fn(worker_id): dataset_obj.per_worker_init(worker_info.id) -def select_time_period( - batch: example.Example, - keys: List[str], - time_of_first_example: pd.DatetimeIndex, - start_time: xr.DataArray, - end_time: xr.DataArray, -) -> example.Example: - """ - Selects a subset of data between the indicies of [start, end] for each key in keys - - Args: - batch: Example containing the data - keys: Keys in batch to use - time_of_first_example: Datetime of the current time in the first example of the batch - start_time: Start time DataArray - end_time: End time DataArray - - Returns: - Example containing the subselected data - """ - start_i, end_i = np.searchsorted(time_of_first_example, [start_time.data, end_time.data]) - for key in keys: - batch[key] = batch[key].isel(time=slice(start_i, end_i)) - - return batch - - def subselect_data( - batch: example.Example, + batch: Batch, required_keys: Union[Tuple[str], List[str]], history_minutes: int, forecast_minutes: int, current_timestep_index: Optional[int] = None, -) -> example.Example: +) -> Batch: """ Subselects the data temporally. This function selects all data within the time range [t0 - history_minutes, t0 + forecast_minutes] @@ -423,80 +402,68 @@ def subselect_data( f"and forecast minutes if {forecast_minutes}" ) - # We are subsetting the data - date_time_index_to_use = ( - SATELLITE_DATETIME_INDEX if SATELLITE_DATA in required_keys else NWP_TARGET_TIME - ) - - # t0_dt or if not available use a different datetime index - if T0_DT in batch.keys(): - current_time_of_first_batch = batch[T0_DT][0] + # We are subsetting the data, so we need to select the t0_dt, i.e the time now for eahc Example. + # We infact only need this from the first example in each batch + if current_timestep_index is None: + # t0_dt or if not available use a different datetime index + t0_dt_of_first_example = batch.metadata.t0_dt[0].values else: - current_time_of_first_batch = batch[date_time_index_to_use].isel( - time=current_timestep_index - )[0] - - # Datetimes are in seconds, so just need to convert minutes to second + 30sec buffer - # Only need to do it for the first example in the batch, as masking indicies should be the same for all of them - # The extra 30 seconds is added to ensure that the first and last timestep are always contained - # within the [start_time, end_time] range - start_time = current_time_of_first_batch - pd.to_timedelta( - f"{history_minutes} minute 30 second" - ) - end_time = current_time_of_first_batch + pd.to_timedelta(f"{forecast_minutes} minute 30 second") - used_datetime_features = [k for k in DATETIME_FEATURE_NAMES if k in required_keys] - if SATELLITE_DATA in required_keys: - batch = select_time_period( - batch, - keys=[SATELLITE_DATA, SATELLITE_DATETIME_INDEX] + used_datetime_features, - time_of_first_example=batch[SATELLITE_DATETIME_INDEX][0].data, - start_time=start_time, - end_time=end_time, - ) - _LOG.debug( - f"Sat Datetime Shape: {batch[SATELLITE_DATETIME_INDEX].shape} Sat Data Shape: {batch[SATELLITE_DATA].shape}" + if SATELLITE_DATA in required_keys: + t0_dt_of_first_example = batch.satellite.sat_datetime_index[ + 0, current_timestep_index + ].values + else: + t0_dt_of_first_example = batch.satellite.sat_datetime_index[ + 0, current_timestep_index + ].values + + # make this a datetime object + t0_dt_of_first_example = pd.to_datetime(t0_dt_of_first_example) + + if batch.satellite is not None: + batch.satellite.select_time_period( + keys=[SATELLITE_DATA, SATELLITE_DATETIME_INDEX], + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + t0_dt_of_first_example=t0_dt_of_first_example, ) # Now for NWP, if used - if NWP_DATA in required_keys: - batch = select_time_period( - batch, - keys=[NWP_DATA, NWP_TARGET_TIME] + used_datetime_features - if SATELLITE_DATA not in required_keys - else [NWP_DATA, NWP_TARGET_TIME], - time_of_first_example=batch[NWP_TARGET_TIME][0].data, - start_time=start_time, - end_time=end_time, - ) - _LOG.debug( - f"NWP Datetime Shape: {batch[NWP_TARGET_TIME].shape} NWP Data Shape: {batch[NWP_DATA].shape}" + if batch.nwp is not None: + batch.nwp.select_time_period( + keys=[NWP_DATA, NWP_TARGET_TIME], + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + t0_dt_of_first_example=t0_dt_of_first_example, ) - - # Now GSP, if used - if GSP_YIELD in required_keys and GSP_DATETIME_INDEX in batch: - batch = select_time_period( - batch, + # + # Now for GSP, if used + if batch.gsp is not None: + batch.gsp.select_time_period( keys=[GSP_DATETIME_INDEX, GSP_YIELD], - time_of_first_example=batch[GSP_DATETIME_INDEX][0].data, - start_time=start_time, - end_time=end_time, - ) - _LOG.debug( - f"GSP Datetime Shape: {batch[GSP_DATETIME_INDEX].shape} GSP Data Shape: {batch[GSP_YIELD].shape}" + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + t0_dt_of_first_example=t0_dt_of_first_example, ) - # Now PV systems, if used - if PV_YIELD in required_keys and PV_DATETIME_INDEX in batch: - batch = select_time_period( - batch, - keys=[PV_DATETIME_INDEX, PV_YIELD, SUN_ELEVATION_ANGLE, SUN_AZIMUTH_ANGLE], - time_of_first_example=batch[PV_DATETIME_INDEX][0].data, - start_time=start_time, - end_time=end_time, + # Now for PV, if used + if batch.pv is not None: + batch.pv.select_time_period( + keys=[PV_DATETIME_INDEX, PV_YIELD], + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + t0_dt_of_first_example=t0_dt_of_first_example, ) - _LOG.debug( - f"PV Datetime Shape: {batch[PV_DATETIME_INDEX].shape} PV Data Shape: {batch[PV_YIELD].shape}" - f" PV Azimuth Shape: {batch[SUN_ELEVATION_ANGLE].shape} PV Elevation Shape: {batch[SUN_AZIMUTH_ANGLE].shape}" + + # Now for SUN, if used + if batch.sun is not None: + batch.sun.select_time_period( + keys=[SUN_ELEVATION_ANGLE, SUN_AZIMUTH_ANGLE], + history_minutes=history_minutes, + forecast_minutes=forecast_minutes, + t0_dt_of_first_example=t0_dt_of_first_example, ) + # DATETIME TODO + return batch diff --git a/nowcasting_dataset/dataset/example.py b/nowcasting_dataset/dataset/example.py deleted file mode 100644 index 96cf9101..00000000 --- a/nowcasting_dataset/dataset/example.py +++ /dev/null @@ -1,143 +0,0 @@ -""" Example Data Class """ -from numbers import Number -from typing import TypedDict, List - -import numpy as np -import pandas as pd - -from nowcasting_dataset.consts import * - - -class Example(TypedDict): - """Simple class for structuring data for each ML example. - - Using typing.TypedDict gives us several advantages: - 1. Single 'source of truth' for the type and documentation of the fields - in each example. - 2. A static type checker can check the types are correct. - - Instead of TypedDict, we could use typing.NamedTuple, - which would provide runtime checks, but the deal-breaker with Tuples is - that they're immutable so we cannot change the values in the transforms. - """ - - # timestamp of now. In this data object there will be both - # - historic data before this timestamp, - # - and future data after this timestamp - # shape is [batch_size,] - t0_dt = Array - - # IMAGES - # Shape: [batch_size,] seq_length, width, height, channel - sat_data: Array - sat_x_coords: Array #: OSGB geo-spatial coordinates. - sat_y_coords: Array - - # Topographic data - # Elevation map of the area covered by the satellite data - # Shape: [batch_size,] width, height - topo_data: Array - topo_x_coords: Array - topo_y_coords: Array - - #: PV yield from all PV systems in the region of interest (ROI). - #: Includes central PV system, which will always be the first entry. - #: shape = [batch_size, ] seq_length, n_pv_systems_per_example - pv_yield: Array - - # PV azimuth and elevation angles i.e where the sun is. - #: shape = [batch_size, ] seq_length - sun_azimuth_angle: Array - sun_elevation_angle: Array - - #: PV identification. - #: shape = [batch_size, ] n_pv_systems_per_example - pv_system_id: Array - pv_system_row_number: Array #: In the range [0, len(pv_metadata)]. - - #: PV system geographical location (in OSGB coords). - #: shape = [batch_size, ] n_pv_systems_per_example - pv_system_x_coords: Array - pv_system_y_coords: Array - pv_datetime_index: Array #: shape = [batch_size, ] seq_length - - # Numerical weather predictions (NWPs) - nwp: Array #: Shape: [batch_size,] channel, seq_length, width, height - nwp_x_coords: Array - nwp_y_coords: Array - - # METADATA - x_meters_center: Number #: In OSGB coordinations - y_meters_center: Number #: In OSGB coordinations - - # Datetimes (abbreviated to "dt") - # At 5-minutes past the hour {0, 5, ..., 55} - # *not* the {4, 9, ..., 59} timings of the satellite imagery. - # Datetimes become Unix epochs (UTC) represented as int64 just before being - # passed into the ML model. - # t0_dt is 'now', the most recent observation. - sat_datetime_index: Array - nwp_target_time: Array - hour_of_day_sin: Array #: Shape: [batch_size,] seq_length - hour_of_day_cos: Array - day_of_year_sin: Array - day_of_year_cos: Array - - #: GSP PV yield from all GSP in the region of interest (ROI). - # : Includes central GSP, which will always be the first entry. This will be a numpy array of values. - gsp_yield: Array #: shape = [batch_size, ] seq_length, n_gsp_systems_per_example - # GSP identification. - gsp_id: Array #: shape = [batch_size, ] n_gsp_per_example - #: GSP geographical location (in OSGB coords). - gsp_x_coords: Array #: shape = [batch_size, ] n_gsp_per_example - gsp_y_coords: Array #: shape = [batch_size, ] n_gsp_per_example - gsp_datetime_index: Array #: shape = [batch_size, ] seq_length - - # if the centroid type is a GSP, or a PV system - object_at_center: str #: shape = [batch_size, ] - - -def xr_to_example(batch_xr: xr.core.dataset.Dataset, required_keys: List[str]) -> Example: - """ - Change xr dataset to Example - - Args: - batch_xr: batch data in xarray format - required_keys: the keys that are need - - Returns: Example object of the xarray data - - """ - batch = Example( - sat_datetime_index=batch_xr.sat_time_coords, - nwp_target_time=batch_xr.nwp_time_coords, - ) - for key in required_keys: - try: - batch[key] = batch_xr[key] - except KeyError: - pass - - return batch - - -def to_numpy(example: Example) -> Example: - """ - Change items in Example to numpy objects - """ - for key, value in example.items(): - if isinstance(value, xr.DataArray): - # TODO: Use to_numpy() or as_numpy(), introduced in xarray v0.19? - value = value.data - - if isinstance(value, (pd.Series, pd.DataFrame)): - value = value.values - elif isinstance(value, pd.DatetimeIndex): - value = value.values.astype("datetime64[s]").astype(np.int32) - elif isinstance(value, pd.Timestamp): - value = np.int32(value.timestamp()) - elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.datetime64): - value = value.astype("datetime64[s]").astype(np.int32) - - example[key] = value - return example diff --git a/nowcasting_dataset/dataset/validate.py b/nowcasting_dataset/dataset/validate.py index 80a063d8..0e23202e 100644 --- a/nowcasting_dataset/dataset/validate.py +++ b/nowcasting_dataset/dataset/validate.py @@ -28,7 +28,9 @@ TOPOGRAPHIC_Y_COORDS, ) from nowcasting_dataset.dataset.datasets import NetCDFDataset, logger -from nowcasting_dataset.dataset.example import Example + +# from nowcasting_dataset.dataset.example import Example +from nowcasting_dataset.dataset.batch import Batch class ValidatorDataset: @@ -65,6 +67,10 @@ def validate(self): for batch_idx, batch in enumerate(self.batches): logger.info(f"Validating batch {batch_idx}") + # change dict to Batch, this does some validation + if type(batch) == dict: + batch = Batch(**batch) + all_day_from_batch_unique = self.validate_and_get_day_datetimes_for_one_batch( batch=batch ) @@ -75,7 +81,7 @@ def validate(self): self.day_datetimes = day_datetimes - def validate_and_get_day_datetimes_for_one_batch(self, batch): + def validate_and_get_day_datetimes_for_one_batch(self, batch: Batch): """ For one batch, validate, and return the day datetimes in that batch @@ -85,12 +91,10 @@ def validate_and_get_day_datetimes_for_one_batch(self, batch): Returns: list of days that the batch has data for """ - validate_batch_from_configuration(batch, configuration=self.configuration) - - if type(batch[GSP_DATETIME_INDEX]) == torch.Tensor: - batch[GSP_DATETIME_INDEX] = batch[GSP_DATETIME_INDEX].detach().numpy() + if type(batch.metadata.t0_dt) == torch.Tensor: + batch.metadata.t0_dt = batch.metadata.t0_dt.detach().numpy() - all_datetimes_from_batch = pd.to_datetime(batch[GSP_DATETIME_INDEX].reshape(-1), unit="s") + all_datetimes_from_batch = pd.to_datetime(batch.metadata.t0_dt.reshape(-1), unit="s") return pd.DatetimeIndex(all_datetimes_from_batch.date).unique() @@ -105,18 +109,9 @@ def __init__(self, configuration: Configuration, length: int = 10): configuration: configuration object length: length of dataset """ - self.batch_size = configuration.process.batch_size - self.seq_length_5 = ( - configuration.process.seq_len_5_minutes - ) # the sequence data in 5 minute steps - self.seq_length_30 = ( - configuration.process.seq_len_30_minutes - ) # the sequence data in 30 minute steps - self.satellite_image_size_pixels = configuration.process.satellite_image_size_pixels - self.nwp_image_size_pixels = configuration.process.nwp_image_size_pixels - self.number_sat_channels = len(configuration.process.sat_channels) self.number_nwp_channels = len(configuration.process.nwp_channels) self.length = length + self.configuration = configuration def __len__(self): """ Number of pieces of data """ @@ -136,236 +131,7 @@ def __getitem__(self, idx): Returns: Dictionary of random data """ - x = { - "sat_data": torch.randn( - self.batch_size, - self.seq_length_5, - self.satellite_image_size_pixels, - self.satellite_image_size_pixels, - self.number_sat_channels, - ), - "pv_yield": torch.randn( - self.batch_size, self.seq_length_5, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE - ), - "pv_system_id": torch.randint(940, (self.batch_size, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE)), - "pv_system_x_coords": torch.randn(self.batch_size, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE), - "pv_system_y_coords": torch.randn(self.batch_size, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE), - "pv_system_row_number": torch.randint( - 940, (self.batch_size, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE) - ), - "nwp": torch.randn( - self.batch_size, - self.number_nwp_channels, - self.seq_length_5, - self.nwp_image_size_pixels, - self.nwp_image_size_pixels, - ), - "hour_of_day_sin": torch.randn(self.batch_size, self.seq_length_5), - "hour_of_day_cos": torch.randn(self.batch_size, self.seq_length_5), - "day_of_year_sin": torch.randn(self.batch_size, self.seq_length_5), - "day_of_year_cos": torch.randn(self.batch_size, self.seq_length_5), - "gsp_yield": torch.randn( - self.batch_size, self.seq_length_30, DEFAULT_N_GSP_PER_EXAMPLE - ), - "gsp_id": torch.randint(340, (self.batch_size, DEFAULT_N_GSP_PER_EXAMPLE)), - "topo_data": torch.randn( - self.batch_size, self.satellite_image_size_pixels, self.satellite_image_size_pixels - ), - } - - # add a nan - x["pv_yield"][0, 0, :] = float("nan") - - # add fake x and y coords, and make sure they are sorted - x["sat_x_coords"], _ = torch.sort( - torch.randn(self.batch_size, self.satellite_image_size_pixels) - ) - x["sat_y_coords"], _ = torch.sort( - torch.randn(self.batch_size, self.satellite_image_size_pixels), descending=True - ) - x["gsp_x_coords"], _ = torch.sort(torch.randn(self.batch_size, DEFAULT_N_GSP_PER_EXAMPLE)) - x["gsp_y_coords"], _ = torch.sort( - torch.randn(self.batch_size, DEFAULT_N_GSP_PER_EXAMPLE), descending=True - ) - x["topo_x_coords"], _ = torch.sort( - torch.randn(self.batch_size, self.satellite_image_size_pixels) - ) - x["topo_y_coords"], _ = torch.sort( - torch.randn(self.batch_size, self.satellite_image_size_pixels), descending=True - ) - - x["nwp_x_coords"], _ = torch.sort(torch.randn(self.batch_size, self.nwp_image_size_pixels)) - x["nwp_y_coords"], _ = torch.sort( - torch.randn(self.batch_size, self.nwp_image_size_pixels), descending=True - ) - - # add sorted (fake) time series - x["sat_datetime_index"], _ = torch.sort(torch.randn(self.batch_size, self.seq_length_5)) - x["nwp_target_time"], _ = torch.sort(torch.randn(self.batch_size, self.seq_length_5)) - x["gsp_datetime_index"], _ = torch.sort(torch.randn(self.batch_size, self.seq_length_30)) - - x["x_meters_center"], _ = torch.sort(torch.randn(self.batch_size)) - x["y_meters_center"], _ = torch.sort(torch.randn(self.batch_size)) - - # clip yield values from 0 to 1 - x["pv_yield"] = torch.clip(x["pv_yield"], min=0, max=1) - x["gsp_yield"] = torch.clip(x["gsp_yield"], min=0, max=1) - - return x - - -def validate_example( - data: Example, - seq_len_30_minutes: int, - seq_len_5_minutes: int, - sat_image_size: int = 64, - n_sat_channels: int = 1, - nwp_image_size: int = 0, - n_nwp_channels: int = 1, - n_pv_systems_per_example: int = DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE, - n_gsp_per_example: int = DEFAULT_N_GSP_PER_EXAMPLE, - batch: bool = False, -): - """ - Validate the size and shape of the data - - Args: - data: Typed dictionary of the data - seq_len_30_minutes: the length of the sequence for 30 minutely data - seq_len_5_minutes: the length of the sequence for 5 minutely data - sat_image_size: the satellite image size - n_sat_channels: the number of satellite channgles - nwp_image_size: the nwp image size - n_nwp_channels: the number of nwp channels - n_pv_systems_per_example: the number pv systems with nan padding - n_gsp_per_example: the number gsp systems with nan padding - batch: if this example class is a batch or not - """ - n_gsp_id = data[GSP_ID].shape[-1] - assert ( - n_gsp_id == n_gsp_per_example - ), f"gsp_is is len {n_gsp_id}, but should be {n_gsp_per_example}" - assert data[GSP_YIELD].shape[-2:] == ( - seq_len_30_minutes, - n_gsp_id, - ), f"gsp_yield is size {data[GSP_YIELD].shape}, but should be {(seq_len_30_minutes, n_gsp_id)}" - assert data[GSP_X_COORDS].shape[-1] == n_gsp_id - assert data[GSP_Y_COORDS].shape[-1] == n_gsp_id - assert data[GSP_DATETIME_INDEX].shape[-1] == seq_len_30_minutes - - # check the GSP data is between 0 and 1 - assert ( - np.nanmax(data[GSP_YIELD]) <= 1.0 - ), f"Maximum GSP value is {np.nanmax(data[GSP_YIELD])} but it should be <= 1" - assert ( - np.nanmin(data[GSP_YIELD]) >= 0.0 - ), f"Maximum GSP value is {np.nanmin(data[GSP_YIELD])} but it should be >= 0" + x = Batch.fake(configuration=self.configuration) + x.change_type_to_numpy() - if OBJECT_AT_CENTER in data.keys(): - assert data[OBJECT_AT_CENTER] == "gsp" - - if not batch: - # add an extract dimension so that its similar to batch data - data["x_meters_center"] = np.expand_dims(data["x_meters_center"], axis=0) - data["y_meters_center"] = np.expand_dims(data["y_meters_center"], axis=0) - - # loop over batch - for d in data["x_meters_center"]: - assert type(d) in [ - np.float64, - torch.Tensor, - ], f"x_meters_center should be np.float64 but is {type(d)}" - for d in data["y_meters_center"]: - assert type(d) in [ - np.float64, - torch.Tensor, - ], f"y_meters_center should be np.float64 but is {type(d)}" - - assert data[PV_SYSTEM_ID].shape[-1] == n_pv_systems_per_example - assert data[PV_YIELD].shape[-2:] == (seq_len_5_minutes, n_pv_systems_per_example) - assert data[PV_SYSTEM_X_COORDS].shape[-1] == n_pv_systems_per_example - assert data[PV_SYSTEM_Y_COORDS].shape[-1] == n_pv_systems_per_example - - if not batch: - # add an extra dimension so that it's similar to batch data - data[PV_SYSTEM_ID] = np.expand_dims(data[PV_SYSTEM_ID], axis=0) - data[PV_SYSTEM_ROW_NUMBER] = np.expand_dims(data[PV_SYSTEM_ID], axis=0) - - # loop over batch - for i in range(len(data[PV_SYSTEM_ID])): - n_pv_systems = (data[PV_SYSTEM_ID][i, ~np.isnan(data[PV_SYSTEM_ID][i])]).shape[-1] - n_pv_syetem_row_numbers = ( - data[PV_SYSTEM_ROW_NUMBER][i, ~np.isnan(data[PV_SYSTEM_ROW_NUMBER][i])] - ).shape[-1] - assert n_pv_syetem_row_numbers == n_pv_systems, ( - f"Number of PV systems ({n_pv_systems}) does not match the " - f"pv systems row numbers ({n_pv_syetem_row_numbers})" - ) - - if n_pv_systems > 0: - # check the PV data is between 0 and 1 - assert ( - np.nanmax(data[PV_YIELD]) <= 1.0 - ), f"Maximum PV value is {np.nanmax(data[PV_YIELD])} but it should be <= 1" - assert ( - np.nanmin(data[PV_YIELD]) >= 0.0 - ), f"Maximum PV value is {np.nanmin(data[PV_YIELD])} but it should be <= 1" - - if SUN_AZIMUTH_ANGLE in data.keys(): - assert data[SUN_AZIMUTH_ANGLE].shape[-1] == seq_len_5_minutes - if SUN_ELEVATION_ANGLE in data.keys(): - assert data[SUN_ELEVATION_ANGLE].shape[-1] == seq_len_5_minutes - - assert data["sat_data"].shape[-4:] == ( - seq_len_5_minutes, - sat_image_size, - sat_image_size, - n_sat_channels, - ) - assert data["sat_x_coords"].shape[-1] == sat_image_size - assert data["sat_y_coords"].shape[-1] == sat_image_size - assert data["sat_datetime_index"].shape[-1] == seq_len_5_minutes - - assert data[TOPOGRAPHIC_DATA].shape[-2:] == (sat_image_size, sat_image_size) - assert data[TOPOGRAPHIC_Y_COORDS].shape[-1] == sat_image_size - assert data[TOPOGRAPHIC_X_COORDS].shape[-1] == sat_image_size - - nwp_correct_shape = ( - n_nwp_channels, - seq_len_5_minutes, - nwp_image_size, - nwp_image_size, - ) - nwp_shape = data["nwp"].shape[-4:] - assert ( - nwp_shape == nwp_correct_shape - ), f"NWP shape should be ({nwp_correct_shape}), but instead it is {nwp_shape}" - assert data["nwp_x_coords"].shape[-1] == nwp_image_size - assert data["nwp_y_coords"].shape[-1] == nwp_image_size - assert data["nwp_target_time"].shape[-1] == seq_len_5_minutes - - for feature in DATETIME_FEATURE_NAMES: - assert data[feature].shape[-1] == seq_len_5_minutes - - -def validate_batch_from_configuration(data: Example, configuration: Configuration): - """ - Validate data using a configuration - - Args: - data: batch of data - configuration: confgiruation of the data - - """ - validate_example( - data=data, - seq_len_30_minutes=configuration.process.seq_len_30_minutes, - seq_len_5_minutes=configuration.process.seq_len_5_minutes, - sat_image_size=configuration.process.satellite_image_size_pixels, - n_sat_channels=len(configuration.process.sat_channels), - nwp_image_size=configuration.process.nwp_image_size_pixels, - n_nwp_channels=len(configuration.process.nwp_channels), - n_pv_systems_per_example=DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE, - n_gsp_per_example=DEFAULT_N_GSP_PER_EXAMPLE, - batch=True, - ) + return x.dict() diff --git a/nowcasting_dataset/time.py b/nowcasting_dataset/time.py index 9c8d4c05..54ee4445 100644 --- a/nowcasting_dataset/time.py +++ b/nowcasting_dataset/time.py @@ -6,9 +6,10 @@ import numpy as np import pandas as pd import pvlib +import random from nowcasting_dataset import geospatial, utils -from nowcasting_dataset.dataset.example import Example +from nowcasting_dataset.data_sources.datetime.datetime_model import Datetime logger = logging.getLogger(__name__) @@ -54,7 +55,7 @@ def select_daylight_datetimes( def intersection_of_datetimeindexes(indexes: List[pd.DatetimeIndex]) -> pd.DatetimeIndex: - """ Get intersections of datetime indexes """ + """Get intersections of datetime indexes""" assert len(indexes) > 0 intersection = indexes[0] for index in indexes[1:]: @@ -143,7 +144,7 @@ def get_t0_datetimes( def timesteps_to_duration(n_timesteps: int, minute_delta: int = 5) -> pd.Timedelta: - """ Change timesteps to a time duration """ + """Change timesteps to a time duration""" assert n_timesteps >= 0 return pd.Timedelta(n_timesteps * minute_delta, unit="minutes") @@ -164,7 +165,7 @@ def datetime_features(index: pd.DatetimeIndex) -> pd.DataFrame: return pd.DataFrame(features, index=index).astype(np.float32) -def datetime_features_in_example(index: pd.DatetimeIndex) -> Example: +def datetime_features_in_example(index: pd.DatetimeIndex) -> Datetime: """ Make datetime features with sin and cos @@ -178,10 +179,14 @@ def datetime_features_in_example(index: pd.DatetimeIndex) -> Example: dt_features["hour_of_day"] /= 24 dt_features["day_of_year"] /= 365 dt_features = utils.sin_and_cos(dt_features) - example = Example() + + datetime_dict = {} for col_name, series in dt_features.iteritems(): - example[col_name] = series - return example + datetime_dict[col_name] = series.values + + datetime_dict["datetime_index"] = series.index.values + + return Datetime(**datetime_dict) def fill_30_minutes_timestamps_to_5_minutes(index: pd.DatetimeIndex) -> pd.DatetimeIndex: @@ -226,3 +231,43 @@ def fill_30_minutes_timestamps_to_5_minutes(index: pd.DatetimeIndex) -> pd.Datet # drop nans and take index return index_with_gaps.dropna().index + + +def make_random_time_vectors(batch_size, seq_len_5_minutes, seq_len_30_minutes): + """ + Make random time vectors + + 1. t0_dt, Get random datetimes from 2019 + 2. Exapnd t0_dt to make 5 and 30 mins sequences + + Args: + batch_size: the batch size + seq_len_5_minutes: the length of the sequence in 5 mins deltas + seq_len_30_minutes: the length of the sequence in 30 mins deltas + + Returns: + - t0_dt: [batch_size] random init datetimes + - time_5: [batch_size, seq_len_5_minutes] random sequence of datetimes, with 5 mins deltas. + t0_dt is in the middle of the sequence + - time_30: [batch_size, seq_len_30_minutes] random sequence of datetimes, with 30 mins deltas. + t0_dt is in the middle of the sequence + """ + delta_5 = pd.Timedelta(minutes=5) + delta_30 = pd.Timedelta(minutes=30) + + data_range = pd.date_range("2019-01-01", "2021-01-01", freq="5T") + t0_dt = pd.Series(random.choices(data_range, k=batch_size)) + time_5 = ( + pd.DataFrame([t0_dt + i * delta_5 for i in range(seq_len_5_minutes)]) + - int(seq_len_5_minutes / 2) * delta_5 + ) + time_30 = ( + pd.DataFrame([t0_dt + i * delta_30 for i in range(seq_len_30_minutes)]) + - int(seq_len_30_minutes / 2) * delta_5 + ) + + t0_dt = utils.to_numpy(t0_dt) + time_5 = utils.to_numpy(time_5.T) + time_30 = utils.to_numpy(time_30.T) + + return t0_dt, time_5, time_30 diff --git a/nowcasting_dataset/utils.py b/nowcasting_dataset/utils.py index fcd0559d..57ea8aa4 100644 --- a/nowcasting_dataset/utils.py +++ b/nowcasting_dataset/utils.py @@ -2,16 +2,17 @@ import hashlib import logging from pathlib import Path -from typing import List +from typing import List, Optional import fsspec.asyn import numpy as np import pandas as pd +import torch +import xarray as xr import tempfile import gcsfs from nowcasting_dataset.consts import Array -from nowcasting_dataset.dataset.example import Example logger = logging.getLogger(__name__) @@ -114,38 +115,41 @@ def pad_nans(array, pad_width) -> np.ndarray: return np.pad(array, pad_width, constant_values=np.NaN) -def pad_data( - data: Example, - pad_size: int, - one_dimensional_arrays: List[str], - two_dimensional_arrays: List[str], -) -> Example: - """ - Pad (if necessary) so returned arrays are always of size +def to_numpy(value): + """ Change generic data to numpy""" + if isinstance(value, xr.DataArray): + # TODO: Use to_numpy() or as_numpy(), introduced in xarray v0.19? + value = value.data - data has two types of arrays in it, one dimensional arrays and two dimensional arrays - the one dimensional arrays are padded in that dimension - the two dimensional arrays are padded in the second dimension + if isinstance(value, (pd.Series, pd.DataFrame)): + value = value.values + elif isinstance(value, pd.DatetimeIndex): + value = value.values.astype("datetime64[s]").astype(np.int32) + elif isinstance(value, pd.Timestamp): + value = np.int32(value.timestamp()) + elif isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.datetime64): + value = value.astype("datetime64[s]").astype(np.int32) + elif isinstance(value, torch.Tensor): + value = value.numpy() - Args: - data: typed dictionary of data objects - pad_size: the maount that should be padded - one_dimensional_arrays: list of data items that should be padded by one dimension - two_dimensional_arrays: list of data tiems that should be padded in the third dimension (and more) + return value - Returns: Example data +def coord_to_range( + da: xr.DataArray, dim: str, prefix: Optional[str], dtype=np.int32 +) -> xr.DataArray: """ - # Pad (if necessary) so returned arrays are always of size - pad_shape = (0, pad_size) # (before, after) + TODO - for name in one_dimensional_arrays: - data[name] = pad_nans(data[name], pad_width=pad_shape) + TODO: Actually, I think this is over-complicated? I think we can + just strip off the 'coord' from the dimension. - for variable in two_dimensional_arrays: - data[variable] = pad_nans(data[variable], pad_width=((0, 0), pad_shape)) # (axis0, axis1) - - return data + """ + coord = da[dim] + da[dim] = np.arange(len(coord), dtype=dtype) + if prefix is not None: + da[f"{prefix}_{dim}_coords"] = xr.DataArray(coord, coords=[da[dim]], dims=[dim]) + return da class OpenData: diff --git a/scripts/generate_data_for_tests/get_test_data.py b/scripts/generate_data_for_tests/get_test_data.py index ae06cd47..6f32cb68 100644 --- a/scripts/generate_data_for_tests/get_test_data.py +++ b/scripts/generate_data_for_tests/get_test_data.py @@ -10,7 +10,9 @@ import xarray as xr import nowcasting_dataset -from nowcasting_dataset.data_sources.nwp_data_source import open_nwp, NWP_VARIABLE_NAMES +from nowcasting_dataset.data_sources.nwp.nwp_data_source import open_nwp, NWP_VARIABLE_NAMES +from nowcasting_dataset.config.model import Configuration +from nowcasting_dataset.dataset.batch import Batch # set up BUCKET = Path("solar-pv-nowcasting-data") @@ -135,3 +137,15 @@ # save to file sun_xr.to_zarr(f"{local_path}/tests/data/sun/test.zarr", mode="w") + + +######## +# batch0.nc +######## + +c = Configuration() +c.process.nwp_channels = c.process.nwp_channels[0:1] +c.process.sat_channels = c.process.sat_channels[0:1] + +f = Batch.fake(configuration=c) +f.save_netcdf(batch_i=0, path=Path(f"{local_path}/tests/data")) diff --git a/scripts/prepare_ml_data.py b/scripts/prepare_ml_data.py index 52998baa..981301c9 100755 --- a/scripts/prepare_ml_data.py +++ b/scripts/prepare_ml_data.py @@ -23,8 +23,8 @@ from nowcasting_dataset.dataset.datamodule import NowcastingDataModule from nowcasting_dataset.dataset.batch import write_batch_locally -from nowcasting_dataset.data_sources.satellite_data_source import SAT_VARIABLE_NAMES -from nowcasting_dataset.data_sources.nwp_data_source import NWP_VARIABLE_NAMES +from nowcasting_dataset.data_sources.satellite.satellite_data_source import SAT_VARIABLE_NAMES +from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWP_VARIABLE_NAMES from pathy import Pathy from pathlib import Path import fsspec diff --git a/scripts/rechunk_nwp_data.py b/scripts/rechunk_nwp_data.py index fecfd0bb..d26e9a25 100755 --- a/scripts/rechunk_nwp_data.py +++ b/scripts/rechunk_nwp_data.py @@ -6,7 +6,7 @@ import gcsfs import rechunker import zarr -from nowcasting_dataset.data_sources.nwp_data_source import NWP_VARIABLE_NAMES +from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWP_VARIABLE_NAMES import os import numpy as np diff --git a/scripts/validate_ml_data.py b/scripts/validate_ml_data.py index f441b943..388d76ec 100644 --- a/scripts/validate_ml_data.py +++ b/scripts/validate_ml_data.py @@ -7,7 +7,7 @@ from nowcasting_dataset.config.load import load_configuration_from_gcs, load_yaml_configuration from nowcasting_dataset.dataset.datasets import NetCDFDataset, worker_init_fn from nowcasting_dataset.dataset.validate import ValidatorDataset -from nowcasting_dataset.utils import get_maximum_batch_id_from_gcs +from nowcasting_dataset.cloud.utils import get_maximum_batch_id logging.basicConfig(format="%(asctime)s %(levelname)s %(pathname)s %(lineno)d %(message)s") _LOG = logging.getLogger("nowcasting_dataset") @@ -27,9 +27,9 @@ LOCAL_TEMP_PATH = Path("~/temp/").expanduser() # find how many datasets there are -maximum_batch_id_train = get_maximum_batch_id_from_gcs(f"gs://{DST_TRAIN_PATH}") -maximum_batch_id_validation = get_maximum_batch_id_from_gcs(f"gs://{DST_VALIDATION_PATH}") -maximum_batch_id_test = get_maximum_batch_id_from_gcs(f"gs://{DST_TEST_PATH}") +maximum_batch_id_train = get_maximum_batch_id(f"gs://{DST_TRAIN_PATH}") +maximum_batch_id_validation = get_maximum_batch_id(f"gs://{DST_VALIDATION_PATH}") +maximum_batch_id_test = get_maximum_batch_id(f"gs://{DST_TEST_PATH}") dataloader_config = dict( pin_memory=True, @@ -49,6 +49,7 @@ f"gs://{DST_TRAIN_PATH}", LOCAL_TEMP_PATH, cloud="gcp", + configuration=config, ), **dataloader_config, ) @@ -60,6 +61,7 @@ f"gs://{DST_VALIDATION_PATH}", LOCAL_TEMP_PATH, cloud="gcp", + configuration=config, ), **dataloader_config, ) @@ -70,6 +72,7 @@ f"gs://{DST_TEST_PATH}", LOCAL_TEMP_PATH, cloud="gcp", + configuration=config, ), **dataloader_config, ) diff --git a/tests/data/0.nc b/tests/data/0.nc index d5b9b6ad..b1bdcd18 100644 Binary files a/tests/data/0.nc and b/tests/data/0.nc differ diff --git a/tests/data_sources/gsp/test_gsp_data_source.py b/tests/data_sources/gsp/test_gsp_data_source.py index 2070ccf8..86b5be11 100644 --- a/tests/data_sources/gsp/test_gsp_data_source.py +++ b/tests/data_sources/gsp/test_gsp_data_source.py @@ -1,10 +1,7 @@ import os from datetime import datetime -import pandas as pd - import nowcasting_dataset -from nowcasting_dataset.consts import T0_DT from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource from nowcasting_dataset.geospatial import osgb_to_lat_lon @@ -72,10 +69,10 @@ def test_gsp_pv_data_source_get_example(): t0_dt=gsp.gsp_power.index[0], x_meters_center=x_locations[0], y_meters_center=y_locations[0] ) - assert len(l["gsp_id"]) == len(l["gsp_yield"][0]) - assert len(l["gsp_x_coords"]) == len(l["gsp_y_coords"]) - assert len(l["gsp_x_coords"]) > 0 - assert type(l[T0_DT]) == pd.Timestamp + assert len(l.gsp_id) == len(l.gsp_yield[0]) + assert len(l.gsp_x_coords) == len(l.gsp_y_coords) + assert len(l.gsp_x_coords) > 0 + # assert type(l[T0_DT]) == pd.Timestamp def test_gsp_pv_data_source_get_batch(): @@ -105,9 +102,9 @@ def test_gsp_pv_data_source_get_batch(): y_locations=y_locations[0:batch_size], ) - assert len(batch) == batch_size - assert len(batch[0]["gsp_yield"]) == 4 - assert len(batch[0]["gsp_id"]) == len(batch[0]["gsp_x_coords"]) - assert len(batch[1]["gsp_x_coords"]) == len(batch[1]["gsp_y_coords"]) - assert len(batch[2]["gsp_x_coords"]) > 0 - assert T0_DT in batch[3].keys() + assert batch.batch_size == batch_size + assert len(batch.gsp_yield[0]) == 4 + assert len(batch.gsp_id[0]) == len(batch.gsp_x_coords[0]) + assert len(batch.gsp_x_coords[1]) == len(batch.gsp_y_coords[1]) + assert len(batch.gsp_x_coords[2]) > 0 + # assert T0_DT in batch[3].keys() diff --git a/tests/data_sources/sun/test_sun_data_source.py b/tests/data_sources/sun/test_sun_data_source.py index 943b65bd..8902548f 100644 --- a/tests/data_sources/sun/test_sun_data_source.py +++ b/tests/data_sources/sun/test_sun_data_source.py @@ -1,6 +1,7 @@ from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource from datetime import datetime -from nowcasting_dataset.dataset.example import Example + +# from nowcasting_dataset.dataset.example import Example from nowcasting_dataset.consts import SUN_ELEVATION_ANGLE, SUN_AZIMUTH_ANGLE import pandas as pd @@ -26,10 +27,8 @@ def test_get_example(test_data_folder): example = sun_data_source.get_example(t0_dt=start_dt, x_meters_center=x, y_meters_center=y) - assert SUN_ELEVATION_ANGLE in example.keys() - assert SUN_AZIMUTH_ANGLE in example.keys() - assert len(example[SUN_ELEVATION_ANGLE]) == 19 - assert len(example[SUN_AZIMUTH_ANGLE]) == 19 + assert len(example.sun_elevation_angle) == 19 + assert len(example.sun_azimuth_angle) == 19 def test_get_example_different_year(test_data_folder): @@ -45,7 +44,5 @@ def test_get_example_different_year(test_data_folder): example = sun_data_source.get_example(t0_dt=start_dt, x_meters_center=x, y_meters_center=y) - assert SUN_ELEVATION_ANGLE in example.keys() - assert SUN_AZIMUTH_ANGLE in example.keys() - assert len(example[SUN_ELEVATION_ANGLE]) == 19 - assert len(example[SUN_AZIMUTH_ANGLE]) == 19 + assert len(example.sun_elevation_angle) == 19 + assert len(example.sun_azimuth_angle) == 19 diff --git a/tests/data_sources/test_datasource_output.py b/tests/data_sources/test_datasource_output.py new file mode 100644 index 00000000..934d3222 --- /dev/null +++ b/tests/data_sources/test_datasource_output.py @@ -0,0 +1,102 @@ +from nowcasting_dataset.data_sources.datetime.datetime_model import Datetime +from nowcasting_dataset.data_sources.gsp.gsp_model import GSP +from nowcasting_dataset.data_sources.pv.pv_model import PV +from nowcasting_dataset.data_sources.nwp.nwp_model import NWP +from nowcasting_dataset.data_sources.satellite.satellite_model import Satellite +from nowcasting_dataset.data_sources.sun.sun_model import Sun +from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic + + +def test_datetime(): + + s = Datetime.fake( + batch_size=4, + seq_length_5=13, + ) + + +def test_gsp(): + + s = GSP.fake(batch_size=4, seq_length_30=13, n_gsp_per_batch=32) + + +def test_gsp_pad(): + + s = GSP.fake(batch_size=4, seq_length_30=13, n_gsp_per_batch=7).split()[0] + s.to_numpy() + s.pad(n_gsp_per_example=32) + + assert s.gsp_yield.shape == (13, 32) + + +def test_gsp_split(): + + s = GSP.fake(batch_size=4, seq_length_30=13, n_gsp_per_batch=32) + split = s.split() + + assert len(split) == 4 + assert type(split[0]) == GSP + assert (split[0].gsp_yield == s.gsp_yield[0]).all() + + +def test_gsp_join(): + + s = GSP.fake(batch_size=2, seq_length_30=13, n_gsp_per_batch=32).split() + + s: GSP = GSP.create_batch_from_examples(s) + + assert s.batch_size == 2 + assert len(s.gsp_yield.shape) == 3 + assert s.gsp_yield.shape[0] == 2 + assert s.gsp_yield.shape[1] == 13 + assert s.gsp_yield.shape[2] == 32 + + +def test_nwp(): + + s = NWP.fake(batch_size=4, seq_length_5=13, nwp_image_size_pixels=64, number_nwp_channels=8) + + +def test_nwp_split(): + + s = NWP.fake(batch_size=4, seq_length_5=13, nwp_image_size_pixels=64, number_nwp_channels=8) + s = s.split() + + +def test_pv(): + + s = PV.fake(batch_size=4, seq_length_5=13, n_pv_systems_per_batch=128) + + +def test_nwp_pad(): + + s = PV.fake(batch_size=4, seq_length_5=13, n_pv_systems_per_batch=37).split()[0] + s.to_numpy() + s.pad(n_pv_systems_per_example=128) + + assert s.pv_yield.shape == (13, 128) + + +def test_satellite(): + + s = Satellite.fake( + batch_size=4, seq_length_5=13, satellite_image_size_pixels=64, number_sat_channels=7 + ) + + assert s.sat_x_coords is not None + + +def test_sun(): + + s = Sun.fake( + batch_size=4, + seq_length_5=13, + ) + + +def test_topo(): + + s = Topographic.fake( + batch_size=4, + satellite_image_size_pixels=64, + ) diff --git a/tests/data_sources/test_nwp_data_source.py b/tests/data_sources/test_nwp_data_source.py index 94b7dd09..e698c492 100644 --- a/tests/data_sources/test_nwp_data_source.py +++ b/tests/data_sources/test_nwp_data_source.py @@ -1,7 +1,7 @@ import os import nowcasting_dataset -from nowcasting_dataset.data_sources.nwp_data_source import NWPDataSource +from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWPDataSource def test_nwp_data_source_init(): @@ -59,8 +59,8 @@ def test_nwp_data_source_batch(): t0_datetimes = nwp._data.init_time[2:10].values x = nwp._data.x[0:4].values - y = nwp._data.x[0:4].values + y = nwp._data.y[0:4].values batch = nwp.get_batch(t0_datetimes=t0_datetimes, x_locations=x, y_locations=y) - assert len(batch) == 4 + assert batch.batch_size == 4 diff --git a/tests/data_sources/test_pv_data_source.py b/tests/data_sources/test_pv_data_source.py index da36d150..93dbc41c 100644 --- a/tests/data_sources/test_pv_data_source.py +++ b/tests/data_sources/test_pv_data_source.py @@ -5,7 +5,7 @@ import pandas as pd import nowcasting_dataset -from nowcasting_dataset.data_sources.pv_data_source import ( +from nowcasting_dataset.data_sources.pv.pv_data_source import ( PVDataSource, drop_pv_systems_which_produce_overnight, ) @@ -40,12 +40,11 @@ def test_get_example_and_batch(): example = pv_data_source.get_example( pv_data_source.pv_power.index[0], x_locations[0], y_locations[0] ) - assert "pv_yield" in example.keys() batch = pv_data_source.get_batch( - pv_data_source.pv_power.index[0:5], x_locations[0:10], y_locations[0:10] + pv_data_source.pv_power.index[6:11], x_locations[0:10], y_locations[0:10] ) - assert len(batch) == 5 + assert batch.batch_size == 5 def test_drop_pv_systems_which_produce_overnight(): diff --git a/tests/data_sources/test_satellite_data_source.py b/tests/data_sources/test_satellite_data_source.py index 8d5c710f..12b817a7 100644 --- a/tests/data_sources/test_satellite_data_source.py +++ b/tests/data_sources/test_satellite_data_source.py @@ -38,7 +38,7 @@ def test_get_example(sat_data_source, x, y, left, right, top, bottom): sat_data_source.open() t0_dt = pd.Timestamp("2019-01-01T13:00") example = sat_data_source.get_example(t0_dt=t0_dt, x_meters_center=x, y_meters_center=y) - sat_data = example["sat_data"] + sat_data = example.sat_data assert left == sat_data.x.values[0] assert right == sat_data.x.values[-1] # sat_data.y is top-to-bottom. diff --git a/tests/data_sources/test_topographic_data_source.py b/tests/data_sources/test_topographic_data_source.py index a5482d9b..168c42e3 100644 --- a/tests/data_sources/test_topographic_data_source.py +++ b/tests/data_sources/test_topographic_data_source.py @@ -32,7 +32,7 @@ def test_get_example_2km(x, y, left, right, top, bottom): ) t0_dt = pd.Timestamp("2019-01-01T13:00") example = topo_source.get_example(t0_dt=t0_dt, x_meters_center=x, y_meters_center=y) - topo_data = example["topo_data"] + topo_data = example.topo_data assert topo_data.shape == (128, 128) assert len(topo_data.x) == 128 assert len(topo_data.y) == 128 diff --git a/tests/dataset/test_batch.py b/tests/dataset/test_batch.py new file mode 100644 index 00000000..2624bf36 --- /dev/null +++ b/tests/dataset/test_batch.py @@ -0,0 +1,74 @@ +from nowcasting_dataset.data_sources.gsp.gsp_model import GSP +import numpy as np + +from nowcasting_dataset.dataset.batch import Batch, GSP +from nowcasting_dataset.dataset.validate import FakeDataset +import torch +from nowcasting_dataset.config.model import Configuration + +import xarray as xr + + +def test_model(): + + _ = Batch.fake() + + +def test_model_to_numpy(): + + f = Batch.fake() + + f.change_type_to_numpy() + + assert type(f.gsp) == GSP + + +def test_model_split(): + + f = Batch.fake() + + data = f.split() + + assert len(data) == f.batch_size + assert type(data[0].gsp) == GSP + + +def test_model_to_xr_dataset(configuration): + + f = Batch.fake(configuration=configuration) + f_xr = f.batch_to_dataset() + + assert type(f_xr) == xr.Dataset + + +def test_model_from_xr_dataset(): + + f = Batch.fake() + + f_xr = f.batch_to_dataset() + + _ = Batch.load_batch_from_dataset(xr_dataset=f_xr) + + +def test_model_from_xr_dataset_to_numpy(): + + f = Batch.fake() + + f_xr = f.batch_to_dataset() + fs = Batch.load_batch_from_dataset(xr_dataset=f_xr) + # check they are the same + fs.change_type_to_numpy() + f.gsp.to_numpy() + assert f.gsp.gsp_yield.shape == fs.gsp.gsp_yield.shape + assert (f.gsp.gsp_yield[0].astype(np.float32) == fs.gsp.gsp_yield[0]).all() + assert (f.gsp.gsp_yield.astype(np.float32) == fs.gsp.gsp_yield).all() + + +def test_fake_dataset(): + train = torch.utils.data.DataLoader(FakeDataset(configuration=Configuration()), batch_size=None) + i = iter(train) + x = next(i) + + x = Batch(**x) + # IT WORKS + assert type(x.satellite.sat_data) == torch.Tensor diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 0bca540d..bf7f4509 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -9,16 +9,11 @@ import nowcasting_dataset from nowcasting_dataset.config.load import load_yaml_configuration -from nowcasting_dataset.consts import GSP_DATETIME_INDEX, DEFAULT_REQUIRED_KEYS + from nowcasting_dataset.dataset import datamodule -from nowcasting_dataset.dataset.batch import batch_to_dataset from nowcasting_dataset.dataset.datamodule import NowcastingDataModule -from nowcasting_dataset.dataset.example import Example -from nowcasting_dataset.dataset.example import ( - xr_to_example, -) from nowcasting_dataset.dataset.split.split import SplitMethod -from nowcasting_dataset.dataset.validate import validate_example, validate_batch_from_configuration +from nowcasting_dataset.dataset.batch import Batch logging.basicConfig(format="%(asctime)s %(levelname)s %(pathname)s %(lineno)d %(message)s") _LOG = logging.getLogger("nowcasting_dataset") @@ -116,26 +111,28 @@ def test_data_module(config_filename): data_generator = iter(data_module.train_dataset) batch = next(data_generator) - assert len(batch) == config.process.batch_size + assert batch["batch_size"] == config.process.batch_size - for key in list(Example.__annotations__.keys()): - assert key in batch[0].keys() + _ = Batch(**batch) - seq_len_30_minutes = 4 # 30 minutes history, 60 minutes in the future plus now, is 4) - seq_len_5_minutes = ( - 19 # 30 minutes history (=6), 60 minutes in the future (=12) plus now, is 19) - ) + # for key in list(Example.__annotations__.keys()): + # assert key in batch[0].keys() + # + # seq_len_30_minutes = 4 # 30 minutes history, 60 minutes in the future plus now, is 4) + # seq_len_5_minutes = ( + # 19 # 30 minutes history (=6), 60 minutes in the future (=12) plus now, is 19) + # ) - for x in batch: - validate_example( - data=x, - n_nwp_channels=len(config.process.nwp_channels), - nwp_image_size=config.process.nwp_image_size_pixels, - n_sat_channels=len(config.process.sat_channels), - sat_image_size=config.process.satellite_image_size_pixels, - seq_len_30_minutes=seq_len_30_minutes, - seq_len_5_minutes=seq_len_5_minutes, - ) + # for x in batch: + # validate_example( + # data=x, + # n_nwp_channels=len(config.process.nwp_channels), + # nwp_image_size=config.process.nwp_image_size_pixels, + # n_sat_channels=len(config.process.sat_channels), + # sat_image_size=config.process.satellite_image_size_pixels, + # seq_len_30_minutes=seq_len_30_minutes, + # seq_len_5_minutes=seq_len_5_minutes, + # ) def test_batch_to_batch_to_dataset(): @@ -155,6 +152,7 @@ def test_batch_to_batch_to_dataset(): nwp_base_path=config.input_data.nwp_zarr_path, gsp_filename=config.input_data.gsp_zarr_path, topographic_filename=config.input_data.topographic_filename, + sun_filename=config.input_data.sun_zarr_path, pin_memory=True, #: Passed to DataLoader. num_workers=0, #: Passed to DataLoader. prefetch_factor=8, #: Passed to DataLoader. @@ -179,14 +177,11 @@ def test_batch_to_batch_to_dataset(): data_generator = iter(data_module.train_dataset) batch = next(data_generator) - batch_xr = batch_to_dataset(batch=batch) - assert type(batch_xr) == xr.Dataset - assert GSP_DATETIME_INDEX in batch_xr - assert pd.DataFrame(batch_xr[GSP_DATETIME_INDEX]).isnull().sum().sum() == 0 + batch = Batch(**batch) + + batch_xr = batch.batch_to_dataset() - # validate batch - from nowcasting_dataset.dataset.example import to_numpy + assert type(batch_xr) == xr.Dataset + assert pd.DataFrame(batch_xr.gsp_datetime_index).isnull().sum().sum() == 0 - batch0 = xr_to_example(batch_xr=batch_xr, required_keys=DEFAULT_REQUIRED_KEYS) - batch0 = to_numpy(batch0) - validate_batch_from_configuration(data=batch0, configuration=config) + _ = Batch.load_batch_from_dataset(batch_xr) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a1d650da..abb36aff 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -7,20 +7,20 @@ @pytest.fixture -def dataset(sat_data_source): +def dataset(sat_data_source, general_data_source): all_datetimes = sat_data_source.datetime_index() t0_datetimes = nd_time.get_t0_datetimes(datetimes=all_datetimes, total_seq_len=2, history_len=0) return NowcastingDataset( batch_size=8, n_batches_per_epoch_per_worker=64, n_samples_per_timestep=2, - data_sources=[sat_data_source], + data_sources=[sat_data_source, general_data_source], t0_datetimes=t0_datetimes, ) @pytest.fixture -def dataset_gsp(gsp_data_source): +def dataset_gsp(gsp_data_source, general_data_source): all_datetimes = gsp_data_source.datetime_index() t0_datetimes = nd_time.get_t0_datetimes( datetimes=all_datetimes, @@ -34,7 +34,7 @@ def dataset_gsp(gsp_data_source): batch_size=8, n_batches_per_epoch_per_worker=64, n_samples_per_timestep=2, - data_sources=[gsp_data_source], + data_sources=[gsp_data_source, general_data_source], t0_datetimes=t0_datetimes, ) @@ -55,8 +55,8 @@ def test_get_batch(dataset: NowcastingDataset): dataset.per_worker_init(worker_id=1) example = dataset._get_batch() assert isinstance(example, dict) - assert "sat_data" in example - assert example["sat_data"].shape == ( + assert "satellite" in example + assert example["satellite"]["sat_data"].shape == ( 8, 2, pytest.IMAGE_SIZE_PIXELS, @@ -70,4 +70,4 @@ def test_get_batch_gsp(dataset_gsp: NowcastingDataset): example = dataset_gsp._get_batch() assert isinstance(example, dict) - assert GSP_DATETIME_INDEX in example.keys() + assert "gsp" in example.keys() diff --git a/tests/test_netcdf_dataset.py b/tests/test_netcdf_dataset.py index ce3e43ce..aab733e8 100644 --- a/tests/test_netcdf_dataset.py +++ b/tests/test_netcdf_dataset.py @@ -10,6 +10,7 @@ import xarray as xr import nowcasting_dataset +import nowcasting_dataset.dataset.batch from nowcasting_dataset.config.model import Configuration from nowcasting_dataset.consts import ( SATELLITE_X_COORDS, @@ -25,18 +26,18 @@ GSP_DATETIME_INDEX, T0_DT, ) -from nowcasting_dataset.dataset import example + +# from nowcasting_dataset.dataset import example +from nowcasting_dataset.dataset.batch import Batch from nowcasting_dataset.dataset.datasets import NetCDFDataset, worker_init_fn, subselect_data def test_subselect_date(test_data_folder): - dataset = xr.open_dataset(f"{test_data_folder}/0.nc") - x = example.Example( - sat_data=dataset["sat_data"], - nwp=dataset["nwp"], - nwp_target_time=dataset["nwp_time_coords"], - sat_datetime_index=dataset["sat_time_coords"], - ) + + # x = Batch.load_netcdf(f"{test_data_folder}/0.nc") + x = Batch.fake() + x = x.batch_to_dataset() + x = Batch.load_batch_from_dataset(x) batch = subselect_data( x, @@ -46,19 +47,16 @@ def test_subselect_date(test_data_folder): forecast_minutes=10, ) - assert batch[SATELLITE_DATA].shape[1] == 5 - assert batch[NWP_DATA].shape[2] == 5 + assert batch.satellite.sat_data.shape[1] == 5 + assert batch.nwp.nwp.shape[2] == 5 -def test_subselect_date_with_t0_dt(test_data_folder): - dataset = xr.open_dataset(f"{test_data_folder}/0.nc") - x = example.Example( - sat_data=dataset["sat_data"], - nwp=dataset["nwp"], - nwp_target_time=dataset["nwp_time_coords"], - sat_datetime_index=dataset["sat_time_coords"], - ) - x[T0_DT] = x[SATELLITE_DATETIME_INDEX].isel(time=7) +def test_subselect_date_with_to_dt(test_data_folder): + + # x = Batch.load_netcdf(f"{test_data_folder}/0.nc") + x = Batch.fake() + x = x.batch_to_dataset() + x = Batch.load_batch_from_dataset(x) batch = subselect_data( x, @@ -67,8 +65,8 @@ def test_subselect_date_with_t0_dt(test_data_folder): forecast_minutes=10, ) - assert batch[SATELLITE_DATA].shape[1] == 5 - assert batch[NWP_DATA].shape[2] == 5 + assert batch.satellite.sat_data.shape[1] == 5 + assert batch.nwp.nwp.shape[2] == 5 def test_netcdf_dataset_local_using_configuration(configuration: Configuration): @@ -105,11 +103,12 @@ def test_netcdf_dataset_local_using_configuration(configuration: Configuration): t = iter(train_dataset) data = next(t) - sat_data = data[SATELLITE_DATA] + sat_data = data.satellite.sat_data + # TODO # Sat is in 5min increments, so should have 2 history + current + 2 future assert sat_data.shape[1] == 5 - assert data[NWP_DATA].shape[2] == 5 + assert data.nwp.nwp.shape[2] == 5 # Make sure file isn't deleted! assert os.path.exists(os.path.join(DATA_PATH, "0.nc")) @@ -142,13 +141,13 @@ def test_get_dataloaders_gcp(configuration: Configuration): train_dataset.per_worker_init(1) t = iter(train_dataset) - data = next(t) + data: Batch = next(t) # image - z = data[SATELLITE_DATA][0][0][:, :, 0] - _ = data[GSP_YIELD][0][:, 0] + z = data.satellite.sat_data[0][0][:, :, 0] + _ = data.gsp.gsp_yield[0][:, 0] - _ = pd.to_datetime(data[SATELLITE_DATETIME_INDEX][0, 0], unit="s") + _ = pd.to_datetime(data.satellite.sat_datetime_index[0, 0], unit="s") fig = go.Figure(data=go.Contour(z=z)) @@ -206,15 +205,15 @@ def test_required_keys_gcp(configuration: Configuration): os.path.join(DATA_PATH, "train"), os.path.join(TEMP_PATH, "train"), cloud="gcp", - required_keys=[ - NWP_DATA, - NWP_X_COORDS, - NWP_Y_COORDS, - SATELLITE_DATA, - SATELLITE_X_COORDS, - SATELLITE_Y_COORDS, - GSP_DATETIME_INDEX, - ], + # required_keys=[ + # NWP_DATA, + # NWP_X_COORDS, + # NWP_Y_COORDS, + # SATELLITE_DATA, + # SATELLITE_X_COORDS, + # SATELLITE_Y_COORDS, + # GSP_DATETIME_INDEX, + # ], configuration=configuration, ) diff --git a/tests/test_time.py b/tests/test_time.py index 2c5db87e..8b64b3d3 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -66,16 +66,13 @@ def test_timesteps_to_duration(): def test_datetime_features_in_example(): index = pd.date_range("2020-01-01", "2020-01-06 23:00", freq="h") example = nd_time.datetime_features_in_example(index) - assert len(example["hour_of_day_sin"]) == len(index) + assert len(example.hour_of_day_sin) == len(index) for col_name in ["hour_of_day_sin", "hour_of_day_cos"]: - assert col_name in example np.testing.assert_array_almost_equal( - example[col_name], np.tile(example[col_name][:24], reps=6) + example.__getattribute__(col_name), + np.tile(example.__getattribute__(col_name)[:24], reps=6), ) - assert "day_of_year_sin" in example - assert "day_of_year_cos" in example - @pytest.mark.parametrize("history_length", [2, 3, 12]) @pytest.mark.parametrize("forecast_length", [2, 3, 12]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 46644ba7..2b8bf5cb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,8 @@ import pytest from nowcasting_dataset import utils -from nowcasting_dataset.dataset.example import Example + +# from nowcasting_dataset.dataset.example import Example def test_is_monotically_increasing(): @@ -33,22 +34,3 @@ def test_sin_and_cos(): def test_get_netcdf_filename(): assert utils.get_netcdf_filename(10) == "10.nc" assert utils.get_netcdf_filename(10, add_hash=True) == "77eb6f_10.nc" - - -def test_pad_data(): - seq_length = 4 - n_gsp_system_ids = 17 - - data = Example() - data["gsp_yield"] = np.random.random((seq_length, n_gsp_system_ids)) - data["gsp_system_id"] = np.random.random((n_gsp_system_ids)) - - data = utils.pad_data( - data=data, - pad_size=1, - one_dimensional_arrays=["gsp_system_id"], - two_dimensional_arrays=["gsp_yield"], - ) - - assert data["gsp_yield"].shape == (seq_length, n_gsp_system_ids + 1) - assert data["gsp_system_id"].shape == (n_gsp_system_ids + 1,)