Skip to content

Commit

Permalink
fix(dask): aggregation with multi-key groupby fails on dask backend
Browse files Browse the repository at this point in the history
  • Loading branch information
patcao authored and cpcloud committed Jul 8, 2022
1 parent 4725571 commit 4f8bc70
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 4 deletions.
27 changes: 23 additions & 4 deletions ibis/backends/dask/execution/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dask.dataframe as dd
import dask.delayed
import numpy as np
import pandas as pd
from dask.dataframe.groupby import SeriesGroupBy

Expand Down Expand Up @@ -37,10 +38,26 @@ def register_types_to_dispatcher(
dispatcher.register(ibis_op, *types_to_register)(fn)


def make_meta_series(dtype, name=None, index_name=None):
def make_meta_series(
dtype: np.dtype,
name: Optional[str] = None,
meta_index: Optional[pd.Index] = None,
):
if isinstance(meta_index, pd.MultiIndex):
index_names = meta_index.names
series_index = pd.MultiIndex(
levels=[[]] * len(index_names),
codes=[[]] * len(index_names),
names=index_names,
)
elif isinstance(meta_index, pd.Index):
series_index = pd.Index([], name=meta_index.name)
else:
series_index = pd.Index([])

return pd.Series(
[],
index=pd.Index([], name=index_name),
index=series_index,
dtype=dtype,
name=name,
)
Expand Down Expand Up @@ -209,15 +226,17 @@ def _coerce_to_dataframe(
# NOTE - We add a detailed meta here so we do not drop the key index
# downstream. This seems to be fixed in versions of dask > 2020.12.0
dtypes = map(ibis_dtype_to_pandas, types)

series = [
data.apply(
_select_item_in_iter,
selection=i,
meta=make_meta_series(dtype, index_name=data.index.name),
meta=make_meta_series(
dtype, meta_index=data._meta_nonempty.index
),
)
for i, dtype in enumerate(dtypes)
]

result = dd.concat(series, axis=1)

elif isinstance(data, (tuple, list)):
Expand Down
47 changes: 47 additions & 0 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,53 @@ def test_aggregate_grouped(
backend.assert_frame_equal(result2, expected)


@mark.notimpl(
[
"clickhouse",
"datafusion",
"duckdb",
"impala",
"mysql",
"postgres",
"pyspark",
"sqlite",
]
)
def test_aggregate_multikey_group_reduction(backend, alltypes, df):
"""Tests .aggregate() on a multi-key groupby with a reduction operation"""

@reduction(
input_type=[dt.double],
output_type=dt.Struct(['mean', 'std'], [dt.double, dt.double]),
)
def mean_and_std(v):
return v.mean(), v.std()

grouping_key_cols = ['bigint_col', 'int_col']

expr1 = alltypes.groupby(grouping_key_cols).aggregate(
mean_and_std(alltypes['double_col']).destructure()
)

result1 = expr1.execute()

# Note: Using `reset_index` to get the grouping key as a column
expected = (
df.groupby(grouping_key_cols)['double_col']
.agg(['mean', 'std'])
.reset_index()
)

# Row ordering may differ depending on backend, so sort on the
# grouping key
result1 = result1.sort_values(by=grouping_key_cols).reset_index(drop=True)
expected = expected.sort_values(by=grouping_key_cols).reset_index(
drop=True
)

backend.assert_frame_equal(result1, expected)


@pytest.mark.parametrize(
('result_fn', 'expected_fn'),
[
Expand Down

0 comments on commit 4f8bc70

Please sign in to comment.