Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,17 @@ async def _handle_reconnection(
) -> None:
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
# Bail if max retries exceeded
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
if attempt >= MAX_RECONNECTION_ATTEMPTS:
logger.warning(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch
error_data = ErrorData(
code=INTERNAL_ERROR,
message=f"SSE stream disconnected after {MAX_RECONNECTION_ATTEMPTS} reconnection attempts",
)
error_msg = SessionMessage(
JSONRPCError(jsonrpc="2.0", id=ctx.session_message.message.id, error=error_data)
)
await ctx.read_stream_writer.send(error_msg)
return

# Always wait - use server value or default
Expand All @@ -404,12 +413,15 @@ async def _handle_reconnection(
# Track for potential further reconnection
reconnect_last_event_id: str = last_event_id
reconnect_retry_ms = retry_interval_ms
received_data = False

async for sse in event_source.aiter_sse():
if sse.id: # pragma: no branch
reconnect_last_event_id = sse.id
if sse.retry is not None:
reconnect_retry_ms = sse.retry
if sse.data:
received_data = True

is_complete = await self._handle_sse_event(
sse,
Expand All @@ -421,9 +433,13 @@ async def _handle_reconnection(
await event_source.response.aclose()
return

# Stream ended again without response - reconnect again (reset attempt counter)
# Stream ended without response — reconnect.
# Reset attempt counter only when real data was received
# (server made progress). Otherwise increment to prevent
# infinite loops when server only sends priming events.
next_attempt = 0 if received_data else attempt + 1
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, next_attempt)
except Exception as e: # pragma: no cover
logger.debug(f"Reconnection failed: {e}")
# Try to reconnect again if we still have an event ID
Expand Down
74 changes: 74 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
CallToolRequestParams,
CallToolResult,
InitializeResult,
JSONRPCError,
JSONRPCRequest,
ListToolsResult,
PaginatedRequestParams,
Expand Down Expand Up @@ -2318,3 +2319,76 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(

assert "content-type" in headers_data
assert headers_data["content-type"] == "application/json"


@pytest.mark.anyio
async def test_handle_reconnection_stops_after_max_attempts() -> None:
"""_handle_reconnection must not reset attempt counter on stream drop.

Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/2393.
When the server accepts the SSE connection but closes the stream without
sending a complete JSON-RPC response, the client must give up after
MAX_RECONNECTION_ATTEMPTS total attempts and report an error — not retry
forever.
"""
from unittest.mock import AsyncMock, MagicMock

from mcp.client.streamable_http import MAX_RECONNECTION_ATTEMPTS, RequestContext

transport = StreamableHTTPTransport("http://test/mcp")
connect_count = 0

@asynccontextmanager
async def fake_aconnect_sse(*_args: object, **_kwargs: object):
nonlocal connect_count
connect_count += 1

response = MagicMock()
response.raise_for_status = MagicMock()
response.aclose = AsyncMock()

event_source = MagicMock()
event_source.response = response

async def aiter_sse():
yield ServerSentEvent(event="message", data="", id=f"evt-{connect_count}", retry=None)

event_source.aiter_sse = aiter_sse
yield event_source

write_stream, read_stream = create_context_streams[SessionMessage | Exception](1)

request = JSONRPCRequest(
jsonrpc="2.0",
id="req-1",
method="tools/call",
params={"name": "test_tool", "arguments": {}},
)
ctx = RequestContext(
client=MagicMock(),
session_id="test-session",
session_message=SessionMessage(request),
metadata=None,
read_stream_writer=write_stream,
)

import mcp.client.streamable_http as _mod

original = _mod.aconnect_sse
_mod.aconnect_sse = fake_aconnect_sse # type: ignore[assignment]
try:
await transport._handle_reconnection(ctx, "evt-0", 0)
finally:
_mod.aconnect_sse = original

assert connect_count == MAX_RECONNECTION_ATTEMPTS

with anyio.fail_after(1):
msg = await read_stream.receive()
assert isinstance(msg, SessionMessage)
assert isinstance(msg.message, JSONRPCError)
assert "reconnection attempts" in msg.message.error.message.lower()
assert msg.message.id == "req-1"

await write_stream.aclose()
await read_stream.aclose()
Loading