Skip to content

Commit

Permalink
Support delayed and single-partition bags in Bag.join
Browse files Browse the repository at this point in the history
This can significantly improve performance when joining against larger
collections due to serialization overhead on the distributed scheduler.

There is still more work to do here for multi-partition joins.

Experiments also show that GC is having a profound effect on performance
here.
  • Loading branch information
mrocklin committed Mar 9, 2018
1 parent cc9f8db commit 1841902
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 13 deletions.
58 changes: 51 additions & 7 deletions dask/bag/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,21 +918,65 @@ def std(self, ddof=0):
def join(self, other, on_self, on_other=None):
""" Joins collection with another collection.
Other collection must be an Iterable, and not a Bag.
Other collection must be one of the following:
1. An iterable. We recommend tuples over lists for internal
performance reasons.
2. A delayed object, pointing to a tuple. This is recommended if the
other collection is sizable and you're using the distributed
scheduler. Dask is able to pass around data wrapped in delayed
objects with greater sophistication.
3. A Bag with a single partition
You might also consider Dask Dataframe, whose join operations are much
more heavily optimized.
Parameters
----------
other: Iterable, Delayed, Bag
Other collection on which to join
on_self: callable
Function to call on elements in this collection to determine a
match
on_other: callable (defaults to on_self)
Function to call on elements in the other collection to determine a
match
Examples
--------
>>> people = from_sequence(['Alice', 'Bob', 'Charlie'])
>>> fruit = ['Apple', 'Apricot', 'Banana']
>>> list(people.join(fruit, lambda x: x[0])) # doctest: +SKIP
[('Apple', 'Alice'), ('Apricot', 'Alice'), ('Banana', 'Bob')]
"""
assert isinstance(other, Iterable)
assert not isinstance(other, Bag)
name = 'join-' + tokenize(self, other, on_self, on_other)
dsk = {}
if isinstance(other, Bag):
if other.npartitions == 1:
dsk.update(other.dask)
dsk['join-%s-other' % name] = (list, other._keys()[0])
other = other._keys()[0]
else:
msg = ("Multi-bag joins are not implemented. "
"We recommend Dask dataframe if appropriate")
raise NotImplementedError(msg)
elif isinstance(other, Delayed):
dsk.update(other.dask)
other = other._key
elif isinstance(other, Iterable):
other = other
else:
msg = ("Joined argument must be single-partition Bag, "
" delayed object, or Iterable, got %s" %
type(other).__name)
raise TypeError(msg)

if on_other is None:
on_other = on_self
name = 'join-' + tokenize(self, other, on_self, on_other)
dsk = dict(((name, i), (list, (join, on_other, other,
on_self, (self.name, i))))
for i in range(self.npartitions))

dsk.update({(name, i): (list, (join, on_other, other,
on_self, (self.name, i)))
for i in range(self.npartitions)})
return type(self)(merge(self.dask, dsk), name, self.npartitions)

def product(self, other):
Expand Down
27 changes: 21 additions & 6 deletions dask/bag/tests/test_bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@
b = Bag(dsk, 'x', 3)


def assert_eq(a, b):
if hasattr(a, 'compute'):
a = a.compute(get=dask.local.get_sync)
if hasattr(b, 'compute'):
b = b.compute(get=dask.local.get_sync)

assert a == b


def iseven(x):
return x % 2 == 0

Expand Down Expand Up @@ -353,12 +362,18 @@ def test_var():
assert float(b.var()) == 2.0


def test_join():
c = b.join([1, 2, 3], on_self=isodd, on_other=iseven)
assert list(c) == list(join(iseven, [1, 2, 3], isodd, list(b)))
assert (list(b.join([1, 2, 3], isodd)) ==
list(join(isodd, [1, 2, 3], isodd, list(b))))
assert c.name == b.join([1, 2, 3], on_self=isodd, on_other=iseven).name
@pytest.mark.parametrize('transform', [
identity,
dask.delayed,
lambda x: db.from_sequence(x, npartitions=1)
])
def test_join(transform):
other = transform([1, 2, 3])
c = b.join(other, on_self=isodd, on_other=iseven)
assert_eq(c, list(join(iseven, [1, 2, 3], isodd, list(b))))
assert_eq(b.join(other, isodd),
list(join(isodd, [1, 2, 3], isodd, list(b))))
assert c.name == b.join(other, on_self=isodd, on_other=iseven).name


def test_foldby():
Expand Down

0 comments on commit 1841902

Please sign in to comment.