Skip to content
forked from pydata/xarray

Commit

Permalink
Deprecate squeeze in GroupBy.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 2, 2023
1 parent 5213f0d commit b7805a8
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 59 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Breaking changes

Deprecations
~~~~~~~~~~~~
- The `squeeze` kwarg to GroupBy is now deprecated. (:issue:`2157`)
By `Deepak Cherian <https://github.com/dcherian>`_.

- As part of an effort to standardize the API, we're renaming the ``dims``
keyword arg to ``dim`` for the minority of functions which current use
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6620,7 +6620,7 @@ def interp_calendar(
def groupby(
self,
group: Hashable | DataArray | IndexVariable,
squeeze: bool = True,
squeeze: bool | None = None,
restore_coord_dims: bool = False,
) -> DataArrayGroupBy:
"""Returns a DataArrayGroupBy object for performing grouped operations.
Expand Down
105 changes: 77 additions & 28 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray
from xarray.core.utils import (
either_dict_or_kwargs,
emit_user_level_warning,
hashable,
is_scalar,
maybe_wrap_array,
Expand Down Expand Up @@ -73,6 +74,21 @@ def check_reduce_dims(reduce_dims, dimensions):
)


def _maybe_squeeze_indices(
indices, squeeze: bool | None, grouper: ResolvedGrouper, warn: bool
):
if squeeze in [None, True] and grouper.can_squeeze:
if squeeze is None and warn:
emit_user_level_warning(
"The `squeeze` kwarg to GroupBy is being removed."
"Pass .groupby(..., squeeze=False) to silence this warning."
)
if isinstance(indices, slice):
assert indices.stop - indices.start == 1
indices = indices.start
return indices


def unique_value_groups(
ar, sort: bool = True
) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]:
Expand Down Expand Up @@ -366,10 +382,10 @@ def dims(self):
return self.group1d.dims

@abstractmethod
def _factorize(self, squeeze: bool) -> T_FactorizeOut:
def factorize(self) -> T_FactorizeOut:
raise NotImplementedError

def factorize(self, squeeze: bool) -> None:
def _factorize(self) -> None:
# This design makes it clear to mypy that
# codes, group_indices, unique_coord, and full_index
# are set by the factorize method on the derived class.
Expand All @@ -378,7 +394,7 @@ def factorize(self, squeeze: bool) -> None:
self.group_indices,
self.unique_coord,
self.full_index,
) = self._factorize(squeeze)
) = self.factorize()

@property
def is_unique_and_monotonic(self) -> bool:
Expand All @@ -393,15 +409,19 @@ def group_as_index(self) -> pd.Index:
self._group_as_index = self.group1d.to_index()
return self._group_as_index

@property
def can_squeeze(self):
is_dimension = self.group.dims == (self.group.name,)
return is_dimension and self.is_unique_and_monotonic


@dataclass
class ResolvedUniqueGrouper(ResolvedGrouper):
grouper: UniqueGrouper

def _factorize(self, squeeze) -> T_FactorizeOut:
is_dimension = self.group.dims == (self.group.name,)
if is_dimension and self.is_unique_and_monotonic:
return self._factorize_dummy(squeeze)
def factorize(self) -> T_FactorizeOut:
if self.can_squeeze:
return self._factorize_dummy()
else:
return self._factorize_unique()

Expand All @@ -424,15 +444,12 @@ def _factorize_unique(self) -> T_FactorizeOut:

return codes, group_indices, unique_coord, full_index

def _factorize_dummy(self, squeeze) -> T_FactorizeOut:
def _factorize_dummy(self) -> T_FactorizeOut:
size = self.group.size
# no need to factorize
if not squeeze:
# use slices to do views instead of fancy indexing
# equivalent to: group_indices = group_indices.reshape(-1, 1)
group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)]
else:
group_indices = list(range(size))
# use slices to do views instead of fancy indexing
# equivalent to: group_indices = group_indices.reshape(-1, 1)
group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)]
size_range = np.arange(size)
if isinstance(self.group, _DummyGroup):
codes = self.group.to_dataarray().copy(data=size_range)
Expand All @@ -448,7 +465,7 @@ def _factorize_dummy(self, squeeze) -> T_FactorizeOut:
class ResolvedBinGrouper(ResolvedGrouper):
grouper: BinGrouper

def _factorize(self, squeeze: bool) -> T_FactorizeOut:
def factorize(self) -> T_FactorizeOut:
from xarray.core.dataarray import DataArray

data = self.group1d.values
Expand Down Expand Up @@ -546,7 +563,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
_apply_loffset(self.grouper.loffset, first_items)
return first_items, codes

def _factorize(self, squeeze: bool) -> T_FactorizeOut:
def factorize(self) -> T_FactorizeOut:
full_index, first_items, codes_ = self._get_index_and_items()
sbins = first_items.values.astype(np.int64)
group_indices: T_GroupIndices = [
Expand Down Expand Up @@ -591,14 +608,14 @@ class TimeResampleGrouper(Grouper):
loffset: datetime.timedelta | str | None


def _validate_groupby_squeeze(squeeze: bool) -> None:
def _validate_groupby_squeeze(squeeze: bool | None) -> None:
# While we don't generally check the type of every arg, passing
# multiple dimensions as multiple arguments is common enough, and the
# consequences hidden enough (strings evaluate as true) to warrant
# checking here.
# A future version could make squeeze kwarg only, but would face
# backward-compat issues.
if not isinstance(squeeze, bool):
if squeeze is not None and not isinstance(squeeze, bool):
raise TypeError(f"`squeeze` must be True or False, but {squeeze} was supplied")


Expand Down Expand Up @@ -730,7 +747,7 @@ def __init__(
self._original_obj = obj

for grouper_ in self.groupers:
grouper_.factorize(squeeze)
grouper_._factorize()

(grouper,) = self.groupers
self._original_group = grouper.group
Expand Down Expand Up @@ -762,9 +779,14 @@ def sizes(self) -> Mapping[Hashable, int]:
Dataset.sizes
"""
if self._sizes is None:
self._sizes = self._obj.isel(
{self._group_dim: self._group_indices[0]}
).sizes
(grouper,) = self.groupers
index = _maybe_squeeze_indices(
self._group_indices[0],
self._squeeze,
grouper,
warn=True,
)
self._sizes = self._obj.isel({self._group_dim: index}).sizes

return self._sizes

Expand Down Expand Up @@ -798,14 +820,22 @@ def groups(self) -> dict[GroupKey, GroupIndex]:
# provided to mimic pandas.groupby
if self._groups is None:
(grouper,) = self.groupers
self._groups = dict(zip(grouper.unique_coord.values, self._group_indices))
squeezed_indices = (
_maybe_squeeze_indices(ind, self._squeeze, grouper, warn=idx > 0)
for idx, ind in enumerate(self._group_indices)
)
self._groups = dict(zip(grouper.unique_coord.values, squeezed_indices))
return self._groups

def __getitem__(self, key: GroupKey) -> T_Xarray:
"""
Get DataArray or Dataset corresponding to a particular group label.
"""
return self._obj.isel({self._group_dim: self.groups[key]})
(grouper,) = self.groupers
index = _maybe_squeeze_indices(
self.groups[key], self._squeeze, grouper, warn=True
)
return self._obj.isel({self._group_dim: index})

def __len__(self) -> int:
(grouper,) = self.groupers
Expand All @@ -826,7 +856,11 @@ def __repr__(self) -> str:

def _iter_grouped(self) -> Iterator[T_Xarray]:
"""Iterate over each element in this group"""
for indices in self._group_indices:
(grouper,) = self.groupers
for idx, indices in enumerate(self._group_indices):
indices = _maybe_squeeze_indices(
indices, self._squeeze, grouper, warn=idx > 0
)
yield self._obj.isel({self._group_dim: indices})

def _infer_concat_args(self, applied_example):
Expand Down Expand Up @@ -1309,7 +1343,11 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic):
@property
def dims(self) -> tuple[Hashable, ...]:
if self._dims is None:
self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims
(grouper,) = self.groupers
index = _maybe_squeeze_indices(
self._group_indices[0], self._squeeze, grouper, warn=True
)
self._dims = self._obj.isel({self._group_dim: index}).dims

return self._dims

Expand All @@ -1318,7 +1356,11 @@ def _iter_grouped_shortcut(self):
metadata
"""
var = self._obj.variable
for indices in self._group_indices:
(grouper,) = self.groupers
for idx, indices in enumerate(self._group_indices):
indices = _maybe_squeeze_indices(
indices, self._squeeze, grouper, warn=idx > 0
)
yield var[{self._group_dim: indices}]

def _concat_shortcut(self, applied, dim, positions=None):
Expand Down Expand Up @@ -1517,7 +1559,14 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic):
@property
def dims(self) -> Frozen[Hashable, int]:
if self._dims is None:
self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims
(grouper,) = self.groupers
index = _maybe_squeeze_indices(
self._group_indices[0],
self._squeeze,
grouper,
warn=True,
)
self._dims = self._obj.isel({self._group_dim: index}).dims

return self._dims

Expand Down
Loading

0 comments on commit b7805a8

Please sign in to comment.