Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask.array tracks dtypes when possible #87

Merged
merged 4 commits into from
Mar 24, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 60 additions & 17 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functools import partial, wraps
from toolz.curried import (identity, pipe, partition, concat, unique, pluck,
frequencies, join, first, memoize, map, groupby, valmap, accumulate,
merge, curry, compose)
merge, curry, compose, reduce)
import numpy as np
from . import chunk
from .slicing import slice_array, insert_many, remove_full_slices
Expand Down Expand Up @@ -365,7 +365,7 @@ def rec_concatenate(arrays, axis=0):
return np.concatenate(arrays, axis=axis)


def map_blocks(x, func, blockshape=None, blockdims=None):
def map_blocks(x, func, blockshape=None, blockdims=None, dtype=None):
""" Map a function across all blocks of a dask array

You must also specify the blockdims/blockshape of the resulting array. If
Expand Down Expand Up @@ -411,7 +411,7 @@ def map_blocks(x, func, blockshape=None, blockdims=None):
else:
dsk = dict(((name,) + k[1:], (func, k)) for k in core.flatten(x._keys()))

return Array(merge(dsk, x.dask), name, blockdims=blockdims)
return Array(merge(dsk, x.dask), name, blockdims=blockdims, dtype=dtype)


def blockdims_from_blockshape(shape, blockshape):
Expand Down Expand Up @@ -441,16 +441,20 @@ class Array(object):
block sizes along each dimension
"""

__slots__ = 'dask', 'name', 'blockdims'
__slots__ = 'dask', 'name', 'blockdims', '_dtype'

def __init__(self, dask, name, shape=None, blockshape=None, blockdims=None):
def __init__(self, dask, name, shape=None, blockshape=None, blockdims=None,
dtype=None):
self.dask = dask
self.name = name
if blockdims is None:
blockdims = blockdims_from_blockshape(shape, blockshape)
if blockdims is None:
raise ValueError("Either give shape and blockshape or blockdims")
self.blockdims = tuple(map(tuple, blockdims))
if isinstance(dtype, (str, list)):
dtype = np.dtype(dtype)
self._dtype = dtype

@property
def numblocks(self):
Expand All @@ -464,7 +468,10 @@ def __len__(self):
return sum(self.blockdims[0])

@property
@memoize(key=lambda args, kwargs: (id(args[0]), args[0].name, args[0].blockdims))
def dtype(self):
if self._dtype is not None:
return self._dtype
if self.shape:
return self[(0,) * self.ndim].compute().dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you cache the dtype if you need to calculate it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Note though that I think we've covered most of the operations. I haven't yet done ones, zeros, and da.random, but I think that this PR covers most things.

else:
Expand Down Expand Up @@ -539,7 +546,13 @@ def __getitem__(self, index):
if (isinstance(index, (str, unicode)) or
( isinstance(index, list)
and all(isinstance(i, (str, unicode)) for i in index))):
return elemwise(getitem, self, index)
if self._dtype is not None and isinstance(index, (str, unicode)):
dt = self._dtype[index]
elif self._dtype is not None and isinstance(index, list):
dt = np.dtype([(name, self._dtype[name]) for name in index])
else:
dt = None
return elemwise(getitem, self, index, dtype=dt)

# Slicing
out = next(names)
Expand All @@ -551,7 +564,8 @@ def __getitem__(self, index):

dsk, blockdims = slice_array(out, self.name, self.blockdims, index)

return Array(merge(self.dask, dsk), out, blockdims=blockdims)
return Array(merge(self.dask, dsk), out, blockdims=blockdims,
dtype=self._dtype)

@wraps(np.dot)
def dot(self, other):
Expand Down Expand Up @@ -687,9 +701,9 @@ def vnorm(self, ord=None, axis=None, keepdims=False):
return vnorm(self, ord=ord, axis=axis, keepdims=keepdims)

@wraps(map_blocks)
def map_blocks(self, func, blockshape=None, blockdims=None):
def map_blocks(self, func, blockshape=None, blockdims=None, dtype=None):
return map_blocks(self, func, blockshape=blockshape,
blockdims=blockdims)
blockdims=blockdims, dtype=dtype)


def from_array(x, blockdims=None, blockshape=None, name=None, **kwargs):
Expand All @@ -707,11 +721,12 @@ def from_array(x, blockdims=None, blockshape=None, name=None, **kwargs):
blockdims = blockdims_from_blockshape(x.shape, blockshape)
name = name or next(names)
dask = merge({name: x}, getem(name, blockdims=blockdims))
return Array(dask, name, blockdims=blockdims)
return Array(dask, name, blockdims=blockdims, dtype=x.dtype)


def atop(func, out, out_ind, *args):
def atop(func, out, out_ind, *args, **kwargs):
""" Array object version of dask.array.top """
dtype = kwargs.get('dtype', None)
arginds = list(partition(2, args)) # [x, ij, y, jk] -> [(x, ij), (y, jk)]
numblocks = dict([(a.name, a.numblocks) for a, ind in arginds])
argindsstr = list(concat([(a.name, ind) for a, ind in arginds]))
Expand All @@ -729,7 +744,8 @@ def atop(func, out, out_ind, *args):
blockdims = tuple(blockdimss[i] for i in out_ind)

dsks = [a.dask for a, _ in arginds]
return Array(merge(dsk, *dsks), out, shape, blockdims=blockdims)
return Array(merge(dsk, *dsks), out, shape, blockdims=blockdims,
dtype=dtype)


def get(dsk, keys, get=None, **kwargs):
Expand Down Expand Up @@ -842,7 +858,13 @@ def stack(seq, axis=0):

dsk = dict(zip(keys, values))
dsk2 = merge(dsk, *[a.dask for a in seq])
return Array(dsk2, name, shape, blockdims=blockdims)

if all(a._dtype is not None for a in seq):
dt = reduce(np.promote_types, [a._dtype for a in seq])
else:
dt = None

return Array(dsk2, name, shape, blockdims=blockdims, dtype=dt)


concatenate_names = ('concatenate-%d' % i for i in count(1))
Expand Down Expand Up @@ -913,15 +935,20 @@ def concatenate(seq, axis=0):
dsk = dict(zip(keys, values))
dsk2 = merge(dsk, *[a.dask for a in seq])

return Array(dsk2, name, shape, blockdims=blockdims)
if all(a._dtype is not None for a in seq):
dt = reduce(np.promote_types, [a._dtype for a in seq])
else:
dt = None

return Array(dsk2, name, shape, blockdims=blockdims, dtype=dt)


@wraps(np.transpose)
def transpose(a, axes=None):
axes = axes or tuple(range(a.ndim))[::-1]
return atop(curry(np.transpose, axes=axes),
next(names), axes,
a, tuple(range(a.ndim)))
a, tuple(range(a.ndim)), dtype=a._dtype)


@curry
Expand Down Expand Up @@ -968,12 +995,17 @@ def tensordot(lhs, rhs, axes=2):
out_index.remove(left_index[l])
right_index[r] = left_index[l]

if lhs._dtype is not None and rhs._dtype is not None :
dt = np.promote_types(lhs._dtype, rhs._dtype)
else:
dt = None

func = many(binop=np.tensordot, reduction=sum,
axes=(left_axes, right_axes))
return atop(func,
next(names), out_index,
lhs, tuple(left_index),
rhs, tuple(right_index))
rhs, tuple(right_index), dtype=dt)


def insert_to_ooc(out, arr):
Expand Down Expand Up @@ -1027,13 +1059,23 @@ def elemwise(op, *args, **kwargs):
arrays = [arg for arg in args if isinstance(arg, Array)]
other = [(i, arg) for i, arg in enumerate(args) if not isinstance(arg, Array)]

if not all(a._dtype is not None for a in arrays):
dt = None
elif all(hasattr(a, 'dtype') for a in args): # Just numpy like things
dt = reduce(np.promote_types, [a.dtype for a in args])
else: # crap, value dependent
vals = [np.empty((1,), dtype=a.dtype) if hasattr(a, 'dtype') else a
for a in args]
dt = op(*vals).dtype

if other:
op2 = partial_by_order(op, other)
else:
op2 = op

return atop(op2, name, expr_inds,
*concat((a, tuple(range(a.ndim)[::-1])) for a in arrays))
*concat((a, tuple(range(a.ndim)[::-1])) for a in arrays),
dtype=dt)


def wrap_elemwise(func):
Expand Down Expand Up @@ -1085,6 +1127,7 @@ def isnull(values):
import pandas as pd
return elemwise(pd.isnull, values)


def notnull(values):
""" pandas.notnull for dask arrays """
return ~isnull(values)
Expand Down
74 changes: 52 additions & 22 deletions dask/array/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..utils import ignoring


def reduction(x, chunk, aggregate, axis=None, keepdims=None):
def reduction(x, chunk, aggregate, axis=None, keepdims=None, dtype=None):
""" General version of reductions

>>> reduction(my_array, np.sum, np.sum, axis=0, keepdims=False) # doctest: +SKIP
Expand All @@ -30,88 +30,116 @@ def reduction(x, chunk, aggregate, axis=None, keepdims=None):
inds2 = tuple(i for i in inds if i not in axis)

result = atop(compose(aggregate2, curry(_concatenate2, axes=axis)),
next(names), inds2, tmp, inds)
next(names), inds2, tmp, inds, dtype=dtype)

if keepdims:
dsk = result.dask.copy()
for k in flatten(result._keys()):
k2 = (k[0],) + insert_many(k[1:], axis, 0)
dsk[k2] = dsk.pop(k)
blockdims = insert_many(result.blockdims, axis, [1])
return Array(dsk, result.name, blockdims=blockdims)
return Array(dsk, result.name, blockdims=blockdims, dtype=dtype)
else:
return result


@wraps(chunk.sum)
def sum(a, axis=None, keepdims=False):
return reduction(a, chunk.sum, chunk.sum, axis=axis, keepdims=keepdims)
if a._dtype is not None:
dt = np.empty((1,), dtype=a._dtype).sum().dtype
else:
dt = None
return reduction(a, chunk.sum, chunk.sum, axis=axis, keepdims=keepdims,
dtype=dt)


@wraps(chunk.prod)
def prod(a, axis=None, keepdims=False):
return reduction(a, chunk.prod, chunk.prod, axis=axis, keepdims=keepdims)
if a._dtype is not None:
dt = np.empty((1,), dtype=a._dtype).prod().dtype
else:
dt = None
return reduction(a, chunk.prod, chunk.prod, axis=axis, keepdims=keepdims,
dtype=dt)


@wraps(chunk.min)
def min(a, axis=None, keepdims=False):
return reduction(a, chunk.min, chunk.min, axis=axis, keepdims=keepdims)
return reduction(a, chunk.min, chunk.min, axis=axis, keepdims=keepdims,
dtype=a._dtype)


@wraps(chunk.max)
def max(a, axis=None, keepdims=False):
return reduction(a, chunk.max, chunk.max, axis=axis, keepdims=keepdims)
return reduction(a, chunk.max, chunk.max, axis=axis, keepdims=keepdims,
dtype=a._dtype)


@wraps(chunk.argmin)
def argmin(a, axis=None):
return arg_reduction(a, chunk.min, chunk.argmin, axis=axis)
return arg_reduction(a, chunk.min, chunk.argmin, axis=axis, dtype='i8')


@wraps(chunk.nanargmin)
def nanargmin(a, axis=None):
return arg_reduction(a, chunk.nanmin, chunk.nanargmin, axis=axis)
return arg_reduction(a, chunk.nanmin, chunk.nanargmin, axis=axis,
dtype='i8')


@wraps(chunk.argmax)
def argmax(a, axis=None):
return arg_reduction(a, chunk.max, chunk.argmax, axis=axis)
return arg_reduction(a, chunk.max, chunk.argmax, axis=axis, dtype='i8')


@wraps(chunk.nanargmax)
def nanargmax(a, axis=None):
return arg_reduction(a, chunk.nanmax, chunk.nanargmax, axis=axis)
return arg_reduction(a, chunk.nanmax, chunk.nanargmax, axis=axis,
dtype='i8')


@wraps(chunk.any)
def any(a, axis=None, keepdims=False):
return reduction(a, chunk.any, chunk.any, axis=axis, keepdims=keepdims)
return reduction(a, chunk.any, chunk.any, axis=axis, keepdims=keepdims,
dtype=np.bool_)


@wraps(chunk.all)
def all(a, axis=None, keepdims=False):
return reduction(a, chunk.all, chunk.all, axis=axis, keepdims=keepdims)
return reduction(a, chunk.all, chunk.all, axis=axis, keepdims=keepdims,
dtype=np.bool_)


@wraps(chunk.nansum)
def nansum(a, axis=None, keepdims=False):
return reduction(a, chunk.nansum, chunk.sum, axis=axis, keepdims=keepdims)
if a._dtype is not None:
dt = chunk.nansum(np.empty((1,), dtype=a._dtype)).dtype
else:
dt = None
return reduction(a, chunk.nansum, chunk.sum, axis=axis, keepdims=keepdims,
dtype=dt)


with ignoring(AttributeError):
@wraps(chunk.nanprod)
def nanprod(a, axis=None, keepdims=False):
return reduction(a, chunk.nanprod, chunk.prod, axis=axis, keepdims=keepdims)
if a._dtype is not None:
dt = np.empty((1,), dtype=a._dtype).nanprod().dtype
else:
dt = None
return reduction(a, chunk.nanprod, chunk.prod, axis=axis,
keepdims=keepdims, dtype=dt)


@wraps(chunk.nanmin)
def nanmin(a, axis=None, keepdims=False):
return reduction(a, chunk.nanmin, chunk.min, axis=axis, keepdims=keepdims)
return reduction(a, chunk.nanmin, chunk.min, axis=axis, keepdims=keepdims,
dtype=a._dtype)


@wraps(chunk.nanmax)
def nanmax(a, axis=None, keepdims=False):
return reduction(a, chunk.nanmax, chunk.max, axis=axis, keepdims=keepdims)
return reduction(a, chunk.nanmax, chunk.max, axis=axis, keepdims=keepdims,
dtype=a._dtype)


def numel(x, **kwargs):
Expand Down Expand Up @@ -139,12 +167,13 @@ def mean_agg(pair, **kwargs):

@wraps(chunk.mean)
def mean(a, axis=None, keepdims=False):
return reduction(a, mean_chunk, mean_agg, axis=axis, keepdims=keepdims)
return reduction(a, mean_chunk, mean_agg, axis=axis, keepdims=keepdims,
dtype='f8')


def nanmean(a, axis=None, keepdims=False):
return reduction(a, partial(mean_chunk, sum=chunk.nansum, numel=nannumel),
mean_agg, axis=axis, keepdims=keepdims)
mean_agg, axis=axis, keepdims=keepdims, dtype='f8')
with ignoring(AttributeError):
nanmean = wraps(chunk.nanmean)(nanmean)

Expand Down Expand Up @@ -172,7 +201,8 @@ def var_agg(A, ddof=None, **kwargs):

@wraps(chunk.var)
def var(a, axis=None, keepdims=False, ddof=0):
return reduction(a, var_chunk, partial(var_agg, ddof=ddof), axis=axis, keepdims=keepdims)
return reduction(a, var_chunk, partial(var_agg, ddof=ddof), axis=axis,
keepdims=keepdims, dtype='f8')


def nanvar(a, axis=None, keepdims=False, ddof=0):
Expand Down Expand Up @@ -230,7 +260,7 @@ def arg_aggregate(func, argfunc, dims, pairs):
return np.choose(args, argmins + offsets)


def arg_reduction(a, func, argfunc, axis=0):
def arg_reduction(a, func, argfunc, axis=0, dtype=None):
""" General version of argmin/argmax

>>> arg_reduction(my_array, np.min, axis=0) # doctest: +SKIP
Expand All @@ -249,4 +279,4 @@ def argreduce(x):

return atop(partial(arg_aggregate, func, argfunc, a.blockdims[axis]),
next(names), [i for i in range(a.ndim) if i != axis],
a2, list(range(a.ndim)))
a2, list(range(a.ndim)), dtype=dtype)
Loading