Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 143 additions & 8 deletions mockito/matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"""

from abc import ABC, abstractmethod
import functools
import re
builtin_any = any

Expand Down Expand Up @@ -137,15 +138,46 @@ def matches(self, arg):
return True

def __repr__(self):
return "<Any: %s>" % self.wanted_type
return "<Any: %s>" % _any_wanted_type_label(self.wanted_type)


def _any_wanted_type_label(wanted_type):
if isinstance(wanted_type, type):
return _type_label(wanted_type)

if (
isinstance(wanted_type, tuple)
and all(isinstance(t, type) for t in wanted_type)
):
items = [_type_label(t) for t in wanted_type]
if len(items) == 1:
return '(%s,)' % items[0]
return '(%s)' % ', '.join(items)

return _safe_repr(wanted_type)


def _type_label(type_):
module = _safe_getattr(type_, '__module__')
qualname = _safe_getattr(type_, '__qualname__') or _safe_getattr(type_, '__name__')
if qualname is None:
return _safe_repr(type_)

if module is None or module == 'builtins':
return qualname

return '%s.%s' % (module, qualname)


class ValueMatcher(Matcher):
def __init__(self, value):
self.value = value

def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self.value)
return "<%s: %s>" % (
self.__class__.__name__,
_safe_repr(self.value),
)


class Eq(ValueMatcher):
Expand Down Expand Up @@ -223,7 +255,93 @@ def matches(self, arg):
return self.predicate(arg)

def __repr__(self):
return "<ArgThat>"
return "<ArgThat: %s>" % _arg_that_predicate_label(self.predicate)


def _arg_that_predicate_label(predicate):
try:
return _arg_that_predicate_label_unchecked(predicate)
except Exception:
predicate_class = _safe_getattr(
_safe_getattr(predicate, '__class__'),
'__name__',
)
if predicate_class is None:
return 'callable'

return 'callable %s' % predicate_class


def _arg_that_predicate_label_unchecked(predicate):
if isinstance(predicate, functools.partial):
return _arg_that_partial_label(predicate)

function_line = _line_of_callable(predicate)
function_name = _safe_getattr(predicate, '__name__')
if function_name is not None:
if function_name == '<lambda>':
return _label_with_line('lambda', function_line)
return _label_with_line('def %s' % function_name, function_line)

predicate_class = _safe_getattr(
_safe_getattr(predicate, '__class__'),
'__name__',
)
if predicate_class is None:
predicate_class = 'object'

call = _safe_getattr(predicate, '__call__')
call_line = _line_of_callable(call)
return _label_with_line(
'callable %s.__call__' % predicate_class,
call_line,
)


def _arg_that_partial_label(predicate):
partial_func = _safe_getattr(predicate, 'func')
partial_name = _safe_getattr(partial_func, '__name__')

if partial_name is not None:
return 'partial %s' % partial_name

return 'partial'


def _line_of_callable(value):
if value is None:
return None

func = _safe_getattr(value, '__func__', value)
code = _safe_getattr(func, '__code__')
if code is None:
return None

return _safe_getattr(code, 'co_firstlineno')


def _safe_getattr(value, name, default=None):
try:
return getattr(value, name)
except Exception:
return default


def _safe_repr(value):
try:
return repr(value)
except Exception:
try:
return object.__repr__(value)
except Exception:
return '<unrepresentable>'


def _label_with_line(label, line_number):
if line_number is None:
return label

return '%s at line %s' % (label, line_number)


class Contains(Matcher):
Expand All @@ -236,24 +354,41 @@ def matches(self, arg):
return self.sub and len(self.sub) > 0 and arg.find(self.sub) > -1

def __repr__(self):
return "<Contains: '%s'>" % self.sub
return "<Contains: %s>" % _safe_repr(self.sub)


class Matches(Matcher):
def __init__(self, regex, flags=0):
self.regex = re.compile(regex, flags)
self.flags = _explicit_regex_flags(regex, flags)

def matches(self, arg):
if not isinstance(arg, str):
return
return self.regex.match(arg) is not None

def __repr__(self):
if self.regex.flags:
return "<Matches: %s flags=%d>" % (self.regex.pattern,
self.regex.flags)
if self.flags:
return "<Matches: %r flags=%d>" % (self.regex.pattern, self.flags)
else:
return "<Matches: %s>" % self.regex.pattern
return "<Matches: %r>" % self.regex.pattern


def _explicit_regex_flags(regex, flags):
if flags:
return flags

compiled_flags = _safe_getattr(regex, 'flags')
pattern = _safe_getattr(regex, 'pattern')
if compiled_flags is None or pattern is None:
return 0

try:
baseline_flags = re.compile(pattern).flags
except Exception:
return compiled_flags

return compiled_flags & ~baseline_flags


class ArgumentCaptor(Matcher, Capturing):
Expand Down
169 changes: 169 additions & 0 deletions tests/matcher_repr_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from functools import partial
import re

import numpy as np

from mockito import and_, any as any_, arg_that, contains, eq, gt, matches, not_, or_


def test_value_matchers_use_repr_for_string_values():
assert repr(eq("foo")) == "<Eq: 'foo'>"


def test_composed_matchers_include_quoted_nested_values():
assert repr(not_(eq("foo"))) == "<Not: <Eq: 'foo'>>"
assert repr(and_(eq("foo"), gt(1))) == "<And: [<Eq: 'foo'>, <Gt: 1>]>"
assert repr(or_(eq("foo"), gt(1))) == "<Or: [<Eq: 'foo'>, <Gt: 1>]>"


def test_any_repr_uses_pretty_names_for_types():
assert repr(any_(int)) == "<Any: int>"
assert repr(any_((int, str))) == "<Any: (int, str)>"


def test_any_repr_quotes_non_type_values():
assert repr(any_("foo")) == "<Any: 'foo'>"


def test_any_repr_handles_types_with_broken_introspection():
class EvilMeta(type):
def __getattribute__(cls, name):
if name in {'__module__', '__qualname__', '__name__'}:
raise RuntimeError('boom')
return super().__getattribute__(name)

class Evil(metaclass=EvilMeta):
pass

matcher_repr = repr(any_(Evil))
assert matcher_repr.startswith("<Any: <class '")
assert "Evil" in matcher_repr


def test_any_repr_handles_values_with_broken_repr():
class BrokenRepr:
def __repr__(self):
raise RuntimeError('boom')

matcher_repr = repr(any_(BrokenRepr()))
assert matcher_repr.startswith('<Any: <')
assert 'BrokenRepr object' in matcher_repr


def test_value_matcher_repr_handles_values_with_broken_repr():
class BrokenRepr:
def __repr__(self):
raise RuntimeError('boom')

matcher_repr = repr(eq(BrokenRepr()))
assert matcher_repr.startswith('<Eq: <')
assert 'BrokenRepr object' in matcher_repr


def test_contains_repr_handles_values_with_broken_repr():
class BrokenRepr:
def __repr__(self):
raise RuntimeError('boom')

matcher_repr = repr(contains(BrokenRepr()))
assert matcher_repr.startswith('<Contains: <')
assert 'BrokenRepr object' in matcher_repr


def test_contains_repr_uses_safe_quoted_substring():
assert repr(contains("a'b")) == "<Contains: \"a'b\">"


def test_matches_repr_shows_only_explicit_flags():
assert repr(matches("f..")) == "<Matches: 'f..'>"
assert repr(matches("f..", re.IGNORECASE)) == (
f"<Matches: 'f..' flags={int(re.IGNORECASE)}>"
)


def test_matches_repr_shows_flags_for_compiled_patterns():
compiled = re.compile('f..', re.IGNORECASE)

assert repr(matches(compiled)) == (
f"<Matches: 'f..' flags={int(re.IGNORECASE)}>"
)


def test_arg_that_repr_includes_named_function_name():
# Predicate display name: "def is_positive"
def is_positive(value):
return value > 0

matcher = arg_that(is_positive)

assert repr(matcher) == (
f"<ArgThat: def is_positive at line {is_positive.__code__.co_firstlineno}>"
)


def test_arg_that_repr_includes_lambda_name():
# Predicate display name: "lambda"
predicate = lambda value: value > 0
matcher = arg_that(predicate)

assert repr(matcher) == (
f"<ArgThat: lambda at line {predicate.__code__.co_firstlineno}>"
)


def test_arg_that_repr_for_callable_instance_includes_class_name():
# Predicate display name: "callable IsPositive.__call__"
class IsPositive:
def __call__(self, value):
return value > 0

predicate = IsPositive()
matcher = arg_that(predicate)

assert repr(matcher) == (
"<ArgThat: callable IsPositive.__call__ at line "
f"{predicate.__call__.__func__.__code__.co_firstlineno}>"
)


def test_arg_that_repr_for_builtin_callable_has_no_line_number():
matcher = arg_that(len)

assert repr(matcher) == "<ArgThat: def len>"


def test_arg_that_repr_for_partial_uses_underlying_function_name():
predicate = partial(pow, exp=2)
matcher = arg_that(predicate)

assert repr(matcher) == "<ArgThat: partial pow>"


def test_arg_that_repr_for_numpy_ufunc_uses_function_name_without_line():
matcher = arg_that(np.isfinite)

assert repr(matcher) == "<ArgThat: def isfinite>"


def test_arg_that_repr_for_partial_numpy_function_uses_wrapped_name():
predicate = partial(np.allclose, b=0.0)
matcher = arg_that(predicate)

assert repr(matcher) == "<ArgThat: partial allclose>"


def test_arg_that_repr_handles_callables_with_broken_name_introspection():
class BrokenNameCallable:
def __getattribute__(self, name):
if name == '__name__':
raise RuntimeError("boom")
return super().__getattribute__(name)

def __call__(self, value):
return value > 0

matcher = arg_that(BrokenNameCallable())

matcher_repr = repr(matcher)
assert matcher_repr.startswith("<ArgThat: callable BrokenNameCallable")
assert "__name__" not in matcher_repr
Loading