diff --git a/.gitignore b/.gitignore index c6537d0..e82e5bb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +*.swp # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/napari_animation/_tests/test_interpolation.py b/napari_animation/_tests/test_interpolation.py index 1ec94e8..c4d62fb 100644 --- a/napari_animation/_tests/test_interpolation.py +++ b/napari_animation/_tests/test_interpolation.py @@ -1,3 +1,4 @@ +import numbers from dataclasses import asdict from typing import NamedTuple @@ -14,17 +15,29 @@ from napari_animation.interpolation.utils import nested_assert_close +def _expected_type(a, b): + if isinstance(a, numbers.Integral) and isinstance(b, numbers.Real): + return type(b) + return type(a) + + # Actual tests -@pytest.mark.parametrize("a", [0.0, 0]) -@pytest.mark.parametrize("b", [100.0, 100]) +@pytest.mark.parametrize("a", [0.0, 0, np.float32(0)]) +@pytest.mark.parametrize("b", [100.0, 100, np.float32(100)]) @pytest.mark.parametrize("fraction", [0, 0.0, 0.5, 1.0, 1]) def test_interpolate_num(a, b, fraction): """Check that interpolation of numbers produces valid output""" result = interpolate_num(a, b, fraction) - assert isinstance(result, type(a)) + assert isinstance(result, _expected_type(a, b)) assert result == fraction * b +@pytest.mark.parametrize("b", [1.0, np.float32(1)]) +def test_interpolate_proper_type(b): + result = interpolate_num(0, b, 0.5) + assert np.isclose(result, 0.5) + + @pytest.mark.parametrize("a,b", [([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])]) @pytest.mark.parametrize( "fraction,expected", diff --git a/napari_animation/interpolation/base_interpolation.py b/napari_animation/interpolation/base_interpolation.py index 2996c00..bb2213f 100644 --- a/napari_animation/interpolation/base_interpolation.py +++ b/napari_animation/interpolation/base_interpolation.py @@ -1,4 +1,4 @@ -from numbers import Number +from numbers import Integral, Number, Real from typing import Sequence, Tuple, TypeVar import numpy as np @@ -85,6 +85,8 @@ def interpolate_num(a: Number, b: Number, fraction: float) -> Number: Interpolated value between a and b at fraction. """ number_cls = type(a) + if isinstance(b, Real) and isinstance(a, Integral): + number_cls = type(b) return number_cls(a + (b - a) * fraction)