Skip to content

Commit

Permalink
Merge 15a74ea into b5bb3cf
Browse files Browse the repository at this point in the history
  • Loading branch information
fornellas committed Jul 17, 2020
2 parents b5bb3cf + 15a74ea commit 4e55774
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 32 deletions.
62 changes: 44 additions & 18 deletions tests/patch_attribute_testslide.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,58 @@ def patch_attribute_tests(context):

@context.shared_context
def patching_works(context):
@context.example
def patching_works(self):
def sm_hasattr(obj, name):
try:
return hasattr(obj, name)
except UndefinedAttribute:
return False

if sm_hasattr(self.real_target, self.attribute):
original_value = getattr(self.real_target, self.attribute)
@context.function
def strict_mock_hasattr(self, obj, name):
try:
return hasattr(obj, name)
except UndefinedAttribute:
return False

@context.before
def before(self):
if self.strict_mock_hasattr(self.real_target, self.attribute):
self.original_value = getattr(self.real_target, self.attribute)
else:
original_value = None
self.assertNotEqual(original_value, self.new_value)
self.patch_attribute(self.target, self.attribute, self.new_value)
self.assertEqual(getattr(self.real_target, self.attribute), self.new_value)
self.original_value = None
self.assertNotEqual(
self.original_value,
self.new_value,
"Previous test tainted this result!",
)

@context.after
def after(self):
self.assertEqual(
getattr(self.real_target, self.attribute),
self.new_value,
"Patching did not work",
)

unpatch_all_mocked_attributes()
if original_value:
if self.original_value:
self.assertEqual(
getattr(self.real_target, self.attribute), original_value
getattr(self.real_target, self.attribute),
self.original_value,
"Unpatching did not work.",
)
else:
self.assertFalse(sm_hasattr(self.real_target, self.attribute))
self.assertFalse(
self.strict_mock_hasattr(self.real_target, self.attribute),
"Unpatching did not work",
)

@context.example
def patching_works(self):
self.patch_attribute(self.target, self.attribute, self.new_value)

@context.example
def double_patching_works(self):
self.patch_attribute(self.target, self.attribute, "whatever")
self.patch_attribute(self.target, self.attribute, self.new_value)

@context.shared_context
def common(context, fails_if_class_attribute):
context.merge_context("patching works")
context.nest_context("patching works")

@context.example
def it_fails_if_attribute_is_callable(self):
Expand Down
9 changes: 6 additions & 3 deletions testslide/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import inspect
from typing import Callable, Any


class _DescriptorProxy(object):
Expand Down Expand Up @@ -37,7 +38,7 @@ def __delete__(self, instance):
del self.instance_attr_map[instance]


def _is_instance_method(target, method):
def _is_instance_method(target: Any, method):
if inspect.ismodule(target):
return False

Expand All @@ -56,7 +57,7 @@ def _is_instance_method(target, method):
return False


def _mock_instance_attribute(instance, attr, value):
def _mock_instance_attribute(instance: Any, attr: str, value: Any):
"""
Patch attribute at instance with given value. This works for any instance
attribute, even when the attribute is defined via the descriptor protocol using
Expand All @@ -80,7 +81,9 @@ def unpatch_class():
return unpatch_class


def _patch(target, attribute, new_value, restore, restore_value=None):
def _patch(
target: Any, attribute: str, new_value: Any, restore: Any, restore_value: Any = None
) -> Callable:
if _is_instance_method(target, attribute):
unpatcher = _mock_instance_attribute(target, attribute, new_value)
elif hasattr(type(target), attribute) and isinstance(
Expand Down
42 changes: 31 additions & 11 deletions testslide/patch_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from .patch import _patch
from .lib import _bail_if_private
from testslide.strict_mock import UndefinedAttribute
from typing import List, Callable
from typing import Callable, Any, Dict, Tuple

_unpatchers: List[Callable] = []
_restore_values: Dict[Tuple[Any, str], Any] = {}
_unpatchers: Dict[Tuple[Any, str], Callable] = {}


def unpatch_all_mocked_attributes():
Expand All @@ -18,19 +19,22 @@ def unpatch_all_mocked_attributes():
active patch_attribute() patches.
"""
unpatch_exceptions = []
for unpatcher in _unpatchers:
for unpatcher in _unpatchers.values():
try:
unpatcher()
except Exception as e:
unpatch_exceptions.append(e)
del _unpatchers[:]
_restore_values.clear()
_unpatchers.clear()
if unpatch_exceptions:
raise RuntimeError(
"Exceptions raised when unpatching: {}".format(unpatch_exceptions)
)


def patch_attribute(target, attribute, new_value, allow_private=False):
def patch_attribute(
target: Any, attribute: str, new_value: Any, allow_private: bool = False
) -> None:
"""
Patch target's attribute with new_value. The target can be any Python
object, such as modules, classes or instances; attribute is a string with
Expand All @@ -42,9 +46,13 @@ def patch_attribute(target, attribute, new_value, allow_private=False):
class, which may affect other instances. patch_attribute() takes care of
what's needed, so only the target instance is affected.
"""
_bail_if_private(attribute, allow_private)

if isinstance(target, str):
target = testslide._importer(target)

key = (id(target), attribute)

if isinstance(target, testslide.StrictMock):
template_class = target._template
if template_class:
Expand All @@ -55,21 +63,28 @@ def patch_attribute(target, attribute, new_value, allow_private=False):
"You can either use mock_callable() / mock_async_callable() instead."
)

def sm_hasattr(obj, name):
def strict_mock_hasattr(obj, name):
try:
return hasattr(obj, name)
except UndefinedAttribute:
return False

if sm_hasattr(target, attribute):
if strict_mock_hasattr(target, attribute) and key not in _unpatchers:
restore = True
restore_value = getattr(target, attribute)
else:
restore = False
restore_value = None
skip_unpatcher = False
else:
restore = True
restore_value = getattr(target, attribute)
if key in _unpatchers:
restore = False
restore_value = _restore_values[key]
skip_unpatcher = True
else:
restore = True
restore_value = getattr(target, attribute)
skip_unpatcher = False
if isinstance(restore_value, type):
raise ValueError(
"Attribute can not be a class!\n"
Expand All @@ -80,6 +95,11 @@ def sm_hasattr(obj, name):
"Attribute can not be callable!\n"
"You can either use mock_callable() / mock_async_callable() instead."
)
_bail_if_private(attribute, allow_private)

if restore:
_restore_values[key] = restore_value

unpatcher = _patch(target, attribute, new_value, restore, restore_value)
_unpatchers.append(unpatcher)

if not skip_unpatcher:
_unpatchers[key] = unpatcher

0 comments on commit 4e55774

Please sign in to comment.