Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX-#2405: Make sure named aggregation work for Series objects #6892

Merged
merged 5 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions modin/pandas/base.py
Expand Up @@ -655,6 +655,7 @@ def aggregate(self, func=None, axis=0, *args, **kwargs): # noqa: PR01, RT01, D2
Aggregate using one or more operations over the specified axis.
"""
axis = self._get_axis_number(axis)

anmyachev marked this conversation as resolved.
Show resolved Hide resolved
result = None

if axis == 0:
Expand Down
39 changes: 39 additions & 0 deletions modin/pandas/groupby.py
Expand Up @@ -1924,8 +1924,45 @@
)
)

def _validate_func_kwargs(self, kwargs: dict):
"""
Validate types of user-provided "named aggregation" kwargs.

`TypeError` is raised if aggfunc is not `str` or callable.
anmyachev marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
kwargs : dict

Returns
-------
columns : List[str]
List of user-provided keys.
func : List[Union[str, callable[...,Any]]]
List of user-provided aggfuncs.

Notes
-----
Copied from pandas
anmyachev marked this conversation as resolved.
Show resolved Hide resolved
"""
tuple_given_message = "func is expected but received {} in **kwargs."
anmyachev marked this conversation as resolved.
Show resolved Hide resolved
columns = list(kwargs)
func = []
anmyachev marked this conversation as resolved.
Show resolved Hide resolved
for col_func in kwargs.values():
if not (isinstance(col_func, str) or callable(col_func)):
raise TypeError(tuple_given_message.format(type(col_func).__name__))

Check warning on line 1953 in modin/pandas/groupby.py

View check run for this annotation

Codecov / codecov/patch

modin/pandas/groupby.py#L1953

Added line #L1953 was not covered by tests
anmyachev marked this conversation as resolved.
Show resolved Hide resolved
func.append(col_func)
if not columns:
no_arg_message = "Must provide 'func' or named aggregation **kwargs."
raise TypeError(no_arg_message)

Check warning on line 1957 in modin/pandas/groupby.py

View check run for this annotation

Codecov / codecov/patch

modin/pandas/groupby.py#L1956-L1957

Added lines #L1956 - L1957 were not covered by tests
anmyachev marked this conversation as resolved.
Show resolved Hide resolved
return columns, func

def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
engine_default = engine is None and engine_kwargs is None
# if func is None, will switch to user-provided "named aggregation" kwargs
if func_is_none := func is None:
columns, func = self._validate_func_kwargs(kwargs)
kwargs = {}
if isinstance(func, dict) and engine_default:
raise SpecificationError("nested renamer is not supported")
elif is_list_like(func) and engine_default:
Expand All @@ -1946,6 +1983,8 @@
# because there is no need to identify which original column's aggregation
# the new column represents. alternatively we could give the query compiler
# a hint that it's for a series, not a dataframe.
if func_is_none:
return result.set_axis(labels=columns, axis=1, copy=False)
return result.set_axis(
labels=self._try_get_str_func(func), axis=1, copy=False
)
Expand Down
8 changes: 8 additions & 0 deletions modin/pandas/test/test_groupby.py
Expand Up @@ -2989,6 +2989,14 @@ def test_groupby_apply_series_result(modify_config):
)


def test_groupby_named_aggregation():
ser = pd.Series([10, 10, 10, 1, 1, 1, 2, 3], name="data")

eval_general(
ser, ser._to_pandas(), lambda ser: ser.groupby(level=0).agg(result=("max"))
anmyachev marked this conversation as resolved.
Show resolved Hide resolved
)


### TEST GROUPBY WARNINGS ###


Expand Down