diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index 80820dff91..0edf513a3c 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -68,6 +68,19 @@ from pymongo.mongo_client import MongoClient +def _resumable(exc: PyMongoError) -> bool: + """Return True if given a resumable change stream error.""" + if isinstance(exc, (ConnectionFailure, CursorNotFound)): + return True + if isinstance(exc, OperationFailure): + if exc._max_wire_version is None: + return False + return ( + exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError") + ) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS) + return False + + class ChangeStream(Generic[_DocumentType]): """The internal abstract base class for change stream cursors. @@ -343,20 +356,21 @@ def try_next(self) -> Optional[_DocumentType]: # Attempt to get the next change with at most one getMore and at most # one resume attempt. try: - change = self._cursor._try_next(True) - except (ConnectionFailure, CursorNotFound): - self._resume() - change = self._cursor._try_next(False) - except OperationFailure as exc: - if exc._max_wire_version is None: - raise - is_resumable = ( - exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError") - ) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS) - if not is_resumable: - raise - self._resume() - change = self._cursor._try_next(False) + try: + change = self._cursor._try_next(True) + except PyMongoError as exc: + if not _resumable(exc): + raise + self._resume() + change = self._cursor._try_next(False) + except PyMongoError as exc: + # Close the stream after a fatal error. + if not _resumable(exc) and not exc.timeout: + self.close() + raise + except Exception: + self.close() + raise # Check if the cursor was invalidated. if not self._cursor.alive: diff --git a/test/test_change_stream.py b/test/test_change_stream.py index b5b260086d..a8b793333e 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -486,7 +486,7 @@ def _get_expected_resume_token(self, stream, listener, previous_change=None): return response["cursor"]["postBatchResumeToken"] @no_type_check - def _test_raises_error_on_missing_id(self, expected_exception, expected_exception2): + def _test_raises_error_on_missing_id(self, expected_exception): """ChangeStream will raise an exception if the server response is missing the resume token. """ @@ -494,7 +494,8 @@ def _test_raises_error_on_missing_id(self, expected_exception, expected_exceptio self.watched_collection().insert_one({}) with self.assertRaises(expected_exception): next(change_stream) - with self.assertRaises(expected_exception2): + # The cursor should now be closed. + with self.assertRaises(StopIteration): next(change_stream) @no_type_check @@ -526,14 +527,14 @@ def test_update_resume_token_legacy(self): # Prose test no. 2 @client_context.require_version_min(4, 1, 8) def test_raises_error_on_missing_id_418plus(self): - # Server returns an error on 4.1.8+, subsequent next() resumes and gets the same error. - self._test_raises_error_on_missing_id(OperationFailure, OperationFailure) + # Server returns an error on 4.1.8+ + self._test_raises_error_on_missing_id(OperationFailure) # Prose test no. 2 @client_context.require_version_max(4, 1, 8) def test_raises_error_on_missing_id_418minus(self): - # PyMongo raises an error, closes the cursor, subsequent next() raises StopIteration. - self._test_raises_error_on_missing_id(InvalidOperation, StopIteration) + # PyMongo raises an error + self._test_raises_error_on_missing_id(InvalidOperation) # Prose test no. 3 @no_type_check