Skip to content
forked from pydata/xarray

Commit

Permalink
Support rechunking to a frequency.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jun 22, 2024
1 parent 01fbf50 commit 9323c08
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 3 deletions.
43 changes: 40 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2666,11 +2666,13 @@ def chunk(
sizes along that dimension will not be updated; non-dask arrays will be
converted into dask arrays with a single block.
Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted.
Parameters
----------
chunks : int, tuple of int, "auto" or mapping of hashable to int, optional
chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional
Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or
``{"x": 5, "y": 5}``.
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``.
name_prefix : str, default: "xarray-"
Prefix for the name of any new dask arrays.
token : str, optional
Expand Down Expand Up @@ -2705,6 +2707,8 @@ def chunk(
xarray.unify_chunks
dask.array.from_array
"""
from xarray.core.dataarray import DataArray

if chunks is None and not chunks_kwargs:
warnings.warn(
"None value for 'chunks' is deprecated. "
Expand All @@ -2730,6 +2734,39 @@ def chunk(
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}"
)

def _resolve_frequency(name: Hashable, freq: str) -> tuple[int]:
variable = self._variables.get(name, None)
if variable is None:
raise ValueError(
f"Cannot chunk by frequency string {freq!r} for virtual variables."
)
elif not _contains_datetime_like_objects(variable):
raise ValueError(
f"chunks={freq!r} only supported for datetime variables. "
f"Received variable {name!r} with dtype {variable.dtype!r} instead."
)

chunks = tuple(
DataArray(
np.ones(variable.shape, dtype=int),
dims=(name,),
coords={name: variable},
)
# TODO: This could be generalized to `freq` being a `Resampler` object,
# and using `groupby` instead of `resample`
.resample({name: freq})
.sum()
.data.tolist()
)
return chunks

chunks_mapping_ints = {
name: (
_resolve_frequency(name, chunks) if isinstance(chunks, str) else chunks
)
for name, chunks in chunks_mapping.items()
}

chunkmanager = guess_chunkmanager(chunked_array_type)
if from_array_kwargs is None:
from_array_kwargs = {}
Expand All @@ -2738,7 +2775,7 @@ def chunk(
k: _maybe_chunk(
k,
v,
chunks_mapping,
chunks_mapping_ints,
token,
lock,
name_prefix,
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def copy(
# FYI in some cases we don't allow `None`, which this doesn't take account of.
# FYI the `str` is for a size string, e.g. "16MB", supported by dask.
T_ChunkDim: TypeAlias = Union[str, int, Literal["auto"], None, tuple[int, ...]]
T_FreqStr: TypeAlias = str
T_ChunkDim: TypeAlias = Union[T_FreqStr, int, Literal["auto"], None, tuple[int, ...]]
# We allow the tuple form of this (though arguably we could transition to named dims only)
T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]]
T_NormalizedChunks = tuple[tuple[int, ...], ...]
Expand Down
39 changes: 39 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,45 @@ def get_dask_names(ds):
):
data.chunk({"foo": 10})

@requires_dask
@pytest.mark.parametrize(
"calendar",
(
"standard",
pytest.param(
"gregorian",
marks=pytest.mark.skipif(not has_cftime, reason="needs cftime"),
),
),
)
@pytest.mark.parametrize("freq", ["D", "W", "5ME", "YE"])
def test_chunk_by_frequency(self, freq, calendar) -> None:
import dask.array

N = 365 * 2
ds = Dataset(
{
"pr": ("time", dask.array.random.random((N), chunks=(20))),
"ones": ("time", np.ones((N,))),
},
coords={
"time": xr.date_range(
"2001-01-01", periods=N, freq="D", calendar=calendar
)
},
)
actual = ds.chunk(time=freq).chunksizes["time"]
expected = tuple(ds.ones.resample(time=freq).sum().data.tolist())
assert expected == actual

def test_chunk_by_frequecy_errors(self):
ds = Dataset({"foo": ("x", [1, 2, 3])})
with pytest.raises(ValueError, match="virtual variable"):
ds.chunk(x="YE")
ds["x"] = ("x", [1, 2, 3])
with pytest.raises(ValueError, match="datetime variables"):
ds.chunk(x="YE")

@requires_dask
def test_dask_is_lazy(self) -> None:
store = InaccessibleVariableDataStore()
Expand Down

0 comments on commit 9323c08

Please sign in to comment.