Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve plotting of lazy signals, by keeping current chunk #2568

Merged
merged 28 commits into from Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
08ec188
Improve plotting of lazy signals, by keeping current chunk
magnunor Oct 29, 2020
444dadf
Minor style fix in _signals/lazy.py
magnunor Oct 30, 2020
2955e71
Add tests for chunk plotting with lazy signals
magnunor Nov 4, 2020
40ac510
Fix some wrong lazy chunk plot tests
magnunor Nov 4, 2020
b3145ec
Merge branch 'RELEASE_next_minor' into improve_lazy_plotting
magnunor Nov 5, 2020
b124dc5
Rename _get_temporary_plotting_dask_chunk to _get_temporary_dask_chunk
magnunor Nov 5, 2020
72c9a74
Improved lazy plotting: fix issue with data changing using map
magnunor Nov 5, 2020
58699d6
Improved lazy plotting: fix issue when __call__() had not been run
magnunor Nov 5, 2020
ab0ed34
Style fix in _clear_temp_dask_data
magnunor Nov 5, 2020
dfdf33b
Merge remote-tracking branch 'upstream/RELEASE_next_minor' into impro…
ericpre Dec 2, 2020
8a1cf38
Merge branch 'RELEASE_next_minor' into improve_lazy_plotting
magnunor Feb 21, 2021
efc7d4e
Merge branch 'RELEASE_next_minor' into improve_lazy_plotting
magnunor May 25, 2021
c3a2f63
Add lazy plotting improvements to upcoming changes
magnunor May 25, 2021
9244e4f
Merge branch 'RELEASE_next_minor' into improve_lazy_plotting
magnunor Sep 5, 2021
890187a
Rename _get_navigation_dimension_chunk_slice, and several variable names
magnunor Nov 7, 2021
c399e25
Fix some whitespace issues in _get_navigation_dimension_chunk_slice
magnunor Nov 7, 2021
68c474d
Rename _temp_dask_ to _cache_dask_, and position to indices
magnunor Nov 7, 2021
638339b
Improvements to lazy plotting functionality
magnunor Nov 14, 2021
2ae2887
Fix issue in LazySignal init
magnunor Nov 21, 2021
92f084a
Rename unit test to make purpose clearer
magnunor Nov 21, 2021
f51bf43
Improve lazy caching unit tests
magnunor Nov 21, 2021
3f20522
Small change, get_current_signal: make it work with LazySignal.__init__
magnunor Nov 26, 2021
b3eb354
LazySignal init: see if signal has events attribute
magnunor Nov 26, 2021
a978cd6
Add docstring to _get_cache_dask_chunk
magnunor Nov 28, 2021
b691883
Update upcoming_changes, to include multifit
magnunor Nov 28, 2021
cc1a4d3
Merge branch 'RELEASE_next_minor' into improve_lazy_plotting
magnunor Nov 28, 2021
41a0d34
Use existing function to get navigation chunks
magnunor Nov 28, 2021
ed46373
Fix __init__ signals when using `attributes={'_lazy':True}` as keywor…
ericpre Dec 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
137 changes: 137 additions & 0 deletions hyperspy/_signals/lazy.py
Expand Up @@ -84,12 +84,85 @@ def to_array(thing, chunks=None):
raise ValueError


def _get_navigation_dimension_chunk_slice(navigation_indices, chunks):
"""Get the slice necessary to get the dask data chunk containing the
navigation indices.

Parameters
----------
navigation_indices : iterable
chunks : iterable

Returns
-------
chunk_slice : list of slices

Examples
--------
Making all the variables

>>> import dask.array as da
>>> from hyperspy._signals.lazy import _get_navigation_dimension_chunk_slice
>>> data = da.random.random((128, 128, 256, 256), chunks=(32, 32, 32, 32))
>>> s = hs.signals.Signal2D(data).as_lazy()
>>> sig_dim = s.axes_manager.signal_dimension
>>> nav_chunks = s.data.chunks[:-sig_dim]
>>> navigation_indices = s.axes_manager._getitem_tuple[:-sig_dim]

The navigation index here is (0, 0), giving us the slice which contains
this index.

>>> chunk_slice = _get_navigation_dimension_chunk_slice(navigation_indices, nav_chunks)
>>> print(chunk_slice)
(slice(0, 32, None), slice(0, 32, None))
>>> data_chunk = data[chunk_slice]

Moving the navigator to a new position, by directly setting the indices.
Normally, this is done by moving the navigator while plotting the data.
Note the "inversion" of the axes here: the indices is given in (x, y),
while the chunk_slice is given in (y, x).

>>> s.axes_manager.indices = (128, 70)
>>> navigation_indices = s.axes_manager._getitem_tuple[:-sig_dim]
>>> chunk_slice = _get_navigation_dimension_chunk_slice(navigation_indices, nav_chunks)
>>> print(chunk_slice)
(slice(64, 96, None), slice(96, 128, None))
>>> data_chunk = data[chunk_slice]

"""
chunk_slice_list = da.core.slices_from_chunks(chunks)
for chunk_slice in chunk_slice_list:
is_slice = True
for index_nav in range(len(navigation_indices)):
temp_slice = chunk_slice[index_nav]
nav = navigation_indices[index_nav]
if not (temp_slice.start <= nav < temp_slice.stop):
is_slice = False
break
if is_slice:
return chunk_slice
return False


class LazySignal(BaseSignal):
"""A Lazy Signal instance that delays computation until explicitly saved
(assuming storing the full result of computation in memory is not feasible)
"""
_lazy = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# The _cache_dask_chunk and _cache_dask_chunk_slice attributes are
# used to temporarily cache data contained in one chunk, when
# self.__call__ is used. Typically done when using plot or fitting.
# _cache_dask_chunk has the NumPy array itself, while
# _cache_dask_chunk_slice has the navigation dimension chunk which
# the NumPy array originates from.
self._cache_dask_chunk = None
self._cache_dask_chunk_slice = None
if not self._clear_cache_dask_data in self.events.data_changed.connected:
self.events.data_changed.connect(self._clear_cache_dask_data)

def compute(self, close_file=False, show_progressbar=None, **kwargs):
"""Attempt to store the full signal in memory.

Expand Down Expand Up @@ -197,6 +270,10 @@ def _get_file_handle(self, warn=True):
"the file is already closed or it is not "
"an hdf5 file.")

def _clear_cache_dask_data(self, obj=None):
self._cache_dask_chunk = None
self._cache_dask_chunk_slice = None

def _get_dask_chunks(self, axis=None, dtype=None):
"""Returns dask chunks.

Expand Down Expand Up @@ -370,6 +447,66 @@ def get_dask_function(numpy_name):
s._remove_axis([ax.index_in_axes_manager for ax in axes])
return s

def _get_cache_dask_chunk(self, indices):
"""Method for handling caching of dask chunks, when using __call__.

When accessing data in a chunked HDF5 file, the whole chunks needs
to be loaded into memory. So even if you only want to access a single
index in the navigation dimension, the whole chunk in the navigation
dimension needs to be loaded into memory. This method keeps (caches)
this chunk in memory after loading it, so moving to a different
position with the same chunk will be much faster, reducing amount of
data which needs be read from the disk.

If a navigation index (via the indices parameter) in a different chunk
is asked for, the currently cached chunk is discarded, and the new
chunk is loaded into memory.

This only works for functions using self.__call__, for example
plot and fitting functions. This will not work with the region of
interest functionality.

The cached chunk is stored in the attribute s._cache_dask_chunk,
and the slice needed to extract this chunk is in
s._cache_dask_chunk_slice. To these, use s._clear_cache_dask_data()

Parameter
---------
indices : tuple
Must be the same length as navigation dimensions in self.

Returns
-------
value : NumPy array
Same shape as the signal shape of self.

Examples
--------
>>> import dask.array as da
>>> s = hs.signals.Signal2D(da.ones((5, 10, 20, 30, 40))).as_lazy()
>>> value = s._get_cache_dask_chunk((3, 6, 2))
>>> cached_chunk = s._cache_dask_chunk # Cached array
>>> cached_chunk_slice = s._cache_dask_chunk_slice # Slice of chunk
>>> s._clear_cache_dask_data() # Clearing both of these

"""

sig_dim = self.axes_manager.signal_dimension
chunks = self._get_navigation_chunk_size()
navigation_indices = indices[:-sig_dim]
chunk_slice = _get_navigation_dimension_chunk_slice(navigation_indices, chunks)
if (chunk_slice != self._cache_dask_chunk_slice or
self._cache_dask_chunk is None):
self._cache_dask_chunk = np.asarray(self.data.__getitem__(chunk_slice))
self._cache_dask_chunk_slice = chunk_slice

indices = list(indices)
for i, temp_slice in enumerate(chunk_slice):
indices[i] -= temp_slice.start
indices = tuple(indices)
value = self._cache_dask_chunk[indices]
return value

def rebin(self, new_shape=None, scale=None,
crop=False, dtype=None, out=None, rechunk=True):
factors = self._validate_rebin_args_and_get_factors(
Expand Down
3 changes: 0 additions & 3 deletions hyperspy/_signals/signal2d.py
Expand Up @@ -954,6 +954,3 @@ def find_peaks(self, method='local_max', interactive=True,
class LazySignal2D(LazySignal, Signal2D):

_lazy = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
23 changes: 15 additions & 8 deletions hyperspy/signal.py
Expand Up @@ -2180,7 +2180,6 @@ class for more details).
self.models = ModelManager(self)
self.learning_results = LearningResults()
kwds['data'] = data
self._load_dictionary(kwds)
self._plot = None
self.inav = SpecialSlicersSignal(self, True)
self.isig = SpecialSlicersSignal(self, False)
Expand All @@ -2198,6 +2197,7 @@ class for more details).
Arguments:
obj: The signal that owns the data.
""", arguments=['obj'])
self._load_dictionary(kwds)

def _create_metadata(self):
self.metadata = DictionaryTreeBrowser()
Expand Down Expand Up @@ -2649,10 +2649,12 @@ def _get_undefined_axes_list(self, ragged=False):
def __call__(self, axes_manager=None, fft_shift=False):
if axes_manager is None:
axes_manager = self.axes_manager
value = np.atleast_1d(self.data.__getitem__(
axes_manager._getitem_tuple))
if isinstance(value, da.Array):
value = np.asarray(value)
indices = axes_manager._getitem_tuple
if self._lazy:
value = self._get_cache_dask_chunk(indices)
else:
value = self.data.__getitem__(indices)
value = np.atleast_1d(value)
if fft_shift:
value = np.fft.fftshift(value)
return value
Expand Down Expand Up @@ -5330,11 +5332,16 @@ def get_current_signal(self, auto_title=True, auto_filename=True):
key_dict[key] = marker.get_data_position(key)
marker.set_data(**key_dict)

cs = self.__class__(
class_ = hyperspy.io.assign_signal_subclass(
dtype=self.data.dtype,
signal_dimension=self.axes_manager.signal_dimension,
signal_type=self._signal_type,
lazy=False)

cs = class_(
self(),
axes=self.axes_manager._get_signal_axes_dicts(),
metadata=metadata.as_dictionary(),
attributes={'_lazy': False})
metadata=metadata.as_dictionary())

if cs.metadata.has_item('Markers'):
temp_marker_dict = cs.metadata.Markers.as_dictionary()
Expand Down
5 changes: 5 additions & 0 deletions hyperspy/tests/signals/test_assign_subclass.py
Expand Up @@ -199,3 +199,8 @@ def test_complex_to_dielectric_function(self):
assert isinstance(self.s, hs.signals.DielectricFunction)
self.s.set_signal_type("")
assert isinstance(self.s, hs.signals.ComplexSignal1D)


def test_create_lazy_signal():
# Check that this syntax is working
_ = hs.signals.BaseSignal([0, 1, 2], attributes={'_lazy': True})