Skip to content

Commit

Permalink
Merge pull request #2830 from ericpre/protect_axes_manager_attribute
Browse files Browse the repository at this point in the history
Protect axes manager attribute
  • Loading branch information
jlaehne committed Apr 15, 2022
2 parents 9c3d8c9 + 6f3b5f2 commit 3e89e22
Show file tree
Hide file tree
Showing 28 changed files with 212 additions and 127 deletions.
3 changes: 0 additions & 3 deletions hyperspy/_signals/complex_signal1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ class ComplexSignal1D(ComplexSignal, CommonSignal1D):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.axes_manager.signal_dimension != 1:
self.axes_manager.set_signal_dimension(1)


class LazyComplexSignal1D(ComplexSignal1D, LazyComplexSignal):

Expand Down
2 changes: 0 additions & 2 deletions hyperspy/_signals/complex_signal2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ class ComplexSignal2D(ComplexSignal, CommonSignal2D):

def __init__(self, *args, **kw):
super().__init__(*args, **kw)
if self.axes_manager.signal_dimension != 2:
self.axes_manager.set_signal_dimension(2)

def add_phase_ramp(self, ramp_x, ramp_y, offset=0):
"""Add a linear phase ramp to the wave.
Expand Down
2 changes: 1 addition & 1 deletion hyperspy/_signals/eds.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def get_lines_intensity(self,
f'X-ray line intensity of {self.metadata.General.title}: '
f'{Xray_line} at {line_energy:.2f} '
f'{self.axes_manager.signal_axes[0].units}')
img.axes_manager.set_signal_dimension(0)
img = img.transpose(signal_axes=[])
if plot_result and img.axes_manager.navigation_size == 1:
if img._lazy:
img.compute()
Expand Down
6 changes: 3 additions & 3 deletions hyperspy/_signals/eels.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,8 @@ def estimate_elastic_scattering_intensity(
else:
ax = self.axes_manager.signal_axes[0]
# I0 = self._get_navigation_signal()
# I0.axes_manager.set_signal_dimension(0)
threshold.axes_manager.set_signal_dimension(0)
# I0 = I0.transpose(signal_axes=[])
threshold = threshold.transpose(signal_axes=[])
binned = ax.is_binned

def estimating_function(data, threshold=None):
Expand Down Expand Up @@ -821,7 +821,7 @@ def estimate_thickness(self,
s.tmp_parameters.folder = self.tmp_parameters.folder
s.tmp_parameters.extension = \
self.tmp_parameters.extension
s.axes_manager.set_signal_dimension(0)
s = s.transpose(signal_axes=[])
s.set_signal_type("")
return s

Expand Down
10 changes: 1 addition & 9 deletions hyperspy/_signals/signal1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,6 @@ def __init__(self, *args, **kwargs):
if kwargs.get('ragged', False):
raise ValueError("Signal1D can't be ragged.")
super().__init__(*args, **kwargs)
if self.axes_manager.signal_dimension != 1:
self.axes_manager.set_signal_dimension(1)

def _get_spikes_diagnosis_histogram_data(self, signal_mask=None,
navigation_mask=None,
Expand Down Expand Up @@ -1619,7 +1617,7 @@ def estimating_function(spectrum,
self.metadata.General.title +
" full-width at %.1f maximum right position" % factor)
for signal in (left, width, right):
signal.axes_manager.set_signal_dimension(0)
signal = signal.transpose(signal_axes=[])
signal.set_signal_type("")
if return_interval is True:
return [width, left, right]
Expand Down Expand Up @@ -1657,10 +1655,4 @@ def plot(self,

class LazySignal1D(LazySignal, Signal1D):

"""
"""
_lazy = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.axes_manager.set_signal_dimension(1)
3 changes: 0 additions & 3 deletions hyperspy/_signals/signal2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,11 @@ class Signal2D(BaseSignal, CommonSignal2D):
"""
"""
_signal_dimension = 2
_lazy = False

def __init__(self, *args, **kwargs):
if kwargs.get('ragged', False):
raise ValueError("Signal2D can't be ragged.")
super().__init__(*args, **kwargs)
if self.axes_manager.signal_dimension != 2:
self.axes_manager.set_signal_dimension(2)

def plot(self,
navigator="auto",
Expand Down
134 changes: 89 additions & 45 deletions hyperspy/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from hyperspy.api_nogui import _ureg
from hyperspy.events import Events, Event
from hyperspy.exceptions import VisibleDeprecationWarning
from hyperspy.misc.array_tools import (
numba_closest_index_round,
numba_closest_index_floor,
Expand Down Expand Up @@ -61,7 +62,7 @@
class ndindex_nat(np.ndindex):

def __next__(self):
return super(ndindex_nat, self).__next__()[::-1]
return super().__next__()[::-1]


def generate_uniform_axis(offset, scale, size, offset_index=0):
Expand Down Expand Up @@ -274,7 +275,7 @@ class BaseDataAxis(t.HasTraits):
low_index = t.Int(0)
high_index = t.Int()
slice = t.Instance(slice)
navigate = t.Bool(t.Undefined)
navigate = t.Bool(False)
is_binned = t.Bool(t.Undefined)
index = t.Range('low_index', 'high_index')
axis = t.Array()
Expand Down Expand Up @@ -1489,12 +1490,6 @@ def __init__(self, axes_list):
if self._axes:
self.remove(self._axes)
self.create_axes(axes_list)
# set_signal_dimension is called only if there is no current
# view. It defaults to spectrum
navigates = [i.navigate for i in self._axes]
if t.Undefined in navigates:
# Default to Signal1D view if the view is not fully defined
self.set_signal_dimension(len(axes_list))

self._update_attributes()
self._update_trait_handlers()
Expand Down Expand Up @@ -2009,76 +2004,126 @@ def update_axes_attributes_from(self, axes,
def _update_attributes(self):
getitem_tuple = []
values = []
self.signal_axes = ()
self.navigation_axes = ()
signal_axes = ()
navigation_axes = ()
for axis in self._axes:
# Until we find a better place, take property of the axes
# here to avoid difficult to debug bugs.
axis.axes_manager = self
if axis.slice is None:
getitem_tuple += axis.index,
values.append(axis.value)
self.navigation_axes += axis,
navigation_axes += axis,
else:
getitem_tuple += axis.slice,
self.signal_axes += axis,
if not self.signal_axes and self.navigation_axes:
signal_axes += axis,
if not signal_axes and navigation_axes:
getitem_tuple[-1] = slice(axis.index, axis.index + 1)

self.signal_axes = self.signal_axes[::-1]
self.navigation_axes = self.navigation_axes[::-1]
self._signal_axes = signal_axes[::-1]
self._navigation_axes = navigation_axes[::-1]
self._getitem_tuple = tuple(getitem_tuple)

if len(self.signal_axes) == 1 and self.signal_axes[0].size == 1:
self.signal_dimension = 0
self._signal_dimension = 0
else:
self.signal_dimension = len(self.signal_axes)
self.navigation_dimension = len(self.navigation_axes)
self._signal_dimension = len(self.signal_axes)
self._navigation_dimension = len(self.navigation_axes)

self._signal_size = (np.prod(self.signal_shape)
if self.signal_shape else 0)
self._navigation_size = (np.prod(self.navigation_shape)
if self.navigation_shape else 0)

self._update_max_index()

@property
def signal_axes(self):
"""The signal axes as a tuple."""
return self._signal_axes

@property
def navigation_axes(self):
"""The navigation axes as a tuple."""
return self._navigation_axes

@property
def signal_shape(self):
"""The shape of the signal space."""
return tuple([axis.size for axis in self._signal_axes])

@property
def navigation_shape(self):
"""The shape of the navigation space."""
if self.navigation_dimension != 0:
self.navigation_shape = tuple([
axis.size for axis in self.navigation_axes])
return tuple([axis.size for axis in self._navigation_axes])
else:
self.navigation_shape = ()
return ()

self.signal_shape = tuple([axis.size for axis in self.signal_axes])
self.navigation_size = (np.cumprod(self.navigation_shape)[-1]
if self.navigation_shape else 0)
self.signal_size = (np.cumprod(self.signal_shape)[-1]
if self.signal_shape else 0)
self._update_max_index()
@property
def signal_size(self):
"""The size of the signal space."""
return self._signal_size

def set_signal_dimension(self, value):
"""Set the dimension of the signal.
@property
def navigation_size(self):
"""The size of the navigation space."""
return self._navigation_size

Attributes
----------
value : int
@property
def navigation_dimension(self):
"""The dimension of the navigation space."""
return self._navigation_dimension

Raises
------
ValueError
If value if greater than the number of axes or is negative.
@property
def signal_dimension(self):
"""The dimension of the signal space."""
return self._signal_dimension

"""
if self.ragged and value > 0:
raise ValueError("Signal containing ragged array must have zero "
"signal dimension.")
if len(self._axes) == 0:
def _set_signal_dimension(self, value):
if len(self._axes) == 0 or self._signal_dimension == value:
# Nothing to be done
return
elif self.ragged and value > 0:
raise ValueError("Signal containing ragged array "
"must have zero signal dimension.")
elif value > len(self._axes):
raise ValueError(
"The signal dimension cannot be greater"
" than the number of axes which is %i" % len(self._axes))
"The signal dimension cannot be greater "
f"than the number of axes which is {len(self._axes)}")
elif value < 0:
raise ValueError(
"The signal dimension must be a positive integer")

# Figure out which axis needs navigate=True
tl = [True] * len(self._axes)
if value != 0:
tl[-value:] = (False,) * value

for axis in self._axes:
# Changing navigate attribute will update the axis._slice
# which in turn will trigger _on_slice_changed and call
# _update_attribute
axis.navigate = tl.pop(0)

def set_signal_dimension(self, value):
"""Set the dimension of the signal.
Attributes
----------
value : int
Raises
------
ValueError
If value if greater than the number of axes or is negative.
"""
warnings.warn(("Using `set_signal_dimension` is deprecated, use "
"`as_signal1D`, `as_signal2D` or `transpose` of the "
"signal instance instead."),
VisibleDeprecationWarning)
self._set_signal_dimension(value)

def key_navigator(self, event):
'Set hotkeys for controlling the indices of the navigator plot'

Expand Down Expand Up @@ -2169,7 +2214,6 @@ def _get_navigation_axes_dicts(self):
self.navigation_axes[::-1]]

def show(self):
from hyperspy.exceptions import VisibleDeprecationWarning
msg = (
"The `AxesManager.show` method is deprecated and will be removed "
"in v2.0. Use `gui` instead.")
Expand Down
10 changes: 6 additions & 4 deletions hyperspy/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,10 +725,12 @@ def dict2signal(signal_dict, lazy=False):
lazy=lazy)(**signal_dict)
if signal._lazy:
signal._make_lazy()
if signal.axes_manager.signal_dimension != signal_dimension:
# This may happen when the signal dimension couldn't be matched with
# any specialised subclass
signal.axes_manager.set_signal_dimension(signal_dimension)


# This may happen when the signal dimension couldn't be matched with
# any specialised subclass
signal.axes_manager._set_signal_dimension(signal_dimension)

if "post_process" in signal_dict:
for f in signal_dict['post_process']:
signal = f(signal)
Expand Down
2 changes: 1 addition & 1 deletion hyperspy/models/edsmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def get_lines_intensity(self,
line_energy,
self.signal.axes_manager.signal_axes[0].units,
self.signal.metadata.General.title))
img.axes_manager.set_signal_dimension(0)
img = img.transpose(signal_axes=[])
if plot_result and img.axes_manager.signal_dimension == 0:
print("%s at %s %s : Intensity = %.2f"
% (xray_line,
Expand Down

0 comments on commit 3e89e22

Please sign in to comment.