Skip to content

Commit

Permalink
More like arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed May 12, 2023
1 parent 934a248 commit d148074
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions numpy_groupies/aggregate_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _sum(group_idx, a, size, fill_value, dtype=None):

def _prod(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype_scalar(fill_value, dtype, a)
ret = np.full(size, fill_value, dtype=dtype)
ret = np.full(size, fill_value, dtype=dtype, like=a)
if fill_value != 1:
ret[group_idx] = 1 # product starts from 1
np.multiply.at(ret, group_idx, a)
Expand All @@ -57,7 +57,7 @@ def _len(group_idx, a, size, fill_value, dtype=None):

def _last(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype(fill_value, dtype or a.dtype)
ret = np.full(size, fill_value, dtype=dtype)
ret = np.full(size, fill_value, dtype=dtype, like=a)
# repeated indexing gives last value, see:
# the phrase "leaving behind the last value" on this page:
# http://wiki.scipy.org/Tentative_NumPy_Tutorial
Expand All @@ -67,14 +67,14 @@ def _last(group_idx, a, size, fill_value, dtype=None):

def _first(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype(fill_value, dtype or a.dtype)
ret = np.full(size, fill_value, dtype=dtype)
ret = np.full(size, fill_value, dtype=dtype, like=a)
ret[group_idx[::-1]] = a[::-1] # same trick as _last, but in reverse
return ret


def _all(group_idx, a, size, fill_value, dtype=None):
check_boolean(fill_value)
ret = np.full(size, fill_value, dtype=bool)
ret = np.full(size, fill_value, dtype=bool, like=a)
if not fill_value:
ret[group_idx] = True
ret[group_idx.compress(np.logical_not(a))] = False
Expand All @@ -83,7 +83,7 @@ def _all(group_idx, a, size, fill_value, dtype=None):

def _any(group_idx, a, size, fill_value, dtype=None):
check_boolean(fill_value)
ret = np.full(size, fill_value, dtype=bool)
ret = np.full(size, fill_value, dtype=bool, like=a)
if fill_value:
ret[group_idx] = False
ret[group_idx.compress(a)] = True
Expand All @@ -93,7 +93,7 @@ def _any(group_idx, a, size, fill_value, dtype=None):
def _min(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype(fill_value, dtype or a.dtype)
dmax = maxval(fill_value, dtype)
ret = np.full(size, fill_value, dtype=dtype)
ret = np.full(size, fill_value, dtype=dtype, like=a)
if fill_value != dmax:
ret[group_idx] = dmax # min starts from maximum
np.minimum.at(ret, group_idx, a)
Expand All @@ -103,7 +103,7 @@ def _min(group_idx, a, size, fill_value, dtype=None):
def _max(group_idx, a, size, fill_value, dtype=None):
dtype = minimum_dtype(fill_value, dtype or a.dtype)
dmin = minval(fill_value, dtype)
ret = np.full(size, fill_value, dtype=dtype)
ret = np.full(size, fill_value, dtype=dtype, like=a)
if fill_value != dmin:
ret[group_idx] = dmin # max starts from minimum
np.maximum.at(ret, group_idx, a)
Expand All @@ -115,7 +115,7 @@ def _argmax(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
group_max = _max(group_idx, a_, size, np.nan)
# nan should never be maximum, so use a and not a_
is_max = a == group_max[group_idx]
ret = np.full(size, fill_value, dtype=dtype)
ret = np.full(size, fill_value, dtype=dtype, like=a)
group_idx_max = group_idx[is_max]
(argmax,) = is_max.nonzero()
ret[group_idx_max[::-1]] = argmax[
Expand All @@ -129,7 +129,7 @@ def _argmin(group_idx, a, size, fill_value, dtype=int, _nansqueeze=False):
group_min = _min(group_idx, a_, size, np.nan)
# nan should never be minimum, so use a and not a_
is_min = a == group_min[group_idx]
ret = np.full(size, fill_value, dtype=dtype)
ret = np.full(size, fill_value, dtype=dtype, like=a)
group_idx_min = group_idx[is_min]
(argmin,) = is_min.nonzero()
ret[group_idx_min[::-1]] = argmin[
Expand All @@ -144,7 +144,7 @@ def _mean(group_idx, a, size, fill_value, dtype=np.dtype(np.float64)):
counts = np.bincount(group_idx, minlength=size)
if iscomplexobj(a):
dtype = a.dtype # TODO: this is a bit clumsy
sums = np.empty(size, dtype=dtype)
sums = np.empty(size, dtype=dtype, like=a)
sums.real = np.bincount(group_idx, weights=a.real, minlength=size)
sums.imag = np.bincount(group_idx, weights=a.imag, minlength=size)
else:
Expand Down

0 comments on commit d148074

Please sign in to comment.