Skip to content

Commit

Permalink
Rename method= keyword to shuffle= in bag.groupby
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed May 4, 2018
1 parent 4756b59 commit bc1d137
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
25 changes: 13 additions & 12 deletions dask/bag/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import io
import itertools
import math
import types
import uuid
import warnings
from collections import Iterable, Iterator, defaultdict
Expand Down Expand Up @@ -1190,7 +1189,7 @@ def __iter__(self):
return iter(self.compute())

def groupby(self, grouper, method=None, npartitions=None, blocksize=2**20,
max_branch=None):
max_branch=None, shuffle=None):
""" Group collection by key function
This requires a full dataset read, serialization and shuffle.
Expand All @@ -1200,7 +1199,7 @@ def groupby(self, grouper, method=None, npartitions=None, blocksize=2**20,
----------
grouper: function
Function on which to group elements
method: str
shuffle: str
Either 'disk' for an on-disk shuffle or 'tasks' to use the task
scheduling framework. Use 'disk' if you are on a single machine
and 'tasks' if you are on a distributed cluster.
Expand All @@ -1224,20 +1223,22 @@ def groupby(self, grouper, method=None, npartitions=None, blocksize=2**20,
--------
Bag.foldby
"""
if method is None:
get = _globals.get('get')
if (isinstance(get, types.MethodType) and
'distributed' in get.__func__.__module__):
method = 'tasks'
if method is not None:
raise Exception("The method= keyword has been moved to shuffle=")
if shuffle is None:
shuffle = _globals.get('shuffle')
if shuffle is None:
if 'distributed' in _globals.get('scheduler', ''):
shuffle = 'tasks'
else:
method = 'disk'
if method == 'disk':
shuffle = 'disk'
if shuffle == 'disk':
return groupby_disk(self, grouper, npartitions=npartitions,
blocksize=blocksize)
elif method == 'tasks':
elif shuffle == 'tasks':
return groupby_tasks(self, grouper, max_branch=max_branch)
else:
msg = "Shuffle method must be 'disk' or 'tasks'"
msg = "Shuffle must be 'disk' or 'tasks'"
raise NotImplementedError(msg)

def to_dataframe(self, meta=None, columns=None):
Expand Down
22 changes: 11 additions & 11 deletions dask/bag/tests/test_bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ def test_accumulate():

def test_groupby_tasks():
b = db.from_sequence(range(160), npartitions=4)
out = b.groupby(lambda x: x % 10, max_branch=4, method='tasks')
out = b.groupby(lambda x: x % 10, max_branch=4, shuffle='tasks')
partitions = dask.get(out.dask, out.__dask_keys__())

for a in partitions:
Expand All @@ -1122,7 +1122,7 @@ def test_groupby_tasks():
assert not set(pluck(0, a)) & set(pluck(0, b))

b = db.from_sequence(range(1000), npartitions=100)
out = b.groupby(lambda x: x % 123, method='tasks')
out = b.groupby(lambda x: x % 123, shuffle='tasks')
assert len(out.dask) < 100**2
partitions = dask.get(out.dask, out.__dask_keys__())

Expand All @@ -1132,7 +1132,7 @@ def test_groupby_tasks():
assert not set(pluck(0, a)) & set(pluck(0, b))

b = db.from_sequence(range(10000), npartitions=345)
out = b.groupby(lambda x: x % 2834, max_branch=24, method='tasks')
out = b.groupby(lambda x: x % 2834, max_branch=24, shuffle='tasks')
partitions = dask.get(out.dask, out.__dask_keys__())

for a in partitions:
Expand All @@ -1145,26 +1145,26 @@ def test_groupby_tasks_names():
b = db.from_sequence(range(160), npartitions=4)
func = lambda x: x % 10
func2 = lambda x: x % 20
assert (set(b.groupby(func, max_branch=4, method='tasks').dask) ==
set(b.groupby(func, max_branch=4, method='tasks').dask))
assert (set(b.groupby(func, max_branch=4, method='tasks').dask) !=
set(b.groupby(func, max_branch=2, method='tasks').dask))
assert (set(b.groupby(func, max_branch=4, method='tasks').dask) !=
set(b.groupby(func2, max_branch=4, method='tasks').dask))
assert (set(b.groupby(func, max_branch=4, shuffle='tasks').dask) ==
set(b.groupby(func, max_branch=4, shuffle='tasks').dask))
assert (set(b.groupby(func, max_branch=4, shuffle='tasks').dask) !=
set(b.groupby(func, max_branch=2, shuffle='tasks').dask))
assert (set(b.groupby(func, max_branch=4, shuffle='tasks').dask) !=
set(b.groupby(func2, max_branch=4, shuffle='tasks').dask))


@pytest.mark.parametrize('size,npartitions,groups', [(1000, 20, 100),
(12345, 234, 1042)])
def test_groupby_tasks_2(size, npartitions, groups):
func = lambda x: x % groups
b = db.range(size, npartitions=npartitions).groupby(func, method='tasks')
b = db.range(size, npartitions=npartitions).groupby(func, shuffle='tasks')
result = b.compute(scheduler='sync')
assert dict(result) == groupby(func, range(size))


def test_groupby_tasks_3():
func = lambda x: x % 10
b = db.range(20, npartitions=5).groupby(func, method='tasks', max_branch=2)
b = db.range(20, npartitions=5).groupby(func, shuffle='tasks', max_branch=2)
result = b.compute(scheduler='sync')
assert dict(result) == groupby(func, range(20))
# assert b.npartitions == 5
Expand Down

0 comments on commit bc1d137

Please sign in to comment.