Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plot residual #3186

Merged
merged 5 commits into from Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/model_fitting/plot_residual.py
@@ -0,0 +1,36 @@
"""
Plot Residual
=============

Fit an affine function and plot the residual.

"""

import numpy as np
import hyperspy.api as hs

#%%
# Create a signal:
data = np.arange(1000, dtype=np.int64).reshape((10, 100))
s = hs.signals.Signal1D(data)

#%%
# Add noise:
s.add_poissonian_noise(random_state=0)

#%%
# Create model:
m = s.create_model()
line = hs.model.components1D.Expression("a * x + b", name="Affine")
m.append(line)

#%%
# Fit for all navigation positions:
m.multifit()

#%%
# Plot the fitted model with residual:
m.plot(plot_residual=True)

#%%
# sphinx_gallery_thumbnail_number = 2
35 changes: 32 additions & 3 deletions hyperspy/models/model1d.py
Expand Up @@ -258,6 +258,7 @@ def __init__(self, signal1D, dictionary=None):
self._plot_components = False
self._suspend_update = False
self._model_line = None
self._residual_line = None
self.axis = self.axes_manager.signal_axes[0]
self.axes_manager.events.indices_changed.connect(
self._on_navigating, [])
Expand Down Expand Up @@ -377,7 +378,8 @@ def remove(self, things):
remove.__doc__ = BaseModel.remove.__doc__

def __call__(self, non_convolved=False, onlyactive=False,
component_list=None, binned=None):
component_list=None, binned=None,
ignore_channel_switches = False):
"""Returns the corresponding model for the current coordinates

Parameters
Expand All @@ -393,6 +395,9 @@ def __call__(self, non_convolved=False, onlyactive=False,
binned : bool or None
Specify whether the binned attribute of the signal axes needs to be
taken into account.
ignore_channel_switches: bool
If true, the entire signal axis are returned
without checking channel_switches.

cursor: 1 or 2

Expand All @@ -412,7 +417,8 @@ def __call__(self, non_convolved=False, onlyactive=False,
component for component in component_list if component.active]

if self.convolved is False or non_convolved is True:
axis = self.axis.axis[self.channel_switches]
slice_ = slice(None) if ignore_channel_switches else self.channel_switches
axis = self.axis.axis[slice_]
sum_ = np.zeros(len(axis))
for component in component_list:
sum_ += component.function(axis)
Expand Down Expand Up @@ -708,15 +714,24 @@ def _model2plot(self, axes_manager, out_of_range2nans=True):
ns[np.where(self.channel_switches)] = s
s = ns
return s

def _residual_for_plot(self,**kwargs):
"""From an model1D object, the original signal is subtracted
by the model signal then returns the residual
"""

def plot(self, plot_components=False, **kwargs):
return self.signal.__call__() - self.__call__(ignore_channel_switches=True)

def plot(self, plot_components=False,plot_residual=False, **kwargs):
"""Plot the current spectrum to the screen and a map with a
cursor to explore the SI.

Parameters
----------
plot_components : bool
If True, add a line per component to the signal figure.
plot_residual : bool
If True, add a residual line (Signal - Model) to the signal figure.
**kwargs : dict
All extra keyword arguements are passed to
:py:meth:`~._signals.signal1d.Signal1D.plot`
Expand All @@ -740,6 +755,20 @@ def plot(self, plot_components=False, **kwargs):
self._model_line = l2
self._plot = self.signal._plot
self._connect_parameters2update_plot(self)

#Optional to plot the residual of (Signal - Model)
if plot_residual:
l3 = hyperspy.drawing.signal1d.Signal1DLine()
# _residual_for_plot outputs the residual (Signal - Model)
l3.data_function = self._residual_for_plot
l3.set_line_properties(color='green', type='line')
# Add the line to the figure
_plot.signal_plot.add_line(l3)
l3.plot()
# Quick access to _residual_line if needed
self._residual_line = l3


if plot_components is True:
self.enable_plot_components()
else:
Expand Down
5 changes: 4 additions & 1 deletion hyperspy/tests/drawing/test_plot_model1d.py
Expand Up @@ -29,7 +29,7 @@

class TestModelPlot:
def setup_method(self, method):
s = Signal1D(np.arange(1000).reshape((10, 100)))
s = Signal1D(np.arange(1000, dtype=np.int64).reshape((10, 100)))
s.add_poissonian_noise(random_state=0)
m = s.create_model()
line = Expression("a * x", name="line", a=1)
Expand Down Expand Up @@ -64,3 +64,6 @@ def test_default_navigator_plot(self):
def test_no_navigator(self):
self.m.plot(navigator=None)
assert self.m.signal._plot.navigator_plot is None

def test_plot_residual(self):
self.m.plot(plot_residual=True)
1 change: 0 additions & 1 deletion hyperspy/tests/model/test_model.py
Expand Up @@ -23,7 +23,6 @@

import hyperspy.api as hs
from hyperspy.decorators import lazifyTestClass
from hyperspy.misc.test_utils import ignore_warning
from hyperspy.misc.utils import slugify


Expand Down
1 change: 1 addition & 0 deletions upcoming_changes/3186.new.rst
@@ -0,0 +1 @@
Added a plot_residual argument to model1d.plot. This argument will plot a residual signal (Signal - Model) to the signal plot.