Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally return stored data after storing #2980

Merged
merged 19 commits into from Jan 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
182 changes: 151 additions & 31 deletions dask/array/core.py
Expand Up @@ -18,12 +18,12 @@
import warnings

try:
from cytoolz import (partition, concat, join, first,
from cytoolz import (partition, concat, concatv, join, first,
groupby, valmap, accumulate, assoc)
from cytoolz.curried import filter, pluck

except ImportError:
from toolz import (partition, concat, join, first,
from toolz import (partition, concat, concatv, join, first,
groupby, valmap, accumulate, assoc)
from toolz.curried import filter, pluck
from toolz import pipe, map, reduce
Expand All @@ -32,14 +32,15 @@
from . import chunk
from .numpy_compat import _make_sliced_dtype
from .slicing import slice_array, replace_ellipsis
from ..base import Base, tokenize, dont_optimize, compute_as_if_collection
from ..base import (Base, tokenize, dont_optimize, compute_as_if_collection,
persist)
from ..context import _globals, globalmethod
from ..utils import (homogeneous_deepmap, ndeepmap, ignoring, concrete,
is_integer, IndexCallable, funcname, derived_from,
SerializableLock, ensure_dict, Dispatch)
from ..compatibility import unicode, long, getargspec, zip_longest, apply
from ..core import quote
from ..delayed import to_task_dask
from ..delayed import Delayed, to_task_dask
from .. import threaded, core
from .. import sharedict
from ..sharedict import ShareDict
Expand Down Expand Up @@ -815,7 +816,8 @@ def broadcast_chunks(*chunkss):
return tuple(result)


def store(sources, targets, lock=True, regions=None, compute=True, **kwargs):
def store(sources, targets, lock=True, regions=None, compute=True,
return_stored=False, **kwargs):
""" Store dask arrays in array-like objects, overwrite data in target

This stores dask arrays into object that supports numpy-style setitem
Expand All @@ -842,6 +844,8 @@ def store(sources, targets, lock=True, regions=None, compute=True, **kwargs):
for the corresponding source and target in sources and targets, respectively.
compute: boolean, optional
If true compute immediately, return ``dask.delayed.Delayed`` otherwise
return_stored: boolean, optional
Optionally return the stored result (default False).

Examples
--------
Expand All @@ -859,6 +863,7 @@ def store(sources, targets, lock=True, regions=None, compute=True, **kwargs):

>>> store([x, y, z], [dset1, dset2, dset3]) # doctest: +SKIP
"""

if isinstance(sources, Array):
sources = [sources]
targets = [targets]
Expand All @@ -880,31 +885,73 @@ def store(sources, targets, lock=True, regions=None, compute=True, **kwargs):
raise ValueError("Different number of sources [%d] and targets [%d] than regions [%d]"
% (len(sources), len(targets), len(regions)))

updates = {}
keys = []
# Optimize all sources together
sources_dsk = sharedict.merge(*[e.__dask_graph__() for e in sources])
sources_dsk = Array.__dask_optimize__(
sources_dsk,
[e.__dask_keys__() for e in sources]
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you are looking @jcrist, would be good if you could give these optimization lines a quick look. Based on the docs this seemed ok, but you certainly know better since you wrote all of these __dask_*__ functions. 😉

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly how they should be used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking.


tgt_dsks = []
store_keys = []
store_dsks = []
if return_stored:
load_names = []
load_dsks = []
for tgt, src, reg in zip(targets, sources, regions):
# if out is a delayed object update dictionary accordingly
try:
dsk = {}
dsk.update(tgt.dask)
each_tgt_dsk = {}
each_tgt_dsk.update(tgt.dask)
tgt = tgt.key
except AttributeError:
dsk = {}
each_tgt_dsk = {}

src = Array(sources_dsk, src.name, src.chunks, src.dtype)

each_store_dsk = insert_to_ooc(
src, tgt, lock=lock, region=reg, return_stored=return_stored
)

if return_stored:
load_names.append('load-store-%s' % src.name)
load_dsks.append(retrieve_from_ooc(
each_store_dsk.keys(),
each_store_dsk
))

tgt_dsks.append(each_tgt_dsk)

store_keys.extend(each_store_dsk.keys())
store_dsks.append(each_store_dsk)

update = insert_to_ooc(src, tgt, lock=lock, region=reg)
keys.extend(update)
store_dsks_mrg = sharedict.merge(*concatv(
store_dsks, tgt_dsks, [sources_dsk]
))

update.update(dsk)
updates.update(update)
if return_stored:
if compute:
store_dlyds = [Delayed(k, store_dsks_mrg) for k in store_keys]
store_dlyds = persist(*store_dlyds)
store_dsks_mrg = sharedict.merge(*[e.dask for e in store_dlyds])

name = 'store-' + tokenize(*keys)
dsk = sharedict.merge((name, updates), *[src.dask for src in sources])
if compute:
compute_as_if_collection(Array, dsk, keys, **kwargs)
load_dsks_mrg = sharedict.merge(store_dsks_mrg, *load_dsks)

result = tuple(
Array(load_dsks_mrg, n, src.chunks, src.dtype) for n in load_names
)

return result
else:
from ..delayed import Delayed
dsk.update({name: keys})
return Delayed(name, dsk)
name = 'store-' + tokenize(*store_keys)
dsk = sharedict.merge({name: store_keys}, store_dsks_mrg)
result = Delayed(name, dsk)

if compute:
result.compute()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed that this call should still get **kwargs. Fixing with PR ( #3300 ).

return None
else:
return result


def blockdims_from_blockshape(shape, chunks):
Expand Down Expand Up @@ -2157,9 +2204,7 @@ def atop(func, out_ind, *args, **kwargs):
concatenate : bool, keyword only
If true concatenate arrays along dummy indices, else provide lists
adjust_chunks : dict
Dictionary mapping index to information to adjust chunk sizes. Can
either be a constant chunksize, a tuple of all chunksizes, or a
function that converts old chunksize to new chunksize
Dictionary mapping index to function to be applied to chunk sizes
new_axes : dict, keyword only
New indexes and their dimension lengths

Expand Down Expand Up @@ -2568,7 +2613,7 @@ def concatenate(seq, axis=0, allow_unknown_chunksizes=False):
return Array(dsk2, name, chunks, dtype=dt)


def store_chunk(x, out, index, lock, region):
def store_chunk(x, out, index, lock, region, return_stored):
"""
A function inserted in a Dask graph for storing a chunk.

Expand All @@ -2584,15 +2629,21 @@ def store_chunk(x, out, index, lock, region):
Lock to use before writing to ``out``.
region: slice-like or None
Where relative to ``out`` to store ``x``.
return_stored: bool
Whether to return ``out``.

Examples
--------

>>> a = np.ones((5, 6))
>>> b = np.empty(a.shape)
>>> store_chunk(a, b, (slice(None), slice(None)), False, None)
>>> store_chunk(a, b, (slice(None), slice(None)), False, None, False)
"""

result = None
if return_stored:
result = out
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed this bit before. Is this still used? From below it looks like the result is always re-read from the store, but I may be missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is still used.

In an effort to chain the store_chunk calls to the load_chunk calls, we return out unchanged so that it can be fed as an argument into load_chunk via a key. Does that make sense? It could be there is a better way to do this. Just seemed simple enough to implement at the time.

We could also move this branch to the end and simply return either out or None if that would make it more clear.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, missed that it was out and not x. Fine by me.


subindex = index
if region is not None:
subindex = fuse_slice(region, index)
Expand All @@ -2605,10 +2656,10 @@ def store_chunk(x, out, index, lock, region):
if lock:
lock.release()

return None
return result


def insert_to_ooc(arr, out, lock=True, region=None):
def insert_to_ooc(arr, out, lock=True, region=None, return_stored=False):
"""
Creates a Dask graph for storing chunks from ``arr`` in ``out``.

Expand All @@ -2624,6 +2675,9 @@ def insert_to_ooc(arr, out, lock=True, region=None):
region: slice-like, optional
Where in ``out`` to store ``arr``'s results
(default is ``None``, meaning all of ``out``).
return_stored: bool, optional
Whether to return ``out``
(default is ``False``, meaning ``None`` is returned).

Examples
--------
Expand All @@ -2642,13 +2696,79 @@ def insert_to_ooc(arr, out, lock=True, region=None):
dsk = dict()
for t, slc in zip(core.flatten(arr.__dask_keys__()), slices):
store_key = (name,) + t[1:]
dsk[store_key] = (
store_chunk, t, out, slc, lock, region
)
dsk[store_key] = (store_chunk, t, out, slc, lock, region, return_stored)

return dsk


def load_chunk(x, index, lock, region):
"""
A function inserted in a Dask graph for loading a chunk.

Parameters
----------
x: array-like
An array (potentially a NumPy one)
index: slice-like
Where to store result from ``x`` in ``out``.
lock: Lock-like or False
Lock to use before writing to ``out``.
region: slice-like or None
Where relative to ``out`` to store ``x``.

Examples
--------

>>> a = np.ones((5, 6))
>>> load_chunk(a, (slice(None), slice(None)), False, None) # doctest: +SKIP
"""

result = None

subindex = index
if region is not None:
subindex = fuse_slice(region, index)

if lock:
lock.acquire()
try:
result = x[subindex]
finally:
if lock:
lock.release()

return result


def retrieve_from_ooc(keys, dsk):
"""
Creates a Dask graph for loading stored ``keys`` from ``dsk``.

Parameters
----------
keys: Sequence
A sequence containing Dask graph keys to load
dsk: Mapping
A Dask graph corresponding to a Dask Array

Examples
--------
>>> import dask.array as da
>>> d = da.ones((5, 6), chunks=(2, 3))
>>> a = np.empty(d.shape)
>>> g = insert_to_ooc(d, a)
>>> retrieve_from_ooc(g.keys(), g) # doctest: +SKIP
"""

load_dsk = dict()
for each_key in keys:
load_key = ('load-%s' % each_key[0],) + each_key[1:]
# Reuse the result and arguments from `store_chunk` in `load_chunk`.
load_dsk[load_key] = (load_chunk, each_key,) + dsk[each_key][3:-1]

return load_dsk


def asarray(a):
"""Convert the input to a dask array.

Expand Down