Skip to content

Commit

Permalink
A HighLevelGraph abstract layer for map_overlap (#7595)
Browse files Browse the repository at this point in the history
  • Loading branch information
GenevieveBuckley committed Jun 10, 2021
1 parent 1d0262c commit 6599905
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 178 deletions.
14 changes: 1 addition & 13 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..core import quote
from ..delayed import Delayed, delayed
from ..highlevelgraph import HighLevelGraph
from ..layers import reshapelist
from ..sizeof import sizeof
from ..utils import (
IndexCallable,
Expand Down Expand Up @@ -4598,19 +4599,6 @@ def shapelist(a):
return ()


def reshapelist(shape, seq):
"""Reshape iterator to nested shape
>>> reshapelist((2, 3), range(6))
[[0, 1, 2], [3, 4, 5]]
"""
if len(shape) == 1:
return list(seq)
else:
n = int(len(seq) / shape[0])
return [reshapelist(shape[1:], part) for part in partition(n, seq)]


def transposelist(arrays, axes, extradims=0):
"""Permute axes of nested list
Expand Down
171 changes: 29 additions & 142 deletions dask/array/overlap.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,23 @@
import warnings
from itertools import product
from numbers import Integral
from operator import getitem

import numpy as np
from tlz import concat, get, merge, partial, pipe
from tlz import concat, get, partial
from tlz.curried import map

from ..base import tokenize
from ..core import flatten
from ..highlevelgraph import HighLevelGraph
from ..utils import concrete, derived_from
from ..layers import ArrayOverlapLayer
from ..utils import derived_from
from . import chunk, numpy_compat
from .core import (
Array,
concatenate,
concatenate3,
map_blocks,
reshapelist,
unify_chunks,
)
from .core import Array, concatenate, map_blocks, unify_chunks
from .creation import empty_like, full_like


def fractional_slice(task, axes):
"""
>>> fractional_slice(('x', 5.1), {0: 2}) # doctest: +SKIP
(getitem, ('x', 6), (slice(0, 2),))
>>> fractional_slice(('x', 3, 5.1), {0: 2, 1: 3}) # doctest: +SKIP
(getitem, ('x', 3, 5), (slice(None, None, None), slice(-3, None)))
>>> fractional_slice(('x', 2.9, 5.1), {0: 2, 1: 3}) # doctest: +SKIP
(getitem, ('x', 3, 5), (slice(0, 2), slice(-3, None)))
"""
rounded = (task[0],) + tuple(int(round(i)) for i in task[1:])

index = []
for i, (t, r) in enumerate(zip(task[1:], rounded[1:])):
def _overlap_internal_chunks(original_chunks, axes):
"""Get new chunks for array with overlap."""
chunks = []
for i, bds in enumerate(original_chunks):
depth = axes.get(i, 0)
if isinstance(depth, tuple):
left_depth = depth[0]
Expand All @@ -47,75 +26,16 @@ def fractional_slice(task, axes):
left_depth = depth
right_depth = depth

if t == r:
index.append(slice(None, None, None))
elif t < r and right_depth:
index.append(slice(0, right_depth))
elif t > r and left_depth:
index.append(slice(-left_depth, None))
if len(bds) == 1:
chunks.append(bds)
else:
index.append(slice(0, 0))
index = tuple(index)

if all(ind == slice(None, None, None) for ind in index):
return task
else:
return (getitem, rounded, index)


def expand_key(k, dims, name=None, axes=None):
"""Get all neighboring keys around center
Parameters
----------
k: tuple
They key around which to generate new keys
dims: Sequence[int]
The number of chunks in each dimension
name: Option[str]
The name to include in the output keys, or none to include no name
axes: Dict[int, int]
The axes active in the expansion. We don't expand on non-active axes
Examples
--------
>>> expand_key(('x', 2, 3), dims=[5, 5], name='y', axes={0: 1, 1: 1}) # doctest: +NORMALIZE_WHITESPACE
[[('y', 1.1, 2.1), ('y', 1.1, 3), ('y', 1.1, 3.9)],
[('y', 2, 2.1), ('y', 2, 3), ('y', 2, 3.9)],
[('y', 2.9, 2.1), ('y', 2.9, 3), ('y', 2.9, 3.9)]]
>>> expand_key(('x', 0, 4), dims=[5, 5], name='y', axes={0: 1, 1: 1}) # doctest: +NORMALIZE_WHITESPACE
[[('y', 0, 3.1), ('y', 0, 4)],
[('y', 0.9, 3.1), ('y', 0.9, 4)]]
"""

def inds(i, ind):
rv = []
if ind - 0.9 > 0:
rv.append(ind - 0.9)
rv.append(ind)
if ind + 0.9 < dims[i] - 1:
rv.append(ind + 0.9)
return rv

shape = []
for i, ind in enumerate(k[1:]):
num = 1
if ind > 0:
num += 1
if ind < dims[i] - 1:
num += 1
shape.append(num)

args = [
inds(i, ind) if any((axes.get(i, 0),)) else [ind] for i, ind in enumerate(k[1:])
]
if name is not None:
args = [[name]] + args
seq = list(product(*args))
shape2 = [d if any((axes.get(i, 0),)) else 1 for i, d in enumerate(shape)]
result = reshapelist(shape2, seq)
return result
left = [bds[0] + right_depth]
right = [bds[-1] + left_depth]
mid = []
for bd in bds[1:-1]:
mid.append(bd + left_depth + right_depth)
chunks.append(left + mid + right)
return chunks


def overlap_internal(x, axes):
Expand All @@ -132,51 +52,18 @@ def overlap_internal(x, axes):
The axes input informs how many cells to overlap between neighboring blocks
{0: 2, 2: 5} means share two cells in 0 axis, 5 cells in 2 axis
"""
dims = list(map(len, x.chunks))
expand_key2 = partial(expand_key, dims=dims, axes=axes)

# Make keys for each of the surrounding sub-arrays
interior_keys = pipe(
x.__dask_keys__(), flatten, map(expand_key2), map(flatten), concat, list
token = tokenize(x, axes)
name = "overlap-" + token

graph = ArrayOverlapLayer(
name=x.name,
axes=axes,
chunks=x.chunks,
numblocks=x.numblocks,
token=token,
)

name = "overlap-" + tokenize(x, axes)
getitem_name = "getitem-" + tokenize(x, axes)
interior_slices = {}
overlap_blocks = {}
for k in interior_keys:
frac_slice = fractional_slice((x.name,) + k, axes)
if (x.name,) + k != frac_slice:
interior_slices[(getitem_name,) + k] = frac_slice
else:
interior_slices[(getitem_name,) + k] = (x.name,) + k
overlap_blocks[(name,) + k] = (
concatenate3,
(concrete, expand_key2((None,) + k, name=getitem_name)),
)

chunks = []
for i, bds in enumerate(x.chunks):
depth = axes.get(i, 0)
if isinstance(depth, tuple):
left_depth = depth[0]
right_depth = depth[1]
else:
left_depth = depth
right_depth = depth

if len(bds) == 1:
chunks.append(bds)
else:
left = [bds[0] + right_depth]
right = [bds[-1] + left_depth]
mid = []
for bd in bds[1:-1]:
mid.append(bd + left_depth + right_depth)
chunks.append(left + mid + right)

dsk = merge(interior_slices, overlap_blocks)
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[x])
graph = HighLevelGraph.from_collections(name, graph, dependencies=[x])
chunks = _overlap_internal_chunks(x.chunks, axes)

return Array(graph, name, chunks, meta=x)

Expand Down Expand Up @@ -534,7 +421,7 @@ def overlap(x, depth, boundary):
new_chunks = tuple(
ensure_minimum_chunksize(size, c) for size, c in zip(depths, x.chunks)
)
x1 = x.rechunk(new_chunks)
x1 = x.rechunk(new_chunks) # this is a no-op if x.chunks == new_chunks

x2 = boundaries(x1, depth2, boundary2)
x3 = overlap_internal(x2, depth2)
Expand Down
21 changes: 0 additions & 21 deletions dask/array/tests/test_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
boundaries,
constant,
ensure_minimum_chunksize,
fractional_slice,
getitem,
nearest,
overlap,
overlap_internal,
Expand All @@ -24,25 +22,6 @@
from ..lib.stride_tricks import sliding_window_view


def test_fractional_slice():
assert fractional_slice(("x", 4.9), {0: 2}) == (getitem, ("x", 5), (slice(0, 2),))

assert fractional_slice(("x", 3, 5.1), {0: 2, 1: 3}) == (
getitem,
("x", 3, 5),
(slice(None, None, None), slice(-3, None)),
)

assert fractional_slice(("x", 2.9, 5.1), {0: 2, 1: 3}) == (
getitem,
("x", 3, 5),
(slice(0, 2), slice(-3, None)),
)

fs = fractional_slice(("x", 4.9), {0: 2})
assert isinstance(fs[1][1], int)


def test_overlap_internal():
x = np.arange(64).reshape((8, 8))
d = da.from_array(x, chunks=(4, 4))
Expand Down
Loading

0 comments on commit 6599905

Please sign in to comment.