Skip to content
51 changes: 51 additions & 0 deletions docs/recipes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,54 @@ constructor dict values are still set on the class for compatibility.

Btw, `copy` will *just work* for strict mocks and does not raise an error when not configured/expected. This is
just not implemented and considered not-worth-the-effort.


Shared setUp stubs with tearDown safety checks
----------------------------------------------

Sometimes you have one "big" fixture / ``setUp`` that configures reusable stubs.

Only some tests actually need all of them, but you also want to call
``verifyStubbedInvocationsAreUsed()`` or ``ensureNoUnverifiedInteractions()``
unconditionally as your safety net on ``tearDown``.

Yeah, I hate that but we need to be realistic. Use ``between=(0,)`` like so::

class TestService:
def setUp(self):
self.client = mock()
when(self.client).fetch("/warmup").thenReturn({"ok": True})
... # more

def tearDown(self):
verify(self.client, between=(0,)).fetch(...) # mark as ok!
verifyStubbedInvocationsAreUsed()
ensureNoUnverifiedInteractions()


Speccing from ``typing.Protocol``
---------------------------------

If your production code uses ``typing.Protocol`` interfaces, you can use them
as ``mock(spec=...)`` input directly::

from typing import Protocol
from mockito import mock, when

class Service(Protocol):
async def fetch(self, path: str) -> str:
...

def close(self) -> bool:
...

service = mock(Service)
when(service).fetch('/health').thenReturn('ok')
when(service).close().thenReturn(True)

assert await service.fetch('/health') == 'ok' # async stays async
assert service.close() is True # sync stays sync

Such mocks are strict by default, so unknown methods and invalid call signatures
still fail early.

71 changes: 11 additions & 60 deletions mockito/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import inspect
import operator
from collections import deque
from functools import cached_property
from typing import TYPE_CHECKING, Union

from . import matchers, signature
Expand Down Expand Up @@ -137,11 +136,12 @@ def __call__(self, *params: Any, **named_params: Any) -> Any | None:
self._remember_params(params_without_first_arg, named_params)
self.mock.remember(self)

matching_invocation = self._find_best_matching_stubbed_invocation()
if matching_invocation is not None:
matching_invocation.should_answer(self)
matching_invocation.capture_arguments(self)
return matching_invocation.answer_first(*params, **named_params)
for matching_invocation in self.mock.stubbed_invocations:
if matching_invocation.matches(self):
matching_invocation.should_answer(self)
matching_invocation.capture_arguments(self)
return matching_invocation.answer_first(
*params, **named_params)

if self.strict:
stubbed_invocations = [
Expand Down Expand Up @@ -170,21 +170,6 @@ def __call__(self, *params: Any, **named_params: Any) -> Any | None:

return None

def _find_best_matching_stubbed_invocation(self) -> StubbedInvocation | None:
candidates = [
candidate
for candidate in self.mock.stubbed_invocations
if candidate.matches(self)
]

if not candidates:
return None

if len(candidates) == 1:
return candidates[0]

return max(candidates, key=lambda candidate: candidate.specificity_score)


class RememberedPropertyAccess(RememberedInvocation):
def ensure_mocked_object_has_method(self, method_name):
Expand Down Expand Up @@ -409,6 +394,9 @@ def verification_has_lower_bound_of_zero(
):
return True

if isinstance(verification, verificationModule.AtMost):
return True

if (
isinstance(verification, verificationModule.Between)
and verification.wanted_from == 0
Expand Down Expand Up @@ -477,10 +465,7 @@ def __init__(
if strict is not None:
self.strict = strict

self.refers_coroutine = (
is_coroutine_method(mock.peek_original_method(method_name))
or mock.is_marked_as_coroutine(method_name)
)
self.refers_coroutine = mock.method_expects_awaitable(method_name)
self.discard_first_arg = mock.will_have_self_or_cls(method_name)
default_answer = (
return_awaitable(None) if self.refers_coroutine else return_(None)
Expand Down Expand Up @@ -523,33 +508,6 @@ def __call__(self, *params: Any, **named_params: Any) -> AnswerSelector:
self.mock.finish_stubbing(self)
return AnswerSelector(self, self.refers_coroutine, self.discard_first_arg)

@cached_property
def specificity_score(self) -> tuple[int, int]:
quality = 0

for value in self.params:
if value is not matchers.ARGS_SENTINEL:
quality += self._specificity_score(value)

for key, value in self.named_params.items():
if key is not matchers.KWARGS_SENTINEL:
quality += self._specificity_score(value)

coverage = len(self.params) + len(self.named_params)
return coverage, quality

def _specificity_score(self, value: object) -> int:
if value is Ellipsis:
return 0

if isinstance(value, matchers.Any) and value.wanted_type is None:
return 0

if isinstance(value, matchers.Matcher):
return 1

return 3

def forget_self(self) -> None:
if self in self.mock.stubbed_invocations:
self.mock.forget_stubbed_invocation(self)
Expand Down Expand Up @@ -719,7 +677,7 @@ def __call__(self, *params, **named_params):
def create_chain_mock() -> tuple[object, Mock]:
from .mocking import mock

chain_root = mock()
chain_root = mock(strict=True)
theMock = mock_registry.mock_for(chain_root)
assert theMock is not None, "Missing chain mock registry entry"
return chain_root, theMock
Expand All @@ -737,13 +695,6 @@ async def answer(*args, **kwargs) -> T:
return answer


def is_coroutine_method(method: Any) -> bool:
if isinstance(method, (staticmethod, classmethod)):
method = method.__func__

return inspect.iscoroutinefunction(method)


def raise_(exception: Exception | type[Exception]) -> Callable[..., NoReturn]:
def answer(*args, **kwargs) -> NoReturn:
raise exception
Expand Down
21 changes: 19 additions & 2 deletions mockito/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def new_mocked_method(*args, **kwargs):
pass

if (
_is_coroutine_method(original_method)
self.method_expects_awaitable(method_name, original_method)
and SUPPORTS_MARKCOROUTINEFUNCTION
):
new_mocked_method = inspect.markcoroutinefunction(new_mocked_method)
Expand Down Expand Up @@ -452,7 +452,7 @@ def forget_stubbed_invocation(
if self.stubbed_invocations:
return

mock_registry.unstub_mock(self)
mock_registry.unstub(self.mocked_obj)

def restore_method(self, method_name: str, original_method: object) -> None:
if original_method is _MISSING_ATTRIBUTE:
Expand All @@ -477,6 +477,23 @@ def mark_as_coroutine(self, method_name: str) -> None:
def is_marked_as_coroutine(self, method_name: str) -> bool:
return method_name in self._methods_marked_as_coroutine

def method_expects_awaitable(
self,
method_name: str,
original_method: object | None = None,
) -> bool:
if original_method is None:
original_method = self.peek_original_method(method_name)

return (
_is_coroutine_method(original_method)
or self.is_marked_as_coroutine(method_name)
or (
self.spec is None
and method_name in _ASYNC_BY_PROTOCOL_METHODS
)
)

def has_method(self, method_name: str) -> bool:
if self.spec is None:
return True
Expand Down
65 changes: 65 additions & 0 deletions tests/async_protocol_methods_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import asyncio
import inspect

import pytest

from mockito import mock, when


pytestmark = pytest.mark.usefixtures("unstub")


def run(coro):
return asyncio.run(coro)


async def _use_async_resource(resource):
async with resource as entered:
return entered


async def _collect_async_iter(values):
seen = []
async for value in values:
seen.append(value)
return seen


def test_when_thenReturn_on_ad_hoc_mock_aenter_and_aexit_are_awaitable():
resource = mock()
when(resource).__aenter__().thenReturn(resource)
when(resource).__aexit__(..., ..., ...).thenReturn(False)

pending_enter = resource.__aenter__()
assert inspect.isawaitable(pending_enter)
assert run(pending_enter) is resource

pending_exit = resource.__aexit__(None, None, None)
assert inspect.isawaitable(pending_exit)
assert run(pending_exit) is False


def test_when_thenReturn_on_ad_hoc_mock_supports_async_with():
resource = mock()
entered = object()
when(resource).__aenter__().thenReturn(entered)
when(resource).__aexit__(..., ..., ...).thenReturn(False)

assert run(_use_async_resource(resource)) is entered


def test_when_thenReturn_on_ad_hoc_mock_anext_is_awaitable():
values = mock()
when(values).__anext__().thenReturn(1)

pending = values.__anext__()
assert inspect.isawaitable(pending)
assert run(pending) == 1


def test_when_thenReturn_on_ad_hoc_mock_supports_async_for():
values = mock()
when(values).__aiter__().thenReturn(values)
when(values).__anext__().thenReturn(1).thenReturn(2).thenRaise(StopAsyncIteration)

assert run(_collect_async_iter(values)) == [1, 2]
9 changes: 6 additions & 3 deletions tests/call_original_implem_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import sys

import pytest
from mockito import mock, when, verify, ArgumentError
from mockito import mock, when, verify
from mockito.invocation import AnswerError
from mockito.verification import VerificationError

from . import module
from .test_base import TestBase
Expand Down Expand Up @@ -110,12 +111,14 @@ def testDumbMockHasNoOriginalImplementations(self):

def testDumbMockFailedThenCallOriginalImplementationDoesNotLeakStub(self):
dog = mock()
with pytest.raises(VerificationError):
verify(dog).bark()

with pytest.raises(AnswerError):
when(dog).bark().thenCallOriginalImplementation()

with pytest.raises(ArgumentError):
verify(dog).bark(Ellipsis)
with pytest.raises(VerificationError):
verify(dog).bark()

def testSpeccedMockHasOriginalImplementations(self):
dog = mock({"huge": True}, spec=Dog)
Expand Down
Loading
Loading