Skip to content

Commit

Permalink
Merge pull request #3199 from CSSFrancis/n-d-navigator
Browse files Browse the repository at this point in the history
N d navigator
  • Loading branch information
ericpre committed Feb 7, 2024
2 parents 4039c37 + d7d2cf5 commit 950741a
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 28 deletions.
10 changes: 3 additions & 7 deletions hyperspy/_signals/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,13 +1221,9 @@ def plot(self, navigator="auto", **kwargs):
)
navigator = "auto"
if navigator == "auto":
nav_dim = self.axes_manager.navigation_dimension
if nav_dim in [1, 2]:
if self.navigator is None:
self.compute_navigator()
navigator = self.navigator
elif nav_dim > 2:
navigator = "slider"
if self.navigator is None:
self.compute_navigator()
navigator = self.navigator
super().plot(navigator=navigator, **kwargs)

def compute_navigator(self, index=None, chunks_number=None, show_progressbar=None):
Expand Down
38 changes: 20 additions & 18 deletions hyperspy/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3056,15 +3056,18 @@ def get_1D_sum_explorer_wrapper(*args, **kwargs):
navigator = sum_wrapper(self, am.signal_axes + am.navigation_axes[1:])
return np.nan_to_num(to_numpy(navigator.data)).squeeze()

def get_dynamic_explorer_wrapper(*args, **kwargs):
navigator.axes_manager.indices = self.axes_manager.indices[
navigator.axes_manager.signal_dimension :
]
navigator.axes_manager._update_attributes()
if np.issubdtype(navigator._get_current_data().dtype, np.complexfloating):
return abs(navigator._get_current_data(as_numpy=True))
else:
return navigator(as_numpy=True)
def get_dynamic_image_explorer(*args, **kwargs):
am = self.axes_manager
nav_ind = am.indices[2:] # image at first 2 nav indices
slices = [slice(None)] * len(am.navigation_axes)
slices[2:] = nav_ind
new_nav = navigator.transpose(
signal_axes=len(am.navigation_axes)
) # roll axes to signal axes
ind = new_nav.isig.__getitem__(
slices=slices
) # Get the value from the nav reverse because hyperspy
return np.nan_to_num(to_numpy(ind.data)).squeeze()

if not isinstance(navigator, BaseSignal) and navigator == "auto":
if self.navigator is not None:
Expand Down Expand Up @@ -3124,21 +3127,20 @@ def is_shape_compatible(navigation_shape, shape):
if is_shape_compatible(
axes_manager.navigation_shape, navigator.axes_manager.signal_shape
):
self._plot.navigator_data_function = get_static_explorer_wrapper
if len(axes_manager.navigation_shape) > 2:
self._plot.navigator_data_function = get_dynamic_image_explorer
else:
self._plot.navigator_data_function = get_static_explorer_wrapper
# Static transposed navigator
elif is_shape_compatible(
axes_manager.navigation_shape,
navigator.axes_manager.navigation_shape,
):
navigator = navigator.T
self._plot.navigator_data_function = get_static_explorer_wrapper
# Dynamic navigator
elif (
axes_manager.navigation_shape
== navigator.axes_manager.signal_shape
+ navigator.axes_manager.navigation_shape
):
self._plot.navigator_data_function = get_dynamic_explorer_wrapper
if len(axes_manager.navigation_shape) > 2:
self._plot.navigator_data_function = get_dynamic_image_explorer
else:
self._plot.navigator_data_function = get_static_explorer_wrapper
else:
raise ValueError(
"The dimensions of the provided (or stored) navigator "
Expand Down
4 changes: 1 addition & 3 deletions hyperspy/tests/drawing/test_plot_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@ def test_plot_lazy(ndim):

if ndim == 0:
assert s._plot.navigator_data_function is None
elif ndim in [1, 2]:
else:
assert s.navigator.data.shape == tuple([N] * ndim)
assert isinstance(s.navigator, hs.signals.BaseSignal)
else:
assert s._plot.navigator_data_function == "slider"


@pytest.mark.parametrize(
Expand Down
120 changes: 120 additions & 0 deletions hyperspy/tests/drawing/test_plot_signal2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import hyperspy.api as hs
from hyperspy.drawing.utils import make_cmap, plot_RGB_map
from hyperspy.tests.drawing.test_plot_signal import _TestPlot
from hyperspy.decorators import lazifyTestClass


scalebar_color = "blue"
default_tol = 2.0
Expand Down Expand Up @@ -831,3 +833,121 @@ def test_plot_images_bool():
s = hs.signals.Signal2D(data)

hs.plot.plot_images(s)


def test_plot_static_signal_nav():
s = hs.signals.Signal2D(np.ones((20, 20, 10, 10)))
nav = hs.signals.Signal2D(np.ones((20, 20)))
s.plot(navigator=nav)


@lazifyTestClass
class TestDynamicNavigatorPlot:
def setup_method(self, method):
self.signal5d2d = hs.signals.Signal2D(
np.arange((10**5)).reshape(
(
10,
10,
10,
10,
10,
)
)
)
self.signal6d2d = hs.signals.Signal2D(
np.arange((10**6)).reshape((10, 10, 10, 10, 10, 10))
)
self.signal4d1d = hs.signals.Signal1D(
np.arange((10**4)).reshape((10, 10, 10, 10))
)
self.signal5d1d = hs.signals.Signal1D(
np.arange((10**5)).reshape(
(
10,
10,
10,
10,
10,
)
)
)

def test_plot_5d(self):
import hyperspy.api as hs
import numpy as np

s = self.signal5d2d
nav = hs.signals.BaseSignal(np.arange((10 * 10 * 10)).reshape(10, 10, 10))
s.plot(navigator=nav)
data1 = s._plot.navigator_plot._current_data
s.axes_manager.indices = (0, 0, 1)
data2 = s._plot.navigator_plot._current_data
assert not np.array_equal(data1, data2)
s.axes_manager.indices = (0, 2, 1)
data3 = s._plot.navigator_plot._current_data
assert np.array_equal(data2, data3)

def test_plot_5d_2(self):
import hyperspy.api as hs
import numpy as np

s = self.signal5d2d
nav = hs.signals.Signal2D(np.arange((10 * 10 * 10)).reshape(10, 10, 10))
s.plot(navigator=nav)
data1 = s._plot.navigator_plot._current_data
s.axes_manager.indices = (0, 0, 1)
data2 = s._plot.navigator_plot._current_data
assert not np.array_equal(data1, data2)
s.axes_manager.indices = (0, 2, 1)
data3 = s._plot.navigator_plot._current_data
assert np.array_equal(data2, data3)

def test_plot_6d(self):
import hyperspy.api as hs
import numpy as np

s = self.signal6d2d
nav = hs.signals.BaseSignal(
np.arange((10 * 10 * 10 * 10)).reshape(10, 10, 10, 10)
)
s.plot(navigator=nav)
data1 = s._plot.navigator_plot._current_data
s.axes_manager.indices = (0, 0, 0, 1)
data2 = s._plot.navigator_plot._current_data
assert not np.array_equal(data1, data2)
s.axes_manager.indices = (0, 2, 0, 1)
data3 = s._plot.navigator_plot._current_data
assert np.array_equal(data2, data3)

def test_plot_4d_1dSignal(self):
import hyperspy.api as hs
import numpy as np

s = self.signal4d1d
nav = hs.signals.BaseSignal(np.arange((10 * 10 * 10)).reshape(10, 10, 10))
s.plot(navigator=nav)
data1 = s._plot.navigator_plot._current_data
s.axes_manager.indices = (0, 0, 1)
data2 = s._plot.navigator_plot._current_data
assert not np.array_equal(data1, data2)
s.axes_manager.indices = (0, 2, 1)
data3 = s._plot.navigator_plot._current_data
assert np.array_equal(data2, data3)

def test_plot_5d_1dsignal(self):
import hyperspy.api as hs
import numpy as np

s = self.signal5d1d
nav = hs.signals.BaseSignal(
np.arange((10 * 10 * 10 * 10)).reshape(10, 10, 10, 10)
)
s.plot(navigator=nav)
data1 = s._plot.navigator_plot._current_data
s.axes_manager.indices = (0, 0, 0, 1)
data2 = s._plot.navigator_plot._current_data
assert not np.array_equal(data1, data2)
s.axes_manager.indices = (0, 2, 0, 1)
data3 = s._plot.navigator_plot._current_data
assert np.array_equal(data2, data3)
1 change: 1 addition & 0 deletions upcoming_changes/3199.enhancements.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add an dynamic navigator which updates when the number of navigation dimensions is greater than 3

0 comments on commit 950741a

Please sign in to comment.