Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix add_widget for undefined roi #2767

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion hyperspy/drawing/_widgets/line2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _set_axes(self, axes):
# _set_axes overwrites self._size so we back it up
size = self._size
position = self._pos
super(Line2DWidget, self)._set_axes(axes)
super()._set_axes(axes)
# Restore self._size
self._size = size
self._pos = position
Expand Down
71 changes: 42 additions & 29 deletions hyperspy/roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _parse_axes(self, axes, axes_manager):

Parameters
----------
axes : specification of axes to use, default is None
axes : specification of axes to use
The axes argument specifies which axes the ROI will be applied on.
The axes in the collection can be either of the following:

Expand Down Expand Up @@ -354,7 +354,7 @@ def _apply_roi2widget(self, widget):
"""
raise NotImplementedError()

def _set_default_values(self, signal):
def _set_default_values(self, signal, axes=None):
"""When the ROI is called interactively with Undefined parameters,
use these values instead.
"""
Expand Down Expand Up @@ -406,9 +406,7 @@ def interactive(self, signal, navigation_signal="same", out=None,
'fft_shift', False):
raise NotImplementedError('ROIs are not supported when data '
'are shifted during plotting.')
# Undefined if roi initialised without specifying parameters
if t.Undefined in tuple(self):
self._set_default_values(signal)

if isinstance(navigation_signal, str) and navigation_signal == "same":
navigation_signal = signal
if navigation_signal is not None:
Expand Down Expand Up @@ -484,7 +482,13 @@ def add_widget(self, signal, axes=None, widget=None, color='green',
kwargs:
All keyword argument are passed to the widget constructor.
"""

axes = self._parse_axes(axes, signal.axes_manager,)

# Undefined if roi initialised without specifying parameters
if t.Undefined in tuple(self):
self._set_default_values(signal, axes=axes)

if widget is None:
widget = self._get_widget_type(
axes, signal)(
Expand Down Expand Up @@ -623,14 +627,15 @@ class Point1DROI(BasePointROI):
_ndim = 1

def __init__(self, value=None):
super(Point1DROI, self).__init__()
super().__init__()
value = value if value is not None else t.Undefined
self.value = value

def _set_default_values(self, signal):
ax0, *_ = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.value = ax0._parse_value('rel0.5')
self.value = axes[0]._parse_value('rel0.5')

@property
def parameters(self):
Expand Down Expand Up @@ -682,18 +687,19 @@ class Point2DROI(BasePointROI):
_ndim = 2

def __init__(self, x=None, y=None):
super(Point2DROI, self).__init__()
super().__init__()
x, y = (
para if para is not None
else t.Undefined for para in (x, y))

self.x, self.y = x, y

def _set_default_values(self, signal):
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.x = ax0._parse_value("rel0.5")
self.y = ax1._parse_value("rel0.5")
self.x = axes[0]._parse_value("rel0.5")
self.y = axes[1]._parse_value("rel0.5")

@property
def parameters(self):
Expand Down Expand Up @@ -749,10 +755,11 @@ def __init__(self, left=None, right=None):
else t.Undefined for para in (left, right))
self.left, self.right = left, right

def _set_default_values(self, signal):
ax0, *_ = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.left, self.right = _get_central_half_limits_of_axis(ax0)
self.left, self.right = _get_central_half_limits_of_axis(axes[0])

@property
def parameters(self):
Expand Down Expand Up @@ -833,14 +840,16 @@ def __getitem__(self, *args, **kwargs):
_tuple = (self.left, self.right, self.top, self.bottom)
return _tuple.__getitem__(*args, **kwargs)

def _set_default_values(self, signal):
def _set_default_values(self, signal, axes=None):
# Need to turn of bounds checking or undefined values trigger error
old_bounds_check = self._bounds_check
self._bounds_check = False
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)

# If roi parameters are undefined, use center of axes
self.left, self.right = _get_central_half_limits_of_axis(ax0)
self.top, self.bottom = _get_central_half_limits_of_axis(ax1)
self.left, self.right = _get_central_half_limits_of_axis(axes[0])
self.top, self.bottom = _get_central_half_limits_of_axis(axes[1])
self._bounds_check = old_bounds_check

@property
Expand Down Expand Up @@ -982,8 +991,11 @@ def __init__(self, cx=None, cy=None, r=None, r_inner=0):
self._bounds_check = True # Use reponsibly!
self.cx, self.cy, self.r, self.r_inner = cx, cy, r, r_inner

def _set_default_values(self, signal):
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
ax0, ax1 = axes

# If roi parameters are undefined, use center of axes
self.cx = ax0._parse_value('rel0.5')
self.cy = ax1._parse_value('rel0.5')
Expand Down Expand Up @@ -1146,19 +1158,20 @@ class Line2DROI(BaseInteractiveROI):
_ndim = 2

def __init__(self, x1=None, y1=None, x2=None, y2=None, linewidth=0):
super(Line2DROI, self).__init__()
super().__init__()
x1, y1, x2, y2 = (
para if para is not None
else t.Undefined for para in (x1, y1, x2, y2))

self.x1, self.y1, self.x2, self.y2 = x1, y1, x2, y2
self.linewidth = linewidth

def _set_default_values(self, signal):
ax0, ax1 = self._parse_axes(None, signal.axes_manager)
def _set_default_values(self, signal, axes=None):
if axes is None:
axes = self._parse_axes(None, signal.axes_manager)
# If roi parameters are undefined, use center of axes
self.x1, self.x2 = _get_central_half_limits_of_axis(ax0)
self.y1, self.y2 = _get_central_half_limits_of_axis(ax1)
self.x1, self.x2 = _get_central_half_limits_of_axis(axes[0])
self.y1, self.y2 = _get_central_half_limits_of_axis(axes[1])

@property
def parameters(self):
Expand Down Expand Up @@ -1459,4 +1472,4 @@ def __call__(self, signal, out=None, axes=None, order=0):

def _get_central_half_limits_of_axis(ax):
"Return indices of the central half of a DataAxis"
return ax._parse_value("rel0.25"), ax._parse_value("rel0.75")
return ax._parse_value("rel0.25"), ax._parse_value("rel0.75")
32 changes: 25 additions & 7 deletions hyperspy/tests/utils/test_roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,21 @@ def test_spanroi_getitem(self):
r = SpanROI(15, 30)
assert tuple(r) == (15, 30)

def test_widget_initialisation(self):
self.s_s.plot()
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI, CircleROI]:
@pytest.mark.parametrize('axes', [None, 'signal'])
def test_add_widget_ROI_undefined(self, axes):
s = self.s_i
s.plot()
if axes == 'signal':
axes = s.axes_manager.signal_axes
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI,
CircleROI]:
r = roi()
r._set_default_values(self.s_s)
r.add_widget(self.s_s)
r.add_widget(s, axes=axes)
if axes is None:
expected_axes = s.axes_manager.navigation_axes
else:
expected_axes = axes
r.signal_map[s][1][0] in expected_axes

def test_span_spectrum_sig(self):
s = self.s_s
Expand Down Expand Up @@ -491,7 +500,7 @@ def test_line2droi_angle(self):
r.angle(axis='z')

def test_repr_None(self):
# Setting the args=None sets them as traits.Undefined, which didn't
# Setting the args=None sets them as traits.Undefined, which didn't
# have a string representation in the old %s style.
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI]:
r = roi()
Expand Down Expand Up @@ -521,11 +530,20 @@ def test_undefined_call(self):
r(self.s_s)

def test_default_values_call(self):
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI, CircleROI]:
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI,
CircleROI]:
r = roi()
r._set_default_values(self.s_s)
r(self.s_s)

def test_default_values_call_specify_signal_axes(self):
s = self.s_i
for roi in [Point1DROI, Point2DROI, RectangularROI, SpanROI, Line2DROI,
CircleROI]:
r = roi()
r._set_default_values(s, axes=s.axes_manager.signal_axes)
r(s)

def test_get_central_half_limits(self):
ax = self.s_s.axes_manager[0]
assert _get_central_half_limits_of_axis(ax) == (73.75, 221.25)
Expand Down