Skip to content

Commit

Permalink
Join keep cases together
Browse files Browse the repository at this point in the history
Reuse the `Delayed` store objects in the case of `keep=False` as well.
Should simplify some logic between the two cases.
  • Loading branch information
jakirkham committed Dec 13, 2017
1 parent cf14ecd commit ff7ca87
Showing 1 changed file with 25 additions and 25 deletions.
50 changes: 25 additions & 25 deletions dask/array/core.py
Expand Up @@ -884,13 +884,10 @@ def store(sources, targets, lock=True, regions=None, compute=True, keep=False, *
raise ValueError("Different number of sources [%d] and targets [%d] than regions [%d]"
% (len(sources), len(targets), len(regions)))

store_dlyds = []
if keep:
load_names = []
load_dsks = []
store_dlyds = []
else:
updates = {}
keys = []
for tgt, src, reg in zip(targets, sources, regions):
# if out is a delayed object update dictionary accordingly
try:
Expand All @@ -906,25 +903,21 @@ def store(sources, targets, lock=True, regions=None, compute=True, keep=False, *
load_names.append('load-%s' % src.name)
load_dsks.append(retrieve_from_ooc(each_store_dsk))

store_dlyds.append([])
for each_store_key in each_store_dsk:
store_dlyds[-1].append(
Delayed(
each_store_key,
sharedict.merge(
(
each_store_key,
{each_store_key: each_store_dsk[each_store_key]}
),
dsk,
src.dask
)
store_dlyds.append([])
for each_store_key in each_store_dsk:
store_dlyds[-1].append(
Delayed(
each_store_key,
sharedict.merge(
(
each_store_key,
{each_store_key: each_store_dsk[each_store_key]}
),
dsk,
src.dask
)
)
else:
keys.extend(each_store_dsk)
dsk.update(each_store_dsk)
updates.update(dsk)
)

if keep:
if compute:
Expand All @@ -950,13 +943,20 @@ def store(sources, targets, lock=True, regions=None, compute=True, keep=False, *
results = tuple(results)
return results
else:
keys = [e.key for e in core.flatten(store_dlyds)]
name = 'store-' + tokenize(*keys)
dsk = sharedict.merge((name, updates), *[src.dask for src in sources])
key_names = tuple(set(e[0] for e in keys))
dsk = sharedict.merge(
(name, {name: e for e in key_names}),
*[e.dask for e in core.flatten(store_dlyds)]
)
dsk.update({name: keys})
result = Delayed(name, dsk)

if compute:
compute_as_if_collection(Array, dsk, keys, **kwargs)
result.compute()
else:
dsk.update({name: keys})
return Delayed(name, dsk)
return result


def blockdims_from_blockshape(shape, chunks):
Expand Down

0 comments on commit ff7ca87

Please sign in to comment.