diff --git a/README.md b/README.md index 5dbc4bd9d..8abcfe877 100644 --- a/README.md +++ b/README.md @@ -2153,6 +2153,56 @@ if __name__ == "__main__": _Full example: [examples/snippets/clients/streamable_basic.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/streamable_basic.py)_ +### Handling Server Notifications + +Servers may send notifications, which derive from the `ServerNotification` class. To handle these, follow the following steps: + +1. For each notification type you want to support, write a callback function that follows implements the matching protocol, such as `ToolListChangedFnT` for the tool list changed notification. +2. Pass that function to the appropriate parameter when instantiating your client, e.g. `tool_list_changed_callback` for the tool list changed notification. This will be called every time your client receives the matching notification. + + +```python +# Snippets demonstrating handling known and custom server notifications + +import asyncio + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + +# Create dummy server parameters for stdio connection +server_params = StdioServerParameters( + command="uv", + args=["run"], + env={}, +) + + +# Create a custom handler for the resource list changed notification +async def custom_resource_list_changed_handler() -> None: + """Custom handler for resource list changed notifications.""" + print("RESOURCE LIST CHANGED") + + +async def run(): + async with stdio_client(server_params) as (read, write): + async with ClientSession( + read, + write, + resource_list_changed_callback=custom_resource_list_changed_handler, + ) as session: + # Initialize the connection + await session.initialize() + + # Do client stuff here + + +if __name__ == "__main__": + asyncio.run(run()) +``` + +_Full example: [examples/snippets/clients/server_notification_client.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/server_notification_client.py)_ + + ### Client Display Utilities When building MCP clients, the SDK provides utilities to help display human-readable names for tools, resources, and prompts: diff --git a/examples/snippets/clients/server_notification_client.py b/examples/snippets/clients/server_notification_client.py new file mode 100644 index 000000000..a51277cf8 --- /dev/null +++ b/examples/snippets/clients/server_notification_client.py @@ -0,0 +1,36 @@ +# Snippets demonstrating handling known and custom server notifications + +import asyncio + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + +# Create dummy server parameters for stdio connection +server_params = StdioServerParameters( + command="uv", + args=["run"], + env={}, +) + + +# Create a custom handler for the resource list changed notification +async def custom_resource_list_changed_handler() -> None: + """Custom handler for resource list changed notifications.""" + print("RESOURCE LIST CHANGED") + + +async def run(): + async with stdio_client(server_params) as (read, write): + async with ClientSession( + read, + write, + resource_list_changed_callback=custom_resource_list_changed_handler, + ) as session: + # Initialize the connection + await session.initialize() + + # Do client stuff here + + +if __name__ == "__main__": + asyncio.run(run()) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3835a2a57..06a345100 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -48,6 +48,38 @@ async def __call__( ) -> None: ... # pragma: no branch +class ProgressNotificationFnT(Protocol): + async def __call__( + self, + params: types.ProgressNotificationParams, + ) -> None: ... + + +class ResourceUpdatedFnT(Protocol): + async def __call__( + self, + params: types.ResourceUpdatedNotificationParams, + ) -> None: ... + + +class ResourceListChangedFnT(Protocol): + async def __call__( + self, + ) -> None: ... + + +class ToolListChangedFnT(Protocol): + async def __call__( + self, + ) -> None: ... + + +class PromptListChangedFnT(Protocol): + async def __call__( + self, + ) -> None: ... + + class MessageHandlerFnT(Protocol): async def __call__( self, @@ -96,6 +128,32 @@ async def _default_logging_callback( pass +async def _default_progress_callback( + params: types.ProgressNotificationParams, +) -> None: + """Note: Default progress handling happens in the BaseSession class. This callback will only be called after the + default progress handling has completed.""" + pass + + +async def _default_resource_updated_callback( + params: types.ResourceUpdatedNotificationParams, +) -> None: + pass + + +async def _default_resource_list_changed_callback() -> None: + pass + + +async def _default_tool_list_changed_callback() -> None: + pass + + +async def _default_prompt_list_changed_callback() -> None: + pass + + ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) @@ -117,6 +175,11 @@ def __init__( elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, + progress_notification_callback: ProgressNotificationFnT | None = None, + resource_updated_callback: ResourceUpdatedFnT | None = None, + resource_list_changed_callback: ResourceListChangedFnT | None = None, + tool_list_changed_callback: ToolListChangedFnT | None = None, + prompt_list_changed_callback: PromptListChangedFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, ) -> None: @@ -132,6 +195,11 @@ def __init__( self._elicitation_callback = elicitation_callback or _default_elicitation_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback + self._progress_notification_callback = progress_notification_callback or _default_progress_callback + self._resource_updated_callback = resource_updated_callback or _default_resource_updated_callback + self._resource_list_changed_callback = resource_list_changed_callback or _default_resource_list_changed_callback + self._tool_list_changed_callback = tool_list_changed_callback or _default_tool_list_changed_callback + self._prompt_list_changed_callback = prompt_list_changed_callback or _default_prompt_list_changed_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._server_capabilities: types.ServerCapabilities | None = None @@ -547,9 +615,20 @@ async def _handle_incoming( async def _received_notification(self, notification: types.ServerNotification) -> None: """Handle notifications from the server.""" - # Process specific notification types match notification.root: case types.LoggingMessageNotification(params=params): await self._logging_callback(params) + case types.ProgressNotification(params=params): + await self._progress_notification_callback(params) + case types.ResourceUpdatedNotification(params=params): + await self._resource_updated_callback(params) + case types.ResourceListChangedNotification(): + await self._resource_list_changed_callback() + case types.ToolListChangedNotification(): + await self._tool_list_changed_callback() + case types.PromptListChangedNotification(): + await self._prompt_list_changed_callback() case _: + # CancelledNotification is handled separately in shared/session.py + # and should never reach this point. This case is defensive. pass diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 06d404e31..53ab3b107 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -13,7 +13,19 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types -from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.session import ( + ClientSession, + ElicitationFnT, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + ProgressNotificationFnT, + PromptListChangedFnT, + ResourceListChangedFnT, + ResourceUpdatedFnT, + SamplingFnT, + ToolListChangedFnT, +) from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage @@ -53,6 +65,11 @@ async def create_connected_server_and_client_session( sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, + progress_notification_callback: ProgressNotificationFnT | None = None, + resource_updated_callback: ResourceUpdatedFnT | None = None, + resource_list_changed_callback: ResourceListChangedFnT | None = None, + tool_list_changed_callback: ToolListChangedFnT | None = None, + prompt_list_changed_callback: PromptListChangedFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, raise_exceptions: bool = False, @@ -88,6 +105,11 @@ async def create_connected_server_and_client_session( sampling_callback=sampling_callback, list_roots_callback=list_roots_callback, logging_callback=logging_callback, + progress_notification_callback=progress_notification_callback, + resource_updated_callback=resource_updated_callback, + resource_list_changed_callback=resource_list_changed_callback, + tool_list_changed_callback=tool_list_changed_callback, + prompt_list_changed_callback=prompt_list_changed_callback, message_handler=message_handler, client_info=client_info, elicitation_callback=elicitation_callback, diff --git a/tests/client/test_notification_callbacks.py b/tests/client/test_notification_callbacks.py new file mode 100644 index 000000000..ff24a0c3e --- /dev/null +++ b/tests/client/test_notification_callbacks.py @@ -0,0 +1,508 @@ +""" +Tests for client notification callbacks. + +This module tests all notification types that can be sent from the server to the client, +ensuring that the callback mechanism works correctly for each notification type. +""" + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +import pytest +from pydantic import AnyUrl + +import mcp.types as types +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) +from mcp.shared.session import RequestResponder +from mcp.types import TextContent + +if TYPE_CHECKING: + from _pytest.fixtures import FixtureRequest + + +class ProgressNotificationCollector: + """Collector for ProgressNotification events.""" + + def __init__(self) -> None: + """Initialize the collector.""" + self.notifications: list[types.ProgressNotificationParams] = [] + + async def __call__(self, params: types.ProgressNotificationParams) -> None: + """Collect a progress notification.""" + self.notifications.append(params) + + +class ResourceUpdatedCollector: + """Collector for ResourceUpdatedNotification events.""" + + def __init__(self) -> None: + """Initialize the collector.""" + self.notifications: list[types.ResourceUpdatedNotificationParams] = [] + + async def __call__(self, params: types.ResourceUpdatedNotificationParams) -> None: + """Collect a resource updated notification.""" + self.notifications.append(params) + + +class ResourceListChangedCollector: + """Collector for ResourceListChangedNotification events.""" + + def __init__(self) -> None: + """Initialize the collector.""" + self.notification_count: int = 0 + + async def __call__(self) -> None: + """Collect a resource list changed notification.""" + self.notification_count += 1 + + +class ToolListChangedCollector: + """Collector for ToolListChangedNotification events.""" + + def __init__(self) -> None: + """Initialize the collector.""" + self.notification_count: int = 0 + + async def __call__(self) -> None: + """Collect a tool list changed notification.""" + self.notification_count += 1 + + +class PromptListChangedCollector: + """Collector for PromptListChangedNotification events.""" + + def __init__(self) -> None: + """Initialize the collector.""" + self.notification_count: int = 0 + + async def __call__(self) -> None: + """Collect a prompt list changed notification.""" + self.notification_count += 1 + + +@pytest.fixture +def progress_collector() -> ProgressNotificationCollector: + """Create a progress notification collector.""" + return ProgressNotificationCollector() + + +@pytest.fixture +def resource_updated_collector() -> ResourceUpdatedCollector: + """Create a resource updated collector.""" + return ResourceUpdatedCollector() + + +@pytest.fixture +def resource_list_changed_collector() -> ResourceListChangedCollector: + """Create a resource list changed collector.""" + return ResourceListChangedCollector() + + +@pytest.fixture +def tool_list_changed_collector() -> ToolListChangedCollector: + """Create a tool list changed collector.""" + return ToolListChangedCollector() + + +@pytest.fixture +def prompt_list_changed_collector() -> PromptListChangedCollector: + """Create a prompt list changed collector.""" + return PromptListChangedCollector() + + +@pytest.mark.anyio +async def test_progress_notification_callback(progress_collector: ProgressNotificationCollector) -> None: + """Test that progress notifications are correctly received by the callback.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + @server.tool("send_progress") + async def send_progress_tool(progress: float, total: float, message: str) -> bool: + """Send a progress notification to the client.""" + # Get the progress token from the request metadata + ctx = server.get_context() + if ctx.request_context.meta and ctx.request_context.meta.progressToken: + await ctx.session.send_progress_notification( + progress_token=ctx.request_context.meta.progressToken, + progress=progress, + total=total, + message=message, + ) + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + async with create_session( + server._mcp_server, + progress_notification_callback=progress_collector, + message_handler=message_handler, + ) as client_session: + # Call tool with progress token in metadata + result = await client_session.call_tool( + "send_progress", + {"progress": 50.0, "total": 100.0, "message": "Halfway there"}, + meta={"progressToken": "test-token-123"}, + ) + assert result.isError is False + + # Verify the progress notification was received + assert len(progress_collector.notifications) == 1 + notification = progress_collector.notifications[0] + assert notification.progressToken == "test-token-123" + assert notification.progress == 50.0 + assert notification.total == 100.0 + assert notification.message == "Halfway there" + + +@pytest.mark.anyio +async def test_resource_updated_callback(resource_updated_collector: ResourceUpdatedCollector) -> None: + """Test that resource updated notifications are correctly received by the callback.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + @server.tool("update_resource") + async def update_resource_tool(uri: str) -> bool: + """Send a resource updated notification to the client.""" + await server.get_context().session.send_resource_updated(AnyUrl(uri)) + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + async with create_session( + server._mcp_server, + resource_updated_callback=resource_updated_collector, + message_handler=message_handler, + ) as client_session: + # Trigger resource update notification + result = await client_session.call_tool("update_resource", {"uri": "file:///test/resource.txt"}) + assert result.isError is False + + # Verify the notification was received + assert len(resource_updated_collector.notifications) == 1 + notification = resource_updated_collector.notifications[0] + assert str(notification.uri) == "file:///test/resource.txt" + + +@pytest.mark.anyio +async def test_resource_list_changed_callback( + resource_list_changed_collector: ResourceListChangedCollector, +) -> None: + """Test that resource list changed notifications are correctly received by the callback.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + @server.tool("change_resource_list") + async def change_resource_list_tool() -> bool: + """Send a resource list changed notification to the client.""" + await server.get_context().session.send_resource_list_changed() + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + async with create_session( + server._mcp_server, + resource_list_changed_callback=resource_list_changed_collector, + message_handler=message_handler, + ) as client_session: + # Trigger resource list changed notification + result = await client_session.call_tool("change_resource_list", {}) + assert result.isError is False + + # Verify the notification was received + assert resource_list_changed_collector.notification_count == 1 + + +@pytest.mark.anyio +async def test_tool_list_changed_callback(tool_list_changed_collector: ToolListChangedCollector) -> None: + """Test that tool list changed notifications are correctly received by the callback.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + @server.tool("change_tool_list") + async def change_tool_list_tool() -> bool: + """Send a tool list changed notification to the client.""" + await server.get_context().session.send_tool_list_changed() + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + async with create_session( + server._mcp_server, + tool_list_changed_callback=tool_list_changed_collector, + message_handler=message_handler, + ) as client_session: + # Trigger tool list changed notification + result = await client_session.call_tool("change_tool_list", {}) + assert result.isError is False + + # Verify the notification was received + assert tool_list_changed_collector.notification_count == 1 + + +@pytest.mark.anyio +async def test_prompt_list_changed_callback(prompt_list_changed_collector: PromptListChangedCollector) -> None: + """Test that prompt list changed notifications are correctly received by the callback.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + @server.tool("change_prompt_list") + async def change_prompt_list_tool() -> bool: + """Send a prompt list changed notification to the client.""" + await server.get_context().session.send_prompt_list_changed() + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + async with create_session( + server._mcp_server, + prompt_list_changed_callback=prompt_list_changed_collector, + message_handler=message_handler, + ) as client_session: + # Trigger prompt list changed notification + result = await client_session.call_tool("change_prompt_list", {}) + assert result.isError is False + + # Verify the notification was received + assert prompt_list_changed_collector.notification_count == 1 + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "notification_type,callback_param,collector_fixture,tool_name,tool_args,verification", + [ + ( + "progress", + "progress_notification_callback", + "progress_collector", + "send_progress", + {"progress": 75.0, "total": 100.0, "message": "Almost done"}, + lambda c: ( # type: ignore[misc] + len(c.notifications) == 1 # type: ignore[attr-defined] + and c.notifications[0].progress == 75.0 # type: ignore[attr-defined] + and c.notifications[0].total == 100.0 # type: ignore[attr-defined] + and c.notifications[0].message == "Almost done" # type: ignore[attr-defined] + ), + ), + ( + "resource_updated", + "resource_updated_callback", + "resource_updated_collector", + "update_resource", + {"uri": "file:///test/data.json"}, + lambda c: ( # type: ignore[misc] + len(c.notifications) == 1 # type: ignore[attr-defined] + and str(c.notifications[0].uri) == "file:///test/data.json" # type: ignore[attr-defined] + ), + ), + ( + "resource_list_changed", + "resource_list_changed_callback", + "resource_list_changed_collector", + "change_resource_list", + {}, + lambda c: c.notification_count == 1, # type: ignore[attr-defined] + ), + ( + "tool_list_changed", + "tool_list_changed_callback", + "tool_list_changed_collector", + "change_tool_list", + {}, + lambda c: c.notification_count == 1, # type: ignore[attr-defined] + ), + ( + "prompt_list_changed", + "prompt_list_changed_callback", + "prompt_list_changed_collector", + "change_prompt_list", + {}, + lambda c: c.notification_count == 1, # type: ignore[attr-defined] + ), + ], +) +async def test_notification_callback_parametrized( + notification_type: str, + callback_param: str, + collector_fixture: str, + tool_name: str, + tool_args: dict[str, Any], + verification: Callable[[Any], bool], + request: "FixtureRequest", +) -> None: + """Parametrized test for all notification callbacks.""" + from mcp.server.fastmcp import FastMCP + + # Get the collector from the fixture + collector = request.getfixturevalue(collector_fixture) + + server = FastMCP("test") + + # Define all tools (simpler than dynamic tool creation) + @server.tool("send_progress") + async def send_progress_tool(progress: float, total: float, message: str) -> bool: + """Send a progress notification to the client.""" + ctx = server.get_context() + if ctx.request_context.meta and ctx.request_context.meta.progressToken: + await ctx.session.send_progress_notification( + progress_token=ctx.request_context.meta.progressToken, + progress=progress, + total=total, + message=message, + ) + return True + + @server.tool("update_resource") + async def update_resource_tool(uri: str) -> bool: + """Send a resource updated notification to the client.""" + await server.get_context().session.send_resource_updated(AnyUrl(uri)) + return True + + @server.tool("change_resource_list") + async def change_resource_list_tool() -> bool: + """Send a resource list changed notification to the client.""" + await server.get_context().session.send_resource_list_changed() + return True + + @server.tool("change_tool_list") + async def change_tool_list_tool() -> bool: + """Send a tool list changed notification to the client.""" + await server.get_context().session.send_tool_list_changed() + return True + + @server.tool("change_prompt_list") + async def change_prompt_list_tool() -> bool: + """Send a prompt list changed notification to the client.""" + await server.get_context().session.send_prompt_list_changed() + return True + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle exceptions from the session.""" + if isinstance(message, Exception): + raise message + + # Create session with the appropriate callback + session_kwargs: dict[str, Any] = {callback_param: collector, "message_handler": message_handler} + + async with create_session(server._mcp_server, **session_kwargs) as client_session: # type: ignore[arg-type] + # Call the appropriate tool + meta = {"progressToken": "param-test-token"} if notification_type == "progress" else None + result = await client_session.call_tool(tool_name, tool_args, meta=meta) + assert result.isError is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "true" + + # Verify using the provided verification function + assert verification(collector), f"Verification failed for {notification_type}" + + +@pytest.mark.anyio +async def test_all_default_callbacks_with_notifications() -> None: + """Test that all default notification callbacks work (they do nothing). + + This single test covers multiple default callbacks by not providing + custom callbacks and triggering various notification types. + """ + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test-server") + + @server.tool("send_progress") + async def send_progress_tool(progress: float, total: float) -> bool: + """Send a progress notification.""" + ctx = server.get_context() + if ctx.request_context.meta and ctx.request_context.meta.progressToken: + await ctx.session.send_progress_notification( + progress_token=ctx.request_context.meta.progressToken, + progress=progress, + total=total, + ) + return True + + @server.tool("send_resource_updated") + async def send_resource_updated_tool(uri: str) -> bool: + """Send a resource updated notification.""" + from pydantic import AnyUrl + + await server.get_context().session.send_resource_updated(uri=AnyUrl(uri)) + return True + + @server.tool("send_resource_list_changed") + async def send_resource_list_changed_tool() -> bool: + """Send a resource list changed notification.""" + await server.get_context().session.send_resource_list_changed() + return True + + @server.tool("send_tool_list_changed") + async def send_tool_list_changed_tool() -> bool: + """Send a tool list changed notification.""" + await server.get_context().session.send_tool_list_changed() + return True + + @server.tool("send_prompt_list_changed") + async def send_prompt_list_changed_tool() -> bool: + """Send a prompt list changed notification.""" + await server.get_context().session.send_prompt_list_changed() + return True + + # Create session WITHOUT custom callbacks - all will use defaults + async with create_session(server._mcp_server) as client_session: + # Test progress notification with default callback + result1 = await client_session.call_tool( + "send_progress", + {"progress": 50.0, "total": 100.0}, + meta={"progressToken": "test-token"}, + ) + assert result1.isError is False + + # Test resource updated with default callback + result2 = await client_session.call_tool( + "send_resource_updated", + {"uri": "file:///test.txt"}, + ) + assert result2.isError is False + + # Test resource list changed with default callback + result3 = await client_session.call_tool("send_resource_list_changed", {}) + assert result3.isError is False + + # Test tool list changed with default callback + result4 = await client_session.call_tool("send_tool_list_changed", {}) + assert result4.isError is False + + # Test prompt list changed with default callback + result5 = await client_session.call_tool("send_prompt_list_changed", {}) + assert result5.isError is False