Skip to content

Commit

Permalink
Fix StrictMock validation with inheritance (#283)
Browse files Browse the repository at this point in the history
Summary:
Closes #282 (tests to follow).

Pull Request resolved: #283

Reviewed By: deathowl

Differential Revision: D26489435

Pulled By: fornellas

fbshipit-source-id: 77602b3899237038c331286a413a6f5f83ffb620
  • Loading branch information
fornellas authored and facebook-github-bot committed Feb 19, 2021
1 parent 65dafe9 commit 276dccc
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 46 deletions.
43 changes: 36 additions & 7 deletions tests/strict_mock_testslide.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
NonExistentAttribute,
StrictMock,
UndefinedAttribute,
UnsupportedMagic,
)


Expand Down Expand Up @@ -107,18 +108,33 @@ def static_method_wrapped(extra, message):
def class_method_wrapped(cls, extra, message):
return "class_method: {}".format(message)

def __eq__(self, other):
return id(self) == id(other)

class TemplateStrictMock(StrictMock):
def __hash__(self):
return id(self)


class TemplateBaseStrictMock(StrictMock):
def __init__(self):
super().__init__(template=Template)

def instance_method(self, message):
return "mock"
@staticmethod
def static_method(message):
return 101 # Wrong type

def __len__(self):
return 100


class TemplateStrictMock(TemplateBaseStrictMock):
def __instance_method_helper(self):
return "mock"

def instance_method(self, message):
return self.__instance_method_helper()


class ContextManagerTemplate(Template):
def __enter__(self):
pass
Expand Down Expand Up @@ -275,6 +291,22 @@ def overriding_regular_methods_work(self):
def overriding_magic_methods_work(self):
self.assertEqual(len(self.strict_mock), 100)

@context.example
def type_validation_works(self):
with self.assertRaises(TypeCheckError):
self.strict_mock.static_method("whatever")

@context.example
def hash_works(self):
d = {}
d[self.strict_mock] = "value"
self.assertEqual(d[self.strict_mock], "value")

@context.example
def cant_set_hash(self):
with self.assertRaises(UnsupportedMagic):
self.strict_mock.__hash__ = lambda: 0

@context.sub_context
def given_as_an_argument(context):
@context.sub_context
Expand Down Expand Up @@ -368,10 +400,7 @@ def raises_when_setting_non_existing_attributes(self):
attr_name = "non_existing_attr"
with self.assertRaisesWithRegexMessage(
NonExistentAttribute,
f"'{attr_name}' can not be set.\n"
f"{self.strict_mock_rgx} template class does not have "
"this attribute so the mock can not have it as well.\n"
"See also: 'runtime_attrs' at StrictMock.__init__.",
f"'{attr_name}' is not part of the API.*",
):
setattr(self.strict_mock, attr_name, "whatever")

Expand Down
125 changes: 86 additions & 39 deletions testslide/strict_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ def __init__(self, strict_mock: "StrictMock", name: str) -> None:

def __str__(self) -> str:
return (
f"'{self.name}' can not be set.\n"
f"{self.strict_mock} template class does not have this attribute "
"so the mock can not have it as well.\n"
f"'{self.name}' is not part of the API.\n"
f"{self.strict_mock} template class API does not have this "
"attribute so the mock can not have it as well.\n"
"If you are inheriting StrictMock, you can define private "
"attributes, that will not interfere with the API, by prefixing "
"them with '__' (and at most one '_' suffix) "
" (https://docs.python.org/3/tutorial/classes.html#tut-private).\n"
"See also: 'runtime_attrs' at StrictMock.__init__."
)

Expand Down Expand Up @@ -223,7 +227,7 @@ class StrictMock(object):
# expected to work as they should. If implemented by the template class,
# they will have default values assigned to them, that raise
# UndefinedAttribute until configured.
_SETTABLE_MAGICS = [
__SETTABLE_MAGICS = [
"__abs__",
"__add__",
"__aenter__",
Expand Down Expand Up @@ -332,7 +336,7 @@ class StrictMock(object):

# These magics either won't work or makes no sense to be set for mock after
# an instance of a class. Trying to set them will raise UnsupportedMagic.
_UNSETTABLE_MAGICS = [
__UNSETTABLE_MAGICS = [
"__bases__",
"__class__",
"__class_getitem__",
Expand Down Expand Up @@ -377,7 +381,7 @@ def __new__(
strict_mock_instance = object.__new__(strict_mock_subclass)
return strict_mock_instance

def _setup_magic_methods(self) -> None:
def __setup_magic_methods(self) -> None:
"""
Populate all template's magic methods with expected default behavior.
This is important as things such as bool() depend on they existing
Expand All @@ -399,15 +403,23 @@ def _setup_magic_methods(self) -> None:
if klass is object:
continue
for name in klass.__dict__:
if name in type(self).__dict__:
continue
if name == "__hash__":
if klass.__dict__["__hash__"] is None:
setattr(self, name, None)
else:
setattr(self, name, lambda: id(self))
continue
if (
callable(klass.__dict__[name])
and name in self._SETTABLE_MAGICS
and name not in self._UNSETTABLE_MAGICS
and name in self.__SETTABLE_MAGICS
and name not in self.__UNSETTABLE_MAGICS
and name not in implemented_magic_methods
):
setattr(self, name, _DefaultMagic(self, name))

def _setup_default_context_manager(self, default_context_manager: bool) -> None:
def __setup_default_context_manager(self, default_context_manager: bool) -> None:
if self._template and default_context_manager:
if hasattr(self._template, "__enter__") and hasattr(
self._template, "__exit__"
Expand All @@ -427,10 +439,10 @@ async def aexit(exc_type, exc_value, traceback):
self.__aenter__ = aenter
self.__aexit__ = aexit

def _get_caller_frame(self, depth: int) -> FrameType:
def __get_caller_frame(self, depth: int) -> FrameType:
# Adding extra 3 to account for the stack:
# _get_caller_frame
# _get_caller
# __get_caller_frame
# __get_caller
# __init__
depth = depth + 3
current_frame = inspect.currentframe()
Expand All @@ -443,11 +455,11 @@ def _get_caller_frame(self, depth: int) -> FrameType:

return current_frame # type: ignore

def _get_caller(self, depth: int) -> Optional[str]:
def __get_caller(self, depth: int) -> Optional[str]:
# Doing inspect.stack will retrieve the whole stack, including context
# and that is really slow, this only retrieves the minimum, and does
# not read the file contents.
caller_frame = self._get_caller_frame(depth)
caller_frame = self.__get_caller_frame(depth)
# loading the context ends up reading files from disk and that might block
# the event loop, so we don't do it.
frameinfo = inspect.getframeinfo(caller_frame, context=0)
Expand All @@ -462,6 +474,35 @@ def _get_caller(self, depth: int) -> Optional[str]:
else:
return None

def __setup_subclass(self):
"""
When StrictMock is subclassed, any attributes defined at the subclass
will override any of StrictMock's validations. In order to overcome
this, for attributes that makes sense, we set them at StrictMock's
dynamically created subclass from __new__ using __setattr__, so that
all validations work.
"""
if type(self).mro()[1] == StrictMock:
return
for klass in type(self).mro()[1:]:
if klass == StrictMock:
break
for name in klass.__dict__.keys():
if name in [
"__doc__",
"__init__",
"__module__",
]:
continue
# https://docs.python.org/3/tutorial/classes.html#tut-private
if name.startswith(f"_{type(self).__name__}__") and not name.endswith(
"__"
):
continue
if name == "__hash__" and klass.__dict__["__hash__"] is None:
continue
StrictMock.__setattr__(self, name, getattr(self, name))

def __init__(
self,
template: Optional[type] = None,
Expand Down Expand Up @@ -494,7 +535,7 @@ def __init__(
self.__dict__["_runtime_attrs"] = runtime_attrs or []
self.__dict__["_name"] = name
self.__dict__["_type_validation"] = type_validation
self.__dict__["__caller"] = self._get_caller(1)
self.__dict__["__caller"] = self.__get_caller(1)
self.__dict__[
"_attributes_to_skip_type_validation"
] = attributes_to_skip_type_validation
Expand All @@ -505,9 +546,11 @@ def __init__(
caller_frame_info = inspect.getframeinfo(caller_frame, context=0) # type: ignore
self.__dict__["_caller_frame_info"] = caller_frame_info

self._setup_magic_methods()
self.__setup_magic_methods()

self._setup_default_context_manager(default_context_manager)
self.__setup_default_context_manager(default_context_manager)

self.__setup_subclass()

@property # type: ignore
def __class__(self) -> type:
Expand All @@ -522,11 +565,12 @@ def _template(self) -> None:
# introspection.
return testslide.mock_constructor._get_class_or_mock(self.__dict__["_template"])

# FIXME change to __runtime_attrs
@property
def _runtime_attrs(self) -> Optional[List[Any]]:
return self.__dict__["_runtime_attrs"]

def _template_has_attr(self, name: str) -> bool:
def __template_has_attr(self, name: str) -> bool:
def get_class_init(klass: type) -> Callable:
import testslide.mock_constructor # Avoid cyclic dependencies

Expand Down Expand Up @@ -564,10 +608,10 @@ def is_runtime_attr() -> bool:
)

@staticmethod
def _is_magic_method(name: str) -> bool:
def __is_magic_method(name: str) -> bool:
return name.startswith("__") and name.endswith("__")

def _validate_attribute_type(self, name: str, value: Any) -> None:
def __validate_attribute_type(self, name: str, value: Any) -> None:
if (
not self.__dict__["_type_validation"]
or name in self.__dict__["_attributes_to_skip_type_validation"]
Expand All @@ -579,12 +623,12 @@ def _validate_attribute_type(self, name: str, value: Any) -> None:
if name in annotations:
testslide.lib._validate_argument_type(annotations[name], name, value)

def _validate_and_wrap_mock_value(self, name: str, value: Any) -> Any:
def __validate_and_wrap_mock_value(self, name: str, value: Any) -> Any:
if self._template:
if not self._template_has_attr(name):
if not self.__template_has_attr(name):
raise NonExistentAttribute(self, name)

self._validate_attribute_type(name, value)
self.__validate_attribute_type(name, value)

if hasattr(self._template, name):
template_value = getattr(self._template, name)
Expand Down Expand Up @@ -658,23 +702,26 @@ def return_validation_wrapper(*args, **kwargs):
return value

def __setattr__(self, name: str, value: Any) -> None:
if self._is_magic_method(name):
if self.__is_magic_method(name):
# ...check whether we're allowed to mock...
if name in self._UNSETTABLE_MAGICS or (
name in StrictMock.__dict__ and name not in self._SETTABLE_MAGICS
):
if (
name in self.__UNSETTABLE_MAGICS
or (name in StrictMock.__dict__ and name not in self.__SETTABLE_MAGICS)
) and name != "__hash__":
raise UnsupportedMagic(self, name)
# ...or if it is something unsupported.
if name not in self._SETTABLE_MAGICS:
if name not in self.__SETTABLE_MAGICS and name != "__hash__":
raise NotImplementedError(
f"StrictMock does not implement support for {name}"
)
if name == "__hash__" and name in type(self).__dict__:
raise UnsupportedMagic(self, name)

mock_value = self._validate_and_wrap_mock_value(name, value)
mock_value = self.__validate_and_wrap_mock_value(name, value)
setattr(type(self), name, mock_value)

def __getattr__(self, name: str) -> Any:
if self._template and self._template_has_attr(name):
if self._template and self.__template_has_attr(name):
raise UndefinedAttribute(self, name)
else:
raise AttributeError(f"'{name}' was not set for {self}.")
Expand Down Expand Up @@ -707,45 +754,45 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return self.__repr__()

def _get_copy(self) -> "StrictMock":
self_copy = type(self)(
def __get_copy(self) -> "StrictMock":
self_copy = StrictMock(
template=self._template,
runtime_attrs=self._runtime_attrs,
name=self._name,
type_validation=self._type_validation,
attributes_to_skip_type_validation=self._attributes_to_skip_type_validation,
)
self_copy.__dict__["__caller"] = self._get_caller(2)
self_copy.__dict__["__caller"] = self.__get_caller(2)
return self_copy

def _get_copyable_attrs(self, self_copy: "StrictMock") -> List[str]:
def __get_copyable_attrs(self, self_copy: "StrictMock") -> List[str]:
attrs = []
for name in type(self).__dict__:
if name not in self_copy.__dict__:
if (
name.startswith("__")
and name.endswith("__")
and not name in self._SETTABLE_MAGICS
and name not in self.__SETTABLE_MAGICS
):
continue
attrs.append(name)
return attrs

def __copy__(self) -> "StrictMock":
self_copy = self._get_copy()
self_copy = self.__get_copy()

for name in self._get_copyable_attrs(self_copy):
for name in self.__get_copyable_attrs(self_copy):
setattr(type(self_copy), name, type(self).__dict__[name])

return self_copy

def __deepcopy__(self, memo: Optional[Dict[Any, Any]] = None) -> "StrictMock":
if memo is None:
memo = {}
self_copy = self._get_copy()
self_copy = self.__get_copy()
memo[id(self)] = self_copy

for name in self._get_copyable_attrs(self_copy):
for name in self.__get_copyable_attrs(self_copy):
value = copy.deepcopy(type(self).__dict__[name], memo)
setattr(type(self_copy), name, value)
return self_copy
Expand Down

0 comments on commit 276dccc

Please sign in to comment.