diff --git a/mocket/__init__.py b/mocket/__init__.py index 1012569a..221b3cc3 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -4,3 +4,5 @@ except ImportError: # Py3 from mocket.mocket import mocketize, Mocket, MocketEntry + +__all__ = (mocketize, Mocket, MocketEntry) diff --git a/mocket/mocket.py b/mocket/mocket.py index dc893870..aba54059 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -190,11 +190,15 @@ def get_entry(self, data): def sendall(self, data, *args, **kwargs): entry = self.get_entry(data) - if not entry: - return self.true_sendall(data, *args, **kwargs) - entry.collect(data) + + if entry: + entry.collect(data) + response = entry.get_response() + else: + response = self.true_sendall(data, *args, **kwargs) + self.fd.seek(0) - self.fd.write(entry.get_response()) + self.fd.write(response) self.fd.seek(0) def recv(self, buffersize, flags=None): @@ -254,50 +258,46 @@ def true_sendall(self, data, *args, **kwargs): line = gzip_buffer.getvalue() r_lines.append(line) encoded_response = b'\r\n'.join(r_lines) - written = len(encoded_response) # if not available, call the real sendall except KeyError: self._connect() self.true_socket.sendall(data, *args, **kwargs) - written, r = 0, io.BytesIO() + r = io.BytesIO() while True: recv = self.true_socket.recv(self._buflen) - r.write(recv) - written += len(recv) - if len(recv) < self._buflen: + if r.write(recv) < self._buflen: break - response_dict['request'] = req - lines = response_dict['response'] = [] - gzipped_lines = response_dict['gzip'] = [] - - # update the dictionary with the response obtained encoded_response = r.getvalue() - for line_no, line in enumerate(encoded_response.split(b'\r\n')): - try: - line = decode_utf8(line) - except UnicodeDecodeError: - f = gzip.GzipFile(mode='rb', fileobj=io.BytesIO(line)) - try: - line = f.read(len(line)) - finally: - f.close() - line = decode_utf8(line) - gzipped_lines.append(line_no + 1) + # dump the resulting dictionary to a JSON file + if Mocket.get_truesocket_recording_dir(): + response_dict['request'] = req + lines = response_dict['response'] = [] + gzipped_lines = response_dict['gzip'] = [] - lines.append(line) + # update the dictionary with the response obtained + for line_no, line in enumerate(encoded_response.split(b'\r\n')): + + try: + line = decode_utf8(line) + except UnicodeDecodeError: + f = gzip.GzipFile(mode='rb', fileobj=io.BytesIO(line)) + try: + line = f.read(len(line)) + finally: + f.close() + line = decode_utf8(line) + gzipped_lines.append(line_no + 1) + + lines.append(line) - # dump the resulting dictionary to a JSON file - if self._truesocket_recording_dir: with io.open(path, mode='w', encoding=encoding) as f: f.write(decode_utf8(json.dumps(responses, indent=4, sort_keys=True))) - # write the response to the mocket socket - self.fd.write(encoded_response) - # flush the mocket socket - self.fd.seek(- written, 1) + # response back to .sendall() which writes it to the mocket socket and flush the BytesIO + return encoded_response def send(self, data, *args, **kwargs): # pragma: no cover entry = self.get_entry(data) @@ -313,12 +313,6 @@ def __getattr__(self, name): return getattr(self.true_socket, name) # pragma: no cover -class RecordingMocketSocket(MocketSocket): - def __init__(self, *args, **kwargs): - super(RecordingMocketSocket, self).__init__(*args, **kwargs) - self._truesocket_recording_dir = True - - class Mocket(object): _entries = collections.defaultdict(list) _requests = [] @@ -358,23 +352,17 @@ def remove_last_request(cls): @staticmethod def enable(namespace=None, truesocket_recording_dir=None): - if namespace: - Mocket._namespace = namespace + Mocket._namespace = namespace + Mocket._truesocket_recording_dir = truesocket_recording_dir + if truesocket_recording_dir: # JSON dumps will be saved here assert os.path.isdir(truesocket_recording_dir) - Mocket._truesocket_recording_dir = truesocket_recording_dir - - socket.socket = socket.__dict__['socket'] = RecordingMocketSocket - socket._socketobject = socket.__dict__['_socketobject'] = RecordingMocketSocket - socket.SocketType = socket.__dict__['SocketType'] = RecordingMocketSocket - ssl.SSLSocket = ssl.__dict__['SSLSocket'] = RecordingMocketSocket - else: - socket.socket = socket.__dict__['socket'] = MocketSocket - socket._socketobject = socket.__dict__['_socketobject'] = MocketSocket - socket.SocketType = socket.__dict__['SocketType'] = MocketSocket - ssl.SSLSocket = ssl.__dict__['SSLSocket'] = MocketSocket + socket.socket = socket.__dict__['socket'] = MocketSocket + socket._socketobject = socket.__dict__['_socketobject'] = MocketSocket + socket.SocketType = socket.__dict__['SocketType'] = MocketSocket + ssl.SSLSocket = ssl.__dict__['SSLSocket'] = MocketSocket socket.create_connection = socket.__dict__['create_connection'] = create_connection socket.gethostname = socket.__dict__['gethostname'] = lambda: 'localhost' socket.gethostbyname = socket.__dict__['gethostbyname'] = lambda host: '127.0.0.1'