Skip to content

Commit

Permalink
Merge pull request #15 from ShaneHarvey/PYTHON-1377
Browse files Browse the repository at this point in the history
PYTHON-1377 Implement OP_MSG parsing and responses in MockupDB
  • Loading branch information
ajdavis committed Jun 19, 2018
2 parents 3ea04af + 4ce8055 commit 571f6fb
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 66 deletions.
285 changes: 222 additions & 63 deletions mockupdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ def reraise(exctype, value, trace=None):
'MockupDB', 'go', 'going', 'Future', 'wait_until', 'interactive_server',

'OP_REPLY', 'OP_UPDATE', 'OP_INSERT', 'OP_QUERY', 'OP_GET_MORE',
'OP_DELETE', 'OP_KILL_CURSORS',
'OP_DELETE', 'OP_KILL_CURSORS', 'OP_MSG',

'QUERY_FLAGS', 'UPDATE_FLAGS', 'INSERT_FLAGS', 'DELETE_FLAGS',
'REPLY_FLAGS',
'REPLY_FLAGS', 'OP_MSG_FLAGS',

'Request', 'Command', 'OpQuery', 'OpGetMore', 'OpKillCursors', 'OpInsert',
'OpUpdate', 'OpDelete', 'OpReply',
'OpUpdate', 'OpDelete', 'OpReply', 'OpMsg',

'Matcher', 'absent',
]
Expand Down Expand Up @@ -239,6 +239,7 @@ def wait_until(predicate, success_description, timeout=10):
OP_GET_MORE = 2005
OP_DELETE = 2006
OP_KILL_CURSORS = 2007
OP_MSG = 2013

QUERY_FLAGS = OrderedDict([
('TailableCursor', 2),
Expand All @@ -263,7 +264,13 @@ def wait_until(predicate, success_description, timeout=10):
('CursorNotFound', 1),
('QueryFailure', 2)])

OP_MSG_FLAGS = OrderedDict([
('checksumPresent', 1),
('moreToCome', 2)])

_UNPACK_BYTE = struct.Struct("<b").unpack
_UNPACK_INT = struct.Struct("<i").unpack
_UNPACK_UINT = struct.Struct("<I").unpack
_UNPACK_LONG = struct.Struct("<q").unpack


Expand Down Expand Up @@ -554,6 +561,128 @@ def __repr__(self):
return '%s(%s)' % (name, ', '.join(str(part) for part in parts))


class CommandBase(Request):
"""A command the client executes on the server."""
is_command = True

# Check command name case-insensitively.
_non_matched_attrs = Request._non_matched_attrs + ('command_name', )

@property
def command_name(self):
"""The command name or None.
>>> Command({'count': 'collection'}).command_name
'count'
>>> Command('aggregate', 'collection', cursor=absent).command_name
'aggregate'
"""
if self.docs and self.docs[0]:
return list(self.docs[0])[0]

def _matches_docs(self, docs, other_docs):
assert len(docs) == len(other_docs) == 1
doc, = docs
other_doc, = other_docs
items = list(doc.items())
other_items = list(other_doc.items())

# Compare command name case-insensitively.
if items and other_items:
if items[0][0].lower() != other_items[0][0].lower():
return False
if not _bson_values_equal(items[0][1], other_items[0][1]):
return False
return super(CommandBase, self)._matches_docs(
[OrderedDict(items[1:])],
[OrderedDict(other_items[1:])])


class OpMsg(CommandBase):
"""An OP_MSG request the client executes on the server."""
opcode = OP_MSG
is_command = True
_flags_map = OP_MSG_FLAGS

@classmethod
def unpack(cls, msg, client, server, request_id):
"""Parse message and return an `OpMsg`.
Takes the client message as bytes, the client and server socket objects,
and the client request id.
"""
flags, = _UNPACK_UINT(msg[:4])
pos = 4
first_payload_type, = _UNPACK_BYTE(msg[pos:pos+1])
pos += 1
first_payload_size, = _UNPACK_INT(msg[pos:pos+4])
if flags != 0:
raise ValueError('OP_MSG flag must be 0 not %r' % (flags,))
if first_payload_type != 0:
raise ValueError('First OP_MSG payload type must be 0 not %r' % (
first_payload_type,))

# Parse the initial document and add the optional payload type 1.
payload_document = _bson.decode_all(msg[pos:pos+first_payload_size],
CODEC_OPTIONS)[0]
pos += first_payload_size
if len(msg) != pos:
payload_type, = _UNPACK_BYTE(msg[pos:pos+1])
pos += 1
if payload_type != 1:
raise ValueError('Second OP_MSG payload type must be 1 not %r'
% (payload_type,))
section_size, = _UNPACK_INT(msg[pos:pos+4])
pos += 4
if len(msg) != pos + section_size:
raise ValueError('More than two OP_MSG sections unsupported')
identifier, pos = _get_c_string(msg, pos)
documents = _bson.decode_all(msg[pos:], CODEC_OPTIONS)
payload_document[identifier] = documents

database = payload_document['$db']
return OpMsg(payload_document, namespace=database, flags=flags,
_client=client, request_id=request_id,
_server=server)

def __init__(self, *args, **kwargs):
super(OpMsg, self).__init__(*args, **kwargs)
if len(self._docs) > 1:
raise_args_err('OpMsg too many documents', ValueError)

@property
def slave_ok(self):
"""True if this OpMsg can read from a secondary."""
read_preference = self.doc.get('$readPreference')
return read_preference and read_preference.get('mode') != 'primary'

slave_okay = slave_ok
"""Synonym for `.slave_ok`."""

@property
def command_name(self):
"""The command name or None.
>>> OpMsg({'count': 'collection'}).command_name
'count'
>>> OpMsg('aggregate', 'collection', cursor=absent).command_name
'aggregate'
"""
if self.docs and self.docs[0]:
return list(self.docs[0])[0]

def _replies(self, *args, **kwargs):
reply = make_op_msg_reply(*args, **kwargs)
if not reply.docs:
reply.docs = [{'ok': 1}]
else:
if len(reply.docs) > 1:
raise ValueError('OP_MSG reply with multiple documents: %s'
% (reply.docs, ))
reply.doc.setdefault('ok', 1)
super(OpMsg, self)._replies(reply)


class OpQuery(Request):
"""A query (besides a command) the client executes on the server.
Expand Down Expand Up @@ -635,41 +764,8 @@ def __repr__(self):
return rep + ')'


class Command(OpQuery):
class Command(CommandBase, OpQuery):
"""A command the client executes on the server."""
is_command = True

# Check command name case-insensitively.
_non_matched_attrs = OpQuery._non_matched_attrs + ('command_name', )

@property
def command_name(self):
"""The command name or None.
>>> Command({'count': 'collection'}).command_name
'count'
>>> Command('aggregate', 'collection', cursor=absent).command_name
'aggregate'
"""
if self.docs and self.docs[0]:
return list(self.docs[0])[0]

def _matches_docs(self, docs, other_docs):
assert len(docs) == len(other_docs) == 1
doc, = docs
other_doc, = other_docs
items = list(doc.items())
other_items = list(other_doc.items())

# Compare command name case-insensitively.
if items and other_items:
if items[0][0].lower() != other_items[0][0].lower():
return False
if not _bson_values_equal(items[0][1], other_items[0][1]):
return False
return super(Command, self)._matches_docs(
[OrderedDict(items[1:])],
[OrderedDict(other_items[1:])])

def _replies(self, *args, **kwargs):
reply = make_reply(*args, **kwargs)
Expand Down Expand Up @@ -823,23 +919,12 @@ def unpack(cls, msg, client, server, request_id):
request_id=request_id, _server=server)


class OpReply(object):
class Reply(object):
"""A reply from `MockupDB` to the client."""
def __init__(self, *args, **kwargs):
self._flags = kwargs.pop('flags', 0)
self._cursor_id = kwargs.pop('cursor_id', 0)
self._starting_from = kwargs.pop('starting_from', 0)
self._docs = make_docs(*args, **kwargs)

@property
def docs(self):
"""The reply documents, if any."""
return self._docs

@docs.setter
def docs(self, docs):
self._docs = make_docs(docs)

@property
def doc(self):
"""Contents of reply.
Expand All @@ -850,6 +935,35 @@ def doc(self):
assert len(self._docs) == 1, '%s has more than one document' % self
return self._docs[0]

def __str__(self):
return docs_repr(*self._docs)

def __repr__(self):
rep = '%s(%s' % (self.__class__.__name__, self)
if self._flags:
rep += ', flags=' + '|'.join(
name for name, value in REPLY_FLAGS.items()
if self._flags & value)

return rep + ')'


class OpReply(Reply):
"""An OP_REPLY reply from `MockupDB` to the client."""
def __init__(self, *args, **kwargs):
self._cursor_id = kwargs.pop('cursor_id', 0)
self._starting_from = kwargs.pop('starting_from', 0)
super(OpReply, self).__init__(*args, **kwargs)

@property
def docs(self):
"""The reply documents, if any."""
return self._docs

@docs.setter
def docs(self, docs):
self._docs = make_docs(docs)

def update(self, *args, **kwargs):
"""Update the document. Same as ``dict().update()``.
Expand Down Expand Up @@ -879,20 +993,47 @@ def reply_bytes(self, request):
message += struct.pack("<i", OP_REPLY)
return message + data

def __str__(self):
return docs_repr(*self._docs)

def __repr__(self):
rep = '%s(%s' % (self.__class__.__name__, self)
if self._starting_from:
rep += ', starting_from=%d' % self._starting_from
class OpMsgReply(Reply):
"""A OP_MSG reply from `MockupDB` to the client."""
def __init__(self, *args, **kwargs):
super(OpMsgReply, self).__init__(*args, **kwargs)
assert len(self._docs) <= 1, 'OpMsgReply can only have one document'

if self._flags:
rep += ', flags=' + '|'.join(
name for name, value in REPLY_FLAGS.items()
if self._flags & value)
@property
def docs(self):
"""The reply documents, if any."""
return self._docs

return rep + ')'
@docs.setter
def docs(self, docs):
self._docs = make_docs(docs)
assert len(self._docs) == 1, 'OpMsgReply must have one document'

def update(self, *args, **kwargs):
"""Update the document. Same as ``dict().update()``.
>>> reply = OpMsgReply({'ismaster': True})
>>> reply.update(maxWireVersion=3)
>>> reply.doc['maxWireVersion']
3
>>> reply.update({'maxWriteBatchSize': 10, 'msg': 'isdbgrid'})
"""
self.doc.update(*args, **kwargs)

def reply_bytes(self, request):
"""Take a `Request` and return an OP_MSG message as bytes."""
flags = struct.pack("<I", self._flags)
payload_type = struct.pack("<b", 0)
payload_data = _bson.BSON.encode(self.doc)
data = b''.join([flags, payload_type, payload_data])

reply_id = random.randint(0, 1000000)
response_to = request.request_id

header = struct.pack(
"<iiii", 16 + len(data), reply_id, response_to, OP_MSG)
return header + data


absent = {'absent': 1}
Expand Down Expand Up @@ -1070,8 +1211,15 @@ def __init__(self, port=None, verbose=False,
{'ismaster': True,
'minWireVersion': min_wire_version,
'maxWireVersion': max_wire_version})
if max_wire_version >= 6:
self.autoresponds(OpMsg('ismaster'),
{'ismaster': True,
'minWireVersion': min_wire_version,
'maxWireVersion': max_wire_version})
elif auto_ismaster:
self.autoresponds(Command('ismaster'), auto_ismaster)
if max_wire_version >= 6:
self.autoresponds(OpMsg('ismaster'), auto_ismaster)

@_synchronized
def run(self):
Expand Down Expand Up @@ -1509,7 +1657,8 @@ def bind_socket(address):
raise socket.error('could not bind socket')


OPCODES = {OP_QUERY: OpQuery,
OPCODES = {OP_MSG: OpMsg,
OP_QUERY: OpQuery,
OP_INSERT: OpInsert,
OP_UPDATE: OpUpdate,
OP_DELETE: OpDelete,
Expand Down Expand Up @@ -1554,7 +1703,7 @@ def mock_server_receive(sock, length):


def make_docs(*args, **kwargs):
"""Make the documents for a `Request` or `OpReply`.
"""Make the documents for a `Request` or `Reply`.
Takes a variety of argument styles, returns a list of dicts.
Expand Down Expand Up @@ -1646,14 +1795,24 @@ def make_prototype_request(*args, **kwargs):

def make_reply(*args, **kwargs):
# Error we might raise.
if args and isinstance(args[0], OpReply):
if args and isinstance(args[0], (OpReply, OpMsgReply)):
if args[1:] or kwargs:
raise_args_err("can't interpret args")
return args[0]

return OpReply(*args, **kwargs)


def make_op_msg_reply(*args, **kwargs):
# Error we might raise.
if args and isinstance(args[0], (OpReply, OpMsgReply)):
if args[1:] or kwargs:
raise_args_err("can't interpret args")
return args[0]

return OpMsgReply(*args, **kwargs)


def unprefixed(bson_str):
rep = unicode(repr(bson_str))
if rep.startswith(u'u"') or rep.startswith(u"u'"):
Expand Down

0 comments on commit 571f6fb

Please sign in to comment.