Skip to content

Commit

Permalink
Moved _get_block_pattern -> hyperspy.misc.utils and Added docstrings …
Browse files Browse the repository at this point in the history
…to _get_block_pattern
  • Loading branch information
CSSFrancis committed Mar 8, 2022
1 parent e7b4437 commit 5e25ceb
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 33 deletions.
40 changes: 37 additions & 3 deletions hyperspy/misc/utils.py
Expand Up @@ -1176,10 +1176,10 @@ def process_function_blockwise(data,
The function to applied to the signal axis
nav_indexes : tuple
The indexes of the navigation axes for the dataset.
output_signal_shape: tuple
output_signal_size: tuple
The shape of the output signal. For a ragged signal, this is equal to 1
block_info : dict
The block info as described by the ``dask.array.map_blocks`` function
output_dtype : dtype
The data type for the output.
arg_keys : tuple
The list of keys for the passed arguments (args). Together this makes
a set of key:value pairs to be passed to the function.
Expand Down Expand Up @@ -1219,6 +1219,40 @@ def process_function_blockwise(data,
return output_array


def _get_block_pattern(args, output_shape):
""" Returns the block pattern used by the `blockwise` function for a
set of arguments give a resulting output_shape
Parameters
----------
args: list
A list of all the arguments which are used for `da.blockwise`
output_shape: tuple
The output shape for the function passed to `da.blockwise` given args
"""
arg_patterns = tuple(tuple(range(a.ndim)) for a in args)
arg_shapes = tuple(a.shape for a in args)
output_pattern = tuple(range(len(output_shape)))
all_ind = arg_shapes + (output_shape,)
max_len = max((len(i) for i in all_ind)) # max number of dimensions
max_arg_len = max((len(i) for i in arg_shapes))
adjust_chunks = {}
new_axis = {}
output_shape = output_shape + (0,) * (max_len - len(output_shape))
for i in range(max_len):
shapes = np.array(
[s[i] if len(s) > i else -1 for s in (output_shape,) + arg_shapes]
)
is_equal_shape = shapes == shapes[0] # if in shapes == output shapes
if not all(is_equal_shape):
if i > max_arg_len - 1: # output shape is a new axis
new_axis[i] = output_shape[i]
else: # output shape is an existing axis
adjust_chunks[i] = output_shape[i] # adjusting chunks based on output
arg_pairs = [(a, p) for a, p in zip(args, arg_patterns)]
return arg_pairs, adjust_chunks, new_axis, output_pattern


def guess_output_signal_size(test_data,
function,
ragged,
Expand Down
28 changes: 3 additions & 25 deletions hyperspy/signal.py
Expand Up @@ -56,7 +56,8 @@
from hyperspy.misc.slicing import SpecialSlicers, FancySlicing
from hyperspy.misc.utils import slugify
from hyperspy.misc.utils import is_binned # remove in v2.0
from hyperspy.misc.utils import process_function_blockwise, guess_output_signal_size
from hyperspy.misc.utils import (
process_function_blockwise, guess_output_signal_size,_get_block_pattern)
from hyperspy.misc.utils import add_scalar_axis
from hyperspy.docstrings.signal import (
ONE_AXIS_PARAMETER, MANY_AXIS_PARAMETER, OUT_ARG, NAN_FUNC, OPTIMIZE_ARG,
Expand Down Expand Up @@ -4928,7 +4929,7 @@ def _map_iterate(
if output_dtype is None:
output_dtype = temp_output_dtype
output_shape = self.axes_manager._navigation_shape_in_array + output_signal_size
arg_pairs, adjust_chunks, new_axis, output_pattern = old_sig.get_block_pattern(
arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern(
(old_sig.data,) + args, output_shape
)

Expand Down Expand Up @@ -4996,29 +4997,6 @@ def _map_iterate(
sig.data = sig.data.compute(num_workers=max_workers)
return sig

def get_block_pattern(self, args, output_shape):
arg_patterns = tuple(tuple(range(a.ndim)) for a in args)
arg_shapes = tuple(a.shape for a in args)
output_pattern = tuple(range(len(output_shape)))
all_ind = arg_shapes + (output_shape,)
max_len = max((len(i) for i in all_ind)) # max number of dimensions
max_arg_len = max((len(i) for i in arg_shapes))
adjust_chunks = {}
new_axis = {}
output_shape = output_shape + (0,) * (max_len - len(output_shape))
for i in range(max_len):
shapes = np.array(
[s[i] if len(s) > i else -1 for s in (output_shape,) + arg_shapes]
)
is_equal_shape = shapes == shapes[0] # if in shapes == output shapes
if not all(is_equal_shape):
if i > max_arg_len - 1:
new_axis[i] = output_shape[i]
else:
adjust_chunks[i] = output_shape[i]
arg_pairs = [(a, p) for a, p in zip(args, arg_patterns)]
return arg_pairs, adjust_chunks, new_axis, output_pattern

def _get_iterating_kwargs(self, iterating_kwargs):
nav_chunks = self.get_chunk_size(self.axes_manager.navigation_axes)
args, arg_keys = (), ()
Expand Down
11 changes: 6 additions & 5 deletions hyperspy/tests/signals/test_map_method.py
Expand Up @@ -27,6 +27,7 @@
from hyperspy.decorators import lazifyTestClass
from hyperspy.exceptions import VisibleDeprecationWarning
from hyperspy._signals.lazy import LazySignal
from hyperspy.misc.utils import _get_block_pattern


@lazifyTestClass(ragged=False)
Expand Down Expand Up @@ -651,7 +652,7 @@ def test_no_change_2d_signal(self, input_shape):
chunks = (10,) * len(input_shape)
dask_array = da.random.random(input_shape, chunks=chunks)
s = hs.signals.Signal2D(dask_array).as_lazy()
arg_pairs, adjust_chunks, new_axis, output_pattern = s.get_block_pattern((s.data,), input_shape)
arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern((s.data,), input_shape)
assert new_axis == {}
assert adjust_chunks == {}

Expand All @@ -664,25 +665,25 @@ def test_no_change_1d_signal(self, input_shape):
dask_array = da.random.random(input_shape, chunks=chunks)
s = hs.signals.Signal1D(dask_array).as_lazy()
output_signal_size = input_shape[-1:]
arg_pairs, adjust_chunks, new_axis, output_pattern = s.get_block_pattern((s.data,), input_shape)
arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern((s.data,), input_shape)
assert new_axis == {}
assert adjust_chunks == {}

def test_different_output_signal_size_signal2d(self):
s = hs.signals.Signal2D(np.zeros((4, 5)))
arg_pairs, adjust_chunks, new_axis, output_pattern = s.get_block_pattern((s.data,), (1,))
arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern((s.data,), (1,))
assert new_axis == {}
assert adjust_chunks == {0: 1, 1: 0}

def test_different_output_signal_size_signal2d_2(self):
s = hs.signals.Signal2D(np.zeros((7, 10, 5)))
arg_pairs, adjust_chunks, new_axis, output_pattern = s.get_block_pattern((s.data,), (7, 2))
arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern((s.data,), (7, 2))
assert new_axis == {}
assert adjust_chunks == {1: 2, 2: 0}

def test_different_output_signal_size_signal2d_3(self):
s = hs.signals.Signal2D(np.zeros((3, 2, 7, 10, 5)))
arg_pairs, adjust_chunks, new_axis, output_pattern = s.get_block_pattern((s.data,),(3, 2, 5,))
arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern((s.data,), (3, 2, 5,))
assert new_axis == {}
assert adjust_chunks == {2: 5, 3: 0, 4: 0}

Expand Down

0 comments on commit 5e25ceb

Please sign in to comment.