Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYPE: Standardize resampling type #1571

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions datacube/api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from datacube.storage import reproject_and_fuse, BandInfo
from datacube.utils import ignore_exceptions_if
from odc.geo import CRS, yx_, res_, resyx_, Resolution, XY
from odc.geo.warp import Resampling
from odc.geo.xr import xr_coords
from datacube.utils.dates import normalise_dt
from odc.geo.geom import intersects, box, bbox_union, Geometry
Expand Down Expand Up @@ -244,7 +245,7 @@ def load(self,
measurements: str | list[str] | None = None,
output_crs: Any = None,
resolution: int | float | tuple[int | float, int | float] | Resolution | None = None,
resampling: str | dict[str, str] | None = None,
resampling: Resampling | dict[str, Resampling] | None = None,
align: XY[float] | Iterable[float] | None = None,
skip_broken_datasets: bool = False,
dask_chunks: dict[str, str | int] | None = None,
Expand Down Expand Up @@ -878,7 +879,7 @@ def _cbk(*ignored):
@staticmethod
def load_data(sources: xarray.DataArray, geobox: GeoBox,
measurements: Mapping[str, Measurement] | list[Measurement],
resampling: str | dict[str, str] | None = None,
resampling: Resampling | dict[str, Resampling] | None = None,
fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None,
dask_chunks: dict[str, str | int] | None = None,
skip_broken_datasets: bool = False,
Expand Down Expand Up @@ -969,7 +970,7 @@ def __exit__(self, type_, value, traceback):


def per_band_load_data_settings(measurements: list[Measurement] | Mapping[str, Measurement],
resampling: str | Mapping[str, str] | None = None,
resampling: Resampling | Mapping[str, Resampling] | None = None,
fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None
) -> list[Measurement]:
def with_resampling(m, resampling, default=None):
Expand All @@ -982,7 +983,7 @@ def with_fuser(m, fuser, default=None):
m['fuser'] = fuser.get(m.name, default)
return m

if isinstance(resampling, str):
if resampling is not None and not isinstance(resampling, dict):
resampling = {'*': resampling}

if fuse_func is None or callable(fuse_func):
Expand Down
3 changes: 2 additions & 1 deletion datacube/storage/_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from odc.geo.geobox import GeoBox
from odc.geo.roi import roi_is_empty
from odc.geo.xr import xr_coords
from odc.geo.warp import Resampling
from datacube.model import Measurement
from datacube.drivers._types import ReaderDriver
from ..drivers.datasource import DataSource
Expand All @@ -47,7 +48,7 @@ def reproject_and_fuse(datasources: List[DataSource],
destination: np.ndarray,
dst_geobox: GeoBox,
dst_nodata: Optional[Union[int, float]],
resampling: str = 'nearest',
resampling: Resampling = 'nearest',
fuse_func: Optional[FuserFunction] = None,
skip_broken_datasets: bool = False,
progress_cbk: Optional[ProgressFunction] = None,
Expand Down
10 changes: 7 additions & 3 deletions datacube/utils/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .io import check_write_path
from odc.geo.geobox import GeoBox
from odc.geo.math import align_up
from odc.geo.warp import Resampling, resampling_s2rio

from deprecat import deprecat

Expand All @@ -38,7 +39,7 @@ def _write_cog(
nodata: Optional[float] = None,
overwrite: bool = False,
blocksize: Optional[int] = None,
overview_resampling: Optional[str] = None,
overview_resampling: Optional[Resampling] = None,
overview_levels: Optional[List[int]] = None,
ovr_blocksize: Optional[int] = None,
use_windowed_writes: bool = False,
Expand Down Expand Up @@ -118,7 +119,10 @@ def _write_cog(
fname, overwrite
) # aborts if overwrite=False and file exists already

resampling = rasterio.enums.Resampling[overview_resampling]
if isinstance(overview_resampling, str):
resampling = resampling_s2rio(overview_resampling)
else:
resampling = overview_resampling

if (blocksize % 16) != 0:
warnings.warn("Block size must be a multiple of 16, will be adjusted")
Expand Down Expand Up @@ -219,7 +223,7 @@ def write_cog(
overwrite: bool = False,
blocksize: Optional[int] = None,
ovr_blocksize: Optional[int] = None,
overview_resampling: Optional[str] = None,
overview_resampling: Optional[Resampling] = None,
overview_levels: Optional[List[int]] = None,
use_windowed_writes: bool = False,
intermediate_compression: Union[bool, str, Dict[str, Any]] = False,
Expand Down
3 changes: 3 additions & 0 deletions docs/about/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ What's New
v1.9.next
=========

- Standardize resampling input supported to `odc.geo.warp.Resampling`.


v1.9.0-rc3 (27th March 2024)
============================

Expand Down
28 changes: 20 additions & 8 deletions tests/storage/test_storage_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# SPDX-License-Identifier: Apache-2.0
import numpy as np

import pytest
from rasterio.enums import Resampling

from datacube.storage._read import (
read_time_slice,
read_time_slice_v2,
Expand All @@ -27,14 +30,20 @@
)


nearest_resampling_parametrize = pytest.mark.parametrize(
"nearest_resampling", ['nearest', Resampling.nearest, Resampling.nearest.value]
)


def test_pick_read_scale():
assert pick_read_scale(0.7) == 1
assert pick_read_scale(1.3) == 1
assert pick_read_scale(2.3) == 2
assert pick_read_scale(1.99999) == 2


def test_read_paste(tmpdir):
@nearest_resampling_parametrize
def test_read_paste(nearest_resampling, tmpdir):
from datacube.testutils import mk_test_image
from datacube.testutils.io import write_gtiff
from pathlib import Path
Expand All @@ -46,7 +55,7 @@ def test_read_paste(tmpdir):

mm = write_gtiff(pp/'tst-read-paste-128x64-int16.tif', xx, nodata=None)

def _read(geobox, resampling='nearest',
def _read(geobox, resampling=nearest_resampling,
fallback_nodata=-999,
dst_nodata=-999,
check_paste=False):
Expand Down Expand Up @@ -112,7 +121,8 @@ def _read(geobox, resampling='nearest',
np.testing.assert_array_equal(xx[1::2, 1::2], yy)


def test_read_with_reproject(tmpdir):
@nearest_resampling_parametrize
def test_read_with_reproject(nearest_resampling, tmpdir):
from datacube.testutils import mk_test_image
from datacube.testutils.io import write_gtiff
from pathlib import Path
Expand All @@ -131,7 +141,7 @@ def test_read_with_reproject(tmpdir):
assert mm.geobox == tile

def _read(geobox,
resampling='nearest',
resampling=nearest_resampling,
fallback_nodata=None,
dst_nodata=-999):
with RasterFileDataSource(mm.path, 1, nodata=fallback_nodata).open() as rdr:
Expand Down Expand Up @@ -171,7 +181,8 @@ def _read(geobox,
assert nvalid > nempty


def test_read_paste_v2(tmpdir):
@nearest_resampling_parametrize
def test_read_paste_v2(nearest_resampling, tmpdir):
from datacube.testutils import mk_test_image
from datacube.testutils.io import write_gtiff
from datacube.testutils.iodriver import open_reader
Expand All @@ -184,7 +195,7 @@ def test_read_paste_v2(tmpdir):

mm = write_gtiff(pp/'tst-read-paste-128x64-int16.tif', xx, nodata=None)

def _read(geobox, resampling='nearest',
def _read(geobox, resampling=nearest_resampling,
fallback_nodata=-999,
dst_nodata=-999,
check_paste=False):
Expand Down Expand Up @@ -256,7 +267,8 @@ def _read(geobox, resampling='nearest',
np.testing.assert_array_equal(xx[1::2, 1::2], yy)


def test_read_with_reproject_v2(tmpdir):
@nearest_resampling_parametrize
def test_read_with_reproject_v2(nearest_resampling, tmpdir):
from datacube.testutils import mk_test_image
from datacube.testutils.io import write_gtiff
from datacube.testutils.iodriver import open_reader
Expand All @@ -268,7 +280,7 @@ def test_read_with_reproject_v2(tmpdir):
assert (xx != -999).all()
tile = AlbersGS.tile_geobox((17, -40))[:64, :128]

def _read(geobox, resampling='nearest',
def _read(geobox, resampling=nearest_resampling,
fallback_nodata=-999,
dst_nodata=-999):

Expand Down
Loading