From e7ba9637a65467334e95e658069186e981c54475 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 12 May 2017 13:52:36 -0400 Subject: [PATCH] simplify solution Suggestion by jcrist --- dask/bag/core.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/dask/bag/core.py b/dask/bag/core.py index fcde2239742..a3289905c12 100644 --- a/dask/bag/core.py +++ b/dask/bag/core.py @@ -733,11 +733,10 @@ def reduction(self, perpartition, aggregate, split_every=None, fmt = '%s-aggregate-%s' % (name or funcname(aggregate), token) depth = 0 - while k > 1: + while k > split_every: c = fmt + str(depth) - is_last = k <= split_every dsk2 = dict(((c, i), (empty_safe_aggregate, aggregate, - [(b, j) for j in inds], is_last)) + [(b, j) for j in inds], False)) for i, inds in enumerate(partition_all(split_every, range(k)))) dsk.update(dsk2) @@ -745,21 +744,14 @@ def reduction(self, perpartition, aggregate, split_every=None, b = c depth += 1 - if self.npartitions == 1: - dsk[(a, 0)] = (aggregate, [dsk[(a, 0)]]) - - if not self.npartitions: - task = (aggregate, []) - if out_type is Item: - return Item({b: task}, b) - else: - return Bag({(b, 0): task}, b, 1) + dsk[(fmt, 0)] = (empty_safe_aggregate, aggregate, + [(b, j) for j in range(k)], True) if out_type is Item: - dsk[b] = dsk.pop((b, 0)) - return Item(merge(self.dask, dsk), b) + dsk[fmt] = dsk.pop((fmt, 0)) + return Item(merge(self.dask, dsk), fmt) else: - return Bag(merge(self.dask, dsk), b, 1) + return Bag(merge(self.dask, dsk), fmt, 1) def sum(self, split_every=None): """ Sum all elements """