diff --git a/README.md b/README.md index 5e8129c96..c76efbc8e 100644 --- a/README.md +++ b/README.md @@ -886,9 +886,10 @@ async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: max_tokens=100, ) - if result.content.type == "text": - return result.content.text - return str(result.content) + content = result.content[0] if isinstance(result.content, list) else result.content + if content.type == "text": + return content.text + return str(content) ``` _Full example: [examples/snippets/servers/sampling.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/sampling.py)_ diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index 32c3e1d91..9d50f52cd 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -134,8 +134,9 @@ async def test_sampling(prompt: str, ctx: Context[ServerSession, None]) -> str: max_tokens=100, ) - if result.content.type == "text": - model_response = result.content.text + content = result.content[0] if isinstance(result.content, list) else result.content + if content.type == "text": + model_response = content.text else: model_response = "No response" diff --git a/examples/snippets/servers/sampling.py b/examples/snippets/servers/sampling.py index 0099836c2..8c3f0d16c 100644 --- a/examples/snippets/servers/sampling.py +++ b/examples/snippets/servers/sampling.py @@ -20,6 +20,7 @@ async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str: max_tokens=100, ) - if result.content.type == "text": - return result.content.text - return str(result.content) + content = result.content[0] if isinstance(result.content, list) else result.content + if content.type == "text": + return content.text + return str(content) diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index e93b95c90..077ff9af6 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -41,7 +41,11 @@ ResourcesCapability, ResourceUpdatedNotification, RootsCapability, + SamplingCapability, + SamplingContextCapability, SamplingMessage, + SamplingMessageContentBlock, + SamplingToolsCapability, ServerCapabilities, ServerNotification, ServerRequest, @@ -50,7 +54,10 @@ StopReason, SubscribeRequest, Tool, + ToolChoice, + ToolResultContent, ToolsCapability, + ToolUseContent, UnsubscribeRequest, ) from .types import ( @@ -65,6 +72,7 @@ "ClientResult", "ClientSession", "ClientSessionGroup", + "CompleteRequest", "CreateMessageRequest", "CreateMessageResult", "ErrorData", @@ -77,6 +85,7 @@ "InitializedNotification", "JSONRPCError", "JSONRPCRequest", + "JSONRPCResponse", "ListPromptsRequest", "ListPromptsResult", "ListResourcesRequest", @@ -91,12 +100,16 @@ "PromptsCapability", "ReadResourceRequest", "ReadResourceResult", + "Resource", "ResourcesCapability", "ResourceUpdatedNotification", - "Resource", "RootsCapability", + "SamplingCapability", + "SamplingContextCapability", "SamplingMessage", + "SamplingMessageContentBlock", "SamplingRole", + "SamplingToolsCapability", "ServerCapabilities", "ServerNotification", "ServerRequest", @@ -107,10 +120,11 @@ "StopReason", "SubscribeRequest", "Tool", + "ToolChoice", + "ToolResultContent", "ToolsCapability", + "ToolUseContent", "UnsubscribeRequest", "stdio_client", "stdio_server", - "CompleteRequest", - "JSONRPCResponse", ] diff --git a/src/mcp/types.py b/src/mcp/types.py index 871322740..df59f43ce 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar +from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar, Union from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel from pydantic.networks import AnyUrl, UrlConstraints @@ -250,8 +250,24 @@ class RootsCapability(BaseModel): model_config = ConfigDict(extra="allow") -class SamplingCapability(BaseModel): - """Capability for sampling operations.""" +class SamplingContextCapability(BaseModel): + """ + Capability for context inclusion during sampling. + + Indicates support for non-'none' values in the includeContext parameter. + SOFT-DEPRECATED: New implementations should use tools parameter instead. + """ + + model_config = ConfigDict(extra="allow") + + +class SamplingToolsCapability(BaseModel): + """ + Capability indicating support for tool calling during sampling. + + When present in ClientCapabilities.sampling, indicates that the client + supports the tools and toolChoice parameters in sampling requests. + """ model_config = ConfigDict(extra="allow") @@ -262,13 +278,34 @@ class ElicitationCapability(BaseModel): model_config = ConfigDict(extra="allow") +class SamplingCapability(BaseModel): + """ + Sampling capability structure, allowing fine-grained capability advertisement. + """ + + context: SamplingContextCapability | None = None + """ + Present if the client supports non-'none' values for includeContext parameter. + SOFT-DEPRECATED: New implementations should use tools parameter instead. + """ + tools: SamplingToolsCapability | None = None + """ + Present if the client supports tools and toolChoice parameters in sampling requests. + Presence indicates full tool calling support during sampling. + """ + model_config = ConfigDict(extra="allow") + + class ClientCapabilities(BaseModel): """Capabilities a client may support.""" experimental: dict[str, dict[str, Any]] | None = None """Experimental, non-standard capabilities that the client supports.""" sampling: SamplingCapability | None = None - """Present if the client supports sampling from an LLM.""" + """ + Present if the client supports sampling from an LLM. + Can contain fine-grained capabilities like context and tools support. + """ elicitation: ElicitationCapability | None = None """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None @@ -742,11 +779,89 @@ class AudioContent(BaseModel): model_config = ConfigDict(extra="allow") +class ToolUseContent(BaseModel): + """ + Content representing an assistant's request to invoke a tool. + + This content type appears in assistant messages when the LLM wants to call a tool + during sampling. The server should execute the tool and return a ToolResultContent + in the next user message. + """ + + type: Literal["tool_use"] + """Discriminator for tool use content.""" + + name: str + """The name of the tool to invoke. Must match a tool name from the request's tools array.""" + + id: str + """Unique identifier for this tool call, used to correlate with ToolResultContent.""" + + input: dict[str, Any] + """Arguments to pass to the tool. Must conform to the tool's inputSchema.""" + + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class ToolResultContent(BaseModel): + """ + Content representing the result of a tool execution. + + This content type appears in user messages as a response to a ToolUseContent + from the assistant. It contains the output of executing the requested tool. + """ + + type: Literal["tool_result"] + """Discriminator for tool result content.""" + + toolUseId: str + """The unique identifier that corresponds to the tool call's id field.""" + + content: list[Union[TextContent, ImageContent, AudioContent, "ResourceLink", "EmbeddedResource"]] = [] + """ + A list of content objects representing the tool result. + Defaults to empty list if not provided. + """ + + structuredContent: dict[str, Any] | None = None + """ + Optional structured tool output that matches the tool's outputSchema (if defined). + """ + + isError: bool | None = None + """Whether the tool execution resulted in an error.""" + + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +SamplingMessageContentBlock: TypeAlias = TextContent | ImageContent | AudioContent | ToolUseContent | ToolResultContent +"""Content block types allowed in sampling messages.""" + + class SamplingMessage(BaseModel): """Describes a message issued to or received from an LLM API.""" role: Role - content: TextContent | ImageContent | AudioContent + content: SamplingMessageContentBlock | list[SamplingMessageContentBlock] + """ + Message content. Can be a single content block or an array of content blocks + for multi-modal messages and tool interactions. + """ + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -1035,6 +1150,25 @@ class ModelPreferences(BaseModel): model_config = ConfigDict(extra="allow") +class ToolChoice(BaseModel): + """ + Controls tool usage behavior during sampling. + + Allows the server to specify whether and how the LLM should use tools + in its response. + """ + + mode: Literal["auto", "required", "none"] | None = None + """ + Controls when tools are used: + - "auto": Model decides whether to use tools (default) + - "required": Model MUST use at least one tool before completing + - "none": Model should not use tools + """ + + model_config = ConfigDict(extra="allow") + + class CreateMessageRequestParams(RequestParams): """Parameters for creating a message.""" @@ -1057,6 +1191,16 @@ class CreateMessageRequestParams(RequestParams): stopSequences: list[str] | None = None metadata: dict[str, Any] | None = None """Optional metadata to pass through to the LLM provider.""" + tools: list["Tool"] | None = None + """ + Tool definitions for the LLM to use during sampling. + Requires clientCapabilities.sampling.tools to be present. + """ + toolChoice: ToolChoice | None = None + """ + Controls tool usage behavior. + Requires clientCapabilities.sampling.tools and the tools parameter to be present. + """ model_config = ConfigDict(extra="allow") @@ -1067,18 +1211,26 @@ class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling params: CreateMessageRequestParams -StopReason = Literal["endTurn", "stopSequence", "maxTokens"] | str +StopReason = Literal["endTurn", "stopSequence", "maxTokens", "toolUse"] | str class CreateMessageResult(Result): """The client's response to a sampling/create_message request from the server.""" role: Role - content: TextContent | ImageContent | AudioContent + """The role of the message sender (typically 'assistant' for LLM responses).""" + content: SamplingMessageContentBlock | list[SamplingMessageContentBlock] + """ + Response content. May be a single content block or an array. + May include ToolUseContent if stopReason is 'toolUse'. + """ model: str """The name of the model that generated the message.""" stopReason: StopReason | None = None - """The reason why sampling stopped, if known.""" + """ + The reason why sampling stopped, if known. + 'toolUse' indicates the model wants to use a tool. + """ class ResourceTemplateReference(BaseModel): diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 43b321d96..6cf0627e9 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -210,7 +210,10 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] ) # Return the sampling result in the tool response - response = sampling_result.content.text if sampling_result.content.type == "text" else None + content = ( + sampling_result.content[0] if isinstance(sampling_result.content, list) else sampling_result.content + ) + response = content.text if content.type == "text" else None return [ TextContent( type="text", @@ -1239,7 +1242,12 @@ async def sampling_callback( nonlocal sampling_callback_invoked, captured_message_params sampling_callback_invoked = True captured_message_params = params - message_received = params.messages[0].content.text if params.messages[0].content.type == "text" else None + msg_content = ( + params.messages[0].content[0] + if isinstance(params.messages[0].content, list) + else params.messages[0].content + ) + message_received = msg_content.text if msg_content.type == "text" else None return types.CreateMessageResult( role="assistant", diff --git a/tests/test_types.py b/tests/test_types.py index bb2042c80..0fdc778cf 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,16 +1,26 @@ +from typing import Any + import pytest from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, ClientRequest, + CreateMessageRequestParams, + CreateMessageResult, Implementation, InitializeRequest, InitializeRequestParams, JSONRPCMessage, JSONRPCRequest, ListToolsResult, + SamplingCapability, + SamplingMessage, + TextContent, Tool, + ToolChoice, + ToolResultContent, + ToolUseContent, ) @@ -60,6 +70,210 @@ async def test_method_initialization(): assert initialize_request.params.protocolVersion == LATEST_PROTOCOL_VERSION +@pytest.mark.anyio +async def test_tool_use_content(): + """Test ToolUseContent type for SEP-1577.""" + tool_use_data = { + "type": "tool_use", + "name": "get_weather", + "id": "call_abc123", + "input": {"location": "San Francisco", "unit": "celsius"}, + } + + tool_use = ToolUseContent.model_validate(tool_use_data) + assert tool_use.type == "tool_use" + assert tool_use.name == "get_weather" + assert tool_use.id == "call_abc123" + assert tool_use.input == {"location": "San Francisco", "unit": "celsius"} + + # Test serialization + serialized = tool_use.model_dump(by_alias=True, exclude_none=True) + assert serialized["type"] == "tool_use" + assert serialized["name"] == "get_weather" + + +@pytest.mark.anyio +async def test_tool_result_content(): + """Test ToolResultContent type for SEP-1577.""" + tool_result_data = { + "type": "tool_result", + "toolUseId": "call_abc123", + "content": [{"type": "text", "text": "It's 72°F in San Francisco"}], + "isError": False, + } + + tool_result = ToolResultContent.model_validate(tool_result_data) + assert tool_result.type == "tool_result" + assert tool_result.toolUseId == "call_abc123" + assert len(tool_result.content) == 1 + assert tool_result.isError is False + + # Test with empty content (should default to []) + minimal_result_data = {"type": "tool_result", "toolUseId": "call_xyz"} + minimal_result = ToolResultContent.model_validate(minimal_result_data) + assert minimal_result.content == [] + + +@pytest.mark.anyio +async def test_tool_choice(): + """Test ToolChoice type for SEP-1577.""" + # Test with mode + tool_choice_data = {"mode": "required"} + tool_choice = ToolChoice.model_validate(tool_choice_data) + assert tool_choice.mode == "required" + + # Test with minimal data (all fields optional) + minimal_choice = ToolChoice.model_validate({}) + assert minimal_choice.mode is None + + # Test different modes + auto_choice = ToolChoice.model_validate({"mode": "auto"}) + assert auto_choice.mode == "auto" + + none_choice = ToolChoice.model_validate({"mode": "none"}) + assert none_choice.mode == "none" + + +@pytest.mark.anyio +async def test_sampling_message_with_user_role(): + """Test SamplingMessage with user role for SEP-1577.""" + # Test with single content + user_msg_data = {"role": "user", "content": {"type": "text", "text": "Hello"}} + user_msg = SamplingMessage.model_validate(user_msg_data) + assert user_msg.role == "user" + assert isinstance(user_msg.content, TextContent) + + # Test with array of content including tool result + multi_content_data: dict[str, Any] = { + "role": "user", + "content": [ + {"type": "text", "text": "Here's the result:"}, + {"type": "tool_result", "toolUseId": "call_123", "content": []}, + ], + } + multi_msg = SamplingMessage.model_validate(multi_content_data) + assert multi_msg.role == "user" + assert isinstance(multi_msg.content, list) + assert len(multi_msg.content) == 2 + + +@pytest.mark.anyio +async def test_sampling_message_with_assistant_role(): + """Test SamplingMessage with assistant role for SEP-1577.""" + # Test with tool use content + assistant_msg_data = { + "role": "assistant", + "content": { + "type": "tool_use", + "name": "search", + "id": "call_456", + "input": {"query": "MCP protocol"}, + }, + } + assistant_msg = SamplingMessage.model_validate(assistant_msg_data) + assert assistant_msg.role == "assistant" + assert isinstance(assistant_msg.content, ToolUseContent) + + # Test with array of mixed content + multi_content_data: dict[str, Any] = { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me search for that..."}, + {"type": "tool_use", "name": "search", "id": "call_789", "input": {}}, + ], + } + multi_msg = SamplingMessage.model_validate(multi_content_data) + assert isinstance(multi_msg.content, list) + assert len(multi_msg.content) == 2 + + +@pytest.mark.anyio +async def test_sampling_message_backward_compatibility(): + """Test that SamplingMessage maintains backward compatibility.""" + # Old-style message (single content, no tools) + old_style_data = {"role": "user", "content": {"type": "text", "text": "Hello"}} + old_msg = SamplingMessage.model_validate(old_style_data) + assert old_msg.role == "user" + assert isinstance(old_msg.content, TextContent) + + # New-style message with tool content + new_style_data: dict[str, Any] = { + "role": "assistant", + "content": {"type": "tool_use", "name": "test", "id": "call_1", "input": {}}, + } + new_msg = SamplingMessage.model_validate(new_style_data) + assert new_msg.role == "assistant" + assert isinstance(new_msg.content, ToolUseContent) + + # Array content + array_style_data: dict[str, Any] = { + "role": "user", + "content": [{"type": "text", "text": "Result:"}, {"type": "tool_result", "toolUseId": "call_1", "content": []}], + } + array_msg = SamplingMessage.model_validate(array_style_data) + assert isinstance(array_msg.content, list) + + +@pytest.mark.anyio +async def test_create_message_request_params_with_tools(): + """Test CreateMessageRequestParams with tools for SEP-1577.""" + tool = Tool( + name="get_weather", + description="Get weather information", + inputSchema={"type": "object", "properties": {"location": {"type": "string"}}}, + ) + + params = CreateMessageRequestParams( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text="What's the weather?"))], + maxTokens=1000, + tools=[tool], + toolChoice=ToolChoice(mode="auto"), + ) + + assert params.tools is not None + assert len(params.tools) == 1 + assert params.tools[0].name == "get_weather" + assert params.toolChoice is not None + assert params.toolChoice.mode == "auto" + + +@pytest.mark.anyio +async def test_create_message_result_with_tool_use(): + """Test CreateMessageResult with tool use content for SEP-1577.""" + result_data = { + "role": "assistant", + "content": {"type": "tool_use", "name": "search", "id": "call_123", "input": {"query": "test"}}, + "model": "claude-3", + "stopReason": "toolUse", + } + + result = CreateMessageResult.model_validate(result_data) + assert result.role == "assistant" + assert isinstance(result.content, ToolUseContent) + assert result.stopReason == "toolUse" + assert result.model == "claude-3" + + +@pytest.mark.anyio +async def test_client_capabilities_with_sampling_tools(): + """Test ClientCapabilities with nested sampling capabilities for SEP-1577.""" + # New structured format + capabilities_data: dict[str, Any] = { + "sampling": {"tools": {}}, + } + capabilities = ClientCapabilities.model_validate(capabilities_data) + assert capabilities.sampling is not None + assert isinstance(capabilities.sampling, SamplingCapability) + assert capabilities.sampling.tools is not None + + # With both context and tools + full_capabilities_data: dict[str, Any] = {"sampling": {"context": {}, "tools": {}}} + full_caps = ClientCapabilities.model_validate(full_capabilities_data) + assert isinstance(full_caps.sampling, SamplingCapability) + assert full_caps.sampling.context is not None + assert full_caps.sampling.tools is not None + + def test_tool_preserves_json_schema_2020_12_fields(): """Verify that JSON Schema 2020-12 keywords are preserved in Tool.inputSchema.