Skip to content
60 changes: 34 additions & 26 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
'true_getaddrinfo',
'true_ssl_wrap_socket',
'true_ssl_socket',
'true_ssl_context',
'true_inet_pton',
'create_connection',
'MocketSocket',
'Mocket',
Expand All @@ -48,11 +50,8 @@
true_getaddrinfo = socket.getaddrinfo
true_ssl_wrap_socket = ssl.wrap_socket
true_ssl_socket = ssl.SSLSocket
try:
true_ssl_context = ssl.SSLContext
except AttributeError:
# Python 2.6
true_ssl_context = None
true_ssl_context = ssl.SSLContext
true_inet_pton = socket.inet_pton


class SuperFakeSSLContext(object):
Expand Down Expand Up @@ -98,7 +97,6 @@ def wrap_bio(self, incoming, outcoming, *args, **kwargs):
# FIXME: fake SSLObject implementation
ssl_obj = MocketSocket()
ssl_obj._host = kwargs['server_hostname']
# ssl_obj.fd = outcoming
return ssl_obj

def __getattr__(self, name):
Expand All @@ -119,8 +117,8 @@ class MocketSocket(object):
family = None
type = None
proto = None
_host = '127.0.0.1'
_port = 80
_host = None
_port = None
_address = None
cipher = lambda s: ("ADH", "AES256", "SHA")
compression = lambda s: ssl.OP_NO_COMPRESSION
Expand Down Expand Up @@ -177,8 +175,9 @@ def getsockname(self):
return socket.gethostbyname(self._address[0]), self._address[1]

def getpeercert(self, *args, **kwargs):
if not self._host:
self._host, _ = self._address
if not (self._host and self._port):
self._address = self._host, self._port = Mocket._address

now = datetime.now()
shift = now + timedelta(days=30 * 12)
return {
Expand All @@ -205,21 +204,16 @@ def getpeercert(self, *args, **kwargs):
def unwrap(self):
return self

def write(self, c):
return len(c)
def write(self, data):
return self.send(encode_to_bytes(data))

def fileno(self):
if not self.fd.r_fd:
self.fd.r_fd, self.fd.w_fd = os.pipe()
return self.fd.r_fd
Mocket.r_fd, Mocket.w_fd = os.pipe()
return Mocket.r_fd

def connect(self, address):
self._address = self._host, self._port = address

# def close(self):
# if self.true_socket and self._connected:
# self.true_socket.close()
# self._closed = True
Mocket._address = address

def makefile(self, mode='r', bufsize=-1):
self._mode = mode
Expand All @@ -243,12 +237,17 @@ def sendall(self, data, *args, **kwargs):
self.fd.truncate()
self.fd.seek(0)

def read(self, buffersize):
return self.fd.read(buffersize)

def recv(self, buffersize, flags=None):
if Mocket.r_fd and Mocket.w_fd:
return os.read(Mocket.r_fd, buffersize)
return self.fd.read(buffersize)

def _connect(self): # pragma: no cover
if not self._connected:
self.true_socket.connect(self._address)
self.true_socket.connect(Mocket._address)
self._connected = True

def true_sendall(self, data, *args, **kwargs):
Expand Down Expand Up @@ -318,15 +317,17 @@ def true_sendall(self, data, *args, **kwargs):

def send(self, data, *args, **kwargs): # pragma: no cover
entry = self.get_entry(data)
if entry:
if self._entry != entry:
self.sendall(data, *args, **kwargs)
if entry and self._entry != entry:
self.sendall(data, *args, **kwargs)
self._entry = entry
return len(data)

# def __getattribute__(self, name):
# return super(MocketSocket, self).__getattribute__(name)

def __getattr__(self, name):
# useful when clients call methods on real
# socket we do not provide on the fake one
""" Useful when clients call methods on real
socket we do not provide on the fake one. """
return getattr(self.true_socket, name) # pragma: no cover


Expand All @@ -335,6 +336,8 @@ class Mocket(object):
_requests = []
_namespace = text_type(id(_entries))
_truesocket_recording_dir = None
r_fd = None
w_fd = None

@classmethod
def register(cls, *entries):
Expand Down Expand Up @@ -387,6 +390,10 @@ def enable(namespace=None, truesocket_recording_dir=None):
lambda host, port, family=None, socktype=None, proto=None, flags=None: [(2, 1, 6, '', (host, port))]
ssl.wrap_socket = ssl.__dict__['wrap_socket'] = FakeSSLContext.wrap_socket
ssl.SSLContext = ssl.__dict__['SSLSocket'] = FakeSSLContext
socket.inet_pton = socket.__dict__['inet_pton'] = lambda family, ip: byte_type(
'\x7f\x00\x00\x01',
'utf-8'
)

@staticmethod
def disable():
Expand All @@ -400,6 +407,7 @@ def disable():
ssl.wrap_socket = ssl.__dict__['SSLSocket'] = true_ssl_wrap_socket
ssl.SSLSocket = ssl.__dict__['wrap_socket'] = true_ssl_socket
ssl.SSLContext = ssl.__dict__['SSLSocket'] = true_ssl_context
socket.inet_pton = socket.__dict__['inet_pton'] = true_inet_pton

@classmethod
def get_namespace(cls):
Expand Down
5 changes: 4 additions & 1 deletion mocket/mockhttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def can_handle(self, data):
requestline, _ = decode_from_bytes(data).split(CRLF, 1)
method, path, version = self._parse_requestline(requestline)
except ValueError:
return self == Mocket._last_entry
try:
return self == Mocket._last_entry
except AttributeError:
return False
uri = urlsplit(path)
kw = dict(keep_blank_values=True)
ch = uri.path == self.path and parse_qs(uri.query, **kw) == parse_qs(self.query, **kw) and method == self.method
Expand Down
9 changes: 4 additions & 5 deletions mocket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@


class MocketSocketCore(io.BytesIO):
r_fd = None
w_fd = None

def write(self, content):
super(MocketSocketCore, self).write(content)

if self.r_fd and self.w_fd:
os.write(self.w_fd, content)
from mocket import Mocket

if Mocket.r_fd and Mocket.w_fd:
os.write(Mocket.w_fd, content)
2 changes: 1 addition & 1 deletion runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def runtests(args=None):
if major == 3 and minor >= 5:
python35 = True

pip.main(['install', 'aiohttp'])
pip.main(['install', 'aiohttp', 'async_timeout'])

if not any(a for a in args[1:] if not a.startswith('-')):
args.append('tests/main')
Expand Down
41 changes: 34 additions & 7 deletions tests/tests35/test_http_aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import aiohttp
import asyncio
import async_timeout
from unittest import TestCase

from mocket.mocket import mocketize
Expand All @@ -8,21 +9,47 @@

class AioHttpEntryTestCase(TestCase):
@mocketize
def test_session(self):
def test_http_session(self):
url = 'http://httpbin.org/ip'
body = "asd" * 100
Entry.single_register(Entry.GET, url, body=body, status=404)
Entry.single_register(Entry.POST, url, body=body*2, status=201)

async def main(l):
async with aiohttp.ClientSession(loop=l) as session:
async with session.get(url) as get_response:
assert get_response.status == 404
assert await get_response.text() == body
with async_timeout.timeout(3):
async with session.get(url) as get_response:
assert get_response.status == 404
assert await get_response.text() == body

async with session.post(url, data=body*6) as post_response:
assert post_response.status == 201
assert await post_response.text() == body*2
with async_timeout.timeout(3):
async with session.post(url, data=body * 6) as post_response:
assert post_response.status == 201
assert await post_response.text() == body * 2

loop = asyncio.get_event_loop()
loop.set_debug(True)
loop.run_until_complete(main(loop))

@mocketize
def test_https_session(self):
url = 'https://httpbin.org/ip'
body = "asd" * 100
Entry.single_register(Entry.GET, url, body=body, status=404)
Entry.single_register(Entry.POST, url, body=body*2, status=201)

async def main(l):
async with aiohttp.ClientSession(loop=l) as session:
with async_timeout.timeout(3):
async with session.get(url) as get_response:
assert get_response.status == 404
assert await get_response.text() == body

with async_timeout.timeout(3):
async with session.post(url, data=body * 6) as post_response:
assert post_response.status == 201
assert await post_response.text() == body * 2

loop = asyncio.get_event_loop()
loop.set_debug(True)
loop.run_until_complete(main(loop))