Skip to content

Commit

Permalink
apacheGH-36709: [Python] Allow to specify use_threads=False in Table.…
Browse files Browse the repository at this point in the history
…group_by to have stable ordering (apache#36768)

### Rationale for this change

Add a `use_threads` keyword to the `group_by` method on Table, and passes this through to the Declaration.to_table call. This also allows to specify `use_threads=False` to get stable ordering of the output, and which is also required to specify for certain aggregations (eg `"first"` will fail with the default of `use_threads=True`)

### Are these changes tested?

Yes, added a test (similar to the one we have for this for `filter`), that would fail (>50% of the times) if the output was no longer ordered.

* Closes: apache#36709

Authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
  • Loading branch information
jorisvandenbossche authored and loicalleyne committed Nov 13, 2023
1 parent 4bf4830 commit 2cf2a82
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 7 deletions.
4 changes: 2 additions & 2 deletions python/pyarrow/acero.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,10 @@ def _sort_source(table_or_dataset, sort_keys, output_type=Table, **kwargs):
raise TypeError("Unsupported output type")


def _group_by(table, aggregates, keys):
def _group_by(table, aggregates, keys, use_threads=True):

decl = Declaration.from_sequence([
Declaration("table_source", TableSourceNodeOptions(table)),
Declaration("aggregate", AggregateNodeOptions(aggregates, keys=keys))
])
return decl.to_table(use_threads=True)
return decl.to_table(use_threads=use_threads)
20 changes: 15 additions & 5 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -4599,8 +4599,9 @@ cdef class Table(_Tabular):
"""
return self.drop_columns(columns)

def group_by(self, keys):
"""Declare a grouping over the columns of the table.
def group_by(self, keys, use_threads=True):
"""
Declare a grouping over the columns of the table.
Resulting grouping can then be used to perform aggregations
with a subsequent ``aggregate()`` method.
Expand All @@ -4609,6 +4610,9 @@ cdef class Table(_Tabular):
----------
keys : str or list[str]
Name of the columns that should be used as the grouping key.
use_threads : bool, default True
Whether to use multithreading or not. When set to True (the
default), no stable ordering of the output is guaranteed.
Returns
-------
Expand All @@ -4635,7 +4639,7 @@ cdef class Table(_Tabular):
year: [[2020,2022,2021,2019]]
n_legs_sum: [[2,6,104,5]]
"""
return TableGroupBy(self, keys)
return TableGroupBy(self, keys, use_threads=use_threads)

def join(self, right_table, keys, right_keys=None, join_type="left outer",
left_suffix=None, right_suffix=None, coalesce_keys=True,
Expand Down Expand Up @@ -5183,6 +5187,9 @@ class TableGroupBy:
Input table to execute the aggregation on.
keys : str or list[str]
Name of the grouped columns.
use_threads : bool, default True
Whether to use multithreading or not. When set to True (the default),
no stable ordering of the output is guaranteed.
Examples
--------
Expand All @@ -5208,12 +5215,13 @@ class TableGroupBy:
values_sum: [[3,7,5]]
"""

def __init__(self, table, keys):
def __init__(self, table, keys, use_threads=True):
if isinstance(keys, str):
keys = [keys]

self._table = table
self.keys = keys
self._use_threads = use_threads

def aggregate(self, aggregations):
"""
Expand Down Expand Up @@ -5328,4 +5336,6 @@ list[tuple(str, str, FunctionOptions)]
aggr_name = "_".join(target) + "_" + func_nohash
group_by_aggrs.append((target, func, opt, aggr_name))

return _pac()._group_by(self._table, group_by_aggrs, self.keys)
return _pac()._group_by(
self._table, group_by_aggrs, self.keys, use_threads=self._use_threads
)
14 changes: 14 additions & 0 deletions python/pyarrow/tests/test_exec_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,17 @@ def test_join_extension_array_column():
result = _perform_join(
"left outer", t1, ["colB"], t3, ["colC"])
assert result["colB"] == pa.chunked_array(ext_array)


def test_group_by_ordering():
# GH-36709 - preserve ordering in groupby by setting use_threads=False
table1 = pa.table({'a': [1, 2, 3, 4], 'b': ['a'] * 4})
table2 = pa.table({'a': [1, 2, 3, 4], 'b': ['b'] * 4})
table = pa.concat_tables([table1, table2])

for _ in range(50):
# 50 seems to consistently cause errors when order is not preserved.
# If the order problem is reintroduced this test will become flaky
# which is still a signal that the order is not preserved.
result = table.group_by("b", use_threads=False).aggregate([])
assert result["b"] == pa.chunked_array([["a"], ["b"]])
15 changes: 15 additions & 0 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,6 +2175,21 @@ def sorted_by_keys(d):
}


@pytest.mark.acero
def test_table_group_by_first():
# "first" is an ordered aggregation -> requires to specify use_threads=False
table1 = pa.table({'a': [1, 2, 3, 4], 'b': ['a', 'b'] * 2})
table2 = pa.table({'a': [1, 2, 3, 4], 'b': ['b', 'a'] * 2})
table = pa.concat_tables([table1, table2])

with pytest.raises(NotImplementedError):
table.group_by("b").aggregate([("a", "first")])

result = table.group_by("b", use_threads=False).aggregate([("a", "first")])
expected = pa.table({"b": ["a", "b"], "a_first": [1, 2]})
assert result.equals(expected)


def test_table_to_recordbatchreader():
table = pa.Table.from_pydict({'x': [1, 2, 3]})
reader = table.to_reader()
Expand Down

0 comments on commit 2cf2a82

Please sign in to comment.