Skip to content

Commit

Permalink
_full
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jun 3, 2023
1 parent d148074 commit a37c559
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions numpy_groupies/aggregate_numpy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from packaging.version import Version

from .utils import (
aggregate_common_doc,
Expand All @@ -13,13 +14,27 @@
check_fill_value,
input_validation,
iscomplexobj,
maxval,
minimum_dtype,
minimum_dtype_scalar,
minval,
maxval,
)


def _full(size, fill_value, *, dtype=None, like=None):
"""Backcompat for numpy < 1.20.0 which does not support the `like` kwarg"""
if (
like is not None # numpy bug?
and not np.isscalar(like) # scalars don't work
and Version(np.__version__) >= Version("1.20.0")
):
kwargs = {"like": like}
else:
kwargs = {}

return np.full(size, fill_value=fill_value, dtype=dtype, **kwargs)


def _sum(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype_scalar(fill_value, dtype, a)

Expand All @@ -44,7 +59,7 @@ def _sum(group_idx, a, size, fill_value, dtype=None):

def _prod(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype_scalar(fill_value, dtype, a)
ret = np.full(size, fill_value, dtype=dtype, like=a)
ret = _full(size, fill_value, dtype=dtype, like=a)
if fill_value != 1:
ret[group_idx] = 1 # product starts from 1
np.multiply.at(ret, group_idx, a)
Expand All @@ -57,7 +72,7 @@ def _len(group_idx, a, size, fill_value, dtype=None):

def _last(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype(fill_value, dtype or a.dtype)
ret = np.full(size, fill_value, dtype=dtype, like=a)
ret = _full(size, fill_value, dtype=dtype, like=a)
# repeated indexing gives last value, see:
# the phrase "leaving behind the last value" on this page:
# http://wiki.scipy.org/Tentative_NumPy_Tutorial
Expand All @@ -67,14 +82,14 @@ def _last(group_idx, a, size, fill_value, dtype=None):

def _first(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype(fill_value, dtype or a.dtype)
ret = np.full(size, fill_value, dtype=dtype, like=a)
ret = _full(size, fill_value, dtype=dtype, like=a)
ret[group_idx[::-1]] = a[::-1] # same trick as _last, but in reverse
return ret


def _all(group_idx, a, size, fill_value, dtype=None):
check_boolean(fill_value)
ret = np.full(size, fill_value, dtype=bool, like=a)
ret = _full(size, fill_value, dtype=bool, like=a)
if not fill_value:
ret[group_idx] = True
ret[group_idx.compress(np.logical_not(a))] = False
Expand All @@ -83,7 +98,7 @@ def _all(group_idx, a, size, fill_value, dtype=None):

def _any(group_idx, a, size, fill_value, dtype=None):
check_boolean(fill_value)
ret = np.full(size, fill_value, dtype=bool, like=a)
ret = _full(size, fill_value, dtype=bool, like=a)
if fill_value:
ret[group_idx] = False
ret[group_idx.compress(a)] = True
Expand All @@ -93,7 +108,7 @@ def _any(group_idx, a, size, fill_value, dtype=None):
def _min(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype(fill_value, dtype or a.dtype)
dmax = maxval(fill_value, dtype)
ret = np.full(size, fill_value, dtype=dtype, like=a)
ret = _full(size, fill_value, dtype=dtype, like=a)
if fill_value != dmax:
ret[group_idx] = dmax # min starts from maximum
np.minimum.at(ret, group_idx, a)
Expand All @@ -103,7 +118,7 @@ def _min(group_idx, a, size, fill_value, dtype=None):
def _max(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype(fill_value, dtype or a.dtype)
dmin = minval(fill_value, dtype)
ret = np.full(size, fill_value, dtype=dtype, like=a)
ret = _full(size, fill_value, dtype=dtype, like=a)
if fill_value != dmin:
ret[group_idx] = dmin # max starts from minimum
np.maximum.at(ret, group_idx, a)
Expand All @@ -115,7 +130,7 @@ def _argmax(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
group_max = _max(group_idx, a_, size, np.nan)
# nan should never be maximum, so use a and not a_
is_max = a == group_max[group_idx]
ret = np.full(size, fill_value, dtype=dtype, like=a)
ret = _full(size, fill_value, dtype=dtype, like=a)
group_idx_max = group_idx[is_max]
(argmax,) = is_max.nonzero()
ret[group_idx_max[::-1]] = argmax[
Expand All @@ -129,7 +144,7 @@ def _argmin(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
group_min = _min(group_idx, a_, size, np.nan)
# nan should never be minimum, so use a and not a_
is_min = a == group_min[group_idx]
ret = np.full(size, fill_value, dtype=dtype, like=a)
ret = _full(size, fill_value, dtype=dtype, like=a)
group_idx_min = group_idx[is_min]
(argmin,) = is_min.nonzero()
ret[group_idx_min[::-1]] = argmin[
Expand All @@ -148,7 +163,9 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
sums.real = np.bincount(group_idx, weights=a.real, minlength=size)
sums.imag = np.bincount(group_idx, weights=a.imag, minlength=size)
else:
sums = np.bincount(group_idx, weights=a, minlength=size).astype(dtype, copy=False)
sums = np.bincount(group_idx, weights=a, minlength=size).astype(
dtype, copy=False
)

with np.errstate(divide="ignore", invalid="ignore"):
ret = sums.astype(dtype, copy=False) / counts
Expand Down Expand Up @@ -223,7 +240,7 @@ def _generic_callable(
"""groups a by inds, and then applies foo to each group in turn, placing
the results in an array."""
groups = _array(group_idx, a, size, ())
ret = np.full(size, fill_value, dtype=dtype or np.float64)
ret = _full(size, fill_value, dtype=dtype or np.float64)

for i, grp in enumerate(groups):
if np.ndim(grp) == 1 and len(grp) > 0:
Expand Down

0 comments on commit a37c559

Please sign in to comment.