Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ Release 2.0.0
Note that async introspection metadata (e.g. `inspect.iscoroutinefunction`)
for stub wrappers is currently implemented only on Python 3.12+.

- Expanded `mock({...})` constructor shorthands:
- `"async <name>"` marks methods as async and supports either `...` or a function value.
- `{"__enter__": ...}` / `{"__aenter__": ...}` now install default matching
`__exit__` / `__aexit__` handlers when not provided.
- `{"__iter__": [..]}` and `{"__aiter__": [..]}` now normalize values into
proper iterator / async-iterator behavior.
- In constructor dict shorthands, zero-argument functions now widen to
accept arbitrary call arguments (e.g. `lambda: "ok"` behaves like
`lambda *a, **kw: "ok"`).

- Added first-class property/descriptor stubbing support, including class-level property
stubbing via `when(F).p.thenReturn(...)` and `thenCallOriginalImplementation()` support for
property stubs (including chained answers like
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Read
:maxdepth: 1

walk-through
mock-shorthands
recipes
the-functions
any-and-ellipses
Expand Down
109 changes: 109 additions & 0 deletions docs/mock-shorthands.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
mock() configuration and shorthands
===================================

If you really dig mock driven development, you use dumb mock()s and don't patch
real objects and modules all the time.

The standard setup works as expected::

cat = mock()
when(cat).meow().thenReturn("Miau!")

# Use it
cat.meow()

To get you up to speed, we have several shortcuts:

cat = mock({"age": 12})
cat.age # => 12

You can also define functions::

cat = mock({"meow": lambda: "Miau!"})
cat.meow() # => "Miau!"

Note that such a lambda without any arguments defined, accepts all possible arguments
and always returns the same answer. It is thus the same as saying

when(cat).meow(...).thenReturn("Miau!") # note the Ellipsis

If you want to define async functions, use

response = mock({"async text": lambda: "Hi"})
session = mock({"async get": lambda: response})

To build up a complete `aiohttp` example,

import aiohttp
from mockito import when, unstub

async def fetch_text(location, session):
async with session.get(location, raise_for_status=True) as resp:
return await resp.text()

you also need to define the context/with handlers:

resp = mock({
"__aenter__": ...,
"async text": lambda: "Fake!"
})

session = mock({
# since __aenter__ is async by protocol "async __aenter__" is not needed (but allowed)
"__aenter__": ..., # <== ... denotes to install a standard return value of self
# it always installs a standard __aexit__ returning None or False
# if not provided by the user

"async get": lambda: resp, # <== install async method with *args, **kwargs
# equivalent to when(session).get(...).thenReturn(resp)
})

.. note::

``__aenter__``, ``__aexit__``, ``__anext__`` are async by definition,
use either ``mock({"__aenter__": ...})`` or
``mock({"async __aenter__": ...})``.

For ``__aiter__``, we have a special shortcode:

numbers = mock({"__aiter__": [1, 2, 3]}) # install a function that wraps these values
# in an async iterator for easy use

async for number in numbers:
...


You can also just mark a function async::

session = mock({
"__aenter__": ...,
"async get": ..., # <== record the intent that this is an async method
# and install a `return None` handler as well
# You can override that handler clearly later,
# see right below
})
when(session).get(..., raise_for_status=True).thenReturn(resp) # async! as marked before

# This session can be used as return value for the global constructor, e.g.
when(aiohttp).ClientSession().thenReturn(session)

# and then passed around
body = await fetch_text('https://example.com', session)
assert body == 'Fake!'

We have the same shortcuts available for `__enter__` and `__iter__`.

mock({"__enter__": ...}) # installs a standard enter that return self
# and a standard exit handler returning None if nothing else is
# provided by the user.

mock({"__iter__": [4, 5, 6]}) # install handler and wrap in an iterator

Remember or note that when you rather use specced mock()s you're more or less limited by what the spec
implements. If you for example use `aiohttp.ClientSession` as the blueprint for your mock,
we already know that `get` is async and you don't need to tell mockito so.

mock({
"get": lambda: response # Look up if ClientSession defines "async def get"
# and follow suit.
} , spec=ClientSession)
5 changes: 3 additions & 2 deletions mockito/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,9 @@ def __init__(
if strict is not None:
self.strict = strict

self.refers_coroutine = is_coroutine_method(
mock.peek_original_method(method_name)
self.refers_coroutine = (
is_coroutine_method(mock.peek_original_method(method_name))
or mock.is_marked_as_coroutine(method_name)
)
self.discard_first_arg = mock.will_have_self_or_cls(method_name)
default_answer = (
Expand Down
177 changes: 172 additions & 5 deletions mockito/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import functools
from collections import deque
from contextlib import contextmanager
from typing import AsyncIterator, Callable, Iterable, Iterator, cast

from . import invocation, signature, utils
from .mock_registry import mock_registry
Expand All @@ -40,6 +41,9 @@

_MISSING_ATTRIBUTE = object()

_CONFIG_ASYNC_PREFIX = "async "
_ASYNC_BY_PROTOCOL_METHODS = {"__aenter__", "__aexit__", "__anext__"}


class _Dummy:
# We spell out `__call__` here for convenience. All other magic methods
Expand Down Expand Up @@ -223,6 +227,7 @@ def __init__(
list[tuple[str, object | None, object]] = []

self._observers: list = []
self._methods_marked_as_coroutine: set[str] = set()

def attach(self, observer) -> None:
if observer not in self._observers:
Expand Down Expand Up @@ -428,9 +433,16 @@ def unstub(self) -> None:
self.restore_method(method_name, original_method)
self.stubbed_invocations = deque()
self.invocations = []
self._methods_marked_as_coroutine = set()

# SPECCING

def mark_as_coroutine(self, method_name: str) -> None:
self._methods_marked_as_coroutine.add(method_name)

def is_marked_as_coroutine(self, method_name: str) -> bool:
return method_name in self._methods_marked_as_coroutine

def has_method(self, method_name: str) -> bool:
if self.spec is None:
return True
Expand Down Expand Up @@ -609,11 +621,166 @@ def __repr__(self):
obj = Dummy()
theMock = Mock(Dummy, strict=strict, spec=spec)

for n, v in config.items():
if inspect.isfunction(v):
invocation.StubbedInvocation(theMock, n)(Ellipsis).thenAnswer(v)
else:
setattr(Dummy, n, v)
normalized_names = {
_normalize_config_key(raw_name)[0]
for raw_name in config
}

for raw_name, value in config.items():
_configure_mock_from_shorthand(
theMock,
Dummy,
obj,
raw_name,
value,
normalized_names,
)

mock_registry.register(obj, theMock)
return obj


def _configure_mock_from_shorthand(
theMock: Mock,
Dummy: type,
obj: object,
raw_name: str,
value: object,
configured_names: set[str],
) -> None:
method_name, marked_async = _normalize_config_key(raw_name)
should_be_async = marked_async or method_name in _ASYNC_BY_PROTOCOL_METHODS

if method_name in {"__enter__", "__aenter__"} and value is Ellipsis:
_stub_from_shorthand(
theMock,
method_name,
return_value=obj,
force_async=(method_name == "__aenter__"),
)

companion_exit = "__aexit__" if method_name == "__aenter__" else "__exit__"
if companion_exit not in configured_names:
_stub_from_shorthand(
theMock,
companion_exit,
return_value=False,
force_async=(companion_exit == "__aexit__"),
)
return

if method_name == "__iter__":
iter_answer = _normalize_iter_answer(value)
_stub_from_shorthand(theMock, method_name, answer=iter_answer)
return

if method_name == "__aiter__":
aiter_answer = _normalize_aiter_answer(value)
_stub_from_shorthand(theMock, method_name, answer=aiter_answer)
return

if inspect.isfunction(value):
function_answer = _widen_zero_arg_callable(value)
_stub_from_shorthand(
theMock,
method_name,
answer=function_answer,
force_async=should_be_async,
)
return

if should_be_async:
if value is Ellipsis:
_stub_from_shorthand(theMock, method_name, force_async=True)
return

raise TypeError(
"Async shorthand '%s' expects a function value or Ellipsis. "
"Use `lambda: value` for fixed async return values."
% raw_name
)

setattr(Dummy, method_name, value)


def _normalize_config_key(raw_name: str) -> tuple[str, bool]:
if raw_name.startswith(_CONFIG_ASYNC_PREFIX):
return raw_name[len(_CONFIG_ASYNC_PREFIX):], True
return raw_name, False


def _stub_from_shorthand(
theMock: Mock,
method_name: str,
*,
answer: object = OMITTED,
return_value: object = OMITTED,
force_async: bool = False,
) -> None:
if force_async:
theMock.mark_as_coroutine(method_name)

stubbed = invocation.StubbedInvocation(theMock, method_name)(Ellipsis)

if answer is not OMITTED:
stubbed.thenAnswer(answer) # type: ignore[arg-type]
elif return_value is not OMITTED:
stubbed.thenReturn(return_value)


def _widen_zero_arg_callable(function: object):
if not inspect.isfunction(function):
return function

try:
params = inspect.signature(function).parameters
except Exception:
return function

if params:
return function

def widened(*args, **kwargs):
return function()

widened.__name__ = function.__name__
widened.__doc__ = function.__doc__
return widened


def _normalize_iter_answer(value) -> Callable[..., Iterator[object]]:
def answer(*args, **kwargs) -> Iterator[object]:
result = value(*args, **kwargs) if callable(value) else value
return iter(cast(Iterable[object], result))

return answer


def _normalize_aiter_answer(value) -> Callable[..., AsyncIterator[object]]:
def answer(*args, **kwargs) -> AsyncIterator[object]:
result = value(*args, **kwargs) if callable(value) else value
return _normalize_aiter_result(result)

return answer


def _normalize_aiter_result(value) -> AsyncIterator[object]:
if hasattr(value, "__anext__"):
return cast(AsyncIterator[object], value)

aiter = getattr(value, "__aiter__", None)
if callable(aiter):
candidate = aiter()
if hasattr(candidate, "__anext__"):
return cast(AsyncIterator[object], candidate)
raise TypeError(
"__aiter__() must return an async iterator implementing __anext__"
)

iterator = iter(cast(Iterable[object], value))

async def generator() -> AsyncIterator[object]:
for item in iterator:
yield item

return generator()
Loading
Loading