Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trace #4717

Merged
merged 2 commits into from
Apr 19, 2019
Merged

Trace #4717

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dask/array/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# Absent for NumPy versions prior to 1.12.
pass
from .reductions import (sum, prod, mean, std, var, any, all, min, max,
moment,
moment, trace,
argmin, argmax,
nansum, nanmean, nanstd, nanvar, nanmin,
nanmax, nanargmin, nanargmax,
Expand Down
5 changes: 5 additions & 0 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,11 @@ def sum(self, axis=None, dtype=None, keepdims=False, split_every=None,
return sum(self, axis=axis, dtype=dtype, keepdims=keepdims,
split_every=split_every, out=out)

@derived_from(np.ndarray)
def trace(self, offset=0, axis1=0, axis2=1, dtype=None):
from .reductions import trace
return trace(self, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)

@derived_from(np.ndarray)
def prod(self, axis=None, dtype=None, keepdims=False, split_every=None,
out=None):
Expand Down
7 changes: 6 additions & 1 deletion dask/array/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .core import _concatenate2, Array, handle_out
from .blockwise import blockwise
from ..blockwise import lol_tuples
from .creation import arange
from .creation import arange, diagonal
from .ufunc import sqrt
from .utils import validate_axis
from .wrap import zeros, ones
Expand Down Expand Up @@ -906,3 +906,8 @@ def argtopk(a, k, axis=-1, split_every=None):
a_plus_idx, chunk=chunk_combine, combine=chunk_combine,
aggregate=aggregate, axis=axis, keepdims=True, dtype=np.intp,
split_every=split_every, concatenate=False, output_size=abs(k))


@wraps(np.trace)
def trace(a, offset=0, axis1=0, axis2=1, dtype=None):
return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1, dtype=dtype)
25 changes: 25 additions & 0 deletions dask/array/tests/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,3 +566,28 @@ def test_regres_3940(func):
assert func(a).name != func(a + 1).name
assert func(a, axis=0).name != func(a).name
assert func(a, axis=0).name != func(a, axis=1).name


def test_trace():
def _assert(a, b, *args, **kwargs):
return assert_eq(a.trace(*args, **kwargs), b.trace(*args, **kwargs))

b = np.arange(12).reshape((3, 4))
a = da.from_array(b, 1)
_assert(a, b)
_assert(a, b, 0)
_assert(a, b, 1)
_assert(a, b, -1)

b = np.arange(8).reshape((2, 2, 2))
a = da.from_array(b, 2)
_assert(a, b)
_assert(a, b, 0)
_assert(a, b, 1)
_assert(a, b, -1)
_assert(a, b, 0, 0, 1)
_assert(a, b, 0, 0, 2)
_assert(a, b, 0, 1, 2, int)
_assert(a, b, 0, 1, 2, float)
_assert(a, b, offset=1, axis1=0, axis2=2, dtype=int)
_assert(a, b, offset=1, axis1=0, axis2=2, dtype=float)
1 change: 1 addition & 0 deletions docs/source/array-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ Top level user functions:
tensordot
tile
topk
trace
transpose
tril
triu
Expand Down