Skip to content

Commit

Permalink
Memoize package_of function
Browse files Browse the repository at this point in the history
This can be quite slow when str(obj) is slow
  • Loading branch information
mrocklin committed May 1, 2017
1 parent 1d8dff8 commit 779d27f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
15 changes: 8 additions & 7 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def _concatenate2(arrays, axes=[]):
return arrays
if len(axes) > 1:
arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays]
module = package_of(max(arrays, key=lambda x: x.__array_priority__)) or np
module = package_of(type(max(arrays, key=lambda x: x.__array_priority__))) or np
return module.concatenate(arrays, axis=axes[0])


Expand Down Expand Up @@ -1790,10 +1790,11 @@ def normalize_chunks(chunks, shape=None):
"""
if chunks is None:
raise ValueError(chunks_none_error_message)
if isinstance(chunks, list):
chunks = tuple(chunks)
if isinstance(chunks, Number):
chunks = (chunks,) * len(shape)
if type(chunks) is not tuple:
if type(chunks) is list:
chunks = tuple(chunks)
if isinstance(chunks, Number):
chunks = (chunks,) * len(shape)
if not chunks and shape and all(s == 0 for s in shape):
chunks = ((),) * len(shape)

Expand Down Expand Up @@ -2479,7 +2480,7 @@ def transpose(a, axes=None):

def _tensordot(a, b, axes):
x = max([a, b], key=lambda x: x.__array_priority__)
module = package_of(x) or np
module = package_of(type(x)) or np
x = module.tensordot(a, b, axes=axes)
ind = [slice(None, None)] * x.ndim
for a in sorted(axes[0]):
Expand Down Expand Up @@ -3432,7 +3433,7 @@ def concatenate3(arrays):

advanced = max(core.flatten(arrays, container=(list, tuple)),
key=lambda x: getattr(x, '__array_priority__', 0))
module = package_of(advanced) or np
module = package_of(type(advanced)) or np
if module is not np and hasattr(module, 'concatenate'):
x = unpack_singleton(arrays)
return _concatenate2(arrays, axes=list(range(x.ndim)))
Expand Down
23 changes: 16 additions & 7 deletions dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,14 +1055,23 @@ def ensure_dict(d):
return dict(d)


def package_of(obj):
_packages = {}


def package_of(typ):
""" Return package containing object's definition
Or return None if not found
"""
# http://stackoverflow.com/questions/43462701/get-package-of-python-object/43462865#43462865
mod = inspect.getmodule(obj)
if not mod:
return
base, _sep, _stem = mod.__name__.partition('.')
return sys.modules[base]
try:
return _packages[typ]
except KeyError:
# http://stackoverflow.com/questions/43462701/get-package-of-python-object/43462865#43462865
mod = inspect.getmodule(typ)
if not mod:
result = None
else:
base, _sep, _stem = mod.__name__.partition('.')
result = sys.modules[base]
_packages[typ] = result
return result

0 comments on commit 779d27f

Please sign in to comment.