diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bcd4456e2..528e56c99 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,7 +51,7 @@ jobs: - name: Test titiler.xarray run: | - python -m pip install -e src/titiler/xarray["test"] + python -m pip install -e src/titiler/xarray["test,all"] python -m pytest src/titiler/xarray --cov=titiler.xarray --cov-report=xml --cov-append --cov-report=term-missing - name: Test titiler.mosaic diff --git a/src/titiler/xarray/pyproject.toml b/src/titiler/xarray/pyproject.toml index 78747a86d..123740604 100644 --- a/src/titiler/xarray/pyproject.toml +++ b/src/titiler/xarray/pyproject.toml @@ -30,19 +30,25 @@ classifiers = [ dynamic = ["version"] dependencies = [ "titiler.core==0.19.0.dev", - "cftime", - "h5netcdf", "xarray", "rioxarray", - "zarr", "fsspec", - "s3fs", - "aiohttp", - "pandas", - "httpx", + "zarr", + "h5netcdf", + "cftime", ] [project.optional-dependencies] +s3 = [ + "s3fs", +] +http = [ + "aiohttp", +] +all = [ + "s3fs", + "aiohttp", +] test = [ "pytest", "pytest-cov", diff --git a/src/titiler/xarray/tests/fixtures/generate_fixtures.ipynb b/src/titiler/xarray/tests/fixtures/generate_fixtures.ipynb index 52c997147..2b2b733b2 100644 --- a/src/titiler/xarray/tests/fixtures/generate_fixtures.ipynb +++ b/src/titiler/xarray/tests/fixtures/generate_fixtures.ipynb @@ -138,6 +138,30 @@ " ds.to_zarr(store=f\"pyramid.zarr\", mode=\"w\", group=ix)" ] }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import fsspec\n", + "from kerchunk.hdf import SingleHdf5ToZarr\n", + "\n", + "with fsspec.open(\"dataset_3d.nc\", mode=\"rb\", anon=True) as infile:\n", + " h5chunks = SingleHdf5ToZarr(infile, \"dataset_3d.nc\", inline_threshold=100)\n", + "\n", + " with open(\"reference.json\", 'w') as f:\n", + " f.write(json.dumps(h5chunks.translate()));\n" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/titiler/xarray/tests/test_io_tools.py b/src/titiler/xarray/tests/test_io_tools.py index bb280d71f..918fc709c 100644 --- a/src/titiler/xarray/tests/test_io_tools.py +++ b/src/titiler/xarray/tests/test_io_tools.py @@ -109,12 +109,19 @@ def test_get_variable(): @pytest.mark.parametrize( - "filename", - ["dataset_2d.nc", "dataset_3d.nc", "dataset_3d.zarr"], + "protocol,filename", + [ + ("file://", "dataset_2d.nc"), + ("file://", "dataset_3d.nc"), + ("file://", "dataset_3d.zarr"), + ("", "dataset_2d.nc"), + ("", "dataset_3d.nc"), + ("", "dataset_3d.zarr"), + ], ) -def test_reader(filename): +def test_reader(protocol, filename): """test reader.""" - src_path = os.path.join(prefix, filename) + src_path = protocol + os.path.join(protocol, prefix, filename) assert Reader.list_variables(src_path) == ["dataset"] with Reader(src_path, variable="dataset") as src: diff --git a/src/titiler/xarray/titiler/xarray/dependencies.py b/src/titiler/xarray/titiler/xarray/dependencies.py index 5c13d9e7b..0b4b87cc0 100644 --- a/src/titiler/xarray/titiler/xarray/dependencies.py +++ b/src/titiler/xarray/titiler/xarray/dependencies.py @@ -22,14 +22,6 @@ class XarrayIOParams(DefaultDependency): ), ] = None - reference: Annotated[ - Optional[bool], - Query( - title="reference", - description="Whether the dataset is a kerchunk reference", - ), - ] = None - decode_times: Annotated[ Optional[bool], Query( @@ -38,14 +30,6 @@ class XarrayIOParams(DefaultDependency): ), ] = None - consolidated: Annotated[ - Optional[bool], - Query( - title="consolidated", - description="Whether to expect and open zarr store with consolidated metadata", - ), - ] = None - # cache_client diff --git a/src/titiler/xarray/titiler/xarray/io.py b/src/titiler/xarray/titiler/xarray/io.py index 2c575a7ce..cf0232022 100644 --- a/src/titiler/xarray/titiler/xarray/io.py +++ b/src/titiler/xarray/titiler/xarray/io.py @@ -1,18 +1,27 @@ """titiler.xarray.io""" import pickle -import re from typing import Any, Callable, Dict, List, Optional, Protocol +from urllib.parse import urlparse import attr import fsspec import numpy -import s3fs import xarray from morecantile import TileMatrixSet from rio_tiler.constants import WEB_MERCATOR_TMS from rio_tiler.io.xarray import XarrayReader +try: + import s3fs +except ImportError: # pragma: nocover + s3fs = None # type: ignore + +try: + import aiohttp +except ImportError: # pragma: nocover + aiohttp = None # type: ignore + class CacheClient(Protocol): """CacheClient Protocol.""" @@ -26,18 +35,10 @@ def set(self, key: str, body: bytes) -> None: ... -def parse_protocol(src_path: str, reference: Optional[bool] = False) -> str: +def parse_protocol(src_path: str) -> str: """Parse protocol from path.""" - match = re.match(r"^(s3|https|http)", src_path) - protocol = "file" - if match: - protocol = match.group(0) - - # override protocol if reference - if reference: - protocol = "reference" - - return protocol + parsed = urlparse(src_path) + return parsed.scheme or "file" def xarray_engine(src_path: str) -> str: @@ -45,8 +46,8 @@ def xarray_engine(src_path: str) -> str: # ".hdf", ".hdf5", ".h5" will be supported once we have tests + expand the type permitted for the group parameter if any(src_path.lower().endswith(ext) for ext in [".nc", ".nc4"]): return "h5netcdf" - else: - return "zarr" + + return "zarr" def get_filesystem( @@ -59,6 +60,8 @@ def get_filesystem( Get the filesystem for the given source path. """ if protocol == "s3": + assert s3fs is not None, "s3fs must be installed to support S3:// url" + s3_filesystem = s3fs.S3FileSystem() return ( s3_filesystem.open(src_path) @@ -66,11 +69,12 @@ def get_filesystem( else s3fs.S3Map(root=src_path, s3=s3_filesystem) ) - elif protocol == "reference": - reference_args = {"fo": src_path, "remote_options": {"anon": anon}} - return fsspec.filesystem("reference", **reference_args).get_mapper("") - elif protocol in ["https", "http", "file"]: + if protocol.startswith("http"): + assert ( + aiohttp is not None + ), "aiohttp must be installed to support HTTP:// url" + filesystem = fsspec.filesystem(protocol) # type: ignore return ( filesystem.open(src_path) @@ -85,9 +89,7 @@ def get_filesystem( def xarray_open_dataset( src_path: str, group: Optional[Any] = None, - reference: Optional[bool] = False, decode_times: Optional[bool] = True, - consolidated: Optional[bool] = True, cache_client: Optional[CacheClient] = None, ) -> xarray.Dataset: """Open dataset.""" @@ -98,7 +100,7 @@ def xarray_open_dataset( if data_bytes: return pickle.loads(data_bytes) - protocol = parse_protocol(src_path, reference=reference) + protocol = parse_protocol(src_path) xr_engine = xarray_engine(src_path) file_handler = get_filesystem(src_path, protocol, xr_engine) @@ -115,19 +117,26 @@ def xarray_open_dataset( # NetCDF arguments if xr_engine == "h5netcdf": - xr_open_args["engine"] = "h5netcdf" - xr_open_args["lock"] = False - else: - # Zarr arguments - xr_open_args["engine"] = "zarr" - xr_open_args["consolidated"] = consolidated + xr_open_args.update( + { + "engine": "h5netcdf", + "lock": False, + } + ) - # Additional arguments when dealing with a reference file. - if reference: - xr_open_args["consolidated"] = False - xr_open_args["backend_kwargs"] = {"consolidated": False} + ds = xarray.open_dataset(file_handler, **xr_open_args) + + # Fallback to Zarr + else: + if protocol == "reference": + xr_open_args.update( + { + "consolidated": False, + "backend_kwargs": {"consolidated": False}, + } + ) - ds = xarray.open_dataset(file_handler, **xr_open_args) + ds = xarray.open_zarr(file_handler, **xr_open_args) if cache_client: # Serialize the dataset to bytes using pickle @@ -245,9 +254,7 @@ class Reader(XarrayReader): opener: Callable[..., xarray.Dataset] = attr.ib(default=xarray_open_dataset) group: Optional[Any] = attr.ib(default=None) - reference: bool = attr.ib(default=False) decode_times: bool = attr.ib(default=False) - consolidated: Optional[bool] = attr.ib(default=True) cache_client: Optional[CacheClient] = attr.ib(default=None) # xarray.DataArray options @@ -266,9 +273,7 @@ def __attrs_post_init__(self): self.ds = self.opener( self.src_path, group=self.group, - reference=self.reference, decode_times=self.decode_times, - consolidated=self.consolidated, cache_client=self.cache_client, ) @@ -293,14 +298,10 @@ def list_variables( cls, src_path: str, group: Optional[Any] = None, - reference: Optional[bool] = False, - consolidated: Optional[bool] = True, ) -> List[str]: """List available variable in a dataset.""" with xarray_open_dataset( src_path, group=group, - reference=reference, - consolidated=consolidated, ) as ds: return list(ds.data_vars) # type: ignore