Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion nowcasting_dataset/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,5 @@
SPATIAL_AND_TEMPORAL_LOCATIONS_OF_EACH_EXAMPLE_FILENAME = (
"spatial_and_temporal_locations_of_each_example.csv"
)
SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES = ("t0_datetime_UTC", "x_center_OSGB", "y_center_OSGB")

LOG_LEVELS = ("DEBUG", "INFO", "WARNING", "ERROR")
80 changes: 24 additions & 56 deletions nowcasting_dataset/data_sources/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import nowcasting_dataset.filesystem.utils as nd_fs_utils
import nowcasting_dataset.time as nd_time
from nowcasting_dataset import square, utils
from nowcasting_dataset.consts import SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES
from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput
from nowcasting_dataset.data_sources.metadata.metadata_model import SpaceTimeLocation
from nowcasting_dataset.dataset.xr_utils import (
convert_coordinates_to_indexes_for_list_datasets,
join_list_dataset_to_batch_dataset,
Expand Down Expand Up @@ -137,7 +137,7 @@ def open(self):
@utils.exception_logger
def create_batches(
self,
spatial_and_temporal_locations_of_each_example: pd.DataFrame,
spatial_and_temporal_locations_of_each_example: List[SpaceTimeLocation],
idx_of_first_batch: int,
batch_size: int,
dst_path: Path,
Expand Down Expand Up @@ -174,16 +174,6 @@ def create_batches(
)
assert upload_every_n_batches >= 0, "`upload_every_n_batches` must be >= 0"

spatial_and_temporal_locations_of_each_example_columns = (
spatial_and_temporal_locations_of_each_example.columns.to_list()
)
assert spatial_and_temporal_locations_of_each_example_columns == list(
SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES
), (
f"The provided data columns {spatial_and_temporal_locations_of_each_example_columns}"
f" do not match {SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES=}"
)

self.open()

# Figure out where to write batches to:
Expand All @@ -200,7 +190,7 @@ def create_batches(
batch_idx=batch_idx, batch_size=batch_size
)

locations_for_batch = spatial_and_temporal_locations_of_each_example.iloc[
locations_for_batch = spatial_and_temporal_locations_of_each_example[
start_example_idx:end_example_idx
]
locations_for_batches.append(locations_for_batch)
Expand All @@ -211,11 +201,7 @@ def create_batches(
logger.debug(f"{self.__class__.__name__} creating batch {batch_idx}!")

# Generate batch.
batch = self.get_batch(
t0_datetimes_utc=locations_for_batch.t0_datetime_UTC,
x_centers_osgb=locations_for_batch.x_center_OSGB,
y_centers_osgb=locations_for_batch.y_center_OSGB,
)
batch = self.get_batch(locations=locations_for_batch)

# Save batch to disk.
batch.save_netcdf(
Expand All @@ -239,43 +225,26 @@ def create_batches(
dst_path=dst_path, local_path=path_to_write_to
)

def get_batch(
self,
t0_datetimes_utc: pd.DatetimeIndex,
x_centers_osgb: Iterable[Number],
y_centers_osgb: Iterable[Number],
) -> DataSourceOutput:
def get_batch(self, locations: List[SpaceTimeLocation]) -> DataSourceOutput:
"""
Get Batch Data

Args:
t0_datetimes_utc: list of timestamps for the datetime of the batches.
The batch will also include data for historic and future depending
on `history_minutes` and `future_minutes`.
The batch size is given by the length of the t0_datetimes.
x_centers_osgb: x center batch locations
y_centers_osgb: y center batch locations
locations: List of locations object
A location object contains
- a timestamp of the example (t0_datetime_utc),
- the x center location of the example (x_location_osgb)
- the y center location of the example(y_location_osgb)

Returns: Batch data.
"""
assert len(t0_datetimes_utc) == len(x_centers_osgb), (
f"len(t0_datetimes) != len(x_locations): "
f"{len(t0_datetimes_utc)} != {len(x_centers_osgb)}"
)
assert len(t0_datetimes_utc) == len(y_centers_osgb), (
f"len(t0_datetimes) != len(y_locations): "
f"{len(t0_datetimes_utc)} != {len(y_centers_osgb)}"
)
zipped = list(zip(t0_datetimes_utc, x_centers_osgb, y_centers_osgb))
batch_size = len(t0_datetimes_utc)

batch_size = len(locations)

with futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
future_examples = []
for coords in zipped:
t0_datetime, x_location, y_location = coords
future_example = executor.submit(
self.get_example, t0_datetime, x_location, y_location
)
for location in locations:
future_example = executor.submit(self.get_example, location)
future_examples.append(future_example)

# Get the examples back. Loop round each future so we can log a helpful error.
Expand Down Expand Up @@ -378,9 +347,7 @@ def _get_time_slice(self, t0_datetime_utc: pd.Timestamp):

def get_example(
self,
t0_datetime_utc: pd.Timestamp, #: Datetime of "now": The most recent obs.
x_center_osgb: Number, #: Centre, in OSGB coordinates.
y_center_osgb: Number, #: Centre, in OSGB coordinates.
location: SpaceTimeLocation, #: Location object of the most recent observation
) -> xr.Dataset:
"""Must be overridden by child classes."""
raise NotImplementedError()
Expand Down Expand Up @@ -452,23 +419,24 @@ def data(self):
raise RuntimeError("Please run `open()` before accessing data!")
return self._data

def get_example(
self, t0_datetime_utc: pd.Timestamp, x_center_osgb: Number, y_center_osgb: Number
) -> xr.Dataset:
def get_example(self, location: SpaceTimeLocation) -> xr.Dataset:
"""
Get Example data

Args:
t0_datetime_utc: list of timestamps for the datetime of the batches.
The batch will also include data for historic and future depending
on `history_minutes` and `future_minutes`.
x_center_osgb: x center batch locations
y_center_osgb: y center batch locations
location: A location object of the example which contains
- a timestamp of the example (t0_datetime_utc),
- the x center location of the example (x_location_osgb)
- the y center location of the example(y_location_osgb)

Returns: Example Data

"""

t0_datetime_utc = location.t0_datetime_utc
x_center_osgb = location.x_center_osgb
y_center_osgb = location.y_center_osgb

logger.debug(
f"Getting example for {t0_datetime_utc=}, " f"{x_center_osgb=} and {y_center_osgb=}"
)
Expand Down
70 changes: 40 additions & 30 deletions nowcasting_dataset/data_sources/fake/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from nowcasting_dataset.data_sources.gsp.eso import get_gsp_metadata_from_eso
from nowcasting_dataset.data_sources.gsp.gsp_model import GSP
from nowcasting_dataset.data_sources.metadata.metadata_model import Metadata
from nowcasting_dataset.data_sources.metadata.metadata_model import Metadata, SpaceTimeLocation
from nowcasting_dataset.data_sources.nwp.nwp_model import NWP
from nowcasting_dataset.data_sources.optical_flow.optical_flow_model import OpticalFlow
from nowcasting_dataset.data_sources.pv.pv_model import PV
Expand Down Expand Up @@ -108,9 +108,9 @@ def gsp_fake(
t0_datetimes_utc = make_t0_datetimes_utc(batch_size)
x_centers_osgb, y_centers_osgb = make_random_x_and_y_osgb_centers(batch_size)
else:
t0_datetimes_utc = metadata.t0_datetime_utc
x_centers_osgb = metadata.x_center_osgb
y_centers_osgb = metadata.y_center_osgb
t0_datetimes_utc = metadata.t0_datetimes_utc
x_centers_osgb = metadata.x_centers_osgb
y_centers_osgb = metadata.y_centers_osgb

# make batch of arrays
xr_datasets = [
Expand Down Expand Up @@ -157,13 +157,17 @@ def metadata_fake(
# choose random index
index = np.random.choice(len(metadata), size=batch_size)

lat = metadata.iloc[index].centroid_lat
lon = metadata.iloc[index].centroid_lon
lat = list(metadata.iloc[index].centroid_lat)
lon = list(metadata.iloc[index].centroid_lon)
ids = list(metadata.iloc[index].index)
id_types = ["gsp"] * batch_size

else:
# get random OSGB center in the UK
lat = np.random.uniform(51, 55, batch_size)
lon = np.random.uniform(-2.5, 1, batch_size)
ids = [None] * batch_size
id_types = [None] * batch_size

x_centers_osgb, y_centers_osgb = lat_lon_to_osgb(lat=lat, lon=lon)

Expand All @@ -172,13 +176,19 @@ def metadata_fake(
batch_size=batch_size, temporally_align_examples=temporally_align_examples
)

metadata_dict = {}
metadata_dict["batch_size"] = batch_size
metadata_dict["x_center_osgb"] = list(x_centers_osgb)
metadata_dict["y_center_osgb"] = list(y_centers_osgb)
metadata_dict["t0_datetime_utc"] = list(t0_datetimes_utc)
# would be good to parrelize this
locations = [
SpaceTimeLocation(
t0_datetime_utc=t0_datetimes_utc[i],
x_center_osgb=x_centers_osgb[i],
y_center_osgb=y_centers_osgb[i],
id=ids[i],
id_type=id_types[i],
)
for i in range(0, batch_size)
]

return Metadata(**metadata_dict)
return Metadata(batch_size=batch_size, space_time_locations=locations)


def nwp_fake(
Expand All @@ -201,9 +211,9 @@ def nwp_fake(
t0_datetimes_utc = make_t0_datetimes_utc(batch_size)
x_centers_osgb, y_centers_osgb = make_random_x_and_y_osgb_centers(batch_size)
else:
t0_datetimes_utc = metadata.t0_datetime_utc
x_centers_osgb = metadata.x_center_osgb
y_centers_osgb = metadata.y_center_osgb
t0_datetimes_utc = metadata.t0_datetimes_utc
x_centers_osgb = metadata.x_centers_osgb
y_centers_osgb = metadata.y_centers_osgb

# make batch of arrays
xr_arrays = [
Expand Down Expand Up @@ -248,9 +258,9 @@ def pv_fake(
t0_datetimes_utc = make_t0_datetimes_utc(batch_size)
x_centers_osgb, y_centers_osgb = make_random_x_and_y_osgb_centers(batch_size)
else:
t0_datetimes_utc = metadata.t0_datetime_utc
x_centers_osgb = metadata.x_center_osgb
y_centers_osgb = metadata.y_center_osgb
t0_datetimes_utc = metadata.t0_datetimes_utc
x_centers_osgb = metadata.x_centers_osgb
y_centers_osgb = metadata.y_centers_osgb

# make batch of arrays
xr_datasets = [
Expand Down Expand Up @@ -296,9 +306,9 @@ def satellite_fake(
t0_datetimes_utc = make_t0_datetimes_utc(batch_size)
x_centers_osgb, y_centers_osgb = make_random_x_and_y_osgb_centers(batch_size)
else:
t0_datetimes_utc = metadata.t0_datetime_utc
x_centers_osgb = metadata.x_center_osgb
y_centers_osgb = metadata.y_center_osgb
t0_datetimes_utc = metadata.t0_datetimes_utc
x_centers_osgb = metadata.x_centers_osgb
y_centers_osgb = metadata.y_centers_osgb

# make batch of arrays
xr_arrays = [
Expand Down Expand Up @@ -340,9 +350,9 @@ def hrv_satellite_fake(
t0_datetimes_utc = make_t0_datetimes_utc(batch_size)
x_centers_osgb, y_centers_osgb = make_random_x_and_y_osgb_centers(batch_size)
else:
t0_datetimes_utc = metadata.t0_datetime_utc
x_centers_osgb = metadata.x_center_osgb
y_centers_osgb = metadata.y_center_osgb
t0_datetimes_utc = metadata.t0_datetimes_utc
x_centers_osgb = metadata.x_centers_osgb
y_centers_osgb = metadata.y_centers_osgb

# make batch of arrays
xr_arrays = [
Expand Down Expand Up @@ -385,9 +395,9 @@ def optical_flow_fake(
t0_datetimes_utc = make_t0_datetimes_utc(batch_size)
x_centers_osgb, y_centers_osgb = make_random_x_and_y_osgb_centers(batch_size)
else:
t0_datetimes_utc = metadata.t0_datetime_utc
x_centers_osgb = metadata.x_center_osgb
y_centers_osgb = metadata.y_center_osgb
t0_datetimes_utc = metadata.t0_datetimes_utc
x_centers_osgb = metadata.x_centers_osgb
y_centers_osgb = metadata.y_centers_osgb

# make batch of arrays
xr_arrays = [
Expand Down Expand Up @@ -421,7 +431,7 @@ def sun_fake(
if metadata is None:
t0_datetimes_utc = make_t0_datetimes_utc(batch_size)
else:
t0_datetimes_utc = metadata.t0_datetime_utc
t0_datetimes_utc = metadata.t0_datetimes_utc

# create dataset with both azimuth and elevation, index with time
# make batch of arrays
Expand All @@ -442,8 +452,8 @@ def topographic_fake(batch_size, image_size_pixels, metadata: Optional[Metadata]
if metadata is None:
x_centers_osgb, y_centers_osgb = make_random_x_and_y_osgb_centers(batch_size)
else:
x_centers_osgb = metadata.x_center_osgb
y_centers_osgb = metadata.y_center_osgb
x_centers_osgb = metadata.x_centers_osgb
y_centers_osgb = metadata.y_centers_osgb

# make batch of arrays
xr_arrays = []
Expand Down
Loading