Skip to content

Commit

Permalink
last review comments, stack for FDA
Browse files Browse the repository at this point in the history
  • Loading branch information
jlaehne committed May 9, 2021
1 parent 94cf419 commit 967fe74
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 10 deletions.
29 changes: 22 additions & 7 deletions hyperspy/misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,8 @@ def stack(signal_list, axis=None, new_axis_name="stack_element", lazy=None,
stacking axis of the first signal is uniform, it is extended up to the
new length; if it is non-uniform, the axes vectors of all signals are
concatenated along this direction; if it is a `FunctionalDataAxis`,
it is converted to a non-uniform DataAxis and treated as such.
it is extended based on the expression of the first signal (and its sub
axis `x` is handled as above depending on whether it is uniform or not).
new_axis_name : str
The name of the new axis when `axis` is None.
If an axis with this name already
Expand Down Expand Up @@ -963,7 +964,7 @@ def stack(signal_list, axis=None, new_axis_name="stack_element", lazy=None,
"""
from hyperspy.signals import BaseSignal
from hyperspy.axes import FunctionalDataAxis
from hyperspy.axes import FunctionalDataAxis, UniformDataAxis, DataAxis
import dask.array as da
from numbers import Number

Expand Down Expand Up @@ -1009,13 +1010,11 @@ def stack(signal_list, axis=None, new_axis_name="stack_element", lazy=None,
# Matching axis calibration is checked here
broadcasted_sigs = broadcast_signals(*signal_list, ignore_axis=axis_input)

if axis is not None:
if axis_input is not None:
step_sizes = [s.axes_manager[axis_input].size for s in broadcasted_sigs]
axis = broadcasted_sigs[0].axes_manager[axis_input]
if not axis.is_uniform:
# stack axes if non-uniform (convert to DataAxis if FunctionalDataAxis)
if type(axis) is FunctionalDataAxis:
axis.axes_manager[axis_input].convert_to_non_uniform_axis()
# stack axes if non-uniform (DataAxis)
if type(axis) is DataAxis:
for _s in signal_list[1:]:
_axis = _s.axes_manager[axis_input]
if (axis.axis[0] < axis.axis[-1] and axis.axis[-1] < _axis.axis[0]) \
Expand All @@ -1025,6 +1024,22 @@ def stack(signal_list, axis=None, new_axis_name="stack_element", lazy=None,
raise ValueError("Signals can only be stacked along a "
"non-uniform axes if the axis values do not overlap"
" and have the correct order.")
# stack axes if FunctionalDataAxis and its x axis is uniform
elif type(axis) is FunctionalDataAxis and \
type(axis.axes_manager[axis_input].x) is UniformDataAxis:
axis.x.size = np.sum(step_sizes)
# stack axes if FunctionalDataAxis and its x axis is not uniform
elif type(axis) is FunctionalDataAxis and \
type(axis.axes_manager[axis_input].x) is DataAxis:
for _s in signal_list[1:]:
_axis = _s.axes_manager[axis_input]
if (axis.x.axis[0] < axis.x.axis[-1] and axis.x.axis[-1] < _axis.x.axis[0]) \
or (axis.x.axis[-1] < axis.x.axis[0] and _axis.x.axis[-1] < axis.x.axis[0]):
axis.x.axis = np.concatenate((axis.x.axis, _axis.x.axis))
else:
raise ValueError("Signals can only be stacked along a "
"non-uniform axes if the axis values do not overlap"
" and have the correct order.")

datalist = [s.data for s in broadcasted_sigs]
newdata = (
Expand Down
4 changes: 2 additions & 2 deletions hyperspy/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4284,8 +4284,8 @@ def integrate1D(self, axis, out=None):
The integration is performed using
`Simpson's rule <https://en.wikipedia.org/wiki/Simpson%%27s_rule>`_ if
`axis.is_binned` is ``False`` or `axis.is_uniform` is ``False`` and
simple summation over the given axis if both are ``True``.
`axis.is_binned` is ``False`` and simple summation over the given axis
if ``True``.
Parameters
----------
Expand Down
11 changes: 10 additions & 1 deletion hyperspy/tests/utils/test_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,19 @@ def test_stack_non_uniform_axis(self):
s2.axes_manager[2].axis = s2.axes_manager[2].axis[::-1]
rs = utils.stack([s2, s], axis=2)
assert rs.axes_manager[2].axis.size == rs.data.shape[2]
# Test stacking of functional data axes

def test_stack_functional_data_axis(self):
s = self.signal
s2 = s.deepcopy()
# Test stacking of functional data axes with uniform x vector
s.axes_manager[0].convert_to_functional_data_axis(expression='x')
s2.axes_manager[0].offset = 2
s2.axes_manager[0].convert_to_functional_data_axis(expression='x')
rs = utils.stack([s, s2], axis=0)
assert rs.axes_manager[0].axis.size == rs.data.shape[1]
# Test stacking of functional data axes with uniform x vector
s.axes_manager[0].x.convert_to_non_uniform_axis()
s2.axes_manager[0].x.convert_to_non_uniform_axis()
rs = utils.stack([s, s2], axis=0)
assert rs.axes_manager[0].axis.size == rs.data.shape[1]

0 comments on commit 967fe74

Please sign in to comment.