Skip to content

Commit

Permalink
FEAT-#6934: Support 'include_groups=False' parameter in 'groupby.appl…
Browse files Browse the repository at this point in the history
…y()' (#6938)

Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
  • Loading branch information
dchigarev committed Feb 19, 2024
1 parent b875991 commit 4704751
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
7 changes: 1 addition & 6 deletions modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3967,12 +3967,7 @@ def apply_func(df): # pragma: no cover
df = df.squeeze(axis=1)
result = operator(df.groupby(by, **kwargs))

if (
align_result_columns
and df.empty
and result.empty
and df.columns.equals(result.columns)
):
if align_result_columns and df.empty and result.empty:
# We want to align columns only of those frames that actually performed
# some groupby aggregation, if an empty frame was originally passed
# (an empty bin on reshuffling was created) then there were no groupby
Expand Down
12 changes: 1 addition & 11 deletions modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,16 +646,6 @@ def cummax(self, axis=lib.no_default, numeric_only=False, **kwargs):
)

def apply(self, func, *args, include_groups=True, **kwargs):
if not include_groups:
return self._default_to_pandas(
lambda df: df.apply(
func,
*args,
include_groups=include_groups,
**kwargs,
)
)

func = cast_function_modin2pandas(func)
if not isinstance(func, BuiltinFunctionType):
func = wrap_udf_function(func)
Expand All @@ -665,7 +655,7 @@ def apply(self, func, *args, include_groups=True, **kwargs):
numeric_only=False,
agg_func=func,
agg_args=args,
agg_kwargs=kwargs,
agg_kwargs={**kwargs, "include_groups": include_groups},
how="group_wise",
)
reduced_index = pandas.Index([MODIN_UNNAMED_SERIES_LABEL])
Expand Down
28 changes: 28 additions & 0 deletions modin/pandas/test/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3358,3 +3358,31 @@ def test_range_groupby_categories_external_grouper(columns, cat_cols):
pd_df, pd_by = get_external_groupers(pd_df, columns, drop_from_original_df=True)

eval_general(md_df.groupby(md_by), pd_df.groupby(pd_by), lambda grp: grp.count())


@pytest.mark.parametrize("by", [["a"], ["a", "b"]])
@pytest.mark.parametrize("as_index", [True, False])
@pytest.mark.parametrize("include_groups", [True, False])
def test_include_groups(by, as_index, include_groups):
data = {
"a": [1, 1, 2, 2] * 64,
"b": [11, 11, 22, 22] * 64,
"c": [111, 111, 222, 222] * 64,
"data": [1, 2, 3, 4] * 64,
}

def func(df):
if include_groups:
assert len(df.columns.intersection(by)) == len(by)
else:
assert len(df.columns.intersection(by)) == 0
return df.sum()

md_df, pd_df = create_test_dfs(data)
eval_general(
md_df,
pd_df,
lambda df: df.groupby(by, as_index=as_index).apply(
func, include_groups=include_groups
),
)

0 comments on commit 4704751

Please sign in to comment.