Skip to content

Commit

Permalink
refactor: change Dask reader API again
Browse files Browse the repository at this point in the history
Move `cfg=` into `.open` rather `.read`
  • Loading branch information
Kirill888 committed Jun 28, 2024
1 parent 5915054 commit a6d6afb
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 32 deletions.
8 changes: 5 additions & 3 deletions odc/loader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def resolve_sources(
out: List[List[tuple[int, RasterSource]]] = []

for layer in self.srcs:
_srcs: List[RasterSource] = []
_srcs: List[tuple[int, RasterSource]] = []
for idx in layer:
src = srcs[idx].get(self.band, None)
if src is not None:
Expand Down Expand Up @@ -280,10 +280,12 @@ def _task_futures(
src_hash = tokenize(src)
rdr = rdr_cache.get(src_hash, None)
if rdr is None:
rdr = dask_reader.open(src, ctx, layer_name=layer_name, idx=i_src)
rdr = dask_reader.open(
src, cfg, ctx, layer_name=layer_name, idx=i_src
)
rdr_cache[src_hash] = rdr

fut = rdr.read(cfg, dst_gbox, selection=task.selection, idx=idx)
fut = rdr.read(dst_gbox, selection=task.selection, idx=idx)
keys_out.append(fut.key)
dsk.update(fut.dask)

Expand Down
15 changes: 12 additions & 3 deletions odc/loader/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
env: dict[str, Any] | None = None,
ctx: Any | None = None,
src: RasterSource | None = None,
cfg: RasterLoadParams | None = None,
layer_name: str = "",
idx: int = -1,
) -> None:
Expand All @@ -63,26 +64,28 @@ def __init__(
self._env = env
self._ctx = ctx
self._src = src
self._cfg = cfg
self._layer_name = layer_name
self._src_idx = idx

def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
*,
selection: Optional[ReaderSubsetSelection] = None,
idx: tuple[int, ...],
) -> Any:
assert self._src is not None
assert self._ctx is not None
assert self._cfg is not None

read_op = delayed(_dask_read_adaptor, name=self._layer_name)

# TODO: supply `dask_key_name=` that makes sense
return read_op(
self._src,
self._ctx,
cfg,
self._cfg,
dst_geobox,
self._driver,
self._env,
Expand All @@ -91,13 +94,19 @@ def read(
)

def open(
self, src: RasterSource, ctx: Any, layer_name: str, idx: int
self,
src: RasterSource,
cfg: RasterLoadParams,
ctx: Any,
layer_name: str,
idx: int,
) -> "ReaderDaskAdaptor":
return ReaderDaskAdaptor(
self._driver,
self._env,
ctx,
src,
cfg,
layer_name=layer_name,
idx=idx,
)
Expand Down
72 changes: 54 additions & 18 deletions odc/loader/_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from dask.base import tokenize
from dask.delayed import Delayed, delayed
from fsspec.core import url_to_fs
from odc.geo.geobox import GeoBox, GeoBoxBase
from odc.geo.gcp import GCPGeoBox
from odc.geo.geobox import GeoBox, GeoBoxBase, GeoboxTiles
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_coords, xr_reproject

from .types import (
Expand Down Expand Up @@ -106,9 +107,17 @@ def __init__(
chunks: None | dict[str, int],
driver: Any | None = None,
) -> None:
gbt: GeoboxTiles | None = None
if chunks is not None:
cy, cx = (
chunks.get(name, fallback)
for name, fallback in zip(["y", "x"], geobox.shape.yx)
)
gbt = GeoboxTiles(geobox, (cy, cx))
self.geobox = geobox
self.chunks = chunks
self.driver = driver
self.gbt = gbt

def with_env(self, env: dict[str, Any]) -> "Context":
assert isinstance(env, dict)
Expand Down Expand Up @@ -183,6 +192,7 @@ def from_raster_source(

coords = {**xx.coords}
if geobox is not None:
assert isinstance(geobox, (GeoBox, GCPGeoBox))
coords.update(xr_coords(geobox, dims=xx.odc.spatial_dims or ("y", "x")))

xx = xr.DataArray(
Expand Down Expand Up @@ -246,38 +256,42 @@ class XrMemReaderDask:
def __init__(
self,
src: xr.DataArray | None = None,
cfg: RasterLoadParams | None = None,
layer_name: str = "",
) -> None:
self._layer_name = layer_name
self._xx = src
self._cfg = cfg

def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
*,
selection: ReaderSubsetSelection | None = None,
idx: tuple[int, ...] = (),
) -> Delayed:
assert self._xx is not None
assert self._cfg is not None
assert isinstance(idx, tuple)
xx = self._xx
assert isinstance(xx.odc, ODCExtensionDa)
assert isinstance(xx.odc.geobox, GeoBox)
assert xx.odc.spatial_dims is not None

xx = _select_extra_dims(self._xx, selection, cfg)
assert xx.odc.geobox is not None
yx_roi = xx.odc.geobox.overlap_roi(dst_geobox)
selection = _extra_dims_selector(selection, self._cfg)
selection.update(zip(xx.odc.spatial_dims, yx_roi))

yy = xr_reproject(
xx,
dst_geobox,
resampling=cfg.resampling,
dst_nodata=cfg.fill_value,
dtype=cfg.dtype,
chunks=dst_geobox.shape.yx,
)
return delayed(_with_roi)(yy.data, dask_key_name=(self._layer_name, *idx))
xx = self._xx.isel(selection)
out_key = (self._layer_name, *idx)
fut = delayed(_with_roi)(xx.data, dask_key_name=out_key)

return fut

def open(
self,
src: RasterSource,
cfg: RasterLoadParams,
ctx: Context,
*,
layer_name: str,
Expand All @@ -290,7 +304,20 @@ def open(

assert xx.odc.geobox is not None
assert not any(map(math.isnan, xx.odc.geobox.transform[:6]))
return XrMemReaderDask(xx, layer_name=layer_name)
assert ctx.gbt is not None
gbt = ctx.gbt
assert isinstance(gbt.base, GeoBox)

xx_warped = xr_reproject(
xx,
gbt.base,
resampling=cfg.resampling,
dst_nodata=cfg.fill_value,
dtype=cfg.dtype,
chunks=gbt.chunk_shape((0, 0)).yx,
)

return XrMemReaderDask(xx_warped, cfg, layer_name=layer_name)


class XrMemReaderDriver:
Expand Down Expand Up @@ -483,16 +510,25 @@ def _with_roi(xx: np.ndarray) -> tuple[tuple[slice, slice], np.ndarray]:
return (slice(None), slice(None)), xx


def _extra_dims_selector(
selection: ReaderSubsetSelection, cfg: RasterLoadParams
) -> dict[str, Any]:
if selection is None:
return {}

assert isinstance(selection, (slice, int)) or len(selection) == 1
assert len(cfg.extra_dims) == 1
(band_dim,) = cfg.extra_dims
return {band_dim: selection}


def _select_extra_dims(
src: xr.DataArray, selection: ReaderSubsetSelection, cfg: RasterLoadParams
) -> xr.DataArray:
if selection is None:
return src

assert isinstance(selection, (slice, int)) or len(selection) == 1
assert len(cfg.extra_dims) == 1
(band_dim,) = cfg.extra_dims
return src.isel({band_dim: selection})
return src.isel(_extra_dims_selector(selection, cfg))


def extract_zarr_spec(src: SomeDoc) -> ZarrSpecDict | None:
Expand Down
4 changes: 2 additions & 2 deletions odc/loader/test_memreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,12 @@ def test_memreader_zarr(sample_ds: xr.Dataset):
ctx = driver.new_load(gbox, chunks={})
assert isinstance(ctx, Context)

rdr = driver.dask_reader.open(src, ctx, layer_name=f"xx-{tk}", idx=0)
rdr = driver.dask_reader.open(src, cfg, ctx, layer_name=f"xx-{tk}", idx=0)
assert isinstance(rdr, XrMemReaderDask)
assert rdr._xx is not None
assert is_dask_collection(rdr._xx)

fut = rdr.read(cfg, gbox)
fut = rdr.read(gbox)
assert is_dask_collection(fut)

roi, xx = fut.compute(scheduler="synchronous")
Expand Down
10 changes: 5 additions & 5 deletions odc/loader/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,16 +389,16 @@ def test_dask_reader_adaptor(dtype: str):
ctx = base_driver.new_load(gbox, chunks={"x": 64, "y": 64})

src = RasterSource("mem://", meta=meta)
rdr = driver.open(src, ctx, layer_name="aa", idx=0)
cfg = RasterLoadParams.same_as(src)
rdr = driver.open(src, cfg, ctx, layer_name="aa", idx=0)

assert isinstance(rdr, ReaderDaskAdaptor)

cfg = RasterLoadParams.same_as(src)
xx = rdr.read(cfg, gbox, idx=(0,))
xx = rdr.read(gbox, idx=(0,))
assert is_dask_collection(xx)
assert xx.key == ("aa", 0)
assert rdr.read(cfg, gbox, idx=(1,)).key == ("aa", 1)
assert rdr.read(cfg, gbox, idx=(1, 2, 3)).key == ("aa", 1, 2, 3)
assert rdr.read(gbox, idx=(1,)).key == ("aa", 1)
assert rdr.read(gbox, idx=(1, 2, 3)).key == ("aa", 1, 2, 3)

yy = xx.compute(scheduler="synchronous")
assert isinstance(yy, tuple)
Expand Down
2 changes: 1 addition & 1 deletion odc/loader/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,6 @@ class DaskRasterReader(Protocol):

def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
*,
selection: Optional[ReaderSubsetSelection] = None,
Expand All @@ -481,6 +480,7 @@ def read(
def open(
self,
src: RasterSource,
cfg: RasterLoadParams,
ctx: Any,
*,
layer_name: str,
Expand Down

0 comments on commit a6d6afb

Please sign in to comment.