Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add k arg to diag #5683

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 27 additions & 30 deletions dask/array/creation.py
Expand Up @@ -496,52 +496,49 @@ def eye(N, chunks="auto", M=None, k=0, dtype=float):


@derived_from(np)
def diag(v):
name = "diag-" + tokenize(v)
def diag(v, k=0):
name = "diag-" + tokenize(v, k)

meta = meta_from_array(v, 2 if v.ndim == 1 else 1)

if isinstance(v, np.ndarray) or (
hasattr(v, "__array_function__") and not isinstance(v, Array)
):
if v.ndim == 1:
chunks = ((v.shape[0],), (v.shape[0],))
dsk = {(name, 0, 0): (np.diag, v)}
chunks = ((v.shape[0] + abs(k),), (v.shape[0] + abs(k),))
dsk = {(name, 0, 0): (np.diag, v, k)}
elif v.ndim == 2:
chunks = ((min(v.shape),),)
dsk = {(name, 0): (np.diag, v)}
chunks = ((min(v.shape) - abs(k),),)
dsk = {(name, 0): (np.diag, v, k)}
else:
raise ValueError("Array must be 1d or 2d only")
return Array(dsk, name, chunks, meta=meta)
if not isinstance(v, Array):
raise TypeError(
"v must be a dask array or numpy array, got {0}".format(type(v))
)
if v.ndim != 1:
if v.chunks[0] == v.chunks[1]:
dsk = {
(name, i): (np.diag, row[i]) for i, row in enumerate(v.__dask_keys__())
}
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[v])
return Array(graph, name, (v.chunks[0],), meta=meta)
else:
raise NotImplementedError(
"Extracting diagonals from non-square chunked arrays"
)
chunks_1d = v.chunks[0]
blocks = v.__dask_keys__()
dsk = {}
for i, m in enumerate(chunks_1d):
for j, n in enumerate(chunks_1d):
key = (name, i, j)
if i == j:
dsk[key] = (np.diag, blocks[i])
else:
dsk[key] = (np.zeros, (m, n))
dsk[key] = (partial(zeros_like_safe, shape=(m, n)), meta)

graph = HighLevelGraph.from_collections(name, dsk, dependencies=[v])
return Array(graph, name, (chunks_1d, chunks_1d), meta=meta)
if v.ndim != 1:
return diagonal(v, offset=k)

if k == 0:
chunks_1d = v.chunks[0]
blocks = v.__dask_keys__()
dsk = {}
for i, m in enumerate(chunks_1d):
for j, n in enumerate(chunks_1d):
key = (name, i, j)
if i == j:
dsk[key] = (np.diag, blocks[i])
else:
dsk[key] = (np.zeros, (m, n))
dsk[key] = (partial(zeros_like_safe, shape=(m, n)), meta)
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[v])
return Array(graph, name, (chunks_1d, chunks_1d), meta=meta)
elif k > 0:
return pad(diag(v), [[0, k], [k, 0]], mode="constant")
elif k < 0:
return pad(diag(v), [[-k, 0], [0, -k]], mode="constant")


@derived_from(np)
Expand Down
18 changes: 9 additions & 9 deletions dask/array/svg.py
Expand Up @@ -47,9 +47,9 @@ def svg_2d(chunks, offset=(0, 0), skew=(0, 0), size=200, sizes=None):

lines, (min_x, max_x, min_y, max_y) = svg_grid(x, y, offset=offset, skew=skew)

header = '<svg width="%d" height="%d" style="stroke:rgb(0,0,0);stroke-width:1" >\n' % (
max_x + 50,
max_y + 50,
header = (
'<svg width="%d" height="%d" style="stroke:rgb(0,0,0);stroke-width:1" >\n'
% (max_x + 50, max_y + 50,)
)
footer = "\n</svg>"

Expand Down Expand Up @@ -85,9 +85,9 @@ def svg_3d(chunks, size=200, sizes=None, offset=(0, 0)):
z, y, offset=(ox + max_x + 10, oy + max_x), skew=(0, 0)
)

header = '<svg width="%d" height="%d" style="stroke:rgb(0,0,0);stroke-width:1" >\n' % (
max_z + 50,
max_y + 50,
header = (
'<svg width="%d" height="%d" style="stroke:rgb(0,0,0);stroke-width:1" >\n'
% (max_z + 50, max_y + 50,)
)
footer = "\n</svg>"

Expand Down Expand Up @@ -152,9 +152,9 @@ def svg_nd(chunks, size=200):

out.append(o)

header = '<svg width="%d" height="%d" style="stroke:rgb(0,0,0);stroke-width:1" >\n' % (
left,
total_height,
header = (
'<svg width="%d" height="%d" style="stroke:rgb(0,0,0);stroke-width:1" >\n'
% (left, total_height,)
)
footer = "\n</svg>"
return header + "\n\n".join(out) + footer
Expand Down
25 changes: 13 additions & 12 deletions dask/array/tests/test_creation.py
Expand Up @@ -385,32 +385,33 @@ def test_eye():
assert 4 < x.npartitions < 32


def test_diag():
@pytest.mark.parametrize("k", [0, 3, -3])
def test_diag(k):
v = np.arange(11)
assert_eq(da.diag(v), np.diag(v))
assert_eq(da.diag(v, k), np.diag(v, k))

v = da.arange(11, chunks=3)
darr = da.diag(v)
nparr = np.diag(v)
darr = da.diag(v, k)
nparr = np.diag(v, k)
assert_eq(darr, nparr)
assert sorted(da.diag(v).dask) == sorted(da.diag(v).dask)
assert sorted(da.diag(v, k).dask) == sorted(da.diag(v, k).dask)

v = v + v + 3
darr = da.diag(v)
nparr = np.diag(v)
darr = da.diag(v, k)
nparr = np.diag(v, k)
assert_eq(darr, nparr)

v = da.arange(11, chunks=11)
darr = da.diag(v)
nparr = np.diag(v)
darr = da.diag(v, k)
nparr = np.diag(v, k)
assert_eq(darr, nparr)
assert sorted(da.diag(v).dask) == sorted(da.diag(v).dask)
assert sorted(da.diag(v, k).dask) == sorted(da.diag(v, k).dask)

x = np.arange(64).reshape((8, 8))
assert_eq(da.diag(x), np.diag(x))
assert_eq(da.diag(x, k), np.diag(x, k))

d = da.from_array(x, chunks=(4, 4))
assert_eq(da.diag(d), np.diag(x))
assert_eq(da.diag(d, k), np.diag(x, k))


def test_diagonal():
Expand Down