Skip to content

Commit

Permalink
Merge pull request #2767 from ericpre/fix_add_widget_undefined_ROI
Browse files Browse the repository at this point in the history
Fix `add_widget` for undefined roi
  • Loading branch information
jlaehne committed Jun 30, 2021
2 parents a416442 + 94d87d2 commit eb71e12
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 37 deletions.
2 changes: 1 addition & 1 deletion hyperspy/drawing/_widgets/line2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _set_axes(self, axes):
# _set_axes overwrites self._size so we back it up
size = self._size
position = self._pos
super(Line2DWidget, self)._set_axes(axes)
super()._set_axes(axes)
# Restore self._size
self._size = size
self._pos = position
Expand Down
71 changes: 42 additions & 29 deletions hyperspy/roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _parse_axes(self, axes, axes_manager):
Parameters
----------
axes : specification of axes to use, default is None
axes : specification of axes to use
The axes argument specifies which axes the ROI will be applied on.
The axes in the collection can be either of the following:
Expand Down Expand Up @@ -354,7 +354,7 @@ def _apply_roi2widget(self, widget):
"""
raise NotImplementedError()

def _set_default_values(self, signal):
def _set_default_values(self, signal, axes=None):
"""When the ROI is called interactively with Undefined parameters,
use these values instead.
"""
Expand Down Expand Up @@ -406,9 +406,7 @@ def interactive(self, signal, navigation_signal="same", out=None,
'fft_shift', False):
raise NotImplementedError('ROIs are not supported when data '
'are shifted during plotting.')
# Undefined if roi initialised without specifying parameters
if t.Undefined in tuple(self):
self._set_default_values(signal)

if isinstance(navigation_signal, str) and navigation_signal == "same":
navigation_signal = signal
if navigation_signal is not None:
Expand Down Expand Up @@ -484,7 +482,13 @@ def add_widget(self, signal, axes=None, widget=None, color='green',
kwargs:
All keyword argument are passed to the widget constructor.
"""

axes = self._parse_axes(axes, signal.axes_manager,)

# Undefined if roi initialised without specifying parameters
if t.Undefined in tuple(self):
self._set_default_values(signal, axes=axes)

if widget is None:
widget = self._get_widget_type(
axes, signal)(
Expand Down Expand Up @@ -623,14 +627,15 @@ class Point1DROI(BasePointROI):
_ndim = 1

def __init__(self, value=None):
super(Point1DROI, self).__init__()
super().__init__()
value = value if value is not None else t.Undefined
self.value = value

def _set_default_values(self, signal):
ax0, *_ = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.value = ax0._parse_value('rel0.5')
self.value = axes[0]._parse_value('rel0.5')

@property
def parameters(self):
Expand Down Expand Up @@ -682,18 +687,19 @@ class Point2DROI(BasePointROI):
_ndim = 2

def __init__(self, x=None, y=None):
super(Point2DROI, self).__init__()
super().__init__()
x, y = (
para if para is not None
else t.Undefined for para in (x, y))

self.x, self.y = x, y

def _set_default_values(self, signal):
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.x = ax0._parse_value("rel0.5")
self.y = ax1._parse_value("rel0.5")
self.x = axes[0]._parse_value("rel0.5")
self.y = axes[1]._parse_value("rel0.5")

@property
def parameters(self):
Expand Down Expand Up @@ -749,10 +755,11 @@ def __init__(self, left=None, right=None):
else t.Undefined for para in (left, right))
self.left, self.right = left, right

def _set_default_values(self, signal):
ax0, *_ = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.left, self.right = _get_central_half_limits_of_axis(ax0)
self.left, self.right = _get_central_half_limits_of_axis(axes[0])

@property
def parameters(self):
Expand Down Expand Up @@ -833,14 +840,16 @@ def __getitem__(self, *args, **kwargs):
_tuple = (self.left, self.right, self.top, self.bottom)
return _tuple.__getitem__(*args, **kwargs)

def _set_default_values(self, signal):
def _set_default_values(self, signal, axes=None):
# Need to turn of bounds checking or undefined values trigger error
old_bounds_check = self._bounds_check
self._bounds_check = False
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)

# If roi parameters are undefined, use center of axes
self.left, self.right = _get_central_half_limits_of_axis(ax0)
self.top, self.bottom = _get_central_half_limits_of_axis(ax1)
self.left, self.right = _get_central_half_limits_of_axis(axes[0])
self.top, self.bottom = _get_central_half_limits_of_axis(axes[1])
self._bounds_check = old_bounds_check

@property
Expand Down Expand Up @@ -982,8 +991,11 @@ def __init__(self, cx=None, cy=None, r=None, r_inner=0):
self._bounds_check = True # Use reponsibly!
self.cx, self.cy, self.r, self.r_inner = cx, cy, r, r_inner

def _set_default_values(self, signal):
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
ax0, ax1 = axes

# If roi parameters are undefined, use center of axes
self.cx = ax0._parse_value('rel0.5')
self.cy = ax1._parse_value('rel0.5')
Expand Down Expand Up @@ -1146,19 +1158,20 @@ class Line2DROI(BaseInteractiveROI):
_ndim = 2

def __init__(self, x1=None, y1=None, x2=None, y2=None, linewidth=0):
super(Line2DROI, self).__init__()
super().__init__()
x1, y1, x2, y2 = (
para if para is not None
else t.Undefined for para in (x1, y1, x2, y2))

self.x1, self.y1, self.x2, self.y2 = x1, y1, x2, y2
self.linewidth = linewidth

def _set_default_values(self, signal):
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.x1, self.x2 = _get_central_half_limits_of_axis(ax0)
self.y1, self.y2 = _get_central_half_limits_of_axis(ax1)
self.x1, self.x2 = _get_central_half_limits_of_axis(axes[0])
self.y1, self.y2 = _get_central_half_limits_of_axis(axes[1])

@property
def parameters(self):
Expand Down Expand Up @@ -1459,4 +1472,4 @@ def __call__(self, signal, out=None, axes=None, order=0):

def _get_central_half_limits_of_axis(ax):
"Return indices of the central half of a DataAxis"
return ax._parse_value("rel0.25"), ax._parse_value("rel0.75")
return ax._parse_value("rel0.25"), ax._parse_value("rel0.75")
32 changes: 25 additions & 7 deletions hyperspy/tests/utils/test_roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,21 @@ def test_spanroi_getitem(self):
r = SpanROI(15, 30)
assert tuple(r) == (15, 30)

def test_widget_initialisation(self):
self.s_s.plot()
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI, CircleROI]:
@pytest.mark.parametrize('axes', [None, 'signal'])
def test_add_widget_ROI_undefined(self, axes):
s = self.s_i
s.plot()
if axes == 'signal':
axes = s.axes_manager.signal_axes
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI,
CircleROI]:
r = roi()
r._set_default_values(self.s_s)
r.add_widget(self.s_s)
r.add_widget(s, axes=axes)
if axes is None:
expected_axes = s.axes_manager.navigation_axes
else:
expected_axes = axes
r.signal_map[s][1][0] in expected_axes

def test_span_spectrum_sig(self):
s = self.s_s
Expand Down Expand Up @@ -491,7 +500,7 @@ def test_line2droi_angle(self):
r.angle(axis='z')

def test_repr_None(self):
# Setting the args=None sets them as traits.Undefined, which didn't
# Setting the args=None sets them as traits.Undefined, which didn't
# have a string representation in the old %s style.
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI]:
r = roi()
Expand Down Expand Up @@ -521,11 +530,20 @@ def test_undefined_call(self):
r(self.s_s)

def test_default_values_call(self):
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI, CircleROI]:
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI,
CircleROI]:
r = roi()
r._set_default_values(self.s_s)
r(self.s_s)

def test_default_values_call_specify_signal_axes(self):
s = self.s_i
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI,
CircleROI]:
r = roi()
r._set_default_values(s, axes=s.axes_manager.signal_axes)
r(s)

def test_get_central_half_limits(self):
ax = self.s_s.axes_manager[0]
assert _get_central_half_limits_of_axis(ax) == (73.75, 221.25)
Expand Down

0 comments on commit eb71e12

Please sign in to comment.