Skip to content
forked from pydata/xarray

Commit

Permalink
Merge branch 'main' into depr-groupby-squeeze-2
Browse files Browse the repository at this point in the history
* main:
  Fix mypy type ignore (pydata#8564)
  Support for the new compression arguments. (pydata#7551)
  FIX: reverse index output of bottleneck move_argmax/move_argmin functions (pydata#8552)
  Adapt map_blocks to use new Coordinates API (pydata#8560)
  add xeofs to ecosystem.rst (pydata#8561)
  Offer a fixture for unifying DataArray & Dataset tests (pydata#8533)
  Generalize cumulative reduction (scan) to non-dask types (pydata#8019)
  • Loading branch information
dcherian committed Dec 22, 2023
2 parents d6a3f2d + 03ec3cb commit a064430
Show file tree
Hide file tree
Showing 15 changed files with 348 additions and 98 deletions.
1 change: 1 addition & 0 deletions doc/ecosystem.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Extend xarray capabilities
- `xarray-dataclasses <https://github.com/astropenguin/xarray-dataclasses>`_: xarray extension for typed DataArray and Dataset creation.
- `xarray_einstats <https://xarray-einstats.readthedocs.io>`_: Statistics, linear algebra and einops for xarray
- `xarray_extras <https://github.com/crusaderky/xarray_extras>`_: Advanced algorithms for xarray objects (e.g. integrations/interpolations).
- `xeofs <https://github.com/nicrie/xeofs>`_: PCA/EOF analysis and related techniques, integrated with xarray and Dask for efficient handling of large-scale data.
- `xpublish <https://xpublish.readthedocs.io/>`_: Publish Xarray Datasets via a Zarr compatible REST API.
- `xrft <https://github.com/rabernat/xrft>`_: Fourier transforms for xarray data.
- `xr-scipy <https://xr-scipy.readthedocs.io>`_: A lightweight scipy wrapper for xarray.
Expand Down
11 changes: 11 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ New Features

- :py:meth:`xr.cov` and :py:meth:`xr.corr` now support using weights (:issue:`8527`, :pull:`7392`).
By `Llorenç Lledó <https://github.com/lluritu>`_.
- Accept the compression arguments new in netCDF 1.6.0 in the netCDF4 backend.
See `netCDF4 documentation <https://unidata.github.io/netcdf4-python/#efficient-compression-of-netcdf-variables>`_ for details.
By `Markel García-Díez <https://github.com/markelg>`_. (:issue:`6929`, :pull:`7551`) Note that some
new compression filters needs plugins to be installed which may not be available in all netCDF distributions.

Breaking changes
~~~~~~~~~~~~~~~~
Expand All @@ -39,6 +43,9 @@ Deprecations
Bug fixes
~~~~~~~~~

- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.


Documentation
~~~~~~~~~~~~~
Expand Down Expand Up @@ -589,6 +596,10 @@ Internal Changes

- :py:func:`as_variable` now consistently includes the variable name in any exceptions
raised. (:pull:`7995`). By `Peter Hill <https://github.com/ZedThree>`_
- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`,
potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to
use non-dask chunked array types.
(:pull:`8019`) By `Tom Nicholas <https://github.com/TomNicholas>`_.
- :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to
`coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`).
`By Ian Carroll <https://github.com/itcarroll>`_.
Expand Down
25 changes: 17 additions & 8 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ def _extract_nc4_variable_encoding(
"_FillValue",
"dtype",
"compression",
"significant_digits",
"quantize_mode",
"blosc_shuffle",
"szip_coding",
"szip_pixels_per_block",
"endian",
}
if lsd_okay:
valid_encodings.add("least_significant_digit")
Expand Down Expand Up @@ -497,20 +503,23 @@ def prepare_variable(
if name in self.ds.variables:
nc4_var = self.ds.variables[name]
else:
nc4_var = self.ds.createVariable(
default_args = dict(
varname=name,
datatype=datatype,
dimensions=variable.dims,
zlib=encoding.get("zlib", False),
complevel=encoding.get("complevel", 4),
shuffle=encoding.get("shuffle", True),
fletcher32=encoding.get("fletcher32", False),
contiguous=encoding.get("contiguous", False),
chunksizes=encoding.get("chunksizes"),
zlib=False,
complevel=4,
shuffle=True,
fletcher32=False,
contiguous=False,
chunksizes=None,
endian="native",
least_significant_digit=encoding.get("least_significant_digit"),
least_significant_digit=None,
fill_value=fill_value,
)
default_args.update(encoding)
default_args.pop("_FillValue", None)
nc4_var = self.ds.createVariable(**default_args)

nc4_var.setncatts(attrs)

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class Coordinates(AbstractCoordinates):
:py:class:`~xarray.Coordinates` object is passed, its indexes
will be added to the new created object.
indexes: dict-like, optional
Mapping of where keys are coordinate names and values are
Mapping where keys are coordinate names and values are
:py:class:`~xarray.indexes.Index` objects. If None (default),
pandas indexes will be created for each dimension coordinate.
Passing an empty dictionary will skip this default behavior.
Expand Down
22 changes: 22 additions & 0 deletions xarray/core/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ def reduction(
keepdims=keepdims,
)

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
**kwargs,
) -> DaskArray:
from dask.array.reductions import cumreduction

return cumreduction(
func,
binop,
ident,
arr,
axis=axis,
dtype=dtype,
**kwargs,
)

def apply_gufunc(
self,
func: Callable,
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@
try:
from dask.dataframe import DataFrame as DaskDataFrame
except ImportError:
DaskDataFrame = None # type: ignore
DaskDataFrame = None
try:
from dask.delayed import Delayed
except ImportError:
Delayed = None # type: ignore
Delayed = None # type: ignore[misc,assignment]
try:
from iris.cube import Cube as iris_Cube
except ImportError:
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@
try:
from dask.delayed import Delayed
except ImportError:
Delayed = None # type: ignore
Delayed = None # type: ignore[misc,assignment]
try:
from dask.dataframe import DataFrame as DaskDataFrame
except ImportError:
DaskDataFrame = None # type: ignore
DaskDataFrame = None


# list of attributes of pd.DatetimeIndex that are ndarrays of time info
Expand Down
89 changes: 57 additions & 32 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,29 @@
import itertools
import operator
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict

import numpy as np

from xarray.core.alignment import align
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.indexes import Index
from xarray.core.merge import merge
from xarray.core.pycompat import is_dask_collection

if TYPE_CHECKING:
from xarray.core.types import T_Xarray


class ExpectedDict(TypedDict):
shapes: dict[Hashable, int]
coords: set[Hashable]
data_vars: set[Hashable]
indexes: dict[Hashable, Index]


def unzip(iterable):
return zip(*iterable)

Expand All @@ -31,7 +41,9 @@ def assert_chunks_compatible(a: Dataset, b: Dataset):


def check_result_variables(
result: DataArray | Dataset, expected: Mapping[str, Any], kind: str
result: DataArray | Dataset,
expected: ExpectedDict,
kind: Literal["coords", "data_vars"],
):
if kind == "coords":
nice_str = "coordinate"
Expand Down Expand Up @@ -254,7 +266,7 @@ def _wrapper(
args: list,
kwargs: dict,
arg_is_array: Iterable[bool],
expected: dict,
expected: ExpectedDict,
):
"""
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
Expand Down Expand Up @@ -345,33 +357,45 @@ def _wrapper(
for arg in aligned
)

merged_coordinates = merge([arg.coords for arg in aligned]).coords

_, npargs = unzip(
sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0])
)

# check that chunk sizes are compatible
input_chunks = dict(npargs[0].chunks)
input_indexes = dict(npargs[0]._indexes)
for arg in xarray_objs[1:]:
assert_chunks_compatible(npargs[0], arg)
input_chunks.update(arg.chunks)
input_indexes.update(arg._indexes)

coordinates: Coordinates
if template is None:
# infer template by providing zero-shaped arrays
template = infer_template(func, aligned[0], *args, **kwargs)
template_indexes = set(template._indexes)
preserved_indexes = template_indexes & set(input_indexes)
new_indexes = template_indexes - set(input_indexes)
indexes = {dim: input_indexes[dim] for dim in preserved_indexes}
indexes.update({k: template._indexes[k] for k in new_indexes})
template_coords = set(template.coords)
preserved_coord_vars = template_coords & set(merged_coordinates)
new_coord_vars = template_coords - set(merged_coordinates)

preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars]
# preserved_coords contains all coordinates bariables that share a dimension
# with any index variable in preserved_indexes
# Drop any unneeded vars in a second pass, this is required for e.g.
# if the mapped function were to drop a non-dimension coordinate variable.
preserved_coords = preserved_coords.drop_vars(
tuple(k for k in preserved_coords.variables if k not in template_coords)
)

coordinates = merge(
(preserved_coords, template.coords.to_dataset()[new_coord_vars])
).coords
output_chunks: Mapping[Hashable, tuple[int, ...]] = {
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
}

else:
# template xarray object has been provided with proper sizes and chunk shapes
indexes = dict(template._indexes)
coordinates = template.coords
output_chunks = template.chunksizes
if not output_chunks:
raise ValueError(
Expand Down Expand Up @@ -473,6 +497,9 @@ def subset_dataset_to_block(

return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs)

# variable names that depend on the computation. Currently, indexes
# cannot be modified in the mapped function, so we exclude thos
computed_variables = set(template.variables) - set(coordinates.xindexes)
# iterate over all possible chunk combinations
for chunk_tuple in itertools.product(*ichunk.values()):
# mapping from dimension name to chunk index
Expand All @@ -485,29 +512,32 @@ def subset_dataset_to_block(
for isxr, arg in zip(is_xarray, npargs)
]

# expected["shapes", "coords", "data_vars", "indexes"] are used to
# raise nice error messages in _wrapper
expected = {}
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
# even if length of dimension is changed by the applied function
expected["shapes"] = {
k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks
}
expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment]
expected["coords"] = set(template.coords.keys()) # type: ignore[assignment]
expected["indexes"] = {
dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)]
for dim in indexes
expected: ExpectedDict = {
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
# even if length of dimension is changed by the applied function
"shapes": {
k: output_chunks[k][v]
for k, v in chunk_index.items()
if k in output_chunks
},
"data_vars": set(template.data_vars.keys()),
"coords": set(template.coords.keys()),
"indexes": {
dim: coordinates.xindexes[dim][
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
]
for dim in coordinates.xindexes
},
}

from_wrapper = (gname,) + chunk_tuple
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)

# mapping from variable name to dask graph key
var_key_map: dict[Hashable, str] = {}
for name, variable in template.variables.items():
if name in indexes:
continue
for name in computed_variables:
variable = template.variables[name]
gname_l = f"{name}-{gname}"
var_key_map[name] = gname_l

Expand Down Expand Up @@ -543,12 +573,7 @@ def subset_dataset_to_block(
},
)

# TODO: benbovy - flexible indexes: make it work with custom indexes
# this will need to pass both indexes and coords to the Dataset constructor
result = Dataset(
coords={k: idx.to_pandas_index() for k, idx in indexes.items()},
attrs=template.attrs,
)
result = Dataset(coords=coordinates, attrs=template.attrs)

for index in result._indexes:
result[index].attrs = template[index].attrs
Expand Down
37 changes: 37 additions & 0 deletions xarray/core/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,43 @@ def reduction(
"""
raise NotImplementedError()

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
**kwargs,
) -> T_ChunkedArray:
"""
General version of a 1D scan, also known as a cumulative array reduction.
Used in ``ffill`` and ``bfill`` in xarray.
Parameters
----------
func: callable
Cumulative function like np.cumsum or np.cumprod
binop: callable
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
ident: Number
Associated identity like ``np.cumsum->0`` or ``np.cumprod->1``
arr: dask Array
axis: int, optional
dtype: dtype
Returns
-------
Chunked array
See also
--------
dask.array.cumreduction
"""
raise NotImplementedError()

@abstractmethod
def apply_gufunc(
self,
Expand Down
5 changes: 5 additions & 0 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,11 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
values = func(
padded.data, window=self.window[0], min_count=min_count, axis=axis
)
# index 0 is at the rightmost edge of the window
# need to reverse index here
# see GH #8541
if func in [bottleneck.move_argmin, bottleneck.move_argmax]:
values = self.window[0] - 1 - values

if self.center[0]:
values = values[valid]
Expand Down
Loading

0 comments on commit a064430

Please sign in to comment.