New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optionally return stored data after storing #2980
Changes from all commits
d46f4b8
1b98c33
e3f18e8
e87db92
39a99c7
05d90f4
69fab37
eda594e
01a74d0
3bc167f
b22631e
9c16ecc
1084e40
604e29f
3ba9f1e
bc06ae6
32ba769
7e62dc8
ae66c63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,12 +18,12 @@ | |
import warnings | ||
|
||
try: | ||
from cytoolz import (partition, concat, join, first, | ||
from cytoolz import (partition, concat, concatv, join, first, | ||
groupby, valmap, accumulate, assoc) | ||
from cytoolz.curried import filter, pluck | ||
|
||
except ImportError: | ||
from toolz import (partition, concat, join, first, | ||
from toolz import (partition, concat, concatv, join, first, | ||
groupby, valmap, accumulate, assoc) | ||
from toolz.curried import filter, pluck | ||
from toolz import pipe, map, reduce | ||
|
@@ -32,14 +32,15 @@ | |
from . import chunk | ||
from .numpy_compat import _make_sliced_dtype | ||
from .slicing import slice_array, replace_ellipsis | ||
from ..base import Base, tokenize, dont_optimize, compute_as_if_collection | ||
from ..base import (Base, tokenize, dont_optimize, compute_as_if_collection, | ||
persist) | ||
from ..context import _globals, globalmethod | ||
from ..utils import (homogeneous_deepmap, ndeepmap, ignoring, concrete, | ||
is_integer, IndexCallable, funcname, derived_from, | ||
SerializableLock, ensure_dict, Dispatch) | ||
from ..compatibility import unicode, long, getargspec, zip_longest, apply | ||
from ..core import quote | ||
from ..delayed import to_task_dask | ||
from ..delayed import Delayed, to_task_dask | ||
from .. import threaded, core | ||
from .. import sharedict | ||
from ..sharedict import ShareDict | ||
|
@@ -815,7 +816,8 @@ def broadcast_chunks(*chunkss): | |
return tuple(result) | ||
|
||
|
||
def store(sources, targets, lock=True, regions=None, compute=True, **kwargs): | ||
def store(sources, targets, lock=True, regions=None, compute=True, | ||
return_stored=False, **kwargs): | ||
""" Store dask arrays in array-like objects, overwrite data in target | ||
|
||
This stores dask arrays into object that supports numpy-style setitem | ||
|
@@ -842,6 +844,8 @@ def store(sources, targets, lock=True, regions=None, compute=True, **kwargs): | |
for the corresponding source and target in sources and targets, respectively. | ||
compute: boolean, optional | ||
If true compute immediately, return ``dask.delayed.Delayed`` otherwise | ||
return_stored: boolean, optional | ||
Optionally return the stored result (default False). | ||
|
||
Examples | ||
-------- | ||
|
@@ -859,6 +863,7 @@ def store(sources, targets, lock=True, regions=None, compute=True, **kwargs): | |
|
||
>>> store([x, y, z], [dset1, dset2, dset3]) # doctest: +SKIP | ||
""" | ||
|
||
if isinstance(sources, Array): | ||
sources = [sources] | ||
targets = [targets] | ||
|
@@ -880,31 +885,73 @@ def store(sources, targets, lock=True, regions=None, compute=True, **kwargs): | |
raise ValueError("Different number of sources [%d] and targets [%d] than regions [%d]" | ||
% (len(sources), len(targets), len(regions))) | ||
|
||
updates = {} | ||
keys = [] | ||
# Optimize all sources together | ||
sources_dsk = sharedict.merge(*[e.__dask_graph__() for e in sources]) | ||
sources_dsk = Array.__dask_optimize__( | ||
sources_dsk, | ||
[e.__dask_keys__() for e in sources] | ||
) | ||
|
||
tgt_dsks = [] | ||
store_keys = [] | ||
store_dsks = [] | ||
if return_stored: | ||
load_names = [] | ||
load_dsks = [] | ||
for tgt, src, reg in zip(targets, sources, regions): | ||
# if out is a delayed object update dictionary accordingly | ||
try: | ||
dsk = {} | ||
dsk.update(tgt.dask) | ||
each_tgt_dsk = {} | ||
each_tgt_dsk.update(tgt.dask) | ||
tgt = tgt.key | ||
except AttributeError: | ||
dsk = {} | ||
each_tgt_dsk = {} | ||
|
||
src = Array(sources_dsk, src.name, src.chunks, src.dtype) | ||
|
||
each_store_dsk = insert_to_ooc( | ||
src, tgt, lock=lock, region=reg, return_stored=return_stored | ||
) | ||
|
||
if return_stored: | ||
load_names.append('load-store-%s' % src.name) | ||
load_dsks.append(retrieve_from_ooc( | ||
each_store_dsk.keys(), | ||
each_store_dsk | ||
)) | ||
|
||
tgt_dsks.append(each_tgt_dsk) | ||
|
||
store_keys.extend(each_store_dsk.keys()) | ||
store_dsks.append(each_store_dsk) | ||
|
||
update = insert_to_ooc(src, tgt, lock=lock, region=reg) | ||
keys.extend(update) | ||
store_dsks_mrg = sharedict.merge(*concatv( | ||
store_dsks, tgt_dsks, [sources_dsk] | ||
)) | ||
|
||
update.update(dsk) | ||
updates.update(update) | ||
if return_stored: | ||
if compute: | ||
store_dlyds = [Delayed(k, store_dsks_mrg) for k in store_keys] | ||
store_dlyds = persist(*store_dlyds) | ||
store_dsks_mrg = sharedict.merge(*[e.dask for e in store_dlyds]) | ||
|
||
name = 'store-' + tokenize(*keys) | ||
dsk = sharedict.merge((name, updates), *[src.dask for src in sources]) | ||
if compute: | ||
compute_as_if_collection(Array, dsk, keys, **kwargs) | ||
load_dsks_mrg = sharedict.merge(store_dsks_mrg, *load_dsks) | ||
|
||
result = tuple( | ||
Array(load_dsks_mrg, n, src.chunks, src.dtype) for n in load_names | ||
) | ||
|
||
return result | ||
else: | ||
from ..delayed import Delayed | ||
dsk.update({name: keys}) | ||
return Delayed(name, dsk) | ||
name = 'store-' + tokenize(*store_keys) | ||
dsk = sharedict.merge({name: store_keys}, store_dsks_mrg) | ||
result = Delayed(name, dsk) | ||
|
||
if compute: | ||
result.compute() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missed that this call should still get |
||
return None | ||
else: | ||
return result | ||
|
||
|
||
def blockdims_from_blockshape(shape, chunks): | ||
|
@@ -2157,9 +2204,7 @@ def atop(func, out_ind, *args, **kwargs): | |
concatenate : bool, keyword only | ||
If true concatenate arrays along dummy indices, else provide lists | ||
adjust_chunks : dict | ||
Dictionary mapping index to information to adjust chunk sizes. Can | ||
either be a constant chunksize, a tuple of all chunksizes, or a | ||
function that converts old chunksize to new chunksize | ||
Dictionary mapping index to function to be applied to chunk sizes | ||
new_axes : dict, keyword only | ||
New indexes and their dimension lengths | ||
|
||
|
@@ -2568,7 +2613,7 @@ def concatenate(seq, axis=0, allow_unknown_chunksizes=False): | |
return Array(dsk2, name, chunks, dtype=dt) | ||
|
||
|
||
def store_chunk(x, out, index, lock, region): | ||
def store_chunk(x, out, index, lock, region, return_stored): | ||
""" | ||
A function inserted in a Dask graph for storing a chunk. | ||
|
||
|
@@ -2584,15 +2629,21 @@ def store_chunk(x, out, index, lock, region): | |
Lock to use before writing to ``out``. | ||
region: slice-like or None | ||
Where relative to ``out`` to store ``x``. | ||
return_stored: bool | ||
Whether to return ``out``. | ||
|
||
Examples | ||
-------- | ||
|
||
>>> a = np.ones((5, 6)) | ||
>>> b = np.empty(a.shape) | ||
>>> store_chunk(a, b, (slice(None), slice(None)), False, None) | ||
>>> store_chunk(a, b, (slice(None), slice(None)), False, None, False) | ||
""" | ||
|
||
result = None | ||
if return_stored: | ||
result = out | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I missed this bit before. Is this still used? From below it looks like the result is always re-read from the store, but I may be missing something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is still used. In an effort to chain the We could also move this branch to the end and simply There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, missed that it was |
||
|
||
subindex = index | ||
if region is not None: | ||
subindex = fuse_slice(region, index) | ||
|
@@ -2605,10 +2656,10 @@ def store_chunk(x, out, index, lock, region): | |
if lock: | ||
lock.release() | ||
|
||
return None | ||
return result | ||
|
||
|
||
def insert_to_ooc(arr, out, lock=True, region=None): | ||
def insert_to_ooc(arr, out, lock=True, region=None, return_stored=False): | ||
""" | ||
Creates a Dask graph for storing chunks from ``arr`` in ``out``. | ||
|
||
|
@@ -2624,6 +2675,9 @@ def insert_to_ooc(arr, out, lock=True, region=None): | |
region: slice-like, optional | ||
Where in ``out`` to store ``arr``'s results | ||
(default is ``None``, meaning all of ``out``). | ||
return_stored: bool, optional | ||
Whether to return ``out`` | ||
(default is ``False``, meaning ``None`` is returned). | ||
|
||
Examples | ||
-------- | ||
|
@@ -2642,13 +2696,79 @@ def insert_to_ooc(arr, out, lock=True, region=None): | |
dsk = dict() | ||
for t, slc in zip(core.flatten(arr.__dask_keys__()), slices): | ||
store_key = (name,) + t[1:] | ||
dsk[store_key] = ( | ||
store_chunk, t, out, slc, lock, region | ||
) | ||
dsk[store_key] = (store_chunk, t, out, slc, lock, region, return_stored) | ||
|
||
return dsk | ||
|
||
|
||
def load_chunk(x, index, lock, region): | ||
""" | ||
A function inserted in a Dask graph for loading a chunk. | ||
|
||
Parameters | ||
---------- | ||
x: array-like | ||
An array (potentially a NumPy one) | ||
index: slice-like | ||
Where to store result from ``x`` in ``out``. | ||
lock: Lock-like or False | ||
Lock to use before writing to ``out``. | ||
region: slice-like or None | ||
Where relative to ``out`` to store ``x``. | ||
|
||
Examples | ||
-------- | ||
|
||
>>> a = np.ones((5, 6)) | ||
>>> load_chunk(a, (slice(None), slice(None)), False, None) # doctest: +SKIP | ||
""" | ||
|
||
result = None | ||
|
||
subindex = index | ||
if region is not None: | ||
subindex = fuse_slice(region, index) | ||
|
||
if lock: | ||
lock.acquire() | ||
try: | ||
result = x[subindex] | ||
finally: | ||
if lock: | ||
lock.release() | ||
|
||
return result | ||
|
||
|
||
def retrieve_from_ooc(keys, dsk): | ||
""" | ||
Creates a Dask graph for loading stored ``keys`` from ``dsk``. | ||
|
||
Parameters | ||
---------- | ||
keys: Sequence | ||
A sequence containing Dask graph keys to load | ||
dsk: Mapping | ||
A Dask graph corresponding to a Dask Array | ||
|
||
Examples | ||
-------- | ||
>>> import dask.array as da | ||
>>> d = da.ones((5, 6), chunks=(2, 3)) | ||
>>> a = np.empty(d.shape) | ||
>>> g = insert_to_ooc(d, a) | ||
>>> retrieve_from_ooc(g.keys(), g) # doctest: +SKIP | ||
""" | ||
|
||
load_dsk = dict() | ||
for each_key in keys: | ||
load_key = ('load-%s' % each_key[0],) + each_key[1:] | ||
# Reuse the result and arguments from `store_chunk` in `load_chunk`. | ||
load_dsk[load_key] = (load_chunk, each_key,) + dsk[each_key][3:-1] | ||
|
||
return load_dsk | ||
|
||
|
||
def asarray(a): | ||
"""Convert the input to a dask array. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While you are looking @jcrist, would be good if you could give these optimization lines a quick look. Based on the docs this seemed ok, but you certainly know better since you wrote all of these
__dask_*__
functions. 😉There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's exactly how they should be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for checking.