diff --git a/mocket/mocket.py b/mocket/mocket.py index e6c9909f..dceddf73 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -34,6 +34,8 @@ 'true_getaddrinfo', 'true_ssl_wrap_socket', 'true_ssl_socket', + 'true_ssl_context', + 'true_inet_pton', 'create_connection', 'MocketSocket', 'Mocket', @@ -48,11 +50,8 @@ true_getaddrinfo = socket.getaddrinfo true_ssl_wrap_socket = ssl.wrap_socket true_ssl_socket = ssl.SSLSocket -try: - true_ssl_context = ssl.SSLContext -except AttributeError: - # Python 2.6 - true_ssl_context = None +true_ssl_context = ssl.SSLContext +true_inet_pton = socket.inet_pton class SuperFakeSSLContext(object): @@ -98,7 +97,6 @@ def wrap_bio(self, incoming, outcoming, *args, **kwargs): # FIXME: fake SSLObject implementation ssl_obj = MocketSocket() ssl_obj._host = kwargs['server_hostname'] - # ssl_obj.fd = outcoming return ssl_obj def __getattr__(self, name): @@ -119,8 +117,8 @@ class MocketSocket(object): family = None type = None proto = None - _host = '127.0.0.1' - _port = 80 + _host = None + _port = None _address = None cipher = lambda s: ("ADH", "AES256", "SHA") compression = lambda s: ssl.OP_NO_COMPRESSION @@ -177,8 +175,9 @@ def getsockname(self): return socket.gethostbyname(self._address[0]), self._address[1] def getpeercert(self, *args, **kwargs): - if not self._host: - self._host, _ = self._address + if not (self._host and self._port): + self._address = self._host, self._port = Mocket._address + now = datetime.now() shift = now + timedelta(days=30 * 12) return { @@ -205,21 +204,16 @@ def getpeercert(self, *args, **kwargs): def unwrap(self): return self - def write(self, c): - return len(c) + def write(self, data): + return self.send(encode_to_bytes(data)) def fileno(self): - if not self.fd.r_fd: - self.fd.r_fd, self.fd.w_fd = os.pipe() - return self.fd.r_fd + Mocket.r_fd, Mocket.w_fd = os.pipe() + return Mocket.r_fd def connect(self, address): self._address = self._host, self._port = address - - # def close(self): - # if self.true_socket and self._connected: - # self.true_socket.close() - # self._closed = True + Mocket._address = address def makefile(self, mode='r', bufsize=-1): self._mode = mode @@ -243,12 +237,17 @@ def sendall(self, data, *args, **kwargs): self.fd.truncate() self.fd.seek(0) + def read(self, buffersize): + return self.fd.read(buffersize) + def recv(self, buffersize, flags=None): + if Mocket.r_fd and Mocket.w_fd: + return os.read(Mocket.r_fd, buffersize) return self.fd.read(buffersize) def _connect(self): # pragma: no cover if not self._connected: - self.true_socket.connect(self._address) + self.true_socket.connect(Mocket._address) self._connected = True def true_sendall(self, data, *args, **kwargs): @@ -318,15 +317,17 @@ def true_sendall(self, data, *args, **kwargs): def send(self, data, *args, **kwargs): # pragma: no cover entry = self.get_entry(data) - if entry: - if self._entry != entry: - self.sendall(data, *args, **kwargs) + if entry and self._entry != entry: + self.sendall(data, *args, **kwargs) self._entry = entry return len(data) + # def __getattribute__(self, name): + # return super(MocketSocket, self).__getattribute__(name) + def __getattr__(self, name): - # useful when clients call methods on real - # socket we do not provide on the fake one + """ Useful when clients call methods on real + socket we do not provide on the fake one. """ return getattr(self.true_socket, name) # pragma: no cover @@ -335,6 +336,8 @@ class Mocket(object): _requests = [] _namespace = text_type(id(_entries)) _truesocket_recording_dir = None + r_fd = None + w_fd = None @classmethod def register(cls, *entries): @@ -387,6 +390,10 @@ def enable(namespace=None, truesocket_recording_dir=None): lambda host, port, family=None, socktype=None, proto=None, flags=None: [(2, 1, 6, '', (host, port))] ssl.wrap_socket = ssl.__dict__['wrap_socket'] = FakeSSLContext.wrap_socket ssl.SSLContext = ssl.__dict__['SSLSocket'] = FakeSSLContext + socket.inet_pton = socket.__dict__['inet_pton'] = lambda family, ip: byte_type( + '\x7f\x00\x00\x01', + 'utf-8' + ) @staticmethod def disable(): @@ -400,6 +407,7 @@ def disable(): ssl.wrap_socket = ssl.__dict__['SSLSocket'] = true_ssl_wrap_socket ssl.SSLSocket = ssl.__dict__['wrap_socket'] = true_ssl_socket ssl.SSLContext = ssl.__dict__['SSLSocket'] = true_ssl_context + socket.inet_pton = socket.__dict__['inet_pton'] = true_inet_pton @classmethod def get_namespace(cls): diff --git a/mocket/mockhttp.py b/mocket/mockhttp.py index e15c0dc8..0c6a0e36 100644 --- a/mocket/mockhttp.py +++ b/mocket/mockhttp.py @@ -100,7 +100,10 @@ def can_handle(self, data): requestline, _ = decode_from_bytes(data).split(CRLF, 1) method, path, version = self._parse_requestline(requestline) except ValueError: - return self == Mocket._last_entry + try: + return self == Mocket._last_entry + except AttributeError: + return False uri = urlsplit(path) kw = dict(keep_blank_values=True) ch = uri.path == self.path and parse_qs(uri.query, **kw) == parse_qs(self.query, **kw) and method == self.method diff --git a/mocket/utils.py b/mocket/utils.py index 158a7359..3b238007 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -3,11 +3,10 @@ class MocketSocketCore(io.BytesIO): - r_fd = None - w_fd = None - def write(self, content): super(MocketSocketCore, self).write(content) - if self.r_fd and self.w_fd: - os.write(self.w_fd, content) + from mocket import Mocket + + if Mocket.r_fd and Mocket.w_fd: + os.write(Mocket.w_fd, content) diff --git a/runtests.py b/runtests.py index ab37b653..73eed1ad 100644 --- a/runtests.py +++ b/runtests.py @@ -17,7 +17,7 @@ def runtests(args=None): if major == 3 and minor >= 5: python35 = True - pip.main(['install', 'aiohttp']) + pip.main(['install', 'aiohttp', 'async_timeout']) if not any(a for a in args[1:] if not a.startswith('-')): args.append('tests/main') diff --git a/tests/tests35/test_http_aiohttp.py b/tests/tests35/test_http_aiohttp.py index 8acc6b32..f63a000b 100644 --- a/tests/tests35/test_http_aiohttp.py +++ b/tests/tests35/test_http_aiohttp.py @@ -1,5 +1,6 @@ import aiohttp import asyncio +import async_timeout from unittest import TestCase from mocket.mocket import mocketize @@ -8,7 +9,7 @@ class AioHttpEntryTestCase(TestCase): @mocketize - def test_session(self): + def test_http_session(self): url = 'http://httpbin.org/ip' body = "asd" * 100 Entry.single_register(Entry.GET, url, body=body, status=404) @@ -16,13 +17,39 @@ def test_session(self): async def main(l): async with aiohttp.ClientSession(loop=l) as session: - async with session.get(url) as get_response: - assert get_response.status == 404 - assert await get_response.text() == body + with async_timeout.timeout(3): + async with session.get(url) as get_response: + assert get_response.status == 404 + assert await get_response.text() == body - async with session.post(url, data=body*6) as post_response: - assert post_response.status == 201 - assert await post_response.text() == body*2 + with async_timeout.timeout(3): + async with session.post(url, data=body * 6) as post_response: + assert post_response.status == 201 + assert await post_response.text() == body * 2 loop = asyncio.get_event_loop() + loop.set_debug(True) + loop.run_until_complete(main(loop)) + + @mocketize + def test_https_session(self): + url = 'https://httpbin.org/ip' + body = "asd" * 100 + Entry.single_register(Entry.GET, url, body=body, status=404) + Entry.single_register(Entry.POST, url, body=body*2, status=201) + + async def main(l): + async with aiohttp.ClientSession(loop=l) as session: + with async_timeout.timeout(3): + async with session.get(url) as get_response: + assert get_response.status == 404 + assert await get_response.text() == body + + with async_timeout.timeout(3): + async with session.post(url, data=body * 6) as post_response: + assert post_response.status == 201 + assert await post_response.text() == body * 2 + + loop = asyncio.get_event_loop() + loop.set_debug(True) loop.run_until_complete(main(loop))