Skip to content

Commit

Permalink
Add axes argument to _set_defaults_value and use it in add_widget
Browse files Browse the repository at this point in the history
…, where the axes are already known
  • Loading branch information
ericpre committed Jun 9, 2021
1 parent 605f5a6 commit 94d87d2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 38 deletions.
65 changes: 38 additions & 27 deletions hyperspy/roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _parse_axes(self, axes, axes_manager):
Parameters
----------
axes : specification of axes to use, default is None
axes : specification of axes to use
The axes argument specifies which axes the ROI will be applied on.
The axes in the collection can be either of the following:
Expand Down Expand Up @@ -354,7 +354,7 @@ def _apply_roi2widget(self, widget):
"""
raise NotImplementedError()

def _set_default_values(self, signal):
def _set_default_values(self, signal, axes=None):
"""When the ROI is called interactively with Undefined parameters,
use these values instead.
"""
Expand Down Expand Up @@ -482,11 +482,13 @@ def add_widget(self, signal, axes=None, widget=None, color='green',
kwargs:
All keyword argument are passed to the widget constructor.
"""

axes = self._parse_axes(axes, signal.axes_manager,)

# Undefined if roi initialised without specifying parameters
if t.Undefined in tuple(self):
self._set_default_values(signal)
self._set_default_values(signal, axes=axes)

axes = self._parse_axes(axes, signal.axes_manager,)
if widget is None:
widget = self._get_widget_type(
axes, signal)(
Expand Down Expand Up @@ -625,14 +627,15 @@ class Point1DROI(BasePointROI):
_ndim = 1

def __init__(self, value=None):
super(Point1DROI, self).__init__()
super().__init__()
value = value if value is not None else t.Undefined
self.value = value

def _set_default_values(self, signal):
ax0, *_ = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.value = ax0._parse_value('rel0.5')
self.value = axes[0]._parse_value('rel0.5')

@property
def parameters(self):
Expand Down Expand Up @@ -684,18 +687,19 @@ class Point2DROI(BasePointROI):
_ndim = 2

def __init__(self, x=None, y=None):
super(Point2DROI, self).__init__()
super().__init__()
x, y = (
para if para is not None
else t.Undefined for para in (x, y))

self.x, self.y = x, y

def _set_default_values(self, signal):
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.x = ax0._parse_value("rel0.5")
self.y = ax1._parse_value("rel0.5")
self.x = axes[0]._parse_value("rel0.5")
self.y = axes[1]._parse_value("rel0.5")

@property
def parameters(self):
Expand Down Expand Up @@ -751,10 +755,11 @@ def __init__(self, left=None, right=None):
else t.Undefined for para in (left, right))
self.left, self.right = left, right

def _set_default_values(self, signal):
ax0, *_ = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.left, self.right = _get_central_half_limits_of_axis(ax0)
self.left, self.right = _get_central_half_limits_of_axis(axes[0])

@property
def parameters(self):
Expand Down Expand Up @@ -835,14 +840,16 @@ def __getitem__(self, *args, **kwargs):
_tuple = (self.left, self.right, self.top, self.bottom)
return _tuple.__getitem__(*args, **kwargs)

def _set_default_values(self, signal):
def _set_default_values(self, signal, axes=None):
# Need to turn of bounds checking or undefined values trigger error
old_bounds_check = self._bounds_check
self._bounds_check = False
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)

# If roi parameters are undefined, use center of axes
self.left, self.right = _get_central_half_limits_of_axis(ax0)
self.top, self.bottom = _get_central_half_limits_of_axis(ax1)
self.left, self.right = _get_central_half_limits_of_axis(axes[0])
self.top, self.bottom = _get_central_half_limits_of_axis(axes[1])
self._bounds_check = old_bounds_check

@property
Expand Down Expand Up @@ -984,8 +991,11 @@ def __init__(self, cx=None, cy=None, r=None, r_inner=0):
self._bounds_check = True # Use reponsibly!
self.cx, self.cy, self.r, self.r_inner = cx, cy, r, r_inner

def _set_default_values(self, signal):
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
ax0, ax1 = axes

# If roi parameters are undefined, use center of axes
self.cx = ax0._parse_value('rel0.5')
self.cy = ax1._parse_value('rel0.5')
Expand Down Expand Up @@ -1148,19 +1158,20 @@ class Line2DROI(BaseInteractiveROI):
_ndim = 2

def __init__(self, x1=None, y1=None, x2=None, y2=None, linewidth=0):
super(Line2DROI, self).__init__()
super().__init__()
x1, y1, x2, y2 = (
para if para is not None
else t.Undefined for para in (x1, y1, x2, y2))

self.x1, self.y1, self.x2, self.y2 = x1, y1, x2, y2
self.linewidth = linewidth

def _set_default_values(self, signal):
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.x1, self.x2 = _get_central_half_limits_of_axis(ax0)
self.y1, self.y2 = _get_central_half_limits_of_axis(ax1)
self.x1, self.x2 = _get_central_half_limits_of_axis(axes[0])
self.y1, self.y2 = _get_central_half_limits_of_axis(axes[1])

@property
def parameters(self):
Expand Down
34 changes: 23 additions & 11 deletions hyperspy/tests/utils/test_roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,21 @@ def test_spanroi_getitem(self):
r = SpanROI(15, 30)
assert tuple(r) == (15, 30)

def test_widget_initialisation(self):
self.s_s.plot()
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI, CircleROI]:
r = roi()
r._set_default_values(self.s_s)
r.add_widget(self.s_s)

def test_add_widget_ROI_undefined(self):
@pytest.mark.parametrize('axes', [None, 'signal'])
def test_add_widget_ROI_undefined(self, axes):
s = self.s_i
s.plot()
line = Line2DROI()
line.add_widget(s)
if axes == 'signal':
axes = s.axes_manager.signal_axes
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI,
CircleROI]:
r = roi()
r.add_widget(s, axes=axes)
if axes is None:
expected_axes = s.axes_manager.navigation_axes
else:
expected_axes = axes
r.signal_map[s][1][0] in expected_axes

def test_span_spectrum_sig(self):
s = self.s_s
Expand Down Expand Up @@ -527,11 +530,20 @@ def test_undefined_call(self):
r(self.s_s)

def test_default_values_call(self):
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI, CircleROI]:
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI,
CircleROI]:
r = roi()
r._set_default_values(self.s_s)
r(self.s_s)

def test_default_values_call_specify_signal_axes(self):
s = self.s_i
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI,
CircleROI]:
r = roi()
r._set_default_values(s, axes=s.axes_manager.signal_axes)
r(s)

def test_get_central_half_limits(self):
ax = self.s_s.axes_manager[0]
assert _get_central_half_limits_of_axis(ax) == (73.75, 221.25)
Expand Down

0 comments on commit 94d87d2

Please sign in to comment.