Skip to content

Commit

Permalink
Fix inconsistency in value2index for uniform and non-uniform axis.
Browse files Browse the repository at this point in the history
  • Loading branch information
ericpre committed Jun 30, 2021
1 parent 3561a67 commit 0659013
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 28 deletions.
25 changes: 19 additions & 6 deletions hyperspy/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
numba_closest_index_floor,
numba_closest_index_ceil,
round_half_towards_zero,
round_half_away_from_zero,
)
from hyperspy.misc.utils import isiterable, ordinal
from hyperspy.misc.math_tools import isfloat
Expand Down Expand Up @@ -1192,14 +1193,26 @@ def value2index(self, value, rounding=round):

value = self._parse_value(value)

multiplier = 1E12
index = 1 / multiplier * np.trunc(
(value - self.offset) / self.scale * multiplier
)

if rounding is round:
rounding = round_half_towards_zero
elif rounding is math.ceil:
rounding = np.ceil
elif rounding is math.floor:
rounding = np.floor
# When value are negative, we need to use half away from zero
# approach on the index, because the index is always positive
index = np.where(
value >= 0 if np.sign(self.scale) > 0 else value < 0,
round_half_towards_zero(index, decimals=0),
round_half_away_from_zero(index, decimals=0),
)
else:
if rounding is math.ceil:
rounding = np.ceil
elif rounding is math.floor:
rounding = np.floor

index = rounding((value - self.offset) / self.scale)
index = rounding(index)

if isinstance(value, np.ndarray):
index = index.astype(int)
Expand Down
44 changes: 36 additions & 8 deletions hyperspy/misc/array_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,15 +470,12 @@ def numba_closest_index_round(axis_array, value_array):
rtol = 1e-12
machineepsilon = np.min(np.abs(np.diff(axis_array))) * rtol
for i, v in enumerate(value_array.flat):
if v >= 0:
index_array.flat[i] = np.abs(axis_array - v + machineepsilon).argmin()
else:
index_array.flat[i] = np.abs(axis_array - v - machineepsilon).argmin()
index_array.flat[i] = np.abs(axis_array - v + np.sign(v) * machineepsilon).argmin()
return index_array


@njit(cache=True)
def numba_closest_index_floor(axis_array, value_array):
def numba_closest_index_floor(axis_array, value_array): # pragma: no cover
"""For each value in value_array, find the closest smaller value in
axis_array and return the result as a numpy array of the same shape
as value_array.
Expand All @@ -504,7 +501,7 @@ def numba_closest_index_floor(axis_array, value_array):


@njit(cache=True)
def numba_closest_index_ceil(axis_array, value_array):
def numba_closest_index_ceil(axis_array, value_array): # pragma: no cover
"""For each value in value_array, find the closest larger value in
axis_array and return the result as a numpy array of the same shape
as value_array.
Expand All @@ -528,7 +525,8 @@ def numba_closest_index_ceil(axis_array, value_array):
return index_array


def round_half_towards_zero(array, decimals=0):
@njit(cache=True)
def round_half_towards_zero(array, decimals=0): # pragma: no cover
"""
Round input array using "half towards zero" strategy.
Expand All @@ -546,4 +544,34 @@ def round_half_towards_zero(array, decimals=0):
An array of the same type as a, containing the rounded values.
"""
multiplier = 10 ** decimals
return np.ceil(array * multiplier - 0.5) / multiplier

return np.where(array >= 0,
np.ceil(array * multiplier - 0.5) / multiplier,
np.floor(array * multiplier + 0.5) / multiplier
)


@njit(cache=True)
def round_half_away_from_zero(array, decimals=0): # pragma: no cover
"""
Round input array using "half away from zero" strategy.
Parameters
----------
array : ndarray
Input array.
decimals : int, optional
Number of decimal places to round to (default: 0).
Returns
-------
rounded_array : ndarray
An array of the same type as a, containing the rounded values.
"""
multiplier = 10 ** decimals

return np.where(array >= 0,
np.floor(array * multiplier + 0.5) / multiplier,
np.ceil(array * multiplier - 0.5) / multiplier
)
43 changes: 29 additions & 14 deletions hyperspy/tests/axes/test_data_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def test_deepcopy_on_trait_change(self):
def test_uniform_value2index(self):
#Tests for value2index
#Works as intended
assert self.axis.value2index(10.15) == 2
assert self.axis.value2index(10.15) == 1
assert self.axis.value2index(10.17, rounding=math.floor) == 1
assert self.axis.value2index(10.13, rounding=math.ceil) == 2
# Test that output is integer
Expand Down Expand Up @@ -516,14 +516,14 @@ def test_uniform_value2index(self):
#Arrays work as intended
arval = np.array([[10.15, 10.15], [10.24, 10.28]])
assert np.all(self.axis.value2index(arval) \
== np.array([[2, 2], [2, 3]]))
== np.array([[1, 1], [2, 3]]))
assert np.all(self.axis.value2index(arval, rounding=math.floor) \
== np.array([[1, 1], [2, 2]]))
assert np.all(self.axis.value2index(arval, rounding=math.ceil)\
== np.array([[2, 2], [3, 3]]))
#List in --> array out
assert np.all(self.axis.value2index(arval.tolist()) \
== np.array([[2, 2], [2, 3]]))
== np.array([[1, 1], [2, 3]]))
#One value out of bound in array in --> error out (both sides)
arval[1,1] = 111
with pytest.raises(ValueError):
Expand Down Expand Up @@ -744,14 +744,29 @@ def test_value_range_to_indices_v1_greater_than_v2(self):


def test_rounding_consistency_axis_type():
inax = [[-11.0, -10.9],
[-10.9, -11.0],
[+10.9, +11.0],
[+11.0, +10.9]]
inval = [-10.95, -10.95, 10.95, 10.95]

for i, j in zip(inax, inval):
ax = UniformDataAxis(scale=i[1]-i[0], offset=i[0], size=len(i))
nua_idx = super(type(ax),ax).value2index(j, rounding=round)
unif_idx = ax.value2index(j, rounding=round)
assert nua_idx == unif_idx
scales = [0.1, -0.1, 0.1, -0.1]
offsets = [-11.0, -10.9, 10.9, 11.0]
values = [-10.95, -10.95, 10.95, 10.95]

for i, (scale, offset, value) in enumerate(zip(scales, offsets, values)):
ax = UniformDataAxis(scale=scale, offset=offset, size=3)
ua_idx = ax.value2index(value)
nua_idx = super(type(ax), ax).value2index(value)
print('scale', scale)
print('offset', offset)
print('Axis values:', ax.axis)
print(f"value: {value} --> uniform: {ua_idx}, non-uniform: {nua_idx}")
print("\n")
assert nua_idx == ua_idx


@pytest.mark.parametrize('shift', (0.05, 0.025))
def test_rounding_consistency_axis_type_half(shift):

axis = UniformDataAxis(size=20, scale=0.1, offset=-1.0);
test_vals = axis.axis[:-1] + shift

uaxis_indices = axis.value2index(test_vals)
nuaxis_indices = super(type(axis), axis).value2index(test_vals)

np.testing.assert_allclose(uaxis_indices, nuaxis_indices)
42 changes: 42 additions & 0 deletions hyperspy/tests/misc/test_array_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
get_array_memory_size_in_GiB,
get_signal_chunk_slice,
numba_histogram,
round_half_towards_zero,
round_half_away_from_zero,
)

dt = [("x", np.uint8), ("y", np.uint16), ("text", (bytes, 6))]
Expand Down Expand Up @@ -188,8 +190,48 @@ def test_get_signal_chunk_slice_not_square(sig_chunks, index, expected):
chunk_slice = get_signal_chunk_slice(index, data.chunks)
assert chunk_slice == expected


@pytest.mark.parametrize('dtype', ['<u2', 'u2', '>u2', '<f4', 'f4', '>f4'])
def test_numba_histogram(dtype):
arr = np.arange(100, dtype=dtype)
np.testing.assert_array_equal(numba_histogram(arr, 5, (0, 100)), [20, 20, 20, 20, 20])


def test_round_half_towards_zero_integer():
a = np.array([-2.0, -1.7, -1.5, -0.2, 0.0, 0.2, 1.5, 1.7, 2.0])
np.testing.assert_allclose(
round_half_towards_zero(a, decimals=0),
np.array([-2.0, -2.0, -1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 2.0])
)
np.testing.assert_allclose(
round_half_towards_zero(a, decimals=0),
round_half_towards_zero(a)
)


def test_round_half_towards_zero():
a = np.array([-2.01, -1.56, -1.55, -1.50, -0.22, 0.0, 0.22, 1.50, 1.55, 1.56, 2.01])
np.testing.assert_allclose(
round_half_towards_zero(a, decimals=1),
np.array([-2.0, -1.6, -1.5, -1.5, -0.2, 0.0, 0.2, 1.5, 1.5, 1.6, 2.0])
)


def test_round_half_away_from_zero_integer():
a = np.array([-2.0, -1.7, -1.5, -0.2, 0.0, 0.2, 1.5, 1.7, 2.0])
np.testing.assert_allclose(
round_half_away_from_zero(a, decimals=0),
np.array([-2.0, -2.0, -2.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0])
)
np.testing.assert_allclose(
round_half_away_from_zero(a, decimals=0),
round_half_away_from_zero(a)
)


def test_round_half_away_from_zero():
a = np.array([-2.01, -1.56, -1.55, -1.50, -0.22, 0.0, 0.22, 1.50, 1.55, 1.56, 2.01])
np.testing.assert_allclose(
round_half_away_from_zero(a, decimals=1),
np.array([-2.0, -1.6, -1.6, -1.5, -0.2, 0.0, 0.2, 1.5, 1.6, 1.6, 2.0])
)

0 comments on commit 0659013

Please sign in to comment.