From ed463736df726f292cca3f49cf03e3a69410d62a Mon Sep 17 00:00:00 2001 From: Eric Prestat Date: Tue, 7 Dec 2021 23:05:32 +0000 Subject: [PATCH] Fix __init__ signals when using `attributes={'_lazy':True}` as keyword argument --- hyperspy/_signals/lazy.py | 13 +++++-------- hyperspy/_signals/signal2d.py | 3 --- hyperspy/signal.py | 13 ++++++++----- hyperspy/tests/signals/test_assign_subclass.py | 5 +++++ hyperspy/tests/signals/test_tools.py | 1 - 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/hyperspy/_signals/lazy.py b/hyperspy/_signals/lazy.py index b8e3985d75..198d946084 100644 --- a/hyperspy/_signals/lazy.py +++ b/hyperspy/_signals/lazy.py @@ -160,9 +160,8 @@ def __init__(self, *args, **kwargs): # the NumPy array originates from. self._cache_dask_chunk = None self._cache_dask_chunk_slice = None - if hasattr(self, "events"): - if not self._clear_cache_dask_data in self.events.data_changed._connected_all: - self.events.data_changed.connect(self._clear_cache_dask_data) + if not self._clear_cache_dask_data in self.events.data_changed.connected: + self.events.data_changed.connect(self._clear_cache_dask_data) def compute(self, close_file=False, show_progressbar=None, **kwargs): """Attempt to store the full signal in memory. @@ -272,9 +271,7 @@ def _get_file_handle(self, warn=True): "an hdf5 file.") def _clear_cache_dask_data(self, obj=None): - if self._cache_dask_chunk is not None: - del self._cache_dask_chunk - self._cache_dask_chunk = None + self._cache_dask_chunk = None self._cache_dask_chunk_slice = None def _get_dask_chunks(self, axis=None, dtype=None): @@ -498,8 +495,8 @@ def _get_cache_dask_chunk(self, indices): chunks = self._get_navigation_chunk_size() navigation_indices = indices[:-sig_dim] chunk_slice = _get_navigation_dimension_chunk_slice(navigation_indices, chunks) - if chunk_slice != self._cache_dask_chunk_slice: - self._clear_cache_dask_data() + if (chunk_slice != self._cache_dask_chunk_slice or + self._cache_dask_chunk is None): self._cache_dask_chunk = np.asarray(self.data.__getitem__(chunk_slice)) self._cache_dask_chunk_slice = chunk_slice diff --git a/hyperspy/_signals/signal2d.py b/hyperspy/_signals/signal2d.py index d701465b29..d1161197c4 100644 --- a/hyperspy/_signals/signal2d.py +++ b/hyperspy/_signals/signal2d.py @@ -954,6 +954,3 @@ def find_peaks(self, method='local_max', interactive=True, class LazySignal2D(LazySignal, Signal2D): _lazy = True - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) diff --git a/hyperspy/signal.py b/hyperspy/signal.py index 035a82e0d2..a8828c4aa9 100644 --- a/hyperspy/signal.py +++ b/hyperspy/signal.py @@ -2180,7 +2180,6 @@ class for more details). self.models = ModelManager(self) self.learning_results = LearningResults() kwds['data'] = data - self._load_dictionary(kwds) self._plot = None self.inav = SpecialSlicersSignal(self, True) self.isig = SpecialSlicersSignal(self, False) @@ -2198,6 +2197,7 @@ class for more details). Arguments: obj: The signal that owns the data. """, arguments=['obj']) + self._load_dictionary(kwds) def _create_metadata(self): self.metadata = DictionaryTreeBrowser() @@ -5332,13 +5332,16 @@ def get_current_signal(self, auto_title=True, auto_filename=True): key_dict[key] = marker.get_data_position(key) marker.set_data(**key_dict) - cs = self.__class__( + class_ = hyperspy.io.assign_signal_subclass( + dtype=self.data.dtype, + signal_dimension=self.axes_manager.signal_dimension, + signal_type=self._signal_type, + lazy=False) + + cs = class_( self(), axes=self.axes_manager._get_signal_axes_dicts(), metadata=metadata.as_dictionary()) - if self._lazy: - cs._lazy = False - cs._assign_subclass() if cs.metadata.has_item('Markers'): temp_marker_dict = cs.metadata.Markers.as_dictionary() diff --git a/hyperspy/tests/signals/test_assign_subclass.py b/hyperspy/tests/signals/test_assign_subclass.py index 8e2fd659db..04dc289eb9 100644 --- a/hyperspy/tests/signals/test_assign_subclass.py +++ b/hyperspy/tests/signals/test_assign_subclass.py @@ -199,3 +199,8 @@ def test_complex_to_dielectric_function(self): assert isinstance(self.s, hs.signals.DielectricFunction) self.s.set_signal_type("") assert isinstance(self.s, hs.signals.ComplexSignal1D) + + +def test_create_lazy_signal(): + # Check that this syntax is working + _ = hs.signals.BaseSignal([0, 1, 2], attributes={'_lazy': True}) diff --git a/hyperspy/tests/signals/test_tools.py b/hyperspy/tests/signals/test_tools.py index 863054f453..67758116d4 100644 --- a/hyperspy/tests/signals/test_tools.py +++ b/hyperspy/tests/signals/test_tools.py @@ -24,7 +24,6 @@ from hyperspy import signals from hyperspy.decorators import lazifyTestClass -from hyperspy.signal_tools import SpikesRemoval, SpikesRemovalInteractive def _verify_test_sum_x_E(self, s):