Skip to content

Commit

Permalink
Avoid calling ShareDict.__getitem__
Browse files Browse the repository at this point in the history
This is slow but occurs any time that a ShareDict is included within a
toolz.merge call.  I tested this by raising in ShareDict.__getitem__.  I've
disabled this for now because __getitem__ is something that ShareDict should
arguably support.
  • Loading branch information
mrocklin committed Mar 2, 2017
1 parent 9c59bc8 commit 20f15aa
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 32 deletions.
18 changes: 9 additions & 9 deletions dask/array/core.py
Expand Up @@ -29,7 +29,7 @@
from ..base import Base, tokenize, normalize_token
from ..utils import (homogeneous_deepmap, ndeepmap, ignoring, concrete,
is_integer, IndexCallable, funcname, derived_from,
SerializableLock)
SerializableLock, ensure_dict)
from ..compatibility import unicode, long, getargspec, zip_longest, apply
from ..delayed import to_task_dask
from .. import threaded, core
Expand Down Expand Up @@ -400,16 +400,15 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
if not kwargs: # will not be used in an apply, should be a tuple
valtups = [tuple(vt) for vt in valtups]

dsk = {}

# Add heads to tuples
keys = [(output,) + kt for kt in keytups]

dsk = {}
# Unpack delayed objects in kwargs
if kwargs:
task, dsk2 = to_task_dask(kwargs)
if dsk2:
dsk.update(dsk2)
dsk.update(ensure_dict(dsk2))
kwargs2 = task
else:
kwargs2 = kwargs
Expand All @@ -418,6 +417,7 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
vals = [(func,) + vt for vt in valtups]

dsk.update(dict(zip(keys, vals)))

return dsk


Expand Down Expand Up @@ -867,7 +867,7 @@ def store(sources, targets, lock=True, regions=None, compute=True, **kwargs):
else:
from ..delayed import Delayed
dsk.update({name: keys})
return Delayed(name, dict(dsk))
return Delayed(name, dsk)


def blockdims_from_blockshape(shape, chunks):
Expand Down Expand Up @@ -2961,7 +2961,7 @@ def bincount(x, weights=None, minlength=None):

chunks = ((minlength,),)

dsk.update(x.dask)
dsk = sharedict.merge((name, dsk), x.dask)
if weights is not None:
dsk.update(weights.dask)

Expand Down Expand Up @@ -3212,8 +3212,8 @@ def triu(m, k=0):
dsk[(name, i, j)] = (np.triu, (m.name, i, j), k - (chunk * (j - i)))
else:
dsk[(name, i, j)] = (m.name, i, j)
dsk.update(m.dask)
return Array(dsk, name, shape=m.shape, chunks=m.chunks, dtype=m.dtype)
return Array(sharedict.merge((name, dsk), m.dask), name,
shape=m.shape, chunks=m.chunks, dtype=m.dtype)


def tril(m, k=0):
Expand Down Expand Up @@ -3262,7 +3262,7 @@ def tril(m, k=0):
dsk[(name, i, j)] = (np.tril, (m.name, i, j), k - (chunk * (j - i)))
else:
dsk[(name, i, j)] = (np.zeros, (m.chunks[0][i], m.chunks[1][j]))
dsk.update(m.dask)
dsk = sharedict.merge(m.dask, (name, dsk))
return Array(dsk, name, shape=m.shape, chunks=m.chunks, dtype=m.dtype)


Expand Down
11 changes: 7 additions & 4 deletions dask/array/ghost.py
Expand Up @@ -6,11 +6,12 @@
from toolz import merge, pipe, concat, partial
from toolz.curried import map

from . import chunk, wrap
from .core import Array, map_blocks, concatenate, concatenate3, reshapelist
from .. import sharedict
from ..base import tokenize
from ..core import flatten
from ..utils import concrete
from .core import Array, map_blocks, concatenate, concatenate3, reshapelist
from . import chunk, wrap


def fractional_slice(task, axes):
Expand Down Expand Up @@ -125,8 +126,10 @@ def ghost_internal(x, axes):
mid.append(bd + axes.get(i, 0) * 2)
chunks.append(left + mid + right)

return Array(merge(interior_slices, ghost_blocks, x.dask),
name, chunks, dtype=x.dtype)
dsk = merge(interior_slices, ghost_blocks)
dsk = sharedict.merge(x.dask, (name, dsk))

return Array(dsk, name, chunks, dtype=x.dtype)


def trim_internal(x, axes):
Expand Down
2 changes: 2 additions & 0 deletions dask/array/optimization.py
Expand Up @@ -7,6 +7,7 @@
from .core import getarray, getarray_nofancy
from ..core import flatten
from ..optimize import cull, fuse, inline_functions
from ..utils import ensure_dict


def optimize(dsk, keys, fuse_keys=None, fast_functions=None,
Expand All @@ -18,6 +19,7 @@ def optimize(dsk, keys, fuse_keys=None, fast_functions=None,
2. Remove full slicing, e.g. x[:]
3. Inline fast functions like getitem and np.transpose
"""
dsk = ensure_dict(dsk)
keys = list(flatten(keys))
if fast_functions is not None:
inline_functions_fast_functions = fast_functions
Expand Down
6 changes: 4 additions & 2 deletions dask/array/percentile.py
Expand Up @@ -9,6 +9,7 @@

from .core import Array
from ..base import tokenize
from .. import sharedict


@wraps(np.percentile)
Expand Down Expand Up @@ -55,8 +56,9 @@ def percentile(a, q, interpolation='linear'):
if np.issubdtype(dtype, np.integer):
dtype = (np.array([], dtype=dtype) / 0.5).dtype

return Array(merge(a.dask, dsk, dsk2), name2, chunks=((len(q),),),
dtype=dtype)
dsk = merge(dsk, dsk2)
dsk = sharedict.merge(a.dask, (name2, dsk))
return Array(dsk, name2, chunks=((len(q),),), dtype=dtype)


def merge_percentiles(finalq, qs, vals, Ns, interpolation='lower'):
Expand Down
7 changes: 5 additions & 2 deletions dask/array/random.py
Expand Up @@ -7,6 +7,7 @@

from .core import (normalize_chunks, Array, slices_from_chunks,
broadcast_shapes, broadcast_to)
from .. import sharedict
from ..base import tokenize
from ..utils import ignoring, random_state_data

Expand Down Expand Up @@ -81,13 +82,14 @@ def _broadcast_any(ar, shape, chunks):
# Broadcast all arguments, get tiny versions as well
# Start adding the relevant bits to the graph
dsk = {}
dsks = []
lookup = {}
small_args = []
for i, ar in enumerate(args):
if isinstance(ar, (np.ndarray, Array)):
res = _broadcast_any(ar, size, chunks)
if isinstance(res, Array):
dsk.update(res.dask)
dsks.append(res.dask)
lookup[i] = res.name
elif isinstance(res, np.ndarray):
name = 'array-{}'.format(tokenize(res))
Expand All @@ -102,7 +104,7 @@ def _broadcast_any(ar, shape, chunks):
if isinstance(ar, (np.ndarray, Array)):
res = _broadcast_any(ar, size, chunks)
if isinstance(res, Array):
dsk.update(res.dask)
dsks.append(res.dask)
lookup[key] = res.name
elif isinstance(res, np.ndarray):
name = 'array-{}'.format(tokenize(res))
Expand Down Expand Up @@ -147,6 +149,7 @@ def _broadcast_any(ar, shape, chunks):
kwrg[k] = (getitem, lookup[k], slc)
vals.append((_apply_random, func.__name__, state, size, arg, kwrg))
dsk.update(dict(zip(keys, vals)))
dsk = sharedict.merge((name, dsk), *dsks)
return Array(dsk, name, chunks + extra_chunks, dtype=dtype)

@doc_wraps(np.random.RandomState.beta)
Expand Down
11 changes: 5 additions & 6 deletions dask/array/ufunc.py
Expand Up @@ -4,9 +4,8 @@

import numpy as np

from toolz.curried import merge
from .core import Array, elemwise
from .. import core
from .. import core, sharedict
from ..utils import skip_doctest


Expand Down Expand Up @@ -128,8 +127,8 @@ def frexp(x):
ldt = l.dtype
rdt = r.dtype

L = Array(merge(tmp.dask, ldsk), left, chunks=tmp.chunks, dtype=ldt)
R = Array(merge(tmp.dask, rdsk), right, chunks=tmp.chunks, dtype=rdt)
L = Array(sharedict.merge(tmp.dask, (left, ldsk)), left, chunks=tmp.chunks, dtype=ldt)
R = Array(sharedict.merge(tmp.dask, (right, rdsk)), right, chunks=tmp.chunks, dtype=rdt)
return L, R


Expand All @@ -151,8 +150,8 @@ def modf(x):
ldt = l.dtype
rdt = r.dtype

L = Array(merge(tmp.dask, ldsk), left, chunks=tmp.chunks, dtype=ldt)
R = Array(merge(tmp.dask, rdsk), right, chunks=tmp.chunks, dtype=rdt)
L = Array(sharedict.merge(tmp.dask, (left, ldsk)), left, chunks=tmp.chunks, dtype=ldt)
R = Array(sharedict.merge(tmp.dask, (right, rdsk)), right, chunks=tmp.chunks, dtype=rdt)
return L, R


Expand Down
11 changes: 6 additions & 5 deletions dask/base.py
Expand Up @@ -11,10 +11,11 @@
from toolz import merge, groupby, curry, identity
from toolz.functoolz import Compose

from . import sharedict
from .compatibility import bind_method, unicode, PY3
from .context import _globals
from .core import flatten
from .utils import Dispatch
from .utils import Dispatch, ensure_dict
from .sharedict import ShareDict

__all__ = ("Base", "compute", "normalize_token", "tokenize", "visualize")
Expand Down Expand Up @@ -97,7 +98,7 @@ def compute(self, **kwargs):
@classmethod
def _get(cls, dsk, keys, get=None, **kwargs):
get = get or _globals['get'] or cls._default_get
dsk2 = cls._optimize(dict(dsk), keys, **kwargs)
dsk2 = cls._optimize(ensure_dict(dsk), keys, **kwargs)
return get(dsk2, keys, **kwargs)

@classmethod
Expand Down Expand Up @@ -251,7 +252,7 @@ def visualize(*args, **kwargs):
optimize_graph = kwargs.pop('optimize_graph', False)
from dask.dot import dot_graph
if optimize_graph:
dsks.extend([arg._optimize(dict(arg.dask), arg._keys())
dsks.extend([arg._optimize(ensure_dict(arg.dask), arg._keys())
for arg in args])
else:
dsks.extend([arg.dask for arg in args])
Expand Down Expand Up @@ -408,12 +409,12 @@ def collections_to_dsk(collections, optimize_graph=True, **kwargs):
groups = {opt: _extract_graph_and_keys(val)
for opt, val in groups.items()}
for opt in optimizations:
groups = {k: [opt(dict(dsk), keys), keys]
groups = {k: [opt(ensure_dict(dsk), keys), keys]
for k, (dsk, keys) in groups.items()}
dsk = merge([opt(dsk, keys, **kwargs)
for opt, (dsk, keys) in groups.items()])
else:
dsk = merge(dict(c.dask) for c in collections)
dsk = ensure_dict(sharedict.merge(*[c.dask for c in collections]))

return dsk

Expand Down
4 changes: 2 additions & 2 deletions dask/dataframe/io/io.py
Expand Up @@ -19,7 +19,7 @@
from ..shuffle import set_partition
from ..utils import insert_meta_param_description

from ...utils import M
from ...utils import M, ensure_dict

lock = Lock()

Expand Down Expand Up @@ -399,7 +399,7 @@ def from_dask_array(x, columns=None):
else:
dsk[name, i] = (pd.DataFrame, chunk, ind, meta.columns)

return new_dd_object(merge(x.dask, dsk), name, meta, divisions)
return new_dd_object(merge(ensure_dict(x.dask), dsk), name, meta, divisions)


def _link(token, result):
Expand Down
4 changes: 2 additions & 2 deletions dask/delayed.py
Expand Up @@ -11,7 +11,7 @@
from . import base, threaded
from .compatibility import apply
from .core import quote
from .utils import concrete, funcname, methodcaller
from .utils import concrete, funcname, methodcaller, ensure_dict
from . import sharedict

__all__ = ['Delayed', 'delayed']
Expand Down Expand Up @@ -69,7 +69,7 @@ def to_task_dask(expr):
if isinstance(expr, base.Base):
name = 'finalize-' + tokenize(expr, pure=True)
keys = expr._keys()
dsk = expr._optimize(dict(expr.dask), keys)
dsk = expr._optimize(ensure_dict(expr.dask), keys)
dsk[name] = (expr._finalize, (concrete, keys))
return name, dsk
if isinstance(expr, tuple) and type(expr) != tuple:
Expand Down

0 comments on commit 20f15aa

Please sign in to comment.