Skip to content

Concurrent server-side requests are serialized end-to-end by BaseSession #2489

@demoray

Description

@demoray

Description

When a server tool issues multiple concurrent server to client requests over a single MCP session — for example, fanning out several ServerSession.create_message(...) (sampling) calls via asyncio.gather — only one is ever in flight on the receiving side: BaseSession._receive_loop awaits each incoming request handler inline before reading the next message off the stream, so subsequent responses can't even be dequeued until the current handler returns. A minimal reproducer using mcp.server.fastmcp.FastMCP + ClientSession over the SDK's own in-memory transport (5 concurrent create_message calls, sampling callback sleeps 0.5s) shows peak in-flight = 1 and elapsed ≈ 2.5s instead of ≈ 0.5s.

Concrete use case: a tool that fans out dozens of independent sampling requests per invocation against an LLM deployment with provisioned capacity sitting idle, the work is embarrassingly parallel and there's token budget to burn, but wall-clock scales linearly with the fan-out because every request blocks the next one. The same shape applies to elicitation, list-roots, and any other server to client request type, since they all flow through the same loop.

I'd like the SDK to support concurrent handling of incoming requests on a single session, likely as an opt-in knob to preserve today's strict-ordering semantics for handlers that depend on them, with whatever bounding/back-pressure mechanism the maintainers prefer.

References

"""Reproducer: concurrent server->client requests are serialized end-to-end.

A FastMCP tool fans out N concurrent ``ServerSession.create_message`` calls
via ``asyncio.gather``. The client's sampling callback records peak in-flight.

Expected with concurrent dispatch:  peak == N,  elapsed ~= SLEEP
Observed today:                     peak == 1,  elapsed ~= N * SLEEP

Root cause: ``BaseSession._receive_loop`` awaits each incoming request
handler inline before reading the next message off the stream
(mcp/shared/session.py).

Run:  python repro_concurrent_sampling.py
"""

import asyncio

from mcp.server.fastmcp import Context, FastMCP
from mcp.shared.context import RequestContext
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.types import (
    CreateMessageRequestParams,
    CreateMessageResult,
    SamplingMessage,
    TextContent,
)

N = 5
SLEEP = 0.5


async def main() -> None:
    inflight = 0
    peak = 0

    async def sampling_callback(
        _ctx: RequestContext,
        params: CreateMessageRequestParams,
    ) -> CreateMessageResult:
        nonlocal inflight, peak
        inflight += 1
        peak = max(peak, inflight)
        try:
            await asyncio.sleep(SLEEP)
        finally:
            inflight -= 1
        msg = params.messages[0].content
        echo = msg.text if isinstance(msg, TextContent) else ""
        return CreateMessageResult(
            role="assistant",
            content=TextContent(type="text", text=f"echo:{echo}"),
            model="test-model",
        )

    server = FastMCP("repro")

    @server.tool()
    async def fanout(ctx: Context) -> str:
        async def one(i: int) -> str:
            r = await ctx.session.create_message(
                messages=[
                    SamplingMessage(
                        role="user",
                        content=TextContent(type="text", text=str(i)),
                    )
                ],
                max_tokens=8,
            )
            return r.content.text if isinstance(r.content, TextContent) else ""

        return ",".join(await asyncio.gather(*(one(i) for i in range(N))))

    loop = asyncio.get_running_loop()
    t0 = loop.time()
    async with create_connected_server_and_client_session(
        server, sampling_callback=sampling_callback
    ) as session:
        await session.call_tool("fanout", {})
    elapsed = loop.time() - t0

    print(f"N={N}, per-call sleep={SLEEP}s")
    print(f"peak in-flight: {peak}    (concurrent: {N}, serialized: 1)")
    print(f"elapsed:        {elapsed:.2f}s  (concurrent: ~{SLEEP}s, serialized: ~{N * SLEEP}s)")

    assert peak == N, (
        f"server->client requests were serialized: peak in-flight={peak}, expected {N}"
    )


if __name__ == "__main__":
    asyncio.run(main())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions