Skip to content

Commit

Permalink
Dask.array tracks dtypes when possible
Browse files Browse the repository at this point in the history
The Array class now holds a `_dtype` attribute.  Various dask.array functions
propagate dtype information, repeating a bit of numpy logic where necessary.

If this logic fails then we fall back on computation of a small element of the
dask array.

Fixes dask#64
  • Loading branch information
mrocklin committed Mar 24, 2015
1 parent 88639c7 commit 475ad60
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 38 deletions.
73 changes: 57 additions & 16 deletions dask/array/core.py
Expand Up @@ -365,7 +365,7 @@ def rec_concatenate(arrays, axis=0):
return np.concatenate(arrays, axis=axis)


def map_blocks(x, func, blockshape=None, blockdims=None):
def map_blocks(x, func, blockshape=None, blockdims=None, dtype=None):
""" Map a function across all blocks of a dask array
You must also specify the blockdims/blockshape of the resulting array. If
Expand Down Expand Up @@ -411,7 +411,7 @@ def map_blocks(x, func, blockshape=None, blockdims=None):
else:
dsk = dict(((name,) + k[1:], (func, k)) for k in core.flatten(x._keys()))

return Array(merge(dsk, x.dask), name, blockdims=blockdims)
return Array(merge(dsk, x.dask), name, blockdims=blockdims, dtype=dtype)


def blockdims_from_blockshape(shape, blockshape):
Expand Down Expand Up @@ -441,16 +441,20 @@ class Array(object):
block sizes along each dimension
"""

__slots__ = 'dask', 'name', 'blockdims'
__slots__ = 'dask', 'name', 'blockdims', '_dtype'

def __init__(self, dask, name, shape=None, blockshape=None, blockdims=None):
def __init__(self, dask, name, shape=None, blockshape=None, blockdims=None,
dtype=None):
self.dask = dask
self.name = name
if blockdims is None:
blockdims = blockdims_from_blockshape(shape, blockshape)
if blockdims is None:
raise ValueError("Either give shape and blockshape or blockdims")
self.blockdims = tuple(map(tuple, blockdims))
if isinstance(dtype, (str, list)):
dtype = np.dtype(dtype)
self._dtype = dtype

@property
def numblocks(self):
Expand All @@ -465,6 +469,8 @@ def __len__(self):

@property
def dtype(self):
if self._dtype is not None:
return self._dtype
if self.shape:
return self[(0,) * self.ndim].compute().dtype
else:
Expand Down Expand Up @@ -539,7 +545,13 @@ def __getitem__(self, index):
if (isinstance(index, (str, unicode)) or
( isinstance(index, list)
and all(isinstance(i, (str, unicode)) for i in index))):
return elemwise(getitem, self, index)
if self._dtype is not None and isinstance(index, (str, unicode)):
dt = self._dtype[index]
elif self._dtype is not None and isinstance(index, list):
dt = np.dtype([(name, self._dtype[name]) for name in index])
else:
dt = None
return elemwise(getitem, self, index, dtype=dt)

# Slicing
out = next(names)
Expand All @@ -551,7 +563,8 @@ def __getitem__(self, index):

dsk, blockdims = slice_array(out, self.name, self.blockdims, index)

return Array(merge(self.dask, dsk), out, blockdims=blockdims)
return Array(merge(self.dask, dsk), out, blockdims=blockdims,
dtype=self._dtype)

@wraps(np.dot)
def dot(self, other):
Expand Down Expand Up @@ -687,9 +700,9 @@ def vnorm(self, ord=None, axis=None, keepdims=False):
return vnorm(self, ord=ord, axis=axis, keepdims=keepdims)

@wraps(map_blocks)
def map_blocks(self, func, blockshape=None, blockdims=None):
def map_blocks(self, func, blockshape=None, blockdims=None, dtype=None):
return map_blocks(self, func, blockshape=blockshape,
blockdims=blockdims)
blockdims=blockdims, dtype=dtype)


def from_array(x, blockdims=None, blockshape=None, name=None, **kwargs):
Expand All @@ -707,11 +720,12 @@ def from_array(x, blockdims=None, blockshape=None, name=None, **kwargs):
blockdims = blockdims_from_blockshape(x.shape, blockshape)
name = name or next(names)
dask = merge({name: x}, getem(name, blockdims=blockdims))
return Array(dask, name, blockdims=blockdims)
return Array(dask, name, blockdims=blockdims, dtype=x.dtype)


def atop(func, out, out_ind, *args):
def atop(func, out, out_ind, *args, **kwargs):
""" Array object version of dask.array.top """
dtype = kwargs.get('dtype', None)
arginds = list(partition(2, args)) # [x, ij, y, jk] -> [(x, ij), (y, jk)]
numblocks = dict([(a.name, a.numblocks) for a, ind in arginds])
argindsstr = list(concat([(a.name, ind) for a, ind in arginds]))
Expand All @@ -729,7 +743,8 @@ def atop(func, out, out_ind, *args):
blockdims = tuple(blockdimss[i] for i in out_ind)

dsks = [a.dask for a, _ in arginds]
return Array(merge(dsk, *dsks), out, shape, blockdims=blockdims)
return Array(merge(dsk, *dsks), out, shape, blockdims=blockdims,
dtype=dtype)


def get(dsk, keys, get=None, **kwargs):
Expand Down Expand Up @@ -842,7 +857,13 @@ def stack(seq, axis=0):

dsk = dict(zip(keys, values))
dsk2 = merge(dsk, *[a.dask for a in seq])
return Array(dsk2, name, shape, blockdims=blockdims)

if all(a._dtype is not None for a in seq):
dt = reduce(np.promote_types, [a._dtype for a in seq])
else:
dt = None

return Array(dsk2, name, shape, blockdims=blockdims, dtype=dt)


concatenate_names = ('concatenate-%d' % i for i in count(1))
Expand Down Expand Up @@ -913,15 +934,20 @@ def concatenate(seq, axis=0):
dsk = dict(zip(keys, values))
dsk2 = merge(dsk, *[a.dask for a in seq])

return Array(dsk2, name, shape, blockdims=blockdims)
if all(a._dtype is not None for a in seq):
dt = reduce(np.promote_types, [a._dtype for a in seq])
else:
dt = None

return Array(dsk2, name, shape, blockdims=blockdims, dtype=dt)


@wraps(np.transpose)
def transpose(a, axes=None):
axes = axes or tuple(range(a.ndim))[::-1]
return atop(curry(np.transpose, axes=axes),
next(names), axes,
a, tuple(range(a.ndim)))
a, tuple(range(a.ndim)), dtype=a._dtype)


@curry
Expand Down Expand Up @@ -968,12 +994,17 @@ def tensordot(lhs, rhs, axes=2):
out_index.remove(left_index[l])
right_index[r] = left_index[l]

if lhs._dtype is not None and rhs._dtype is not None :
dt = np.promote_types(lhs._dtype, rhs._dtype)
else:
dt = None

func = many(binop=np.tensordot, reduction=sum,
axes=(left_axes, right_axes))
return atop(func,
next(names), out_index,
lhs, tuple(left_index),
rhs, tuple(right_index))
rhs, tuple(right_index), dtype=dt)


def insert_to_ooc(out, arr):
Expand Down Expand Up @@ -1027,13 +1058,23 @@ def elemwise(op, *args, **kwargs):
arrays = [arg for arg in args if isinstance(arg, Array)]
other = [(i, arg) for i, arg in enumerate(args) if not isinstance(arg, Array)]

if not all(a._dtype is not None for a in arrays):
dt = None
elif all(hasattr(a, 'dtype') for a in args): # Just numpy like things
dt = reduce(np.promote_types, [a.dtype for a in args])
else: # crap, value dependent
vals = [np.empty((1,), dtype=a.dtype) if hasattr(a, 'dtype') else a
for a in args]
dt = op(*vals).dtype

if other:
op2 = partial_by_order(op, other)
else:
op2 = op

return atop(op2, name, expr_inds,
*concat((a, tuple(range(a.ndim)[::-1])) for a in arrays))
*concat((a, tuple(range(a.ndim)[::-1])) for a in arrays),
dtype=dt)


def wrap_elemwise(func):
Expand Down
74 changes: 52 additions & 22 deletions dask/array/reductions.py
Expand Up @@ -11,7 +11,7 @@
from ..utils import ignoring


def reduction(x, chunk, aggregate, axis=None, keepdims=None):
def reduction(x, chunk, aggregate, axis=None, keepdims=None, dtype=None):
""" General version of reductions
>>> reduction(my_array, np.sum, np.sum, axis=0, keepdims=False) # doctest: +SKIP
Expand All @@ -30,88 +30,116 @@ def reduction(x, chunk, aggregate, axis=None, keepdims=None):
inds2 = tuple(i for i in inds if i not in axis)

result = atop(compose(aggregate2, curry(_concatenate2, axes=axis)),
next(names), inds2, tmp, inds)
next(names), inds2, tmp, inds, dtype=dtype)

if keepdims:
dsk = result.dask.copy()
for k in flatten(result._keys()):
k2 = (k[0],) + insert_many(k[1:], axis, 0)
dsk[k2] = dsk.pop(k)
blockdims = insert_many(result.blockdims, axis, [1])
return Array(dsk, result.name, blockdims=blockdims)
return Array(dsk, result.name, blockdims=blockdims, dtype=dtype)
else:
return result


@wraps(chunk.sum)
def sum(a, axis=None, keepdims=False):
return reduction(a, chunk.sum, chunk.sum, axis=axis, keepdims=keepdims)
if a._dtype is not None:
dt = np.empty((1,), dtype=a._dtype).sum().dtype
else:
dt = None
return reduction(a, chunk.sum, chunk.sum, axis=axis, keepdims=keepdims,
dtype=dt)


@wraps(chunk.prod)
def prod(a, axis=None, keepdims=False):
return reduction(a, chunk.prod, chunk.prod, axis=axis, keepdims=keepdims)
if a._dtype is not None:
dt = np.empty((1,), dtype=a._dtype).prod().dtype
else:
dt = None
return reduction(a, chunk.prod, chunk.prod, axis=axis, keepdims=keepdims,
dtype=dt)


@wraps(chunk.min)
def min(a, axis=None, keepdims=False):
return reduction(a, chunk.min, chunk.min, axis=axis, keepdims=keepdims)
return reduction(a, chunk.min, chunk.min, axis=axis, keepdims=keepdims,
dtype=a._dtype)


@wraps(chunk.max)
def max(a, axis=None, keepdims=False):
return reduction(a, chunk.max, chunk.max, axis=axis, keepdims=keepdims)
return reduction(a, chunk.max, chunk.max, axis=axis, keepdims=keepdims,
dtype=a._dtype)


@wraps(chunk.argmin)
def argmin(a, axis=None):
return arg_reduction(a, chunk.min, chunk.argmin, axis=axis)
return arg_reduction(a, chunk.min, chunk.argmin, axis=axis, dtype='i8')


@wraps(chunk.nanargmin)
def nanargmin(a, axis=None):
return arg_reduction(a, chunk.nanmin, chunk.nanargmin, axis=axis)
return arg_reduction(a, chunk.nanmin, chunk.nanargmin, axis=axis,
dtype='i8')


@wraps(chunk.argmax)
def argmax(a, axis=None):
return arg_reduction(a, chunk.max, chunk.argmax, axis=axis)
return arg_reduction(a, chunk.max, chunk.argmax, axis=axis, dtype='i8')


@wraps(chunk.nanargmax)
def nanargmax(a, axis=None):
return arg_reduction(a, chunk.nanmax, chunk.nanargmax, axis=axis)
return arg_reduction(a, chunk.nanmax, chunk.nanargmax, axis=axis,
dtype='i8')


@wraps(chunk.any)
def any(a, axis=None, keepdims=False):
return reduction(a, chunk.any, chunk.any, axis=axis, keepdims=keepdims)
return reduction(a, chunk.any, chunk.any, axis=axis, keepdims=keepdims,
dtype=np.bool_)


@wraps(chunk.all)
def all(a, axis=None, keepdims=False):
return reduction(a, chunk.all, chunk.all, axis=axis, keepdims=keepdims)
return reduction(a, chunk.all, chunk.all, axis=axis, keepdims=keepdims,
dtype=np.bool_)


@wraps(chunk.nansum)
def nansum(a, axis=None, keepdims=False):
return reduction(a, chunk.nansum, chunk.sum, axis=axis, keepdims=keepdims)
if a._dtype is not None:
dt = chunk.nansum(np.empty((1,), dtype=a._dtype)).dtype
else:
dt = None
return reduction(a, chunk.nansum, chunk.sum, axis=axis, keepdims=keepdims,
dtype=dt)


with ignoring(AttributeError):
@wraps(chunk.nanprod)
def nanprod(a, axis=None, keepdims=False):
return reduction(a, chunk.nanprod, chunk.prod, axis=axis, keepdims=keepdims)
if a._dtype is not None:
dt = np.empty((1,), dtype=a._dtype).nanprod().dtype
else:
dt = None
return reduction(a, chunk.nanprod, chunk.prod, axis=axis,
keepdims=keepdims, dtype=dt)


@wraps(chunk.nanmin)
def nanmin(a, axis=None, keepdims=False):
return reduction(a, chunk.nanmin, chunk.min, axis=axis, keepdims=keepdims)
return reduction(a, chunk.nanmin, chunk.min, axis=axis, keepdims=keepdims,
dtype=a._dtype)


@wraps(chunk.nanmax)
def nanmax(a, axis=None, keepdims=False):
return reduction(a, chunk.nanmax, chunk.max, axis=axis, keepdims=keepdims)
return reduction(a, chunk.nanmax, chunk.max, axis=axis, keepdims=keepdims,
dtype=a._dtype)


def numel(x, **kwargs):
Expand Down Expand Up @@ -139,12 +167,13 @@ def mean_agg(pair, **kwargs):

@wraps(chunk.mean)
def mean(a, axis=None, keepdims=False):
return reduction(a, mean_chunk, mean_agg, axis=axis, keepdims=keepdims)
return reduction(a, mean_chunk, mean_agg, axis=axis, keepdims=keepdims,
dtype='f8')


def nanmean(a, axis=None, keepdims=False):
return reduction(a, partial(mean_chunk, sum=chunk.nansum, numel=nannumel),
mean_agg, axis=axis, keepdims=keepdims)
mean_agg, axis=axis, keepdims=keepdims, dtype='f8')
with ignoring(AttributeError):
nanmean = wraps(chunk.nanmean)(nanmean)

Expand Down Expand Up @@ -172,7 +201,8 @@ def var_agg(A, ddof=None, **kwargs):

@wraps(chunk.var)
def var(a, axis=None, keepdims=False, ddof=0):
return reduction(a, var_chunk, partial(var_agg, ddof=ddof), axis=axis, keepdims=keepdims)
return reduction(a, var_chunk, partial(var_agg, ddof=ddof), axis=axis,
keepdims=keepdims, dtype='f8')


def nanvar(a, axis=None, keepdims=False, ddof=0):
Expand Down Expand Up @@ -230,7 +260,7 @@ def arg_aggregate(func, argfunc, dims, pairs):
return np.choose(args, argmins + offsets)


def arg_reduction(a, func, argfunc, axis=0):
def arg_reduction(a, func, argfunc, axis=0, dtype=None):
""" General version of argmin/argmax
>>> arg_reduction(my_array, np.min, axis=0) # doctest: +SKIP
Expand All @@ -249,4 +279,4 @@ def argreduce(x):

return atop(partial(arg_aggregate, func, argfunc, a.blockdims[axis]),
next(names), [i for i in range(a.ndim) if i != axis],
a2, list(range(a.ndim)))
a2, list(range(a.ndim)), dtype=dtype)

0 comments on commit 475ad60

Please sign in to comment.