Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
ericpre committed Jan 24, 2022
1 parent 3279f3b commit 4b03b14
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 22 deletions.
18 changes: 11 additions & 7 deletions hyperspy/signal.py
Expand Up @@ -44,10 +44,8 @@
import hyperspy.misc.utils
from hyperspy.misc.utils import DictionaryTreeBrowser
from hyperspy.drawing import signal as sigdraw
from hyperspy.defaults_parser import preferences
from hyperspy.misc.io.tools import ensure_directory
from hyperspy.misc.utils import iterable_not_string
from hyperspy.external.progressbar import progressbar
from hyperspy.exceptions import SignalDimensionError, DataDimensionError
from hyperspy.misc import rgb_tools
from hyperspy.misc.utils import underline, isiterable
Expand Down Expand Up @@ -5014,7 +5012,7 @@ def _get_drop_axis_new_axis(self, output_signal_size):

def _get_iterating_kwargs(self, iterating_kwargs):
signal_dim_shape = self.axes_manager.signal_shape
nav_chunks = self.get_chunk_size(axis=self.axes_manager.navigation_axes)
nav_chunks = self.get_chunk_size(self.axes_manager.navigation_axes)
args, arg_keys = (), ()
for key in iterating_kwargs:
if not isinstance(iterating_kwargs[key], BaseSignal):
Expand All @@ -5025,12 +5023,18 @@ def _get_iterating_kwargs(self, iterating_kwargs):
"Pass signal instances instead."
)
if iterating_kwargs[key]._lazy:
if iterating_kwargs[key].get_chunk_size(axis=iterating_kwargs[key].axes_manager.navigation_axes) !=\
nav_chunks:
iterating_kwargs[key].rechunk(nav_chunks=nav_chunks, sig_chunks=-1)
axes = iterating_kwargs[key].axes_manager.navigation_axes
if iterating_kwargs[key].get_chunk_size(axes) != nav_chunks:
iterating_kwargs[key].rechunk(
nav_chunks=nav_chunks,
sig_chunks=-1
)
else:
iterating_kwargs[key] = iterating_kwargs[key].as_lazy()
iterating_kwargs[key].rechunk(nav_chunks=nav_chunks, sig_chunks=-1)
iterating_kwargs[key].rechunk(
nav_chunks=nav_chunks,
sig_chunks=-1
)
extra_dims = (len(signal_dim_shape) -
len(iterating_kwargs[key].axes_manager.signal_shape))
if extra_dims > 0:
Expand Down
21 changes: 14 additions & 7 deletions hyperspy/tests/signals/test_lazy.py
Expand Up @@ -392,9 +392,9 @@ def test_signal1d(self):


class TestHTMLRep:
def test_html_rep(self):
sig = _signal()
sig._repr_html_()

def test_html_rep(self, signal):
signal._repr_html_()

def test_html_rep_zero_dim_nav(self):
s = hs.signals.BaseSignal(da.random.random((500, 1000))).as_lazy()
Expand All @@ -414,7 +414,14 @@ def test_get_chunk_string(self):
s_string = s._get_chunk_string()
assert (s_string == "(<b>6</b>,<b>6</b>|3,2)")

def test_get_chunk_size(self):
sig = _signal()
s = sig.get_chunk_size()
assert s == ((2, 1, 3), (4, 5))

def test_get_chunk_size(signal):
sig = signal
chunk_size = sig.get_chunk_size()
assert chunk_size == ((2, 1, 3), (4, 5))
assert sig.get_chunk_size(sig.axes_manager.navigation_axes) == chunk_size
assert sig.get_chunk_size([0, 1]) == chunk_size

sig = _signal()
chunk_size = sig.get_chunk_size(axes=0)
chunk_size == ((2, 1, 3), )
17 changes: 9 additions & 8 deletions hyperspy/tests/signals/test_map_method.py
Expand Up @@ -356,15 +356,15 @@ def test_map_nav_size_error(self):
def test_keep_navigation_chunks(self):
s = self.s
s_out = s.map(lambda x: x, inplace=False, lazy_output=True)
assert (s.get_chunk_size(axis=s.axes_manager.navigation_axes) ==
s_out.get_chunk_size(axis=s_out.axes_manager.navigation_axes))
assert (s.get_chunk_size(s.axes_manager.navigation_axes) ==
s_out.get_chunk_size(s_out.axes_manager.navigation_axes))

def test_keep_navigation_chunks_cropping(self):
s = self.s
s1 = s.inav[1:-2, 2:-1]
s_out = s1.map(lambda x: x, inplace=False, lazy_output=True)
assert (s1.get_chunk_size(axis=s1.axes_manager.navigation_axes) ==
s_out.get_chunk_size(axis=s_out.axes_manager.navigation_axes))
assert (s1.get_chunk_size(s1.axes_manager.navigation_axes) ==
s_out.get_chunk_size(s_out.axes_manager.navigation_axes))

@pytest.mark.parametrize("output_signal_size", [(3,), (3, 4), (3, 4, 5)])
def test_map_output_signal_size(self, output_signal_size):
Expand Down Expand Up @@ -581,7 +581,7 @@ def test_empty(self):

def test_one_iterating_kwarg(self):
s = self.s
nav_chunks = s.get_chunk_size(axis=s.axes_manager.navigation_axes)
nav_chunks = s.get_chunk_size(axes=s.axes_manager.navigation_axes)
nav_dim = len(nav_chunks)
s_iter0 = hs.signals.Signal1D(np.random.random((10, 20, 2)))
iterating_kwargs = {"iter0": s_iter0}
Expand All @@ -595,7 +595,7 @@ def test_one_iterating_kwarg(self):

def test_many_iterating_kwarg(self):
s = self.s
nav_chunks = s.get_chunk_size(axis=s.axes_manager.navigation_axes)
nav_chunks = s.get_chunk_size(axes=s.axes_manager.navigation_axes)
nav_dim = len(nav_chunks)
s_iter0 = hs.signals.Signal1D(np.random.random((10, 20, 2)))
s_iter1 = hs.signals.Signal2D(np.random.random((10, 20, 200, 200)))
Expand All @@ -613,7 +613,7 @@ def test_many_iterating_kwarg(self):

def test_lazy_iterating_kwarg(self):
s = self.s
nav_chunks = s.get_chunk_size(axis=s.axes_manager.navigation_axes)
nav_chunks = s.get_chunk_size(axes=s.axes_manager.navigation_axes)
nav_dim = len(nav_chunks)
dask_array_iter0 = da.zeros((10, 20, 2), chunks=(5, 10, 2))
dask_array_iter1 = da.zeros((10, 20, 2), chunks=(5, 5, 2))
Expand All @@ -629,7 +629,7 @@ def test_lazy_iterating_kwarg(self):

def test_cropping_iterating_kwarg(self):
s = self.s.inav[1:]
nav_chunks = s.get_chunk_size(axis=s.axes_manager.navigation_axes)
nav_chunks = s.get_chunk_size(axes=s.axes_manager.navigation_axes)
nav_dim = len(nav_chunks)
s_iter0 = hs.signals.Signal1D(np.random.random((10, 19, 2)))
iterating_kwargs = {"iter0": s_iter0}
Expand Down Expand Up @@ -1006,6 +1006,7 @@ def test_inplace(self):
assert s.data[0, -1] == 0.0
assert s.data[-1, 0] == 0.0
assert s.data[-1, -1] == 0.0
assert s_rot is None


class TestCompareMapAllvsMapIterate:
Expand Down

0 comments on commit 4b03b14

Please sign in to comment.