Skip to content

Commit

Permalink
fix type selection
Browse files Browse the repository at this point in the history
  • Loading branch information
Czaki committed Nov 26, 2023
1 parent 373bd4f commit 322743b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.swp
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
19 changes: 16 additions & 3 deletions napari_animation/_tests/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numbers
from dataclasses import asdict
from typing import NamedTuple

Expand All @@ -14,17 +15,29 @@
from napari_animation.interpolation.utils import nested_assert_close


def _expected_type(a, b):
if isinstance(a, numbers.Integral) and isinstance(b, numbers.Real):
return type(b)
return type(a)


# Actual tests
@pytest.mark.parametrize("a", [0.0, 0])
@pytest.mark.parametrize("b", [100.0, 100])
@pytest.mark.parametrize("a", [0.0, 0, np.float32(0)])
@pytest.mark.parametrize("b", [100.0, 100, np.float32(100)])
@pytest.mark.parametrize("fraction", [0, 0.0, 0.5, 1.0, 1])
def test_interpolate_num(a, b, fraction):
"""Check that interpolation of numbers produces valid output"""
result = interpolate_num(a, b, fraction)
assert isinstance(result, type(a))
assert isinstance(result, _expected_type(a, b))
assert result == fraction * b


@pytest.mark.parametrize("b", [1.0, np.float32(1)])
def test_interpolate_proper_type(b):
result = interpolate_num(0, b, 0.5)
assert np.isclose(result, 0.5)


@pytest.mark.parametrize("a,b", [([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])])
@pytest.mark.parametrize(
"fraction,expected",
Expand Down
2 changes: 1 addition & 1 deletion napari_animation/interpolation/base_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def interpolate_num(a: Number, b: Number, fraction: float) -> Number:
Interpolated value between a and b at fraction.
"""
number_cls = type(a)
if isinstance(b, Real):
if isinstance(b, Real) and isinstance(a, Integral):
number_cls = type(b)
return number_cls(a + (b - a) * fraction)

Expand Down

0 comments on commit 322743b

Please sign in to comment.