Skip to content

Commit

Permalink
refactor: in zarr driver
Browse files Browse the repository at this point in the history
- some names not longer fit
- include reference to current driver in the
  loader context
  • Loading branch information
Kirill888 committed Jun 24, 2024
1 parent f3e1a23 commit 975dada
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions odc/loader/_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ class XrMDPlugin:
def __init__(
self,
template: RasterGroupMetadata,
src: xr.Dataset | None = None,
fallback: xr.Dataset | None = None,
) -> None:
self._template = template
self._src = src
self._fallback = fallback

def _resolve_src(self, md: Any, regen_coords: bool = False) -> xr.Dataset | None:
return _resolve_src_dataset(
md, regen_coords=regen_coords, fallback=self._src, chunks={}
md, regen_coords=regen_coords, fallback=self._fallback, chunks={}
)

def extract(self, md: Any) -> RasterGroupMetadata:
Expand Down Expand Up @@ -150,9 +150,11 @@ def __init__(
self,
geobox: GeoBox,
chunks: None | dict[str, int],
driver: Any | None = None,
) -> None:
self.geobox = geobox
self.chunks = chunks
self.driver = driver

def with_env(self, env: dict[str, Any]) -> "Context":
assert isinstance(env, dict)
Expand All @@ -161,7 +163,10 @@ def with_env(self, env: dict[str, Any]) -> "Context":

class XrMemReader:
"""
Protocol for raster readers.
Implements protocol for raster readers.
- Read from in-memory xarray.Dataset
- Read from zarr spec
"""

# pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -242,7 +247,7 @@ def read(

class XrMemReaderDriver:
"""
Read from in memory xarray.Dataset.
Read from in memory xarray.Dataset or zarr spec document.
"""

Reader = XrMemReader
Expand All @@ -265,7 +270,7 @@ def new_load(
*,
chunks: None | dict[str, int] = None,
) -> Context:
return Context(geobox, chunks)
return Context(geobox, chunks, driver=self)

def finalise_load(self, load_state: Context) -> Context:
return load_state
Expand All @@ -284,7 +289,7 @@ def open(self, src: RasterSource, ctx: Context) -> XrMemReader:

@property
def md_parser(self) -> MDParser:
return XrMDPlugin(self.template, src=self.src)
return XrMDPlugin(self.template, fallback=self.src)

@property
def dask_reader(self) -> DaskRasterReader | None:
Expand Down

0 comments on commit 975dada

Please sign in to comment.