Skip to content

Commit

Permalink
add dask.bag.concat
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed May 12, 2015
1 parent 11fe6e8 commit d5c6f13
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dask/bag/__init__.py
@@ -1,5 +1,5 @@
from __future__ import absolute_import, division, print_function

from .core import (Bag, Item, from_sequence, from_filenames, from_hdfs,
to_textfiles)
to_textfiles, concat)
from ..context import set_options
28 changes: 23 additions & 5 deletions dask/bag/core.py
Expand Up @@ -14,9 +14,10 @@
from functools import wraps


from toolz import (merge, concat, frequencies, merge_with, take, curry, reduce,
from toolz import (merge, frequencies, merge_with, take, curry, reduce,
join, reduceby, valmap, count, map, partition_all, filter,
pluck, groupby)
import toolz
try:
from cytoolz import (curry, frequencies, merge_with, join, reduceby,
count, pluck, groupby)
Expand Down Expand Up @@ -325,7 +326,7 @@ def topk(self, k, key=None):
topk = heapq.nlargest
dsk = dict(((a, i), (list, (topk, k, (self.name, i))))
for i in range(self.npartitions))
dsk2 = {(b, 0): (list, (topk, k, (concat, list(dsk.keys()))))}
dsk2 = {(b, 0): (list, (topk, k, (toolz.concat, list(dsk.keys()))))}
return Bag(merge(self.dask, dsk, dsk2), b, 1)

def distinct(self):
Expand Down Expand Up @@ -497,7 +498,7 @@ def foldby(self, key, binop, initial=no_default, combine=None,
dsk2 = {(b, 0): (dictitems,
(reduceby,
0, combine2,
(concat, (map, dictitems, list(dsk.keys()))),
(toolz.concat, (map, dictitems, list(dsk.keys()))),
combine_initial))}
else:
dsk2 = {(b, 0): (dictitems,
Expand Down Expand Up @@ -530,7 +531,7 @@ def _keys(self):
def compute(self, **kwargs):
results = self.get(self.dask, self._keys(), **kwargs)
if isinstance(results[0], Iterable):
results = concat(results)
results = toolz.concat(results)
if not isinstance(results, Iterator):
results = iter(results)
return results
Expand All @@ -546,7 +547,7 @@ def concat(self):
[1, 2, 3]
"""
name = next(names)
dsk = dict(((name, i), (list, (concat, (self.name, i))))
dsk = dict(((name, i), (list, (toolz.concat, (self.name, i))))
for i in range(self.npartitions))
return Bag(merge(self.dask, dsk), name, self.npartitions)

Expand Down Expand Up @@ -835,3 +836,20 @@ def takes_multiple_arguments(func):
if spec.defaults is None:
return len(spec.args) != 1
return len(spec.args) - len(spec.defaults) > 1


def concat(bags):
""" Concatenate many bags together, unioning all elements
>>> a = db.from_sequence([1, 2, 3])
>>> b = db.from_sequence([4, 5, 6])
>>> c = db.concat([a, b])
>>> list(c)
[1, 2, 3, 4, 5, 6]
"""
name = next(names)
counter = itertools.count(0)
dsk = dict(((name, next(counter)), key) for bag in bags
for key in sorted(bag.dask))
return Bag(merge(dsk, *[b.dask for b in bags]), name, len(dsk))
8 changes: 8 additions & 0 deletions dask/bag/tests/test_bag.py
Expand Up @@ -327,3 +327,11 @@ def test_bz2_stream():
text = '\n'.join(map(str, range(10000)))
compressed = bz2.compress(text.encode())
assert list(take(100, bz2_stream(compressed))) == list(map(str, range(100)))


def test_concat():
a = db.from_sequence([1, 2, 3])
b = db.from_sequence([4, 5, 6])
c = db.concat([a, b])

assert list(c) == [1, 2, 3, 4, 5, 6]

0 comments on commit d5c6f13

Please sign in to comment.