From 251133344499c6ad84ca0f78a260a0b47f684c36 Mon Sep 17 00:00:00 2001 From: Brian Caswell Date: Tue, 21 Apr 2026 22:20:55 +0000 Subject: [PATCH] fix: dispatch client request handlers concurrently (#2489) BaseSession._receive_loop awaited each incoming request handler inline, serializing server->client requests (e.g. concurrent sampling calls via asyncio.gather peaked at one in flight). Add an opt-in '_dispatch_requests_concurrently' flag on BaseSession that spawns each request handler in the session's task group. ClientSession enables it; ServerSession stays serial to preserve the initialize ordering that its state machine relies on. Also fix two RequestResponder races that concurrent dispatch widens: - __enter__ no longer replaces the cancel scope, so a cancel() that arrives before the handler enters the context targets the same scope the handler will later run under. - cancel() is idempotent and safe to call before entry. Handler exceptions are translated into a JSON-RPC error response so a raising handler can't wedge the peer. --- src/mcp/client/session.py | 3 + src/mcp/shared/session.py | 55 +++++++-- tests/shared/test_session.py | 214 +++++++++++++++++++++++++++++++++++ 3 files changed, 262 insertions(+), 10 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0cea454a7..43b3472f8 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -123,6 +123,9 @@ def __init__( experimental_task_handlers: ExperimentalTaskHandlers | None = None, ) -> None: super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds) + # Dispatch incoming server->client requests concurrently so a slow + # sampling/elicitation callback doesn't serialize other in-flight requests. + self._dispatch_requests_concurrently = True self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback self._sampling_capabilities = sampling_capabilities diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae..b3cbf45a6 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -99,7 +99,8 @@ def __init__( def __enter__(self) -> RequestResponder[ReceiveRequestT, SendResultT]: """Enter the context manager, enabling request cancellation tracking.""" self._entered = True - self._cancel_scope = anyio.CancelScope() + # Enter the scope created in __init__ so pre-entry cancel() targets + # the same scope the handler will later run under. self._cancel_scope.__enter__() return self @@ -140,11 +141,12 @@ async def respond(self, response: SendResultT | ErrorData) -> None: ) async def cancel(self) -> None: - """Cancel this request and mark it as completed.""" - if not self._entered: # pragma: no cover - raise RuntimeError("RequestResponder must be used as a context manager") - if not self._cancel_scope: # pragma: no cover - raise RuntimeError("No active cancel scope") + """Cancel this request and mark it as completed. + + Safe to call before the context manager has been entered. + """ + if self._completed: + return self._cancel_scope.cancel() self._completed = True # Mark as completed so it's removed from in_flight @@ -158,6 +160,10 @@ async def cancel(self) -> None: def in_flight(self) -> bool: # pragma: no cover return not self._completed and not self.cancelled + @property + def completed(self) -> bool: + return self._completed + @property def cancelled(self) -> bool: return self._cancel_scope.cancel_called @@ -185,6 +191,10 @@ class BaseSession( _progress_callbacks: dict[RequestId, ProgressFnT] _response_routers: list[ResponseRouter] + # When True, incoming requests are dispatched to the session's task group + # so handlers run concurrently with the receive loop. + _dispatch_requests_concurrently: bool = False + def __init__( self, read_stream: ReadStream[SessionMessage | Exception], @@ -348,6 +358,29 @@ def _receive_request_adapter(self) -> TypeAdapter[ReceiveRequestT]: def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: raise NotImplementedError + async def _dispatch_request( + self, + responder: RequestResponder[ReceiveRequestT, SendResultT], + ) -> None: + """Run the per-request handler chain, translating handler exceptions + into a JSON-RPC error response so they can't wedge the peer. + """ + request_id = responder.request_id + try: + await self._received_request(responder) + if not responder.completed: + await self._handle_incoming(responder) + except Exception: + logging.warning("Request handler raised an exception", exc_info=True) + if not responder.completed: + error_response = JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""), + ) + await self._write_stream.send(SessionMessage(message=error_response)) + self._in_flight.pop(request_id, None) + async def _receive_loop(self) -> None: async with self._read_stream, self._write_stream: try: @@ -370,10 +403,6 @@ async def _handle_session_message(message: SessionMessage) -> None: context=sender_context, ) self._in_flight[responder.request_id] = responder - await self._received_request(responder) - - if not responder._completed: # type: ignore[reportPrivateUsage] - await self._handle_incoming(responder) except Exception: # For request validation errors, send a proper JSON-RPC error # response instead of crashing the server @@ -386,6 +415,12 @@ async def _handle_session_message(message: SessionMessage) -> None: ) session_message = SessionMessage(message=error_response) await self._write_stream.send(session_message) + return + + if self._dispatch_requests_concurrently: + self._task_group.start_soon(self._dispatch_request, responder) + else: + await self._dispatch_request(responder) elif isinstance(message.message, JSONRPCNotification): try: diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index d7c6cc3b5..6272f48a0 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -4,6 +4,7 @@ from mcp import Client, types from mcp.client.session import ClientSession from mcp.server import Server, ServerRequestContext +from mcp.shared._context import RequestContext from mcp.shared.exceptions import MCPError from mcp.shared.memory import create_client_server_memory_streams from mcp.shared.message import SessionMessage @@ -416,3 +417,216 @@ async def make_request(client_session: ClientSession): # Pending request completed successfully assert len(result_holder) == 1 assert isinstance(result_holder[0], EmptyResult) + + +@pytest.mark.anyio +async def test_concurrent_server_to_client_requests_run_in_parallel(): + """Regression test for #2489. + + A server tool fans out N concurrent ``ServerSession.create_message`` calls + via ``anyio.create_task_group``. The client sampling callback records the + peak number of concurrently-in-flight calls. Before the fix, requests were + serialized end-to-end by ``BaseSession._receive_loop`` and peak was 1. + """ + n = 4 + + inflight = 0 + peak = 0 + started = anyio.Event() + release = anyio.Event() + + async def sampling_callback( + context: RequestContext[ClientSession], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult: + nonlocal inflight, peak + inflight += 1 + peak = max(peak, inflight) + if peak == n: + started.set() + try: + with anyio.fail_after(5): + await release.wait() + finally: + inflight -= 1 + msg = params.messages[0].content + echo = msg.text if isinstance(msg, types.TextContent) else "" + return types.CreateMessageResult( + role="assistant", + content=types.TextContent(type="text", text=f"echo:{echo}"), + model="test-model", + ) + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + results: list[str] = [""] * n + + async def one(i: int) -> None: + r = await ctx.session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent(type="text", text=str(i)), + ) + ], + max_tokens=8, + ) + results[i] = r.content.text if isinstance(r.content, types.TextContent) else "" + + async with anyio.create_task_group() as tg: # pragma: no branch + for i in range(n): + tg.start_soon(one, i) + return types.CallToolResult(content=[types.TextContent(type="text", text=",".join(results))]) + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="fanout", input_schema={"type": "object"})]) + + server = Server(name="fanout", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools) + + async with Client(server, sampling_callback=sampling_callback) as client: + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call() -> None: + await client.call_tool("fanout", {}) + + tg.start_soon(call) + with anyio.fail_after(5): + await started.wait() + release.set() + + assert peak == n, f"server->client requests were serialized: peak in-flight={peak}, expected {n}" + + +@pytest.mark.anyio +async def test_sampling_callback_exception_returns_error_response(): + """A raising sampling callback must produce a JSON-RPC error response so + the server-side ``await ctx.session.create_message(...)`` doesn't hang. + """ + + async def sampling_callback( + context: RequestContext[ClientSession], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult: + raise RuntimeError("boom") + + caught: list[MCPError] = [] + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + try: + await ctx.session.create_message( + messages=[ + types.SamplingMessage( + role="user", + content=types.TextContent(type="text", text="x"), + ) + ], + max_tokens=8, + ) + except MCPError as e: + caught.append(e) + return types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="boom", input_schema={"type": "object"})]) + + server = Server(name="raise", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools) + + async with Client(server, sampling_callback=sampling_callback) as client: + with anyio.fail_after(5): + await client.call_tool("boom", {}) + + assert len(caught) == 1 + + +@pytest.mark.anyio +async def test_double_cancel_does_not_send_second_response(): + """Cancel called twice on the same responder must not emit a second response.""" + + class _Dummy: + _send_response_calls = 0 + + async def _send_response(self, *, request_id: types.RequestId, response: object) -> None: + self._send_response_calls += 1 + + dummy = _Dummy() + responder = RequestResponder[types.ServerRequest, types.ClientResult]( + request_id=1, + request_meta=None, + request=types.PingRequest(method="ping"), + session=dummy, # type: ignore[arg-type] + on_complete=lambda _r: None, + ) + with responder: + await responder.cancel() + await responder.cancel() + assert dummy._send_response_calls == 1 + + +@pytest.mark.anyio +async def test_cancel_before_context_entered_marks_scope_cancelled(): + """Regression: with concurrent dispatch, a CancelledNotification can + arrive before the handler task has entered ``with responder:``. + ``cancel()`` must not raise, and the scope entered later must already + be cancelled. + """ + + class _Dummy: + async def _send_response(self, *, request_id: types.RequestId, response: object) -> None: + pass + + responder = RequestResponder[types.ServerRequest, types.ClientResult]( + request_id=7, + request_meta=None, + request=types.PingRequest(method="ping"), + session=_Dummy(), # type: ignore[arg-type] + on_complete=lambda _r: None, + ) + + await responder.cancel() + assert responder.cancelled + assert responder._cancel_scope.cancel_called + + +@pytest.mark.anyio +async def test_handler_that_responds_then_raises_emits_no_duplicate_error(): + """If a request handler completes the response and then raises, the + dispatch path must not emit a second JSON-RPC error for the same id. + """ + + class _RaiseAfterRespond(ClientSession): + async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: + with responder: + await responder.respond(types.EmptyResult()) + raise RuntimeError("after respond") + + class _CapturingWrite: + def __init__(self) -> None: + self.sent: list[SessionMessage] = [] + + async def send(self, msg: SessionMessage) -> None: + self.sent.append(msg) + + async with create_client_server_memory_streams() as (client_streams, _server_streams): + client_read, client_write = client_streams + session = _RaiseAfterRespond(client_read, client_write) + + capture = _CapturingWrite() + session._write_stream = capture # type: ignore[assignment] + + responder = RequestResponder[types.ServerRequest, types.ClientResult]( + request_id=99, + request_meta=None, + request=types.PingRequest(method="ping"), + session=session, + on_complete=lambda r: session._in_flight.pop(r.request_id, None), + ) + session._in_flight[99] = responder + + await session._dispatch_request(responder) + + assert len(capture.sent) == 1, capture.sent + assert isinstance(capture.sent[0].message, JSONRPCResponse) + assert capture.sent[0].message.id == 99