Skip to content

Commit

Permalink
Add Dask Array._meta attribute (#4543)
Browse files Browse the repository at this point in the history
This adds a `._meta` attribute to Dask Array that records a small
example of the chunk type.  This should help to maintain metadata
about types, such as is useful for sparse and GPU arrays.

Fixes #2977
  • Loading branch information
pentschev authored and mrocklin committed Jun 7, 2019
1 parent 3a55adf commit 32f0fac
Show file tree
Hide file tree
Showing 19 changed files with 1,065 additions and 143 deletions.
36 changes: 35 additions & 1 deletion dask/array/blockwise.py
Expand Up @@ -3,12 +3,42 @@

import toolz

import numpy as np

from .. import base, utils
from ..delayed import unpack_collections
from ..highlevelgraph import HighLevelGraph
from ..blockwise import blockwise as core_blockwise


def blockwise_meta(func, dtype, *args, **kwargs):
arrays = args[::2]
ndims = [a.ndim if hasattr(a, 'ndim') else 0 for a in arrays]
args_meta = [arg._meta if hasattr(arg, '_meta') else
arg[tuple(slice(0, 0, None) for _ in range(nd))] if nd > 0 else arg
for arg, nd in zip(arrays, ndims)]
kwargs_meta = {k: v._meta if hasattr(v, '_meta') else v for k, v in kwargs.items()}

# TODO: look for alternative to this, causes issues when using map_blocks()
# with np.vectorize, such as dask.array.routines._isnonzero_vec().
if isinstance(func, np.vectorize):
meta = func(*args_meta)
return meta.astype(dtype)

try:
meta = func(*args_meta, **kwargs_meta)
except TypeError:
# The concatenate argument is an argument introduced by this
# function and may not be support by some external functions,
# such as in NumPy
kwargs_meta.pop('concatenate', None)
meta = func(*args_meta, **kwargs_meta)
except ValueError:
return None

return meta.astype(dtype)


def blockwise(func, out_ind, *args, **kwargs):
""" Tensor operation: Generalized inner and outer products
Expand Down Expand Up @@ -203,7 +233,11 @@ def blockwise(func, out_ind, *args, **kwargs):
"adjust_chunks values must be callable, int, or tuple")
chunks = tuple(chunks)

return Array(graph, out, chunks, dtype=dtype)
try:
meta = blockwise_meta(func, dtype, *args, **kwargs)
return Array(graph, out, chunks, meta=meta)
except Exception:
return Array(graph, out, chunks, dtype=dtype)


def atop(*args, **kwargs):
Expand Down
88 changes: 69 additions & 19 deletions dask/array/core.py
Expand Up @@ -731,7 +731,7 @@ def store(sources, targets, lock=True, regions=None, compute=True,
sources_dsk,
list(core.flatten([e.__dask_keys__() for e in sources]))
)
sources2 = [Array(sources_dsk, e.name, e.chunks, e.dtype) for e in sources]
sources2 = [Array(sources_dsk, e.name, e.chunks, meta=e) for e in sources]

# Optimize all targets together
targets2 = []
Expand Down Expand Up @@ -774,7 +774,7 @@ def store(sources, targets, lock=True, regions=None, compute=True,
)

result = tuple(
Array(load_store_dsk, 'load-store-%s' % t, s.chunks, s.dtype)
Array(load_store_dsk, 'load-store-%s' % t, s.chunks, meta=s)
for s, t in zip(sources, toks)
)

Expand Down Expand Up @@ -857,28 +857,40 @@ class Array(DaskMethodsMixin):
Shape of the entire array
chunks: iterable of tuples
block sizes along each dimension
dtype : str or dtype
Typecode or data-type for the new Dask Array
meta : empty ndarray
empty ndarray created with same NumPy backend, ndim and dtype as the
Dask Array being created (overrides dtype)
See Also
--------
dask.array.from_array
"""
__slots__ = 'dask', '_name', '_cached_keys', '_chunks', 'dtype'
__slots__ = 'dask', '_name', '_cached_keys', '_chunks', '_meta'

def __new__(cls, dask, name, chunks, dtype, shape=None):
def __new__(cls, dask, name, chunks, dtype=None, meta=None, shape=None):
self = super(Array, cls).__new__(cls)
assert isinstance(dask, Mapping)
if not isinstance(dask, HighLevelGraph):
dask = HighLevelGraph.from_collections(name, dask, dependencies=())
self.dask = dask
self.name = name
if dtype is None:
raise ValueError("You must specify the dtype of the array")
self.dtype = np.dtype(dtype)
if dtype is not None and meta is not None:
raise TypeError("You must not specify both meta and dtype")
if dtype is None and meta is None:
raise ValueError("You must specify the meta or dtype of the array")

self._chunks = normalize_chunks(chunks, shape, dtype=self.dtype)
self._chunks = normalize_chunks(chunks, shape, dtype=dtype or meta.dtype)
if self._chunks is None:
raise ValueError(CHUNKS_NONE_ERROR_MESSAGE)

if dtype:
self._meta = np.empty((0,) * self.ndim, dtype=dtype)
else:
from .utils import meta_from_array
self._meta = meta_from_array(meta, meta.ndim)

for plugin in config.get('array_plugins', ()):
result = plugin(self)
if result is not None:
Expand Down Expand Up @@ -944,8 +956,8 @@ def chunksize(self):
return tuple(max(c) for c in self.chunks)

@property
def _meta(self):
return np.empty(shape=(), dtype=self.dtype)
def dtype(self):
return self._meta.dtype

def _get_chunks(self):
return self._chunks
Expand Down Expand Up @@ -1217,7 +1229,7 @@ def __setitem__(self, key, value):
if isinstance(value, Array) and value.ndim > 1:
raise ValueError('boolean index array should have 1 dimension')
y = where(key, value, self)
self.dtype = y.dtype
self._meta = y._meta
self.dask = y.dask
self.name = y.name
self._chunks = y.chunks
Expand Down Expand Up @@ -1267,7 +1279,37 @@ def __getitem__(self, index):
dsk, chunks = slice_array(out, self.name, self.chunks, index2)

graph = HighLevelGraph.from_collections(out, dsk, dependencies=[self])
return Array(graph, out, chunks, dtype=self.dtype)

if isinstance(index2, tuple):
new_index = []
for i in range(len(index2)):
if not isinstance(index2[i], tuple):
types = [Integral, list, np.ndarray]
cond = any([isinstance(index2[i], t) for t in types])
new_index.append(slice(0, 0) if cond else index2[i])
else:
new_index.append(tuple([Ellipsis if j is not None else
None for j in index2[i]]))
new_index = tuple(new_index)
meta = self._meta[new_index].astype(self.dtype)
else:
meta = self._meta[index2].astype(self.dtype)

# Exception for object dtype and ndim == 1, which results in primitive types
if not (meta.dtype == object and meta.ndim == 1):

# If meta still has more dimensions than actual data
if meta.ndim > len(chunks):
meta = np.sum(meta, axis=tuple([i for i in range(meta.ndim - len(chunks))]))

# Ensure all dimensions are 0
if not np.isscalar(meta):
meta = meta[tuple([slice(0, 0) for i in range(meta.ndim)])]
# If return array is 0-D, ensure _meta is 0-D
if len(chunks) == 0:
meta = meta.sum()

return Array(graph, out, chunks, meta=meta)

def _vindex(self, key):
if not isinstance(key, tuple):
Expand Down Expand Up @@ -1336,7 +1378,7 @@ def _blocks(self, index):
layer = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}

graph = HighLevelGraph.from_collections(name, layer, dependencies=[self])
return Array(graph, name, chunks, self.dtype)
return Array(graph, name, chunks, meta=self)

@property
def blocks(self):
Expand Down Expand Up @@ -1887,7 +1929,7 @@ def copy(self):
if self.npartitions == 1:
return self.map_blocks(M.copy)
else:
return Array(self.dask, self.name, self.chunks, self.dtype)
return Array(self.dask, self.name, self.chunks, meta=self)

def __deepcopy__(self, memo):
c = self.copy()
Expand Down Expand Up @@ -2331,7 +2373,12 @@ def from_array(x, chunks='auto', name=None, lock=False, asarray=True, fancy=True
dtype=x.dtype)
dsk[original_name] = x

return Array(dsk, name, chunks, dtype=x.dtype)
# Workaround for TileDB, its indexing is 1-based,
# and doesn't seems to support 0-length slicing
if x.__class__.__module__.split('.')[0] == 'tiledb' and hasattr(x, '_ctx_'):
return Array(dsk, name, chunks, dtype=x.dtype)

return Array(dsk, name, chunks, meta=x)


def from_zarr(url, component=None, storage_options=None, chunks=None,name=None, **kwargs):
Expand Down Expand Up @@ -2947,6 +2994,11 @@ def concatenate(seq, axis=0, allow_unknown_chunksizes=False):
for i, ind in enumerate(inds):
ind[axis] = -(i + 1)

from .utils import meta_from_array
metas = [getattr(s, '_meta', s) for s in seq]
metas = [meta_from_array(m, getattr(m, 'ndim', 1)) for m in metas]
meta = np.concatenate(metas)

uc_args = list(concat(zip(seq, inds)))
_, seq = unify_chunks(*uc_args, warn=False)

Expand All @@ -2961,8 +3013,6 @@ def concatenate(seq, axis=0, allow_unknown_chunksizes=False):
if len(set(seq_dtypes)) > 1:
dt = reduce(np.promote_types, seq_dtypes)
seq = [x.astype(dt) for x in seq]
else:
dt = seq_dtypes[0]

names = [a.name for a in seq]

Expand All @@ -2976,7 +3026,7 @@ def concatenate(seq, axis=0, allow_unknown_chunksizes=False):
dsk = dict(zip(keys, values))
graph = HighLevelGraph.from_collections(name, dsk, dependencies=seq)

return Array(graph, name, chunks, dtype=dt)
return Array(graph, name, chunks, meta=meta)


def load_store_chunk(x, out, index, lock, return_stored, load_stored):
Expand Down Expand Up @@ -3357,7 +3407,7 @@ def handle_out(out, result):
"out=%s, result=%s" % (str(out.shape), str(result.shape)))
out._chunks = result.chunks
out.dask = result.dask
out.dtype = result.dtype
out._meta = result._meta
out.name = result.name
elif out is not None:
msg = ("The out parameter is not fully supported."
Expand Down
32 changes: 21 additions & 11 deletions dask/array/creation.py
Expand Up @@ -11,13 +11,13 @@
from ..highlevelgraph import HighLevelGraph
from ..base import tokenize
from ..compatibility import Sequence
from ..utils import derived_from
from . import chunk
from .core import (Array, asarray, normalize_chunks,
stack, concatenate, block,
broadcast_to, broadcast_arrays)
from .wrap import empty, ones, zeros, full
from .utils import AxisError
from ..utils import derived_from
from .utils import AxisError, meta_from_array, zeros_like_safe


def empty_like(a, dtype=None, chunks=None):
Expand Down Expand Up @@ -473,7 +473,11 @@ def eye(N, chunks='auto', M=None, k=0, dtype=float):
@derived_from(np)
def diag(v):
name = 'diag-' + tokenize(v)
if isinstance(v, np.ndarray):

meta = meta_from_array(v, 2 if v.ndim == 1 else 1)

if (isinstance(v, np.ndarray) or
(hasattr(v, '__array_function__') and not isinstance(v, Array))):
if v.ndim == 1:
chunks = ((v.shape[0],), (v.shape[0],))
dsk = {(name, 0, 0): (np.diag, v)}
Expand All @@ -482,7 +486,7 @@ def diag(v):
dsk = {(name, 0): (np.diag, v)}
else:
raise ValueError("Array must be 1d or 2d only")
return Array(dsk, name, chunks, dtype=v.dtype)
return Array(dsk, name, chunks, meta=meta)
if not isinstance(v, Array):
raise TypeError("v must be a dask array or numpy array, "
"got {0}".format(type(v)))
Expand All @@ -491,7 +495,7 @@ def diag(v):
dsk = {(name, i): (np.diag, row[i])
for i, row in enumerate(v.__dask_keys__())}
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[v])
return Array(graph, name, (v.chunks[0],), dtype=v.dtype)
return Array(graph, name, (v.chunks[0],), meta=meta)
else:
raise NotImplementedError("Extracting diagonals from non-square "
"chunked arrays")
Expand All @@ -505,9 +509,10 @@ def diag(v):
dsk[key] = (np.diag, blocks[i])
else:
dsk[key] = (np.zeros, (m, n))
dsk[key] = (partial(zeros_like_safe, shape=(m, n)), meta)

graph = HighLevelGraph.from_collections(name, dsk, dependencies=[v])
return Array(graph, name, (chunks_1d, chunks_1d), dtype=v.dtype)
return Array(graph, name, (chunks_1d, chunks_1d), meta=meta)


@derived_from(np)
Expand Down Expand Up @@ -574,7 +579,8 @@ def _diag_len(dim1, dim2, offset):
chunks = left_chunks + right_shape

graph = HighLevelGraph.from_collections(name, dsk, dependencies=[a])
return Array(graph, name, shape=shape, chunks=chunks, dtype=a.dtype)
meta = meta_from_array(a, len(shape))
return Array(graph, name, shape=shape, chunks=chunks, meta=meta)


def triu(m, k=0):
Expand Down Expand Up @@ -616,13 +622,15 @@ def triu(m, k=0):
for i in range(rdim):
for j in range(hdim):
if chunk * (j - i + 1) < k:
dsk[(name, i, j)] = (np.zeros, (m.chunks[0][i], m.chunks[1][j]))
dsk[(name, i, j)] = (partial(zeros_like_safe,
shape=(m.chunks[0][i], m.chunks[1][j])),
m._meta)
elif chunk * (j - i - 1) < k <= chunk * (j - i + 1):
dsk[(name, i, j)] = (np.triu, (m.name, i, j), k - (chunk * (j - i)))
else:
dsk[(name, i, j)] = (m.name, i, j)
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[m])
return Array(graph, name, shape=m.shape, chunks=m.chunks, dtype=m.dtype)
return Array(graph, name, shape=m.shape, chunks=m.chunks, meta=m)


def tril(m, k=0):
Expand Down Expand Up @@ -668,9 +676,11 @@ def tril(m, k=0):
elif chunk * (j - i - 1) < k <= chunk * (j - i + 1):
dsk[(name, i, j)] = (np.tril, (m.name, i, j), k - (chunk * (j - i)))
else:
dsk[(name, i, j)] = (np.zeros, (m.chunks[0][i], m.chunks[1][j]))
dsk[(name, i, j)] = (partial(zeros_like_safe,
shape=(m.chunks[0][i], m.chunks[1][j])),
m._meta)
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[m])
return Array(graph, name, shape=m.shape, chunks=m.chunks, dtype=m.dtype)
return Array(graph, name, shape=m.shape, chunks=m.chunks, meta=m)


def _np_fromfunction(func, shape, dtype, offset, func_kwargs):
Expand Down
4 changes: 3 additions & 1 deletion dask/array/gufunc.py
Expand Up @@ -10,6 +10,7 @@
from toolz import concat, merge, unique

from .core import Array, asarray, blockwise, getitem, apply_infer_dtype
from .utils import normalize_meta
from ..highlevelgraph import HighLevelGraph
from ..core import flatten

Expand Down Expand Up @@ -397,11 +398,12 @@ def apply_gufunc(func, signature, *args, **kwargs):
leaf_name = "%s_%d-%s" % (name, i, token)
leaf_dsk = {(leaf_name,) + key[1:] + core_chunkinds: ((getitem, key, i) if nout else key) for key in keys}
graph = HighLevelGraph.from_collections(leaf_name, leaf_dsk, dependencies=[tmp])
meta = normalize_meta(tmp._meta, len(output_shape), dtype=odt)
leaf_arr = Array(graph,
leaf_name,
chunks=output_chunks,
shape=output_shape,
dtype=odt)
meta=meta)

### Axes:
if keepdims:
Expand Down

0 comments on commit 32f0fac

Please sign in to comment.