From 469fdb2f351447157e667581f18875d474795502 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 1 Jul 2026 16:10:35 +0000 Subject: [PATCH] Make client-side cancellation work over the 2026 transports The modern session stamp suppressed the courtesy notifications/cancelled for every request, so abandoning a request over 2026 streamable HTTP left the POST open and the server running, and over 2026 stream-pair transports never sent the frame the spec requires. The frame is now the dispatcher's uniform abandon signal: stream transports write it (the 2026 stdio cancellation spelling), while the streamable-HTTP transport translates it into aborting the named request's own in-flight POST - closing the response stream is that wire's cancellation signal, and no client-to-server notification ever POSTs at 2026. Each POST records the era it was sent under so a late cancel is interpreted per the named request, not whatever was negotiated since; pre-2026 wires still POST the frame (a disconnect is explicitly not a cancel there). The negotiation methods keep their cancellation opt-out on every path. Callers can also supply their own request id via CallOptions["request_id"] on both dispatchers - groundwork for demultiplexing subscriptions/listen streams, whose id must be known before the result arrives. Ids reach the peer verbatim ("7" stays a string), collide loudly only for the caller who chose them (minting skips occupied keys), and share one coerced collision domain so the in-memory dispatcher raises exactly where the wire one would. --- src/mcp/client/session.py | 11 +- src/mcp/client/streamable_http.py | 111 +++++- src/mcp/shared/direct_dispatcher.py | 33 +- src/mcp/shared/dispatcher.py | 27 ++ src/mcp/shared/jsonrpc_dispatcher.py | 76 +++-- tests/client/test_session.py | 83 +++-- tests/client/test_streamable_http.py | 435 +++++++++++++++++++++++- tests/shared/test_dispatcher.py | 113 ++++++ tests/shared/test_jsonrpc_dispatcher.py | 50 ++- 9 files changed, 848 insertions(+), 91 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 804180e05..5c09304e4 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -93,7 +93,16 @@ def stamp(data: dict[str, Any], opts: CallOptions) -> None: meta[PROTOCOL_VERSION_META_KEY] = protocol_version meta[CLIENT_INFO_META_KEY] = client_info meta[CLIENT_CAPABILITIES_META_KEY] = capabilities - opts["cancel_on_abandon"] = False + # `cancel_on_abandon` stays at the dispatcher default (True): the + # courtesy `notifications/cancelled` is the abandon signal. On the + # stream transports it is the 2026 wire's cancellation spelling; the + # streamable-HTTP transport translates it into aborting the request's + # own POST instead of writing it (the 2026 HTTP wire has no + # client-to-server notifications - closing the stream is the signal). + # The negotiation methods still opt out, mirroring `_preconnect_stamp`: + # the spec forbids cancelling them. + if data["method"] in ("initialize", "server/discover"): + opts["cancel_on_abandon"] = False headers = opts.setdefault("headers", {}) headers[MCP_PROTOCOL_VERSION_HEADER] = protocol_version headers[MCP_METHOD_HEADER] = data["method"] diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index f28eb7c7a..09e5048cc 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -26,6 +26,7 @@ RequestId, jsonrpc_message_adapter, ) +from mcp_types.version import MODERN_PROTOCOL_VERSIONS from pydantic import ValidationError from mcp.client._transport import TransportStreams @@ -33,6 +34,7 @@ from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER +from mcp.shared.jsonrpc_dispatcher import cancelled_request_id_from_params from mcp.shared.message import ClientMessageMetadata, SessionMessage logger = logging.getLogger(__name__) @@ -70,6 +72,19 @@ class RequestContext: read_stream_writer: StreamWriter +@dataclass(slots=True) +class _InFlightPost: + """A request POST in flight: its abort scope and the era it was sent under. + + `modern` is the negotiated-version cache as of this request's dequeue, so a + later cancel frame is interpreted under the era the request actually ran + with, not whatever the cache says by then. + """ + + scope: anyio.CancelScope + modern: bool + + class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" @@ -81,11 +96,18 @@ def __init__(self, url: str) -> None: """ self.url = url self.session_id: str | None = None - # Captured from each stamped POST's metadata. Reused on outbound HTTP that carries - # no per-message header (transport-internal GET/DELETE, and dispatcher-written - # response/error/cancel POSTs that bypass the session's stamp). Cleared when an - # `initialize` POST goes out so a probe-stamped value cannot leak onto the handshake. + # Captured from each stamped message's metadata, synchronously in the + # post_writer loop so the cache always reflects wire order (a POST task's + # scheduling is arbitrary). Reused on outbound HTTP that carries no + # per-message header (transport-internal GET/DELETE, and dispatcher-written + # response/error POSTs that bypass the session's stamp), and consulted by + # `_consume_modern_cancellation`. Cleared when an `initialize` message is + # dequeued so a probe-stamped value cannot leak onto the handshake. self._protocol_version_header: str | None = None + # Every request's POST runs inside one of these so an outbound + # `notifications/cancelled` at 2026 can abort it; see + # `_consume_modern_cancellation`. Keys are verbatim-typed ("1" is not 1). + self._in_flight_posts: dict[RequestId, _InFlightPost] = {} def _prepare_headers(self) -> dict[str, str]: """Build MCP-specific request headers for any outbound HTTP request. @@ -93,9 +115,9 @@ def _prepare_headers(self) -> dict[str, str]: These are merged with the ``httpx.AsyncClient`` defaults (these take precedence). The cached ``MCP-Protocol-Version`` is included whenever present so messages that don't pass through the session's stamp — - response/error/cancel POSTs, transport-internal GET/DELETE — still - carry the negotiated version. Per-message headers are layered on top - by the caller. + response/error POSTs, legacy cancel frames, transport-internal + GET/DELETE — still carry the negotiated version. Per-message headers + are layered on top by the caller. """ headers: dict[str, str] = { "accept": "application/json, text/event-stream", @@ -245,19 +267,57 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: await event_source.response.aclose() break + def _consume_modern_cancellation(self, session_message: SessionMessage) -> bool: + """Translate an outbound `notifications/cancelled` at 2026; True means "do not POST". + + The 2026 wire defines no client-to-server notifications over streamable + HTTP: closing a request's response stream IS its cancellation signal. + The dispatcher still emits the courtesy frame as its abandon signal + (every outbound cancel names one of our own request ids - the spec + forbids cancelling a request the sender did not issue), so this + transport translates it: when the named request's POST is in flight, + that POST's own recorded era decides - abort-and-swallow at 2026, POST + the frame below it (where the frame is the signal and a disconnect + explicitly is not). With no POST to consult, the cached negotiated + version decides; at 2026 the frame is swallowed even unmatched, so a + late cancel racing the response cannot leak onto the wire. + """ + message = session_message.message + if not (isinstance(message, JSONRPCNotification) and message.method == "notifications/cancelled"): + return False + request_id = cancelled_request_id_from_params(message.params) + post = self._in_flight_posts.get(request_id) if request_id is not None else None + if post is not None: + if not post.modern: + return False + logger.debug("aborting in-flight POST for cancelled request %r", request_id) + post.scope.cancel() + return True + return self._protocol_version_header in MODERN_PROTOCOL_VERSIONS + + async def _run_request_post( + self, + post_fn: Callable[[], Awaitable[None]], + post: _InFlightPost, + request_id: RequestId, + ) -> None: + """Run one request's POST inside its abort scope (see `_consume_modern_cancellation`).""" + try: + with post.scope: + await post_fn() + finally: + # Identity-guarded: a reused id may already have a successor + # registered while this task unwinds - popping by key alone would + # evict the live entry and leave the new POST unabortable. + if self._in_flight_posts.get(request_id) is post: + del self._in_flight_posts[request_id] + async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" message = ctx.session_message.message - is_initialization = self._is_initialization_request(message) - if is_initialization: - # `initialize` is the negotiation, not a "subsequent request" — discard any - # probe-stamped value so the discover→fallback path can't leak it onto the handshake. - self._protocol_version_header = None headers = self._prepare_headers() if ctx.metadata is not None and ctx.metadata.headers is not None: headers.update(ctx.metadata.headers) - if MCP_PROTOCOL_VERSION_HEADER in ctx.metadata.headers: - self._protocol_version_header = ctx.metadata.headers[MCP_PROTOCOL_VERSION_HEADER] async with ctx.client.stream( "POST", @@ -302,7 +362,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: await ctx.read_stream_writer.send(session_message) return - if is_initialization: + if self._is_initialization_request(message): self._maybe_extract_session_id_from_response(response) # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications: @@ -455,6 +515,8 @@ async def post_writer( async def _handle_message(session_message: SessionMessage) -> None: message = session_message.message + if self._consume_modern_cancellation(session_message): + return metadata = ( session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) @@ -470,6 +532,15 @@ async def _handle_message(session_message: SessionMessage) -> None: if self._is_initialized_notification(message): start_get_stream() + if self._is_initialization_request(message): + # `initialize` is the negotiation, not a "subsequent request" — discard any + # probe-stamped value so the discover→fallback path can't leak it onto the handshake. + self._protocol_version_header = None + elif metadata is not None and metadata.headers is not None: + stamped_version = metadata.headers.get(MCP_PROTOCOL_VERSION_HEADER) + if stamped_version is not None: + self._protocol_version_header = stamped_version + ctx = RequestContext( client=client, session_id=self.session_id, @@ -486,7 +557,15 @@ async def handle_request_async(): # If this is a request, start a new task to handle it if isinstance(message, JSONRPCRequest): - tg.start_soon(handle_request_async) + # Register the abort scope before the spawn: the next + # message through this loop can already be the abandon + # signal for this id, ahead of the task ever running. + post = _InFlightPost( + scope=anyio.CancelScope(), + modern=self._protocol_version_header in MODERN_PROTOCOL_VERSIONS, + ) + self._in_flight_posts[message.id] = post + tg.start_soon(self._run_request_post, handle_request_async, post, message.id) else: await handle_request_async() diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index fd3e69d49..62c74b808 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -28,7 +28,7 @@ from pydantic import ValidationError from mcp.shared._compat import resync_tracer -from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT, coerce_request_id from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.message import MessageMetadata from mcp.shared.transport_context import TransportContext @@ -56,7 +56,8 @@ class _DirectDispatchContext: _back_request: _Request _back_notify: _Notify request_id: RequestId | None = None - """A dispatcher-synthesized id for requests; `None` for notifications.""" + """The caller-supplied `CallOptions["request_id"]`, else a dispatcher-synthesized + id for requests; `None` for notifications.""" message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework """Always `None`: in-memory dispatch attaches no transport metadata.""" _on_progress: ProgressFnT | None = None @@ -106,6 +107,7 @@ def __init__(self, transport_ctx: TransportContext, *, raise_handler_exceptions: self._on_request: OnRequest | None = None self._on_notify: OnNotify | None = None self._next_id = 0 + self._in_flight_ids: set[RequestId] = set() self._ready = anyio.Event() self._close_event = anyio.Event() self._running = False @@ -227,9 +229,28 @@ async def _dispatch_request( # waiting on a peer whose run() has not started yet. await self._wait_ready() assert self._on_request is not None - # Synthesize an id: the DispatchContext contract reserves None for notifications. - self._next_id += 1 - dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=self._next_id) + supplied_id = opts.get("request_id") + if supplied_id is not None: + request_id: RequestId = supplied_id + # Collisions use the same coerced domain as JSONRPCDispatcher's + # pending keys, so this in-memory stand-in raises for exactly + # the ids the wire dispatcher would; the context still sees + # the verbatim value. + in_flight_key = coerce_request_id(request_id) + if in_flight_key in self._in_flight_ids: + raise ValueError(f"request id {request_id!r} is already in flight") + else: + # Synthesize an id (the DispatchContext contract reserves None + # for notifications), minting past any key a supplied id + # occupies: the collision error is reserved for the caller + # who actually chose the id. + self._next_id += 1 + while self._next_id in self._in_flight_ids: + self._next_id += 1 + request_id = self._next_id + in_flight_key = request_id + self._in_flight_ids.add(in_flight_key) + dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=request_id) try: return await self._on_request(dctx, method, params) except MCPError: @@ -247,6 +268,8 @@ async def _dispatch_request( raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e logger.exception("request handler raised") raise MCPError(code=INTERNAL_ERROR, message="Internal server error") from None + finally: + self._in_flight_ids.discard(in_flight_key) except TimeoutError: raise MCPError( code=REQUEST_TIMEOUT, diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index de83189f1..16360d314 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -34,11 +34,26 @@ "OnRequest", "Outbound", "ProgressFnT", + "coerce_request_id", ] TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True) +def coerce_request_id(request_id: RequestId) -> RequestId: + """Coerce a stringified int request id back to int so a peer-echoed id still correlates (matches the TS SDK). + + This is the collision/correlation domain dispatchers share: "7" and 7 are one + id for correlation purposes, even where the wire carries the verbatim value. + """ + if isinstance(request_id, str): + try: + return int(request_id) + except ValueError: + pass + return request_id + + class ProgressFnT(Protocol): """Callback invoked when a progress notification arrives for a pending request.""" @@ -51,6 +66,18 @@ class CallOptions(TypedDict, total=False): All keys are optional. Dispatchers ignore keys they do not understand. """ + request_id: RequestId + """Send the request under this caller-supplied id instead of a dispatcher-minted one. + + The peer sees the value verbatim ("7" stays a string). A value that collides + with one of the sender's own in-flight request ids raises `ValueError`. + Callers that need to know a request's id before its result arrives (a + `subscriptions/listen` stream is demultiplexed by it) mint their own ids + here; string ids that don't parse as integers can never collide with the + dispatcher's minted sequence. Per the class contract, dispatchers that + predate this key ignore it and mint as usual. + """ + timeout: float """Seconds to wait for a result before raising and sending `notifications/cancelled`.""" diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 64fcd3298..793c59bc7 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -39,7 +39,15 @@ from mcp.shared._compat import resync_tracer from mcp.shared._otel import inject_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher, OnNotify, OnRequest, ProgressFnT +from mcp.shared.dispatcher import ( + CallOptions, + DispatchContext, + Dispatcher, + OnNotify, + OnRequest, + ProgressFnT, + coerce_request_id, +) from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.message import ( ClientMessageMetadata, @@ -49,7 +57,12 @@ ) from mcp.shared.transport_context import TransportContext -__all__ = ["JSONRPCDispatcher", "handler_exception_to_error_data", "progress_token_from_params"] +__all__ = [ + "JSONRPCDispatcher", + "cancelled_request_id_from_params", + "handler_exception_to_error_data", + "progress_token_from_params", +] logger = logging.getLogger(__name__) @@ -93,14 +106,13 @@ def progress_token_from_params(params: Mapping[str, Any] | None) -> ProgressToke return None -def _coerce_id(request_id: RequestId) -> RequestId: - """Coerce a stringified int request ID back to int so a peer-echoed ID still correlates (matches the TS SDK).""" - if isinstance(request_id, str): - try: - return int(request_id) - except ValueError: - pass - return request_id +def cancelled_request_id_from_params(params: Mapping[str, Any] | None) -> RequestId | None: + """Read `params.requestId` from a `notifications/cancelled`; reject bool (True would alias request id 1).""" + match params: + case {"requestId": str() | int() as request_id} if not isinstance(request_id, bool): + return request_id + case _: + return None @dataclass(slots=True) @@ -314,7 +326,22 @@ async def send_raw_request( if not self._running: raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run()") opts = opts or {} - request_id = self._allocate_id() + supplied_id = opts.get("request_id") + if supplied_id is not None: + request_id: RequestId = supplied_id + # The pending key gets the same coercion `_resolve_pending` applies + # to inbound response ids, so a supplied "7" still correlates + # whether the peer echoes "7" or 7. The wire id stays verbatim. + pending_key = coerce_request_id(request_id) + if pending_key in self._pending: + raise ValueError(f"request id {request_id!r} is already in flight") + else: + # Mint past any key a supplied id occupies: the collision error is + # reserved for the caller who actually chose the id. + request_id = self._allocate_id() + while request_id in self._pending: + request_id = self._allocate_id() + pending_key = request_id out_params = dict(params) if params is not None else {} out_meta = dict(out_params.get("_meta") or {}) on_progress = opts.get("on_progress") @@ -327,7 +354,7 @@ async def send_raw_request( # a WouldBlock later just means the waiter already has its one outcome. send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) pending = _Pending(send=send, receive=receive, on_progress=on_progress) - self._pending[request_id] = pending + self._pending[pending_key] = pending plan = _plan_outbound(_related_request_id, opts) # Spec MUST: only previously-issued requests may be cancelled. A write @@ -398,7 +425,7 @@ async def send_raw_request( raise finally: # Remove the waiter on every path so a late response is dropped, not leaked. - self._pending.pop(request_id, None) + self._pending.pop(pending_key, None) send.close() receive.close() @@ -548,7 +575,7 @@ async def _dispatch_request( # TODO(maxisbey): duplicate ids blind-overwrite (v1/TS parity); revisit # rejecting with INVALID_REQUEST. Key coerced so a stringified # `notifications/cancelled` id still correlates. - self._in_flight[_coerce_id(req.id)] = _InFlight(scope=scope, dctx=dctx) + self._in_flight[coerce_request_id(req.id)] = _InFlight(scope=scope, dctx=dctx) if req.method in self._inline_methods: # Spawn so `sender_ctx` applies, but park the read loop until the # handler returns - that's the inline ordering guarantee. @@ -579,22 +606,17 @@ def _dispatch_notification( layer owns) and still teed to `on_notify` afterwards. """ if msg.method == "notifications/cancelled": - match msg.params: - # bool subclasses int: the guards keep True from aliasing request id 1. - case {"requestId": str() | int() as rid} if ( - not isinstance(rid, bool) and (in_flight := self._in_flight.get(_coerce_id(rid))) is not None - ): - in_flight.dctx.cancel_requested.set() - if self._peer_cancel_mode == "interrupt": - in_flight.scope.cancel() - case _: - pass + rid = cancelled_request_id_from_params(msg.params) + if rid is not None and (in_flight := self._in_flight.get(coerce_request_id(rid))) is not None: + in_flight.dctx.cancel_requested.set() + if self._peer_cancel_mode == "interrupt": + in_flight.scope.cancel() elif msg.method == "notifications/progress": match msg.params: case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( not isinstance(token, bool) and not isinstance(progress, bool) - and (pending := self._pending.get(_coerce_id(token))) is not None + and (pending := self._pending.get(coerce_request_id(token))) is not None and pending.on_progress is not None ): total = msg.params.get("total") @@ -620,7 +642,7 @@ def _dispatch_notification( self._spawn(_contained_notify(on_notify), dctx, msg.method, msg.params, sender_ctx=sender_ctx) def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: - pending = self._pending.get(_coerce_id(request_id)) if request_id is not None else None + pending = self._pending.get(coerce_request_id(request_id)) if request_id is not None else None if pending is None: logger.debug("dropping response for unknown/late request id %r", request_id) return @@ -680,7 +702,7 @@ async def _handle_request( # since handler return, so a peer cancel can't interleave. # Identity guard: don't evict a duplicate id's newer entry. dctx.close() - key = _coerce_id(req.id) + key = coerce_request_id(req.id) if (entry := self._in_flight.get(key)) is not None and entry.dctx is dctx: del self._in_flight[key] # A write interrupted by cancellation may still have delivered diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f76991f65..2a53f67ce 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1330,43 +1330,43 @@ def test_adopt_raises_when_no_mutual_modern_version_is_supported() -> None: assert session.protocol_version is None -@pytest.mark.anyio -async def test_initialize_opts_out_of_cancel_on_abandon_while_other_requests_leave_it_unset(): - """`send_request` passes `cancel_on_abandon=False` for `initialize` — the spec forbids - cancelling it — and leaves the option unset for every other method.""" +class _OptsRecordingDispatcher: + """Records `send_raw_request` opts and answers from a per-method script (default `{}`).""" - class RecordingDispatcher: - """Records `send_raw_request` opts and answers with canned results.""" + def __init__(self, answers: dict[str, dict[str, Any]] | None = None) -> None: + self.calls: list[tuple[str, CallOptions]] = [] + self._answers = answers or {} - def __init__(self) -> None: - self.calls: list[tuple[str, CallOptions]] = [] + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + task_status.started() + await anyio.sleep_forever() - async def run( - self, - on_request: OnRequest, - on_notify: OnNotify, - *, - task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, - ) -> None: - task_status.started() - await anyio.sleep_forever() + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.calls.append((method, opts or {})) + return self._answers.get(method, {}) - async def send_raw_request( - self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None - ) -> dict[str, Any]: - self.calls.append((method, opts or {})) - if method == "initialize": - return InitializeResult( - protocol_version=LATEST_HANDSHAKE_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ).model_dump(by_alias=True, mode="json", exclude_none=True) - return {} + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: + pass - async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: - pass - dispatcher = RecordingDispatcher() +@pytest.mark.anyio +async def test_initialize_opts_out_of_cancel_on_abandon_while_other_requests_leave_it_unset(): + """`send_request` passes `cancel_on_abandon=False` for `initialize` — the spec forbids + cancelling it — and leaves the option unset for every other method.""" + init_answer = InitializeResult( + protocol_version=LATEST_HANDSHAKE_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True) + dispatcher = _OptsRecordingDispatcher({"initialize": init_answer}) with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: await session.initialize() @@ -1376,6 +1376,27 @@ async def notify(self, method: str, params: Mapping[str, Any] | None, opts: Call assert "cancel_on_abandon" not in opts_by_method["ping"] +@pytest.mark.anyio +async def test_modern_stamp_leaves_cancel_on_abandon_at_the_dispatcher_default(): + """Post-adopt modern requests leave `cancel_on_abandon` unset (the dispatcher default, + True): the courtesy frame is the abandon signal — the 2026 cancellation spelling on + stream transports, and the streamable-HTTP transport's cue to abort the request's own + POST. The negotiation methods still opt out on every path: `send_discover`'s explicit + opts, and the stamp's own carve-out for a `server/discover` sent through the generic + `send_request`.""" + dispatcher = _OptsRecordingDispatcher({"server/discover": _discover_result_dict()}) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.discover() + await session.send_ping() + await session.send_request(types.DiscoverRequest(params=types.RequestParams()), types.DiscoverResult) + assert [method for method, _ in dispatcher.calls] == ["server/discover", "ping", "server/discover"] + negotiation_opts, ping_opts, stamped_negotiation_opts = (opts for _, opts in dispatcher.calls) + assert negotiation_opts.get("cancel_on_abandon") is False + assert "cancel_on_abandon" not in ping_opts + assert stamped_negotiation_opts.get("cancel_on_abandon") is False + + def test_constructor_rejects_streams_and_dispatcher_together(): client_side, _server_side = create_direct_dispatcher_pair() s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index 99ff6f03e..defda41f8 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -8,16 +8,37 @@ import base64 import json +from collections.abc import AsyncIterator, Callable, Mapping +from typing import Any import anyio import httpx import pytest from inline_snapshot import snapshot -from mcp_types import METHOD_NOT_FOUND, JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse +from mcp_types import ( + CLIENT_CAPABILITIES_META_KEY, + CLIENT_INFO_META_KEY, + METHOD_NOT_FOUND, + PROTOCOL_VERSION_META_KEY, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, +) +from mcp_types.version import LATEST_MODERN_VERSION +from starlette.types import Receive, Scope, Send from mcp.client.streamable_http import streamable_http_client -from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER, encode_header_value -from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.server import Server +from mcp.server._streamable_http_modern import handle_modern_request +from mcp.server.subscriptions import InMemorySubscriptionBus, ListenHandler, ServerEvent +from mcp.shared.dispatcher import CallOptions, DispatchContext +from mcp.shared.inbound import MCP_METHOD_HEADER, MCP_PROTOCOL_VERSION_HEADER, encode_header_value +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.transport_context import TransportContext +from tests.interaction.transports import StreamingASGITransport +from tests.shared.test_dispatcher import Recorder, echo_handlers @pytest.mark.parametrize( @@ -154,3 +175,411 @@ def handler(request: httpx.Request) -> httpx.Response: assert MCP_PROTOCOL_VERSION_HEADER not in recorded[1].headers assert recorded[2].headers[MCP_PROTOCOL_VERSION_HEADER] == "2025-11-25" assert recorded[3].headers[MCP_PROTOCOL_VERSION_HEADER] == "2025-11-25" + + +class _ParkedSSEStream(httpx.AsyncByteStream): + """An SSE response body that emits one comment line, then parks until closed. + + `opened` fires once the transport is iterating the body (the POST is truly in + flight); `closed` fires when httpx tears the body down — the observable proof + that an abort, not a response, ended the stream. + """ + + def __init__(self) -> None: + self.opened = anyio.Event() + self.closed = anyio.Event() + self._release = anyio.Event() + + async def __aiter__(self) -> AsyncIterator[bytes]: + self.opened.set() + yield b": parked\n\n" + await self._release.wait() + + async def aclose(self) -> None: + self.closed.set() + self._release.set() + + +def _sse_or_ack_handler( + parked: _ParkedSSEStream, posted: list[dict[str, Any]], frame_posted: anyio.Event +) -> Callable[[httpx.Request], httpx.Response]: + """Requests get the parked SSE body; notifications get 202 and set `frame_posted`.""" + + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + posted.append(body) + if "id" in body: + return httpx.Response(200, headers={"content-type": "text/event-stream"}, stream=parked) + frame_posted.set() + return httpx.Response(202) + + return handler + + +@pytest.mark.anyio +async def test_modern_cancelled_frame_aborts_the_matching_in_flight_post() -> None: + """At 2026 an outbound `notifications/cancelled` never POSTs — closing the named + request's response stream IS the wire's cancellation signal — so the transport + aborts the in-flight POST and swallows the frame.""" + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + posted.append(json.loads(request.content)) + return httpx.Response(200, headers={"content-type": "text/event-stream"}, stream=parked) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (_read, write), + ): + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id="listen-1", method="subscriptions/listen", params={}), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}), + ) + ) + await parked.opened.wait() + await write.send( + SessionMessage( + JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": "listen-1"} + ) + ) + ) + await parked.closed.wait() + assert [body["method"] for body in posted] == ["subscriptions/listen"] + + +@pytest.mark.anyio +@pytest.mark.parametrize("stamped_version", [None, "2025-11-25"], ids=["no-version-yet", "2025-11-25"]) +async def test_legacy_cancelled_frame_posts_and_leaves_the_stream_open(stamped_version: str | None) -> None: + """Below 2026 — or before any stamped POST has revealed the version — the frame is + the spec's cancellation signal: it POSTs, and the request's stream stays open + (a 2025 disconnect is explicitly not a cancel).""" + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + frame_posted = anyio.Event() + handler = _sse_or_ack_handler(parked, posted, frame_posted) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (_read, write), + ): + metadata = ( + ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: stamped_version}) + if stamped_version is not None + else None + ) + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={}), + metadata=metadata, + ) + ) + await parked.opened.wait() + await write.send( + SessionMessage( + JSONRPCNotification(jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1}) + ) + ) + await frame_posted.wait() + # Checked before teardown: exiting the transport cancels the parked POST. + assert not parked.closed.is_set() + assert [body["method"] for body in posted] == ["tools/call", "notifications/cancelled"] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "params", + [ + pytest.param({"requestId": 999}, id="unknown-id"), + pytest.param({"requestId": True}, id="bool-must-not-alias-request-id-1"), + pytest.param({"requestId": "1"}, id="string-1-must-not-match-int-1"), + pytest.param({}, id="no-request-id"), + pytest.param(None, id="no-params"), + ], +) +async def test_modern_cancelled_frames_matching_no_post_are_swallowed(params: dict[str, Any] | None) -> None: + """At 2026 the frame is swallowed even when it aborts nothing — the wire defines no + client-to-server notifications, so a late cancel racing the response must not leak + a POST — and a mismatched id must not abort someone else's stream.""" + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + posted.append(body) + if body.get("id") == 1: + return httpx.Response(200, headers={"content-type": "text/event-stream"}, stream=parked) + return httpx.Response(200, json={"jsonrpc": "2.0", "id": body["id"], "result": {}}) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (read, write), + ): + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="subscriptions/listen", params={}), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}), + ) + ) + await parked.opened.wait() + await write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/cancelled", params=params)) + ) + # A follow-up request completing proves the loop moved past the swallowed frame. + await write.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=2, method="ping", params={}))) + reply = await read.receive() + # Checked before teardown: exiting the transport cancels the parked POST. + assert not parked.closed.is_set() + assert isinstance(reply, SessionMessage) + assert isinstance(reply.message, JSONRPCResponse) + assert reply.message.id == 2 + assert [body["method"] for body in posted] == ["subscriptions/listen", "ping"] + + +@pytest.mark.anyio +async def test_handler_scoped_cancelled_frames_are_translated_at_modern_too() -> None: + """A cancel carrying `ServerMessageMetadata` (a handler abandoning its own + back-channel request) still names one of OUR outbound ids — every spec-legal + cancel names a request its sender issued — so at 2026 it aborts that POST and + stays off the wire like any other.""" + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + frame_posted = anyio.Event() + handler = _sse_or_ack_handler(parked, posted, frame_posted) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (_read, write), + ): + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={}), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}), + ) + ) + await parked.opened.wait() + await write.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1} + ), + metadata=ServerMessageMetadata(related_request_id=99), + ) + ) + await parked.closed.wait() + assert [body["method"] for body in posted] == ["tools/call"] + assert not frame_posted.is_set() + + +@pytest.mark.anyio +async def test_cancel_for_a_request_sent_under_2025_still_posts_after_modern_adoption() -> None: + """The translation follows the era the NAMED request was sent under, not the + cache at cancel time: a request POSTed under 2025 keeps 2025 cancellation + semantics (frame on the wire, stream left open) even after a later message + flips the negotiated version to 2026.""" + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + frame_posted = anyio.Event() + + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + posted.append(body) + if body.get("id") == 1: + return httpx.Response(200, headers={"content-type": "text/event-stream"}, stream=parked) + if "id" in body: + return httpx.Response(200, json={"jsonrpc": "2.0", "id": body["id"], "result": {}}) + frame_posted.set() + return httpx.Response(202) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (read, write), + ): + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={}), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: "2025-11-25"}), + ) + ) + await parked.opened.wait() + # A modern-stamped request flips the cached negotiated version. + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=2, method="ping", params={}), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}), + ) + ) + reply = await read.receive() + assert isinstance(reply, SessionMessage) + assert isinstance(reply.message, JSONRPCResponse) + await write.send( + SessionMessage( + JSONRPCNotification(jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1}) + ) + ) + await frame_posted.wait() + # Checked before teardown: exiting the transport cancels the parked POST. + assert not parked.closed.is_set() + assert [body["method"] for body in posted] == ["tools/call", "ping", "notifications/cancelled"] + + +class _SignalingBus(InMemorySubscriptionBus): + """Signals subscribe/unsubscribe so a test observes the stream lifecycle through + the bus Protocol (the public seam) instead of polling handler internals.""" + + def __init__(self) -> None: + super().__init__() + self.subscribed = anyio.Event() + self.unsubscribed = anyio.Event() + + def subscribe(self, listener: Callable[[ServerEvent], None]) -> Callable[[], None]: + unsubscribe = super().subscribe(listener) + self.subscribed.set() + + def unsubscribe_and_signal() -> None: + unsubscribe() + self.unsubscribed.set() + + return unsubscribe_and_signal + + +@pytest.mark.anyio +async def test_scope_cancel_aborts_a_modern_listen_post_end_to_end() -> None: + """Over a real ASGI bridge: cancelling the caller of a parked `subscriptions/listen` + closes the POST's response stream — the server treats the disconnect as the cancel + and releases the subscription — and no `notifications/cancelled` crosses the wire.""" + bus = _SignalingBus() + server = Server("test", on_subscriptions_listen=ListenHandler(bus)) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + async with server.lifespan(server) as lifespan_state: + await handle_modern_request(server, None, False, lifespan_state, scope, receive, send) + + posted_methods: list[str] = [] + + async def record_request(request: httpx.Request) -> None: + posted_methods.append(json.loads(request.content)["method"]) + + acked = anyio.Event() + + async def on_notify(dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None) -> None: + assert method == "notifications/subscriptions/acknowledged" + acked.set() + + on_request, _ = echo_handlers(Recorder()) + + with anyio.fail_after(15): + async with ( + httpx.AsyncClient( + transport=StreamingASGITransport(app), + base_url="http://testserver", + event_hooks={"request": [record_request]}, + ) as http, + streamable_http_client("http://testserver/mcp", http_client=http) as (read, write), + ): + dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read, write) + async with anyio.create_task_group() as tg: # pragma: no branch + await tg.start(dispatcher.run, on_request, on_notify) + listen_scope = anyio.CancelScope() + + async def send_listen() -> None: + params: dict[str, Any] = { + "_meta": { + PROTOCOL_VERSION_META_KEY: LATEST_MODERN_VERSION, + CLIENT_INFO_META_KEY: {"name": "test-client", "version": "0"}, + CLIENT_CAPABILITIES_META_KEY: {}, + }, + "notifications": {"toolsListChanged": True}, + } + opts: CallOptions = { + "request_id": "listen-1", + "headers": { + MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION, + MCP_METHOD_HEADER: "subscriptions/listen", + }, + } + with listen_scope: + await dispatcher.send_raw_request("subscriptions/listen", params, opts) + + tg.start_soon(send_listen) + await acked.wait() + assert bus.subscribed.is_set() + assert not bus.unsubscribed.is_set() + listen_scope.cancel() + await bus.unsubscribed.wait() + tg.cancel_scope.cancel() + assert posted_methods == ["subscriptions/listen"] + + +class _CompletingSSEStream(httpx.AsyncByteStream): + """An SSE body that delivers one JSON-RPC response, then parks in `aclose`. + + Holding `aclose` keeps the finished POST task alive past its response, so a + test can re-register the same request id underneath it before releasing. + """ + + def __init__(self, response_body: dict[str, Any]) -> None: + self._event = f"data: {json.dumps(response_body)}\n\n".encode() + self.release = anyio.Event() + + async def __aiter__(self) -> AsyncIterator[bytes]: + yield self._event + + async def aclose(self) -> None: + await self.release.wait() + + +@pytest.mark.anyio +async def test_a_finished_post_task_does_not_evict_a_reused_ids_new_registration() -> None: + """Request ids are reusable once resolved; a finished POST task unwinding late + must not pop the successor's registration, or a cancel for the reused id would + find nothing to abort and the live POST would leak past the cancellation.""" + completing = _CompletingSSEStream({"jsonrpc": "2.0", "id": "dup-1", "result": {}}) + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + streams = [completing, parked] + + def handler(request: httpx.Request) -> httpx.Response: + posted.append(json.loads(request.content)) + return httpx.Response(200, headers={"content-type": "text/event-stream"}, stream=streams.pop(0)) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (read, write), + ): + modern = ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}) + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id="dup-1", method="tools/call", params={}), + metadata=modern, + ) + ) + reply = await read.receive() + assert isinstance(reply, SessionMessage) + assert isinstance(reply.message, JSONRPCResponse) + # The first task is now parked in `aclose`; reuse its id underneath it. + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id="dup-1", method="subscriptions/listen", params={}), + metadata=modern, + ) + ) + await parked.opened.wait() + completing.release.set() + await anyio.wait_all_tasks_blocked() + # The successor's registration survived: a cancel still aborts it. + await write.send( + SessionMessage( + JSONRPCNotification(jsonrpc="2.0", method="notifications/cancelled", params={"requestId": "dup-1"}) + ) + ) + await parked.closed.wait() + assert [body["method"] for body in posted] == ["tools/call", "subscriptions/listen"] diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 1f8208337..03ef27c8d 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -19,6 +19,7 @@ INVALID_REQUEST, REQUEST_TIMEOUT, ErrorData, + RequestId, Tool, ) @@ -396,6 +397,118 @@ async def test_direct_close_makes_run_return(): server.close() +@pytest.mark.anyio +async def test_send_raw_request_honors_caller_supplied_request_id_verbatim_typed(pair_factory: PairFactory): + """A caller-supplied `CallOptions["request_id"]` reaches the peer's context verbatim — + "7" stays a string, never the integer 7 — and the next call without one still mints + a dispatcher id as before.""" + async with running_pair(pair_factory) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.send_raw_request("first", None, {"request_id": "7"}) + await client.send_raw_request("second", None) + supplied, minted = (ctx.request_id for ctx in srec.contexts) + assert supplied == "7" + assert type(supplied) is str + assert type(minted) is int + + +@pytest.mark.anyio +async def test_send_raw_request_with_in_flight_request_id_raises_and_frees_id_on_completion( + pair_factory: PairFactory, +): + """Reusing an id while it is in flight is a loud `ValueError` — silent reuse would + corrupt response correlation. Once the first request completes, the id is free + again: the reservation is in-flight-scoped, not permanent.""" + entered = anyio.Event() + release = anyio.Event() + + async def parked( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + entered.set() + await release.wait() + return {"served": method} + + async with running_pair(pair_factory, server_on_request=parked) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def first() -> None: + await client.send_raw_request("slow", None, {"request_id": "listen-1"}) + + tg.start_soon(first) + await entered.wait() + with pytest.raises(ValueError, match="already in flight"): + await client.send_raw_request("duplicate", None, {"request_id": "listen-1"}) + release.set() + result = await client.send_raw_request("again", None, {"request_id": "listen-1"}) + assert result == {"served": "again"} + + +@pytest.mark.anyio +async def test_minted_ids_skip_a_caller_supplied_id_still_in_flight(pair_factory: PairFactory): + """The dispatcher mints PAST a key a supplied id occupies — the collision error + is reserved for the caller who chose the id, never an innocent minted request.""" + entered = anyio.Event() + release = anyio.Event() + seen_ids: list[RequestId | None] = [] + + async def maybe_park( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + seen_ids.append(ctx.request_id) + if method == "park": + entered.set() + await release.wait() + return {} + + async with running_pair(pair_factory, server_on_request=maybe_park) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def parked() -> None: + await client.send_raw_request("park", None, {"request_id": "3"}) + + tg.start_soon(parked) + await entered.wait() + # The counter mints 1 and 2, then skips the occupied 3 to 4. + for _ in range(3): + await client.send_raw_request("plain", None) + release.set() + assert [request_id for request_id in seen_ids if request_id != "3"] == [1, 2, 4] + + +@pytest.mark.anyio +async def test_supplied_numeric_string_id_collides_with_its_int_twin(pair_factory: PairFactory): + """ "7" and 7 are one id in the collision domain on BOTH dispatchers, so the + in-memory pair raises exactly where the wire dispatcher (whose pending keys + are coerced for response correlation) would.""" + entered = anyio.Event() + release = anyio.Event() + + async def parked( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + entered.set() + await release.wait() + return {} + + async with running_pair(pair_factory, server_on_request=parked) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def first() -> None: + await client.send_raw_request("slow", None, {"request_id": 7}) + + tg.start_soon(first) + await entered.wait() + with pytest.raises(ValueError, match="already in flight"): + await client.send_raw_request("duplicate", None, {"request_id": "7"}) + release.set() + # Completion frees the id for either spelling. + assert await client.send_raw_request("again", None, {"request_id": "7"}) == {} + + if TYPE_CHECKING: _d: Dispatcher[TransportContext] = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) _o: Outbound = _d diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 82d16bc4b..e91fc2de2 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -34,11 +34,10 @@ from mcp.server import Server, ServerRequestContext from mcp.shared._compat import resync_tracer from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream -from mcp.shared.dispatcher import CallOptions, DispatchContext +from mcp.shared.dispatcher import CallOptions, DispatchContext, coerce_request_id from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] JSONRPCDispatcher, - _coerce_id, _OutboundPlan, _Pending, _plan_outbound, @@ -1821,7 +1820,7 @@ async def respond_stringly() -> None: @pytest.mark.anyio async def test_error_response_with_string_id_correlates_to_int_keyed_pending_request(): - """A JSONRPCError echoing the request ID as a JSON string still resolves the waiter (same `_coerce_id` path).""" + """A JSONRPCError echoing the request ID as a JSON string still resolves the waiter (`coerce_request_id` path).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -1900,10 +1899,10 @@ async def on_progress(progress: float, total: float | None, message: str | None) assert seen == [0.5] -def test_coerce_id_passes_through_non_numeric_string_and_int(): - assert _coerce_id("7") == 7 - assert _coerce_id("not-an-int") == "not-an-int" - assert _coerce_id(42) == 42 +def test_coerce_request_id_passes_through_non_numeric_string_and_int(): + assert coerce_request_id("7") == 7 + assert coerce_request_id("not-an-int") == "not-an-int" + assert coerce_request_id(42) == 42 @pytest.mark.anyio @@ -2154,7 +2153,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> ids=["string-cancel-for-int-request", "int-cancel-for-string-request"], ) async def test_cancelled_correlates_across_string_and_int_request_id_forms(request_id: RequestId, cancel_id: object): - """A peer that stringifies the id between request and cancel still cancels (same `_coerce_id` path).""" + """A peer that stringifies the id between request and cancel still cancels (same `coerce_request_id` path).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -2381,3 +2380,38 @@ async def call() -> None: assert observed[0][0] == "notifications/cancelled" assert observed[0][1]["requestId"] == request_id assert observed[0][1]["reason"] == "user clicked stop" + + +@pytest.mark.anyio +async def test_send_raw_request_with_caller_supplied_string_id_is_verbatim_on_the_wire(): + """A supplied "7" goes on the wire as the string "7", and the response still + correlates when the peer echoes it back as the integer 7 — the pending key gets + the same coercion `_resolve_pending` applies to inbound ids.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + result_box: list[dict[str, Any]] = [] + done = anyio.Event() + try: + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def call() -> None: + result_box.append(await client.send_raw_request("tools/list", None, {"request_id": "7"})) + done.set() + + await tg.start(client.run, on_request, on_notify) + tg.start_soon(call) + wire = await c2s_recv.receive() + assert isinstance(wire, SessionMessage) + assert isinstance(wire.message, JSONRPCRequest) + assert wire.message.id == "7" + assert type(wire.message.id) is str + await s2c_send.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=7, result={"ok": True}))) + await done.wait() + tg.cancel_scope.cancel() + finally: + for stream in (c2s_send, c2s_recv, s2c_send, s2c_recv): + stream.close() + assert result_box == [{"ok": True}]