Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mocket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
except ImportError:
# Py3
from mocket.mocket import mocketize, Mocket, MocketEntry

__all__ = (mocketize, Mocket, MocketEntry)
90 changes: 39 additions & 51 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,15 @@ def get_entry(self, data):

def sendall(self, data, *args, **kwargs):
entry = self.get_entry(data)
if not entry:
return self.true_sendall(data, *args, **kwargs)
entry.collect(data)

if entry:
entry.collect(data)
response = entry.get_response()
else:
response = self.true_sendall(data, *args, **kwargs)

self.fd.seek(0)
self.fd.write(entry.get_response())
self.fd.write(response)
self.fd.seek(0)

def recv(self, buffersize, flags=None):
Expand Down Expand Up @@ -254,50 +258,46 @@ def true_sendall(self, data, *args, **kwargs):
line = gzip_buffer.getvalue()
r_lines.append(line)
encoded_response = b'\r\n'.join(r_lines)
written = len(encoded_response)

# if not available, call the real sendall
except KeyError:
self._connect()
self.true_socket.sendall(data, *args, **kwargs)
written, r = 0, io.BytesIO()
r = io.BytesIO()
while True:
recv = self.true_socket.recv(self._buflen)
r.write(recv)
written += len(recv)
if len(recv) < self._buflen:
if r.write(recv) < self._buflen:
break

response_dict['request'] = req
lines = response_dict['response'] = []
gzipped_lines = response_dict['gzip'] = []

# update the dictionary with the response obtained
encoded_response = r.getvalue()

for line_no, line in enumerate(encoded_response.split(b'\r\n')):
try:
line = decode_utf8(line)
except UnicodeDecodeError:
f = gzip.GzipFile(mode='rb', fileobj=io.BytesIO(line))
try:
line = f.read(len(line))
finally:
f.close()
line = decode_utf8(line)
gzipped_lines.append(line_no + 1)
# dump the resulting dictionary to a JSON file
if Mocket.get_truesocket_recording_dir():
response_dict['request'] = req
lines = response_dict['response'] = []
gzipped_lines = response_dict['gzip'] = []

lines.append(line)
# update the dictionary with the response obtained
for line_no, line in enumerate(encoded_response.split(b'\r\n')):

try:
line = decode_utf8(line)
except UnicodeDecodeError:
f = gzip.GzipFile(mode='rb', fileobj=io.BytesIO(line))
try:
line = f.read(len(line))
finally:
f.close()
line = decode_utf8(line)
gzipped_lines.append(line_no + 1)

lines.append(line)

# dump the resulting dictionary to a JSON file
if self._truesocket_recording_dir:
with io.open(path, mode='w', encoding=encoding) as f:
f.write(decode_utf8(json.dumps(responses, indent=4, sort_keys=True)))

# write the response to the mocket socket
self.fd.write(encoded_response)
# flush the mocket socket
self.fd.seek(- written, 1)
# response back to .sendall() which writes it to the mocket socket and flush the BytesIO
return encoded_response

def send(self, data, *args, **kwargs): # pragma: no cover
entry = self.get_entry(data)
Expand All @@ -313,12 +313,6 @@ def __getattr__(self, name):
return getattr(self.true_socket, name) # pragma: no cover


class RecordingMocketSocket(MocketSocket):
def __init__(self, *args, **kwargs):
super(RecordingMocketSocket, self).__init__(*args, **kwargs)
self._truesocket_recording_dir = True


class Mocket(object):
_entries = collections.defaultdict(list)
_requests = []
Expand Down Expand Up @@ -358,23 +352,17 @@ def remove_last_request(cls):

@staticmethod
def enable(namespace=None, truesocket_recording_dir=None):
if namespace:
Mocket._namespace = namespace
Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir

if truesocket_recording_dir:
# JSON dumps will be saved here
assert os.path.isdir(truesocket_recording_dir)
Mocket._truesocket_recording_dir = truesocket_recording_dir

socket.socket = socket.__dict__['socket'] = RecordingMocketSocket
socket._socketobject = socket.__dict__['_socketobject'] = RecordingMocketSocket
socket.SocketType = socket.__dict__['SocketType'] = RecordingMocketSocket
ssl.SSLSocket = ssl.__dict__['SSLSocket'] = RecordingMocketSocket
else:
socket.socket = socket.__dict__['socket'] = MocketSocket
socket._socketobject = socket.__dict__['_socketobject'] = MocketSocket
socket.SocketType = socket.__dict__['SocketType'] = MocketSocket
ssl.SSLSocket = ssl.__dict__['SSLSocket'] = MocketSocket

socket.socket = socket.__dict__['socket'] = MocketSocket
socket._socketobject = socket.__dict__['_socketobject'] = MocketSocket
socket.SocketType = socket.__dict__['SocketType'] = MocketSocket
ssl.SSLSocket = ssl.__dict__['SSLSocket'] = MocketSocket
socket.create_connection = socket.__dict__['create_connection'] = create_connection
socket.gethostname = socket.__dict__['gethostname'] = lambda: 'localhost'
socket.gethostbyname = socket.__dict__['gethostbyname'] = lambda host: '127.0.0.1'
Expand Down