diff --git a/.github/workflows/tests-conda.yml b/.github/workflows/tests-conda.yml index 743ff98d..9574a073 100644 --- a/.github/workflows/tests-conda.yml +++ b/.github/workflows/tests-conda.yml @@ -97,7 +97,7 @@ jobs: - name: INSTALL - Project run: | - pip install --editable=. + pip install --editable=.[indexing] - name: Run tests env: diff --git a/.github/workflows/tests-python.yml b/.github/workflows/tests-python.yml index 830d25ec..0718242f 100644 --- a/.github/workflows/tests-python.yml +++ b/.github/workflows/tests-python.yml @@ -67,7 +67,7 @@ jobs: - name: Install project run: | pip3 install --requirement=requirements-test.txt - pip3 install --editable=. + pip3 install --editable=.[indexing] - name: Run tests env: diff --git a/.gitignore b/.gitignore index ec33450d..306c7ebd 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ core.* *.idx *.grib2 +*.iarr !sample_data/hrrr/20201214/subset_20201214_hrrr.t00z.wrfsfcf12.grib2 .idea diff --git a/herbie/index/__init__.py b/herbie/index/__init__.py new file mode 100644 index 00000000..7c406231 --- /dev/null +++ b/herbie/index/__init__.py @@ -0,0 +1,6 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +from herbie.index.monkey import monkeypatch_iarray + +monkeypatch_iarray() diff --git a/herbie/index/core.py b/herbie/index/core.py new file mode 100644 index 00000000..bef42bfe --- /dev/null +++ b/herbie/index/core.py @@ -0,0 +1,340 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +import dataclasses +import logging +import os.path +import typing as t +from pathlib import Path + +import iarray_community as ia +import numpy as np +import pandas as pd +import shapely +import xarray as xr +from ndindex import Slice +from scipy.constants import convert_temperature +from shapely.geometry import CAP_STYLE, Point, Polygon + +from herbie.index.model import BBox, Circle, DataSchema, QueryParameter +from herbie.index.util import ( + dataset_get_data_variable_names, + dataset_info, + is_sequence, + round_clipped, + unit, +) + +logger = logging.getLogger(__name__) + + +class NwpIndex: + """ + Manage a multidimensional index of NWP data, using Caterva and ironArray. + + - https://caterva.readthedocs.io/ + - https://ironarray.io/docs/html/ + + TODO: Think about making this an xarray accessor, e.g. `ds.xindex`. + + - https://docs.xarray.dev/en/stable/internals/extending-xarray.html + """ + + # Where the ironArray files (`.iarr`) will be stored. + # FIXME: Segfaults when path contains spaces. => Report to ironArray fame. + # `/Users/amo/Library/Application Support/herbie/index-iarray/precipitation_amount_1hour_Accumulation.iarr` + # BASEDIR = platformdirs.user_data_path("herbie").joinpath("index-iarray") + + # Alternatively, just use the working directory for now. + BASEDIR = Path(os.path.curdir) + + # Configure ironArray. + IA_CONFIG = dict( + codec=ia.Codec.ZSTD, + clevel=1, + # How to choose the best numbers? + # https://ironarray.io/docs/html/tutorials/03.Slicing_Datasets_and_Creating_Views.html#Optimization-Tips + chunks=(360, 360, 720), + blocks=(180, 180, 360), + # chunks=(360, 128, 1440), + # blocks=(8, 8, 720), + # TODO: Does it really work? + # nthreads=12, + ) + + def __init__(self, name, resolution=None, schema=None, dataset=None, irondata=None): + self.name: str = name + self._resolution: float = resolution + self.dataset: xr.Dataset = dataset + self.irondata: ia.IArray = irondata + + self.path = self.BASEDIR.joinpath(self.name).with_suffix(".iarr") + self.schema: DataSchema = schema or DataSchema(path=self.path) + + def exists(self): + return self.path.exists() + + @property + def resolution(self): + if self._resolution: + return self._resolution + elif self.schema.ds is not None: + return self.schema.get_resolution() + else: + raise ValueError("Resolution is required for querying the Dataset by geospatial coordinates") + + @resolution.setter + def resolution(self, value): + self._resolution = value + + def load(self): + """ + Load data from ironArray file. + """ + + # Load data. + # TODO: Handle multiple variable names. + self.irondata: ia.IArray = ia.open(str(self.path)) + logger.info(f"Loaded IArray from: {self.path}") + logger.debug(f"IArray info:\n{self.irondata.info}") + + # Load schema. + self.schema.load() + + return self + + def save(self, dataset: xr.Dataset): + """ + Save data to ironArray file, effectively indexing it on all dimensions. + + Derived from ironArray's `fetch_data.py` example program [1,2], + and its documentation about "Configuring ironArray" [3]. + + [1] https://github.com/ironArray/iron-array-notebooks/blob/76fe0e9f93a75443e3aed73a9ffc36119d4aad6c/tutorials/fetch_data.py#L11-L18 + [2] https://github.com/ironArray/iron-array-notebooks/blob/76fe0e9f93a75443e3aed73a9ffc36119d4aad6c/tutorials/fetch_data.py#L37-L41 + [3] https://ironarray.io/docs/html/tutorials/02.Configuring_ironArray.html + """ + + # Use data from first data variable within dataset. + # TODO: Handle multiple variable names. + data_variables = dataset_get_data_variable_names(dataset) + data_variable = data_variables[0] + logger.info(f"Discovered dataset variable: {data_variable}") + logger.info(f"Storing and indexing to: {self.path}") + logger.debug(f"Dataset info:\n{dataset_info(dataset)}") + + data = dataset[data_variable] + logger.info( + f"Data variable '{data_variable}' has shape={data.shape} and dtype={data.dtype}" + ) + with ia.config(**self.IA_CONFIG): + ia_data = ia.empty( + shape=data.shape, dtype=data.dtype, urlpath=str(self.path) + ) + logger.info("Populating IArray") + ia_data[:] = data.values + logger.info(f"IArray is ready") + logger.debug(f"IArray info:\n{ia_data.info}") + self.irondata = ia_data + + # Save schema. + self.schema.save(ds=dataset) + + def query(self, time=None, location: t.Union[BBox, Circle] = None, lat=None, lon=None) -> "Result": + """ + Query ironArray by multiple dimensions. + """ + + if location is not None: + + # Select location by circle (point and distance). + if isinstance(location, Circle): + circle: Circle = location + # At 38 degrees North latitude (which passes through Stockton California + # and Charlottesville Virginia), one degree of longitude equals 54.6 miles. + # => 0.25 degrees equal 13.65 miles. + # + # -- https://www.usgs.gov/faqs/how-much-distance-does-degree-minute-and-second-cover-your-maps + # + # FIXME: Verify this, and apply the correct conversion for other places on earth. + factor = 54.6 * self.resolution + distance = (circle.distance / (factor * unit.miles)).magnitude + + # Compute minimum bounding rectangle from circle. + point = Point([circle.point.longitude, circle.point.latitude]) \ + .buffer(distance, cap_style=CAP_STYLE.square) + bbox: Polygon = point.minimum_rotated_rectangle + location = BBox(*bbox.bounds) + + # Select location by bounding box. + # https://boundingbox.klokantech.com/ + if isinstance(location, BBox): + lat = [location.lat1, location.lat2] + lon = [location.lon1, location.lon2] + + else: + raise ValueError(f"Unable to process location={location}, type={type(location)}") + + # Compute slices for time or time range, and geolocation point or range (bbox). + time_slice = self.time_slice(coordinate="time", value=time) + lat_slice = self.geo_slice(coordinate="lat", value=lat) + lon_slice = self.geo_slice(coordinate="lon", value=lon) + + # Slice data. + data = self.irondata[time_slice, lat_slice, lon_slice] + + # Rebuild Dataset from result. + coords = { + "time": self.schema.ds.coords["time"][time_slice.start: time_slice.stop], + "lat": self.schema.ds.coords["lat"][lat_slice.start: lat_slice.stop], + "lon": self.schema.ds.coords["lon"][lon_slice.start: lon_slice.stop], + } + ds = self.to_dataset(data, coords=coords) + + return Result(qp=QueryParameter(time=time, lat=lat, lon=lon), ds=ds) + + def to_dataset(self, irondata, coords): + """ + Re-create Xarray Dataset from ironArray data and coordinates. + + The intention is to emit a Dataset which has the same character + as the Dataset originally loaded from GRIB/netCDF/HDF5/Zarr. + """ + + # Re-create empty Dataset with original shape. + schema = self.schema.ds.copy(deep=True) + dataset = xr.Dataset(data_vars=schema.data_vars, coords=coords, attrs=schema.attrs) + + # Populate data. + # TODO: Handle more than one variable. + # TODO: Is there a faster operation than using `list(irondata)`? + variable0_info = self.schema.metadata["variables"][0] + variable0_name = variable0_info["name"] + dataset[variable0_name] = xr.DataArray(list(irondata), **variable0_info) + + return dataset + + def geo_slice(self, coordinate: str, value: t.Union[float, t.Sequence, np.ndarray]): + """ + Compute slice for geolocation point or range (bbox). + """ + + coord = self.schema.ds.coords[coordinate] + + if value is None: + idx = np.where(coord)[0] + effective_slice = Slice(start=idx[0], stop=idx[-1] + 1) + elif isinstance(value, float): + idx = np.where(coord == self.round_location(value))[0][0] + effective_slice = Slice(start=idx, stop=idx + 2) + elif isinstance(value, (t.Sequence, np.ndarray)): + value = sorted(value) + idx = np.where( + np.logical_and( + coord >= self.round_location(value[0]), + coord <= self.round_location(value[1]), + ) + )[0] + effective_slice = Slice(start=idx[0], stop=idx[-1] + 1) + else: + raise ValueError( + f"Unable to process value for {coordinate}={value}, type={type(value)}" + ) + + return effective_slice + + def time_slice( + self, coordinate: str, value: t.Union[float, t.Sequence, np.ndarray] + ): + """ + Compute slice for time or time range. + """ + + coord = self.schema.ds.coords[coordinate] + + if value is None: + idx = np.where(coord)[0] + effective_slice = Slice(idx[0], idx[-1] + 1) + elif isinstance(value, str): + idx = np.where(coord == np.datetime64(value))[0][0] + effective_slice = Slice(idx, idx + 2) + elif isinstance(value, (t.Sequence, np.ndarray, pd.DatetimeIndex)): + idx = np.where( + np.logical_and( + coord >= np.datetime64(value[0]), + coord <= np.datetime64(value[1]), + ) + )[0] + effective_slice = Slice(start=idx[0], stop=idx[-1] + 1) + else: + raise ValueError( + f"Unable to process value for {coordinate}={value}, type={type(value)}" + ) + + return effective_slice + + def round_location(self, value): + return round_clipped(value, self.resolution) + + +@dataclasses.dataclass +class Result: + """ + Wrap query result, and provide convenience accessor methods and value converters. + """ + + qp: QueryParameter + ds: xr.Dataset + + @property + def pv(self): + """ + Return primary variable name. That is, the first one. + + # TODO: Handle multiple variable names. + """ + return list(self.ds.data_vars.keys())[0] + + def select_first(self) -> xr.DataArray: + return self.ds[self.pv][0][0][0] + + def select_first_point(self): + da = self.ds[self.pv] + return da.sel(lat=da["lat"][0], lon=da["lon"][0]) + + def select_first_timestamp(self): + da = self.ds[self.pv] + return da.sel(time=da["time"][0]) + + def kelvin_to_celsius(self): + da = self.ds[self.pv] + da.values = convert_temperature(da.values, "Kelvin", "Celsius") + return self + + def kelvin_to_fahrenheit(self): + da = self.ds[self.pv] + da.values = convert_temperature(da.values, "Kelvin", "Fahrenheit") + return self + + @property + def data(self) -> xr.DataArray: + """ + Auto-select shape of return value, based on the shape of the query parameters. + """ + all_defined = all( + v is not None for v in [self.qp.time, self.qp.lat, self.qp.lon] + ) + is_time_range = is_sequence(self.qp.time) + is_lat_range = is_sequence(self.qp.lat) + is_lon_range = is_sequence(self.qp.lon) + if all_defined and not any([is_time_range, is_lat_range, is_lon_range]): + return self.select_first() + elif not any([is_lat_range, is_lon_range]): + return self.select_first_point() + elif self.qp.time and not is_time_range: + return self.select_first_timestamp() + else: + raise ValueError( + f"Unable to auto-select shape of return value, " + f"query parameters have unknown shape: {self.qp}" + ) diff --git a/herbie/index/loader.py b/herbie/index/loader.py new file mode 100644 index 00000000..cf6e4f34 --- /dev/null +++ b/herbie/index/loader.py @@ -0,0 +1,60 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +import logging + +import fsspec +import numpy as np +import platformdirs +import s3fs +import xarray as xr + +logger = logging.getLogger(__name__) + + +CACHE_BASEDIR = platformdirs.user_cache_path("herbie").joinpath("index-download") + + +def open_era5_zarr(parameter, year, month, datestart=None, dateend=None) -> xr.Dataset: + """ + Load "ERA5 forecasts reanalysis" data from ECMWF, using Zarr. + The ERA5 HRES atmospheric data has a resolution of 31km, 0.28125 degrees [1]. + + The implementation is derived from ironArray's "Slicing Datasets and Creating + Views" documentation [2]. For processing data more efficiently, downloaded data + is cached locally, using fsspec's "filecache" filesystem [3]. + + [1] https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#heading-Spatialgrid + [2] https://ironarray.io/docs/html/tutorials/03.Slicing_Datasets_and_Creating_Views.html + [3] https://filesystem-spec.readthedocs.io/en/latest/features.html#caching-files-locally + """ + location = f"era5-pds/zarr/{year}/{month:02d}/data/{parameter}.zarr/" + logger.info(f"Loading NWP data from {location}") + logger.info(f"Using local cache at {CACHE_BASEDIR}") + + # ERA5 is on AWS S3, it can be accessed anonymously. + fs = s3fs.S3FileSystem(anon=True) + + # Add local cache, using fsspec fame. + fs = fsspec.filesystem("filecache", cache_storage=str(CACHE_BASEDIR), fs=fs) + + # Access resource in Zarr format. + # Possible engines: ['scipy', 'cfgrib', 'gini', 'store', 'zarr'] + s3map = s3fs.S3Map(location, s3=fs) + ds = xr.open_dataset(s3map, engine="zarr") + + # The name of the `time` coordinate may be different between datasets. + time_field_candidates = ["time0", "time1"] + for candidate in time_field_candidates: + if candidate in ds.coords: + ds = ds.rename({candidate: "time"}) + + # Select subset of data based on time range. + if datestart and dateend: + indexers = {"time": slice(np.datetime64(datestart), np.datetime64(dateend))} + ds = ds.sel(indexers=indexers) + + # Rearrange coordinates data from longitude 0 to 360 degrees (long3) to -180 to 180 degrees (long1). + ds = ds.assign(lon=ds["lon"] - 180) + + return ds diff --git a/herbie/index/model.py b/herbie/index/model.py new file mode 100644 index 00000000..b8b7b50d --- /dev/null +++ b/herbie/index/model.py @@ -0,0 +1,136 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +import dataclasses +import json +import typing as t +from pathlib import Path + +import xarray as xr +from pint import Quantity + +from herbie.index.util import dataset_get_data_variable_names, dataset_without_data + + +@dataclasses.dataclass +class DataSchema: + """ + Manage saving and loading an Xarray Dataset schema in netCDF format. + + That means, on saving, all data variables are dropped, but metadata + information about them is stored alongside the data. This information + is reused when re-creating the Dataset in the same shape. + """ + + path: Path + ds: xr.Dataset = None + metadata: t.Dict = None + nc_file: Path = dataclasses.field(init=False) + json_file: Path = dataclasses.field(init=False) + + def __post_init__(self): + self.nc_file = self.path.joinpath("schema.nc") + self.json_file = self.path.joinpath("schema.json") + + def load(self): + """ + Load metadata information for Dataset from netCDF file. + """ + self.ds = xr.load_dataset(self.nc_file) + with open(self.json_file, "r") as fp: + self.metadata = json.load(fp) + + def save(self, ds: xr.Dataset): + """ + Strip data off Dataset, and save its metadata information into netCDF file. + """ + + self.ds = dataset_without_data(ds) + self.metadata = self.get_metadata(ds) + + self.ds.to_netcdf(self.nc_file) + with open(self.json_file, "w") as fp: + json.dump(self.metadata, fp, indent=2) + + @staticmethod + def get_metadata(ds: xr.Dataset): + """ + Get metadata from Dataset. + + This metadata is needed in order to save it for reconstructing the + complete Dataset later. + """ + result = [] + data_variables = dataset_get_data_variable_names(ds) + for variable in data_variables: + da: xr.DataArray = ds[variable] + item = { + "name": da.name, + "attrs": dict(da.attrs), + "dims": list(da.dims), + } + result.append(item) + return {"variables": result} + + def get_resolution(self): + """ + Derive resolution of grid from coordinates. + """ + lat_coord = self.ds.coords["lat"] + lon_coord = self.ds.coords["lon"] + lat_delta = lat_coord[1].values - lat_coord[0].values + lon_delta = lon_coord[1].values - lon_coord[0].values + if abs(lat_delta) == abs(lon_delta): + return abs(lat_delta) + else: + raise ValueError( + "Resolution computed from coordinates deviates between latitude and longitude" + ) + + +@dataclasses.dataclass +class QueryParameter: + """ + Manage query parameters. + """ + + time: t.Optional[str] = None + lat: t.Optional[float] = None + lon: t.Optional[float] = None + + +@dataclasses.dataclass +class Point: + """ + Manage geopoint information. + """ + + longitude: float + latitude: float + + +@dataclasses.dataclass +class Circle: + """ + Manage geolocation circle information. + + Radius in kilometers. + """ + + point: Point + distance: Quantity + + +@dataclasses.dataclass +class BBox: + """ + Manage bounding box information. + + # min_x, min_y, max_x, max_y + # (lon1, lat1, lon2, lat2) = c.bbox + """ + + lon1: float + lat1: float + lon2: float + lat2: float diff --git a/herbie/index/monkey.py b/herbie/index/monkey.py new file mode 100644 index 00000000..5363fec8 --- /dev/null +++ b/herbie/index/monkey.py @@ -0,0 +1,25 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +from iarray_community import IArray + +iarray_info_items_original = IArray.info_items + + +@property +def iarray_info_items(self): + """ + Just a minor patch for ironArray to extend info output. + + TODO: Submit patch to upstream repository. + https://github.com/ironArray/iarray-community + """ + items = iarray_info_items_original.fget(self) + items += [("codec", self.codec)] + items += [("clevel", self.clevel)] + items += [("size", self.size)] + return items + + +def monkeypatch_iarray(): + IArray.info_items = iarray_info_items diff --git a/herbie/index/util.py b/herbie/index/util.py new file mode 100644 index 00000000..1d5f39ac --- /dev/null +++ b/herbie/index/util.py @@ -0,0 +1,51 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +import io +import logging +import sys +import typing as t + +import numpy as np +import pint +import xarray as xr + +unit = pint.UnitRegistry() + + +def round_clipped(value, clipping): + """ + https://stackoverflow.com/a/7859208 + :param value: + :param clipping: + :return: + """ + return round(float(value) / clipping) * clipping + + +def setup_logging(level=logging.INFO): + log_format = "%(asctime)-15s [%(name)-20s] %(levelname)-7s: %(message)s" + logging.basicConfig(format=log_format, stream=sys.stderr, level=level) + + requests_log = logging.getLogger("botocore") + requests_log.setLevel(logging.INFO) + + +def dataset_info(ds: xr.Dataset) -> str: + buf = io.StringIO() + ds.info(buf) + buf.seek(0) + return buf.read() + + +def is_sequence(value): + return not isinstance(value, str) and isinstance(value, (t.Sequence, np.ndarray)) + + +def dataset_get_data_variable_names(ds: xr.Dataset): + return list(ds.data_vars.keys()) + + +def dataset_without_data(ds: xr.Dataset): + data_variables = dataset_get_data_variable_names(ds) + return ds.drop_vars(names=data_variables) diff --git a/setup.cfg b/setup.cfg index c0930423..3de9f588 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,6 +67,11 @@ docs = sphinx-design sphinx-markdown-tables sphinxcontrib-mermaid +indexing = + iarray-community + s3fs + platformdirs + zarr #tests = # pytest diff --git a/tests/test_index_era5.py b/tests/test_index_era5.py new file mode 100644 index 00000000..c4d427de --- /dev/null +++ b/tests/test_index_era5.py @@ -0,0 +1,386 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +import datetime +from unittest import mock + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from xarray.testing import assert_equal + +from herbie.index.core import NwpIndex +from herbie.index.loader import open_era5_zarr +from herbie.index.model import BBox, Circle, Point +from herbie.index.util import unit + +TEMP2M = "air_temperature_at_2_metres" + +TIMERANGE = np.arange( + start=np.datetime64("1987-10-01 08:00"), + stop=np.datetime64("1987-10-01 10:59"), + step=datetime.timedelta(hours=1), +) + + +@pytest.fixture +def era5_temp2m(): + """ + Provide an instance of `NwpIndex` to the test cases. + """ + nwp = NwpIndex(name=TEMP2M) + if not nwp.exists(): + nwp.save(dataset=open_era5_zarr(TEMP2M, 1987, 10, TIMERANGE[0], TIMERANGE[-1])) + nwp.load() + return nwp + + +def test_query_era5_monterey_fahrenheit_point_time(era5_temp2m): + """ + Query indexed ERA5 NWP data for a specific geopoint and time. + """ + + # Temperatures in Monterey, in Fahrenheit. + item = ( + era5_temp2m.query(time="1987-10-01 08:00", lat=36.6083, lon=-121.8674) + .kelvin_to_fahrenheit() + .data + ) + + # Verify values. + assert item.values == np.array(73.805008, dtype=np.float32) + + # Verify coordinate. + assert dict(item.coords) == dict( + time=xr.DataArray(data=np.datetime64("1987-10-01 08:00"), name="time"), + lat=xr.DataArray(data=np.float32(36.5), name="lat"), + lon=xr.DataArray(data=np.float32(-121.75), name="lon"), + ) + + +def test_query_era5_berlin_celsius_point_timerange(era5_temp2m): + """ + Query indexed ERA5 NWP data for the whole time range at a specific geopoint. + """ + + # Temperatures in Berlin, in Celsius. + result = era5_temp2m.query(lat=52.51074, lon=13.43506).kelvin_to_celsius().data + assert len(result.data) == 3 + assert result.shape == (3,) + + # Verify values and coordinates. + reference = xr.DataArray( + data=np.array([6.600006, 6.600006, 6.600006], dtype=np.float32), + coords=dict( + time=xr.DataArray(data=TIMERANGE), + lat=xr.DataArray(data=np.float32(52.5)), + lon=xr.DataArray(data=np.float32(13.5)), + ), + ) + reference = reference.swap_dims(dim_0="time") + assert_equal(result, reference) + + +def test_query_era5_bbox_time(era5_temp2m): + """ + Query indexed ERA5 NWP data for a given area, defined by a bounding box. + + http://bboxfinder.com/ + """ + + # Temperatures in Monterey area, in Fahrenheit. + result = ( + era5_temp2m.query( + time="1987-10-01 08:00", + lat=(36.450837, 36.700907), + lon=(-122.166252, -121.655045), + ) + .kelvin_to_fahrenheit() + .data + ) + assert len(result.data) == 2 + assert result.shape == (2, 3) + + # Verify values and coordinates. + reference = xr.DataArray( + data=np.array( + [[73.58001, 71.89251, 70.88001], [74.93001, 75.717514, 73.80501]], + dtype=np.float32, + ), + dims=("lat", "lon"), + coords=dict( + time=xr.DataArray(data=np.datetime64("1987-10-01 08:00")), + lat=xr.DataArray( + data=np.array([36.75, 36.5], dtype=np.float32), dims=("lat",) + ), + lon=xr.DataArray( + data=np.array([-122.25, -122.0, -121.75], dtype=np.float32), + dims=("lon",), + ), + ), + ) + assert_equal(result, reference) + + +def test_query_era5_geoslice_time(era5_temp2m): + """ + Query indexed ERA5 NWP data for a given slice on the latitude coordinate, + along the same longitude coordinates. + """ + + # Temperatures for whole slice. + result = ( + era5_temp2m.query( + time="1987-10-01 08:00", lat=None, lon=(-122.166252, -121.655045) + ) + .kelvin_to_celsius() + .data + ) + assert len(result.data) == 721 + assert result.shape == (721, 3) + + # Verify coordinates. + reference = xr.DataArray( + data=mock.ANY, + dims=("lat", "lon"), + coords=dict( + time=xr.DataArray(data=np.datetime64("1987-10-01 08:00")), + lat=xr.DataArray( + data=np.arange(start=90.0, stop=-90.01, step=-0.25, dtype=np.float32), + dims=("lat",), + ), + lon=xr.DataArray( + data=np.array([-122.25, -122.0, -121.75], dtype=np.float32), + dims=("lon",), + ), + ), + ) + assert_equal(result, reference) + + # Verify values of first and last record, and its coordinate. + assert result[0].values.tolist() == [ + -21.587493896484375, + -21.587493896484375, + -21.587493896484375, + ] + assert result[0].coords["lat"] == 90 + + assert result[-1].values.tolist() == [ + -43.837493896484375, + -43.837493896484375, + -43.837493896484375, + ] + assert result[-1].coords["lat"] == -90.0 + + +def test_query_era5_point_timerange_tuple(era5_temp2m): + """ + Query indexed ERA5 NWP data within given time range. + This variant uses a `tuple` for defining time range boundaries. + + While the input dataset contains three records, filtering by + time range should only yield two records. + """ + + # Temperatures for whole slice. + result = ( + era5_temp2m.query( + time=(np.datetime64("1987-10-01 08:00"), np.datetime64("1987-10-01 09:05")), + lat=52.51074, + lon=13.43506, + ) + .kelvin_to_celsius() + .data + ) + assert len(result.data) == 2 + assert result.shape == (2,) + + # Verify values and coordinates. + timerange = np.arange( + start=np.datetime64("1987-10-01 08:00"), + stop=np.datetime64("1987-10-01 09:01"), + step=datetime.timedelta(hours=1), + ) + reference = xr.DataArray( + data=np.array([6.600006, 6.600006], dtype=np.float32), + coords=dict( + time=xr.DataArray(data=timerange), + lat=xr.DataArray(data=np.float32(52.5)), + lon=xr.DataArray(data=np.float32(13.5)), + ), + ) + reference = reference.swap_dims(dim_0="time") + assert_equal(result, reference) + + +def test_query_era5_point_timerange_numpy(era5_temp2m): + """ + Query indexed ERA5 NWP data at a specific point within given time range. + This variant uses an `np.array` for defining time range boundaries. + + While the input dataset contains three records, filtering by + time range should only yield two records. + """ + + # Define timerange used for querying. + timerange = np.arange( + start=np.datetime64("1987-10-01 08:00"), + stop=np.datetime64("1987-10-01 09:01"), + step=datetime.timedelta(hours=1), + ) + + # Temperatures for whole slice. + result = ( + era5_temp2m.query(time=timerange, lat=52.51074, lon=13.43506) + .kelvin_to_celsius() + .data + ) + assert len(result.data) == 2 + assert result.shape == (2,) + + # Verify values and coordinates. + reference = xr.DataArray( + data=np.array([6.600006, 6.600006], dtype=np.float32), + coords=dict( + time=xr.DataArray(data=timerange), + lat=xr.DataArray(data=np.float32(52.5)), + lon=xr.DataArray(data=np.float32(13.5)), + ), + ) + reference = reference.swap_dims(dim_0="time") + assert_equal(result, reference) + + +def test_query_era5_bbox_timerange(era5_temp2m): + """ + Query indexed ERA5 NWP data within a given bounding box area and time range. + + This variant uses a pandas `DatetimeIndex` for defining the time range + boundaries, and a `BBox` instance for defining a geospatial bounding box. + """ + + data_var = "air_temperature_at_2_metres" + + # Temperatures in Berlin area, in Celsius. + ds = ( + era5_temp2m.query( + time=pd.date_range( + start="1987-10-01 08:00", end="1987-10-01 09:00", freq="H" + ), + location=BBox(lon1=13.000, lat1=52.700, lon2=13.600, lat2=52.300), + ) + .kelvin_to_celsius() + .ds + ) + assert len(ds) == 1 + assert ds[data_var].shape == (2, 3, 3) + assert ds[data_var].dims == ("time", "lat", "lon") + + # Verify values and coordinates. + reference = xr.DataArray( + dims=("time", "lat", "lon"), + data=np.array( + [ + [ + [6.412506, 6.412506, 6.475006], + [6.537506, 6.537506, 6.600006], + [6.662506, 6.662506, 6.725006], + ], + [ + [6.350006, 6.412506, 6.412506], + [6.537506, 6.537506, 6.600006], + [6.662506, 6.662506, 6.662506], + ], + ], + dtype=np.float32, + ), + coords=dict( + time=xr.DataArray( + data=np.arange( + start=np.datetime64("1987-10-01 08:00:00"), + stop=np.datetime64("1987-10-01 09:00:01"), + step=datetime.timedelta(hours=1), + ), + name="time", + dims=("time",), + ), + lat=xr.DataArray( + data=np.array([52.75, 52.5, 52.25], dtype=np.float32), dims=("lat",) + ), + lon=xr.DataArray( + data=np.array([13.0, 13.25, 13.5], dtype=np.float32), dims=("lon",) + ), + ), + ) + assert_equal(ds[data_var], reference) + + +def test_query_era5_circle_timerange(era5_temp2m): + """ + Query indexed ERA5 NWP data within a given circular area and time range. + + This variant uses a pandas `DatetimeIndex` for defining the time range + boundaries, and a `Circle` instance for defining a geospatial bounding box. + """ + + data_var = "air_temperature_at_2_metres" + + # Temperatures in Monterey area, in Fahrenheit. + ds = ( + era5_temp2m.query( + time=pd.date_range( + start="1987-10-01 08:00", end="1987-10-01 09:00", freq="H" + ), + location=Circle( + Point(longitude=-121.8674, latitude=36.6083), distance=3.5 * unit.miles + ), + ) + .kelvin_to_fahrenheit() + .ds + ) + assert len(ds) == 1 + assert ds[data_var].shape == (2, 3, 3) + assert ds[data_var].dims == ("time", "lat", "lon") + + # Verify values and coordinates. + reference = xr.DataArray( + dims=("time", "lat", "lon"), + data=np.array( + [ + [ + [71.89251, 70.88001, 69.64251], + [75.717514, 73.80501, 72.11751], + [78.530014, 77.29251, 76.280014], + ], + [ + [72.23001, 71.217514, 69.98001], + [76.61751, 74.48001, 72.68001], + [79.76751, 78.41751, 77.18001], + ], + ], + dtype=np.float32, + ), + coords=dict( + time=xr.DataArray( + data=np.arange( + start=np.datetime64("1987-10-01 08:00:00"), + stop=np.datetime64("1987-10-01 09:00:00.001"), + step=datetime.timedelta(hours=1), + ), + name="time", + dims=("time",), + ), + lat=xr.DataArray( + data=np.array([36.75, 36.5, 36.25], dtype=np.float32), + name="lat", + dims=("lat",), + ), + lon=xr.DataArray( + data=np.array([-122.0, -121.75, -121.5], dtype=np.float32), + name="lon", + dims=("lon",), + ), + ), + ) + assert_equal(ds[data_var], reference)