Skip to content

Commit

Permalink
Replace deprecated scipy.interpolate.interp1d with `scipy.interpola…
Browse files Browse the repository at this point in the history
…te.make_interp_spline`
  • Loading branch information
ericpre committed Sep 25, 2023
1 parent 2109811 commit 778de67
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 77 deletions.
30 changes: 8 additions & 22 deletions hyperspy/_components/scalable_fixed_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# along with HyperSpy. If not, see <https://www.gnu.org/licenses/#GPL>.

import numpy as np
from scipy.interpolate import interp1d
from scipy.interpolate import make_interp_spline

from hyperspy.component import Component
from hyperspy.ui_registry import add_gui_method
Expand Down Expand Up @@ -97,37 +97,23 @@ def interpolate(self, value):
self.xscale.free = value
self.shift.free = value

def prepare_interpolator(self, kind='linear', fill_value=0, **kwargs):
def prepare_interpolator(self, **kwargs):
"""Prepare interpolation.
Parameters
----------
x : array
The spectral axis of the fixed pattern
kind : str or int, optional
Specifies the kind of interpolation as a string
('linear', 'nearest', 'zero', 'slinear', 'quadratic, 'cubic')
or as an integer specifying the order of the spline interpolator
to use. Default is 'linear'.
fill_value : float, optional
If provided, then this value will be used to fill in for requested
points outside of the data range. If not provided, then the default
is NaN.
Notes
-----
Any extra keyword argument is passed to `scipy.interpolate.interp1d`
**kwargs : dict
Keywords argument are passed to
:py:func:`scipy.interpolate.make_interp_spline`
"""

self.f = interp1d(
self.f = make_interp_spline(
self.signal.axes_manager.signal_axes[0].axis,
self.signal.data.squeeze(),
kind=kind,
bounds_error=False,
fill_value=fill_value,
**kwargs)
**kwargs
)

def _function(self, x, xscale, yscale, shift):
if self.interpolate is True:
Expand Down
15 changes: 11 additions & 4 deletions hyperspy/_signals/signal1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import warnings

import numpy as np
import numpy.ma as ma
import dask.array as da
from scipy import interpolate
from scipy.signal import savgol_filter, medfilt
Expand Down Expand Up @@ -226,7 +227,11 @@ def interpolate1D(number_of_interpolation_points, data):
ch = len(data)
old_ax = np.linspace(0, 100, ch)
new_ax = np.linspace(0, 100, ch * ip - (ip - 1))
interpolator = interpolate.interp1d(old_ax, data)

data = ma.masked_invalid(data)
interpolator = interpolate.make_interp_spline(
old_ax, data, k=1, check_finite=False,
)
return interpolator(new_ax)


Expand Down Expand Up @@ -256,9 +261,11 @@ def _shift1D(data, **kwargs):
if np.isnan(shift) or shift == 0:
return data

#This is the interpolant function
si = interpolate.interp1d(original_axis, data, bounds_error=False,
fill_value=fill_value, kind=kind)
data = ma.masked_invalid(data)
# #This is the interpolant function
si = interpolate.make_interp_spline(
original_axis, data, k=1, check_finite=False
)

#Evaluate interpolated data at shifted positions
return si(original_axis-shift)
Expand Down
2 changes: 1 addition & 1 deletion hyperspy/misc/eels/base_gos.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,4 @@ def integrateq(self, onset_energy, angle, E0):
qint *= (4.0 * np.pi * a0 ** 2.0 * R ** 2 / E / T *
self.subshell_factor) * 1e28
self.qint = qint
return interpolate.interp1d(E, qint, kind=3)
return interpolate.make_interp_spline(E, qint, k=3)
4 changes: 3 additions & 1 deletion hyperspy/misc/eels/hydrogenic_gos.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def integrateq(self, onset_energy, angle, E0):
lambda x: self.gosfunc(E, np.exp(x)),
math.log(qa0sqmin), math.log(qa0sqmax))[0])
self.qint = qint
return interpolate.interp1d(self.energy_axis + energy_shift, qint)
return interpolate.make_interp_spline(
self.energy_axis + energy_shift, qint, k=1,
)

def gosfuncK(self, E, qa02):
# gosfunc calculates (=DF/DE) which IS PER EV AND PER ATOM
Expand Down
12 changes: 5 additions & 7 deletions hyperspy/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3112,32 +3112,30 @@ def interpolate_on_axis(self,
axis=0,
inplace=False,
degree=1):
"""Replaces the given `axis` with the provided `new_axis`
and interpolates data accordingly using :py:func:`scipy.interpolate.make_interp_spline`.
"""Replaces the given ``axis`` with the provided ``new_axis``
and interpolates data accordingly using
:py:func:`scipy.interpolate.make_interp_spline`.
Parameters
----------
new_axis : UniformDataAxis, DataAxis or FunctionalDataAxis
Axis which replaces the one specified by the `axis` argument.
Axis which replaces the one specified by the ``axis`` argument.
If this new axis exceeds the range of the old axis,
a warning is raised that the data will be extrapolated.
axis : int or str, default=0
Specifies the axis which will be replaced using the index of the
axis in the `axes_manager`. The axis can be specified using the index of the
axis in `axes_manager` or the axis name.
inplace : bool, default=False
If ``True`` the data of `self` is replaced by the result and
the axis is changed inplace. Otherwise `self` is not changed
and a new signal with the changes incorporated is returned.
degree: int, default=1
Specifies the B-Spline degree of the used interpolator.
Returns
-------
s : :py:class:`~hyperspy.signal.BaseSignal` (or subclass)
s : :py:class:`~.api.signals.BaseSignal` (or subclass)
A copy of the object with the axis exchanged and the data interpolated.
This only occurs when inplace is set to ``False``, otherwise nothing is returned.
"""
Expand Down
29 changes: 7 additions & 22 deletions hyperspy/signal_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,7 +1747,7 @@ def __init__(self, signal, navigation_mask=None, signal_mask=None,
_logger.info(f'Threshold value: {threshold}')
self.argmax = None
self.derivmax = None
self.kind = "linear"
self.spline_order = 1
self._temp_mask = np.zeros(self.signal().shape, dtype='bool')
self.index = 0
self.threshold = threshold
Expand Down Expand Up @@ -1826,10 +1826,7 @@ def get_interpolated_spectrum(self, axes_manager=None):
data = self.signal().copy()
axis = self.signal.axes_manager.signal_axes[0]
left, right = self.get_interpolation_range()
if self.kind == 'linear':
pad = 1
else:
pad = self.spline_order
pad = self.spline_order
ileft = left - pad
iright = right + pad
ileft = np.clip(ileft, 0, len(data))
Expand All @@ -1852,7 +1849,7 @@ def get_interpolated_spectrum(self, axes_manager=None):
# Interpolate
x = np.hstack((axis.axis[ileft:left], axis.axis[right:iright]))
y = np.hstack((data[ileft:left], data[right:iright]))
intp = interpolate.interp1d(x, y, kind=self.kind)
intp = interpolate.make_interp_spline(x, y, k=self.spline_order)
data[left:right] = intp(axis.axis[left:right])

# Add noise
Expand Down Expand Up @@ -1882,17 +1879,11 @@ def remove_all_spikes(self):

@add_gui_method(toolkey="hyperspy.Signal1D.spikes_removal_tool")
class SpikesRemovalInteractive(SpikesRemoval, SpanSelectorInSignal1D):
interpolator_kind = t.Enum(
'Linear',
'Spline',
default='Linear',
desc="the type of interpolation to use when\n"
"replacing the signal where a spike has been replaced")
threshold = t.Float(400, desc="the derivative magnitude threshold above\n"
"which to find spikes")
click_to_show_instructions = t.Button()
show_derivative_histogram = t.Button()
spline_order = t.Range(1, 10, 3,
spline_order = t.Range(1, 10, 1,
desc="the order of the spline used to\n"
"connect the reconstructed data")
interpolator = None
Expand Down Expand Up @@ -2013,19 +2004,13 @@ def on_disabling_span_selector(self):
self.interpolated_line = None

def _spline_order_changed(self, old, new):
self.kind = self.spline_order
self.span_selector_changed()
if new != old:
self.spline_order = new
self.span_selector_changed()

def _add_noise_changed(self, old, new):
self.span_selector_changed()

def _interpolator_kind_changed(self, old, new):
if new == 'linear':
self.kind = new
else:
self.kind = self.spline_order
self.span_selector_changed()

def create_interpolation_line(self):
self.interpolated_line = drawing.signal1d.Signal1DLine()
self.interpolated_line.data_function = self.get_interpolated_spectrum
Expand Down
32 changes: 16 additions & 16 deletions hyperspy/tests/component/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,10 @@ def test_both_unbinned(self):
m = s.create_model()
fp = hs.model.components1D.ScalableFixedPattern(s1)
m.append(fp)
with ignore_warning(message="invalid value encountered in sqrt",
category=RuntimeWarning):
m.fit()
assert abs(fp.yscale.value - 100) <= 0.1
fp.xscale.free = False
fp.shift.free = False
m.fit()
np.testing.assert_allclose(fp.yscale.value, 100)

@pytest.mark.parametrize(("uniform"), (True, False))
def test_both_binned(self, uniform):
Expand All @@ -383,10 +383,10 @@ def test_both_binned(self, uniform):
m = s.create_model()
fp = hs.model.components1D.ScalableFixedPattern(s1)
m.append(fp)
with ignore_warning(message="invalid value encountered in sqrt",
category=RuntimeWarning):
m.fit()
assert abs(fp.yscale.value - 100) <= 0.1
fp.xscale.free = False
fp.shift.free = False
m.fit()
np.testing.assert_allclose(fp.yscale.value, 100)

def test_pattern_unbinned_signal_binned(self):
s = self.s
Expand All @@ -396,10 +396,10 @@ def test_pattern_unbinned_signal_binned(self):
m = s.create_model()
fp = hs.model.components1D.ScalableFixedPattern(s1)
m.append(fp)
with ignore_warning(message="invalid value encountered in sqrt",
category=RuntimeWarning):
m.fit()
assert abs(fp.yscale.value - 1000) <= 1
fp.xscale.free = False
fp.shift.free = False
m.fit()
np.testing.assert_allclose(fp.yscale.value, 1000)

def test_pattern_binned_signal_unbinned(self):
s = self.s
Expand All @@ -409,10 +409,10 @@ def test_pattern_binned_signal_unbinned(self):
m = s.create_model()
fp = hs.model.components1D.ScalableFixedPattern(s1)
m.append(fp)
with ignore_warning(message="invalid value encountered in sqrt",
category=RuntimeWarning):
m.fit()
assert abs(fp.yscale.value - 10) <= .1
fp.xscale.free = False
fp.shift.free = False
m.fit()
np.testing.assert_allclose(fp.yscale.value, 10)

def test_function(self):
s = self.s
Expand Down
8 changes: 4 additions & 4 deletions hyperspy/tests/signals/test_1D_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def test_shift1D(self):
# Check that at the edges of the spectrum the value == to the
# background value. If it wasn't it'll mean that the cropping
# code is buggy
assert (s.data[:, -1] == 2).all()
assert (s.data[:, 0] == 2).all()
np.testing.assert_allclose(s.data[:, -1], 2)
np.testing.assert_allclose(s.data[:, 0], 2)
# Check that the calibration is correct
assert s.axes_manager._axes[1].offset == self.new_offset
assert s.axes_manager._axes[1].scale == self.scale
Expand All @@ -82,8 +82,8 @@ def test_align(self):
# Check that at the edges of the spectrum the value == to the
# background value. If it wasn't it'll mean that the cropping
# code is buggy
assert (s.data[:, -1] == 2).all()
assert (s.data[:, 0] == 2).all()
np.testing.assert_allclose(s.data[:, -1], 2)
np.testing.assert_allclose(s.data[:, 0], 2)
# Check that the calibration is correct
assert (
s.axes_manager._axes[1].offset == self.new_offset)
Expand Down

0 comments on commit 778de67

Please sign in to comment.