diff --git a/docs/recipes.rst b/docs/recipes.rst index 4db3a3f..0ad1031 100644 --- a/docs/recipes.rst +++ b/docs/recipes.rst @@ -145,3 +145,54 @@ constructor dict values are still set on the class for compatibility. Btw, `copy` will *just work* for strict mocks and does not raise an error when not configured/expected. This is just not implemented and considered not-worth-the-effort. + + +Shared setUp stubs with tearDown safety checks +---------------------------------------------- + +Sometimes you have one "big" fixture / ``setUp`` that configures reusable stubs. + +Only some tests actually need all of them, but you also want to call +``verifyStubbedInvocationsAreUsed()`` or ``ensureNoUnverifiedInteractions()`` +unconditionally as your safety net on ``tearDown``. + +Yeah, I hate that but we need to be realistic. Use ``between=(0,)`` like so:: + + class TestService: + def setUp(self): + self.client = mock() + when(self.client).fetch("/warmup").thenReturn({"ok": True}) + ... # more + + def tearDown(self): + verify(self.client, between=(0,)).fetch(...) # mark as ok! + verifyStubbedInvocationsAreUsed() + ensureNoUnverifiedInteractions() + + +Speccing from ``typing.Protocol`` +--------------------------------- + +If your production code uses ``typing.Protocol`` interfaces, you can use them +as ``mock(spec=...)`` input directly:: + + from typing import Protocol + from mockito import mock, when + + class Service(Protocol): + async def fetch(self, path: str) -> str: + ... + + def close(self) -> bool: + ... + + service = mock(Service) + when(service).fetch('/health').thenReturn('ok') + when(service).close().thenReturn(True) + + assert await service.fetch('/health') == 'ok' # async stays async + assert service.close() is True # sync stays sync + +Such mocks are strict by default, so unknown methods and invalid call signatures +still fail early. + diff --git a/mockito/invocation.py b/mockito/invocation.py index f2463f4..a5e2af6 100644 --- a/mockito/invocation.py +++ b/mockito/invocation.py @@ -25,7 +25,6 @@ import inspect import operator from collections import deque -from functools import cached_property from typing import TYPE_CHECKING, Union from . import matchers, signature @@ -137,11 +136,12 @@ def __call__(self, *params: Any, **named_params: Any) -> Any | None: self._remember_params(params_without_first_arg, named_params) self.mock.remember(self) - matching_invocation = self._find_best_matching_stubbed_invocation() - if matching_invocation is not None: - matching_invocation.should_answer(self) - matching_invocation.capture_arguments(self) - return matching_invocation.answer_first(*params, **named_params) + for matching_invocation in self.mock.stubbed_invocations: + if matching_invocation.matches(self): + matching_invocation.should_answer(self) + matching_invocation.capture_arguments(self) + return matching_invocation.answer_first( + *params, **named_params) if self.strict: stubbed_invocations = [ @@ -170,21 +170,6 @@ def __call__(self, *params: Any, **named_params: Any) -> Any | None: return None - def _find_best_matching_stubbed_invocation(self) -> StubbedInvocation | None: - candidates = [ - candidate - for candidate in self.mock.stubbed_invocations - if candidate.matches(self) - ] - - if not candidates: - return None - - if len(candidates) == 1: - return candidates[0] - - return max(candidates, key=lambda candidate: candidate.specificity_score) - class RememberedPropertyAccess(RememberedInvocation): def ensure_mocked_object_has_method(self, method_name): @@ -409,6 +394,9 @@ def verification_has_lower_bound_of_zero( ): return True + if isinstance(verification, verificationModule.AtMost): + return True + if ( isinstance(verification, verificationModule.Between) and verification.wanted_from == 0 @@ -477,10 +465,7 @@ def __init__( if strict is not None: self.strict = strict - self.refers_coroutine = ( - is_coroutine_method(mock.peek_original_method(method_name)) - or mock.is_marked_as_coroutine(method_name) - ) + self.refers_coroutine = mock.method_expects_awaitable(method_name) self.discard_first_arg = mock.will_have_self_or_cls(method_name) default_answer = ( return_awaitable(None) if self.refers_coroutine else return_(None) @@ -523,33 +508,6 @@ def __call__(self, *params: Any, **named_params: Any) -> AnswerSelector: self.mock.finish_stubbing(self) return AnswerSelector(self, self.refers_coroutine, self.discard_first_arg) - @cached_property - def specificity_score(self) -> tuple[int, int]: - quality = 0 - - for value in self.params: - if value is not matchers.ARGS_SENTINEL: - quality += self._specificity_score(value) - - for key, value in self.named_params.items(): - if key is not matchers.KWARGS_SENTINEL: - quality += self._specificity_score(value) - - coverage = len(self.params) + len(self.named_params) - return coverage, quality - - def _specificity_score(self, value: object) -> int: - if value is Ellipsis: - return 0 - - if isinstance(value, matchers.Any) and value.wanted_type is None: - return 0 - - if isinstance(value, matchers.Matcher): - return 1 - - return 3 - def forget_self(self) -> None: if self in self.mock.stubbed_invocations: self.mock.forget_stubbed_invocation(self) @@ -719,7 +677,7 @@ def __call__(self, *params, **named_params): def create_chain_mock() -> tuple[object, Mock]: from .mocking import mock - chain_root = mock() + chain_root = mock(strict=True) theMock = mock_registry.mock_for(chain_root) assert theMock is not None, "Missing chain mock registry entry" return chain_root, theMock @@ -737,13 +695,6 @@ async def answer(*args, **kwargs) -> T: return answer -def is_coroutine_method(method: Any) -> bool: - if isinstance(method, (staticmethod, classmethod)): - method = method.__func__ - - return inspect.iscoroutinefunction(method) - - def raise_(exception: Exception | type[Exception]) -> Callable[..., NoReturn]: def answer(*args, **kwargs) -> NoReturn: raise exception diff --git a/mockito/mocking.py b/mockito/mocking.py index 5772f7e..3bf1689 100644 --- a/mockito/mocking.py +++ b/mockito/mocking.py @@ -374,7 +374,7 @@ def new_mocked_method(*args, **kwargs): pass if ( - _is_coroutine_method(original_method) + self.method_expects_awaitable(method_name, original_method) and SUPPORTS_MARKCOROUTINEFUNCTION ): new_mocked_method = inspect.markcoroutinefunction(new_mocked_method) @@ -452,7 +452,7 @@ def forget_stubbed_invocation( if self.stubbed_invocations: return - mock_registry.unstub_mock(self) + mock_registry.unstub(self.mocked_obj) def restore_method(self, method_name: str, original_method: object) -> None: if original_method is _MISSING_ATTRIBUTE: @@ -477,6 +477,23 @@ def mark_as_coroutine(self, method_name: str) -> None: def is_marked_as_coroutine(self, method_name: str) -> bool: return method_name in self._methods_marked_as_coroutine + def method_expects_awaitable( + self, + method_name: str, + original_method: object | None = None, + ) -> bool: + if original_method is None: + original_method = self.peek_original_method(method_name) + + return ( + _is_coroutine_method(original_method) + or self.is_marked_as_coroutine(method_name) + or ( + self.spec is None + and method_name in _ASYNC_BY_PROTOCOL_METHODS + ) + ) + def has_method(self, method_name: str) -> bool: if self.spec is None: return True diff --git a/tests/async_protocol_methods_test.py b/tests/async_protocol_methods_test.py new file mode 100644 index 0000000..b48813d --- /dev/null +++ b/tests/async_protocol_methods_test.py @@ -0,0 +1,65 @@ +import asyncio +import inspect + +import pytest + +from mockito import mock, when + + +pytestmark = pytest.mark.usefixtures("unstub") + + +def run(coro): + return asyncio.run(coro) + + +async def _use_async_resource(resource): + async with resource as entered: + return entered + + +async def _collect_async_iter(values): + seen = [] + async for value in values: + seen.append(value) + return seen + + +def test_when_thenReturn_on_ad_hoc_mock_aenter_and_aexit_are_awaitable(): + resource = mock() + when(resource).__aenter__().thenReturn(resource) + when(resource).__aexit__(..., ..., ...).thenReturn(False) + + pending_enter = resource.__aenter__() + assert inspect.isawaitable(pending_enter) + assert run(pending_enter) is resource + + pending_exit = resource.__aexit__(None, None, None) + assert inspect.isawaitable(pending_exit) + assert run(pending_exit) is False + + +def test_when_thenReturn_on_ad_hoc_mock_supports_async_with(): + resource = mock() + entered = object() + when(resource).__aenter__().thenReturn(entered) + when(resource).__aexit__(..., ..., ...).thenReturn(False) + + assert run(_use_async_resource(resource)) is entered + + +def test_when_thenReturn_on_ad_hoc_mock_anext_is_awaitable(): + values = mock() + when(values).__anext__().thenReturn(1) + + pending = values.__anext__() + assert inspect.isawaitable(pending) + assert run(pending) == 1 + + +def test_when_thenReturn_on_ad_hoc_mock_supports_async_for(): + values = mock() + when(values).__aiter__().thenReturn(values) + when(values).__anext__().thenReturn(1).thenReturn(2).thenRaise(StopAsyncIteration) + + assert run(_collect_async_iter(values)) == [1, 2] diff --git a/tests/call_original_implem_test.py b/tests/call_original_implem_test.py index 0075e29..aeb5cdd 100644 --- a/tests/call_original_implem_test.py +++ b/tests/call_original_implem_test.py @@ -2,8 +2,9 @@ import sys import pytest -from mockito import mock, when, verify, ArgumentError +from mockito import mock, when, verify from mockito.invocation import AnswerError +from mockito.verification import VerificationError from . import module from .test_base import TestBase @@ -110,12 +111,14 @@ def testDumbMockHasNoOriginalImplementations(self): def testDumbMockFailedThenCallOriginalImplementationDoesNotLeakStub(self): dog = mock() + with pytest.raises(VerificationError): + verify(dog).bark() with pytest.raises(AnswerError): when(dog).bark().thenCallOriginalImplementation() - with pytest.raises(ArgumentError): - verify(dog).bark(Ellipsis) + with pytest.raises(VerificationError): + verify(dog).bark() def testSpeccedMockHasOriginalImplementations(self): dog = mock({"huge": True}, spec=Dog) diff --git a/tests/chaining_test.py b/tests/chaining_test.py index 9acb5d4..c05de26 100644 --- a/tests/chaining_test.py +++ b/tests/chaining_test.py @@ -1,7 +1,7 @@ import pytest -from mockito import expect, mock, verify, when -from mockito.invocation import InvocationError +from mockito import expect, mock, verify, unstub, when +from mockito.invocation import AnswerError, InvocationError pytestmark = pytest.mark.usefixtures("unstub") @@ -106,6 +106,265 @@ def test_property_chaining_is_supported(): assert cat.age.value() == 14 assert cat.age.greater_than(12) is True +@pytest.mark.xfail(reason="Not implemented") +def test_deep_property_chain_with_method_leaf_is_supported(): + cat = mock() + + when(cat).age.expected.to.be(14).thenReturn(False) + + assert cat.age.expected.to.be(14) is False + + +@pytest.mark.xfail(reason="Not implemented") +def test_deep_property_chain_with_property_leaf_is_supported(): + cat = mock() + + when(cat).age.expected.to.value.thenReturn(14) + + assert cat.age.expected.to.value == 14 + + +@pytest.mark.xfail(reason="Not implemented") +def test_deep_property_chain_method_and_property_leaf_can_coexist(): + cat = mock() + + when(cat).age.expected.to.be(14).thenReturn(False) + when(cat).age.expected.to.value.thenReturn(14) + + assert cat.age.expected.to.be(14) is False + assert cat.age.expected.to.value == 14 + + +def test_unconfigured_context_manager_rewinds_1(): + cat = mock() + + assert cat.age() is None + + with pytest.raises((TypeError, AttributeError)): + with when(cat).age.expected.to.thenReturn: + pass + + assert cat.age() is None + + +def test_unconfigured_context_manager_rewinds_2(): + cat = mock() + + assert cat.age() is None + + with pytest.raises((TypeError, AttributeError)): + with when(cat).age.expected.to.leaf: + pass + + assert cat.age() is None + + +@pytest.mark.xfail(reason="Not implemented") +def test_failed_call_original_rewinds_1(): + cat = mock() + with pytest.raises(AnswerError): + when(cat).age.expected.to.value.thenCallOriginalImplementation() + + assert cat.age() is None + + +@pytest.mark.xfail(reason="Not implemented") +def test_failed_call_original_rewinds_2(): + cat = mock() + with pytest.raises(AnswerError): + when(cat).age.expected.to.be(14).thenCallOriginalImplementation() + + assert cat.age() is None + + +@pytest.mark.xfail(reason="Not implemented") +def test_failed_call_original_on_deep_property_leaf_rolls_back_only_leaf(): + cat = mock() + + when(cat).age.expected.to.be(14).thenReturn(False) + when(cat).age.expected.to.value.thenReturn(14) + + with pytest.raises(AnswerError) as exc: + when(cat).age.expected.to.value.thenCallOriginalImplementation() + + assert str(exc.value) == ( + "'.Dummy'>' " + "has no original implementation for 'value'." + ) + assert cat.age.expected.to.be(14) is False + assert cat.age.expected.to.value == 14 + + +@pytest.mark.xfail(reason="Not implemented") +def test_a(): + cat = mock() + assert cat.our() is None + + with when(cat).our.cat.named("spooky").is_very.brave.thenReturn(True): + assert cat.our.cat.named("spooky").is_very.brave is True + + assert cat.our() is None + with when(cat).our.cat.named("spooky").is_very.brave.thenReturn(True): + assert cat.our.cat.named("spooky").is_very.brave is True + + assert cat.our() is None + + +@pytest.mark.xfail(reason="Not implemented") +def test_b(): + cat = mock() + assert cat.our() is None + + with pytest.raises(AnswerError): + with when(cat).our.cat.named("spooky") \ + .is_very.brave.thenCallOriginalImplementation(): + ... + + assert cat.our() is None + + +@pytest.mark.xfail(reason="Not implemented") +def test_c(): + cat = mock() + assert cat.our() is None + + with pytest.raises(AnswerError): + when(cat).our.cat.named("spooky").is_very.brave.thenCallOriginalImplementation() + + assert cat.our() is None + + +@pytest.mark.xfail(reason="Not implemented") +def test_d(): + cat = mock() + assert cat.our() is None + + when(cat).our.cat.named("spooky") + with pytest.raises(AnswerError): + when(cat).our.cat.named("spooky").is_very.brave.thenCallOriginalImplementation() + + assert cat.our + assert cat.our.cat.named("spooky") is None + + +@pytest.mark.xfail(reason="Not implemented") +def test_e1(): + cat = mock() + assert cat.our() is None + + with when(cat).our.cat.named("spooky") \ + .is_very.brave.thenReturn(True).thenReturn(False): + assert cat.our.cat.named("spooky").is_very.brave is True + assert cat.our.cat.named("spooky").is_very.brave is False + assert cat.our.cat.named("spooky").is_very.brave is False + + assert cat.our() is None + + +@pytest.mark.xfail(reason="Not implemented") +def test_e2(): + cat = mock() + assert cat.our() is None + + when(cat).our.cat.named("spooky") + assert cat.our.cat.named("spooky") is None + + when(cat).our.cat.named("spooky").is_very.brave.thenReturn(True).thenReturn(False) + assert cat.our.cat.named("spooky") is not None + + unstub(cat) + assert cat.our() is None + + +@pytest.mark.xfail(reason="Not implemented") +def test_e3(): + cat = mock() + assert cat.our() is None + + when(cat).our.cat.named("spooky").is_very.brave.thenReturn(True).thenReturn(False) + assert cat.our.cat.named("spooky") is not None + + unstub(cat.our.cat) + with pytest.raises(AttributeError): + assert cat.our.cat.named("spooky") is not None + + +def test_f(): + cat = mock() + assert cat.our() is None + + with pytest.raises(AttributeError): + with when(cat).our.cat.named("spooky") \ + .is_very.brave.thenReturn(True).otherwise.null: + ... + + assert cat.our() is None + + +def test_g1(): # for illustration + cat = mock(strict=False) + assert hasattr(cat, "your") is True + assert cat.your # okay, not strict + + +def test_g1b(): # for illustration + cat = mock(strict=False) + expect(cat).your # <== doesn't change anything + assert hasattr(cat, "your") is True + assert cat.your + + +def test_g2(): # for illustration + cat = mock(strict=True) + assert hasattr(cat, "your") is False + with pytest.raises(AttributeError): + assert cat.your # 'your' is not configured + + +def test_g2b(): # for illustration + cat = mock(strict=True) + expect(cat).your # <== doesn't change anything + assert hasattr(cat, "your") is False + with pytest.raises(AttributeError): + assert cat.your # 'your' is not configured + + +@pytest.mark.xfail(reason="Needs decision") +def test_g3_non_strict_chain_child_stays_non_strict(): + cat = mock(strict=False) + + when(cat).our.cat.named("spooky").is_spooky + + spooky = cat.our.cat.named("spooky") + assert spooky is not None + assert hasattr(spooky, "is_spooky") is True + assert spooky.is_spooky + + +@pytest.mark.xfail(reason="Not implemented") +def test_g4_strict_chain_child_stays_strict(): + cat = mock(strict=True) + + when(cat).our.cat.named("spooky").is_spooky + + spooky = cat.our.cat.named("spooky") + assert spooky is not None + assert hasattr(spooky, "is_spooky") is False + with pytest.raises(AttributeError): # 'Dummy' has no attribute 'is_spooky' ... + assert spooky.is_spooky + + +@pytest.mark.xfail(reason="Not implemented") +def test_g5_ensure_we_unwind_to_previous_state(): + cat = mock() + expect(cat).our.cat.named("spooky") + assert cat.our.cat.named("spooky") is None + + with expect(cat).our.cat.named("spooky").is_spooky: + assert cat.our.cat.named("spooky") is not None + + assert cat.our.cat.named("spooky") is None + def test_context_manager_unwinds_method_chains_of_any_length(): cat = mock() @@ -200,3 +459,22 @@ def test_chain_matching_requires_candidate_matches_existing_direction(): assert cat.meow(2).purr() == "two" +def test_unexpected_chain_segment_arguments_raise_invocation_error_early(): + cat = mock() + + when(cat).meow().jump("bar").sleep().thenReturn("ok") + + with pytest.raises(InvocationError) as exc: + cat.meow().jump("baz").sleep() + + assert str(exc.value) == ( + "\nCalled but not expected:\n" + "\n" + " jump('baz')\n" + "\n" + "Stubbed invocations are:\n" + "\n" + " jump('bar')\n" + "\n" + ) + diff --git a/tests/in_order_test.py b/tests/in_order_test.py index 9e26331..f65524f 100644 --- a/tests/in_order_test.py +++ b/tests/in_order_test.py @@ -525,9 +525,10 @@ def test_in_order_verify_zero_lower_bound_does_not_fail_on_empty_queue( "verify_kwargs", [ {"times": 0}, + {"atmost": 2}, {"between": (0, 2)}, ], - ids=["times_0", "between_0_2"], + ids=["times_0", "atmost_2", "between_0_2"], ) def test_in_order_verify_zero_lower_bound_does_not_fail_when_all_calls_are_consumed( verify_kwargs, @@ -549,9 +550,10 @@ def test_in_order_verify_zero_lower_bound_does_not_fail_when_all_calls_are_consu "verify_kwargs", [ {"times": 0}, + {"atmost": 2}, {"between": (0, 2)}, ], - ids=["times_0", "between_0_2"], + ids=["times_0", "atmost_2", "between_0_2"], ) def test_in_order_zero_verify_marks_stub_as_checked_for_follow_up_global_verifications( verify_kwargs, diff --git a/tests/instancemethods_test.py b/tests/instancemethods_test.py index fb276b1..1ff128e 100644 --- a/tests/instancemethods_test.py +++ b/tests/instancemethods_test.py @@ -253,6 +253,7 @@ def testBarkOnUnusedStub(self): class TestPassIfExplicitlyVerified: @pytest.mark.parametrize('verification', [ {'times': 0}, + {'atmost': 3}, {'between': [0, 3]} ]) def testPassIfExplicitlyVerified(self, verification): @@ -303,6 +304,7 @@ def testPassIfExplicitlyVerified4(self): class TestPassIfImplicitlyVerifiedViaExpect: @pytest.mark.parametrize('verification', [ {'times': 0}, + {'atmost': 3}, {'between': [0, 3]} ]) def testPassIfImplicitlyVerified(self, verification): diff --git a/tests/pathlib_stubbing_research_test.py b/tests/pathlib_stubbing_research_test.py new file mode 100644 index 0000000..c0e4bdf --- /dev/null +++ b/tests/pathlib_stubbing_research_test.py @@ -0,0 +1,95 @@ +import pathlib +import sys + +import pytest + +from mockito import mock, when +from mockito.invocation import InvocationError + + +pytestmark = pytest.mark.usefixtures("unstub") + + +def test_pathlib_factory_can_stub_exists_per_path_value(): + when(pathlib).Path("foo").exists().thenReturn(True) + when(pathlib).Path("bar").exists().thenReturn(False) + + assert pathlib.Path("foo").exists() is True + assert pathlib.Path("bar").exists() is False + + +def test_pathlib_factory_can_stub_read_text_per_path_value(): + when(pathlib).Path("foo").read_text().thenReturn("A") + when(pathlib).Path("bar").read_text().thenReturn("B") + + assert pathlib.Path("foo").read_text() == "A" + assert pathlib.Path("bar").read_text() == "B" + + +def test_pathlib_factory_can_return_path_doubles_with_parents_property(): + foo = mock({"parents": ["root", "foo"]}, spec=pathlib.Path) + bar = mock({"parents": ["root", "bar"]}, spec=pathlib.Path) + + when(pathlib).Path("foo").thenReturn(foo) + when(pathlib).Path("bar").thenReturn(bar) + + assert pathlib.Path("foo").parents == ["root", "foo"] + assert pathlib.Path("bar").parents == ["root", "bar"] + + +@pytest.mark.xfail(reason="Not implemented", run=sys.version_info >= (3, 12)) +def test_pathlib_factory_can_stub_parents_property_per_path_via_chaining(): + when(pathlib).Path("foo").parents.thenReturn(["root", "foo"]) + when(pathlib).Path("bar").parents.thenReturn(["root", "bar"]) + + assert pathlib.Path("foo").parents == ["root", "foo"] + assert pathlib.Path("bar").parents == ["root", "bar"] + + +@pytest.mark.xfail(reason="Not implemented", run=sys.version_info >= (3, 12)) +def test_pathlib_factory_can_chain_through_parent_property_then_method(): + when(pathlib).Path("foo").parent.exists().thenReturn(True) + + assert pathlib.Path("foo").parent.exists() is True + + +def test_pathlib_factory_chain_can_distinguish_root_paths_with_operator_slash(): + when(pathlib).Path("foo").__truediv__("bar").exists().thenReturn(True) + + assert (pathlib.Path("foo") / "bar").exists() is True + + with pytest.raises(InvocationError): + (pathlib.Path("bar") / "bar").exists() + + +def test_pathlib_factory_chain_segment_mismatch_should_scream_like_os_path(): + when(pathlib).Path("foo").__truediv__("bar").exists().thenReturn(True) + + with pytest.raises(InvocationError): + (pathlib.Path("foo") / "baz").exists() + + +@pytest.mark.xfail( + reason=( + "Not implemented, not decided: decompose Path(*parts) constructor " + "stubs into __truediv__ chain matching" + ), + run=sys.version_info >= (3, 12) +) +def test_pathlib_constructor_parts_stub_can_match_slash_composition(): + when(pathlib).Path("foo", "bar", "baz").exists().thenReturn(True) + + assert (pathlib.Path("foo") / "bar" / "baz").exists() is True + + +@pytest.mark.xfail( + reason=( + "Not implemented, not decided: treat Path('a/b/c') constructor " + "stubs as segment-aware slash chains" + ), + run=sys.version_info >= (3, 12) +) +def test_pathlib_single_string_stub_can_match_slash_composition(): + when(pathlib).Path("foo/bar/baz").exists().thenReturn(True) + + assert (pathlib.Path("foo") / "bar" / "baz").exists() is True diff --git a/tests/protocol_speccing_test.py b/tests/protocol_speccing_test.py new file mode 100644 index 0000000..238baf8 --- /dev/null +++ b/tests/protocol_speccing_test.py @@ -0,0 +1,83 @@ +import asyncio +import inspect +from typing import Protocol + +import pytest + +from mockito import mock, when +from mockito.invocation import InvocationError + + +pytestmark = pytest.mark.usefixtures("unstub") + + +def run(coro): + return asyncio.run(coro) + + +class ServiceProtocol(Protocol): + async def fetch(self, path: str, timeout: int = 1) -> str: + ... + + def close(self, hard: bool = False) -> bool: + ... + + +class BaseRunnerProtocol(Protocol): + def run(self, value: int) -> int: + ... + + +class ExtendedRunnerProtocol(BaseRunnerProtocol, Protocol): + def run(self, value: int, mode: str = "safe") -> int: + ... + + +def test_protocol_spec_enforces_method_existence(): + service = mock(ServiceProtocol) + + with pytest.raises(InvocationError): + when(service).unknown() + + with pytest.raises(AttributeError): + service.unknown() + + +def test_protocol_spec_keeps_async_and_sync_methods_distinct(): + service = mock(ServiceProtocol) + + when(service).fetch("/health", timeout=1).thenReturn("ok") + when(service).close(hard=False).thenReturn(True) + + pending = service.fetch("/health", timeout=1) + assert inspect.isawaitable(pending) + assert run(pending) == "ok" + + result = service.close(hard=False) + assert not inspect.isawaitable(result) + assert result is True + + +def test_protocol_spec_enforces_method_signatures_for_stubbing_and_calls(): + service = mock(ServiceProtocol) + + with pytest.raises(TypeError): + when(service).fetch() + + with pytest.raises(TypeError): + when(service).close(True, False) + + when(service).close().thenReturn(True) + + with pytest.raises(TypeError): + service.close(True, False) + + +def test_protocol_signature_follows_override_definition_on_child_protocol(): + runner = mock(ExtendedRunnerProtocol) + + when(runner).run(1, mode="fast").thenReturn(2) + assert runner.run(1, mode="fast") == 2 + + with pytest.raises(TypeError): + when(runner).run(1, "fast", "extra") diff --git a/tests/stub_specificity_test.py b/tests/stub_specificity_test.py deleted file mode 100644 index 514ad7c..0000000 --- a/tests/stub_specificity_test.py +++ /dev/null @@ -1,88 +0,0 @@ -import pytest - -from mockito import any, args, kwargs, mock, when - - -pytestmark = pytest.mark.usefixtures("unstub") - - -class _Path: - def exists(self, location): - return f"orig:{location}" - - -def test_literal_stub_beats_ellipsis_even_if_ellipsis_added_last(): - path = mock(_Path) - - when(path).exists(".flake8").thenReturn("stubbed") - when(path).exists(...).thenCallOriginalImplementation() - - assert path.exists(".flake8") == "stubbed" - assert path.exists("README.rst") == "orig:README.rst" - - -def test_literal_stub_beats_ellipsis_even_if_literal_added_last(): - path = mock(_Path) - - when(path).exists(...).thenCallOriginalImplementation() - when(path).exists(".flake8").thenReturn("stubbed") - - assert path.exists(".flake8") == "stubbed" - assert path.exists("README.rst") == "orig:README.rst" - - -def test_typed_any_is_more_specific_than_any_and_ellipsis(): - path = mock() - - when(path).exists(...).thenReturn("ellipsis") - when(path).exists(any()).thenReturn("any") - when(path).exists(any(str)).thenReturn("typed-any") - - assert path.exists(".flake8") == "typed-any" - assert path.exists(1) == "any" - - -def test_any_and_ellipsis_have_same_specificity_and_keep_last_wins_tie_break(): - path = mock() - - when(path).exists(any()).thenReturn("any") - when(path).exists(...).thenReturn("ellipsis") - assert path.exists(1) == "ellipsis" - - other = mock() - when(other).exists(...).thenReturn("ellipsis") - when(other).exists(any()).thenReturn("any") - assert other.exists(1) == "any" - - -def test_coverage_beats_quality_when_both_match(): - subject = mock() - - when(subject).f("x", ...).thenReturn("prefix") - when(subject).f(..., retry=..., headers=...).thenReturn("kwargs-shape") - - assert subject.f("x", retry=5, headers={}) == "kwargs-shape" - - -def test_literal_beats_matchers_when_coverage_is_equal(): - subject = mock() - - when(subject).f("x", ...).thenReturn("prefix-fallback") - when(subject).f(any(str), any(int)).thenReturn("typed-exact") - - assert subject.f("x", 1) == "prefix-fallback" - - -def test_args_and_kwargs_sentinels_have_same_weight_as_ellipsis(): - subject = mock() - - when(subject).f(...).thenReturn("ellipsis") - when(subject).f(*args).thenReturn("args") - - assert subject.f(1) == "args" - - other = mock() - when(other).g(...).thenReturn("ellipsis") - when(other).g(**kwargs).thenReturn("kwargs") - - assert other.g(retry=1) == "kwargs"