Skip to content

Commit

Permalink
Reductions support dtype= keyword argument
Browse files Browse the repository at this point in the history
This follows the NumPy API to enable users to specify the output dtype
of reduction operations.

Example
-------

In [1]: import dask.array as da

In [2]: x = da.ones((4, 4), chunks=2)

In [3]: x.sum(axis=1, dtype='i1')
Out[3]: dask.array<x_2, shape=(4,), chunks=((2, 2)), dtype=int8>

In [4]: x.sum(axis=1, dtype='i1').compute()
Out[4]: array([4, 4, 4, 4], dtype=int8)

Fixes dask#270
  • Loading branch information
mrocklin committed Jun 4, 2015
1 parent a16f838 commit 0023d94
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 38 deletions.
20 changes: 10 additions & 10 deletions dask/array/core.py
Expand Up @@ -872,29 +872,29 @@ def argmax(self, axis=None):
return argmax(self, axis=axis)

@wraps(np.sum)
def sum(self, axis=None, keepdims=False):
def sum(self, axis=None, dtype=None, keepdims=False):
from .reductions import sum
return sum(self, axis=axis, keepdims=keepdims)
return sum(self, axis=axis, dtype=dtype, keepdims=keepdims)

@wraps(np.prod)
def prod(self, axis=None, keepdims=False):
def prod(self, axis=None, dtype=None, keepdims=False):
from .reductions import prod
return prod(self, axis=axis, keepdims=keepdims)
return prod(self, axis=axis, dtype=dtype, keepdims=keepdims)

@wraps(np.mean)
def mean(self, axis=None, keepdims=False):
def mean(self, axis=None, dtype=None, keepdims=False):
from .reductions import mean
return mean(self, axis=axis, keepdims=keepdims)
return mean(self, axis=axis, dtype=dtype, keepdims=keepdims)

@wraps(np.std)
def std(self, axis=None, keepdims=False, ddof=0):
def std(self, axis=None, dtype=None, keepdims=False, ddof=0):
from .reductions import std
return std(self, axis=axis, keepdims=keepdims, ddof=ddof)
return std(self, axis=axis, dtype=dtype, keepdims=keepdims, ddof=ddof)

@wraps(np.var)
def var(self, axis=None, keepdims=False, ddof=0):
def var(self, axis=None, dtype=None, keepdims=False, ddof=0):
from .reductions import var
return var(self, axis=axis, keepdims=keepdims, ddof=ddof)
return var(self, axis=axis, dtype=dtype, keepdims=keepdims, ddof=ddof)

def vnorm(self, ord=None, axis=None, keepdims=False):
""" Vector norm """
Expand Down
81 changes: 53 additions & 28 deletions dask/array/reductions.py
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from functools import partial, wraps
from toolz import compose, curry
import inspect

from .core import _concatenate2, Array, atop, names, sqrt, elemwise
from .slicing import insert_many
Expand All @@ -21,6 +22,11 @@ def reduction(x, chunk, aggregate, axis=None, keepdims=None, dtype=None):
if isinstance(axis, int):
axis = (axis,)

if dtype and 'dtype' in inspect.getargspec(chunk).args:
chunk = partial(chunk, dtype=dtype)
if dtype and 'dtype' in inspect.getargspec(aggregate).args:
aggregate = partial(aggregate, dtype=dtype)

chunk2 = partial(chunk, axis=axis, keepdims=True)
aggregate2 = partial(aggregate, axis=axis, keepdims=keepdims)

Expand All @@ -44,8 +50,10 @@ def reduction(x, chunk, aggregate, axis=None, keepdims=None, dtype=None):


@wraps(chunk.sum)
def sum(a, axis=None, keepdims=False):
if a._dtype is not None:
def sum(a, axis=None, dtype=None, keepdims=False):
if dtype is not None:
dt = dtype
elif a._dtype is not None:
dt = np.empty((1,), dtype=a._dtype).sum().dtype
else:
dt = None
Expand All @@ -54,8 +62,10 @@ def sum(a, axis=None, keepdims=False):


@wraps(chunk.prod)
def prod(a, axis=None, keepdims=False):
if a._dtype is not None:
def prod(a, axis=None, dtype=None, keepdims=False):
if dtype is not None:
dt = dtype
elif a._dtype is not None:
dt = np.empty((1,), dtype=a._dtype).prod().dtype
else:
dt = None
Expand Down Expand Up @@ -110,8 +120,10 @@ def all(a, axis=None, keepdims=False):


@wraps(chunk.nansum)
def nansum(a, axis=None, keepdims=False):
if a._dtype is not None:
def nansum(a, axis=None, dtype=None, keepdims=False):
if dtype is not None:
dt = dtype
elif a._dtype is not None:
dt = chunk.nansum(np.empty((1,), dtype=a._dtype)).dtype
else:
dt = None
Expand All @@ -121,8 +133,10 @@ def nansum(a, axis=None, keepdims=False):

with ignoring(AttributeError):
@wraps(chunk.nanprod)
def nanprod(a, axis=None, keepdims=False):
if a._dtype is not None:
def nanprod(a, axis=None, dtype=None, keepdims=False):
if dtype is not None:
dt = dtype
elif a._dtype is not None:
dt = np.empty((1,), dtype=a._dtype).nanprod().dtype
else:
dt = None
Expand Down Expand Up @@ -166,30 +180,35 @@ def mean_agg(pair, **kwargs):


@wraps(chunk.mean)
def mean(a, axis=None, keepdims=False):
if a._dtype is not None:
def mean(a, axis=None, dtype=None, keepdims=False):
if dtype is not None:
dt = dtype
elif a._dtype is not None:
dt = np.mean(np.empty(shape=(1,), dtype=a._dtype)).dtype
else:
dt = None
return reduction(a, mean_chunk, mean_agg, axis=axis, keepdims=keepdims,
dtype=dt)


def nanmean(a, axis=None, keepdims=False):
if a._dtype is not None:
def nanmean(a, axis=None, dtype=None, keepdims=False):
if dtype is not None:
dt = dtype
elif a._dtype is not None:
dt = np.mean(np.empty(shape=(1,), dtype=a._dtype)).dtype
else:
dt = None
return reduction(a, partial(mean_chunk, sum=chunk.nansum, numel=nannumel),
mean_agg, axis=axis, keepdims=keepdims, dtype=dt)

with ignoring(AttributeError):
nanmean = wraps(chunk.nanmean)(nanmean)


def var_chunk(A, sum=chunk.sum, numel=numel, **kwargs):
def var_chunk(A, sum=chunk.sum, numel=numel, dtype='f8', **kwargs):
n = numel(A, **kwargs)
x = sum(A, dtype='f8', **kwargs)
x2 = sum(A**2, dtype='f8', **kwargs)
x = sum(A, dtype=dtype, **kwargs)
x2 = sum(A**2, dtype=dtype, **kwargs)
result = np.empty(shape=n.shape, dtype=[('x', x.dtype),
('x2', x2.dtype),
('n', n.dtype)])
Expand All @@ -209,38 +228,44 @@ def var_agg(A, ddof=None, **kwargs):


@wraps(chunk.var)
def var(a, axis=None, keepdims=False, ddof=0):
def var(a, axis=None, dtype=None, keepdims=False, ddof=0):
if dtype is not None:
dt = dtype
if a._dtype is not None:
dt = np.var(np.empty(shape=(1,), dtype=a._dtype)).dtype
dt = np.var(np.ones(shape=(1,), dtype=a._dtype)).dtype
else:
dt = None
return reduction(a, var_chunk, partial(var_agg, ddof=ddof), axis=axis,
keepdims=keepdims, dtype=dt)


def nanvar(a, axis=None, keepdims=False, ddof=0):
if a._dtype is not None:
dt = np.var(np.empty(shape=(1,), dtype=a._dtype)).dtype
def nanvar(a, axis=None, dtype=None, keepdims=False, ddof=0):
if dtype is not None:
dt = dtype
elif a._dtype is not None:
dt = np.var(np.ones(shape=(1,), dtype=a._dtype)).dtype
else:
dt = None
return reduction(a, partial(var_chunk, sum=chunk.nansum, numel=nannumel),
partial(var_agg, ddof=ddof), axis=axis, keepdims=keepdims,
dtype=dt)

with ignoring(AttributeError):
nanvar = wraps(chunk.nanvar)(nanvar)

@wraps(chunk.std)
def std(a, axis=None, keepdims=False, ddof=0):
return sqrt(a.var(axis=axis, keepdims=keepdims, ddof=ddof))
def std(a, axis=None, dtype=None, keepdims=False, ddof=0):
return sqrt(a.var(axis=axis, dtype=dtype, keepdims=keepdims, ddof=ddof))


def nanstd(a, axis=None, dtype=None, keepdims=False, ddof=0):
return sqrt(nanvar(a, axis=axis, dtype=dtype, keepdims=keepdims, ddof=ddof))

def nanstd(a, axis=None, keepdims=False, ddof=0):
return sqrt(nanvar(a, axis=axis, keepdims=keepdims, ddof=ddof))
with ignoring(AttributeError):
nanstd = wraps(chunk.nanstd)(nanstd)


def vnorm(a, ord=None, axis=None, keepdims=False):
def vnorm(a, ord=None, axis=None, dtype=None, keepdims=False):
""" Vector norm
See np.linalg.norm
Expand All @@ -252,11 +277,11 @@ def vnorm(a, ord=None, axis=None, keepdims=False):
elif ord == -np.inf:
return min(abs(a), axis=axis, keepdims=keepdims)
elif ord == 1:
return sum(abs(a), axis=axis, keepdims=keepdims)
return sum(abs(a), axis=axis, dtype=dtype, keepdims=keepdims)
elif ord % 2 == 0:
return sum(a**ord, axis=axis, keepdims=keepdims)**(1./ord)
return sum(a**ord, axis=axis, dtype=dtype, keepdims=keepdims)**(1./ord)
else:
return sum(abs(a)**ord, axis=axis, keepdims=keepdims)**(1./ord)
return sum(abs(a)**ord, axis=axis, dtype=dtype, keepdims=keepdims)**(1./ord)


def arg_aggregate(func, argfunc, dims, pairs):
Expand Down
7 changes: 7 additions & 0 deletions dask/array/tests/test_reductions.py
Expand Up @@ -53,3 +53,10 @@ def test_nan():
assert eq(np.nanargmax(x, axis=0), da.nanargmax(d, axis=0))
with ignoring(AttributeError):
assert eq(np.nanprod(x), da.nanprod(d))


def test_dtype():
x = np.array([[1, 1], [2, 2], [3, 3]], dtype='i1')
d = da.from_array(x, chunks=2)

assert eq(d.sum(dtype='i1', axis=1), x.sum(dtype='i1', axis=1))

0 comments on commit 0023d94

Please sign in to comment.