Skip to content

Commit

Permalink
Add initialize_zarr
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 16, 2023
1 parent 22ca9ba commit 994af64
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 4 deletions.
1 change: 1 addition & 0 deletions doc/api.rst
Expand Up @@ -27,6 +27,7 @@ Top-level functions
combine_nested
where
infer_freq
initialize_zarr
full_like
zeros_like
ones_like
Expand Down
3 changes: 2 additions & 1 deletion xarray/__init__.py
Expand Up @@ -9,7 +9,7 @@
open_mfdataset,
save_mfdataset,
)
from xarray.backends.zarr import open_zarr
from xarray.backends.zarr import initialize_zarr, open_zarr
from xarray.coding.cftime_offsets import cftime_range, date_range, date_range_like
from xarray.coding.cftimeindex import CFTimeIndex
from xarray.coding.frequencies import infer_freq
Expand Down Expand Up @@ -75,6 +75,7 @@
"full_like",
"get_options",
"infer_freq",
"initialize_zarr",
"load_dataarray",
"load_dataset",
"map_blocks",
Expand Down
93 changes: 90 additions & 3 deletions xarray/backends/zarr.py
Expand Up @@ -3,8 +3,8 @@
import json
import os
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
from collections.abc import Hashable, Iterable
from typing import TYPE_CHECKING, Any, Literal

import numpy as np

Expand All @@ -19,6 +19,7 @@
)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
from xarray.core.common import zeros_like
from xarray.core.parallelcompat import guess_chunkmanager
from xarray.core.pycompat import integer_types
from xarray.core.utils import (
Expand All @@ -34,11 +35,97 @@
from xarray.backends.common import AbstractDataStore
from xarray.core.dataset import Dataset


# need some special secret attributes to tell us the dimensions
DIMENSION_KEY = "_ARRAY_DIMENSIONS"


def initialize_zarr(
store,
ds: Dataset,
*,
region_dims: Iterable[Hashable] | None = None,
mode: Literal["w", "w-"] = "w-",
**kwargs,
) -> Dataset:
"""
Initialize a Zarr store with metadata.
This function initializes a Zarr store with metadata that describes the entire datasets.
If ``region_dims`` is specified, it will also
1. Write variables that don't contain any of ``region_dims``, and
2. Return a dataset with variables that do contain one or more of ``region_dims``.
This dataset can be used for region writes in parallel.
Parameters
----------
store : MutableMapping or str
Zarr store to write to.
ds : Dataset
Dataset to write.
region_dims : Iterable[Hashable], optional
An iterable of dimension names that will be passed to the ``region``
kwarg of ``to_zarr`` later.
mode : {'w', 'w-'}
Write mode for initializing the store.
Returns
-------
Dataset
Dataset containing variables with one or more ``region_dims``
dimensions. Use this for writing to the store in parallel later.
Raises
------
ValueError
"""

if "compute" in kwargs:
raise ValueError("The ``compute`` kwarg is not supported in `initialize_zarr`.")

if not ds.chunks:
raise ValueError("This function should be used with chunked Datasets.")

if mode not in ["w", "w-"]:
raise ValueError(
f"Only mode='w' or mode='w-' is allowed for initialize_zarr. Received mode={mode!r}"
)

# TODO: what should we do here.
# compute=False only skips dask variables.
# - We could reaplce all dask variables with zeros_like
# - and then write all other variables eagerly.
# Right now we do two writes for eager variables
template = zeros_like(ds)
template.to_zarr(store, mode=mode, **kwargs, compute=False)

if region_dims:
after_drop = ds.drop_dims(region_dims)

# we have to remove the dropped variables from the encoding dictionary :/
new_encoding = kwargs.pop("encoding", None)
if new_encoding:
new_encoding = {k: v for k, v in new_encoding.items() if k in after_drop}

after_drop.to_zarr(
store, **kwargs, encoding=new_encoding, compute=True, mode="a"
)

# can't use drop_dims since that will also remove any variable
# with any of the dims to be dropped
# even if they also have one or more of region_dims
dims_to_drop = set(ds.dims) - set(region_dims)
vars_to_drop = [
name
for name, var in ds._variables.items()
if set(var.dims).issubset(dims_to_drop)
]
return ds.drop_vars(vars_to_drop)

else:
return ds


def encode_zarr_attr_value(value):
"""
Encode a attribute value as something that can be serialized as json
Expand Down
55 changes: 55 additions & 0 deletions xarray/tests/test_backends.py
Expand Up @@ -48,6 +48,7 @@
)
from xarray.backends.pydap_ import PydapDataStore
from xarray.backends.scipy_ import ScipyBackendEntrypoint
from xarray.backends.zarr import initialize_zarr
from xarray.coding.strings import check_vlen_dtype, create_vlen_dtype
from xarray.coding.variables import SerializationWarning
from xarray.conventions import encode_dataset_coordinates
Expand Down Expand Up @@ -5434,3 +5435,57 @@ def test_zarr_region_transpose(tmp_path):
ds_region.to_zarr(
tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)}
)


@requires_dask
@requires_zarr
def test_initialize_zarr(tmp_path) -> None:
# TODO:
# 1. with encoding
# 2. with regions
# 3. w-
# 4. mode = r?
# 5. mode=w+?
# 5. no region_dims
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = dask.array.ones((5, 10), chunks=(1, -1))
ds = xr.Dataset(
{
"xy": (("x", "y"), data),
"xonly": ("x", data[:, 0]),
"yonly": ("y", data[0, :]),
"eager_xonly": ("x", data[:, 0].compute()),
"eager_yonly": ("y", data[0, :].compute().astype(int)),
"scalar": 2,
},
coords={"x": x, "y": y},
)
store = tmp_path / "foo.zarr"

with pytest.raises(ValueError, match="Only mode"):
initialize_zarr(store, ds, mode="r")

expected_on_disk = ds.copy(deep=True).assign(
{
# chunked variables are all NaNs (really fill_value?)
"xy": xr.full_like(ds.xy, fill_value=np.nan),
"xonly": xr.full_like(ds.xonly, fill_value=np.nan),
# eager variables with region_dim are all zeros (since we do zeros_like)
"eager_xonly": xr.full_like(ds.xonly, fill_value=0),
# eager variables without region_dim are identical
# but are subject to two writes, first zeros then actual values
"eager_yonly": ds.yonly,
}
)
expected_after_init = ds.drop_vars(["yonly", "eager_yonly", "y", "scalar"])
after_init = initialize_zarr(store, ds, region_dims=("x",))
assert_identical(expected_after_init, after_init)

with xr.open_zarr(store) as actual:
assert_identical(expected_on_disk, actual)

for i in range(ds.sizes["x"]):
after_init.isel(x=[i]).to_zarr(store, region={"x": slice(i, i + 1)})
with xr.open_zarr(store) as actual:
assert_identical(ds, actual)

0 comments on commit 994af64

Please sign in to comment.