diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae..64bacccd1 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -8,6 +8,7 @@ from typing import Any, Generic, Protocol, TypeVar import anyio +from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectSendStream from opentelemetry.trace import SpanKind from pydantic import BaseModel, TypeAdapter @@ -201,6 +202,15 @@ def __init__( self._progress_callbacks = {} self._response_routers = [] self._exit_stack = AsyncExitStack() + self._task_group: TaskGroup = anyio.create_task_group() + self._started = False + + def _require_started(self) -> None: + if not self._started: + raise RuntimeError( + "Session is not running. Use it as an async context manager " + "(e.g. `async with ClientSession(...) as session:`)." + ) def add_response_router(self, router: ResponseRouter) -> None: """Register a response router to handle responses for non-standard requests. @@ -218,8 +228,11 @@ def add_response_router(self, router: ResponseRouter) -> None: self._response_routers.append(router) async def __aenter__(self) -> Self: + if self._started: + raise RuntimeError("Session is already running") self._task_group = anyio.create_task_group() await self._task_group.__aenter__() + self._started = True self._task_group.start_soon(self._receive_loop) return self @@ -234,7 +247,10 @@ async def __aexit__( # would be very surprising behavior), so make sure to cancel the tasks # in the task group. self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + try: + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + finally: + self._started = False async def send_request( self, @@ -251,6 +267,7 @@ async def send_request( Do not use this method to emit notifications! Use send_notification() instead. """ + self._require_started() request_id = self._request_id self._request_id = request_id + 1 @@ -313,6 +330,7 @@ async def send_notification( related_request_id: RequestId | None = None, ) -> None: """Emits a notification, which is a one-way message that does not expect a response.""" + self._require_started() # Some transport implementations may need to set the related_request_id # to attribute to the notifications to the request that triggered them. jsonrpc_notification = JSONRPCNotification( diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index c7bf8fafa..eafb994db 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -50,9 +50,10 @@ async def mock_send(*args: Any, **kwargs: Any): initial_stream_count = len(session._response_streams) # Run the test with the patched method - with patch.object(session._write_stream, "send", mock_send): - with pytest.raises(RuntimeError): - await session.send_request(request, EmptyResult) + async with session: + with patch.object(session._write_stream, "send", mock_send): + with pytest.raises(RuntimeError, match="Simulated network error"): # pragma: no branch + await session.send_request(request, EmptyResult) # Verify that no response streams were leaked assert len(session._response_streams) == initial_stream_count, ( diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f25c964f0..8c3cd9a6a 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -110,6 +110,43 @@ async def message_handler( # pragma: no cover assert isinstance(initialized_notification, InitializedNotification) +@pytest.mark.anyio +async def test_client_session_requires_context_manager(): + client_to_server_send, _client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + _server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_to_server_send, + _client_to_server_receive, + _server_to_client_send, + server_to_client_receive, + ): + session = ClientSession(server_to_client_receive, client_to_server_send) + + with pytest.raises(RuntimeError, match="async context manager"): + await session.initialize() + + +@pytest.mark.anyio +async def test_client_session_reentry_raises_runtime_error(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + session = ClientSession(server_to_client_receive, client_to_server_send) + await session.__aenter__() + try: + with pytest.raises(RuntimeError, match="already running"): + await session.__aenter__() + finally: + await session.__aexit__(None, None, None) + + @pytest.mark.anyio async def test_client_session_custom_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)