Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKPORT] Fix output of df.groupby(as_index=False).size() (#2507) #2508

Merged
merged 1 commit into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions mars/dataframe/groupby/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def _do_custom_agg(op, custom_reduction, *input_objs):
return concat_result

@staticmethod
def _do_predefined_agg(input_obj, agg_func, **kwds):
def _do_predefined_agg(input_obj, agg_func, single_func=False, **kwds):
ndim = getattr(input_obj, 'ndim', None) or input_obj.obj.ndim
if agg_func == 'str_concat':
agg_func = lambda x: x.str.cat(**kwds)
Expand All @@ -640,8 +640,14 @@ def _do_predefined_agg(input_obj, agg_func, **kwds):
agg_func.__name__ = func_name

if ndim == 2:
result = input_obj.agg([agg_func])
result.columns = result.columns.droplevel(-1)
if single_func:
result = input_obj.agg(agg_func)
if result.ndim == 1:
# when agg_func == size, agg only returns one single series.
result = result.to_frame(agg_func)
else:
result = input_obj.agg([agg_func])
result.columns = result.columns.droplevel(-1)
return result
else:
return input_obj.agg(agg_func)
Expand Down Expand Up @@ -704,7 +710,9 @@ def _wrapped_func(col):
if map_func_name == 'custom_reduction':
agg_dfs.extend(cls._do_custom_agg(op, custom_reduction, input_obj))
else:
agg_dfs.append(cls._do_predefined_agg(input_obj, map_func_name, **kwds))
single_func = map_func_name == op.raw_func
agg_dfs.append(cls._do_predefined_agg(
input_obj, map_func_name, single_func, **kwds))

if op._size_recorder_name is not None:
# record_size
Expand Down
5 changes: 5 additions & 0 deletions mars/dataframe/groupby/tests/test_groupby_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,11 @@ def test_dataframe_groupby_agg(setup):

# test as_index=False
for method in ['tree', 'shuffle']:
r = mdf.groupby('c2', as_index=False).agg('size', method=method)
pd.testing.assert_frame_equal(
r.execute().fetch().sort_values('c2', ignore_index=True),
raw.groupby('c2', as_index=False).agg('size').sort_values('c2', ignore_index=True))

r = mdf.groupby('c2', as_index=False).agg('mean', method=method)
pd.testing.assert_frame_equal(
r.execute().fetch().sort_values('c2', ignore_index=True),
Expand Down