Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,30 +238,50 @@ def _create_session_message( # pragma: no cover
message: JSONRPCMessage,
request: Request,
request_id: RequestId,
protocol_version: str,
) -> SessionMessage:
"""Create a session message with metadata including close_sse_stream callback."""
"""Create a session message with metadata including close_sse_stream callback.

async def close_stream_callback() -> None:
self.close_sse_stream(request_id)
The close_sse_stream callbacks are only provided when the client supports
resumability (protocol version >= 2025-11-25). Old clients can't resume if
the stream is closed early because they didn't receive a priming event.
"""
# Only provide close callbacks when client supports resumability
if self._event_store and protocol_version >= "2025-11-25":

async def close_standalone_stream_callback() -> None:
self.close_standalone_sse_stream()
async def close_stream_callback() -> None:
self.close_sse_stream(request_id)

async def close_standalone_stream_callback() -> None:
self.close_standalone_sse_stream()

metadata = ServerMessageMetadata(
request_context=request,
close_sse_stream=close_stream_callback,
close_standalone_sse_stream=close_standalone_stream_callback,
)
else:
metadata = ServerMessageMetadata(request_context=request)

metadata = ServerMessageMetadata(
request_context=request,
close_sse_stream=close_stream_callback,
close_standalone_sse_stream=close_standalone_stream_callback,
)
return SessionMessage(message, metadata=metadata)

async def _send_priming_event( # pragma: no cover
async def _maybe_send_priming_event(
self,
request_id: RequestId,
sse_stream_writer: MemoryObjectSendStream[dict[str, Any]],
protocol_version: str,
) -> None:
"""Send priming event for SSE resumability if event_store is configured."""
"""Send priming event for SSE resumability if event_store is configured.

Only sends priming events to clients with protocol version >= 2025-11-25,
which includes the fix for handling empty SSE data. Older clients would
crash trying to parse empty data as JSON.
"""
if not self._event_store:
return
# Priming events have empty data which older clients cannot handle.
if protocol_version < "2025-11-25":
return
priming_event_id = await self._event_store.store_event(
str(request_id), # Convert RequestId to StreamId (str)
None, # Priming event has no payload
Expand Down Expand Up @@ -499,6 +519,15 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re

return

# Extract protocol version for priming event decision.
# For initialize requests, get from request params.
# For other requests, get from header (already validated).
protocol_version = (
str(message.root.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION))
if is_initialization_request and message.root.params
else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
)

# Extract the request ID outside the try block for proper scope
request_id = str(message.root.id) # pragma: no cover
# Register this stream for the request ID
Expand Down Expand Up @@ -560,7 +589,7 @@ async def sse_writer():
try:
async with sse_stream_writer, request_stream_reader:
# Send priming event for SSE resumability
await self._send_priming_event(request_id, sse_stream_writer)
await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version)

# Process messages from the request-specific stream
async for event_message in request_stream_reader:
Expand Down Expand Up @@ -605,7 +634,7 @@ async def sse_writer():
async with anyio.create_task_group() as tg:
tg.start_soon(response, scope, receive, send)
# Then send the message to be processed by the server
session_message = self._create_session_message(message, request, request_id)
session_message = self._create_session_message(message, request, request_id, protocol_version)
await writer.send(session_message)
except Exception:
logger.exception("SSE response error")
Expand Down Expand Up @@ -864,6 +893,9 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send)
if self.mcp_session_id:
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

# Get protocol version from header (already validated in _validate_protocol_version)
replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)

# Create SSE stream for replay
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)

Expand All @@ -884,7 +916,7 @@ async def send_event(event_message: EventMessage) -> None:
self._sse_stream_writers[stream_id] = sse_stream_writer

# Send priming event for this new connection
await self._send_priming_event(stream_id, sse_stream_writer)
await self._maybe_send_priming_event(stream_id, sse_stream_writer, replay_protocol_version)

# Create new request streams for this connection
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0)
Expand Down
131 changes: 127 additions & 4 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock

import anyio
import httpx
Expand Down Expand Up @@ -41,9 +42,16 @@
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.session import RequestResponder
from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool
from mcp.types import (
InitializeResult,
JSONRPCMessage,
JSONRPCRequest,
TextContent,
TextResourceContents,
Tool,
)
from tests.test_helpers import wait_for_server

# Test constants
Expand Down Expand Up @@ -1761,6 +1769,116 @@ async def test_handle_sse_event_skips_empty_data():
await read_stream.aclose()


@pytest.mark.anyio
async def test_priming_event_not_sent_for_old_protocol_version():
"""Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""
# Create a transport with an event store
transport = StreamableHTTPServerTransport(
"/mcp",
event_store=SimpleEventStore(),
)

# Create a mock stream writer
write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1)

try:
# Call _maybe_send_priming_event with OLD protocol version - should NOT send
await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-06-18")

# Nothing should have been written to the stream
assert write_stream.statistics().current_buffer_used == 0

# Now test with NEW protocol version - should send
await transport._maybe_send_priming_event("test-request-id-2", write_stream, "2025-11-25")

# Should have written a priming event
assert write_stream.statistics().current_buffer_used == 1
finally:
await write_stream.aclose()
await read_stream.aclose()


@pytest.mark.anyio
async def test_priming_event_not_sent_without_event_store():
"""Test that _maybe_send_priming_event returns early when no event_store is configured."""
# Create a transport WITHOUT an event store
transport = StreamableHTTPServerTransport("/mcp")

# Create a mock stream writer
write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1)

try:
# Call _maybe_send_priming_event - should return early without sending
await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25")

# Nothing should have been written to the stream
assert write_stream.statistics().current_buffer_used == 0
finally:
await write_stream.aclose()
await read_stream.aclose()


@pytest.mark.anyio
async def test_priming_event_includes_retry_interval():
"""Test that _maybe_send_priming_event includes retry field when retry_interval is set."""
# Create a transport with an event store AND retry_interval
transport = StreamableHTTPServerTransport(
"/mcp",
event_store=SimpleEventStore(),
retry_interval=5000,
)

# Create a mock stream writer
write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1)

try:
# Call _maybe_send_priming_event with new protocol version
await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25")

# Should have written a priming event with retry field
assert write_stream.statistics().current_buffer_used == 1

# Read the event and verify it has retry field
event = await read_stream.receive()
assert "retry" in event
assert event["retry"] == 5000
finally:
await write_stream.aclose()
await read_stream.aclose()


@pytest.mark.anyio
async def test_close_sse_stream_callback_not_provided_for_old_protocol_version():
"""Test that close_sse_stream callbacks are NOT provided for old protocol versions."""
# Create a transport with an event store
transport = StreamableHTTPServerTransport(
"/mcp",
event_store=SimpleEventStore(),
)

# Create a mock message and request
mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="tools/list"))
mock_request = MagicMock()

# Call _create_session_message with OLD protocol version
session_msg = transport._create_session_message(mock_message, mock_request, "test-request-id", "2025-06-18")

# Callbacks should NOT be provided for old protocol version
assert session_msg.metadata is not None
assert isinstance(session_msg.metadata, ServerMessageMetadata)
assert session_msg.metadata.close_sse_stream is None
assert session_msg.metadata.close_standalone_sse_stream is None

# Now test with NEW protocol version - should provide callbacks
session_msg_new = transport._create_session_message(mock_message, mock_request, "test-request-id-2", "2025-11-25")

# Callbacks SHOULD be provided for new protocol version
assert session_msg_new.metadata is not None
assert isinstance(session_msg_new.metadata, ServerMessageMetadata)
assert session_msg_new.metadata.close_sse_stream is not None
assert session_msg_new.metadata.close_standalone_sse_stream is not None


@pytest.mark.anyio
async def test_streamablehttp_client_receives_priming_event(
event_server: tuple[SimpleEventStore, str],
Expand Down Expand Up @@ -2060,7 +2178,9 @@ async def on_resumption_token(token: str) -> None:


@pytest.mark.anyio
async def test_standalone_get_stream_reconnection(basic_server: None, basic_server_url: str) -> None:
async def test_standalone_get_stream_reconnection(
event_server: tuple[SimpleEventStore, str],
) -> None:
"""
Test that standalone GET stream automatically reconnects after server closes it.

Expand All @@ -2069,8 +2189,11 @@ async def test_standalone_get_stream_reconnection(basic_server: None, basic_serv
2. Server closes GET stream
3. Client reconnects with Last-Event-ID
4. Client receives notification 2 on new connection

Note: Requires event_server fixture (with event store) because close_standalone_sse_stream
callback is only provided when event_store is configured and protocol version >= 2025-11-25.
"""
server_url = basic_server_url
_, server_url = event_server
received_notifications: list[str] = []

async def message_handler(
Expand Down