diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 2c55bb775..da45923e2 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -11,20 +11,23 @@ import contextlib import logging from collections.abc import Callable +from dataclasses import dataclass from datetime import timedelta from types import TracebackType -from typing import Any, TypeAlias +from typing import Any, TypeAlias, overload import anyio from pydantic import BaseModel -from typing_extensions import Self +from typing_extensions import Self, deprecated import mcp from mcp import types +from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters from mcp.client.streamable_http import streamablehttp_client from mcp.shared.exceptions import McpError +from mcp.shared.session import ProgressFnT class SseServerParameters(BaseModel): @@ -65,6 +68,21 @@ class StreamableHttpParameters(BaseModel): ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters +# Use dataclass instead of pydantic BaseModel +# because pydantic BaseModel cannot handle Protocol fields. +@dataclass +class ClientSessionParameters: + """Parameters for establishing a client session to an MCP server.""" + + read_timeout_seconds: timedelta | None = None + sampling_callback: SamplingFnT | None = None + elicitation_callback: ElicitationFnT | None = None + list_roots_callback: ListRootsFnT | None = None + logging_callback: LoggingFnT | None = None + message_handler: MessageHandlerFnT | None = None + client_info: types.Implementation | None = None + + class ClientSessionGroup: """Client for managing connections to multiple MCP servers. @@ -172,11 +190,49 @@ def tools(self) -> dict[str, types.Tool]: """Returns the tools as a dictionary of names to tools.""" return self._tools - async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any], + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + *, + meta: dict[str, Any] | None = None, + ) -> types.CallToolResult: ... + + @overload + @deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.") + async def call_tool( + self, + name: str, + *, + args: dict[str, Any], + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + meta: dict[str, Any] | None = None, + ) -> types.CallToolResult: ... + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + *, + meta: dict[str, Any] | None = None, + args: dict[str, Any] | None = None, + ) -> types.CallToolResult: """Executes a tool given its name and arguments.""" session = self._tool_to_session[name] session_tool_name = self.tools[name].name - return await session.call_tool(session_tool_name, args) + return await session.call_tool( + session_tool_name, + arguments if args is None else args, + read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, + meta=meta, + ) async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" @@ -225,13 +281,16 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, + session_params: ClientSessionParameters | None = None, ) -> mcp.ClientSession: """Connects to a single MCP server.""" - server_info, session = await self._establish_session(server_params) + server_info, session = await self._establish_session(server_params, session_params or ClientSessionParameters()) return await self.connect_with_session(server_info, session) async def _establish_session( - self, server_params: ServerParameters + self, + server_params: ServerParameters, + session_params: ClientSessionParameters, ) -> tuple[types.Implementation, mcp.ClientSession]: """Establish a client session to an MCP server.""" @@ -259,7 +318,20 @@ async def _establish_session( ) read, write, _ = await session_stack.enter_async_context(client) - session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) + session = await session_stack.enter_async_context( + mcp.ClientSession( + read, + write, + read_timeout_seconds=session_params.read_timeout_seconds, + sampling_callback=session_params.sampling_callback, + elicitation_callback=session_params.elicitation_callback, + list_roots_callback=session_params.list_roots_callback, + logging_callback=session_params.logging_callback, + message_handler=session_params.message_handler, + client_info=session_params.client_info, + ) + ) + result = await session.initialize() # Session successfully initialized. diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 3a19cff68..e61ea572b 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -5,7 +5,12 @@ import mcp from mcp import types -from mcp.client.session_group import ClientSessionGroup, SseServerParameters, StreamableHttpParameters +from mcp.client.session_group import ( + ClientSessionGroup, + ClientSessionParameters, + SseServerParameters, + StreamableHttpParameters, +) from mcp.client.stdio import StdioServerParameters from mcp.shared.exceptions import McpError @@ -62,7 +67,7 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov # --- Test Execution --- result = await mcp_session_group.call_tool( name="server1-my_tool", - args={ + arguments={ "name": "value1", "args": {}, }, @@ -73,6 +78,9 @@ def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cov mock_session.call_tool.assert_called_once_with( "my_tool", {"name": "value1", "args": {}}, + read_timeout_seconds=None, + progress_callback=None, + meta=None, ) async def test_connect_to_server(self, mock_exit_stack: contextlib.AsyncExitStack): @@ -329,7 +337,7 @@ async def test_establish_session_parameterized( ( returned_server_info, returned_session, - ) = await group._establish_session(server_params_instance) + ) = await group._establish_session(server_params_instance, ClientSessionParameters()) # --- Assertions --- # 1. Assert the correct specific client function was called @@ -357,7 +365,17 @@ async def test_establish_session_parameterized( mock_client_cm_instance.__aenter__.assert_awaited_once() # 2. Assert ClientSession was called correctly - mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream) + mock_ClientSession_class.assert_called_once_with( + mock_read_stream, + mock_write_stream, + read_timeout_seconds=None, + sampling_callback=None, + elicitation_callback=None, + list_roots_callback=None, + logging_callback=None, + message_handler=None, + client_info=None, + ) mock_raw_session_cm.__aenter__.assert_awaited_once() mock_entered_session.initialize.assert_awaited_once()