From 0e7ea3118dcd647e5d38de061e5f3cf5555b93d9 Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Thu, 4 Dec 2025 11:54:21 +0000 Subject: [PATCH] fix: skip priming events for clients with old protocol versions Priming events (SEP-1699) have empty SSE data which older clients cannot handle - they try to JSON.parse("") and crash. Only send priming events to clients with protocol version >= 2025-11-25, which includes the fix for handling empty SSE data. For the initialize request, the protocol version is extracted from the request params. For subsequent requests, it's taken from the mcp-protocol-version header. --- src/mcp/server/streamable_http.py | 62 ++++++++++--- tests/shared/test_streamable_http.py | 131 ++++++++++++++++++++++++++- 2 files changed, 174 insertions(+), 19 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index d5cd387f1..2613b530c 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -238,30 +238,50 @@ def _create_session_message( # pragma: no cover message: JSONRPCMessage, request: Request, request_id: RequestId, + protocol_version: str, ) -> SessionMessage: - """Create a session message with metadata including close_sse_stream callback.""" + """Create a session message with metadata including close_sse_stream callback. - async def close_stream_callback() -> None: - self.close_sse_stream(request_id) + The close_sse_stream callbacks are only provided when the client supports + resumability (protocol version >= 2025-11-25). Old clients can't resume if + the stream is closed early because they didn't receive a priming event. + """ + # Only provide close callbacks when client supports resumability + if self._event_store and protocol_version >= "2025-11-25": - async def close_standalone_stream_callback() -> None: - self.close_standalone_sse_stream() + async def close_stream_callback() -> None: + self.close_sse_stream(request_id) + + async def close_standalone_stream_callback() -> None: + self.close_standalone_sse_stream() + + metadata = ServerMessageMetadata( + request_context=request, + close_sse_stream=close_stream_callback, + close_standalone_sse_stream=close_standalone_stream_callback, + ) + else: + metadata = ServerMessageMetadata(request_context=request) - metadata = ServerMessageMetadata( - request_context=request, - close_sse_stream=close_stream_callback, - close_standalone_sse_stream=close_standalone_stream_callback, - ) return SessionMessage(message, metadata=metadata) - async def _send_priming_event( # pragma: no cover + async def _maybe_send_priming_event( self, request_id: RequestId, sse_stream_writer: MemoryObjectSendStream[dict[str, Any]], + protocol_version: str, ) -> None: - """Send priming event for SSE resumability if event_store is configured.""" + """Send priming event for SSE resumability if event_store is configured. + + Only sends priming events to clients with protocol version >= 2025-11-25, + which includes the fix for handling empty SSE data. Older clients would + crash trying to parse empty data as JSON. + """ if not self._event_store: return + # Priming events have empty data which older clients cannot handle. + if protocol_version < "2025-11-25": + return priming_event_id = await self._event_store.store_event( str(request_id), # Convert RequestId to StreamId (str) None, # Priming event has no payload @@ -499,6 +519,15 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return + # Extract protocol version for priming event decision. + # For initialize requests, get from request params. + # For other requests, get from header (already validated). + protocol_version = ( + str(message.root.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION)) + if is_initialization_request and message.root.params + else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION) + ) + # Extract the request ID outside the try block for proper scope request_id = str(message.root.id) # pragma: no cover # Register this stream for the request ID @@ -560,7 +589,7 @@ async def sse_writer(): try: async with sse_stream_writer, request_stream_reader: # Send priming event for SSE resumability - await self._send_priming_event(request_id, sse_stream_writer) + await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version) # Process messages from the request-specific stream async for event_message in request_stream_reader: @@ -605,7 +634,7 @@ async def sse_writer(): async with anyio.create_task_group() as tg: tg.start_soon(response, scope, receive, send) # Then send the message to be processed by the server - session_message = self._create_session_message(message, request, request_id) + session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) except Exception: logger.exception("SSE response error") @@ -864,6 +893,9 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) if self.mcp_session_id: headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + # Get protocol version from header (already validated in _validate_protocol_version) + replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION) + # Create SSE stream for replay sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) @@ -884,7 +916,7 @@ async def send_event(event_message: EventMessage) -> None: self._sse_stream_writers[stream_id] = sse_stream_writer # Send priming event for this new connection - await self._send_priming_event(stream_id, sse_stream_writer) + await self._maybe_send_priming_event(stream_id, sse_stream_writer, replay_protocol_version) # Create new request streams for this connection self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 4ed2c88be..a626e7385 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -10,6 +10,7 @@ import time from collections.abc import Generator from typing import Any +from unittest.mock import MagicMock import anyio import httpx @@ -41,9 +42,16 @@ from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError -from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool +from mcp.types import ( + InitializeResult, + JSONRPCMessage, + JSONRPCRequest, + TextContent, + TextResourceContents, + Tool, +) from tests.test_helpers import wait_for_server # Test constants @@ -1761,6 +1769,116 @@ async def test_handle_sse_event_skips_empty_data(): await read_stream.aclose() +@pytest.mark.anyio +async def test_priming_event_not_sent_for_old_protocol_version(): + """Test that _maybe_send_priming_event skips for old protocol versions (backwards compat).""" + # Create a transport with an event store + transport = StreamableHTTPServerTransport( + "/mcp", + event_store=SimpleEventStore(), + ) + + # Create a mock stream writer + write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1) + + try: + # Call _maybe_send_priming_event with OLD protocol version - should NOT send + await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-06-18") + + # Nothing should have been written to the stream + assert write_stream.statistics().current_buffer_used == 0 + + # Now test with NEW protocol version - should send + await transport._maybe_send_priming_event("test-request-id-2", write_stream, "2025-11-25") + + # Should have written a priming event + assert write_stream.statistics().current_buffer_used == 1 + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_priming_event_not_sent_without_event_store(): + """Test that _maybe_send_priming_event returns early when no event_store is configured.""" + # Create a transport WITHOUT an event store + transport = StreamableHTTPServerTransport("/mcp") + + # Create a mock stream writer + write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1) + + try: + # Call _maybe_send_priming_event - should return early without sending + await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25") + + # Nothing should have been written to the stream + assert write_stream.statistics().current_buffer_used == 0 + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_priming_event_includes_retry_interval(): + """Test that _maybe_send_priming_event includes retry field when retry_interval is set.""" + # Create a transport with an event store AND retry_interval + transport = StreamableHTTPServerTransport( + "/mcp", + event_store=SimpleEventStore(), + retry_interval=5000, + ) + + # Create a mock stream writer + write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1) + + try: + # Call _maybe_send_priming_event with new protocol version + await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25") + + # Should have written a priming event with retry field + assert write_stream.statistics().current_buffer_used == 1 + + # Read the event and verify it has retry field + event = await read_stream.receive() + assert "retry" in event + assert event["retry"] == 5000 + finally: + await write_stream.aclose() + await read_stream.aclose() + + +@pytest.mark.anyio +async def test_close_sse_stream_callback_not_provided_for_old_protocol_version(): + """Test that close_sse_stream callbacks are NOT provided for old protocol versions.""" + # Create a transport with an event store + transport = StreamableHTTPServerTransport( + "/mcp", + event_store=SimpleEventStore(), + ) + + # Create a mock message and request + mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="tools/list")) + mock_request = MagicMock() + + # Call _create_session_message with OLD protocol version + session_msg = transport._create_session_message(mock_message, mock_request, "test-request-id", "2025-06-18") + + # Callbacks should NOT be provided for old protocol version + assert session_msg.metadata is not None + assert isinstance(session_msg.metadata, ServerMessageMetadata) + assert session_msg.metadata.close_sse_stream is None + assert session_msg.metadata.close_standalone_sse_stream is None + + # Now test with NEW protocol version - should provide callbacks + session_msg_new = transport._create_session_message(mock_message, mock_request, "test-request-id-2", "2025-11-25") + + # Callbacks SHOULD be provided for new protocol version + assert session_msg_new.metadata is not None + assert isinstance(session_msg_new.metadata, ServerMessageMetadata) + assert session_msg_new.metadata.close_sse_stream is not None + assert session_msg_new.metadata.close_standalone_sse_stream is not None + + @pytest.mark.anyio async def test_streamablehttp_client_receives_priming_event( event_server: tuple[SimpleEventStore, str], @@ -2060,7 +2178,9 @@ async def on_resumption_token(token: str) -> None: @pytest.mark.anyio -async def test_standalone_get_stream_reconnection(basic_server: None, basic_server_url: str) -> None: +async def test_standalone_get_stream_reconnection( + event_server: tuple[SimpleEventStore, str], +) -> None: """ Test that standalone GET stream automatically reconnects after server closes it. @@ -2069,8 +2189,11 @@ async def test_standalone_get_stream_reconnection(basic_server: None, basic_serv 2. Server closes GET stream 3. Client reconnects with Last-Event-ID 4. Client receives notification 2 on new connection + + Note: Requires event_server fixture (with event store) because close_standalone_sse_stream + callback is only provided when event_store is configured and protocol version >= 2025-11-25. """ - server_url = basic_server_url + _, server_url = event_server received_notifications: list[str] = [] async def message_handler(