Skip to content

Commit

Permalink
Backport PR pandas-dev#48027 on branch 1.5.x (ENH: Support masks in g…
Browse files Browse the repository at this point in the history
…roupby prod) (pandas-dev#48302)

Backport PR pandas-dev#48027: ENH: Support masks in groupby prod

Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and phofl committed Aug 30, 2022
1 parent 97cf8e2 commit 3ca5773
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 21 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.5.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ Groupby/resample/rolling
- Bug when using ``engine="numba"`` would return the same jitted function when modifying ``engine_kwargs`` (:issue:`46086`)
- Bug in :meth:`.DataFrameGroupBy.transform` fails when ``axis=1`` and ``func`` is ``"first"`` or ``"last"`` (:issue:`45986`)
- Bug in :meth:`DataFrameGroupBy.cumsum` with ``skipna=False`` giving incorrect results (:issue:`46216`)
- Bug in :meth:`.GroupBy.sum` and :meth:`.GroupBy.cumsum` with integer dtypes losing precision (:issue:`37493`)
- Bug in :meth:`.GroupBy.sum`, :meth:`.GroupBy.prod` and :meth:`.GroupBy.cumsum` with integer dtypes losing precision (:issue:`37493`)
- Bug in :meth:`.GroupBy.cumsum` with ``timedelta64[ns]`` dtype failing to recognize ``NaT`` as a null value (:issue:`46216`)
- Bug in :meth:`.GroupBy.cumsum` with integer dtypes causing overflows when sum was bigger than maximum of dtype (:issue:`37493`)
- Bug in :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` with nullable dtypes incorrectly altering the original data in place (:issue:`46220`)
Expand Down
6 changes: 4 additions & 2 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@ def group_sum(
is_datetimelike: bool = ...,
) -> None: ...
def group_prod(
out: np.ndarray, # floating[:, ::1]
out: np.ndarray, # int64float_t[:, ::1]
counts: np.ndarray, # int64_t[::1]
values: np.ndarray, # ndarray[floating, ndim=2]
values: np.ndarray, # ndarray[int64float_t, ndim=2]
labels: np.ndarray, # const intp_t[:]
mask: np.ndarray | None,
result_mask: np.ndarray | None = ...,
min_count: int = ...,
) -> None: ...
def group_var(
Expand Down
34 changes: 27 additions & 7 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -682,21 +682,24 @@ def group_sum(
@cython.wraparound(False)
@cython.boundscheck(False)
def group_prod(
floating[:, ::1] out,
int64float_t[:, ::1] out,
int64_t[::1] counts,
ndarray[floating, ndim=2] values,
ndarray[int64float_t, ndim=2] values,
const intp_t[::1] labels,
const uint8_t[:, ::1] mask,
uint8_t[:, ::1] result_mask=None,
Py_ssize_t min_count=0,
) -> None:
"""
Only aggregates on axis=0
"""
cdef:
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
floating val, count
floating[:, ::1] prodx
int64float_t val, count
int64float_t[:, ::1] prodx
int64_t[:, ::1] nobs
Py_ssize_t len_values = len(values), len_labels = len(labels)
bint isna_entry, uses_mask = mask is not None

if len_values != len_labels:
raise ValueError("len(index) != len(labels)")
Expand All @@ -716,15 +719,32 @@ def group_prod(
for j in range(K):
val = values[i, j]

# not nan
if val == val:
if uses_mask:
isna_entry = mask[i, j]
elif int64float_t is float32_t or int64float_t is float64_t:
isna_entry = not val == val
else:
isna_entry = False

if not isna_entry:
nobs[lab, j] += 1
prodx[lab, j] *= val

for i in range(ncounts):
for j in range(K):
if nobs[i, j] < min_count:
out[i, j] = NAN

# else case is not possible
if uses_mask:
result_mask[i, j] = True
# Be deterministic, out was initialized as empty
out[i, j] = 0
elif int64float_t is float32_t or int64float_t is float64_t:
out[i, j] = NAN
else:
# we only get here when < mincount which gets handled later
pass

else:
out[i, j] = prodx[i, j]

Expand Down
21 changes: 15 additions & 6 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
"sum",
"ohlc",
"cumsum",
"prod",
}

_cython_arity = {"ohlc": 4} # OHLC
Expand Down Expand Up @@ -221,13 +222,13 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
values = ensure_float64(values)

elif values.dtype.kind in ["i", "u"]:
if how in ["var", "prod", "mean"] or (
if how in ["var", "mean"] or (
self.kind == "transform" and self.has_dropped_na
):
# result may still include NaN, so we have to cast
values = ensure_float64(values)

elif how in ["sum", "ohlc", "cumsum"]:
elif how in ["sum", "ohlc", "prod", "cumsum"]:
# Avoid overflow during group op
if values.dtype.kind == "i":
values = ensure_int64(values)
Expand Down Expand Up @@ -597,8 +598,16 @@ def _call_cython_op(
min_count=min_count,
is_datetimelike=is_datetimelike,
)
elif self.how == "ohlc":
func(result, counts, values, comp_ids, min_count, mask, result_mask)
elif self.how in ["ohlc", "prod"]:
func(
result,
counts,
values,
comp_ids,
min_count=min_count,
mask=mask,
result_mask=result_mask,
)
else:
func(result, counts, values, comp_ids, min_count, **kwargs)
else:
Expand Down Expand Up @@ -631,8 +640,8 @@ def _call_cython_op(
# need to have the result set to np.nan, which may require casting,
# see GH#40767
if is_integer_dtype(result.dtype) and not is_datetimelike:
# Neutral value for sum is 0, so don't fill empty groups with nan
cutoff = max(0 if self.how == "sum" else 1, min_count)
# if the op keeps the int dtypes, we have to use 0
cutoff = max(0 if self.how in ["sum", "prod"] else 1, min_count)
empty_groups = counts < cutoff
if empty_groups.any():
if result_mask is not None and self.uses_mask():
Expand Down
19 changes: 14 additions & 5 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2847,8 +2847,8 @@ def test_single_element_list_grouping():
values, _ = next(iter(df.groupby(["a"])))


@pytest.mark.parametrize("func", ["sum", "cumsum"])
def test_groupby_sum_avoid_casting_to_float(func):
@pytest.mark.parametrize("func", ["sum", "cumsum", "prod"])
def test_groupby_avoid_casting_to_float(func):
# GH#37493
val = 922337203685477580
df = DataFrame({"a": 1, "b": [val]})
Expand All @@ -2859,12 +2859,13 @@ def test_groupby_sum_avoid_casting_to_float(func):
tm.assert_frame_equal(result, expected)


def test_groupby_sum_support_mask(any_numeric_ea_dtype):
@pytest.mark.parametrize("func, val", [("sum", 3), ("prod", 2)])
def test_groupby_sum_support_mask(any_numeric_ea_dtype, func, val):
# GH#37493
df = DataFrame({"a": 1, "b": [1, 2, pd.NA]}, dtype=any_numeric_ea_dtype)
result = df.groupby("a").sum()
result = getattr(df.groupby("a"), func)()
expected = DataFrame(
{"b": [3]},
{"b": [val]},
index=Index([1], name="a", dtype=any_numeric_ea_dtype),
dtype=any_numeric_ea_dtype,
)
Expand All @@ -2887,6 +2888,14 @@ def test_groupby_overflow(val, dtype):
expected = DataFrame({"b": [val, val * 2]}, dtype=f"{dtype}64")
tm.assert_frame_equal(result, expected)

result = df.groupby("a").prod()
expected = DataFrame(
{"b": [val * val]},
index=Index([1], name="a", dtype=f"{dtype}64"),
dtype=f"{dtype}64",
)
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("skipna, val", [(True, 3), (False, pd.NA)])
def test_groupby_cumsum_mask(any_numeric_ea_dtype, skipna, val):
Expand Down

0 comments on commit 3ca5773

Please sign in to comment.