diff --git a/doc/changelog.rst b/doc/changelog.rst index c8e17bc30b..082c22fafc 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,7 +1,7 @@ Changelog ========= -Changes in Version 4.15.1 (2025/09/11) +Changes in Version 4.15.1 (2025/09/16) -------------------------------------- Version 4.15.1 is a bug fix release. diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 2b1895b832..d32a5b3204 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -64,7 +64,6 @@ from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.asynchronous.pool import AsyncBaseConnection from pymongo.common import CONNECT_TIMEOUT from pymongo.daemon import _spawn_daemon from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts, TextOpts @@ -77,11 +76,11 @@ ServerSelectionTimeoutError, ) from pymongo.helpers_shared import _get_timeout_details -from pymongo.network_layer import PyMongoKMSProtocol, async_receive_kms, async_sendall +from pymongo.network_layer import async_socket_sendall from pymongo.operations import UpdateOne from pymongo.pool_options import PoolOptions from pymongo.pool_shared import ( - _configured_protocol_interface, + _async_configured_socket, _raise_connection_failure, ) from pymongo.read_concern import ReadConcern @@ -94,8 +93,10 @@ if TYPE_CHECKING: from pymongocrypt.mongocrypt import MongoCryptKmsContext + from pymongo.pyopenssl_context import _sslConn from pymongo.typings import _Address + _IS_SYNC = False _HTTPS_PORT = 443 @@ -110,10 +111,9 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument) -async def _connect_kms(address: _Address, opts: PoolOptions) -> AsyncBaseConnection: +async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]: try: - interface = await _configured_protocol_interface(address, opts, PyMongoKMSProtocol) - return AsyncBaseConnection(interface, opts) + return await _async_configured_socket(address, opts) except Exception as exc: _raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts)) @@ -198,11 +198,19 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: try: conn = await _connect_kms(address, opts) try: - await async_sendall(conn.conn.get_conn, message) + await async_socket_sendall(conn, message) while kms_context.bytes_needed > 0: # CSOT: update timeout. - conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = await async_receive_kms(conn, kms_context.bytes_needed) + conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + data: memoryview | bytes + if _IS_SYNC: + data = conn.recv(kms_context.bytes_needed) + else: + from pymongo.network_layer import ( # type: ignore[attr-defined] + async_receive_data_socket, + ) + + data = await async_receive_data_socket(conn, kms_context.bytes_needed) if not data: raise OSError("KMS connection closed") kms_context.feed(data) @@ -221,7 +229,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts) ) finally: - await conn.close_conn(None) + conn.close() except MongoCryptError: raise # Propagate MongoCryptError errors directly. except Exception as exc: diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 8c169b4c52..196ec9040f 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -123,89 +123,7 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 _IS_SYNC = False -class AsyncBaseConnection: - """A base connection object for server and kms connections.""" - - def __init__(self, conn: AsyncNetworkingInterface, opts: PoolOptions): - self.conn = conn - self.socket_checker: SocketChecker = SocketChecker() - self.cancel_context: _CancellationContext = _CancellationContext() - self.is_sdam = False - self.closed = False - self.last_timeout: float | None = None - self.more_to_come = False - self.opts = opts - self.max_wire_version = -1 - - def set_conn_timeout(self, timeout: Optional[float]) -> None: - """Cache last timeout to avoid duplicate calls to conn.settimeout.""" - if timeout == self.last_timeout: - return - self.last_timeout = timeout - self.conn.get_conn.settimeout(timeout) - - def apply_timeout( - self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]] - ) -> Optional[float]: - # CSOT: use remaining timeout when set. - timeout = _csot.remaining() - if timeout is None: - # Reset the socket timeout unless we're performing a streaming monitor check. - if not self.more_to_come: - self.set_conn_timeout(self.opts.socket_timeout) - return None - # RTT validation. - rtt = _csot.get_rtt() - if rtt is None: - rtt = self.connect_rtt - max_time_ms = timeout - rtt - if max_time_ms < 0: - timeout_details = _get_timeout_details(self.opts) - formatted = format_timeout_details(timeout_details) - # CSOT: raise an error without running the command since we know it will time out. - errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}" - if self.max_wire_version != -1: - raise ExecutionTimeout( - errmsg, - 50, - {"ok": 0, "errmsg": errmsg, "code": 50}, - self.max_wire_version, - ) - else: - raise TimeoutError(errmsg) - if cmd is not None: - cmd["maxTimeMS"] = int(max_time_ms * 1000) - self.set_conn_timeout(timeout) - return timeout - - async def close_conn(self, reason: Optional[str]) -> None: - """Close this connection with a reason.""" - if self.closed: - return - await self._close_conn() - - async def _close_conn(self) -> None: - """Close this connection.""" - if self.closed: - return - self.closed = True - self.cancel_context.cancel() - # Note: We catch exceptions to avoid spurious errors on interpreter - # shutdown. - try: - await self.conn.close() - except Exception: # noqa: S110 - pass - - def conn_closed(self) -> bool: - """Return True if we know socket has been closed, False otherwise.""" - if _IS_SYNC: - return self.socket_checker.socket_closed(self.conn.get_conn) - else: - return self.conn.is_closing() - - -class AsyncConnection(AsyncBaseConnection): +class AsyncConnection: """Store a connection with some metadata. :param conn: a raw connection object @@ -223,27 +141,29 @@ def __init__( id: int, is_sdam: bool, ): - super().__init__(conn, pool.opts) self.pool_ref = weakref.ref(pool) - self.address: tuple[str, int] = address - self.id: int = id + self.conn = conn + self.address = address + self.id = id self.is_sdam = is_sdam + self.closed = False self.last_checkin_time = time.monotonic() self.performed_handshake = False self.is_writable: bool = False self.max_wire_version = MAX_WIRE_VERSION - self.max_bson_size: int = MAX_BSON_SIZE - self.max_message_size: int = MAX_MESSAGE_SIZE - self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE + self.max_bson_size = MAX_BSON_SIZE + self.max_message_size = MAX_MESSAGE_SIZE + self.max_write_batch_size = MAX_WRITE_BATCH_SIZE self.supports_sessions = False self.hello_ok: bool = False - self.is_mongos: bool = False + self.is_mongos = False self.op_msg_enabled = False self.listeners = pool.opts._event_listeners self.enabled_for_cmap = pool.enabled_for_cmap self.enabled_for_logging = pool.enabled_for_logging self.compression_settings = pool.opts._compression_settings self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None + self.socket_checker: SocketChecker = SocketChecker() self.oidc_token_gen_id: Optional[int] = None # Support for mechanism negotiation on the initial handshake. self.negotiated_mechs: Optional[list[str]] = None @@ -254,6 +174,9 @@ def __init__( self.pool_gen = pool.gen self.generation = self.pool_gen.get_overall() self.ready = False + self.cancel_context: _CancellationContext = _CancellationContext() + self.opts = pool.opts + self.more_to_come: bool = False # For load balancer support. self.service_id: Optional[ObjectId] = None self.server_connection_id: Optional[int] = None @@ -269,6 +192,44 @@ def __init__( # For gossiping $clusterTime from the connection handshake to the client. self._cluster_time = None + def set_conn_timeout(self, timeout: Optional[float]) -> None: + """Cache last timeout to avoid duplicate calls to conn.settimeout.""" + if timeout == self.last_timeout: + return + self.last_timeout = timeout + self.conn.get_conn.settimeout(timeout) + + def apply_timeout( + self, client: AsyncMongoClient[Any], cmd: Optional[MutableMapping[str, Any]] + ) -> Optional[float]: + # CSOT: use remaining timeout when set. + timeout = _csot.remaining() + if timeout is None: + # Reset the socket timeout unless we're performing a streaming monitor check. + if not self.more_to_come: + self.set_conn_timeout(self.opts.socket_timeout) + return None + # RTT validation. + rtt = _csot.get_rtt() + if rtt is None: + rtt = self.connect_rtt + max_time_ms = timeout - rtt + if max_time_ms < 0: + timeout_details = _get_timeout_details(self.opts) + formatted = format_timeout_details(timeout_details) + # CSOT: raise an error without running the command since we know it will time out. + errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}" + raise ExecutionTimeout( + errmsg, + 50, + {"ok": 0, "errmsg": errmsg, "code": 50}, + self.max_wire_version, + ) + if cmd is not None: + cmd["maxTimeMS"] = int(max_time_ms * 1000) + self.set_conn_timeout(timeout) + return timeout + def pin_txn(self) -> None: self.pinned_txn = True assert not self.pinned_cursor @@ -612,6 +573,26 @@ async def close_conn(self, reason: Optional[str]) -> None: error=reason, ) + async def _close_conn(self) -> None: + """Close this connection.""" + if self.closed: + return + self.closed = True + self.cancel_context.cancel() + # Note: We catch exceptions to avoid spurious errors on interpreter + # shutdown. + try: + await self.conn.close() + except Exception: # noqa: S110 + pass + + def conn_closed(self) -> bool: + """Return True if we know socket has been closed, False otherwise.""" + if _IS_SYNC: + return self.socket_checker.socket_closed(self.conn.get_conn) + else: + return self.conn.is_closing() + def send_cluster_time( self, command: MutableMapping[str, Any], diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index a3900e30c1..2e5b61f8ae 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -22,11 +22,10 @@ import struct import sys import time -from asyncio import BaseProtocol, BaseTransport, BufferedProtocol, Future, Transport +from asyncio import AbstractEventLoop, BaseTransport, BufferedProtocol, Future, Transport from typing import ( TYPE_CHECKING, Any, - Callable, Optional, Union, ) @@ -39,30 +38,208 @@ from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.socket_checker import _errno_from_exception -if TYPE_CHECKING: - from pymongo.asynchronous.pool import AsyncBaseConnection, AsyncConnection +try: + from ssl import SSLError, SSLSocket + + _HAVE_SSL = True +except ImportError: + _HAVE_SSL = False + +try: from pymongo.pyopenssl_context import _sslConn - from pymongo.synchronous.pool import BaseConnection, Connection + + _HAVE_PYOPENSSL = True +except ImportError: + _HAVE_PYOPENSSL = False + _sslConn = SSLSocket # type: ignore[assignment, misc] + +from pymongo.ssl_support import ( + BLOCKING_IO_LOOKUP_ERROR, + BLOCKING_IO_READ_ERROR, + BLOCKING_IO_WRITE_ERROR, +) + +if TYPE_CHECKING: + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.synchronous.pool import Connection _UNPACK_HEADER = struct.Struct(" None: + timeout = sock.gettimeout() + sock.settimeout(0.0) + loop = asyncio.get_running_loop() + try: + if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): + await asyncio.wait_for(_async_socket_sendall_ssl(sock, buf, loop), timeout=timeout) + else: + await asyncio.wait_for(loop.sock_sendall(sock, buf), timeout=timeout) # type: ignore[arg-type] + except asyncio.TimeoutError as exc: + # Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands. + raise socket.timeout("timed out") from exc + finally: + sock.settimeout(timeout) + + +if sys.platform != "win32": + + async def _async_socket_sendall_ssl( + sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop + ) -> None: + view = memoryview(buf) + sent = 0 + + def _is_ready(fut: Future[Any]) -> None: + if fut.done(): + return + fut.set_result(None) + + while sent < len(buf): + try: + sent += sock.send(view[sent:]) # type:ignore[arg-type] + except BLOCKING_IO_ERRORS as exc: + fd = sock.fileno() + # Check for closed socket. + if fd == -1: + raise SSLError("Underlying socket has been closed") from None + if isinstance(exc, BLOCKING_IO_READ_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + try: + await fut + finally: + loop.remove_reader(fd) + if isinstance(exc, BLOCKING_IO_WRITE_ERROR): + fut = loop.create_future() + loop.add_writer(fd, _is_ready, fut) + try: + await fut + finally: + loop.remove_writer(fd) + if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + try: + loop.add_writer(fd, _is_ready, fut) + await fut + finally: + loop.remove_reader(fd) + loop.remove_writer(fd) + + async def _async_socket_receive_ssl( + conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False + ) -> memoryview: + mv = memoryview(bytearray(length)) + total_read = 0 + + def _is_ready(fut: Future[Any]) -> None: + if fut.done(): + return + fut.set_result(None) + + while total_read < length: + try: + read = conn.recv_into(mv[total_read:]) + if read == 0: + raise OSError("connection closed") + # KMS responses update their expected size after the first batch, stop reading after one loop + if once: + return mv[:read] + total_read += read + except BLOCKING_IO_ERRORS as exc: + fd = conn.fileno() + # Check for closed socket. + if fd == -1: + raise SSLError("Underlying socket has been closed") from None + if isinstance(exc, BLOCKING_IO_READ_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + try: + await fut + finally: + loop.remove_reader(fd) + if isinstance(exc, BLOCKING_IO_WRITE_ERROR): + fut = loop.create_future() + loop.add_writer(fd, _is_ready, fut) + try: + await fut + finally: + loop.remove_writer(fd) + if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR): + fut = loop.create_future() + loop.add_reader(fd, _is_ready, fut) + try: + loop.add_writer(fd, _is_ready, fut) + await fut + finally: + loop.remove_reader(fd) + loop.remove_writer(fd) + return mv + +else: + # The default Windows asyncio event loop does not support loop.add_reader/add_writer: + # https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support + # Note: In PYTHON-4493 we plan to replace this code with asyncio streams. + async def _async_socket_sendall_ssl( + sock: Union[socket.socket, _sslConn], buf: bytes, dummy: AbstractEventLoop + ) -> None: + view = memoryview(buf) + total_length = len(buf) + total_sent = 0 + # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success + # down to 1ms. + backoff = 0.001 + while total_sent < total_length: + try: + sent = sock.send(view[total_sent:]) + except BLOCKING_IO_ERRORS: + await asyncio.sleep(backoff) + sent = 0 + if sent > 0: + backoff = max(backoff / 2, 0.001) + else: + backoff = min(backoff * 2, 0.512) + total_sent += sent + + async def _async_socket_receive_ssl( + conn: _sslConn, length: int, dummy: AbstractEventLoop, once: Optional[bool] = False + ) -> memoryview: + mv = memoryview(bytearray(length)) + total_read = 0 + # Backoff starts at 1ms, doubles on timeout up to 512ms, and halves on success + # down to 1ms. + backoff = 0.001 + while total_read < length: + try: + read = conn.recv_into(mv[total_read:]) + if read == 0: + raise OSError("connection closed") + # KMS responses update their expected size after the first batch, stop reading after one loop + if once: + return mv[:read] + except BLOCKING_IO_ERRORS: + await asyncio.sleep(backoff) + read = 0 + if read > 0: + backoff = max(backoff / 2, 0.001) + else: + backoff = min(backoff * 2, 0.512) + total_read += read + return mv def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None: sock.sendall(buf) -async def _poll_cancellation(conn: AsyncBaseConnection) -> None: +async def _poll_cancellation(conn: AsyncConnection) -> None: while True: if conn.cancel_context.cancelled: return @@ -70,7 +247,49 @@ async def _poll_cancellation(conn: AsyncBaseConnection) -> None: await asyncio.sleep(_POLL_TIMEOUT) -def wait_for_read(conn: BaseConnection, deadline: Optional[float]) -> None: +async def async_receive_data_socket( + sock: Union[socket.socket, _sslConn], length: int +) -> memoryview: + sock_timeout = sock.gettimeout() + timeout = sock_timeout + + sock.settimeout(0.0) + loop = asyncio.get_running_loop() + try: + if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): + return await asyncio.wait_for( + _async_socket_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] + timeout=timeout, + ) + else: + return await asyncio.wait_for( + _async_socket_receive(sock, length, loop), # type: ignore[arg-type] + timeout=timeout, + ) + except asyncio.TimeoutError as err: + raise socket.timeout("timed out") from err + finally: + sock.settimeout(sock_timeout) + + +async def _async_socket_receive( + conn: socket.socket, length: int, loop: AbstractEventLoop +) -> memoryview: + mv = memoryview(bytearray(length)) + bytes_read = 0 + while bytes_read < length: + chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:]) + if chunk_length == 0: + raise OSError("connection closed") + bytes_read += chunk_length + return mv + + +_PYPY = "PyPy" in sys.version +_WINDOWS = sys.platform == "win32" + + +def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: """Block until at least one byte is read, or a timeout, or a cancel.""" sock = conn.conn.sock timed_out = False @@ -103,7 +322,7 @@ def wait_for_read(conn: BaseConnection, deadline: Optional[float]) -> None: raise socket.timeout("timed out") -def receive_data(conn: BaseConnection, length: int, deadline: Optional[float]) -> memoryview: +def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: buf = bytearray(length) mv = memoryview(buf) bytes_read = 0 @@ -193,7 +412,7 @@ def sock(self) -> Any: class AsyncNetworkingInterface(NetworkingInterfaceBase): - def __init__(self, conn: tuple[Transport, PyMongoBaseProtocol]): + def __init__(self, conn: tuple[Transport, PyMongoProtocol]): super().__init__(conn) @property @@ -211,7 +430,7 @@ def is_closing(self) -> bool: return self.conn[0].is_closing() @property - def get_conn(self) -> PyMongoBaseProtocol: + def get_conn(self) -> PyMongoProtocol: return self.conn[1] @property @@ -250,51 +469,9 @@ def recv_into(self, buffer: bytes | memoryview) -> int: return self.conn.recv_into(buffer) -class PyMongoBaseProtocol(BaseProtocol): +class PyMongoProtocol(BufferedProtocol): def __init__(self, timeout: Optional[float] = None): self.transport: Transport = None # type: ignore[assignment] - self._timeout = timeout - self._closed = asyncio.get_running_loop().create_future() - self._connection_lost = False - - def settimeout(self, timeout: float | None) -> None: - self._timeout = timeout - - @property - def gettimeout(self) -> float | None: - """The configured timeout for the socket that underlies our protocol pair.""" - return self._timeout - - def close(self, exc: Optional[Exception] = None) -> None: - self.transport.abort() - self._resolve_pending(exc) - self._connection_lost = True - - def connection_lost(self, exc: Optional[Exception] = None) -> None: - self._resolve_pending(exc) - if not self._closed.done(): - self._closed.set_result(None) - - def _resolve_pending(self, exc: Optional[Exception] = None) -> None: - pass - - async def wait_closed(self) -> None: - await self._closed - - async def write(self, message: bytes) -> None: - """Write a message to this connection's transport.""" - if self.transport.is_closing(): - raise OSError("Connection is closed") - self.transport.write(message) - self.transport.resume_reading() - - async def read(self, *args: Any) -> Any: - raise NotImplementedError - - -class PyMongoProtocol(PyMongoBaseProtocol, BufferedProtocol): - def __init__(self, timeout: Optional[float] = None): - super().__init__(timeout) # Each message is reader in 2-3 parts: header, compression header, and message body # The message buffer is allocated after the header is read. self._header = memoryview(bytearray(16)) @@ -308,14 +485,25 @@ def __init__(self, timeout: Optional[float] = None): self._expecting_compression = False self._message_size = 0 self._op_code = 0 + self._connection_lost = False self._read_waiter: Optional[Future[Any]] = None + self._timeout = timeout self._is_compressed = False self._compressor_id: Optional[int] = None self._max_message_size = MAX_MESSAGE_SIZE self._response_to: Optional[int] = None + self._closed = asyncio.get_running_loop().create_future() self._pending_messages: collections.deque[Future[Any]] = collections.deque() self._done_messages: collections.deque[Future[Any]] = collections.deque() + def settimeout(self, timeout: float | None) -> None: + self._timeout = timeout + + @property + def gettimeout(self) -> float | None: + """The configured timeout for the socket that underlies our protocol pair.""" + return self._timeout + def connection_made(self, transport: BaseTransport) -> None: """Called exactly once when a connection is made. The transport argument is the transport representing the write side of the connection. @@ -323,6 +511,13 @@ def connection_made(self, transport: BaseTransport) -> None: self.transport = transport # type: ignore[assignment] self.transport.set_write_buffer_limits(MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE) + async def write(self, message: bytes) -> None: + """Write a message to this connection's transport.""" + if self.transport.is_closing(): + raise OSError("Connection is closed") + self.transport.write(message) + self.transport.resume_reading() + async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]: """Read a single MongoDB Wire Protocol message from this connection.""" if self.transport: @@ -465,7 +660,7 @@ def process_compression_header(self) -> tuple[int, int]: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header) return op_code, compressor_id - def _resolve_pending(self, exc: Optional[Exception] = None) -> None: + def _resolve_pending_messages(self, exc: Optional[Exception] = None) -> None: pending = list(self._pending_messages) for msg in pending: if not msg.done(): @@ -475,92 +670,21 @@ def _resolve_pending(self, exc: Optional[Exception] = None) -> None: msg.set_exception(exc) self._done_messages.append(msg) + def close(self, exc: Optional[Exception] = None) -> None: + self.transport.abort() + self._resolve_pending_messages(exc) + self._connection_lost = True -class PyMongoKMSProtocol(PyMongoBaseProtocol): - def __init__(self, timeout: Optional[float] = None): - super().__init__(timeout) - self._buffers: collections.deque[memoryview[bytes]] = collections.deque() - self._bytes_ready = 0 - self._pending_reads: collections.deque[int] = collections.deque() - self._pending_listeners: collections.deque[Future[Any]] = collections.deque() - - def connection_made(self, transport: BaseTransport) -> None: - """Called exactly once when a connection is made. - The transport argument is the transport representing the write side of the connection. - """ - self.transport = transport # type: ignore[assignment] - - def data_received(self, data: bytes) -> None: - if self._connection_lost: - return - - self._bytes_ready += len(data) - self._buffers.append(memoryview(data)) - - if not len(self._pending_reads): - return + def connection_lost(self, exc: Optional[Exception] = None) -> None: + self._resolve_pending_messages(exc) + if not self._closed.done(): + self._closed.set_result(None) - bytes_needed = self._pending_reads.popleft() - data = self._read(bytes_needed) - waiter = self._pending_listeners.popleft() - waiter.set_result(data) - - async def read(self, bytes_needed: int) -> bytes: - """Read up to the requested bytes from this connection.""" - # Note: all reads are "up-to" bytes_needed because we don't know if the kms_context - # has processed a Content-Length header and is requesting a response or not. - # Wait for other listeners first. - if len(self._pending_listeners): - await asyncio.gather(*self._pending_listeners) - # If there are bytes ready, then there is no need to wait further. - if self._bytes_ready > 0: - return self._read(bytes_needed) - if self.transport: - try: - self.transport.resume_reading() - # Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322 - except AttributeError: - raise OSError("connection is already closed") from None - if self.transport and self.transport.is_closing(): - raise OSError("connection is already closed") - self._pending_reads.append(bytes_needed) - read_waiter = asyncio.get_running_loop().create_future() - self._pending_listeners.append(read_waiter) - return await read_waiter - - def _resolve_pending(self, exc: Optional[Exception] = None) -> None: - while self._pending_listeners: - fut = self._pending_listeners.popleft() - fut.set_result(b"") - - def _read(self, bytes_needed: int) -> bytes: - """Read bytes.""" - # Send the bytes to the listener. - if self._bytes_ready < bytes_needed: - bytes_needed = self._bytes_ready - self._bytes_ready -= bytes_needed - - output_buf = memoryview(bytearray(bytes_needed)) - n_remaining = bytes_needed - out_index = 0 - while n_remaining > 0: - buffer = self._buffers.popleft() - buf_size = len(buffer) - # if we didn't exhaust the buffer, read the partial data and return the buffer. - if buf_size > n_remaining: - output_buf[out_index : n_remaining + out_index] = buffer[:n_remaining] - buffer = buffer[n_remaining:] - n_remaining = 0 - self._buffers.appendleft(buffer) - # otherwise exhaust the buffer. - else: - output_buf[out_index : out_index + buf_size] = buffer[:] - out_index += buf_size - n_remaining -= buf_size - return bytes(output_buf) + async def wait_closed(self) -> None: + await self._closed -async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None: +async def async_sendall(conn: PyMongoProtocol, buf: bytes) -> None: try: await asyncio.wait_for(conn.write(buf), timeout=conn.gettimeout) except asyncio.TimeoutError as exc: @@ -568,18 +692,12 @@ async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None: raise socket.timeout("timed out") from exc -async def async_receive_kms(conn: AsyncBaseConnection, bytes_needed: int) -> bytes: - """Receive raw bytes from the kms connection.""" - - def callback(result: Any) -> bytes: - return result - - return await _async_receive_data(conn, callback, bytes_needed) - - -async def _async_receive_data( - conn: AsyncBaseConnection, callback: Callable[..., Any], *args: Any -) -> Any: +async def async_receive_message( + conn: AsyncConnection, + request_id: Optional[int], + max_message_size: int = MAX_MESSAGE_SIZE, +) -> Union[_OpReply, _OpMsg]: + """Receive a raw BSON message or raise socket.error.""" timeout: Optional[Union[float, int]] timeout = conn.conn.gettimeout if _csot.get_timeout(): @@ -595,8 +713,8 @@ async def _async_receive_data( # timeouts on AWS Lambda and other FaaS environments. timeout = max(deadline - time.monotonic(), 0) - read_task = create_task(conn.conn.get_conn.read(*args)) cancellation_task = create_task(_poll_cancellation(conn)) + read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size)) tasks = [read_task, cancellation_task] try: done, pending = await asyncio.wait( @@ -609,7 +727,14 @@ async def _async_receive_data( if len(done) == 0: raise socket.timeout("timed out") if read_task in done: - return callback(read_task.result()) + data, op_code = read_task.result() + try: + unpack_reply = _UNPACK_REPLY[op_code] + except KeyError: + raise ProtocolError( + f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" + ) from None + return unpack_reply(data) raise _OperationCancelled("operation cancelled") except asyncio.CancelledError: for task in tasks: @@ -618,31 +743,6 @@ async def _async_receive_data( raise -async def async_receive_message( - conn: AsyncConnection, - request_id: Optional[int], - max_message_size: int = MAX_MESSAGE_SIZE, -) -> Union[_OpReply, _OpMsg]: - """Receive a raw BSON message or raise socket.error.""" - - def callback(result: Any) -> _OpMsg | _OpReply: - data, op_code = result - try: - unpack_reply = _UNPACK_REPLY[op_code] - except KeyError: - raise ProtocolError( - f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}" - ) from None - return unpack_reply(data) - - return await _async_receive_data(conn, callback, request_id, max_message_size) - - -def receive_kms(conn: BaseConnection, bytes_needed: int) -> bytes: - """Receive raw bytes from the kms connection.""" - return conn.conn.sock.recv(bytes_needed) - - def receive_message( conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE ) -> Union[_OpReply, _OpMsg]: @@ -670,7 +770,7 @@ def receive_message( f"Message length ({length!r}) is larger than server max " f"message size ({max_message_size!r})" ) - data: bytes | memoryview + data: memoryview | bytes if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline)) data = decompress(receive_data(conn, length - 25, deadline), compressor_id) diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py index 0536dc3835..ac562af542 100644 --- a/pymongo/pool_shared.py +++ b/pymongo/pool_shared.py @@ -16,6 +16,7 @@ from __future__ import annotations import asyncio +import functools import socket import ssl import sys @@ -24,6 +25,7 @@ Any, NoReturn, Optional, + Union, ) from pymongo import _csot @@ -35,17 +37,13 @@ _CertificateError, ) from pymongo.helpers_shared import _get_timeout_details, format_timeout_details -from pymongo.network_layer import ( - AsyncNetworkingInterface, - NetworkingInterface, - PyMongoBaseProtocol, - PyMongoProtocol, -) +from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol from pymongo.pool_options import PoolOptions from pymongo.ssl_support import PYSSLError, SSLError, _has_sni SSLErrors = (PYSSLError, SSLError) if TYPE_CHECKING: + from pymongo.pyopenssl_context import _sslConn from pymongo.typings import _Address try: @@ -246,10 +244,64 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s raise OSError("getaddrinfo failed") +async def _async_configured_socket( + address: _Address, options: PoolOptions +) -> Union[socket.socket, _sslConn]: + """Given (host, port) and PoolOptions, return a raw configured socket. + + Can raise socket.error, ConnectionFailure, or _CertificateError. + + Sets socket's SSL and timeout options. + """ + sock = await _async_create_connection(address, options) + ssl_context = options._ssl_context + + if ssl_context is None: + sock.settimeout(options.socket_timeout) + return sock + + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + if _has_sni(False): + loop = asyncio.get_running_loop() + ssl_sock = await loop.run_in_executor( + None, + functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc, unused-ignore] + ) + else: + loop = asyncio.get_running_loop() + ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc, unused-ignore] + except _CertificateError: + sock.close() + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, *SSLErrors) as exc: + sock.close() + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore] + except _CertificateError: + ssl_sock.close() + raise + + ssl_sock.settimeout(options.socket_timeout) + return ssl_sock + + async def _configured_protocol_interface( - address: _Address, - options: PoolOptions, - protocol_kls: type[PyMongoBaseProtocol] = PyMongoProtocol, + address: _Address, options: PoolOptions ) -> AsyncNetworkingInterface: """Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface. @@ -264,7 +316,7 @@ async def _configured_protocol_interface( if ssl_context is None: return AsyncNetworkingInterface( await asyncio.get_running_loop().create_connection( - lambda: protocol_kls(timeout=timeout), sock=sock + lambda: PyMongoProtocol(timeout=timeout), sock=sock ) ) @@ -273,7 +325,7 @@ async def _configured_protocol_interface( # We have to pass hostname / ip address to wrap_socket # to use SSLContext.check_hostname. transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload] - lambda: protocol_kls(timeout=timeout), + lambda: PyMongoProtocol(timeout=timeout), sock=sock, server_hostname=host, ssl=ssl_context, @@ -373,9 +425,56 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket raise OSError("getaddrinfo failed") -def _configured_socket_interface( - address: _Address, options: PoolOptions, *args: Any -) -> NetworkingInterface: +def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]: + """Given (host, port) and PoolOptions, return a raw configured socket. + + Can raise socket.error, ConnectionFailure, or _CertificateError. + + Sets socket's SSL and timeout options. + """ + sock = _create_connection(address, options) + ssl_context = options._ssl_context + + if ssl_context is None: + sock.settimeout(options.socket_timeout) + return sock + + host = address[0] + try: + # We have to pass hostname / ip address to wrap_socket + # to use SSLContext.check_hostname. + if _has_sni(True): + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore] + else: + ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore] + except _CertificateError: + sock.close() + # Raise _CertificateError directly like we do after match_hostname + # below. + raise + except (OSError, *SSLErrors) as exc: + sock.close() + # We raise AutoReconnect for transient and permanent SSL handshake + # failures alike. Permanent handshake failures, like protocol + # mismatch, will be turned into ServerSelectionTimeoutErrors later. + details = _get_timeout_details(options) + _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): + try: + ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore] + except _CertificateError: + ssl_sock.close() + raise + + ssl_sock.settimeout(options.socket_timeout) + return ssl_sock + + +def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface: """Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket. Can raise socket.error, ConnectionFailure, or _CertificateError. diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index a08302c211..f9d51a9eab 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -71,11 +71,11 @@ ServerSelectionTimeoutError, ) from pymongo.helpers_shared import _get_timeout_details -from pymongo.network_layer import PyMongoKMSProtocol, receive_kms, sendall +from pymongo.network_layer import sendall from pymongo.operations import UpdateOne from pymongo.pool_options import PoolOptions from pymongo.pool_shared import ( - _configured_socket_interface, + _configured_socket, _raise_connection_failure, ) from pymongo.read_concern import ReadConcern @@ -85,7 +85,6 @@ from pymongo.synchronous.cursor import Cursor from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.pool import BaseConnection from pymongo.typings import _DocumentType, _DocumentTypeArg from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host from pymongo.write_concern import WriteConcern @@ -93,8 +92,10 @@ if TYPE_CHECKING: from pymongocrypt.mongocrypt import MongoCryptKmsContext + from pymongo.pyopenssl_context import _sslConn from pymongo.typings import _Address + _IS_SYNC = True _HTTPS_PORT = 443 @@ -109,10 +110,9 @@ _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument) -def _connect_kms(address: _Address, opts: PoolOptions) -> BaseConnection: +def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]: try: - interface = _configured_socket_interface(address, opts, PyMongoKMSProtocol) - return BaseConnection(interface, opts) + return _configured_socket(address, opts) except Exception as exc: _raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts)) @@ -197,11 +197,19 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: try: conn = _connect_kms(address, opts) try: - sendall(conn.conn.get_conn, message) + sendall(conn, message) while kms_context.bytes_needed > 0: # CSOT: update timeout. - conn.set_conn_timeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = receive_kms(conn, kms_context.bytes_needed) + conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + data: memoryview | bytes + if _IS_SYNC: + data = conn.recv(kms_context.bytes_needed) + else: + from pymongo.network_layer import ( # type: ignore[attr-defined] + receive_data_socket, + ) + + data = receive_data_socket(conn, kms_context.bytes_needed) if not data: raise OSError("KMS connection closed") kms_context.feed(data) @@ -220,7 +228,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: address, exc, msg_prefix=msg_prefix, timeout_details=_get_timeout_details(opts) ) finally: - conn.close_conn(None) + conn.close() except MongoCryptError: raise # Propagate MongoCryptError errors directly. except Exception as exc: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index f35ca4d0fd..f7f6a26c68 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -123,89 +123,7 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001 _IS_SYNC = True -class BaseConnection: - """A base connection object for server and kms connections.""" - - def __init__(self, conn: NetworkingInterface, opts: PoolOptions): - self.conn = conn - self.socket_checker: SocketChecker = SocketChecker() - self.cancel_context: _CancellationContext = _CancellationContext() - self.is_sdam = False - self.closed = False - self.last_timeout: float | None = None - self.more_to_come = False - self.opts = opts - self.max_wire_version = -1 - - def set_conn_timeout(self, timeout: Optional[float]) -> None: - """Cache last timeout to avoid duplicate calls to conn.settimeout.""" - if timeout == self.last_timeout: - return - self.last_timeout = timeout - self.conn.get_conn.settimeout(timeout) - - def apply_timeout( - self, client: MongoClient[Any], cmd: Optional[MutableMapping[str, Any]] - ) -> Optional[float]: - # CSOT: use remaining timeout when set. - timeout = _csot.remaining() - if timeout is None: - # Reset the socket timeout unless we're performing a streaming monitor check. - if not self.more_to_come: - self.set_conn_timeout(self.opts.socket_timeout) - return None - # RTT validation. - rtt = _csot.get_rtt() - if rtt is None: - rtt = self.connect_rtt - max_time_ms = timeout - rtt - if max_time_ms < 0: - timeout_details = _get_timeout_details(self.opts) - formatted = format_timeout_details(timeout_details) - # CSOT: raise an error without running the command since we know it will time out. - errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}" - if self.max_wire_version != -1: - raise ExecutionTimeout( - errmsg, - 50, - {"ok": 0, "errmsg": errmsg, "code": 50}, - self.max_wire_version, - ) - else: - raise TimeoutError(errmsg) - if cmd is not None: - cmd["maxTimeMS"] = int(max_time_ms * 1000) - self.set_conn_timeout(timeout) - return timeout - - def close_conn(self, reason: Optional[str]) -> None: - """Close this connection with a reason.""" - if self.closed: - return - self._close_conn() - - def _close_conn(self) -> None: - """Close this connection.""" - if self.closed: - return - self.closed = True - self.cancel_context.cancel() - # Note: We catch exceptions to avoid spurious errors on interpreter - # shutdown. - try: - self.conn.close() - except Exception: # noqa: S110 - pass - - def conn_closed(self) -> bool: - """Return True if we know socket has been closed, False otherwise.""" - if _IS_SYNC: - return self.socket_checker.socket_closed(self.conn.get_conn) - else: - return self.conn.is_closing() - - -class Connection(BaseConnection): +class Connection: """Store a connection with some metadata. :param conn: a raw connection object @@ -223,27 +141,29 @@ def __init__( id: int, is_sdam: bool, ): - super().__init__(conn, pool.opts) self.pool_ref = weakref.ref(pool) - self.address: tuple[str, int] = address - self.id: int = id + self.conn = conn + self.address = address + self.id = id self.is_sdam = is_sdam + self.closed = False self.last_checkin_time = time.monotonic() self.performed_handshake = False self.is_writable: bool = False self.max_wire_version = MAX_WIRE_VERSION - self.max_bson_size: int = MAX_BSON_SIZE - self.max_message_size: int = MAX_MESSAGE_SIZE - self.max_write_batch_size: int = MAX_WRITE_BATCH_SIZE + self.max_bson_size = MAX_BSON_SIZE + self.max_message_size = MAX_MESSAGE_SIZE + self.max_write_batch_size = MAX_WRITE_BATCH_SIZE self.supports_sessions = False self.hello_ok: bool = False - self.is_mongos: bool = False + self.is_mongos = False self.op_msg_enabled = False self.listeners = pool.opts._event_listeners self.enabled_for_cmap = pool.enabled_for_cmap self.enabled_for_logging = pool.enabled_for_logging self.compression_settings = pool.opts._compression_settings self.compression_context: Union[SnappyContext, ZlibContext, ZstdContext, None] = None + self.socket_checker: SocketChecker = SocketChecker() self.oidc_token_gen_id: Optional[int] = None # Support for mechanism negotiation on the initial handshake. self.negotiated_mechs: Optional[list[str]] = None @@ -254,6 +174,9 @@ def __init__( self.pool_gen = pool.gen self.generation = self.pool_gen.get_overall() self.ready = False + self.cancel_context: _CancellationContext = _CancellationContext() + self.opts = pool.opts + self.more_to_come: bool = False # For load balancer support. self.service_id: Optional[ObjectId] = None self.server_connection_id: Optional[int] = None @@ -269,6 +192,44 @@ def __init__( # For gossiping $clusterTime from the connection handshake to the client. self._cluster_time = None + def set_conn_timeout(self, timeout: Optional[float]) -> None: + """Cache last timeout to avoid duplicate calls to conn.settimeout.""" + if timeout == self.last_timeout: + return + self.last_timeout = timeout + self.conn.get_conn.settimeout(timeout) + + def apply_timeout( + self, client: MongoClient[Any], cmd: Optional[MutableMapping[str, Any]] + ) -> Optional[float]: + # CSOT: use remaining timeout when set. + timeout = _csot.remaining() + if timeout is None: + # Reset the socket timeout unless we're performing a streaming monitor check. + if not self.more_to_come: + self.set_conn_timeout(self.opts.socket_timeout) + return None + # RTT validation. + rtt = _csot.get_rtt() + if rtt is None: + rtt = self.connect_rtt + max_time_ms = timeout - rtt + if max_time_ms < 0: + timeout_details = _get_timeout_details(self.opts) + formatted = format_timeout_details(timeout_details) + # CSOT: raise an error without running the command since we know it will time out. + errmsg = f"operation would exceed time limit, remaining timeout:{timeout:.5f} <= network round trip time:{rtt:.5f} {formatted}" + raise ExecutionTimeout( + errmsg, + 50, + {"ok": 0, "errmsg": errmsg, "code": 50}, + self.max_wire_version, + ) + if cmd is not None: + cmd["maxTimeMS"] = int(max_time_ms * 1000) + self.set_conn_timeout(timeout) + return timeout + def pin_txn(self) -> None: self.pinned_txn = True assert not self.pinned_cursor @@ -610,6 +571,26 @@ def close_conn(self, reason: Optional[str]) -> None: error=reason, ) + def _close_conn(self) -> None: + """Close this connection.""" + if self.closed: + return + self.closed = True + self.cancel_context.cancel() + # Note: We catch exceptions to avoid spurious errors on interpreter + # shutdown. + try: + self.conn.close() + except Exception: # noqa: S110 + pass + + def conn_closed(self) -> bool: + """Return True if we know socket has been closed, False otherwise.""" + if _IS_SYNC: + return self.socket_checker.socket_closed(self.conn.get_conn) + else: + return self.conn.is_closing() + def send_cluster_time( self, command: MutableMapping[str, Any], diff --git a/tools/synchro.py b/tools/synchro.py index 9a760c0ad7..e502f96281 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -120,9 +120,9 @@ "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", - "async_receive_kms": "receive_kms", "AsyncNetworkingInterface": "NetworkingInterface", "_configured_protocol_interface": "_configured_socket_interface", + "_async_configured_socket": "_configured_socket", "SpecRunnerTask": "SpecRunnerThread", "AsyncMockConnection": "MockConnection", "AsyncMockPool": "MockPool",