From 10e781cfb875f811151d835de25a0d2cb2b15fe4 Mon Sep 17 00:00:00 2001 From: Prashant Mital <5883388+prashantmital@users.noreply.github.com> Date: Thu, 27 May 2021 17:16:50 -0700 Subject: [PATCH] PYTHON-1636 Support exhaust cursors in OP_MSG (#629) (cherry picked from commit d26bf933ed22f9355959ed6c957b6fcce1ead228) --- pymongo/cursor.py | 18 ++++++++++---- pymongo/message.py | 43 +++++++++++++++++++------------- pymongo/response.py | 15 +++++++++-- pymongo/server.py | 31 ++++++++++++++--------- test/test_monitoring.py | 55 +++++++++++++++++------------------------ 5 files changed, 94 insertions(+), 68 deletions(-) diff --git a/pymongo/cursor.py b/pymongo/cursor.py index 1bb58aab60..95f00fd5f8 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -81,17 +81,21 @@ class CursorType(object): # This has to be an old style class due to # http://bugs.jython.org/issue1057 -class _SocketManager: +class _ExhaustManager: """Used with exhaust cursors to ensure the socket is returned. """ - def __init__(self, sock, pool): + def __init__(self, sock, pool, more_to_come): self.sock = sock self.pool = pool + self.more_to_come = more_to_come self.__closed = False def __del__(self): self.close() + def update_exhaust(self, more_to_come): + self.more_to_come = more_to_come + def close(self): """Return this instance's socket to the connection pool. """ @@ -1043,10 +1047,14 @@ def __send_message(self, operation): raise self.__address = response.address - if self.__exhaust and not self.__exhaust_mgr: + if self.__exhaust: # 'response' is an ExhaustResponse. - self.__exhaust_mgr = _SocketManager(response.socket_info, - response.pool) + if not self.__exhaust_mgr: + self.__exhaust_mgr = _ExhaustManager(response.socket_info, + response.pool, + response.more_to_come) + else: + self.__exhaust_mgr.update_exhaust(response.more_to_come) cmd_name = operation.name docs = response.docs diff --git a/pymongo/message.py b/pymongo/message.py index ed96133a22..f9d0e9e2e3 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -269,9 +269,11 @@ def namespace(self): def use_command(self, sock_info, exhaust): use_find_cmd = False - if sock_info.max_wire_version >= 4: - if not exhaust: - use_find_cmd = True + if sock_info.max_wire_version >= 4 and not exhaust: + use_find_cmd = True + elif sock_info.max_wire_version >= 8: + # OP_MSG supports exhaust on MongoDB 4.2+ + use_find_cmd = True elif not self.read_concern.ok_for_legacy: raise ConfigurationError( 'read concern level of %s is not valid ' @@ -398,8 +400,15 @@ def namespace(self): return _UJOIN % (self.db, self.coll) def use_command(self, sock_info, exhaust): + use_cmd = False + if sock_info.max_wire_version >= 4 and not exhaust: + use_cmd = True + elif sock_info.max_wire_version >= 8: + # OP_MSG supports exhaust on MongoDB 4.2+ + use_cmd = True + sock_info.validate_session(self.client, self.session) - return sock_info.max_wire_version >= 4 and not exhaust + return use_cmd def as_command(self, sock_info): """Return a getMore command document for this query.""" @@ -433,8 +442,12 @@ def get_message(self, dummy0, sock_info, use_cmd=False): if use_cmd: spec = self.as_command(sock_info)[0] if sock_info.op_msg_enabled: + if self.exhaust_mgr: + flags = _OpMsg.EXHAUST_ALLOWED + else: + flags = 0 request_id, msg, size, _ = _op_msg( - 0, spec, self.db, None, + flags, spec, self.db, None, False, False, self.codec_options, ctx=sock_info.compression_context) return request_id, msg, size @@ -448,27 +461,23 @@ class _RawBatchQuery(_Query): def use_command(self, socket_info, exhaust): # Compatibility checks. super(_RawBatchQuery, self).use_command(socket_info, exhaust) - # Use OP_MSG when available. - if socket_info.op_msg_enabled and not exhaust: + if socket_info.max_wire_version >= 8: + # MongoDB 4.2+ supports exhaust over OP_MSG + return True + elif socket_info.op_msg_enabled and not exhaust: return True return False - def get_message(self, set_slave_ok, sock_info, use_cmd=False): - return super(_RawBatchQuery, self).get_message( - set_slave_ok, sock_info, use_cmd) - class _RawBatchGetMore(_GetMore): def use_command(self, socket_info, exhaust): - # Use OP_MSG when available. - if socket_info.op_msg_enabled and not exhaust: + if socket_info.max_wire_version >= 8: + # MongoDB 4.2+ supports exhaust over OP_MSG + return True + elif socket_info.op_msg_enabled and not exhaust: return True return False - def get_message(self, set_slave_ok, sock_info, use_cmd=False): - return super(_RawBatchGetMore, self).get_message( - set_slave_ok, sock_info, use_cmd) - class _CursorAddress(tuple): """The server address (host, port) of a cursor, with namespace property.""" diff --git a/pymongo/response.py b/pymongo/response.py index 56cc532f57..474e2c4d3b 100644 --- a/pymongo/response.py +++ b/pymongo/response.py @@ -67,11 +67,12 @@ def docs(self): """The decoded document(s).""" return self._docs + class ExhaustResponse(Response): - __slots__ = ('_socket_info', '_pool') + __slots__ = ('_socket_info', '_pool', '_more_to_come') def __init__(self, data, address, socket_info, pool, request_id, duration, - from_command, docs): + from_command, docs, more_to_come): """Represent a response to an exhaust cursor's initial query. :Parameters: @@ -82,6 +83,9 @@ def __init__(self, data, address, socket_info, pool, request_id, duration, - `request_id`: The request id of this operation. - `duration`: The duration of the operation. - `from_command`: If the response is the result of a db command. + - `docs`: List of documents. + - `more_to_come`: Bool indicating whether cursor is ready to be + exhausted. """ super(ExhaustResponse, self).__init__(data, address, @@ -90,6 +94,7 @@ def __init__(self, data, address, socket_info, pool, request_id, duration, from_command, docs) self._socket_info = socket_info self._pool = pool + self._more_to_come = more_to_come @property def socket_info(self): @@ -105,3 +110,9 @@ def socket_info(self): def pool(self): """The Pool from which the SocketInfo came.""" return self._pool + + @property + def more_to_come(self): + """If true, server is ready to send batches on the socket until the + result set is exhausted or there is an error.""" + return self._more_to_come diff --git a/pymongo/server.py b/pymongo/server.py index 6a74b23514..84887b5dbc 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -20,7 +20,7 @@ from pymongo.errors import NotMasterError, OperationFailure from pymongo.helpers import _check_command_response -from pymongo.message import _convert_exception +from pymongo.message import _convert_exception, _OpMsg from pymongo.response import Response, ExhaustResponse from pymongo.server_type import SERVER_TYPE @@ -95,16 +95,15 @@ def run_operation_with_response( if publish: start = datetime.now() - send_message = not operation.exhaust_mgr - - if send_message: - use_cmd = operation.use_command(sock_info, exhaust) + use_cmd = operation.use_command(sock_info, exhaust) + more_to_come = (operation.exhaust_mgr + and operation.exhaust_mgr.more_to_come) + if more_to_come: + request_id = 0 + else: message = operation.get_message( set_slave_okay, sock_info, use_cmd) request_id, data, max_doc_size = self._split_message(message) - else: - use_cmd = False - request_id = 0 if publish: cmd, dbn = operation.as_command(sock_info) @@ -113,11 +112,11 @@ def run_operation_with_response( start = datetime.now() try: - if send_message: + if more_to_come: + reply = sock_info.receive_message(None) + else: sock_info.send_message(data, max_doc_size) reply = sock_info.receive_message(request_id) - else: - reply = sock_info.receive_message(None) # Unpack and check for command errors. if use_cmd: @@ -176,6 +175,13 @@ def run_operation_with_response( decrypted, operation.codec_options, user_fields) if exhaust: + if isinstance(reply, _OpMsg): + # In OP_MSG, the server keeps sending only if the + # more_to_come flag is set. + more_to_come = reply.more_to_come + else: + # In OP_REPLY, the server keeps sending until cursor_id is 0. + more_to_come = bool(reply.cursor_id) response = ExhaustResponse( data=reply, address=self._description.address, @@ -184,7 +190,8 @@ def run_operation_with_response( duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs) + docs=docs, + more_to_come=more_to_come) else: response = Response( data=reply, diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 6b33211159..241bd1d3dc 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -458,7 +458,7 @@ def test_not_master_error(self): @client_context.require_no_mongos def test_exhaust(self): self.client.pymongo_test.test.drop() - self.client.pymongo_test.test.insert_many([{} for _ in range(10)]) + self.client.pymongo_test.test.insert_many([{} for _ in range(11)]) self.listener.results.clear() cursor = self.client.pymongo_test.test.find( projection={'_id': False}, @@ -472,12 +472,10 @@ def test_exhaust(self): self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) - self.assertEqualCommand( - SON([('find', 'test'), - ('filter', {}), - ('projection', {'_id': False}), - ('batchSize', 5)]), - started.command) + self.assertEqualCommand(SON([('find', 'test'), + ('filter', {}), + ('projection', {'_id': False}), + ('batchSize', 5)]), started.command) self.assertEqual('find', started.command_name) self.assertEqual(cursor.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) @@ -498,32 +496,25 @@ def test_exhaust(self): self.listener.results.clear() tuple(cursor) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) - self.assertEqualCommand( - SON([('getMore', cursor_id), - ('collection', 'test'), - ('batchSize', 5)]), - started.command) - self.assertEqual('getMore', started.command_name) - self.assertEqual(cursor.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) - self.assertTrue(isinstance(started.request_id, int)) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) - self.assertTrue(isinstance(succeeded.duration_micros, int)) - self.assertEqual('getMore', succeeded.command_name) - self.assertTrue(isinstance(succeeded.request_id, int)) - self.assertEqual(cursor.address, succeeded.connection_id) - expected_result = { - 'cursor': {'id': 0, - 'ns': 'pymongo_test.test', - 'nextBatch': [{} for _ in range(5)]}, - 'ok': 1} - self.assertEqualReply(expected_result, succeeded.reply) + for event in results['started']: + self.assertTrue(isinstance(event, monitoring.CommandStartedEvent)) + self.assertEqualCommand(SON([('getMore', cursor_id), + ('collection', 'test'), + ('batchSize', 5)]), event.command) + self.assertEqual('getMore', event.command_name) + self.assertEqual(cursor.address, event.connection_id) + self.assertEqual('pymongo_test', event.database_name) + self.assertTrue(isinstance(event.request_id, int)) + for event in results['succeeded']: + self.assertTrue( + isinstance(event, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(event.duration_micros, int)) + self.assertEqual('getMore', event.command_name) + self.assertTrue(isinstance(event.request_id, int)) + self.assertEqual(cursor.address, event.connection_id) + # Last getMore receives a response with cursor id 0. + self.assertEqual(0, results['succeeded'][-1].reply['cursor']['id']) def test_kill_cursors(self): with client_knobs(kill_cursor_frequency=0.01):