diff --git a/doc/common-issues.rst b/doc/common-issues.rst index f0c9716689..3d2d06a5a7 100644 --- a/doc/common-issues.rst +++ b/doc/common-issues.rst @@ -78,7 +78,7 @@ will receive the following error:: File "/Library/Python/2.7/site-packages/pymongo/collection.py", line 1560, in count return self._count(cmd, collation, session) File "/Library/Python/2.7/site-packages/pymongo/collection.py", line 1504, in _count - with self._socket_for_reads() as (sock_info, slave_ok): + with self._socket_for_reads() as (connection, slave_ok): File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/contextlib.py", line 17, in __enter__ return self.gen.next() File "/Library/Python/2.7/site-packages/pymongo/mongo_client.py", line 982, in _socket_for_reads diff --git a/pymongo/aggregation.py b/pymongo/aggregation.py index cd86564fef..580bb1b130 100644 --- a/pymongo/aggregation.py +++ b/pymongo/aggregation.py @@ -29,7 +29,7 @@ from pymongo.collection import Collection from pymongo.command_cursor import CommandCursor from pymongo.database import Database - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.read_preferences import _ServerMode from pymongo.server import Server from pymongo.typings import _Pipeline @@ -52,7 +52,7 @@ def __init__( explicit_session: bool, let: Optional[Mapping[str, Any]] = None, user_fields: Optional[MutableMapping[str, Any]] = None, - result_processor: Optional[Callable[[Mapping[str, Any], SocketInfo], None]] = None, + result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None, comment: Any = None, ) -> None: if "explain" in options: @@ -134,7 +134,7 @@ def get_cursor( self, session: ClientSession, server: Server, - sock_info: SocketInfo, + conn: Connection, read_preference: _ServerMode, ) -> CommandCursor: # Serialize command. @@ -146,7 +146,7 @@ def get_cursor( # - server version is >= 4.2 or # - server version is >= 3.2 and pipeline doesn't use $out if ("readConcern" not in cmd) and ( - not self._performs_write or (sock_info.max_wire_version >= 8) + not self._performs_write or (conn.max_wire_version >= 8) ): read_concern = self._target.read_concern else: @@ -161,7 +161,7 @@ def get_cursor( write_concern = None # Run command. - result = sock_info.command( + result = conn.command( self._database.name, cmd, read_preference, @@ -176,7 +176,7 @@ def get_cursor( ) if self._result_processor: - self._result_processor(result, sock_info) + self._result_processor(result, conn) # Extract cursor from result or mock/fake one if necessary. if "cursor" in result: @@ -193,14 +193,14 @@ def get_cursor( cmd_cursor = self._cursor_class( self._cursor_collection(cursor), cursor, - sock_info.address, + conn.address, batch_size=self._batch_size or 0, max_await_time_ms=self._max_await_time_ms, session=session, explicit_session=self._explicit_session, comment=self._options.get("comment"), ) - cmd_cursor._maybe_pin_connection(sock_info) + cmd_cursor._maybe_pin_connection(conn) return cmd_cursor diff --git a/pymongo/auth.py b/pymongo/auth.py index b41e885420..063df41e40 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from pymongo.hello import Hello - from pymongo.pool import SocketInfo + from pymongo.pool import Connection HAVE_KERBEROS = True _USE_PRINCIPAL = False @@ -220,9 +220,7 @@ def _authenticate_scram_start( return nonce, first_bare, cmd -def _authenticate_scram( - credentials: MongoCredential, sock_info: SocketInfo, mechanism: str -) -> None: +def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None: """Authenticate using SCRAM.""" username = credentials.username if mechanism == "SCRAM-SHA-256": @@ -239,13 +237,13 @@ def _authenticate_scram( # Make local _hmac = hmac.HMAC - ctx = sock_info.auth_ctx + ctx = conn.auth_ctx if ctx and ctx.speculate_succeeded(): nonce, first_bare = ctx.scram_data res = ctx.speculative_authenticate else: nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism) - res = sock_info.command(source, cmd) + res = conn.command(source, cmd) server_first = res["payload"] parsed = _parse_scram_response(server_first) @@ -285,7 +283,7 @@ def _authenticate_scram( ("payload", Binary(client_final)), ] ) - res = sock_info.command(source, cmd) + res = conn.command(source, cmd) parsed = _parse_scram_response(res["payload"]) if not hmac.compare_digest(parsed[b"v"], server_sig): @@ -301,7 +299,7 @@ def _authenticate_scram( ("payload", Binary(b"")), ] ) - res = sock_info.command(source, cmd) + res = conn.command(source, cmd) if not res["done"]: raise OperationFailure("SASL conversation failed to complete.") @@ -345,7 +343,7 @@ def _canonicalize_hostname(hostname: str) -> str: return name[0].lower() -def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> None: +def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None: """Authenticate using GSSAPI.""" if not HAVE_KERBEROS: raise ConfigurationError( @@ -358,7 +356,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> props = credentials.mechanism_properties # Starting here and continuing through the while loop below - establish # the security context. See RFC 4752, Section 3.1, first paragraph. - host = sock_info.address[0] + host = conn.address[0] if props.canonicalize_host_name: host = _canonicalize_hostname(host) service = props.service_name + "@" + host @@ -413,7 +411,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> ("autoAuthorize", 1), ] ) - response = sock_info.command("$external", cmd) + response = conn.command("$external", cmd) # Limit how many times we loop to catch protocol / library issues for _ in range(10): @@ -430,7 +428,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> ("payload", payload), ] ) - response = sock_info.command("$external", cmd) + response = conn.command("$external", cmd) if result == kerberos.AUTH_GSS_COMPLETE: break @@ -453,7 +451,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> ("payload", payload), ] ) - sock_info.command("$external", cmd) + conn.command("$external", cmd) finally: kerberos.authGSSClientClean(ctx) @@ -462,7 +460,7 @@ def _authenticate_gssapi(credentials: MongoCredential, sock_info: SocketInfo) -> raise OperationFailure(str(exc)) -def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) -> None: +def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None: """Authenticate using SASL PLAIN (RFC 4616)""" source = credentials.source username = credentials.username @@ -476,52 +474,50 @@ def _authenticate_plain(credentials: MongoCredential, sock_info: SocketInfo) -> ("autoAuthorize", 1), ] ) - sock_info.command(source, cmd) + conn.command(source, cmd) -def _authenticate_x509(credentials: MongoCredential, sock_info: SocketInfo) -> None: +def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None: """Authenticate using MONGODB-X509.""" - ctx = sock_info.auth_ctx + ctx = conn.auth_ctx if ctx and ctx.speculate_succeeded(): # MONGODB-X509 is done after the speculative auth step. return - cmd = _X509Context(credentials, sock_info.address).speculate_command() - sock_info.command("$external", cmd) + cmd = _X509Context(credentials, conn.address).speculate_command() + conn.command("$external", cmd) -def _authenticate_mongo_cr(credentials: MongoCredential, sock_info: SocketInfo) -> None: +def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None: """Authenticate using MONGODB-CR.""" source = credentials.source username = credentials.username password = credentials.password # Get a nonce - response = sock_info.command(source, {"getnonce": 1}) + response = conn.command(source, {"getnonce": 1}) nonce = response["nonce"] key = _auth_key(nonce, username, password) # Actually authenticate query = SON([("authenticate", 1), ("user", username), ("nonce", nonce), ("key", key)]) - sock_info.command(source, query) + conn.command(source, query) -def _authenticate_default(credentials: MongoCredential, sock_info: SocketInfo) -> None: - if sock_info.max_wire_version >= 7: - if sock_info.negotiated_mechs: - mechs = sock_info.negotiated_mechs +def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None: + if conn.max_wire_version >= 7: + if conn.negotiated_mechs: + mechs = conn.negotiated_mechs else: source = credentials.source - cmd = sock_info.hello_cmd() + cmd = conn.hello_cmd() cmd["saslSupportedMechs"] = source + "." + credentials.username - mechs = sock_info.command(source, cmd, publish_events=False).get( - "saslSupportedMechs", [] - ) + mechs = conn.command(source, cmd, publish_events=False).get("saslSupportedMechs", []) if "SCRAM-SHA-256" in mechs: - return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-256") + return _authenticate_scram(credentials, conn, "SCRAM-SHA-256") else: - return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1") + return _authenticate_scram(credentials, conn, "SCRAM-SHA-1") else: - return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1") + return _authenticate_scram(credentials, conn, "SCRAM-SHA-1") _AUTH_MAP: Mapping[str, Callable] = { @@ -606,12 +602,12 @@ def speculate_command(self) -> Optional[MutableMapping[str, Any]]: def authenticate( - credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool = False + credentials: MongoCredential, conn: Connection, reauthenticate: bool = False ) -> None: - """Authenticate sock_info.""" + """Authenticate connection.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP[mechanism] if mechanism == "MONGODB-OIDC": - _authenticate_oidc(credentials, sock_info, reauthenticate) + _authenticate_oidc(credentials, conn, reauthenticate) else: - auth_func(credentials, sock_info) + auth_func(credentials, conn) diff --git a/pymongo/auth_aws.py b/pymongo/auth_aws.py index edefd3c930..2f7dcb857f 100644 --- a/pymongo/auth_aws.py +++ b/pymongo/auth_aws.py @@ -49,7 +49,7 @@ def set_cached_credentials(creds): if TYPE_CHECKING: from bson.typings import _ReadableBuffer from pymongo.auth import MongoCredential - from pymongo.pool import SocketInfo + from pymongo.pool import Connection class _AwsSaslContext(AwsSaslContext): # type: ignore @@ -67,7 +67,7 @@ def bson_decode(self, data: _ReadableBuffer) -> Mapping[str, Any]: return bson.decode(data) -def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> None: +def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None: """Authenticate using MONGODB-AWS.""" if not _HAVE_MONGODB_AWS: raise ConfigurationError( @@ -75,7 +75,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No "install with: python -m pip install 'pymongo[aws]'" ) - if sock_info.max_wire_version < 9: + if conn.max_wire_version < 9: raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later") try: @@ -90,7 +90,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No client_first = SON( [("saslStart", 1), ("mechanism", "MONGODB-AWS"), ("payload", client_payload)] ) - server_first = sock_info.command("$external", client_first) + server_first = conn.command("$external", client_first) res = server_first # Limit how many times we loop to catch protocol / library issues for _ in range(10): @@ -102,7 +102,7 @@ def _authenticate_aws(credentials: MongoCredential, sock_info: SocketInfo) -> No ("payload", client_payload), ] ) - res = sock_info.command("$external", cmd) + res = conn.command("$external", cmd) if res["done"]: # SASL complete. break diff --git a/pymongo/auth_oidc.py b/pymongo/auth_oidc.py index 62648b2c0a..e5d0afeb89 100644 --- a/pymongo/auth_oidc.py +++ b/pymongo/auth_oidc.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from pymongo.auth import MongoCredential - from pymongo.pool import SocketInfo + from pymongo.pool import Connection @dataclass @@ -242,25 +242,23 @@ def clear(self) -> None: self.idp_resp = None self.token_exp_utc = None - def run_command( - self, sock_info: SocketInfo, cmd: Mapping[str, Any] - ) -> Optional[Mapping[str, Any]]: + def run_command(self, conn: Connection, cmd: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: try: - return sock_info.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] + return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] except OperationFailure as exc: self.clear() if exc.code == _REAUTHENTICATION_REQUIRED_CODE: if "jwt" in bson.decode(cmd["payload"]): if self.idp_info_gen_id > self.reauth_gen_id: raise - return self.authenticate(sock_info, reauthenticate=True) + return self.authenticate(conn, reauthenticate=True) raise def authenticate( - self, sock_info: SocketInfo, reauthenticate: bool = False + self, conn: Connection, reauthenticate: bool = False ) -> Optional[Mapping[str, Any]]: if reauthenticate: - prev_id = getattr(sock_info, "oidc_token_gen_id", None) + prev_id = getattr(conn, "oidc_token_gen_id", None) # Check if we've already changed tokens. if prev_id == self.token_gen_id: self.reauth_gen_id = self.idp_info_gen_id @@ -268,7 +266,7 @@ def authenticate( if not self.properties.refresh_token_callback: self.clear() - ctx = sock_info.auth_ctx + ctx = conn.auth_ctx cmd = None if ctx and ctx.speculate_succeeded(): @@ -276,10 +274,10 @@ def authenticate( else: cmd = self.auth_start_cmd() assert cmd is not None - resp = self.run_command(sock_info, cmd) + resp = self.run_command(conn, cmd) if resp["done"]: - sock_info.oidc_token_gen_id = self.token_gen_id + conn.oidc_token_gen_id = self.token_gen_id return None server_resp: Dict = bson.decode(resp["payload"]) @@ -289,7 +287,7 @@ def authenticate( conversation_id = resp["conversationId"] token = self.get_current_token() - sock_info.oidc_token_gen_id = self.token_gen_id + conn.oidc_token_gen_id = self.token_gen_id bin_payload = Binary(bson.encode({"jwt": token})) cmd = SON( [ @@ -298,7 +296,7 @@ def authenticate( ("payload", bin_payload), ] ) - resp = self.run_command(sock_info, cmd) + resp = self.run_command(conn, cmd) if not resp["done"]: self.clear() raise OperationFailure("SASL conversation failed to complete.") @@ -306,8 +304,8 @@ def authenticate( def _authenticate_oidc( - credentials: MongoCredential, sock_info: SocketInfo, reauthenticate: bool + credentials: MongoCredential, conn: Connection, reauthenticate: bool ) -> Optional[Mapping[str, Any]]: """Authenticate using MONGODB-OIDC.""" - authenticator = _get_authenticator(credentials, sock_info.address) - return authenticator.authenticate(sock_info, reauthenticate=reauthenticate) + authenticator = _get_authenticator(credentials, conn.address) + return authenticator.authenticate(conn, reauthenticate=reauthenticate) diff --git a/pymongo/bulk.py b/pymongo/bulk.py index f7fdc805bf..a0ada2fb36 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -65,7 +65,7 @@ if TYPE_CHECKING: from pymongo.collection import Collection - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline _DELETE_ALL: int = 0 @@ -311,7 +311,7 @@ def _execute_command( generator: Iterator[Any], write_concern: WriteConcern, session: Optional[ClientSession], - sock_info: SocketInfo, + conn: Connection, op_id: int, retryable: bool, full_result: MutableMapping[str, Any], @@ -326,9 +326,9 @@ def _execute_command( self.next_run = None run = self.current_run - # sock_info.command validates the session, but we use - # sock_info.write_command. - sock_info.validate_session(client, session) + # Connection.command validates the session, but we use + # Connection.write_command + conn.validate_session(client, session) last_run = False while run: @@ -341,7 +341,7 @@ def _execute_command( bwc = self.bulk_ctx_class( db_name, cmd_name, - sock_info, + conn, op_id, listeners, session, @@ -369,11 +369,11 @@ def _execute_command( if retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, sock_info) - sock_info.send_cluster_time(cmd, session, client) - sock_info.add_server_api(cmd) + session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + conn.send_cluster_time(cmd, session, client) + conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. - sock_info.apply_timeout(client, cmd) + conn.apply_timeout(client, cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. @@ -430,13 +430,13 @@ def execute_command( op_id = _randint() def retryable_bulk( - session: Optional[ClientSession], sock_info: SocketInfo, retryable: bool + session: Optional[ClientSession], conn: Connection, retryable: bool ) -> None: self._execute_command( generator, write_concern, session, - sock_info, + conn, op_id, retryable, full_result, @@ -450,7 +450,7 @@ def retryable_bulk( _raise_bulk_write_error(full_result) return full_result - def execute_op_msg_no_results(self, sock_info: SocketInfo, generator: Iterator[Any]) -> None: + def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name client = self.collection.database.client @@ -466,7 +466,7 @@ def execute_op_msg_no_results(self, sock_info: SocketInfo, generator: Iterator[A bwc = self.bulk_ctx_class( db_name, cmd_name, - sock_info, + conn, op_id, listeners, None, @@ -482,7 +482,7 @@ def execute_op_msg_no_results(self, sock_info: SocketInfo, generator: Iterator[A ("writeConcern", {"w": 0}), ] ) - sock_info.add_server_api(cmd) + conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. to_send = bwc.execute_unack(cmd, ops, client) @@ -491,7 +491,7 @@ def execute_op_msg_no_results(self, sock_info: SocketInfo, generator: Iterator[A def execute_command_no_results( self, - sock_info: SocketInfo, + conn: Connection, generator: Iterator[Any], write_concern: WriteConcern, ) -> None: @@ -516,7 +516,7 @@ def execute_command_no_results( generator, initial_write_concern, None, - sock_info, + conn, op_id, False, full_result, @@ -527,7 +527,7 @@ def execute_command_no_results( def execute_no_results( self, - sock_info: SocketInfo, + conn: Connection, generator: Iterator[Any], write_concern: WriteConcern, ) -> None: @@ -538,11 +538,11 @@ def execute_no_results( raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.") # Guard against unsupported unacknowledged writes. unack = write_concern and not write_concern.acknowledged - if unack and self.uses_hint_delete and sock_info.max_wire_version < 9: + if unack and self.uses_hint_delete and conn.max_wire_version < 9: raise ConfigurationError( "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." ) - if unack and self.uses_hint_update and sock_info.max_wire_version < 8: + if unack and self.uses_hint_update and conn.max_wire_version < 8: raise ConfigurationError( "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." ) @@ -553,8 +553,8 @@ def execute_no_results( ) if self.ordered: - return self.execute_command_no_results(sock_info, generator, write_concern) - return self.execute_op_msg_no_results(sock_info, generator) + return self.execute_command_no_results(conn, generator, write_concern) + return self.execute_op_msg_no_results(conn, generator) def execute(self, write_concern: WriteConcern, session: Optional[ClientSession]) -> Any: """Execute operations.""" @@ -573,8 +573,8 @@ def execute(self, write_concern: WriteConcern, session: Optional[ClientSession]) client = self.collection.database.client if not write_concern.acknowledged: - with client._socket_for_writes(session) as sock_info: - self.execute_no_results(sock_info, generator, write_concern) + with client._conn_for_writes(session) as connection: + self.execute_no_results(connection, generator, write_concern) return None else: return self.execute_command(generator, write_concern, session) diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index 10bfd36236..a40b5b2f14 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -78,7 +78,7 @@ from pymongo.collection import Collection from pymongo.database import Database from pymongo.mongo_client import MongoClient - from pymongo.pool import SocketInfo + from pymongo.pool import Connection def _resumable(exc: PyMongoError) -> bool: @@ -213,7 +213,7 @@ def _aggregation_pipeline(self) -> List[Dict[str, Any]]: full_pipeline.extend(self._pipeline) return full_pipeline - def _process_result(self, result: Mapping[str, Any], sock_info: SocketInfo) -> None: + def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None: """Callback that caches the postBatchResumeToken or startAtOperationTime from a changeStream aggregate command response containing an empty batch of change documents. @@ -228,7 +228,7 @@ def _process_result(self, result: Mapping[str, Any], sock_info: SocketInfo) -> N self._start_at_operation_time is None and self._uses_resume_after is False and self._uses_start_after is False - and sock_info.max_wire_version >= 7 + and conn.max_wire_version >= 7 ): self._start_at_operation_time = result.get("operationTime") # PYTHON-2181: informative error on missing operationTime. diff --git a/pymongo/client_session.py b/pymongo/client_session.py index a43982e43d..b4fe5d5da3 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -160,7 +160,7 @@ from bson.son import SON from bson.timestamp import Timestamp from pymongo import _csot -from pymongo.cursor import _SocketManager +from pymongo.cursor import _ConnectionManager from pymongo.errors import ( ConfigurationError, ConnectionFailure, @@ -178,7 +178,7 @@ if TYPE_CHECKING: from types import TracebackType - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.server import Server @@ -400,7 +400,7 @@ def __init__(self, opts: Optional[TransactionOptions], client: MongoClient): self.state = _TxnState.NONE self.sharded = False self.pinned_address: Optional[Tuple[str, Optional[int]]] = None - self.sock_mgr: Optional[_SocketManager] = None + self.conn_mgr: Optional[_ConnectionManager] = None self.recovery_token = None self.attempt = 0 self.client = client @@ -412,23 +412,23 @@ def starting(self) -> bool: return self.state == _TxnState.STARTING @property - def pinned_conn(self) -> Optional[SocketInfo]: - if self.active() and self.sock_mgr: - return self.sock_mgr.sock + def pinned_conn(self) -> Optional[Connection]: + if self.active() and self.conn_mgr: + return self.conn_mgr.conn return None - def pin(self, server: Server, sock_info: SocketInfo) -> None: + def pin(self, server: Server, conn: Connection) -> None: self.sharded = True self.pinned_address = server.description.address if server.description.server_type == SERVER_TYPE.LoadBalancer: - sock_info.pin_txn() - self.sock_mgr = _SocketManager(sock_info, False) + conn.pin_txn() + self.conn_mgr = _ConnectionManager(conn, False) def unpin(self) -> None: self.pinned_address = None - if self.sock_mgr: - self.sock_mgr.close() - self.sock_mgr = None + if self.conn_mgr: + self.conn_mgr.close() + self.conn_mgr = None def reset(self) -> None: self.unpin() @@ -438,11 +438,11 @@ def reset(self) -> None: self.attempt = 0 def __del__(self) -> None: - if self.sock_mgr: + if self.conn_mgr: # Reuse the cursor closing machinery to return the socket to the # pool soon. - self.client._close_cursor_soon(0, None, self.sock_mgr) - self.sock_mgr = None + self.client._close_cursor_soon(0, None, self.conn_mgr) + self.conn_mgr = None def _reraise_with_unknown_commit(exc: Any) -> NoReturn: @@ -839,12 +839,12 @@ def _finish_transaction_with_retry(self, command_name: str) -> Dict[str, Any]: - `command_name`: Either "commitTransaction" or "abortTransaction". """ - def func(session: ClientSession, sock_info: SocketInfo, retryable: bool) -> Dict[str, Any]: - return self._finish_transaction(sock_info, command_name) + def func(session: ClientSession, conn: Connection, retryable: bool) -> Dict[str, Any]: + return self._finish_transaction(conn, command_name) return self._client._retry_internal(True, func, self, None) - def _finish_transaction(self, sock_info: SocketInfo, command_name: str) -> Dict[str, Any]: + def _finish_transaction(self, conn: Connection, command_name: str) -> Dict[str, Any]: self._transaction.attempt += 1 opts = self._transaction.opts assert opts @@ -868,7 +868,7 @@ def _finish_transaction(self, sock_info: SocketInfo, command_name: str) -> Dict[ cmd["recoveryToken"] = self._transaction.recovery_token return self._client.admin._command( - sock_info, cmd, session=self, write_concern=wc, parse_write_concern_error=True + conn, cmd, session=self, write_concern=wc, parse_write_concern_error=True ) def _advance_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None: @@ -954,13 +954,13 @@ def _pinned_address(self) -> Optional[Tuple[str, Optional[int]]]: return None @property - def _pinned_connection(self) -> Optional[SocketInfo]: + def _pinned_connection(self) -> Optional[Connection]: """The connection this transaction was started on.""" return self._transaction.pinned_conn - def _pin(self, server: Server, sock_info: SocketInfo) -> None: + def _pin(self, server: Server, conn: Connection) -> None: """Pin this session to the given Server or to the given connection.""" - self._transaction.pin(server, sock_info) + self._transaction.pin(server, conn) def _unpin(self) -> None: """Unpin this session from any pinned Server.""" @@ -985,12 +985,12 @@ def _apply_to( command: MutableMapping[str, Any], is_retryable: bool, read_preference: ReadPreference, - sock_info: SocketInfo, + conn: Connection, ) -> None: self._check_ended() self._materialize() if self.options.snapshot: - self._update_read_concern(command, sock_info) + self._update_read_concern(command, conn) self._server_session.last_use = time.monotonic() command["lsid"] = self._server_session.session_id @@ -1016,7 +1016,7 @@ def _apply_to( rc = self._transaction.opts.read_concern.document if rc: command["readConcern"] = rc - self._update_read_concern(command, sock_info) + self._update_read_concern(command, conn) command["txnNumber"] = self._server_session.transaction_id command["autocommit"] = False @@ -1025,11 +1025,11 @@ def _start_retryable_write(self) -> None: self._check_ended() self._server_session.inc_transaction_id() - def _update_read_concern(self, cmd: MutableMapping[str, Any], sock_info: SocketInfo) -> None: + def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: Connection) -> None: if self.options.causal_consistency and self.operation_time is not None: cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time if self.options.snapshot: - if sock_info.max_wire_version < 13: + if conn.max_wire_version < 13: raise ConfigurationError("Snapshot reads require MongoDB 5.0 or later") rc = cmd.setdefault("readConcern", {}) rc["level"] = "snapshot" diff --git a/pymongo/collection.py b/pymongo/collection.py index 6a1bcf8c06..af93a9aa77 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -126,7 +126,7 @@ class ReturnDocument: from pymongo.client_session import ClientSession from pymongo.collation import Collation from pymongo.database import Database - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.read_concern import ReadConcern from pymongo.server import Server @@ -262,17 +262,17 @@ def __init__( else: self.__create(name, kwargs, collation, session) - def _socket_for_reads( + def _conn_for_reads( self, session: ClientSession - ) -> ContextManager[Tuple[SocketInfo, Union[PrimaryPreferred, Primary]]]: - return self.__database.client._socket_for_reads(self._read_preference_for(session), session) + ) -> ContextManager[Tuple[Connection, Union[PrimaryPreferred, Primary]]]: + return self.__database.client._conn_for_reads(self._read_preference_for(session), session) - def _socket_for_writes(self, session: Optional[ClientSession]) -> ContextManager[SocketInfo]: - return self.__database.client._socket_for_writes(session) + def _conn_for_writes(self, session: Optional[ClientSession]) -> ContextManager[Connection]: + return self.__database.client._conn_for_writes(session) def _command( self, - sock_info: SocketInfo, + conn: Connection, command: Mapping[str, Any], read_preference: Optional[_ServerMode] = None, codec_options: Optional[CodecOptions] = None, @@ -288,7 +288,7 @@ def _command( """Internal command helper. :Parameters: - - `sock_info` - A SocketInfo instance. + - `conn` - A Connection instance. - `command` - The command itself, as a :class:`~bson.son.SON` instance. - `read_preference` (optional) - The read preference to use. - `codec_options` (optional) - An instance of @@ -313,7 +313,7 @@ def _command( The result document. """ with self.__database.client._tmp_session(session) as s: - return sock_info.command( + return conn.command( self.__database.name, command, read_preference or self._read_preference_for(session), @@ -348,16 +348,16 @@ def __create( if "size" in options: options["size"] = float(options["size"]) cmd.update(options) - with self._socket_for_writes(session) as sock_info: - if qev2_required and sock_info.max_wire_version < 21: + with self._conn_for_writes(session) as conn: + if qev2_required and conn.max_wire_version < 21: raise ConfigurationError( "Driver support of Queryable Encryption is incompatible with server. " "Upgrade server to use Queryable Encryption. " - f"Got maxWireVersion {sock_info.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)" + f"Got maxWireVersion {conn.max_wire_version} but need maxWireVersion >= 21 (MongoDB >=7.0)" ) self._command( - sock_info, + conn, cmd, read_preference=ReadPreference.PRIMARY, write_concern=self._write_concern_for(session), @@ -597,12 +597,12 @@ def _insert_one( command["comment"] = comment def _insert_command( - session: ClientSession, sock_info: SocketInfo, retryable_write: bool + session: ClientSession, conn: Connection, retryable_write: bool ) -> None: if bypass_doc_val: command["bypassDocumentValidation"] = True - result = sock_info.command( + result = conn.command( self.__database.name, command, write_concern=write_concern, @@ -765,7 +765,7 @@ def gen() -> Iterator[Tuple[int, Mapping[str, Any]]]: def _update( self, - sock_info: SocketInfo, + conn: Connection, criteria: Mapping[str, Any], document: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, @@ -801,7 +801,7 @@ def _update( else: update_doc["arrayFilters"] = array_filters if hint is not None: - if not acknowledged and sock_info.max_wire_version < 8: + if not acknowledged and conn.max_wire_version < 8: raise ConfigurationError( "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." ) @@ -821,7 +821,7 @@ def _update( # The command result has to be published for APM unmodified # so we make a shallow copy here before adding updatedExisting. - result = sock_info.command( + result = conn.command( self.__database.name, command, write_concern=write_concern, @@ -865,10 +865,10 @@ def _update_retryable( """Internal update / replace helper.""" def _update( - session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool + session: Optional[ClientSession], conn: Connection, retryable_write: bool ) -> Optional[Mapping[str, Any]]: return self._update( - sock_info, + conn, criteria, document, upsert=upsert, @@ -1255,7 +1255,7 @@ def drop( def _delete( self, - sock_info: SocketInfo, + conn: Connection, criteria: Mapping[str, Any], multi: bool, write_concern: Optional[WriteConcern] = None, @@ -1280,7 +1280,7 @@ def _delete( else: delete_doc["collation"] = collation if hint is not None: - if not acknowledged and sock_info.max_wire_version < 9: + if not acknowledged and conn.max_wire_version < 9: raise ConfigurationError( "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." ) @@ -1297,7 +1297,7 @@ def _delete( command["comment"] = comment # Delete command. - result = sock_info.command( + result = conn.command( self.__database.name, command, write_concern=write_concern, @@ -1325,10 +1325,10 @@ def _delete_retryable( """Internal delete helper.""" def _delete( - session: Optional[ClientSession], sock_info: SocketInfo, retryable_write: bool + session: Optional[ClientSession], conn: Connection, retryable_write: bool ) -> Mapping[str, Any]: return self._delete( - sock_info, + conn, criteria, multi, write_concern=write_concern, @@ -1738,7 +1738,7 @@ def find_raw_batches(self, *args: Any, **kwargs: Any) -> RawBatchCursor[_Documen def _count_cmd( self, session: ClientSession, - sock_info: SocketInfo, + conn: Connection, read_preference: Optional[_ServerMode], cmd: Mapping[str, Any], collation: Optional[Collation], @@ -1747,7 +1747,7 @@ def _count_cmd( # XXX: "ns missing" checks can be removed when we drop support for # MongoDB 3.0, see SERVER-17051. res = self._command( - sock_info, + conn, cmd, read_preference=read_preference, allowable_errors=["ns missing"], @@ -1762,7 +1762,7 @@ def _count_cmd( def _aggregate_one_result( self, - sock_info: SocketInfo, + conn: Connection, read_preference: Optional[_ServerMode], cmd: Mapping[str, Any], collation: Optional[_CollationIn], @@ -1770,7 +1770,7 @@ def _aggregate_one_result( ) -> Optional[Mapping[str, Any]]: """Internal helper to run an aggregate that returns a single result.""" result = self._command( - sock_info, + conn, cmd, read_preference, allowable_errors=[26], # Ignore NamespaceNotFound. @@ -1821,12 +1821,12 @@ def estimated_document_count(self, comment: Optional[Any] = None, **kwargs: Any) def _cmd( session: ClientSession, server: Server, - sock_info: SocketInfo, + conn: Connection, read_preference: Optional[_ServerMode], ) -> int: cmd: SON[str, Any] = SON([("count", self.__name)]) cmd.update(kwargs) - return self._count_cmd(session, sock_info, read_preference, cmd, collation=None) + return self._count_cmd(session, conn, read_preference, cmd, collation=None) return self._retryable_non_cursor_read(_cmd, None) @@ -1910,10 +1910,10 @@ def count_documents( def _cmd( session: ClientSession, server: Server, - sock_info: SocketInfo, + conn: Connection, read_preference: Optional[_ServerMode], ) -> int: - result = self._aggregate_one_result(sock_info, read_preference, cmd, collation, session) + result = self._aggregate_one_result(conn, read_preference, cmd, collation, session) if not result: return 0 return result["n"] @@ -1922,7 +1922,7 @@ def _cmd( def _retryable_non_cursor_read( self, - func: Callable[[ClientSession, Server, SocketInfo, Optional[_ServerMode]], T], + func: Callable[[ClientSession, Server, Connection, Optional[_ServerMode]], T], session: Optional[ClientSession], ) -> T: """Non-cursor read helper to handle implicit session creation.""" @@ -1993,8 +1993,8 @@ def __create_indexes( command (like maxTimeMS) can be passed as keyword arguments. """ names = [] - with self._socket_for_writes(session) as sock_info: - supports_quorum = sock_info.max_wire_version >= 9 + with self._conn_for_writes(session) as conn: + supports_quorum = conn.max_wire_version >= 9 def gen_indexes() -> Iterator[Mapping[str, Any]]: for index in indexes: @@ -2015,7 +2015,7 @@ def gen_indexes() -> Iterator[Mapping[str, Any]]: ) self._command( - sock_info, + conn, cmd, read_preference=ReadPreference.PRIMARY, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, @@ -2236,9 +2236,9 @@ def drop_index( cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - with self._socket_for_writes(session) as sock_info: + with self._conn_for_writes(session) as conn: self._command( - sock_info, + conn, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=["ns not found", 26], @@ -2285,7 +2285,7 @@ def list_indexes( def _cmd( session: ClientSession, server: Server, - sock_info: SocketInfo, + conn: Connection, read_preference: _ServerMode, ) -> CommandCursor[_DocumentType]: cmd = SON([("listIndexes", self.__name), ("cursor", {})]) @@ -2293,9 +2293,9 @@ def _cmd( cmd["comment"] = comment try: - cursor = self._command( - sock_info, cmd, read_preference, codec_options, session=session - )["cursor"] + cursor = self._command(conn, cmd, read_preference, codec_options, session=session)[ + "cursor" + ] except OperationFailure as exc: # Ignore NamespaceNotFound errors to match the behavior # of reading from *.system.indexes. @@ -2305,12 +2305,12 @@ def _cmd( cmd_cursor = CommandCursor( coll, cursor, - sock_info.address, + conn.address, session=session, explicit_session=explicit_session, comment=cmd.get("comment"), ) - cmd_cursor._maybe_pin_connection(sock_info) + cmd_cursor._maybe_pin_connection(conn) return cmd_cursor with self.__database.client._tmp_session(session, False) as s: @@ -2479,9 +2479,9 @@ def gen_indexes(): cmd = SON([("createSearchIndexes", self.name), ("indexes", list(gen_indexes()))]) cmd.update(kwargs) - with self._socket_for_writes(session) as sock_info: + with self._conn_for_writes(session) as conn: resp = self._command( - sock_info, + conn, cmd, read_preference=ReadPreference.PRIMARY, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, @@ -2514,9 +2514,9 @@ def drop_search_index( cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - with self._socket_for_writes(session) as sock_info: + with self._conn_for_writes(session) as conn: self._command( - sock_info, + conn, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=["ns not found", 26], @@ -2551,9 +2551,9 @@ def update_search_index( cmd.update(kwargs) if comment is not None: cmd["comment"] = comment - with self._socket_for_writes(session) as sock_info: + with self._conn_for_writes(session) as conn: self._command( - sock_info, + conn, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=["ns not found", 26], @@ -2980,9 +2980,9 @@ def rename( cmd["comment"] = comment write_concern = self._write_concern_for_cmd(cmd, session) - with self._socket_for_writes(session) as sock_info: + with self._conn_for_writes(session) as conn: with self.__database.client._tmp_session(session) as s: - return sock_info.command( + return conn.command( "admin", cmd, write_concern=write_concern, @@ -3049,11 +3049,11 @@ def distinct( def _cmd( session: ClientSession, server: Server, - sock_info: SocketInfo, + conn: Connection, read_preference: Optional[_ServerMode], ) -> List: return self._command( - sock_info, + conn, cmd, read_preference=read_preference, read_concern=self.read_concern, @@ -3112,7 +3112,7 @@ def __find_and_modify( write_concern = self._write_concern_for_cmd(cmd, session) def _find_and_modify( - session: ClientSession, sock_info: SocketInfo, retryable_write: bool + session: ClientSession, conn: Connection, retryable_write: bool ) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: @@ -3122,17 +3122,17 @@ def _find_and_modify( ) cmd["arrayFilters"] = list(array_filters) if hint is not None: - if sock_info.max_wire_version < 8: + if conn.max_wire_version < 8: raise ConfigurationError( "Must be connected to MongoDB 4.2+ to use hint on find and modify commands." ) - elif not acknowledged and sock_info.max_wire_version < 9: + elif not acknowledged and conn.max_wire_version < 9: raise ConfigurationError( "Must be connected to MongoDB 4.4+ to use hint on unacknowledged find and modify commands." ) cmd["hint"] = hint out = self._command( - sock_info, + conn, cmd, read_preference=ReadPreference.PRIMARY, write_concern=write_concern, diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index 4a3d0311ca..c89b87ce36 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -29,7 +29,7 @@ ) from bson import CodecOptions, _convert_raw_document_lists_to_streams -from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _SocketManager +from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _ConnectionManager from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure from pymongo.message import _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore from pymongo.response import PinnedResponse @@ -38,7 +38,7 @@ if TYPE_CHECKING: from pymongo.client_session import ClientSession from pymongo.collection import Collection - from pymongo.pool import SocketInfo + from pymongo.pool import Connection class CommandCursor(Generic[_DocumentType]): @@ -157,19 +157,19 @@ def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]: """ return self.__postbatchresumetoken - def _maybe_pin_connection(self, sock_info: SocketInfo) -> None: + def _maybe_pin_connection(self, conn: Connection) -> None: client = self.__collection.database.client if not client._should_pin_cursor(self.__session): return if not self.__sock_mgr: - sock_info.pin_cursor() - sock_mgr = _SocketManager(sock_info, False) + conn.pin_cursor() + conn_mgr = _ConnectionManager(conn, False) # Ensure the connection gets returned when the entire result is # returned in the first batch. if self.__id == 0: - sock_mgr.close() + conn_mgr.close() else: - self.__sock_mgr = sock_mgr + self.__sock_mgr = conn_mgr def __send_message(self, operation: _GetMore) -> None: """Send a getmore message and handle the response.""" @@ -197,7 +197,7 @@ def __send_message(self, operation: _GetMore) -> None: if isinstance(response, PinnedResponse): if not self.__sock_mgr: - self.__sock_mgr = _SocketManager(response.socket_info, response.more_to_come) + self.__sock_mgr = _ConnectionManager(response.conn, response.more_to_come) if response.from_command: cursor = response.docs[0]["cursor"] documents = cursor["nextBatch"] diff --git a/pymongo/cursor.py b/pymongo/cursor.py index b718d905e5..444e755c0c 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -64,7 +64,7 @@ from pymongo.client_session import ClientSession from pymongo.collection import Collection from pymongo.message import _OpMsg, _OpReply - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.read_preferences import _ServerMode @@ -139,11 +139,11 @@ class CursorType: """ -class _SocketManager: - """Used with exhaust cursors to ensure the socket is returned.""" +class _ConnectionManager: + """Used with exhaust cursors to ensure the connection is returned.""" - def __init__(self, sock: SocketInfo, more_to_come: bool): - self.sock: Optional[SocketInfo] = sock + def __init__(self, conn: Connection, more_to_come: bool): + self.conn: Optional[Connection] = conn self.more_to_come = more_to_come self.lock = _create_lock() @@ -151,10 +151,10 @@ def update_exhaust(self, more_to_come: bool) -> None: self.more_to_come = more_to_come def close(self) -> None: - """Return this instance's socket to the connection pool.""" - if self.sock: - self.sock.unpin() - self.sock = None + """Return this instance's connection to the connection pool.""" + if self.conn: + self.conn.unpin() + self.conn = None _Sort = Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]] @@ -1085,7 +1085,7 @@ def __send_message(self, operation: Union[_Query, _GetMore]) -> None: self.__address = response.address if isinstance(response, PinnedResponse): if not self.__sock_mgr: - self.__sock_mgr = _SocketManager(response.socket_info, response.more_to_come) + self.__sock_mgr = _ConnectionManager(response.conn, response.more_to_come) cmd_name = operation.name docs = response.docs diff --git a/pymongo/database.py b/pymongo/database.py index a55555cc4e..8de0b1d34a 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -48,7 +48,7 @@ from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline if TYPE_CHECKING: - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.server import Server @@ -689,7 +689,7 @@ def watch( @overload def _command( self, - sock_info: SocketInfo, + conn: Connection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -706,7 +706,7 @@ def _command( @overload def _command( self, - sock_info: SocketInfo, + conn: Connection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -722,7 +722,7 @@ def _command( def _command( self, - sock_info: SocketInfo, + conn: Connection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -742,7 +742,7 @@ def _command( command.update(kwargs) with self.__client._tmp_session(session) as s: - return sock_info.command( + return conn.command( self.__name, command, read_preference, @@ -889,12 +889,12 @@ def command( if read_preference is None: read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY - with self.__client._socket_for_reads(read_preference, session) as ( - sock_info, + with self.__client._conn_for_reads(read_preference, session) as ( + connection, read_preference, ): return self._command( - sock_info, + connection, command, value, check, @@ -973,12 +973,12 @@ def cursor_command( read_preference = ( tmp_session and tmp_session._txn_read_preference() ) or ReadPreference.PRIMARY - with self.__client._socket_for_reads(read_preference, tmp_session) as ( - sock_info, + with self.__client._conn_for_reads(read_preference, tmp_session) as ( + conn, read_preference, ): response = self._command( - sock_info, + conn, command, value, True, @@ -993,13 +993,13 @@ def cursor_command( cmd_cursor = CommandCursor( coll, response["cursor"], - sock_info.address, + conn.address, max_await_time_ms=max_await_time_ms, session=tmp_session, explicit_session=session is not None, comment=comment, ) - cmd_cursor._maybe_pin_connection(sock_info) + cmd_cursor._maybe_pin_connection(conn) return cmd_cursor else: raise InvalidOperation("Command does not return a cursor.") @@ -1015,11 +1015,11 @@ def _retryable_read_command( def _cmd( session: Optional[ClientSession], server: Server, - sock_info: SocketInfo, + conn: Connection, read_preference: _ServerMode, ) -> Dict[str, Any]: return self._command( - sock_info, + conn, command, read_preference=read_preference, session=session, @@ -1029,7 +1029,7 @@ def _cmd( def _list_collections( self, - sock_info: SocketInfo, + conn: Connection, session: Optional[ClientSession], read_preference: _ServerMode, **kwargs: Any, @@ -1039,18 +1039,18 @@ def _list_collections( cmd = SON([("listCollections", 1), ("cursor", {})]) cmd.update(kwargs) with self.__client._tmp_session(session, close=False) as tmp_session: - cursor = self._command( - sock_info, cmd, read_preference=read_preference, session=tmp_session - )["cursor"] + cursor = self._command(conn, cmd, read_preference=read_preference, session=tmp_session)[ + "cursor" + ] cmd_cursor = CommandCursor( coll, cursor, - sock_info.address, + conn.address, session=tmp_session, explicit_session=session is not None, comment=cmd.get("comment"), ) - cmd_cursor._maybe_pin_connection(sock_info) + cmd_cursor._maybe_pin_connection(conn) return cmd_cursor def list_collections( @@ -1090,12 +1090,10 @@ def list_collections( def _cmd( session: Optional[ClientSession], server: Server, - sock_info: SocketInfo, + conn: Connection, read_preference: _ServerMode, ) -> CommandCursor[_DocumentType]: - return self._list_collections( - sock_info, session, read_preference=read_preference, **kwargs - ) + return self._list_collections(conn, session, read_preference=read_preference, **kwargs) return self.__client._retryable_read(_cmd, read_pref, session) @@ -1154,9 +1152,9 @@ def _drop_helper( if comment is not None: command["comment"] = comment - with self.__client._socket_for_writes(session) as sock_info: + with self.__client._conn_for_writes(session) as connection: return self._command( - sock_info, + connection, command, allowable_errors=["ns not found", 26], write_concern=self._write_concern_for(session), diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 9a4a2b04c3..c0c4f51edb 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -307,7 +307,7 @@ def _handle_exception() -> None: def _handle_reauth(func: F) -> F: def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) - from pymongo.pool import SocketInfo + from pymongo.pool import Connection try: return func(*args, **kwargs) @@ -315,19 +315,19 @@ def inner(*args: Any, **kwargs: Any) -> Any: if no_reauth: raise if exc.code == _REAUTHENTICATION_REQUIRED_CODE: - # Look for an argument that either is a SocketInfo - # or has a socket_info attribute, so we can trigger + # Look for an argument that either is a Connection + # or has a connection attribute, so we can trigger # a reauth. - sock_info = None + conn = None for arg in args: - if isinstance(arg, SocketInfo): - sock_info = arg + if isinstance(arg, Connection): + conn = arg break - if hasattr(arg, "sock_info"): - sock_info = arg.sock_info + if hasattr(arg, "connection"): + conn = arg.conn break - if sock_info: - sock_info.authenticate(reauthenticate=True) + if conn: + conn.authenticate(reauthenticate=True) else: raise return func(*args, **kwargs) diff --git a/pymongo/message.py b/pymongo/message.py index 735f8a8ccd..c4cf4d5690 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -227,14 +227,14 @@ def _gen_find_command( return cmd -def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, sock_info): +def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms, comment, conn): """Generate a getMore command document.""" cmd = SON([("getMore", cursor_id), ("collection", coll)]) if batch_size: cmd["batchSize"] = batch_size if max_await_time_ms is not None: cmd["maxTimeMS"] = max_await_time_ms - if comment is not None and sock_info.max_wire_version >= 9: + if comment is not None and conn.max_wire_version >= 9: cmd["comment"] = comment return cmd @@ -264,7 +264,7 @@ class _Query: ) # For compatibility with the _GetMore class. - sock_mgr = None + conn_mgr = None cursor_id = None def __init__( @@ -311,24 +311,23 @@ def reset(self): def namespace(self): return f"{self.db}.{self.coll}" - def use_command(self, sock_info): + def use_command(self, conn): use_find_cmd = False if not self.exhaust: use_find_cmd = True - elif sock_info.max_wire_version >= 8: + elif conn.max_wire_version >= 8: # OP_MSG supports exhaust on MongoDB 4.2+ use_find_cmd = True elif not self.read_concern.ok_for_legacy: raise ConfigurationError( "read concern level of %s is not valid " - "with a max wire version of %d." - % (self.read_concern.level, sock_info.max_wire_version) + "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) ) - sock_info.validate_session(self.client, self.session) + conn.validate_session(self.client, self.session) return use_find_cmd - def as_command(self, sock_info, apply_timeout=False): + def as_command(self, conn, apply_timeout=False): """Return a find command document for this query.""" # We use the command twice: on the wire and for command monitoring. # Generate it once, for speed and to avoid repeating side-effects. @@ -353,24 +352,24 @@ def as_command(self, sock_info, apply_timeout=False): self.name = "explain" cmd = SON([("explain", cmd)]) session = self.session - sock_info.add_server_api(cmd) + conn.add_server_api(cmd) if session: - session._apply_to(cmd, False, self.read_preference, sock_info) + session._apply_to(cmd, False, self.read_preference, conn) # Explain does not support readConcern. if not explain and not session.in_transaction: - session._update_read_concern(cmd, sock_info) - sock_info.send_cluster_time(cmd, session, self.client) + session._update_read_concern(cmd, conn) + conn.send_cluster_time(cmd, session, self.client) # Support auto encryption client = self.client if client._encrypter and not client._encrypter._bypass_auto_encryption: cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) # Support CSOT if apply_timeout: - sock_info.apply_timeout(client, cmd) + conn.apply_timeout(client, cmd) self._as_command = cmd, self.db return self._as_command - def get_message(self, read_preference, sock_info, use_cmd=False): + def get_message(self, read_preference, conn, use_cmd=False): """Get a query message, possibly setting the secondaryOk bit.""" # Use the read_preference decided by _socket_from_server. self.read_preference = read_preference @@ -384,14 +383,14 @@ def get_message(self, read_preference, sock_info, use_cmd=False): spec = self.spec if use_cmd: - spec = self.as_command(sock_info, apply_timeout=True)[0] + spec = self.as_command(conn, apply_timeout=True)[0] request_id, msg, size, _ = _op_msg( 0, spec, self.db, read_preference, self.codec_options, - ctx=sock_info.compression_context, + ctx=conn.compression_context, ) return request_id, msg, size @@ -405,7 +404,7 @@ def get_message(self, read_preference, sock_info, use_cmd=False): else: ntoreturn = self.limit - if sock_info.is_mongos: + if conn.is_mongos: spec = _maybe_add_read_preference(spec, read_preference) return _query( @@ -416,7 +415,7 @@ def get_message(self, read_preference, sock_info, use_cmd=False): spec, None if use_cmd else self.fields, self.codec_options, - ctx=sock_info.compression_context, + ctx=conn.compression_context, ) @@ -433,7 +432,7 @@ class _GetMore: "read_preference", "session", "client", - "sock_mgr", + "conn_mgr", "_as_command", "exhaust", "comment", @@ -452,7 +451,7 @@ def __init__( session, client, max_await_time_ms, - sock_mgr, + conn_mgr, exhaust, comment, ): @@ -465,7 +464,7 @@ def __init__( self.session = session self.client = client self.max_await_time_ms = max_await_time_ms - self.sock_mgr = sock_mgr + self.conn_mgr = conn_mgr self._as_command = None self.exhaust = exhaust self.comment = comment @@ -476,18 +475,18 @@ def reset(self): def namespace(self): return f"{self.db}.{self.coll}" - def use_command(self, sock_info): + def use_command(self, conn): use_cmd = False if not self.exhaust: use_cmd = True - elif sock_info.max_wire_version >= 8: + elif conn.max_wire_version >= 8: # OP_MSG supports exhaust on MongoDB 4.2+ use_cmd = True - sock_info.validate_session(self.client, self.session) + conn.validate_session(self.client, self.session) return use_cmd - def as_command(self, sock_info, apply_timeout=False): + def as_command(self, conn, apply_timeout=False): """Return a getMore command document for this query.""" # See _Query.as_command for an explanation of this caching. if self._as_command is not None: @@ -499,35 +498,35 @@ def as_command(self, sock_info, apply_timeout=False): self.ntoreturn, self.max_await_time_ms, self.comment, - sock_info, + conn, ) if self.session: - self.session._apply_to(cmd, False, self.read_preference, sock_info) - sock_info.add_server_api(cmd) - sock_info.send_cluster_time(cmd, self.session, self.client) + self.session._apply_to(cmd, False, self.read_preference, conn) + conn.add_server_api(cmd) + conn.send_cluster_time(cmd, self.session, self.client) # Support auto encryption client = self.client if client._encrypter and not client._encrypter._bypass_auto_encryption: cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) # Support CSOT if apply_timeout: - sock_info.apply_timeout(client, cmd=None) + conn.apply_timeout(client, cmd=None) self._as_command = cmd, self.db return self._as_command - def get_message(self, dummy0, sock_info, use_cmd=False): + def get_message(self, dummy0, conn, use_cmd=False): """Get a getmore message.""" ns = self.namespace() - ctx = sock_info.compression_context + ctx = conn.compression_context if use_cmd: - spec = self.as_command(sock_info, apply_timeout=True)[0] - if self.sock_mgr: + spec = self.as_command(conn, apply_timeout=True)[0] + if self.conn_mgr: flags = _OpMsg.EXHAUST_ALLOWED else: flags = 0 request_id, msg, size, _ = _op_msg( - flags, spec, self.db, None, self.codec_options, ctx=sock_info.compression_context + flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context ) return request_id, msg, size @@ -535,10 +534,10 @@ def get_message(self, dummy0, sock_info, use_cmd=False): class _RawBatchQuery(_Query): - def use_command(self, sock_info): + def use_command(self, conn): # Compatibility checks. - super().use_command(sock_info) - if sock_info.max_wire_version >= 8: + super().use_command(conn) + if conn.max_wire_version >= 8: # MongoDB 4.2+ supports exhaust over OP_MSG return True elif not self.exhaust: @@ -547,10 +546,10 @@ def use_command(self, sock_info): class _RawBatchGetMore(_GetMore): - def use_command(self, sock_info): + def use_command(self, conn): # Compatibility checks. - super().use_command(sock_info) - if sock_info.max_wire_version >= 8: + super().use_command(conn) + if conn.max_wire_version >= 8: # MongoDB 4.2+ supports exhaust over OP_MSG return True elif not self.exhaust: @@ -794,11 +793,11 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None): class _BulkWriteContext: - """A wrapper around SocketInfo for use with write splitting functions.""" + """A wrapper around Connection for use with write splitting functions.""" __slots__ = ( "db_name", - "sock_info", + "conn", "op_id", "name", "field", @@ -812,10 +811,10 @@ class _BulkWriteContext: ) def __init__( - self, database_name, cmd_name, sock_info, operation_id, listeners, session, op_type, codec + self, database_name, cmd_name, conn, operation_id, listeners, session, op_type, codec ): self.db_name = database_name - self.sock_info = sock_info + self.conn = conn self.op_id = operation_id self.listeners = listeners self.publish = listeners.enabled_for_commands @@ -823,7 +822,7 @@ def __init__( self.field = _FIELD_MAP[self.name] self.start_time = datetime.datetime.now() if self.publish else None self.session = session - self.compress = True if sock_info.compression_context else False + self.compress = True if conn.compression_context else False self.op_type = op_type self.codec = codec @@ -855,20 +854,20 @@ def execute_unack(self, cmd, docs, client): @property def max_bson_size(self): """A proxy for SockInfo.max_bson_size.""" - return self.sock_info.max_bson_size + return self.conn.max_bson_size @property def max_message_size(self): """A proxy for SockInfo.max_message_size.""" if self.compress: # Subtract 16 bytes for the message header. - return self.sock_info.max_message_size - 16 - return self.sock_info.max_message_size + return self.conn.max_message_size - 16 + return self.conn.max_message_size @property def max_write_batch_size(self): """A proxy for SockInfo.max_write_batch_size.""" - return self.sock_info.max_write_batch_size + return self.conn.max_write_batch_size @property def max_split_size(self): @@ -876,14 +875,14 @@ def max_split_size(self): return self.max_bson_size def unack_write(self, cmd, request_id, msg, max_doc_size, docs): - """A proxy for SocketInfo.unack_write that handles event publishing.""" + """A proxy for Connection.unack_write that handles event publishing.""" if self.publish: assert self.start_time is not None duration = datetime.datetime.now() - self.start_time cmd = self._start(cmd, request_id, docs) start = datetime.datetime.now() try: - result = self.sock_info.unack_write(msg, max_doc_size) + result = self.conn.unack_write(msg, max_doc_size) if self.publish: duration = (datetime.datetime.now() - start) + duration if result is not None: @@ -910,14 +909,14 @@ def unack_write(self, cmd, request_id, msg, max_doc_size, docs): @_handle_reauth def write_command(self, cmd, request_id, msg, docs): - """A proxy for SocketInfo.write_command that handles event publishing.""" + """A proxy for Connection.write_command that handles event publishing.""" if self.publish: assert self.start_time is not None duration = datetime.datetime.now() - self.start_time self._start(cmd, request_id, docs) start = datetime.datetime.now() try: - reply = self.sock_info.write_command(request_id, msg, self.codec) + reply = self.conn.write_command(request_id, msg, self.codec) if self.publish: duration = (datetime.datetime.now() - start) + duration self._succeed(request_id, reply, duration) @@ -941,9 +940,9 @@ def _start(self, cmd, request_id, docs): cmd, self.db_name, request_id, - self.sock_info.address, + self.conn.address, self.op_id, - self.sock_info.service_id, + self.conn.service_id, ) return cmd @@ -954,9 +953,9 @@ def _succeed(self, request_id, reply, duration): reply, self.name, request_id, - self.sock_info.address, + self.conn.address, self.op_id, - self.sock_info.service_id, + self.conn.service_id, ) def _fail(self, request_id, failure, duration): @@ -966,9 +965,9 @@ def _fail(self, request_id, failure, duration): failure, self.name, request_id, - self.sock_info.address, + self.conn.address, self.op_id, - self.sock_info.service_id, + self.conn.service_id, ) @@ -997,14 +996,14 @@ def _batch_command(self, cmd, docs): def execute(self, cmd, docs, client): batched_cmd, to_send = self._batch_command(cmd, docs) - result = self.sock_info.command( + result = self.conn.command( self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client ) return result, to_send def execute_unack(self, cmd, docs, client): batched_cmd, to_send = self._batch_command(cmd, docs) - self.sock_info.command( + self.conn.command( self.db_name, batched_cmd, write_concern=WriteConcern(w=0), @@ -1124,7 +1123,7 @@ def _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx): """ data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) - request_id, msg = _compress(2013, data, ctx.sock_info.compression_context) + request_id, msg = _compress(2013, data, ctx.conn.compression_context) return request_id, msg, to_send @@ -1162,7 +1161,7 @@ def _do_batched_op_msg(namespace, operation, command, docs, opts, ctx): ack = bool(command["writeConcern"].get("w", 1)) else: ack = True - if ctx.sock_info.compression_context: + if ctx.conn.compression_context: return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx) return _batched_op_msg(operation, command, docs, ack, opts, ctx) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index c8a265622c..b5f2f08af9 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1160,18 +1160,18 @@ def options(self) -> ClientOptions: def _end_sessions(self, session_ids): """Send endSessions command(s) with the given session ids.""" try: - # Use SocketInfo.command directly to avoid implicitly creating + # Use Connection.command directly to avoid implicitly creating # another session. - with self._socket_for_reads(ReadPreference.PRIMARY_PREFERRED, None) as ( - sock_info, + with self._conn_for_reads(ReadPreference.PRIMARY_PREFERRED, None) as ( + conn, read_pref, ): - if not sock_info.supports_sessions: + if not conn.supports_sessions: return for i in range(0, len(session_ids), common._MAX_END_SESSIONS): spec = SON([("endSessions", session_ids[i : i + common._MAX_END_SESSIONS])]) - sock_info.command("admin", spec, read_preference=read_pref, client=self) + conn.command("admin", spec, read_preference=read_pref, client=self) except PyMongoError: # Drivers MUST ignore any errors returned by the endSessions # command. @@ -1216,7 +1216,7 @@ def _get_topology(self): return self._topology @contextlib.contextmanager - def _get_socket(self, server, session): + def _checkout(self, server, session): in_txn = session and session.in_transaction with _MongoClientErrorHandler(self, server, session) as err_handler: # Reuse the pinned connection, if it exists. @@ -1224,23 +1224,23 @@ def _get_socket(self, server, session): err_handler.contribute_socket(session._pinned_connection) yield session._pinned_connection return - with server.get_socket(handler=err_handler) as sock_info: + with server.checkout(handler=err_handler) as conn: # Pin this session to the selected server or connection. if in_txn and server.description.server_type in ( SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer, ): - session._pin(server, sock_info) - err_handler.contribute_socket(sock_info) + session._pin(server, conn) + err_handler.contribute_socket(conn) if ( self._encrypter and not self._encrypter._bypass_auto_encryption - and sock_info.max_wire_version < 8 + and conn.max_wire_version < 8 ): raise ConfigurationError( "Auto-encryption requires a minimum MongoDB version of 4.2" ) - yield sock_info + yield conn def _select_server(self, server_selector, session, address=None): """Select a server to run an operation on this client. @@ -1273,15 +1273,15 @@ def _select_server(self, server_selector, session, address=None): session._unpin() raise - def _socket_for_writes(self, session): + def _conn_for_writes(self, session): server = self._select_server(writable_server_selector, session) - return self._get_socket(server, session) + return self._checkout(server, session) @contextlib.contextmanager - def _socket_from_server(self, read_preference, server, session): + def _conn_from_server(self, read_preference, server, session): assert read_preference is not None, "read_preference must not be None" - # Get a socket for a server matching the read preference, and yield - # sock_info with the effective read preference. The Server Selection + # Get a connection for a server matching the read preference, and yield + # conn with the effective read preference. The Server Selection # Spec says not to send any $readPreference to standalones and to # always send primaryPreferred when directly connected to a repl set # member. @@ -1289,22 +1289,22 @@ def _socket_from_server(self, read_preference, server, session): topology = self._get_topology() single = topology.description.topology_type == TOPOLOGY_TYPE.Single - with self._get_socket(server, session) as sock_info: + with self._checkout(server, session) as conn: if single: - if sock_info.is_repl and not (session and session.in_transaction): + if conn.is_repl and not (session and session.in_transaction): # Use primary preferred to ensure any repl set member # can handle the request. read_preference = ReadPreference.PRIMARY_PREFERRED - elif sock_info.is_standalone: + elif conn.is_standalone: # Don't send read preference to standalones. read_preference = ReadPreference.PRIMARY - yield sock_info, read_preference + yield conn, read_preference - def _socket_for_reads(self, read_preference, session): + def _conn_for_reads(self, read_preference, session): assert read_preference is not None, "read_preference must not be None" _ = self._get_topology() server = self._select_server(read_preference, session) - return self._socket_from_server(read_preference, server, session) + return self._conn_from_server(read_preference, server, session) def _should_pin_cursor(self, session): return self.__options.load_balanced and not (session and session.in_transaction) @@ -1319,22 +1319,26 @@ def _run_operation(self, operation, unpack_res, address=None): - `address` (optional): Optional address when sending a message to a specific server, used for getMore. """ - if operation.sock_mgr: + if operation.conn_mgr: server = self._select_server( operation.read_preference, operation.session, address=address ) - with operation.sock_mgr.lock: + with operation.conn_mgr.lock: with _MongoClientErrorHandler(self, server, operation.session) as err_handler: - err_handler.contribute_socket(operation.sock_mgr.sock) + err_handler.contribute_socket(operation.conn_mgr.conn) return server.run_operation( - operation.sock_mgr.sock, operation, True, self._event_listeners, unpack_res + operation.conn_mgr.conn, + operation, + True, + self._event_listeners, + unpack_res, ) - def _cmd(session, server, sock_info, read_preference): + def _cmd(session, server, conn, read_preference): operation.reset() # Reset op in case of retry. return server.run_operation( - sock_info, operation, read_preference, self._event_listeners, unpack_res + conn, operation, read_preference, self._event_listeners, unpack_res ) return self._retryable_read( @@ -1388,8 +1392,8 @@ def is_retrying(): supports_session = ( session is not None and server.description.retryable_writes_supported ) - with self._get_socket(server, session) as sock_info: - max_wire_version = sock_info.max_wire_version + with self._checkout(server, session) as conn: + max_wire_version = conn.max_wire_version if retryable and not supports_session: if is_retrying(): # A retry is not possible because this server does @@ -1397,7 +1401,7 @@ def is_retrying(): assert last_error is not None raise last_error retryable = False - return func(session, sock_info, retryable) + return func(session, conn, retryable) except ServerSelectionTimeoutError: if is_retrying(): # The application may think the write was never attempted @@ -1455,13 +1459,16 @@ def _retryable_read(self, func, read_pref, session, address=None, retryable=True raise last_error try: server = self._select_server(read_pref, session, address=address) - with self._socket_from_server(read_pref, server, session) as (sock_info, read_pref): + with self._conn_from_server(read_pref, server, session) as ( + conn, + read_pref, + ): if retrying and not retryable: # A retry is not possible because this server does # not support retryable reads, raise the last error. assert last_error is not None raise last_error - return func(session, server, sock_info, read_pref) + return func(session, server, conn, read_pref) except ServerSelectionTimeoutError: if retrying: # The application may think the write was never attempted @@ -1566,7 +1573,7 @@ def __getitem__(self, name: str) -> database.Database[_DocumentType]: return database.Database(self, name) def _cleanup_cursor( - self, locks_allowed, cursor_id, address, sock_mgr, session, explicit_session + self, locks_allowed, cursor_id, address, conn_mgr, session, explicit_session ): """Cleanup a cursor from cursor.close() or __del__. @@ -1578,33 +1585,33 @@ def _cleanup_cursor( - `locks_allowed`: True if we are allowed to acquire locks. - `cursor_id`: The cursor id which may be 0. - `address`: The _CursorAddress. - - `sock_mgr`: The _SocketManager for the pinned connection or None. + - `conn_mgr`: The _ConnectionManager for the pinned connection or None. - `session`: The cursor's session. - `explicit_session`: True if the session was passed explicitly. """ if locks_allowed: if cursor_id: - if sock_mgr and sock_mgr.more_to_come: + if conn_mgr and conn_mgr.more_to_come: # If this is an exhaust cursor and we haven't completely # exhausted the result set we *must* close the socket # to stop the server from sending more data. - sock_mgr.sock.close_socket(ConnectionClosedReason.ERROR) + conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR) else: - self._close_cursor_now(cursor_id, address, session=session, sock_mgr=sock_mgr) - if sock_mgr: - sock_mgr.close() + self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr) + if conn_mgr: + conn_mgr.close() else: # The cursor will be closed later in a different session. - if cursor_id or sock_mgr: - self._close_cursor_soon(cursor_id, address, sock_mgr) + if cursor_id or conn_mgr: + self._close_cursor_soon(cursor_id, address, conn_mgr) if session and not explicit_session: session._end_session(lock=locks_allowed) - def _close_cursor_soon(self, cursor_id, address, sock_mgr=None): + def _close_cursor_soon(self, cursor_id, address, conn_mgr=None): """Request that a cursor and/or connection be cleaned up soon.""" - self.__kill_cursors_queue.append((address, cursor_id, sock_mgr)) + self.__kill_cursors_queue.append((address, cursor_id, conn_mgr)) - def _close_cursor_now(self, cursor_id, address=None, session=None, sock_mgr=None): + def _close_cursor_now(self, cursor_id, address=None, session=None, conn_mgr=None): """Send a kill cursors message with the given id. The cursor is closed synchronously on the current thread. @@ -1613,10 +1620,10 @@ def _close_cursor_now(self, cursor_id, address=None, session=None, sock_mgr=None raise TypeError("cursor_id must be an instance of int") try: - if sock_mgr: - with sock_mgr.lock: + if conn_mgr: + with conn_mgr.lock: # Cursor is pinned to LB outside of a transaction. - self._kill_cursor_impl([cursor_id], address, session, sock_mgr.sock) + self._kill_cursor_impl([cursor_id], address, session, conn_mgr.conn) else: self._kill_cursors([cursor_id], address, self._get_topology(), session) except PyMongoError: @@ -1633,14 +1640,14 @@ def _kill_cursors(self, cursor_ids, address, topology, session): # Application called close_cursor() with no address. server = topology.select_server(writable_server_selector) - with self._get_socket(server, session) as sock_info: - self._kill_cursor_impl(cursor_ids, address, session, sock_info) + with self._checkout(server, session) as conn: + self._kill_cursor_impl(cursor_ids, address, session, conn) - def _kill_cursor_impl(self, cursor_ids, address, session, sock_info): + def _kill_cursor_impl(self, cursor_ids, address, session, conn): namespace = address.namespace db, coll = namespace.split(".", 1) spec = SON([("killCursors", coll), ("cursors", cursor_ids)]) - sock_info.command(db, spec, session=session, client=self) + conn.command(db, spec, session=session, client=self) def _process_kill_cursors(self): """Process any pending kill cursors requests.""" @@ -1650,18 +1657,18 @@ def _process_kill_cursors(self): # Other threads or the GC may append to the queue concurrently. while True: try: - address, cursor_id, sock_mgr = self.__kill_cursors_queue.pop() + address, cursor_id, conn_mgr = self.__kill_cursors_queue.pop() except IndexError: break - if sock_mgr: - pinned_cursors.append((address, cursor_id, sock_mgr)) + if conn_mgr: + pinned_cursors.append((address, cursor_id, conn_mgr)) else: address_to_cursor_ids[address].append(cursor_id) - for address, cursor_id, sock_mgr in pinned_cursors: + for address, cursor_id, conn_mgr in pinned_cursors: try: - self._cleanup_cursor(True, cursor_id, address, sock_mgr, None, False) + self._cleanup_cursor(True, cursor_id, address, conn_mgr, None, False) except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it @@ -1925,9 +1932,9 @@ def drop_database( if not isinstance(name, str): raise TypeError("name_or_database must be an instance of str or a Database") - with self._socket_for_writes(session) as sock_info: + with self._conn_for_writes(session) as conn: self[name]._command( - sock_info, + conn, {"dropDatabase": 1, "comment": comment}, read_preference=ReadPreference.PRIMARY, write_concern=self._write_concern_for(session), @@ -2149,11 +2156,11 @@ def __init__(self, client, server, session): self.service_id = None self.handled = False - def contribute_socket(self, sock_info, completed_handshake=True): + def contribute_socket(self, conn, completed_handshake=True): """Provide socket information to the error handler.""" - self.max_wire_version = sock_info.max_wire_version - self.sock_generation = sock_info.generation - self.service_id = sock_info.service_id + self.max_wire_version = conn.max_wire_version + self.sock_generation = conn.generation + self.service_id = conn.service_id self.completed_handshake = completed_handshake def handle(self, exc_type, exc_val): diff --git a/pymongo/monitor.py b/pymongo/monitor.py index 2fc0bf8bab..117b9454a9 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -245,9 +245,9 @@ def _check_once(self): if self._cancel_context and self._cancel_context.cancelled: self._reset_connection() - with self._pool.get_socket() as sock_info: - self._cancel_context = sock_info.cancel_context - response, round_trip_time = self._check_with_socket(sock_info) + with self._pool.checkout() as conn: + self._cancel_context = conn.cancel_context + response, round_trip_time = self._check_with_socket(conn) if not response.awaitable: self._rtt_monitor.add_sample(round_trip_time) @@ -393,11 +393,11 @@ def _run(self): def _ping(self): """Run a "hello" command and return the RTT.""" - with self._pool.get_socket() as sock_info: + with self._pool.checkout() as conn: if self._executor._stopped: raise Exception("_RttMonitor closed") start = time.monotonic() - sock_info.hello() + conn.hello() return time.monotonic() - start diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index 24ac7f06bc..2a3a662c9b 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -135,15 +135,15 @@ def pool_closed(self, event): logging.info("[pool {0.address}] pool closed".format(event)) def connection_created(self, event): - logging.info("[pool {0.address}][conn #{0.connection_id}] " + logging.info("[pool {0.address}][connection #{0.connection_id}] " "connection created".format(event)) def connection_ready(self, event): - logging.info("[pool {0.address}][conn #{0.connection_id}] " + logging.info("[pool {0.address}][connection #{0.connection_id}] " "connection setup succeeded".format(event)) def connection_closed(self, event): - logging.info("[pool {0.address}][conn #{0.connection_id}] " + logging.info("[pool {0.address}][connection #{0.connection_id}] " "connection closed, reason: " "{0.reason}".format(event)) @@ -156,11 +156,11 @@ def connection_check_out_failed(self, event): "failed, reason: {0.reason}".format(event)) def connection_checked_out(self, event): - logging.info("[pool {0.address}][conn #{0.connection_id}] " + logging.info("[pool {0.address}][connection #{0.connection_id}] " "connection checked out of pool".format(event)) def connection_checked_in(self, event): - logging.info("[pool {0.address}][conn #{0.connection_id}] " + logging.info("[pool {0.address}][connection #{0.connection_id}] " "connection checked into pool".format(event)) @@ -268,7 +268,7 @@ class ConnectionPoolListener(_EventListener): def pool_created(self, event: "PoolCreatedEvent") -> None: """Abstract method to handle a :class:`PoolCreatedEvent`. - Emitted when a Connection Pool is created. + Emitted when a connection Pool is created. :Parameters: - `event`: An instance of :class:`PoolCreatedEvent`. @@ -278,7 +278,7 @@ def pool_created(self, event: "PoolCreatedEvent") -> None: def pool_ready(self, event: "PoolReadyEvent") -> None: """Abstract method to handle a :class:`PoolReadyEvent`. - Emitted when a Connection Pool is marked ready. + Emitted when a connection Pool is marked ready. :Parameters: - `event`: An instance of :class:`PoolReadyEvent`. @@ -290,7 +290,7 @@ def pool_ready(self, event: "PoolReadyEvent") -> None: def pool_cleared(self, event: "PoolClearedEvent") -> None: """Abstract method to handle a `PoolClearedEvent`. - Emitted when a Connection Pool is cleared. + Emitted when a connection Pool is cleared. :Parameters: - `event`: An instance of :class:`PoolClearedEvent`. @@ -300,7 +300,7 @@ def pool_cleared(self, event: "PoolClearedEvent") -> None: def pool_closed(self, event: "PoolClosedEvent") -> None: """Abstract method to handle a `PoolClosedEvent`. - Emitted when a Connection Pool is closed. + Emitted when a connection Pool is closed. :Parameters: - `event`: An instance of :class:`PoolClosedEvent`. @@ -310,7 +310,7 @@ def pool_closed(self, event: "PoolClosedEvent") -> None: def connection_created(self, event: "ConnectionCreatedEvent") -> None: """Abstract method to handle a :class:`ConnectionCreatedEvent`. - Emitted when a Connection Pool creates a Connection object. + Emitted when a connection Pool creates a Connection object. :Parameters: - `event`: An instance of :class:`ConnectionCreatedEvent`. @@ -320,7 +320,7 @@ def connection_created(self, event: "ConnectionCreatedEvent") -> None: def connection_ready(self, event: "ConnectionReadyEvent") -> None: """Abstract method to handle a :class:`ConnectionReadyEvent`. - Emitted when a Connection has finished its setup, and is now ready to + Emitted when a connection has finished its setup, and is now ready to use. :Parameters: @@ -331,7 +331,7 @@ def connection_ready(self, event: "ConnectionReadyEvent") -> None: def connection_closed(self, event: "ConnectionClosedEvent") -> None: """Abstract method to handle a :class:`ConnectionClosedEvent`. - Emitted when a Connection Pool closes a Connection. + Emitted when a connection Pool closes a connection. :Parameters: - `event`: An instance of :class:`ConnectionClosedEvent`. @@ -361,7 +361,7 @@ def connection_check_out_failed(self, event: "ConnectionCheckOutFailedEvent") -> def connection_checked_out(self, event: "ConnectionCheckedOutEvent") -> None: """Abstract method to handle a :class:`ConnectionCheckedOutEvent`. - Emitted when the driver successfully checks out a Connection. + Emitted when the driver successfully checks out a connection. :Parameters: - `event`: An instance of :class:`ConnectionCheckedOutEvent`. @@ -371,7 +371,7 @@ def connection_checked_out(self, event: "ConnectionCheckedOutEvent") -> None: def connection_checked_in(self, event: "ConnectionCheckedInEvent") -> None: """Abstract method to handle a :class:`ConnectionCheckedInEvent`. - Emitted when the driver checks in a Connection back to the Connection + Emitted when the driver checks in a connection back to the connection Pool. :Parameters: @@ -948,7 +948,7 @@ def __init__(self, address: _Address, connection_id: int) -> None: @property def connection_id(self) -> int: - """The ID of the Connection.""" + """The ID of the connection.""" return self.__connection_id def __repr__(self): @@ -1066,7 +1066,7 @@ def __repr__(self): class ConnectionCheckedOutEvent(_ConnectionIdEvent): - """Published when the driver successfully checks out a Connection. + """Published when the driver successfully checks out a connection. :Parameters: - `address`: The address (host, port) pair of the server this diff --git a/pymongo/network.py b/pymongo/network.py index 4cff1e5294..139f7b2aec 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -52,7 +52,7 @@ from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.mongo_client import MongoClient from pymongo.monitoring import _EventListeners - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode from pymongo.typings import _Address @@ -62,7 +62,7 @@ def command( - sock_info: SocketInfo, + conn: Connection, dbname: str, spec: MutableMapping[str, Any], is_mongos: bool, @@ -88,7 +88,7 @@ def command( """Execute a command over the socket, or raise socket.error. :Parameters: - - `sock`: a raw socket instance + - `conn`: a Connection instance - `dbname`: name of the database on which to run the command - `spec`: a command document as an ordered dict type, eg SON. - `is_mongos`: are we connected to a mongos? @@ -98,7 +98,7 @@ def command( - `client`: optional MongoClient instance for updating $clusterTime. - `check`: raise OperationFailure if there are errors - `allowable_errors`: errors to ignore if `check` is True - - `address`: the (host, port) of `sock` + - `address`: the (host, port) of `conn` - `listeners`: An instance of :class:`~pymongo.monitoring.EventListeners` - `max_bson_size`: The maximum encoded bson size for this server - `read_concern`: The read concern for this command. @@ -125,7 +125,7 @@ def command( if read_concern.level: spec["readConcern"] = read_concern.document if session: - session._update_read_concern(spec, sock_info) + session._update_read_concern(spec, conn) if collation is not None: spec["collation"] = collation @@ -142,7 +142,7 @@ def command( # Support CSOT if client: - sock_info.apply_timeout(client, spec) + conn.apply_timeout(client, spec) _csot.apply_write_concern(spec, write_concern) if use_op_msg: @@ -167,19 +167,19 @@ def command( encoding_duration = datetime.datetime.now() - start assert listeners is not None listeners.publish_command_start( - orig, dbname, request_id, address, service_id=sock_info.service_id + orig, dbname, request_id, address, service_id=conn.service_id ) start = datetime.datetime.now() try: - sock_info.sock.sendall(msg) + conn.conn.sendall(msg) if use_op_msg and unacknowledged: # Unacknowledged, fake a successful command response. reply = None response_doc = {"ok": 1} else: - reply = receive_message(sock_info, request_id) - sock_info.more_to_come = reply.more_to_come + reply = receive_message(conn, request_id) + conn.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response( codec_options=codec_options, user_fields=user_fields ) @@ -190,7 +190,7 @@ def command( if check: helpers._check_command_response( response_doc, - sock_info.max_wire_version, + conn.max_wire_version, allowable_errors, parse_write_concern_error=parse_write_concern_error, ) @@ -203,7 +203,7 @@ def command( failure = message._convert_exception(exc) assert listeners is not None listeners.publish_command_failure( - duration, failure, name, request_id, address, service_id=sock_info.service_id + duration, failure, name, request_id, address, service_id=conn.service_id ) raise if publish: @@ -215,7 +215,7 @@ def command( name, request_id, address, - service_id=sock_info.service_id, + service_id=conn.service_id, speculative_hello=speculative_hello, ) @@ -230,21 +230,19 @@ def command( def receive_message( - sock_info: SocketInfo, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE + conn: Connection, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE ) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise socket.error.""" if _csot.get_timeout(): deadline = _csot.get_deadline() else: - timeout = sock_info.sock.gettimeout() + timeout = conn.conn.gettimeout() if timeout: deadline = time.monotonic() + timeout else: deadline = None # Ignore the response's request id. - length, _, response_to, op_code = _UNPACK_HEADER( - _receive_data_on_socket(sock_info, 16, deadline) - ) + length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline)) # No request_id for exhaust cursor "getMore". if request_id is not None: if request_id != response_to: @@ -260,11 +258,11 @@ def receive_message( ) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - _receive_data_on_socket(sock_info, 9, deadline) + _receive_data_on_socket(conn, 9, deadline) ) - data = decompress(_receive_data_on_socket(sock_info, length - 25, deadline), compressor_id) + data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id) else: - data = _receive_data_on_socket(sock_info, length - 16, deadline) + data = _receive_data_on_socket(conn, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] @@ -276,12 +274,12 @@ def receive_message( _POLL_TIMEOUT = 0.5 -def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None: +def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: """Block until at least one byte is read, or a timeout, or a cancel.""" - context = sock_info.cancel_context + context = conn.cancel_context # Only Monitor connections can be cancelled. if context: - sock = sock_info.sock + sock = conn.conn timed_out = False while True: # SSLSocket can have buffered data which won't be caught by select. @@ -300,7 +298,7 @@ def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None: timeout = max(min(remaining, _POLL_TIMEOUT), 0) else: timeout = _POLL_TIMEOUT - readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout) + readable = conn.socket_checker.select(sock, read=True, timeout=timeout) if context.cancelled: raise _OperationCancelled("hello cancelled") if readable: @@ -313,21 +311,19 @@ def wait_for_read(sock_info: SocketInfo, deadline: Optional[float]) -> None: BLOCKING_IO_ERRORS = (BlockingIOError, *ssl_support.BLOCKING_IO_ERRORS) -def _receive_data_on_socket( - sock_info: SocketInfo, length: int, deadline: Optional[float] -) -> memoryview: +def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: buf = bytearray(length) mv = memoryview(buf) bytes_read = 0 while bytes_read < length: try: - wait_for_read(sock_info, deadline) + wait_for_read(conn, deadline) # CSOT: Update timeout. When the timeout has expired perform one # final non-blocking recv. This helps avoid spurious timeouts when # the response is actually already buffered on the client. if _csot.get_timeout() and deadline is not None: - sock_info.set_socket_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = sock_info.sock.recv_into(mv[bytes_read:]) + conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) + chunk_length = conn.conn.recv_into(mv[bytes_read:]) except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") except OSError as exc: # noqa: B014 diff --git a/pymongo/pool.py b/pymongo/pool.py index a827d10f9c..9baaecc715 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -614,19 +614,19 @@ def cancelled(self): return self._cancelled -class SocketInfo: - """Store a socket with some metadata. +class Connection: + """Store a connection with some metadata. :Parameters: - - `sock`: a raw socket object + - `conn`: a raw connection object - `pool`: a Pool instance - `address`: the server's (host, port) - `id`: the id of this socket in it's pool """ - def __init__(self, sock, pool, address, id): + def __init__(self, conn, pool, address, id): self.pool_ref = weakref.ref(pool) - self.sock = sock + self.conn = conn self.address = address self.id = id self.authed = set() @@ -673,12 +673,12 @@ def __init__(self, sock, pool, address, id): self.last_timeout = self.opts.socket_timeout self.connect_rtt = 0.0 - def set_socket_timeout(self, timeout): - """Cache last timeout to avoid duplicate calls to sock.settimeout.""" + def set_conn_timeout(self, timeout): + """Cache last timeout to avoid duplicate calls to conn.settimeout.""" if timeout == self.last_timeout: return self.last_timeout = timeout - self.sock.settimeout(timeout) + self.conn.settimeout(timeout) def apply_timeout(self, client, cmd): # CSOT: use remaining timeout when set. @@ -686,7 +686,7 @@ def apply_timeout(self, client, cmd): if timeout is None: # Reset the socket timeout unless we're performing a streaming monitor check. if not self.more_to_come: - self.set_socket_timeout(self.opts.socket_timeout) + self.set_conn_timeout(self.opts.socket_timeout) return None # RTT validation. rtt = _csot.get_rtt() @@ -701,7 +701,7 @@ def apply_timeout(self, client, cmd): ) if cmd is not None: cmd["maxTimeMS"] = int(max_time_ms * 1000) - self.set_socket_timeout(timeout) + self.set_conn_timeout(timeout) return timeout def pin_txn(self): @@ -715,9 +715,9 @@ def pin_cursor(self): def unpin(self): pool = self.pool_ref() if pool: - pool.return_socket(self) + pool.checkin(self) else: - self.close_socket(ConnectionClosedReason.STALE) + self.close_conn(ConnectionClosedReason.STALE) def hello_cmd(self): # Handshake spec requires us to use OP_MSG+hello command for the @@ -748,7 +748,7 @@ def _hello(self, cluster_time, topology_version, heartbeat_frequency): awaitable = True # If connect_timeout is None there is no timeout. if self.opts.connect_timeout: - self.set_socket_timeout(self.opts.connect_timeout + heartbeat_frequency) + self.set_conn_timeout(self.opts.connect_timeout + heartbeat_frequency) if not performing_handshake and cluster_time is not None: cmd["$clusterTime"] = cluster_time @@ -919,7 +919,7 @@ def send_message(self, message, max_doc_size): ) try: - self.sock.sendall(message) + self.conn.sendall(message) except BaseException as error: self._raise_connection_failure(error) @@ -999,15 +999,15 @@ def validate_session(self, client, session): if session._client is not client: raise InvalidOperation("Can only use session with the MongoClient that started it") - def close_socket(self, reason): + def close_conn(self, reason): """Close this connection with a reason.""" if self.closed: return - self._close_socket() + self._close_conn() if reason and self.enabled_for_cmap: self.listeners.publish_connection_closed(self.address, self.id, reason) - def _close_socket(self): + def _close_conn(self): """Close this connection.""" if self.closed: return @@ -1017,13 +1017,13 @@ def _close_socket(self): # Note: We catch exceptions to avoid spurious errors on interpreter # shutdown. try: - self.sock.close() + self.conn.close() except Exception: pass - def socket_closed(self): + def conn_closed(self): """Return True if we know socket has been closed, False otherwise.""" - return self.socket_checker.socket_closed(self.sock) + return self.socket_checker.socket_closed(self.conn) def send_cluster_time(self, command, session, client): """Add $clusterTime.""" @@ -1060,12 +1060,12 @@ def _raise_connection_failure(self, error): # KeyboardInterrupt from the start, rather than as an initial # socket.error, so we catch that, close the socket, and reraise it. # - # The connection closed event will be emitted later in return_socket. + # The connection closed event will be emitted later in checkin. if self.ready: reason = None else: reason = ConnectionClosedReason.ERROR - self.close_socket(reason) + self.close_conn(reason) # SSLError from PyOpenSSL inherits directly from Exception. if isinstance(error, (IOError, OSError, SSLError)): _raise_connection_failure(self.address, error) @@ -1073,17 +1073,17 @@ def _raise_connection_failure(self, error): raise def __eq__(self, other): - return self.sock == other.sock + return self.conn == other.conn def __ne__(self, other): return not self == other def __hash__(self): - return hash(self.sock) + return hash(self.conn) def __repr__(self): - return "SocketInfo({}){} at {}".format( - repr(self.sock), + return "Connection({}){} at {}".format( + repr(self.conn), self.closed and " CLOSED" or "", id(self), ) @@ -1256,7 +1256,7 @@ def __init__(self, address, options, handshake=True): :Parameters: - `address`: a (hostname, port) tuple - `options`: a PoolOptions instance - - `handshake`: whether to call hello for each new SocketInfo + - `handshake`: whether to call hello for each new Connection """ if options.pause_enabled: self.state = PoolState.PAUSED @@ -1268,7 +1268,7 @@ def __init__(self, address, options, handshake=True): # LIFO pool. Sockets are ordered on idle time. Sockets claimed # and returned to pool from the left side. Stale sockets removed # from the right side. - self.sockets: collections.deque = collections.deque() + self.conns: collections.deque = collections.deque() self.lock = _create_lock() self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. @@ -1344,17 +1344,17 @@ def _reset(self, close, pause=True, service_id=None): self.active_sockets = 0 self.operation_count = 0 if service_id is None: - sockets, self.sockets = self.sockets, collections.deque() + sockets, self.conns = self.conns, collections.deque() else: discard: collections.deque = collections.deque() keep: collections.deque = collections.deque() - for sock_info in self.sockets: - if sock_info.service_id == service_id: - discard.append(sock_info) + for conn in self.conns: + if conn.service_id == service_id: + discard.append(conn) else: - keep.append(sock_info) + keep.append(conn) sockets = discard - self.sockets = keep + self.conns = keep if close: self.state = PoolState.CLOSED @@ -1367,15 +1367,15 @@ def _reset(self, close, pause=True, service_id=None): # PoolClosedEvent but that reset() SHOULD close sockets *after* # publishing the PoolClearedEvent. if close: - for sock_info in sockets: - sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED) + for conn in sockets: + conn.close_conn(ConnectionClosedReason.POOL_CLOSED) if self.enabled_for_cmap: listeners.publish_pool_closed(self.address) else: if old_state != PoolState.PAUSED and self.enabled_for_cmap: listeners.publish_pool_cleared(self.address, service_id=service_id) - for sock_info in sockets: - sock_info.close_socket(ConnectionClosedReason.STALE) + for conn in sockets: + conn.close_conn(ConnectionClosedReason.STALE) def update_is_writable(self, is_writable): """Updates the is_writable attribute on all sockets currently in the @@ -1383,7 +1383,7 @@ def update_is_writable(self, is_writable): """ self.is_writable = is_writable with self.lock: - for _socket in self.sockets: + for _socket in self.conns: _socket.update_is_writable(self.is_writable) def reset(self, service_id=None): @@ -1412,16 +1412,16 @@ def remove_stale_sockets(self, reference_generation): if self.opts.max_idle_time_seconds is not None: with self.lock: while ( - self.sockets - and self.sockets[-1].idle_time_seconds() > self.opts.max_idle_time_seconds + self.conns + and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds ): - sock_info = self.sockets.pop() - sock_info.close_socket(ConnectionClosedReason.IDLE) + conn = self.conns.pop() + conn.close_conn(ConnectionClosedReason.IDLE) while True: with self.size_cond: # There are enough sockets in the pool. - if len(self.sockets) + self.active_sockets >= self.opts.min_pool_size: + if len(self.conns) + self.active_sockets >= self.opts.min_pool_size: return if self.requests >= self.opts.min_pool_size: return @@ -1435,14 +1435,14 @@ def remove_stale_sockets(self, reference_generation): return self._pending += 1 incremented = True - sock_info = self.connect() + conn = self.connect() with self.lock: # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. if self.gen.get_overall() != reference_generation: - sock_info.close_socket(ConnectionClosedReason.STALE) + conn.close_conn(ConnectionClosedReason.STALE) return - self.sockets.appendleft(sock_info) + self.conns.appendleft(conn) finally: if incremented: # Notify after adding the socket to the pool. @@ -1455,12 +1455,12 @@ def remove_stale_sockets(self, reference_generation): self.size_cond.notify() def connect(self, handler=None): - """Connect to Mongo and return a new SocketInfo. + """Connect to Mongo and return a new Connection. Can raise ConnectionFailure. Note that the pool does not keep a reference to the socket -- you - must call return_socket() when you're done with it. + must call checkin() when you're done with it. """ with self.lock: conn_id = self.next_connection_id @@ -1483,33 +1483,33 @@ def connect(self, handler=None): raise - sock_info = SocketInfo(sock, self, self.address, conn_id) + conn = Connection(sock, self, self.address, conn_id) try: if self.handshake: - sock_info.hello() - self.is_writable = sock_info.is_writable + conn.hello() + self.is_writable = conn.is_writable if handler: - handler.contribute_socket(sock_info, completed_handshake=False) + handler.contribute_socket(conn, completed_handshake=False) - sock_info.authenticate() + conn.authenticate() except BaseException: - sock_info.close_socket(ConnectionClosedReason.ERROR) + conn.close_conn(ConnectionClosedReason.ERROR) raise - return sock_info + return conn @contextlib.contextmanager - def get_socket(self, handler=None): - """Get a socket from the pool. Use with a "with" statement. + def checkout(self, handler=None): + """Get a connection from the pool. Use with a "with" statement. - Returns a :class:`SocketInfo` object wrapping a connected + Returns a :class:`Connection` object wrapping a connected :class:`socket.socket`. This method should always be used in a with-statement:: - with pool.get_socket() as socket_info: - socket_info.send_message(msg) - data = socket_info.receive_message(op_code, request_id) + with pool.get_conn() as connection: + connection.send_message(msg) + data = connection.receive_message(op_code, request_id) Can raise ConnectionFailure or OperationFailure. @@ -1520,36 +1520,36 @@ def get_socket(self, handler=None): if self.enabled_for_cmap: listeners.publish_connection_check_out_started(self.address) - sock_info = self._get_socket(handler=handler) + conn = self._get_conn(handler=handler) if self.enabled_for_cmap: - listeners.publish_connection_checked_out(self.address, sock_info.id) + listeners.publish_connection_checked_out(self.address, conn.id) try: - yield sock_info + yield conn except BaseException: # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the # connection and it is responsible for checking the connection # back into the pool. - pinned = sock_info.pinned_txn or sock_info.pinned_cursor + pinned = conn.pinned_txn or conn.pinned_cursor if handler: # Perform SDAM error handling rules while the connection is # still checked out. exc_type, exc_val, _ = sys.exc_info() handler.handle(exc_type, exc_val) - if not pinned and sock_info.active: - self.return_socket(sock_info) + if not pinned and conn.active: + self.checkin(conn) raise - if sock_info.pinned_txn: + if conn.pinned_txn: with self.lock: - self.__pinned_sockets.add(sock_info) + self.__pinned_sockets.add(conn) self.ntxns += 1 - elif sock_info.pinned_cursor: + elif conn.pinned_cursor: with self.lock: - self.__pinned_sockets.add(sock_info) + self.__pinned_sockets.add(conn) self.ncursors += 1 - elif sock_info.active: - self.return_socket(sock_info) + elif conn.active: + self.checkin(conn) def _raise_if_not_ready(self, emit_event): if self.state != PoolState.READY: @@ -1559,8 +1559,8 @@ def _raise_if_not_ready(self, emit_event): ) _raise_connection_failure(self.address, AutoReconnect("connection pool paused")) - def _get_socket(self, handler=None): - """Get or create a SocketInfo. Can raise ConnectionFailure.""" + def _get_conn(self, handler=None): + """Get or create a Connection. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. # See test.test_client:TestClient.test_fork for an example of # what could go wrong otherwise @@ -1600,7 +1600,7 @@ def _get_socket(self, handler=None): self.requests += 1 # We've now acquired the semaphore and must release it on error. - sock_info = None + conn = None incremented = False emitted_event = False try: @@ -1608,40 +1608,40 @@ def _get_socket(self, handler=None): self.active_sockets += 1 incremented = True - while sock_info is None: + while conn is None: # CMAP: we MUST wait for either maxConnecting OR for a socket # to be checked back into the pool. with self._max_connecting_cond: self._raise_if_not_ready(emit_event=False) - while not (self.sockets or self._pending < self._max_connecting): + while not (self.conns or self._pending < self._max_connecting): if not _cond_wait(self._max_connecting_cond, deadline): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. - if self.sockets or self._pending < self._max_connecting: + if self.conns or self._pending < self._max_connecting: self._max_connecting_cond.notify() emitted_event = True self._raise_wait_queue_timeout() self._raise_if_not_ready(emit_event=False) try: - sock_info = self.sockets.popleft() + conn = self.conns.popleft() except IndexError: self._pending += 1 - if sock_info: # We got a socket from the pool - if self._perished(sock_info): - sock_info = None + if conn: # We got a socket from the pool + if self._perished(conn): + conn = None continue else: # We need to create a new connection try: - sock_info = self.connect(handler=handler) + conn = self.connect(handler=handler) finally: with self._max_connecting_cond: self._pending -= 1 self._max_connecting_cond.notify() except BaseException: - if sock_info: + if conn: # We checked out a socket but authentication failed. - sock_info.close_socket(ConnectionClosedReason.ERROR) + conn.close_conn(ConnectionClosedReason.ERROR) with self.size_cond: self.requests -= 1 if incremented: @@ -1654,45 +1654,45 @@ def _get_socket(self, handler=None): ) raise - sock_info.active = True - return sock_info + conn.active = True + return conn - def return_socket(self, sock_info): - """Return the socket to the pool, or if it's closed discard it. + def checkin(self, conn): + """Return the connection to the pool, or if it's closed discard it. :Parameters: - - `sock_info`: The socket to check into the pool. + - `conn`: The connection to check into the pool. """ - txn = sock_info.pinned_txn - cursor = sock_info.pinned_cursor - sock_info.active = False - sock_info.pinned_txn = False - sock_info.pinned_cursor = False - self.__pinned_sockets.discard(sock_info) + txn = conn.pinned_txn + cursor = conn.pinned_cursor + conn.active = False + conn.pinned_txn = False + conn.pinned_cursor = False + self.__pinned_sockets.discard(conn) listeners = self.opts._event_listeners if self.enabled_for_cmap: - listeners.publish_connection_checked_in(self.address, sock_info.id) + listeners.publish_connection_checked_in(self.address, conn.id) if self.pid != os.getpid(): self.reset_without_pause() else: if self.closed: - sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED) - elif sock_info.closed: + conn.close_conn(ConnectionClosedReason.POOL_CLOSED) + elif conn.closed: # CMAP requires the closed event be emitted after the check in. if self.enabled_for_cmap: listeners.publish_connection_closed( - self.address, sock_info.id, ConnectionClosedReason.ERROR + self.address, conn.id, ConnectionClosedReason.ERROR ) else: with self.lock: # Hold the lock to ensure this section does not race with # Pool.reset(). - if self.stale_generation(sock_info.generation, sock_info.service_id): - sock_info.close_socket(ConnectionClosedReason.STALE) + if self.stale_generation(conn.generation, conn.service_id): + conn.close_conn(ConnectionClosedReason.STALE) else: - sock_info.update_last_checkin_time() - sock_info.update_is_writable(self.is_writable) - self.sockets.appendleft(sock_info) + conn.update_last_checkin_time() + conn.update_is_writable(self.is_writable) + self.conns.appendleft(conn) # Notify any threads waiting to create a connection. self._max_connecting_cond.notify() @@ -1706,7 +1706,7 @@ def return_socket(self, sock_info): self.operation_count -= 1 self.size_cond.notify() - def _perished(self, sock_info): + def _perished(self, conn): """Return True and close the connection if it is "perished". This side-effecty function checks if this socket has been idle for @@ -1720,24 +1720,24 @@ def _perished(self, sock_info): pool, to keep performance reasonable - we can't avoid AutoReconnects completely anyway. """ - idle_time_seconds = sock_info.idle_time_seconds() + idle_time_seconds = conn.idle_time_seconds() # If socket is idle, open a new one. if ( self.opts.max_idle_time_seconds is not None and idle_time_seconds > self.opts.max_idle_time_seconds ): - sock_info.close_socket(ConnectionClosedReason.IDLE) + conn.close_conn(ConnectionClosedReason.IDLE) return True if self._check_interval_seconds is not None and ( 0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds ): - if sock_info.socket_closed(): - sock_info.close_socket(ConnectionClosedReason.ERROR) + if conn.conn_closed(): + conn.close_conn(ConnectionClosedReason.ERROR) return True - if self.stale_generation(sock_info.generation, sock_info.service_id): - sock_info.close_socket(ConnectionClosedReason.STALE) + if self.stale_generation(conn.generation, conn.service_id): + conn.close_conn(ConnectionClosedReason.STALE) return True return False @@ -1772,5 +1772,5 @@ def __del__(self): # Avoid ResourceWarnings in Python 3 # Close all sockets without calling reset() or close() because it is # not safe to acquire a lock in __del__. - for sock_info in self.sockets: - sock_info.close_socket(None) + for conn in self.conns: + conn.close_conn(None) diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index d6762bcaa2..140e6ba841 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -81,7 +81,7 @@ def _is_ip_address(address): return False -# According to the docs for Connection.send it can raise +# According to the docs for socket.send it can raise # WantX509LookupError and should be retried. BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError) @@ -347,7 +347,7 @@ def wrap_socket( server_hostname=None, session=None, ): - """Wrap an existing Python socket sock and return a TLS socket + """Wrap an existing Python socket connection and return a TLS socket object. """ ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs) diff --git a/pymongo/response.py b/pymongo/response.py index bd4795bfb0..f94b0c447b 100644 --- a/pymongo/response.py +++ b/pymongo/response.py @@ -21,7 +21,7 @@ from datetime import timedelta from pymongo.message import _OpMsg, _OpReply - from pymongo.pool import SocketInfo + from pymongo.pool import Connection from pymongo.typings import _Address @@ -85,13 +85,13 @@ def docs(self) -> List[Mapping[str, Any]]: class PinnedResponse(Response): - __slots__ = ("_socket_info", "_more_to_come") + __slots__ = ("_conn", "_more_to_come") def __init__( self, data: Union[_OpMsg, _OpReply], address: _Address, - socket_info: SocketInfo, + conn: Connection, request_id: int, duration: Optional[timedelta], from_command: bool, @@ -103,7 +103,7 @@ def __init__( :Parameters: - `data`: A network response message. - `address`: (host, port) of the source server. - - `socket_info`: The SocketInfo used for the initial query. + - `conn`: The Connection used for the initial query. - `request_id`: The request id of this operation. - `duration`: The duration of the operation. - `from_command`: If the response is the result of a db command. @@ -112,18 +112,18 @@ def __init__( exhausted. """ super().__init__(data, address, request_id, duration, from_command, docs) - self._socket_info = socket_info + self._conn = conn self._more_to_come = more_to_come @property - def socket_info(self) -> SocketInfo: - """The SocketInfo used for the initial query. + def conn(self) -> Connection: + """The Connection used for the initial query. The server will send batches on this socket, without waiting for getMores from the client, until the result set is exhausted or there is an error. """ - return self._socket_info + return self._conn @property def more_to_come(self) -> bool: diff --git a/pymongo/server.py b/pymongo/server.py index 349af4a41d..2ea9327f56 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -42,7 +42,7 @@ from pymongo.mongo_client import _MongoClientErrorHandler from pymongo.monitor import Monitor from pymongo.monitoring import _EventListeners - from pymongo.pool import Pool, SocketInfo + from pymongo.pool import Connection, Pool from pymongo.server_description import ServerDescription _CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} @@ -105,7 +105,7 @@ def request_check(self) -> None: @_handle_reauth def run_operation( self, - sock_info: SocketInfo, + conn: Connection, operation: Union[_Query, _GetMore], read_preference: bool, listeners: _EventListeners, @@ -118,7 +118,7 @@ def run_operation( Can raise ConnectionFailure, OperationFailure, etc. :Parameters: - - `sock_info`: A SocketInfo instance. + - `conn`: A Connection instance. - `operation`: A _Query or _GetMore object. - `read_preference`: The read preference to use. - `listeners`: Instance of _EventListeners or None. @@ -129,27 +129,27 @@ def run_operation( if publish: start = datetime.now() - use_cmd = operation.use_command(sock_info) - more_to_come = operation.sock_mgr and operation.sock_mgr.more_to_come + use_cmd = operation.use_command(conn) + more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come if more_to_come: request_id = 0 else: - message = operation.get_message(read_preference, sock_info, use_cmd) + message = operation.get_message(read_preference, conn, use_cmd) request_id, data, max_doc_size = self._split_message(message) if publish: - cmd, dbn = operation.as_command(sock_info) + cmd, dbn = operation.as_command(conn) listeners.publish_command_start( - cmd, dbn, request_id, sock_info.address, service_id=sock_info.service_id + cmd, dbn, request_id, conn.address, service_id=conn.service_id ) start = datetime.now() try: if more_to_come: - reply = sock_info.receive_message(None) + reply = conn.receive_message(None) else: - sock_info.send_message(data, max_doc_size) - reply = sock_info.receive_message(request_id) + conn.send_message(data, max_doc_size) + reply = conn.receive_message(request_id) # Unpack and check for command errors. if use_cmd: @@ -168,7 +168,7 @@ def run_operation( if use_cmd: first = docs[0] operation.client._process_response(first, operation.session) - _check_command_response(first, sock_info.max_wire_version) + _check_command_response(first, conn.max_wire_version) except Exception as exc: if publish: duration = datetime.now() - start @@ -181,8 +181,8 @@ def run_operation( failure, operation.name, request_id, - sock_info.address, - service_id=sock_info.service_id, + conn.address, + service_id=conn.service_id, ) raise @@ -205,8 +205,8 @@ def run_operation( res, operation.name, request_id, - sock_info.address, - service_id=sock_info.service_id, + conn.address, + service_id=conn.service_id, ) # Decrypt response. @@ -219,7 +219,7 @@ def run_operation( response: Response if client._should_pin_cursor(operation.session) or operation.exhaust: - sock_info.pin_cursor() + conn.pin_cursor() if isinstance(reply, _OpMsg): # In OP_MSG, the server keeps sending only if the # more_to_come flag is set. @@ -227,12 +227,12 @@ def run_operation( else: # In OP_REPLY, the server keeps sending until cursor_id is 0. more_to_come = bool(operation.exhaust and reply.cursor_id) - if operation.sock_mgr: - operation.sock_mgr.update_exhaust(more_to_come) + if operation.conn_mgr: + operation.conn_mgr.update_exhaust(more_to_come) response = PinnedResponse( data=reply, address=self._description.address, - socket_info=sock_info, + conn=conn, duration=duration, request_id=request_id, from_command=use_cmd, @@ -251,10 +251,10 @@ def run_operation( return response - def get_socket( + def checkout( self, handler: Optional[_MongoClientErrorHandler] = None - ) -> ContextManager[SocketInfo]: - return self.pool.get_socket(handler) + ) -> ContextManager[Connection]: + return self.pool.checkout(handler) @property def description(self) -> ServerDescription: diff --git a/test/mockupdb/test_handshake.py b/test/mockupdb/test_handshake.py index 883d518f5b..3d002cbbf1 100644 --- a/test/mockupdb/test_handshake.py +++ b/test/mockupdb/test_handshake.py @@ -106,7 +106,7 @@ def test_client_handshake_data(self): self.addCleanup(client.close) - # New monitoring sockets send data during handshake. + # New monitoring connections send data during handshake. heartbeat = primary.receives("ismaster") _check_handshake_data(heartbeat) heartbeat.ok(primary_response) @@ -169,7 +169,7 @@ def test_client_handshake_saslSupportedMechs(self): self.addCleanup(client.close) - # New monitoring sockets send data during handshake. + # New monitoring connections send data during handshake. heartbeat = server.receives("ismaster") heartbeat.ok(primary_response) diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 2e7fda21e0..4810809333 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -38,7 +38,7 @@ def __init__(self, client, pair, *args, **kwargs): Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs) @contextlib.contextmanager - def get_socket(self, handler=None): + def checkout(self, handler=None): client = self.client host_and_port = f"{self.mock_host}:{self.mock_port}" if host_and_port in client.mock_down_hosts: @@ -48,10 +48,10 @@ def get_socket(self, handler=None): client.mock_standalones + client.mock_members + client.mock_mongoses ), ("bad host: %s" % host_and_port) - with Pool.get_socket(self, handler) as sock_info: - sock_info.mock_host = self.mock_host - sock_info.mock_port = self.mock_port - yield sock_info + with Pool.checkout(self, handler) as conn: + conn.mock_host = self.mock_host + conn.mock_port = self.mock_port + yield conn class DummyMonitor: diff --git a/test/test_auth.py b/test/test_auth.py index f9a9af4d5a..160e718c09 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -60,7 +60,7 @@ class AutoAuthenticateThread(threading.Thread): """Used in testing threaded authentication. This does collection.find_one() with a 1-second delay to ensure it must - check out and authenticate multiple sockets from the pool concurrently. + check out and authenticate multiple connections from the pool concurrently. :Parameters: `collection`: An auth-protected collection containing one document. @@ -217,7 +217,7 @@ def test_gssapi_threaded(self): # Need one document in the collection. AutoAuthenticateThread does # collection.find_one with a 1-second delay, forcing it to check out - # multiple sockets from the pool concurrently, proving that + # multiple connections from the pool concurrently, proving that # auto-authentication works with GSSAPI. collection = db.test if not collection.count_documents({}): diff --git a/test/test_client.py b/test/test_client.py index bba6b37287..24f4603b27 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -96,7 +96,7 @@ ) from pymongo.mongo_client import MongoClient from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent -from pymongo.pool import _METADATA, PoolOptions, SocketInfo +from pymongo.pool import _METADATA, Connection, PoolOptions from pymongo.read_preferences import ReadPreference from pymongo.server_description import ServerDescription from pymongo.server_selectors import readable_server_selector, writable_server_selector @@ -538,13 +538,13 @@ def test_multiple_uris(self): def test_max_idle_time_reaper_default(self): with client_knobs(kill_cursor_frequency=0.1): - # Assert reaper doesn't remove sockets when maxIdleTimeMS not set + # Assert reaper doesn't remove connections when maxIdleTimeMS not set client = rs_or_single_client() server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info: + with server._pool.checkout() as conn: pass - self.assertEqual(1, len(server._pool.sockets)) - self.assertTrue(sock_info in server._pool.sockets) + self.assertEqual(1, len(server._pool.conns)) + self.assertTrue(conn in server._pool.conns) client.close() def test_max_idle_time_reaper_removes_stale_minPoolSize(self): @@ -552,27 +552,27 @@ def test_max_idle_time_reaper_removes_stale_minPoolSize(self): # Assert reaper removes idle socket and replaces it with a new one client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info: + with server._pool.checkout() as conn: pass # When the reaper runs at the same time as the get_socket, two - # sockets could be created and checked into the pool. - self.assertGreaterEqual(len(server._pool.sockets), 1) - wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket") - wait_until(lambda: 1 <= len(server._pool.sockets), "replace stale socket") + # connections could be created and checked into the pool. + self.assertGreaterEqual(len(server._pool.conns), 1) + wait_until(lambda: conn not in server._pool.conns, "remove stale socket") + wait_until(lambda: 1 <= len(server._pool.conns), "replace stale socket") client.close() def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): - # Assert reaper respects maxPoolSize when adding new sockets. + # Assert reaper respects maxPoolSize when adding new connections. client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info: + with server._pool.checkout() as conn: pass # When the reaper runs at the same time as the get_socket, - # maxPoolSize=1 should prevent two sockets from being created. - self.assertEqual(1, len(server._pool.sockets)) - wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket") - wait_until(lambda: 1 == len(server._pool.sockets), "replace stale socket") + # maxPoolSize=1 should prevent two connections from being created. + self.assertEqual(1, len(server._pool.conns)) + wait_until(lambda: conn not in server._pool.conns, "remove stale socket") + wait_until(lambda: 1 == len(server._pool.conns), "replace stale socket") client.close() def test_max_idle_time_reaper_removes_stale(self): @@ -580,15 +580,15 @@ def test_max_idle_time_reaper_removes_stale(self): # Assert reaper has removed idle socket and NOT replaced it client = rs_or_single_client(maxIdleTimeMS=500) server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info_one: + with server._pool.checkout() as conn_one: pass - # Assert that the pool does not close sockets prematurely. + # Assert that the pool does not close connections prematurely. time.sleep(0.300) - with server._pool.get_socket() as sock_info_two: + with server._pool.checkout() as conn_two: pass - self.assertIs(sock_info_one, sock_info_two) + self.assertIs(conn_one, conn_two) wait_until( - lambda: 0 == len(server._pool.sockets), + lambda: 0 == len(server._pool.conns), "stale socket reaped and new one NOT added to the pool", ) client.close() @@ -597,48 +597,50 @@ def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): client = rs_or_single_client() server = client._get_topology().select_server(readable_server_selector) - self.assertEqual(0, len(server._pool.sockets)) + self.assertEqual(0, len(server._pool.conns)) # Assert that pool started up at minPoolSize client = rs_or_single_client(minPoolSize=10) server = client._get_topology().select_server(readable_server_selector) - wait_until(lambda: 10 == len(server._pool.sockets), "pool initialized with 10 sockets") + wait_until( + lambda: 10 == len(server._pool.conns), "pool initialized with 10 connections" + ) # Assert that if a socket is closed, a new one takes its place - with server._pool.get_socket() as sock_info: - sock_info.close_socket(None) + with server._pool.checkout() as conn: + conn.close_conn(None) wait_until( - lambda: 10 == len(server._pool.sockets), + lambda: 10 == len(server._pool.conns), "a closed socket gets replaced from the pool", ) - self.assertFalse(sock_info in server._pool.sockets) + self.assertFalse(conn in server._pool.conns) def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): client = rs_or_single_client(maxIdleTimeMS=500) server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info: + with server._pool.checkout() as conn: pass - self.assertEqual(1, len(server._pool.sockets)) + self.assertEqual(1, len(server._pool.conns)) time.sleep(1) # Sleep so that the socket becomes stale. - with server._pool.get_socket() as new_sock_info: - self.assertNotEqual(sock_info, new_sock_info) - self.assertEqual(1, len(server._pool.sockets)) - self.assertFalse(sock_info in server._pool.sockets) - self.assertTrue(new_sock_info in server._pool.sockets) + with server._pool.checkout() as new_con: + self.assertNotEqual(conn, new_con) + self.assertEqual(1, len(server._pool.conns)) + self.assertFalse(conn in server._pool.conns) + self.assertTrue(new_con in server._pool.conns) - # Test that sockets are reused if maxIdleTimeMS is not set. + # Test that connections are reused if maxIdleTimeMS is not set. client = rs_or_single_client() server = client._get_topology().select_server(readable_server_selector) - with server._pool.get_socket() as sock_info: + with server._pool.checkout() as conn: pass - self.assertEqual(1, len(server._pool.sockets)) + self.assertEqual(1, len(server._pool.conns)) time.sleep(1) - with server._pool.get_socket() as new_sock_info: - self.assertEqual(sock_info, new_sock_info) - self.assertEqual(1, len(server._pool.sockets)) + with server._pool.checkout() as new_con: + self.assertEqual(conn, new_con) + self.assertEqual(1, len(server._pool.conns)) def test_constants(self): """This test uses MongoClient explicitly to make sure that host and @@ -933,11 +935,11 @@ def test_close_closes_sockets(self): topology = client._topology client.close() for server in topology._servers.values(): - self.assertFalse(server._pool.sockets) + self.assertFalse(server._pool.conns) self.assertTrue(server._monitor._executor._stopped) self.assertTrue(server._monitor._rtt_monitor._executor._stopped) - self.assertFalse(server._monitor._pool.sockets) - self.assertFalse(server._monitor._rtt_monitor._pool.sockets) + self.assertFalse(server._monitor._pool.conns) + self.assertFalse(server._monitor._rtt_monitor._pool.conns) def test_bad_uri(self): with self.assertRaises(InvalidURI): @@ -1130,8 +1132,8 @@ def test_waitQueueTimeoutMS(self): def test_socketKeepAlive(self): pool = get_pool(self.client) - with pool.get_socket() as sock_info: - keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) + with pool.checkout() as conn: + keepalive = conn.conn.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) self.assertTrue(keepalive) @no_type_check @@ -1184,7 +1186,7 @@ def test_contextlib(self): # The socket used for the previous commands has been returned to the # pool - self.assertEqual(1, len(get_pool(client).sockets)) + self.assertEqual(1, len(get_pool(client).conns)) with contextlib.closing(client): self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) @@ -1223,7 +1225,7 @@ def test_interrupt_signal(self): # main thread while find() is in-progress: On Windows, SIGALRM is # unavailable so we use a second thread. In our Evergreen setup on # Linux, the thread technique causes an error in the test at - # sock.recv(): TypeError: 'int' object is not callable + # conn.recv(): TypeError: 'int' object is not callable # We don't know what causes this, so we hack around it. if sys.platform == "win32": @@ -1271,16 +1273,16 @@ def test_operation_failure(self): self.addCleanup(client.close) client.pymongo_test.test.find_one() pool = get_pool(client) - socket_count = len(pool.sockets) + socket_count = len(pool.conns) self.assertGreaterEqual(socket_count, 1) - old_sock_info = next(iter(pool.sockets)) + old_conn = next(iter(pool.conns)) client.pymongo_test.test.drop() client.pymongo_test.test.insert_one({"_id": "foo"}) self.assertRaises(OperationFailure, client.pymongo_test.test.insert_one, {"_id": "foo"}) - self.assertEqual(socket_count, len(pool.sockets)) - new_sock_info = next(iter(pool.sockets)) - self.assertEqual(old_sock_info, new_sock_info) + self.assertEqual(socket_count, len(pool.conns)) + new_con = next(iter(pool.conns)) + self.assertEqual(old_conn, new_con) def test_lazy_connect_w0(self): # Ensure that connect-on-demand works when the first operation is @@ -1326,13 +1328,13 @@ def test_exhaust_network_error(self): connected(client) # Cause a network error. - sock_info = one(pool.sockets) - sock_info.sock.close() + conn = one(pool.conns) + conn.conn.close() cursor = collection.find(cursor_type=CursorType.EXHAUST) with self.assertRaises(ConnectionFailure): next(cursor) - self.assertTrue(sock_info.closed) + self.assertTrue(conn.closed) # The semaphore was decremented despite the error. self.assertEqual(0, pool.requests) @@ -1347,10 +1349,10 @@ def test_auth_network_error(self): # Cause a network error on the actual socket. pool = get_pool(c) - socket_info = one(pool.sockets) - socket_info.sock.close() + socket_info = one(pool.conns) + socket_info.conn.close() - # SocketInfo.authenticate logs, but gets a socket.error. Should be + # Connection.authenticate logs, but gets a socket.error. Should be # reraised as AutoReconnect. self.assertRaises(AutoReconnect, c.test.collection.find_one) @@ -1586,7 +1588,7 @@ def stall_connect(*args, **kwargs): self.addCleanup(delattr, pool, "connect") # Wait for the background thread to start creating connections - wait_until(lambda: len(pool.sockets) > 1, "start creating connections") + wait_until(lambda: len(pool.conns) > 1, "start creating connections") # Assert that application operations do not block. for _ in range(10): @@ -1847,7 +1849,7 @@ def test_exhaust_query_server_error(self): collection = client.pymongo_test.test pool = get_pool(client) - sock_info = one(pool.sockets) + conn = one(pool.conns) # This will cause OperationFailure in all mongo versions since # the value for $orderby must be a document. @@ -1856,10 +1858,10 @@ def test_exhaust_query_server_error(self): ) self.assertRaises(OperationFailure, cursor.next) - self.assertFalse(sock_info.closed) + self.assertFalse(conn.closed) # The socket was checked in and the semaphore was decremented. - self.assertIn(sock_info, pool.sockets) + self.assertIn(conn, pool.conns) self.assertEqual(0, pool.requests) def test_exhaust_getmore_server_error(self): @@ -1874,7 +1876,7 @@ def test_exhaust_getmore_server_error(self): pool = get_pool(client) pool._check_interval_seconds = None # Never check. - sock_info = one(pool.sockets) + conn = one(pool.conns) cursor = collection.find(cursor_type=CursorType.EXHAUST) @@ -1884,21 +1886,21 @@ def test_exhaust_getmore_server_error(self): # Cause a server error on getmore. def receive_message(request_id): # Discard the actual server response. - SocketInfo.receive_message(sock_info, request_id) + Connection.receive_message(conn, request_id) # responseFlags bit 1 is QueryFailure. msg = struct.pack(" 1) + self.assertTrue(len(cx_pool.conns) > 1) self.assertEqual(0, cx_pool.requests) def test_max_pool_size_none(self): @@ -479,7 +479,7 @@ def f(): joinall(threads) self.assertEqual(nthreads, self.n_passed) - self.assertTrue(len(cx_pool.sockets) > 1) + self.assertTrue(len(cx_pool.conns) > 1) self.assertEqual(cx_pool.max_pool_size, float("inf")) def test_max_pool_size_zero(self): @@ -502,7 +502,7 @@ def test_max_pool_size_with_connection_failure(self): # socket from pool" instead of AutoReconnect. for _i in range(2): with self.assertRaises(AutoReconnect) as context: - with test_pool.get_socket(): + with test_pool.checkout(): pass # Testing for AutoReconnect instead of ConnectionFailure, above, diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 6156b6b3fc..a3343d07c9 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -284,18 +284,18 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **client_options) @contextlib.contextmanager - def _socket_for_reads(self, read_preference, session): - context = super()._socket_for_reads(read_preference, session) - with context as (sock_info, read_preference): - self.record_a_read(sock_info.address) - yield sock_info, read_preference + def _conn_for_reads(self, read_preference, session): + context = super()._conn_for_reads(read_preference, session) + with context as (conn, read_preference): + self.record_a_read(conn.address) + yield conn, read_preference @contextlib.contextmanager - def _socket_from_server(self, read_preference, server, session): - context = super()._socket_from_server(read_preference, server, session) - with context as (sock_info, read_preference): - self.record_a_read(sock_info.address) - yield sock_info, read_preference + def _conn_from_server(self, read_preference, server, session): + context = super()._conn_from_server(read_preference, server, session) + with context as (conn, read_preference): + self.record_a_read(conn.address) + yield conn, read_preference def record_a_read(self, address): server = self._get_topology().select_server_by_address(address, 0) diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 6c015e0ed2..d97c4b4e8a 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -141,7 +141,7 @@ def test_load_balancing(self): ) self.addCleanup(client.close) wait_until(lambda: len(client.nodes) == 2, "discover both nodes") - wait_until(lambda: len(get_pool(client).sockets) >= 10, "create 10 connections") + wait_until(lambda: len(get_pool(client).conns) >= 10, "create 10 connections") # Delay find commands on delay_finds = { "configureFailPoint": "failCommand", diff --git a/test/utils.py b/test/utils.py index 86edae8808..ec17e2862b 100644 --- a/test/utils.py +++ b/test/utils.py @@ -279,12 +279,12 @@ def failed(self, event): self.add_event(event) -class MockSocketInfo: +class MockConnection: def __init__(self): self.cancel_context = _CancellationContext() self.more_to_come = False - def close_socket(self, reason): + def close_conn(self, reason): pass def __enter__(self): @@ -304,10 +304,10 @@ def __init__(self, address, options, handshake=True): def stale_generation(self, gen, service_id): return self.gen.stale(gen, service_id) - def get_socket(self, handler=None): - return MockSocketInfo() + def checkout(self, handler=None): + return MockConnection() - def return_socket(self, *args, **kwargs): + def checkin(self, *args, **kwargs): pass def _reset(self, service_id=None):