Skip to content

Commit

Permalink
Added in test_rechunk_arguments function which tests the case where i…
Browse files Browse the repository at this point in the history
…terating signal chunks != signal chunks. Changes to get_iterating_kwargs to check for chunk spanning and rechunk if necessary.
  • Loading branch information
CSSFrancis committed Mar 8, 2022
1 parent da78103 commit e7b4437
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
11 changes: 8 additions & 3 deletions hyperspy/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2698,7 +2698,6 @@ def plot(self, navigator="auto", axes_manager=None, plot_markers=True,
if axes_manager.navigation_dimension == 0:
# 0d signal without navigation axis: don't make a figure
# and instead, we display the value
print(self.data)
return
self._plot = mpl_he.MPL_HyperExplorer()
elif axes_manager.signal_dimension == 1:
Expand Down Expand Up @@ -4902,6 +4901,7 @@ def _map_iterate(
_logger.info(
"The chunk size needs to span the full signal size, rechunking..."
)

old_sig = s_input.rechunk(inplace=False, nav_chunks=None)
else:
old_sig = s_input
Expand Down Expand Up @@ -5019,9 +5019,7 @@ def get_block_pattern(self, args, output_shape):
arg_pairs = [(a, p) for a, p in zip(args, arg_patterns)]
return arg_pairs, adjust_chunks, new_axis, output_pattern


def _get_iterating_kwargs(self, iterating_kwargs):
signal_dim_shape = self.axes_manager.signal_shape
nav_chunks = self.get_chunk_size(self.axes_manager.navigation_axes)
args, arg_keys = (), ()
for key in iterating_kwargs:
Expand All @@ -5036,6 +5034,13 @@ def _get_iterating_kwargs(self, iterating_kwargs):
axes = iterating_kwargs[key].axes_manager.navigation_axes
if iterating_kwargs[key].get_chunk_size(axes) != nav_chunks:
iterating_kwargs[key].rechunk(nav_chunks=nav_chunks, sig_chunks=-1)
chunk_span = np.equal(iterating_kwargs[key].data.chunksize,
iterating_kwargs[key].data.shape)
chunk_span = [
chunk_span[i] for i in iterating_kwargs[key].axes_manager.signal_indices_in_array
]
if not all(chunk_span):
iterating_kwargs[key].rechunk(nav_chunks=nav_chunks, sig_chunks=-1)
else:
iterating_kwargs[key] = iterating_kwargs[key].as_lazy()
iterating_kwargs[key].rechunk(nav_chunks=nav_chunks, sig_chunks=-1)
Expand Down
18 changes: 18 additions & 0 deletions hyperspy/tests/signals/test_map_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,24 @@ def test_crop_signal2d_lazy_all_input(self):
assert not np.any(s_out.data)
assert s_out.axes_manager.shape == (39, 28, 44, 40)

def test_rechunk_arguments(self):
chunk_shape = (2, 2, 2, 2, 2)

def add_sum(image, add1, add2):
temp_add = add1.sum(-1) + add2
out = image + np.sum(temp_add)
return out
x = np.ones((4, 5, 10, 11))
s = hs.signals.Signal2D(x)
s_add1 = hs.signals.BaseSignal(2 * np.ones((4, 5, 2, 3, 2))).transpose(3)
s_add2 = hs.signals.BaseSignal(3 * np.ones((4, 5, 2, 3))).transpose(2)

s = hs.signals.Signal2D(da.from_array(s.data, chunks=(2, 2, 2, 2))).as_lazy()
s_add1 = hs.signals.Signal2D(da.from_array(s_add1.data, chunks=chunk_shape)).as_lazy().transpose(
navigation_axes=(1, 2))
s_out = s.map(add_sum, inplace=False, add1=s_add1, add2=s_add2, lazy_output=False)
assert (s_out.axes_manager.shape == s.axes_manager.shape)


class TestLazyNavChunkSize1:
@staticmethod
Expand Down

0 comments on commit e7b4437

Please sign in to comment.