Skip to content

Commit

Permalink
Fix __init__ signals when using attributes={'_lazy':True} as keywor…
Browse files Browse the repository at this point in the history
…d argument
  • Loading branch information
ericpre committed Dec 8, 2021
1 parent 41a0d34 commit ed46373
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 17 deletions.
13 changes: 5 additions & 8 deletions hyperspy/_signals/lazy.py
Expand Up @@ -160,9 +160,8 @@ def __init__(self, *args, **kwargs):
# the NumPy array originates from.
self._cache_dask_chunk = None
self._cache_dask_chunk_slice = None
if hasattr(self, "events"):
if not self._clear_cache_dask_data in self.events.data_changed._connected_all:
self.events.data_changed.connect(self._clear_cache_dask_data)
if not self._clear_cache_dask_data in self.events.data_changed.connected:
self.events.data_changed.connect(self._clear_cache_dask_data)

def compute(self, close_file=False, show_progressbar=None, **kwargs):
"""Attempt to store the full signal in memory.
Expand Down Expand Up @@ -272,9 +271,7 @@ def _get_file_handle(self, warn=True):
"an hdf5 file.")

def _clear_cache_dask_data(self, obj=None):
if self._cache_dask_chunk is not None:
del self._cache_dask_chunk
self._cache_dask_chunk = None
self._cache_dask_chunk = None
self._cache_dask_chunk_slice = None

def _get_dask_chunks(self, axis=None, dtype=None):
Expand Down Expand Up @@ -498,8 +495,8 @@ def _get_cache_dask_chunk(self, indices):
chunks = self._get_navigation_chunk_size()
navigation_indices = indices[:-sig_dim]
chunk_slice = _get_navigation_dimension_chunk_slice(navigation_indices, chunks)
if chunk_slice != self._cache_dask_chunk_slice:
self._clear_cache_dask_data()
if (chunk_slice != self._cache_dask_chunk_slice or
self._cache_dask_chunk is None):
self._cache_dask_chunk = np.asarray(self.data.__getitem__(chunk_slice))
self._cache_dask_chunk_slice = chunk_slice

Expand Down
3 changes: 0 additions & 3 deletions hyperspy/_signals/signal2d.py
Expand Up @@ -954,6 +954,3 @@ def find_peaks(self, method='local_max', interactive=True,
class LazySignal2D(LazySignal, Signal2D):

_lazy = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
13 changes: 8 additions & 5 deletions hyperspy/signal.py
Expand Up @@ -2180,7 +2180,6 @@ class for more details).
self.models = ModelManager(self)
self.learning_results = LearningResults()
kwds['data'] = data
self._load_dictionary(kwds)
self._plot = None
self.inav = SpecialSlicersSignal(self, True)
self.isig = SpecialSlicersSignal(self, False)
Expand All @@ -2198,6 +2197,7 @@ class for more details).
Arguments:
obj: The signal that owns the data.
""", arguments=['obj'])
self._load_dictionary(kwds)

def _create_metadata(self):
self.metadata = DictionaryTreeBrowser()
Expand Down Expand Up @@ -5332,13 +5332,16 @@ def get_current_signal(self, auto_title=True, auto_filename=True):
key_dict[key] = marker.get_data_position(key)
marker.set_data(**key_dict)

cs = self.__class__(
class_ = hyperspy.io.assign_signal_subclass(
dtype=self.data.dtype,
signal_dimension=self.axes_manager.signal_dimension,
signal_type=self._signal_type,
lazy=False)

cs = class_(
self(),
axes=self.axes_manager._get_signal_axes_dicts(),
metadata=metadata.as_dictionary())
if self._lazy:
cs._lazy = False
cs._assign_subclass()

if cs.metadata.has_item('Markers'):
temp_marker_dict = cs.metadata.Markers.as_dictionary()
Expand Down
5 changes: 5 additions & 0 deletions hyperspy/tests/signals/test_assign_subclass.py
Expand Up @@ -199,3 +199,8 @@ def test_complex_to_dielectric_function(self):
assert isinstance(self.s, hs.signals.DielectricFunction)
self.s.set_signal_type("")
assert isinstance(self.s, hs.signals.ComplexSignal1D)


def test_create_lazy_signal():
# Check that this syntax is working
_ = hs.signals.BaseSignal([0, 1, 2], attributes={'_lazy': True})
1 change: 0 additions & 1 deletion hyperspy/tests/signals/test_tools.py
Expand Up @@ -24,7 +24,6 @@

from hyperspy import signals
from hyperspy.decorators import lazifyTestClass
from hyperspy.signal_tools import SpikesRemoval, SpikesRemovalInteractive


def _verify_test_sum_x_E(self, s):
Expand Down

0 comments on commit ed46373

Please sign in to comment.