Skip to content

Commit

Permalink
Add some type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
keis committed Nov 29, 2020
1 parent 2822d56 commit e4fbda7
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 46 deletions.
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
language: python
os: linux
python:
- "2.7"
- "3.6"
- "3.7"
- "3.8"
- "pypy"
- "3.9"
install:
- pip install '.[tests]'
script:
Expand Down
97 changes: 59 additions & 38 deletions matchmock/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
'''Hamcrest matchers for mock objects'''

from typing import Any, Collection, Mapping, Sequence, Tuple

from unittest.mock import Mock, _Call
from hamcrest.core.base_matcher import BaseMatcher
from hamcrest.core.description import Description
from hamcrest.core.matcher import Matcher
from hamcrest.core.helpers.wrap_matcher import wrap_matcher
from hamcrest import (equal_to, anything, has_entries,
has_item, greater_than)
from hamcrest import (
equal_to, anything, has_entries, has_item, greater_than)

__all__ = ['called', 'not_called', 'called_once',
'called_with', 'called_once_with', 'called_n_times']

_Args = Tuple
_Kwargs = Mapping[str, Any]


def describe_call(args, kwargs, desc):
def describe_call(args: _Args, kwargs: _Kwargs, desc: Description) -> None:
desc.append_text('(')
desc.append_list('', ', ', '', args)
desc.append_text(', ')
Expand All @@ -24,47 +32,51 @@ def describe_call(args, kwargs, desc):
desc.append_text(')')


class IsCall(BaseMatcher):
class IsCall(BaseMatcher[_Call]):
'''A matcher that describes a call.
The positional arguments and keyword arguments are represented with
individual submatchers.
'''
args: Matcher[_Args]
kwargs: Matcher[_Kwargs]

def __init__(self, args, kwargs):
def __init__(self, args: Matcher[_Args], kwargs: Matcher[_Kwargs]) -> None:
self.args = args
self.kwargs = kwargs

def _matches(self, item):
def _matches(self, item: _Call) -> bool:
# in python >= 3.8 this can be item.args, item.kwargs
if len(item) == 3:
_name, args, kwargs = item
else:
args, kwargs = item
return self.args.matches(args) and self.kwargs.matches(kwargs)

def describe_mismatch(self, item, mismatch_description):
def describe_mismatch(self, item: _Call, mismatch_description: Description) -> None:
# in python >= 3.8 this can be item.args, item.kwargs
if len(item) == 3:
_name, args, kwargs = item
else:
args, kwargs = item
return (self.args.matches(args, mismatch_description) and
self.kwargs.matches(kwargs, mismatch_description))
if self.args.matches(args, mismatch_description):
self.kwargs.matches(kwargs, mismatch_description)

def describe_to(self, desc):
def describe_to(self, desc: Description) -> None:
desc.append_text('(')
desc.append_description_of(self.args)
desc.append_text(', ')
desc.append_description_of(self.kwargs)
desc.append_text(')')


class IsArgs(BaseMatcher):
def __init__(self, matchers):
self.matchers = tuple(matchers)
class IsArgs(BaseMatcher[_Args]):
matchers: Sequence[Matcher]

def __init__(self, matchers: Sequence[Matcher]) -> None:
self.matchers = matchers

def matches(self, obj, mismatch_description=None):
def matches(self, obj: _Args, mismatch_description: Description = None) -> bool:
md = mismatch_description
if len(obj) < len(self.matchers):
if md:
Expand All @@ -84,23 +96,27 @@ def matches(self, obj, mismatch_description=None):
return False
return True

def describe_to(self, desc):
def describe_to(self, desc: Description) -> None:
desc.append_list('', ', ', '', self.matchers)


class IsKwargs(BaseMatcher):
def __init__(self, matchers):
self._matcher = has_entries(matchers)
self._expected_keys = set(matchers.keys())
class IsKwargs(BaseMatcher[_Kwargs]):
_value_matchers: Collection[Tuple[str, Matcher]]
_matcher: Matcher[_Kwargs]

def matches(self, obj, mismatch_description=None):
def __init__(self, value_matchers: Mapping[str, Matcher]) -> None:
self._value_matchers = value_matchers.items()
self._matcher = has_entries(value_matchers)

def matches(self, obj: _Kwargs, mismatch_description: Description = None) -> bool:
md = mismatch_description
ok = self._matcher.matches(obj, md)
if not ok:
return False

expected_keys = set(k for k, _ in self._value_matchers)
actual_keys = set(obj.keys())
extra_keys = actual_keys - self._expected_keys
extra_keys = actual_keys - expected_keys
if len(extra_keys) > 0:
if md:
md.append_text('extra keyword argument(s) ') \
Expand All @@ -110,9 +126,9 @@ def matches(self, obj, mismatch_description=None):

return True

def describe_to(self, desc):
def describe_to(self, desc: Description) -> None:
first = True
for key, value in self._matcher.value_matchers:
for key, value in self._value_matchers:
if not first:
desc.append_text(', ')
desc.append_text(key) \
Expand All @@ -121,35 +137,40 @@ def describe_to(self, desc):
first = False


def match_args(args):
def match_args(args: Sequence[Any]) -> Matcher[_Args]:
'''Create a matcher for positional arguments'''

return IsArgs(wrap_matcher(m) for m in args)
return IsArgs(tuple(wrap_matcher(m) for m in args))


def match_kwargs(kwargs):
def match_kwargs(kwargs: Mapping[str, Any]) -> Matcher[_Kwargs]:
'''Create a matcher for keyword arguments'''

return IsKwargs({k: wrap_matcher(v) for k, v in kwargs.items()})


class IsCalled(BaseMatcher):
class IsCalled(BaseMatcher[Mock]):
'''Matches a mock and asserts the number of calls and parameters'''

def __init__(self, call, count=None):
call: Matcher[_Call]
count: Matcher[int]
has_call: Matcher[Sequence[_Call]]

def __init__(self, call: Matcher[_Call], count: Matcher[int]) -> None:
self.call = call
self.count = count
self.has_call = has_item(self.call)

def _matches(self, item):
def _matches(self, item: Mock) -> bool:
if not self.count.matches(item.call_count):
return False

if len(item.call_args_list) == 0:
return True

return bool(has_item(self.call).matches(item.call_args_list))
return self.has_call.matches(item.call_args_list)

def describe_mismatch(self, item, mismatch_description):
def describe_mismatch(self, item: Mock, mismatch_description: Description) -> None:
if not self.count.matches(item.call_count):
mismatch_description.append_text(
'was called %s times' % item.call_count)
Expand All @@ -170,45 +191,45 @@ def describe_mismatch(self, item, mismatch_description):
if i != item.call_count - 1:
mismatch_description.append_text(', ')

def describe_to(self, desc):
def describe_to(self, desc: Description) -> None:
desc.append_text('Mock called ')
self.count.describe_to(desc)
desc.append_text(' times with ')
self.call.describe_to(desc)


def called():
def called() -> Matcher[Mock]:
'''Match mock that was called one or more times'''

return IsCalled(anything(), count=greater_than(0))


def called_n_times(n):
def called_n_times(n) -> Matcher[Mock]:
'''Match mock that was called exactly a given number of times'''

return IsCalled(anything(), count=equal_to(n))


def not_called():
def not_called() -> Matcher[Mock]:
'''Match mock that was never called'''

return IsCalled(anything(), count=equal_to(0))


def called_once():
def called_once() -> Matcher[Mock]:
'''Match mock that was called once regardless of arguments'''

return IsCalled(anything(), count=equal_to(1))


def called_with(*args, **kwargs):
def called_with(*args, **kwargs) -> Matcher[Mock]:
'''Match mock has at least one call with the specified arguments'''

return IsCalled(IsCall(match_args(args), match_kwargs(kwargs)),
count=greater_than(0))


def called_once_with(*args, **kwargs):
def called_once_with(*args, **kwargs) -> Matcher[Mock]:
'''Match mock that was called once and with the specified arguments'''

return IsCalled(IsCall(match_args(args), match_kwargs(kwargs)),
Expand Down
19 changes: 15 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = matchmock
author = David Keijser
author_email = keijser@gmail.com
version = 1.0.0
version = 2.0.0
description = Hamcrest matchers for mock objects.
long_description = file: README.md
long_description_content_type = text/markdown
Expand All @@ -13,10 +13,21 @@ packages = matchmock

[options.extras_require]
tests =
pytest>=4.6
pytest-flakes<2
pytest>=6
pytest-flakes>2
pytest-cov
PyHamcrest
pep8
pycodestyle
mock
coveralls
mypy

[mypy]
warn_unreachable = True
warn_unused_ignores = True
warn_redundant_casts = True
warn_unused_configs = True
warn_return_any = True
disallow_any_unimported = True
check_untyped_defs = True
incremental = False
16 changes: 15 additions & 1 deletion tests/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
def test_matching_call():
m = IsCall(match_args(('foo',)), match_kwargs({}))
value = call('foo')
print(value)
assert_that(m.matches(value), equal_to(True))


Expand All @@ -24,6 +23,13 @@ def test_describe_self():
assert_that(str(s), equal_to("('foo', )"))


def test_describe_self_with_kwargs():
m = IsCall(match_args(('foo',)), match_kwargs({'key': 'value'}))
s = StringDescription()
m.describe_to(s)
assert_that(str(s), equal_to("('foo', key='value')"))


def test_describe_mismatch():
m = IsCall(match_args(('foo',)), match_kwargs({}))
value = call('bar')
Expand All @@ -32,6 +38,14 @@ def test_describe_mismatch():
assert_that(str(s), equal_to("argument 0: was 'bar'"))


def test_describe_mismatch_kwargs():
m = IsCall(match_args(('foo',)), match_kwargs({'key': 'value'}))
value = call('foo', key='VALUE')
s = StringDescription()
m.describe_mismatch(value, s)
assert_that(str(s), equal_to("value for 'key' was 'VALUE'"))


def test_args_mismatch_complex():
m = IsCall(match_args([has_entries(name='foo')]), match_kwargs({}))
value = call({'name': 'baz'})
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tox]
envlist = py27,py37
envlist = py37

[testenv]
deps = .[tests]
Expand Down

0 comments on commit e4fbda7

Please sign in to comment.