diff --git a/vcr/patch.py b/vcr/patch.py index dc4e72e8..ab0036b1 100644 --- a/vcr/patch.py +++ b/vcr/patch.py @@ -57,41 +57,22 @@ class PatcherBuilder(object): def __init__(self, cassette): self._cassette = cassette - self._class_to_cassette_subclass = {} def build_patchers(self): patcher_args = itertools.chain(self._httplib(), self._requests(), self._urllib3(), self._httplib2(), self._boto()) - for args in patcher_args: - patcher = self._build_patcher(*args) + for obj, patched_attribute, replacement_class in patcher_args: + patcher = self._build_patcher(obj, patched_attribute, replacement_class) if patcher: yield patcher + if hasattr(replacement_class, 'cassette'): + yield mock.patch.object(replacement_class, 'cassette', self._cassette) def _build_patcher(self, obj, patched_attribute, replacement_class): if not hasattr(obj, patched_attribute): return - - if isinstance(replacement_class, dict): - for key in replacement_class: - replacement_class[key] = self._get_cassette_subclass(replacement_class[key]) - else: - replacement_class = self._get_cassette_subclass(replacement_class) return mock.patch.object(obj, patched_attribute, replacement_class) - def _get_cassette_subclass(self, klass): - if klass.cassette is not None: - return klass - if klass not in self._class_to_cassette_subclass: - self._class_to_cassette_subclass[klass] = self._build_cassette_subclass(klass) - return self._class_to_cassette_subclass[klass] - - def _build_cassette_subclass(self, base_class): - bases = (base_class,) - if not issubclass(base_class, object): # Check for old style class - bases += (object,) - return type('{0}{1}'.format(base_class.__name__, self._cassette._path), - bases, dict(cassette=self._cassette)) - def _httplib(self): yield httplib, 'HTTPConnection', VCRHTTPConnection yield httplib, 'HTTPSConnection', VCRHTTPSConnection