From 5e25ceb68e7639c51a6db8f0906f3f105554c437 Mon Sep 17 00:00:00 2001 From: cssfrancis Date: Tue, 8 Mar 2022 16:40:23 -0600 Subject: [PATCH] Moved _get_block_pattern -> hyperspy.misc.utils and Added docstrings to _get_block_pattern --- hyperspy/misc/utils.py | 40 +++++++++++++++++++++-- hyperspy/signal.py | 28 ++-------------- hyperspy/tests/signals/test_map_method.py | 11 ++++--- 3 files changed, 46 insertions(+), 33 deletions(-) diff --git a/hyperspy/misc/utils.py b/hyperspy/misc/utils.py index 8c61632d6f..e1fe9f81c6 100644 --- a/hyperspy/misc/utils.py +++ b/hyperspy/misc/utils.py @@ -1176,10 +1176,10 @@ def process_function_blockwise(data, The function to applied to the signal axis nav_indexes : tuple The indexes of the navigation axes for the dataset. - output_signal_shape: tuple + output_signal_size: tuple The shape of the output signal. For a ragged signal, this is equal to 1 - block_info : dict - The block info as described by the ``dask.array.map_blocks`` function + output_dtype : dtype + The data type for the output. arg_keys : tuple The list of keys for the passed arguments (args). Together this makes a set of key:value pairs to be passed to the function. @@ -1219,6 +1219,40 @@ def process_function_blockwise(data, return output_array +def _get_block_pattern(args, output_shape): + """ Returns the block pattern used by the `blockwise` function for a + set of arguments give a resulting output_shape + + Parameters + ---------- + args: list + A list of all the arguments which are used for `da.blockwise` + output_shape: tuple + The output shape for the function passed to `da.blockwise` given args + """ + arg_patterns = tuple(tuple(range(a.ndim)) for a in args) + arg_shapes = tuple(a.shape for a in args) + output_pattern = tuple(range(len(output_shape))) + all_ind = arg_shapes + (output_shape,) + max_len = max((len(i) for i in all_ind)) # max number of dimensions + max_arg_len = max((len(i) for i in arg_shapes)) + adjust_chunks = {} + new_axis = {} + output_shape = output_shape + (0,) * (max_len - len(output_shape)) + for i in range(max_len): + shapes = np.array( + [s[i] if len(s) > i else -1 for s in (output_shape,) + arg_shapes] + ) + is_equal_shape = shapes == shapes[0] # if in shapes == output shapes + if not all(is_equal_shape): + if i > max_arg_len - 1: # output shape is a new axis + new_axis[i] = output_shape[i] + else: # output shape is an existing axis + adjust_chunks[i] = output_shape[i] # adjusting chunks based on output + arg_pairs = [(a, p) for a, p in zip(args, arg_patterns)] + return arg_pairs, adjust_chunks, new_axis, output_pattern + + def guess_output_signal_size(test_data, function, ragged, diff --git a/hyperspy/signal.py b/hyperspy/signal.py index 76beefa427..d1887d0d7f 100644 --- a/hyperspy/signal.py +++ b/hyperspy/signal.py @@ -56,7 +56,8 @@ from hyperspy.misc.slicing import SpecialSlicers, FancySlicing from hyperspy.misc.utils import slugify from hyperspy.misc.utils import is_binned # remove in v2.0 -from hyperspy.misc.utils import process_function_blockwise, guess_output_signal_size +from hyperspy.misc.utils import ( + process_function_blockwise, guess_output_signal_size,_get_block_pattern) from hyperspy.misc.utils import add_scalar_axis from hyperspy.docstrings.signal import ( ONE_AXIS_PARAMETER, MANY_AXIS_PARAMETER, OUT_ARG, NAN_FUNC, OPTIMIZE_ARG, @@ -4928,7 +4929,7 @@ def _map_iterate( if output_dtype is None: output_dtype = temp_output_dtype output_shape = self.axes_manager._navigation_shape_in_array + output_signal_size - arg_pairs, adjust_chunks, new_axis, output_pattern = old_sig.get_block_pattern( + arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern( (old_sig.data,) + args, output_shape ) @@ -4996,29 +4997,6 @@ def _map_iterate( sig.data = sig.data.compute(num_workers=max_workers) return sig - def get_block_pattern(self, args, output_shape): - arg_patterns = tuple(tuple(range(a.ndim)) for a in args) - arg_shapes = tuple(a.shape for a in args) - output_pattern = tuple(range(len(output_shape))) - all_ind = arg_shapes + (output_shape,) - max_len = max((len(i) for i in all_ind)) # max number of dimensions - max_arg_len = max((len(i) for i in arg_shapes)) - adjust_chunks = {} - new_axis = {} - output_shape = output_shape + (0,) * (max_len - len(output_shape)) - for i in range(max_len): - shapes = np.array( - [s[i] if len(s) > i else -1 for s in (output_shape,) + arg_shapes] - ) - is_equal_shape = shapes == shapes[0] # if in shapes == output shapes - if not all(is_equal_shape): - if i > max_arg_len - 1: - new_axis[i] = output_shape[i] - else: - adjust_chunks[i] = output_shape[i] - arg_pairs = [(a, p) for a, p in zip(args, arg_patterns)] - return arg_pairs, adjust_chunks, new_axis, output_pattern - def _get_iterating_kwargs(self, iterating_kwargs): nav_chunks = self.get_chunk_size(self.axes_manager.navigation_axes) args, arg_keys = (), () diff --git a/hyperspy/tests/signals/test_map_method.py b/hyperspy/tests/signals/test_map_method.py index cf1cda2f67..7b8e4a2e9a 100644 --- a/hyperspy/tests/signals/test_map_method.py +++ b/hyperspy/tests/signals/test_map_method.py @@ -27,6 +27,7 @@ from hyperspy.decorators import lazifyTestClass from hyperspy.exceptions import VisibleDeprecationWarning from hyperspy._signals.lazy import LazySignal +from hyperspy.misc.utils import _get_block_pattern @lazifyTestClass(ragged=False) @@ -651,7 +652,7 @@ def test_no_change_2d_signal(self, input_shape): chunks = (10,) * len(input_shape) dask_array = da.random.random(input_shape, chunks=chunks) s = hs.signals.Signal2D(dask_array).as_lazy() - arg_pairs, adjust_chunks, new_axis, output_pattern = s.get_block_pattern((s.data,), input_shape) + arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern((s.data,), input_shape) assert new_axis == {} assert adjust_chunks == {} @@ -664,25 +665,25 @@ def test_no_change_1d_signal(self, input_shape): dask_array = da.random.random(input_shape, chunks=chunks) s = hs.signals.Signal1D(dask_array).as_lazy() output_signal_size = input_shape[-1:] - arg_pairs, adjust_chunks, new_axis, output_pattern = s.get_block_pattern((s.data,), input_shape) + arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern((s.data,), input_shape) assert new_axis == {} assert adjust_chunks == {} def test_different_output_signal_size_signal2d(self): s = hs.signals.Signal2D(np.zeros((4, 5))) - arg_pairs, adjust_chunks, new_axis, output_pattern = s.get_block_pattern((s.data,), (1,)) + arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern((s.data,), (1,)) assert new_axis == {} assert adjust_chunks == {0: 1, 1: 0} def test_different_output_signal_size_signal2d_2(self): s = hs.signals.Signal2D(np.zeros((7, 10, 5))) - arg_pairs, adjust_chunks, new_axis, output_pattern = s.get_block_pattern((s.data,), (7, 2)) + arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern((s.data,), (7, 2)) assert new_axis == {} assert adjust_chunks == {1: 2, 2: 0} def test_different_output_signal_size_signal2d_3(self): s = hs.signals.Signal2D(np.zeros((3, 2, 7, 10, 5))) - arg_pairs, adjust_chunks, new_axis, output_pattern = s.get_block_pattern((s.data,),(3, 2, 5,)) + arg_pairs, adjust_chunks, new_axis, output_pattern = _get_block_pattern((s.data,), (3, 2, 5,)) assert new_axis == {} assert adjust_chunks == {2: 5, 3: 0, 4: 0}