Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/RELEASE_next_patch' into RELEA…
Browse files Browse the repository at this point in the history
…SE_next_minor
  • Loading branch information
ericpre committed Sep 17, 2022
2 parents b625c5c + 8f9aff4 commit fbe08e2
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 55 deletions.
2 changes: 1 addition & 1 deletion doc/user_guide/bibliography.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ Bibliography
:ref:`[Zhou2011] <Zhou2011>`
T. Zhou and D. Tao, "GoDec: Randomized Low-rank
& Sparse Matrix Decomposition in Noisy Case", *ICML-11* (2011): 33–40
[`<http://www.icml-2011.org/papers/41_icmlpaper.pdf>`_].
[`<https://icml.cc/Conferences/2011/papers/41_icmlpaper.pdf>`_].

.. _Schaffer2004:

Expand Down
82 changes: 47 additions & 35 deletions hyperspy/_signals/signal1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# You should have received a copy of the GNU General Public License
# along with HyperSpy. If not, see <https://www.gnu.org/licenses/#GPL>.

import os
import logging
import math
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import dask.array as da
from scipy import interpolate
Expand All @@ -29,7 +29,8 @@

from hyperspy.signal import BaseSignal
from hyperspy._signals.common_signal1d import CommonSignal1D
from hyperspy.signal_tools import SpikesRemoval, SpikesRemovalInteractive
from hyperspy.signal_tools import (
SpikesRemoval, SpikesRemovalInteractive, SimpleMessage)
from hyperspy.models.model1d import Model1D
from hyperspy.misc.lowess_smooth import lowess
from hyperspy.misc.utils import is_binned # remove in v2.0
Expand Down Expand Up @@ -270,9 +271,15 @@ def __init__(self, *args, **kwargs):
raise ValueError("Signal1D can't be ragged.")
super().__init__(*args, **kwargs)

def _get_spikes_diagnosis_histogram_data(self, signal_mask=None,
navigation_mask=None,
**kwargs):
def _spikes_diagnosis(
self,
signal_mask=None,
navigation_mask=None,
show_plot=False,
use_gui=False,
**kwargs
):

self._check_signal_dimension_equals_one()
dc = self.data
axis = self.axes_manager.signal_axes[0].axis
Expand All @@ -281,22 +288,41 @@ def _get_spikes_diagnosis_histogram_data(self, signal_mask=None,
axis = axis[~signal_mask]
if navigation_mask is not None:
dc = dc[~navigation_mask, :]
if dc.size == 0:
raise ValueError("The data size must be higher than 0.")
der = abs(np.gradient(dc, axis, axis=-1))
n = ((~navigation_mask).sum() if navigation_mask else
self.axes_manager.navigation_size)

# arbitrary cutoff for number of spectra necessary before histogram
# data is compressed by finding maxima of each spectrum
tmp = BaseSignal(der) if n < 2000 else BaseSignal(
np.ravel(der.max(-1)))
tmp = BaseSignal(der) if n < 2000 else BaseSignal(np.ravel(der.max(-1)))

s_ = tmp.get_histogram(**kwargs)
s_.axes_manager[0].name = "Derivative magnitude"
s_.metadata.Signal.quantity = "Counts"
s_.metadata.General.title = "Spikes Analysis"

if s_.data.size == 1:
message = "The derivative of the data is constant."
if use_gui:
m = SimpleMessage(text=message)
try:
m.gui()
except (NotImplementedError, ImportError):
# This is only available for traitsui, in case of ipywidgets
# we show a warning
warnings.warn(message)
else:
warnings.warn(message)
elif show_plot:
s_.plot(norm="log")

# get histogram signal using smart binning and plot
return tmp.get_histogram(**kwargs)
return s_

def spikes_diagnosis(self, signal_mask=None,
navigation_mask=None,
**kwargs):
"""Plots a histogram to help in choosing the threshold for
def spikes_diagnosis(self, signal_mask=None, navigation_mask=None, **kwargs):
"""
Plots a histogram to help in choosing the threshold for
spikes removal.
Parameters
Expand All @@ -312,27 +338,13 @@ def spikes_diagnosis(self, signal_mask=None,
spikes_removal_tool
"""
tmph = self._get_spikes_diagnosis_histogram_data(signal_mask,
navigation_mask,
**kwargs)
tmph.plot()

# Customize plot appearance
plt.gca().set_title('')
plt.gca().fill_between(tmph.axes_manager[0].axis,
tmph.data,
facecolor='#fddbc7',
interpolate=True,
color='none')
ax = tmph._plot.signal_plot.ax
axl = tmph._plot.signal_plot.ax_lines[0]
axl.set_line_properties(color='#b2182b')
plt.xlabel('Derivative magnitude')
plt.ylabel('Log(Counts)')
ax.set_yscale('log')
ax.set_ylim(10 ** -1, plt.ylim()[1])
ax.set_xlim(plt.xlim()[0], 1.1 * plt.xlim()[1])
plt.draw()
self._spikes_diagnosis(
signal_mask=signal_mask,
navigation_mask=navigation_mask,
show_plot=True,
use_gui=False,
**kwargs
)

spikes_diagnosis.__doc__ %= (SIGNAL_MASK_ARG, NAVIGATION_MASK_ARG)

Expand Down
7 changes: 6 additions & 1 deletion hyperspy/drawing/_widgets/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@

import inspect
import logging
from packaging.version import Version

import matplotlib
import numpy as np

from hyperspy.drawing.widget import ResizableDraggableWidgetBase
from hyperspy.defaults_parser import preferences

from hyperspy.external.matplotlib.widgets import SpanSelector
if Version(matplotlib.__version__) >= Version('3.6.0'):
from matplotlib.widgets import SpanSelector
else:
from hyperspy.external.matplotlib.widgets import SpanSelector

_logger = logging.getLogger(__name__)

Expand Down
17 changes: 12 additions & 5 deletions hyperspy/signal_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,10 +1757,13 @@ def __init__(self, signal, navigation_mask=None, signal_mask=None,
signal.axes_manager.indices = self.coordinates[0]
if threshold == 'auto':
# Find the first zero of the spikes diagnosis plot
hist = signal._get_spikes_diagnosis_histogram_data(
hist = signal._spikes_diagnosis(
signal_mask=signal_mask,
navigation_mask=navigation_mask,
max_num_bins=max_num_bins)
max_num_bins=max_num_bins,
show_plot=False,
use_gui=False,
)
zero_index = np.where(hist.data == 0)[0]
if zero_index.shape[0] > 0:
index = zero_index[0]
Expand Down Expand Up @@ -1950,9 +1953,13 @@ def _click_to_show_instructions_fired(self):
title="Instructions"),

def _show_derivative_histogram_fired(self):
self.signal.spikes_diagnosis(signal_mask=self.signal_mask,
navigation_mask=self.navigation_mask,
max_num_bins=self.max_num_bins)
self.signal._spikes_diagnosis(
signal_mask=self.signal_mask,
navigation_mask=self.navigation_mask,
max_num_bins=self.max_num_bins,
show_plot=True,
use_gui=True,
)

def _reset_line(self):
if self.interpolated_line is not None:
Expand Down
27 changes: 24 additions & 3 deletions hyperspy/tests/signals/test_eels.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,18 +260,22 @@ def test_spikes_diagnosis(self):
self.signal.spikes_diagnosis(zero_loss_peak_mask_width=5.0)

zlp_mask = self.signal.get_zero_loss_peak_mask()
hist_data = self.signal._get_spikes_diagnosis_histogram_data(
signal_mask=zlp_mask, bins=25)
hist_data = self.signal._spikes_diagnosis(signal_mask=zlp_mask, bins=25)
expected_data = np.zeros(25)
expected_data[0] = 232
expected_data[12] = 1
expected_data[-1] = 1
np.testing.assert_allclose(hist_data.data, expected_data)

hist_data2 = self.signal._get_spikes_diagnosis_histogram_data(bins=25)
hist_data2 = self.signal._spikes_diagnosis(bins=25)
expected_data2 = np.array([285, 11, 13, 0, 0, 1, 12, 0])
np.testing.assert_allclose(hist_data2.data[:8], expected_data2)

# mask all to check that it raises an error when there is no data
signal_mask = self.signal.inav[0,1].data.astype(bool)
with pytest.raises(ValueError):
self.signal.spikes_diagnosis(signal_mask=signal_mask)


def test_spikes_removal_tool_no_zlp():
s = hs.datasets.artificial_data.get_core_loss_eels_line_scan_signal()
Expand All @@ -280,6 +284,23 @@ def test_spikes_removal_tool_no_zlp():
s.spikes_removal_tool(zero_loss_peak_mask_width=5.0)


def test_spikes_diagnosis_constant_derivative():
s = hs.signals.Signal1D(np.arange(20).reshape(2, 10))
with pytest.warns():
s._spikes_diagnosis(use_gui=False)

hs.preferences.GUIs.enable_traitsui_gui = False
with pytest.warns():
s._spikes_diagnosis(use_gui=True)

hs.preferences.GUIs.enable_traitsui_gui = True
try:
import hyperspy_gui_traitsui
s._spikes_diagnosis(use_gui=True)
except ImportError:
pass


@lazifyTestClass
class TestPowerLawExtrapolation:

Expand Down
17 changes: 7 additions & 10 deletions hyperspy/tests/signals/test_spike_removal_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_spikes_removal_tool():
s.data[1, 2, 14] += 1

sr = SpikesRemovalInteractive(s)
sr._show_derivative_histogram_fired()
sr.threshold = 1.5
sr.find()
assert s.axes_manager.indices == (0, 1)
Expand Down Expand Up @@ -60,17 +61,13 @@ def test_spikes_removal_tool():
assert s.axes_manager.indices == (0, 0)


add_noise_params = [
[False, 5],
[True, 1]
]


def test_spikes_removal_tool_navigation_dimension_0():
#Artificial Signal
s = Signal1D(np.ones(1234))
#Add a spike
s.data[333] = 666
s.data[333] = 5
np.random.seed(1)
s.add_gaussian_noise(0.01)

assert s.axes_manager.navigation_dimension == 0

Expand All @@ -81,7 +78,7 @@ def test_spikes_removal_tool_navigation_dimension_0():

sr.apply()

np.testing.assert_allclose(s.data[333], 1, atol=1e-4)
np.testing.assert_allclose(s.data[333], 1, atol=0.02)


@pytest.mark.parametrize(("add_noise, decimal"), [(True, 1), (False, 5)])
Expand Down Expand Up @@ -115,8 +112,8 @@ def test_spikes_removal_tool_non_interactive_masking():
navigation_mask[1, 0] = True
signal_mask = np.zeros((30,), dtype='bool')
signal_mask[28:] = True
sr = s.spikes_removal_tool(threshold=0.5, interactive=False, add_noise=False,
navigation_mask=navigation_mask, signal_mask=signal_mask)
s.spikes_removal_tool(threshold=0.5, interactive=False, add_noise=False,
navigation_mask=navigation_mask, signal_mask=signal_mask)
np.testing.assert_almost_equal(s.data[1, 0, 1], 3, decimal=5)
np.testing.assert_almost_equal(s.data[0, 2, 29], 2, decimal=5)
np.testing.assert_almost_equal(s.data[1, 2, 14], 1, decimal=5)
1 change: 1 addition & 0 deletions upcoming_changes/3005.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix handling constant derivative in :py:meth:`~._signals.signal1D.Signal1D.spikes_removal_tool`
1 change: 1 addition & 0 deletions upcoming_changes/3015.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix hyperlink in bibliography
1 change: 1 addition & 0 deletions upcoming_changes/3016.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix matplotlib ``SpanSelector`` import for matplotlib 3.6

0 comments on commit fbe08e2

Please sign in to comment.