Skip to content

Commit

Permalink
add bag.repartition
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Feb 14, 2016
1 parent 058d637 commit 4cee85e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
21 changes: 21 additions & 0 deletions dask/bag/core.py
Expand Up @@ -838,6 +838,27 @@ def to_imperative(self):
from dask.imperative import Value
return [Value(k, [self.dask]) for k in self._keys()]

def repartition(self, npartitions):
""" Coalesce bag into fewer partitions
Examples
--------
>>> b.repartition(5) # set to have 5 partitions # doctest: +SKIP
"""
if npartitions > self.npartitions:
raise NotImplementedError(
"Repartition only supports going to fewer partitions\n"
" old: %d new: %d" % (self.npartitions, npartitions))
size = self.npartitions / npartitions
L = [int(i * self.npartitions / npartitions)
for i in range(npartitions + 1)]
name = 'repartition-%d-%s' % (npartitions, self.name)
dsk = dict(((name, i), (list,
(toolz.concat, [(self.name, j)
for j in range(L[i], L[i + 1])])))
for i in range(npartitions))
return Bag(merge(self.dask, dsk), name, npartitions)


normalize_token.register(Item, lambda a: a.key)
normalize_token.register(Bag, lambda a: a.name)
Expand Down
15 changes: 15 additions & 0 deletions dask/bag/tests/test_bag.py
Expand Up @@ -829,3 +829,18 @@ def test_range():
assert len(b.dask) == npartitions
assert b.npartitions == npartitions
assert list(b) == list(range(100))


def test_repartition():
for x, y in [(10, 5), (7, 3), (5, 1), (5, 4)]:
b = db.from_sequence(range(20), npartitions=x)
c = b.repartition(y)

assert b.npartitions == x
assert c.npartitions == y
assert list(b) == c.compute(get=dask.get)

try:
b.repartition(100)
except NotImplementedError as e:
assert '100' in str(e)

0 comments on commit 4cee85e

Please sign in to comment.