Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mcp.shared.response_router import ResponseRouter
from mcp.types import (
CONNECTION_CLOSED,
INTERNAL_ERROR,
INVALID_PARAMS,
CancelledNotification,
ClientNotification,
Expand Down Expand Up @@ -237,6 +238,34 @@ async def __aexit__(
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)

@staticmethod
def _process_response(
response_or_error: JSONRPCResponse | JSONRPCError | None,
result_type: type[ReceiveResultT],
) -> ReceiveResultT:
"""
Process a JSON-RPC response, validating and returning the result.

Raises McpError if the response is an error or if response_or_error is None.
The None check is a defensive guard against anyio race conditions - see #1717.
"""
if response_or_error is None:
# Defensive check for anyio fail_after race condition (#1717).
# If anyio's CancelScope incorrectly suppresses an exception,
# the response variable may never be assigned. See:
# https://github.com/agronholm/anyio/issues/589
raise McpError(
ErrorData(
code=INTERNAL_ERROR,
message="Internal error: no response received",
)
)

if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)

return result_type.model_validate(response_or_error.result)

async def send_request(
self,
request: SendRequestT,
Expand Down Expand Up @@ -287,6 +316,10 @@ async def send_request(
elif self._session_read_timeout_seconds is not None: # pragma: no cover
timeout = self._session_read_timeout_seconds.total_seconds()

# Initialize to None as a defensive guard against anyio race conditions
# where fail_after may incorrectly suppress exceptions (#1717)
response_or_error: JSONRPCResponse | JSONRPCError | None = None

try:
with anyio.fail_after(timeout):
response_or_error = await response_stream_reader.receive()
Expand All @@ -301,12 +334,22 @@ async def send_request(
),
)
)

if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
except anyio.EndOfStream:
raise McpError(
ErrorData(
code=CONNECTION_CLOSED,
message="Connection closed: stream ended unexpectedly",
)
)
except anyio.ClosedResourceError:
raise McpError(
ErrorData(
code=CONNECTION_CLOSED,
message="Connection closed",
)
)
else:
return result_type.model_validate(response_or_error.result)

return self._process_response(response_or_error, result_type)
finally:
self._response_streams.pop(request_id, None)
self._progress_callbacks.pop(request_id, None)
Expand Down
115 changes: 115 additions & 0 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import AsyncMock, patch

import anyio
import pytest
Expand All @@ -9,12 +10,18 @@
from mcp.server.lowlevel.server import Server
from mcp.shared.exceptions import McpError
from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session
from mcp.shared.session import BaseSession
from mcp.types import (
CONNECTION_CLOSED,
INTERNAL_ERROR,
CancelledNotification,
CancelledNotificationParams,
ClientNotification,
ClientRequest,
EmptyResult,
ErrorData,
JSONRPCError,
JSONRPCResponse,
TextContent,
)

Expand Down Expand Up @@ -168,3 +175,111 @@ async def mock_server():
await ev_closed.wait()
with anyio.fail_after(1):
await ev_response.wait()


class TestProcessResponse:
"""Tests for BaseSession._process_response static method."""

def test_process_response_with_valid_response(self):
"""Test that a valid JSONRPCResponse is processed correctly."""
response = JSONRPCResponse(
jsonrpc="2.0",
id=1,
result={},
)

result = BaseSession._process_response(response, EmptyResult)

assert isinstance(result, EmptyResult)

def test_process_response_with_error(self):
"""Test that a JSONRPCError raises McpError."""
error = JSONRPCError(
jsonrpc="2.0",
id=1,
error=ErrorData(code=-32600, message="Invalid request"),
)

with pytest.raises(McpError) as exc_info:
BaseSession._process_response(error, EmptyResult)

assert exc_info.value.error.code == -32600
assert exc_info.value.error.message == "Invalid request"

def test_process_response_with_none(self):
"""
Test defensive check for anyio fail_after race condition (#1717).

If anyio's CancelScope incorrectly suppresses an exception during
receive(), the response variable may never be assigned. This test
verifies we handle this gracefully instead of raising UnboundLocalError.

See: https://github.com/agronholm/anyio/issues/589
"""
with pytest.raises(McpError) as exc_info:
BaseSession._process_response(None, EmptyResult)

assert exc_info.value.error.code == INTERNAL_ERROR
assert "no response received" in exc_info.value.error.message


@pytest.mark.anyio
async def test_send_request_handles_end_of_stream():
"""Test that EndOfStream from response stream raises McpError with CONNECTION_CLOSED."""

async with create_client_server_memory_streams() as (client_streams, _):
client_read, client_write = client_streams

async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session:
# Mock create_memory_object_stream to return a stream that raises EndOfStream
mock_reader = AsyncMock()
mock_reader.receive = AsyncMock(side_effect=anyio.EndOfStream)
mock_reader.aclose = AsyncMock()

mock_sender = AsyncMock()
mock_sender.aclose = AsyncMock()

# The subscripted form returns a callable that returns the tuple
with patch("mcp.shared.session.anyio.create_memory_object_stream") as mock_create:
# pyright: ignore[reportUnknownLambdaType]
mock_create.__getitem__ = lambda _s, _k: lambda _z: (mock_sender, mock_reader) # type: ignore

with pytest.raises(McpError) as exc_info:
await client_session.send_request(
ClientRequest(types.PingRequest()),
EmptyResult,
)

assert exc_info.value.error.code == CONNECTION_CLOSED
assert "stream ended unexpectedly" in exc_info.value.error.message


@pytest.mark.anyio
async def test_send_request_handles_closed_resource_error():
"""Test that ClosedResourceError from response stream raises McpError with CONNECTION_CLOSED."""

async with create_client_server_memory_streams() as (client_streams, _):
client_read, client_write = client_streams

async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session:
# Mock create_memory_object_stream to return a stream that raises ClosedResourceError
mock_reader = AsyncMock()
mock_reader.receive = AsyncMock(side_effect=anyio.ClosedResourceError)
mock_reader.aclose = AsyncMock()

mock_sender = AsyncMock()
mock_sender.aclose = AsyncMock()

# The subscripted form returns a callable that returns the tuple
with patch("mcp.shared.session.anyio.create_memory_object_stream") as mock_create:
# pyright: ignore[reportUnknownLambdaType]
mock_create.__getitem__ = lambda _s, _k: lambda _z: (mock_sender, mock_reader) # type: ignore

with pytest.raises(McpError) as exc_info:
await client_session.send_request(
ClientRequest(types.PingRequest()),
EmptyResult,
)

assert exc_info.value.error.code == CONNECTION_CLOSED
assert "Connection closed" in exc_info.value.error.message