Skip to content

Commit

Permalink
da.diagonal chunk aware
Browse files Browse the repository at this point in the history
  • Loading branch information
horta committed Feb 10, 2019
1 parent cb57c3c commit 079ac89
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 92 deletions.
96 changes: 49 additions & 47 deletions dask/array/creation.py
Expand Up @@ -16,6 +16,7 @@
stack, concatenate, block,
broadcast_to, broadcast_arrays)
from .wrap import empty, ones, zeros, full
from .utils import AxisError


def empty_like(a, dtype=None, chunks=None):
Expand Down Expand Up @@ -509,69 +510,70 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
# NumPy uses `diag` as we do here.
raise ValueError("diag requires an array of at least two dimensions")

if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same")
def _axis_fmt(axis, name, ndim):
if axis < 0:
t = ndim + axis
if t < 0:
raise AxisError(f"{name}: axis {axis} is out of bounds for "
f"array of dimension {ndim}")
axis = t
return axis

if offset >= 0:
dim1_offset = 0
dim2_offset = +offset
else:
dim1_offset = -offset
dim2_offset = 0
axis1 = _axis_fmt(axis1, "axis1", a.ndim)
axis2 = _axis_fmt(axis2, "axis2", a.ndim)

if axis1 > axis2:
axis1, axis2 = axis2, axis1
dim1_offset, dim2_offset = dim2_offset, dim1_offset
offset *= -1

left_idx = list(range(len(a.shape)))
del left_idx[axis2]
del left_idx[axis1]

left_shape = tuple([a.shape[i] for i in left_idx])
right_shape = (max(min(a.shape[axis1] - dim1_offset,
a.shape[axis2] - dim2_offset), 0),)
shape = left_shape + right_shape
if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same")

if isinstance(a, np.ndarray):

chunks = (1,) * len(left_shape) + (right_shape,)
chunks = normalize_chunks(chunks=chunks, shape=shape)

dsk = {}
a_indices = [range(i) for i in a.shape]
a_indices[axis1] = [slice(None)]
a_indices[axis2] = [slice(None)]
sub_shape = (a.shape[axis1], a.shape[axis2]) + (1,) * (len(shape) - 1)
for ai in product(*a_indices):
ci = tuple(i for i in ai if not isinstance(i, slice))
dsk[(name,) + ci + (0,)] = (np.diagonal, a[ai].reshape(sub_shape),
offset)

return Array(dsk, name, chunks, dtype=a.dtype)
return diagonal(asarray(a), offset, axis1, axis2)

if not isinstance(a, Array):
raise TypeError("a must be a dask array or numpy array, "
"got {0}".format(type(a)))

chunks = list(a.chunks)
chunks[axis1] = -1
chunks[axis2] = -1
a = a.rechunk(chunks)
if axis1 > axis2:
axis1, axis2 = axis2, axis1
offset = -offset

a_indices = [range(len(i)) for i in a.chunks]
def _diag_len(dim1, dim2, offset):
return max(0, min(min(dim1, dim2), dim1 + offset, dim2 - offset))

def _diag(x, offset, axis1, axis2):
def _diagonal(x, offset, axis1, axis2):
return np.diagonal(asarray(x), axis1=axis1, axis2=axis2,
offset=offset)

diag_chunks = []
chunk_offsets = []
cum1 = [0] + list(np.cumsum(a.chunks[axis1]))[:-1]
cum2 = [0] + list(np.cumsum(a.chunks[axis2]))[:-1]
for co1, c1 in zip(cum1, a.chunks[axis1]):
chunk_offsets.append([])
for co2, c2 in zip(cum2, a.chunks[axis2]):
k = offset + co1 - co2
diag_chunks.append(_diag_len(c1, c2, k))
chunk_offsets[-1].append(k)

dsk = {}
for ai in product(*a_indices):
ci = tuple(r for i, r in enumerate(ai) if i not in (axis1, axis2))
tsk = reduce(getitem, ai, a.__dask_keys__())
dsk.update({(name,) + ci + (0,): (_diag, tsk, offset, axis1, axis2)})
idx_set = set(range(a.ndim)) - set([axis1, axis2])
n1 = len(a.chunks[axis1])
n2 = len(a.chunks[axis2])
for idx in product(*(range(len(a.chunks[i])) for i in idx_set)):
for i, (i1, i2) in enumerate(product(range(n1), range(n2))):
tsk = reduce(getitem, idx[:axis1], a.__dask_keys__())[i1]
tsk = reduce(getitem, idx[axis1:axis2 - 1], tsk)[i2]
tsk = reduce(getitem, idx[axis2 - 1:], tsk)
k = chunk_offsets[i1][i2]
dsk[(name,) + idx + (i,)] = (_diagonal, tsk, k, axis1, axis2)

left_shape = tuple(a.shape[i] for i in idx_set)
right_shape = (_diag_len(a.shape[axis1], a.shape[axis2], offset),)
shape = left_shape + right_shape

left_chunks = tuple(a.chunks[i] for i in idx_set)
right_shape = (tuple(diag_chunks),)
chunks = left_chunks + right_shape

chunks = tuple(a.chunks[i] for i in left_idx) + (right_shape,)
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[a])
return Array(graph, name, shape=shape, chunks=chunks, dtype=a.dtype)

Expand Down
101 changes: 56 additions & 45 deletions dask/array/tests/test_creation.py
Expand Up @@ -7,7 +7,7 @@

import dask
import dask.array as da
from dask.array.utils import assert_eq, same_keys
from dask.array.utils import assert_eq, same_keys, AxisError


@pytest.mark.parametrize(
Expand Down Expand Up @@ -387,67 +387,78 @@ def test_diagonal():
with pytest.raises(ValueError):
da.diagonal(v, axis1=0, axis2=0)

assert_eq(da.diagonal(v), np.diagonal(v))
with pytest.raises(AxisError):
da.diagonal(v, axis1=-4)

assert_eq(da.diagonal(v, offset=1), np.diagonal(v, offset=1))
assert_eq(da.diagonal(v, offset=-1), np.diagonal(v, offset=-1))
with pytest.raises(AxisError):
da.diagonal(v, axis2=-4)

v = np.arange(16).reshape((2, 2, 4))
v = np.arange(4 * 5 * 6).reshape((4, 5, 6))
v = da.from_array(v, chunks=2)
assert_eq(da.diagonal(v), np.diagonal(v))
# Empty diagonal.
assert_eq(da.diagonal(v, offset=10), np.diagonal(v, offset=10))
assert_eq(da.diagonal(v, offset=-10), np.diagonal(v, offset=-10))

v = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
assert_eq(da.diagonal(v, axis1=1, axis2=3), np.diagonal(v, axis1=1, axis2=3))
assert_eq(da.diagonal(v, axis1=1, axis2=3, offset=1),
np.diagonal(v, axis1=1, axis2=3, offset=1))

assert_eq(da.diagonal(v, axis1=3, axis2=1, offset=1),
np.diagonal(v, axis1=3, axis2=1, offset=1))
with pytest.raises(ValueError):
da.diagonal(v, axis1=-2)

assert_eq(da.diagonal(v, axis1=3, axis2=1, offset=-5),
np.diagonal(v, axis1=3, axis2=1, offset=-5))
# Negative axis.
assert_eq(da.diagonal(v, axis1=-1), np.diagonal(v, axis1=-1))
assert_eq(da.diagonal(v, offset=1, axis1=-1), np.diagonal(v, offset=1, axis1=-1))

assert_eq(da.diagonal(v, axis1=3, axis2=1, offset=-6),
np.diagonal(v, axis1=3, axis2=1, offset=-6))
# Heterogenous chunks.
v = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
v = da.from_array(v, chunks=(1, (1, 2), (1, 2, 1), (2, 1, 2), (5, 1)))

assert_eq(da.diagonal(v, axis1=-3, axis2=1, offset=-6),
np.diagonal(v, axis1=-3, axis2=1, offset=-6))
assert_eq(da.diagonal(v), np.diagonal(v))
assert_eq(da.diagonal(v, offset=2, axis1=3, axis2=1),
np.diagonal(v, offset=2, axis1=3, axis2=1))

assert_eq(da.diagonal(v, axis1=-3, axis2=1, offset=-6),
np.diagonal(v, axis1=-3, axis2=1, offset=-6))
assert_eq(da.diagonal(v, offset=-2, axis1=3, axis2=1),
np.diagonal(v, offset=-2, axis1=3, axis2=1))

v = da.from_array(v, chunks=2)
assert_eq(da.diagonal(v, axis1=3, axis2=1, offset=1),
np.diagonal(v, axis1=3, axis2=1, offset=1))
assert_eq(da.diagonal(v, offset=-2, axis1=3, axis2=4),
np.diagonal(v, offset=-2, axis1=3, axis2=4))

v = da.from_array(v, chunks=2)
assert_eq(da.diagonal(v, axis1=3, axis2=1, offset=-1),
np.diagonal(v, axis1=3, axis2=1, offset=-1))
assert_eq(da.diagonal(v, 1), np.diagonal(v, 1))
assert_eq(da.diagonal(v, -1), np.diagonal(v, -1))
# Positional arguments
assert_eq(da.diagonal(v, 1, 2, 1), np.diagonal(v, 1, 2, 1))

v = da.from_array(v, chunks=2)
assert_eq(da.diagonal(v), np.diagonal(v))
v = np.arange(2 * 3 * 4 * 5 * 6).reshape((2, 3, 4, 5, 6))
assert_eq(da.diagonal(v, axis1=1, axis2=3), np.diagonal(v, axis1=1, axis2=3))
assert_eq(da.diagonal(v, offset=1, axis1=1, axis2=3),
np.diagonal(v, offset=1, axis1=1, axis2=3))

v = np.arange(6).reshape((2, 3))
v = da.from_array(v, chunks=2)
assert_eq(da.diagonal(v), np.diagonal(v))
assert_eq(da.diagonal(v, offset=1, axis1=3, axis2=1),
np.diagonal(v, offset=1, axis1=3, axis2=1))

v = np.arange(6).reshape((2, 3, 1))
v = da.from_array(v, chunks=2)
assert_eq(da.diagonal(v), np.diagonal(v))
assert_eq(da.diagonal(v, offset=-5, axis1=3, axis2=1),
np.diagonal(v, offset=-5, axis1=3, axis2=1))

# assert sorted(da.diagonal(v).dask) == sorted(da.diagonal(v).dask)
assert_eq(da.diagonal(v, offset=-6, axis1=3, axis2=1),
np.diagonal(v, offset=-6, axis1=3, axis2=1))

v = v + v + 3
darr = da.diagonal(v)
nparr = np.diagonal(v)
assert_eq(darr, nparr)
assert_eq(da.diagonal(v, offset=-6, axis1=-3, axis2=1),
np.diagonal(v, offset=-6, axis1=-3, axis2=1))

x = np.arange(384).reshape((8, 8, 6))
assert_eq(da.diagonal(x, axis1=2, offset=-1),
np.diagonal(x, axis1=2, offset=-1))
assert_eq(da.diagonal(v, offset=-6, axis1=-3, axis2=1),
np.diagonal(v, offset=-6, axis1=-3, axis2=1))

d = da.from_array(x, chunks=(4, 4, 2))
assert_eq(da.diagonal(d, axis1=2, offset=-1),
np.diagonal(x, axis1=2, offset=-1))
v = da.from_array(v, chunks=2)
assert_eq(da.diagonal(v, offset=1, axis1=3, axis2=1),
np.diagonal(v, offset=1, axis1=3, axis2=1))
assert_eq(da.diagonal(v, offset=-1, axis1=3, axis2=1),
np.diagonal(v, offset=-1, axis1=3, axis2=1))

v = np.arange(384).reshape((8, 8, 6))
assert_eq(da.diagonal(v, offset=-1, axis1=2),
np.diagonal(v, offset=-1, axis1=2))

v = da.from_array(v, chunks=(4, 4, 2))
assert_eq(da.diagonal(v, offset=-1, axis1=2),
np.diagonal(v, offset=-1, axis1=2))


@pytest.mark.parametrize('dtype', [None, 'f8', 'i8'])
Expand Down

0 comments on commit 079ac89

Please sign in to comment.