diff --git a/mocket/__init__.py b/mocket/__init__.py index ee36a6d5..29250012 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,10 +1,10 @@ try: # Py2 - from mocket import mocketize, Mocket, MocketEntry + from mocket import mocketize, Mocket, MocketEntry, Mocketizer except ImportError: # Py3 - from mocket.mocket import mocketize, Mocket, MocketEntry + from mocket.mocket import mocketize, Mocket, MocketEntry, Mocketizer -__all__ = (mocketize, Mocket, MocketEntry) +__all__ = (mocketize, Mocket, MocketEntry, Mocketizer) __version__ = '2.2.0' diff --git a/mocket/mocket.py b/mocket/mocket.py index 0d1b8095..708c5dee 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -463,17 +463,19 @@ def get_response(self): class Mocketizer(object): - def __init__(self, instance, namespace=None, truesocket_recording_dir=None): + def __init__(self, instance=None, namespace=None, truesocket_recording_dir=None): self.instance = instance self.truesocket_recording_dir = truesocket_recording_dir self.namespace = namespace or text_type(id(self)) def __enter__(self): Mocket.enable(namespace=self.namespace, truesocket_recording_dir=self.truesocket_recording_dir) - self.check_and_call('mocketize_setup') + if self.instance: + self.check_and_call('mocketize_setup') def __exit__(self, type, value, tb): - self.check_and_call('mocketize_teardown') + if self.instance: + self.check_and_call('mocketize_teardown') Mocket.disable() Mocket.reset() diff --git a/tests/main/test_mocket.py b/tests/main/test_mocket.py index 2b745c33..f253da63 100644 --- a/tests/main/test_mocket.py +++ b/tests/main/test_mocket.py @@ -4,7 +4,7 @@ import pytest -from mocket import Mocket, mocketize, MocketEntry +from mocket import Mocket, mocketize, MocketEntry, Mocketizer from mocket.compat import encode_to_bytes @@ -75,7 +75,6 @@ def test_empty_getresponse(self): entry = MocketEntry(('localhost', 8080), []) self.assertEqual(entry.get_response(), encode_to_bytes('')) - @mocketize def test_subsequent_recv_requests_have_correct_length(self): Mocket.register( MocketEntry( @@ -86,13 +85,14 @@ def test_subsequent_recv_requests_have_correct_length(self): ] ) ) - _so = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - _so.connect(('localhost', 80)) - _so.sendall(b'first\r\n') - assert _so.recv(4096) == b'Long payload' - _so.sendall(b'second\r\n') - assert _so.recv(4096) == b'Short' - _so.close() + with Mocketizer(): + _so = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + _so.connect(('localhost', 80)) + _so.sendall(b'first\r\n') + assert _so.recv(4096) == b'Long payload' + _so.sendall(b'second\r\n') + assert _so.recv(4096) == b'Short' + _so.close() class MocketizeTestCase(TestCase):