Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

non-uniform updates (NotImplementedErrors, tests, spelling, saving-fix) #2761

Merged
merged 10 commits into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 21 additions & 6 deletions hyperspy/_signals/signal1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,10 +513,12 @@ def interpolate_in_between(
Parameters
----------
start, end : int or float
The limits of the interval. If int they are taken as the
axis index. If float they are taken as the axis value.
The limits of the interval. If int, they are taken as the
axis index. If float, they are taken as the axis value.
delta : int or float
The windows around the (start, end) to use for interpolation
The windows around the (start, end) to use for interpolation. If
int, they are taken as index steps. If float, they are taken in
units of the axis value.
%s
%s
%s
Expand All @@ -537,9 +539,15 @@ def interpolate_in_between(
i1 = axis._get_index(start)
i2 = axis._get_index(end)
if isinstance(delta, float):
delta = int(delta / axis.scale)
i0 = int(np.clip(i1 - delta, 0, np.inf))
i3 = int(np.clip(i2 + delta, 0, axis.size))
if isinstance(start, int):
start = axis.axis[start]
if isinstance(end, int):
end = axis.axis[end]
i0 = axis._get_index(start-delta) if start-delta < axis.low_value else 0
i3 = axis._get_index(end+delta) if end+delta > axis.high_value else axis.size
else:
i0 = int(np.clip(i1 - delta, 0, np.inf))
i3 = int(np.clip(i2 + delta, 0, axis.size))

def interpolating_function(dat):
dat_int = interpolate.interp1d(
Expand Down Expand Up @@ -616,12 +624,17 @@ def estimate_shift1D(
------
SignalDimensionError
If the signal dimension is not 1.
NotImplementedError
If the signal axis is a non-uniform axis.
"""
if show_progressbar is None:
show_progressbar = preferences.General.show_progressbar
self._check_signal_dimension_equals_one()
ip = number_of_interpolation_points + 1
axis = self.axes_manager.signal_axes[0]
if not axis.is_uniform:
raise NotImplementedError(
"The function is not implemented for non-uniform signal axes.")
self._check_navigation_mask(mask)
# we compute for now
if isinstance(start, da.Array):
Expand Down Expand Up @@ -863,6 +876,8 @@ def calibrate(self, display=True, toolkit=None):
------
SignalDimensionError
If the signal dimension is not 1.
NotImplementedError
If called with a non-uniform axes.
"""
self._check_signal_dimension_equals_one()
calibration = Signal1DCalibration(self)
Expand Down
2 changes: 1 addition & 1 deletion hyperspy/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ def print_current_values(self, only_free=False, fancy=True):
def _get_scaling_factor(signal, axis, parameter):
"""
Convenience function to get the scaling factor required to take into
account binned and/or non uniform axes.
account binned and/or non-uniform axes.

Parameters
----------
Expand Down
4 changes: 2 additions & 2 deletions hyperspy/datasets/artificial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,10 @@ def get_luminescence_signal(navigation_dimension=0,
if add_baseline:
data += 350.

#if not uniform, transformation into non-linear axis
#if not uniform, transformation into non-uniform axis
if not uniform:
hc = 1239.84198 #nm/eV
#converting to non uniform axis
#converting to non-uniform axis
sig.axes_manager.signal_axes[0].convert_to_functional_data_axis(\
expression="a/x",
name='Energy',
Expand Down
45 changes: 11 additions & 34 deletions hyperspy/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,7 @@ def save(filename, signal, overwrite=None, **kwds):
# Check if the writer can write
sd = signal.axes_manager.signal_dimension
nd = signal.axes_manager.navigation_dimension
nua = signal.axes_manager.all_uniform

if writer.writes is False:
raise ValueError(
Expand All @@ -771,6 +772,14 @@ def save(filename, signal, overwrite=None, **kwds):
f"Please try one of {strlist2enumeration(yes_we_can)}"
)

if writer.non_uniform_axis is False and nua is False:
jlaehne marked this conversation as resolved.
Show resolved Hide resolved
yes_we_can = [plugin.format_name for plugin in io_plugins
ericpre marked this conversation as resolved.
Show resolved Hide resolved
if plugin.non_uniform_axis is True]
raise OSError("Writing to this format is not supported for "
ericpre marked this conversation as resolved.
Show resolved Hide resolved
"non-uniform axes. Use one of the following "
f"formats: {strlist2enumeration(yes_we_can)}"
)

# Create the directory if it does not exist
ensure_directory(filename.parent)
is_file = filename.is_file()
Expand All @@ -782,40 +791,8 @@ def save(filename, signal, overwrite=None, **kwds):
elif overwrite is False and is_file:
write = False # Don't write the file
else:
# Check if the writer can write
sd = signal.axes_manager.signal_dimension
nd = signal.axes_manager.navigation_dimension
nua = signal.axes_manager.all_uniform
if writer.writes is False:
raise ValueError('Writing to this format is not '
'supported, supported file extensions are: %s ' %
strlist2enumeration(default_write_ext))
if writer.writes is not True and (sd, nd) not in writer.writes:
yes_we_can = [plugin.format_name for plugin in io_plugins
if plugin.writes is True or
plugin.writes is not False and
(sd, nd) in plugin.writes]
raise IOError('This file format cannot write this data. '
'The following formats can: %s' %
strlist2enumeration(yes_we_can))
if writer.non_uniform_axis is False and nua is False:
yes_we_can = [plugin.format_name for plugin in io_plugins
if plugin.non_uniform_axis is True]
raise OSError('Writing to this format is not supported for non '
'uniform axes.'
'Use one of the following formats: %s' %
strlist2enumeration(yes_we_can))
ensure_directory(filename)
is_file = os.path.isfile(filename)
if overwrite is None:
write = overwrite_method(filename) # Ask what to do
elif overwrite is True or (overwrite is False and not is_file):
write = True # Write the file
elif overwrite is False and is_file:
write = False # Don't write the file
else:
raise ValueError("`overwrite` parameter can only be None, True or "
"False.")
raise ValueError("`overwrite` parameter can only be None, True or "
"False.")
if write:
# Pass as a string for now, pathlib.Path not
# properly supported in io_plugins
Expand Down
4 changes: 2 additions & 2 deletions hyperspy/io_plugins/README
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ All the read/write plugins must provide a python file containing:
writes_images = <Bool>
writes_spectrum = <Bool>
writes_spectrum_image = <Bool>
# Support for non linear axis
non_linear_axis = <Bool>
# Support for non-uniform axis
non_uniform_axis = <Bool>

- A function called file_reader with at least one attribute: filename

Expand Down
4 changes: 2 additions & 2 deletions hyperspy/io_plugins/emd.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def _read_data_from_groups(self, group_path, dataset_name, stack_key=None,
simu_om.get('cellDimension', 0)[0])
if not math.isclose(total_thickness, len(array_list) * scale,
rel_tol=1e-4):
_logger.warning("Depth axis is non uniform and its offset "
_logger.warning("Depth axis is non-uniform and its offset "
"and scale can't be set accurately.")
# When non-uniform/non-linear axis are implemented, adjust
# the final depth to the "total_thickness"
Expand Down Expand Up @@ -798,7 +798,7 @@ def _parse_axis(axis_data):
offset, scale = axis_data[0], np.diff(axis_data).mean()
else:
# This is a string, return default values
# When non-linear axis is supported we should be able to parse
# When non-uniform axis is supported we should be able to parse
# string
offset, scale = 0, 1
return offset, scale
Expand Down
2 changes: 1 addition & 1 deletion hyperspy/io_plugins/hspy.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
# CHANGES
#
# v3.1
# - add read support for non-linear DataAxis defined by 'axis' vector
# - add read support for non-uniform DataAxis defined by 'axis' vector
# - move metadata.Signal.binned attribute to axes.is_binned parameter
#
# v3.0
Expand Down
4 changes: 3 additions & 1 deletion hyperspy/io_plugins/jeol.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
reads_spectrum_image = True
# Writing capabilities
writes = False
non_uniform_axis = False
# ----------------------


jTYPE = {
Expand Down Expand Up @@ -724,4 +726,4 @@ def read_eds(filename, **kwargs):
extension_to_reader_mapping = {"img": read_img,
"map": read_img,
"pts": read_pts,
"eds": read_eds}
"eds": read_eds}
2 changes: 1 addition & 1 deletion hyperspy/models/model1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def low_loss(self, value):
'navigation dimension as the core-loss.')
if not value.axes_manager.signal_axes[0].is_uniform:
raise ValueError('Low loss convolution is not supported with '
'non linear signal axes.')
'non-uniform signal axes.')
self._low_loss = value
self.set_convolution_axis()
self.convolved = True
Expand Down
2 changes: 1 addition & 1 deletion hyperspy/signal_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(self, signal):
raise SignalDimensionError(
signal.axes_manager.signal_dimension, 1)
if not isinstance(self.axis, UniformDataAxis):
raise ValueError("The calibration tool supports only uniform axes.")
raise NotImplementedError("The calibration tool supports only uniform axes.")
self.units = self.axis.units
self.scale = self.axis.scale
self.offset = self.axis.offset
Expand Down
17 changes: 16 additions & 1 deletion hyperspy/tests/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def test_io_overwriting_None_existing_file_n(self):
self.new_s.save(FULLFILENAME)
assert not self._check_file_is_written(FULLFILENAME)

def test_io_overwriting_invalid_parameter(self):
with pytest.raises(ValueError, match="parameter can only be"):
self.new_s.save(FULLFILENAME, overwrite="spam")

def teardown_method(self, method):
self._clean_file()

Expand All @@ -100,7 +104,7 @@ def setup_method(self, method):
def test_io_nonuniform(self):
assert(self.s.axes_manager[0].is_uniform == False)
self.s.save('tmp.hspy', overwrite = True)
with pytest.raises(AttributeError):
with pytest.raises(OSError):
self.s.save('tmp.msa', overwrite = True)

def test_nonuniform_writer_characteristic(self):
Expand All @@ -111,6 +115,17 @@ def test_nonuniform_writer_characteristic(self):
print(plugin.format_name + ' IO-plugin is missing the '
'characteristic `non_uniform_axis`')

def test_nonuniform_error(self):
assert(self.s.axes_manager[0].is_uniform == False)
no_we_cant = [plugin.file_extensions[plugin.default_extension] for
plugin in io_plugins if (plugin.writes is True or
plugin.writes is not False and (1, 0) in plugin.writes)
and plugin.non_uniform_axis is False]
for ext in no_we_cant:
with pytest.raises(OSError, match = "not supported for"):
filename = 'tmp.' + ext
self.s.save(filename, overwrite = True)

def teardown_method(self):
if os.path.exists('tmp.hspy'):
os.remove('tmp.hspy')
Expand Down
34 changes: 32 additions & 2 deletions hyperspy/tests/signals/test_1D_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,26 @@ def test_single_spectrum(self):
np.testing.assert_array_equal(s.data, np.arange(20))
assert m.data_changed.called

def test_single_spectrum_nonuniform(self):
jlaehne marked this conversation as resolved.
Show resolved Hide resolved
s = self.s.inav[0]
m = mock.Mock()
s.events.data_changed.connect(m.data_changed)
s.axes_manager[-1].convert_to_non_uniform_axis()
s.interpolate_in_between(8, 12)
np.testing.assert_array_equal(s.data, np.arange(20))
assert m.data_changed.called

def test_single_spectrum_in_units(self):
s = self.s.inav[0]
s.interpolate_in_between(0.8, 1.2)
np.testing.assert_array_equal(s.data, np.arange(20))

def test_single_spectrum_in_units_nonuniform(self):
s = self.s.inav[0]
s.axes_manager[-1].convert_to_non_uniform_axis()
s.interpolate_in_between(0.8, 1.2)
np.testing.assert_array_equal(s.data, np.arange(20))

def test_two_spectra(self):
s = self.s
s.interpolate_in_between(8, 12)
Expand All @@ -248,8 +263,23 @@ def test_delta_float(self):
s.interpolate_in_between(8, 12, delta=0.31, kind='cubic')
print(s.data[8:12])
np.testing.assert_allclose(
s.data[8:12], np.array([45.09388598, 104.16170809,
155.48258721, 170.33564422]),
s.data[8:12], np.array([46.595205, 109.802805,
164.512803, 178.615201]),
atol=1,
)

def test_delta_float_nonuniform(self):
s = self.s.inav[0]
s.change_dtype('float')
tmp = np.zeros(s.data.shape)
tmp[12] = s.data[12]
s.data += tmp * 9.
s.axes_manager[0].convert_to_non_uniform_axis()
s.interpolate_in_between(8, 12, delta=0.31, kind='cubic')
print(s.data[8:12])
np.testing.assert_allclose(
s.data[8:12], np.array([46.595205, 109.802805,
164.512803, 178.615201]),
atol=1,
)

Expand Down
4 changes: 4 additions & 0 deletions hyperspy/tests/test_non-uniform_not-implemented.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ def test_signal():
def test_signal1d():
s = Signal1D((1))
s.axes_manager[0].convert_to_non_uniform_axis()
with pytest.raises(NotImplementedError):
s.calibrate()
with pytest.raises(NotImplementedError):
s.shift1D([1])
with pytest.raises(NotImplementedError):
s.estimate_shift1D([1])
with pytest.raises(NotImplementedError):
s.smooth_savitzky_golay()
with pytest.raises(NotImplementedError):
Expand Down