Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,9 @@ async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str:
max_tokens=100,
)

if all(c.type == "text" for c in result.content_as_list):
return "\n".join(c.text for c in result.content_as_list if c.type == "text")
# Since we're not passing tools param, result.content is single content
if result.content.type == "text":
return result.content.text
return str(result.content)
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,9 @@ async def test_sampling(prompt: str, ctx: Context[ServerSession, None]) -> str:
max_tokens=100,
)

if any(c.type == "text" for c in result.content_as_list):
model_response = "\n".join(c.text for c in result.content_as_list if c.type == "text")
# Since we're not passing tools param, result.content is single content
if result.content.type == "text":
model_response = result.content.text
else:
model_response = "No response"

Expand Down
5 changes: 3 additions & 2 deletions examples/snippets/servers/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str:
max_tokens=100,
)

if all(c.type == "text" for c in result.content_as_list):
return "\n".join(c.text for c in result.content_as_list if c.type == "text")
# Since we're not passing tools param, result.content is single content
if result.content.type == "text":
return result.content.text
return str(result.content)
4 changes: 4 additions & 0 deletions src/mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
CompleteRequest,
CreateMessageRequest,
CreateMessageResult,
CreateMessageResultWithTools,
ErrorData,
GetPromptRequest,
GetPromptResult,
Expand Down Expand Up @@ -42,6 +43,7 @@
ResourceUpdatedNotification,
RootsCapability,
SamplingCapability,
SamplingContent,
SamplingContextCapability,
SamplingMessage,
SamplingMessageContentBlock,
Expand Down Expand Up @@ -75,6 +77,7 @@
"CompleteRequest",
"CreateMessageRequest",
"CreateMessageResult",
"CreateMessageResultWithTools",
"ErrorData",
"GetPromptRequest",
"GetPromptResult",
Expand Down Expand Up @@ -105,6 +108,7 @@
"ResourceUpdatedNotification",
"RootsCapability",
"SamplingCapability",
"SamplingContent",
"SamplingContextCapability",
"SamplingMessage",
"SamplingMessageContentBlock",
Expand Down
88 changes: 67 additions & 21 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
"""

from enum import Enum
from typing import Any, TypeVar
from typing import Any, TypeVar, overload

import anyio
import anyio.lowlevel
Expand Down Expand Up @@ -233,6 +233,7 @@ async def send_resource_updated(self, uri: AnyUrl) -> None: # pragma: no cover
)
)

@overload
async def create_message(
self,
messages: list[types.SamplingMessage],
Expand All @@ -244,10 +245,47 @@ async def create_message(
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
tools: list[types.Tool] | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this confused me (why not leave it out?), but its a kwarg, and is needed to discriminate between the overloads

tools: None = None,
tool_choice: types.ToolChoice | None = None,
related_request_id: types.RequestId | None = None,
) -> types.CreateMessageResult:
"""Overload: Without tools, returns single content."""
...

@overload
async def create_message(
self,
messages: list[types.SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
tools: list[types.Tool],
tool_choice: types.ToolChoice | None = None,
related_request_id: types.RequestId | None = None,
) -> types.CreateMessageResultWithTools:
"""Overload: With tools, returns array-capable content."""
...

async def create_message(
self,
messages: list[types.SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
tools: list[types.Tool] | None = None,
tool_choice: types.ToolChoice | None = None,
related_request_id: types.RequestId | None = None,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools:
"""Send a sampling/create_message request.

Args:
Expand Down Expand Up @@ -278,27 +316,35 @@ async def create_message(
validate_sampling_tools(client_caps, tools, tool_choice)
validate_tool_use_result_messages(messages)

request = types.ServerRequest(
types.CreateMessageRequest(
params=types.CreateMessageRequestParams(
messages=messages,
systemPrompt=system_prompt,
includeContext=include_context,
temperature=temperature,
maxTokens=max_tokens,
stopSequences=stop_sequences,
metadata=metadata,
modelPreferences=model_preferences,
tools=tools,
toolChoice=tool_choice,
),
)
)
metadata_obj = ServerMessageMetadata(related_request_id=related_request_id)

# Use different result types based on whether tools are provided
if tools is not None:
return await self.send_request(
request=request,
result_type=types.CreateMessageResultWithTools,
metadata=metadata_obj,
)
return await self.send_request(
request=types.ServerRequest(
types.CreateMessageRequest(
params=types.CreateMessageRequestParams(
messages=messages,
systemPrompt=system_prompt,
includeContext=include_context,
temperature=temperature,
maxTokens=max_tokens,
stopSequences=stop_sequences,
metadata=metadata,
modelPreferences=model_preferences,
tools=tools,
toolChoice=tool_choice,
),
)
),
request=request,
result_type=types.CreateMessageResult,
metadata=ServerMessageMetadata(
related_request_id=related_request_id,
),
metadata=metadata_obj,
)

async def list_roots(self) -> types.ListRootsResult:
Expand Down
26 changes: 25 additions & 1 deletion src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,10 @@ class ToolResultContent(BaseModel):
SamplingMessageContentBlock: TypeAlias = TextContent | ImageContent | AudioContent | ToolUseContent | ToolResultContent
"""Content block types allowed in sampling messages."""

SamplingContent: TypeAlias = TextContent | ImageContent | AudioContent
"""Basic content types for sampling responses (without tool use).
Used for backwards-compatible CreateMessageResult when tools are not used."""


class SamplingMessage(BaseModel):
"""Describes a message issued to or received from an LLM API."""
Expand Down Expand Up @@ -1543,7 +1547,27 @@ class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling


class CreateMessageResult(Result):
"""The client's response to a sampling/create_message request from the server."""
"""The client's response to a sampling/create_message request from the server.

This is the backwards-compatible version that returns single content (no arrays).
Used when the request does not include tools.
"""

role: Role
"""The role of the message sender (typically 'assistant' for LLM responses)."""
content: SamplingContent
"""Response content. Single content block (text, image, or audio)."""
model: str
"""The name of the model that generated the message."""
stopReason: StopReason | None = None
"""The reason why sampling stopped, if known."""


class CreateMessageResultWithTools(Result):
"""The client's response to a sampling/create_message request when tools were provided.

This version supports array content for tool use flows.
"""

role: Role
"""The role of the message sender (typically 'assistant' for LLM responses)."""
Expand Down
78 changes: 78 additions & 0 deletions tests/client/test_sampling_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from mcp.types import (
CreateMessageRequestParams,
CreateMessageResult,
CreateMessageResultWithTools,
SamplingMessage,
TextContent,
ToolUseContent,
)


Expand Down Expand Up @@ -56,3 +58,79 @@ async def test_sampling_tool(message: str):
assert result.isError is True
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"


@pytest.mark.anyio
async def test_create_message_backwards_compat_single_content():
"""Test backwards compatibility: create_message without tools returns single content."""
from mcp.server.fastmcp import FastMCP

server = FastMCP("test")

# Callback returns single content (text)
callback_return = CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="Hello from LLM"),
model="test-model",
stopReason="endTurn",
)

async def sampling_callback(
context: RequestContext[ClientSession, None],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return callback_return

@server.tool("test_backwards_compat")
async def test_tool(message: str):
# Call create_message WITHOUT tools
result = await server.get_context().session.create_message(
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
max_tokens=100,
)
# Backwards compat: result should be CreateMessageResult
assert isinstance(result, CreateMessageResult)
# Content should be single (not a list) - this is the key backwards compat check
assert isinstance(result.content, TextContent)
assert result.content.text == "Hello from LLM"
# CreateMessageResult should NOT have content_as_list (that's on WithTools)
assert not hasattr(result, "content_as_list") or not callable(getattr(result, "content_as_list", None))
return True

async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session:
result = await client_session.call_tool("test_backwards_compat", {"message": "Test"})
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"


@pytest.mark.anyio
async def test_create_message_result_with_tools_type():
"""Test that CreateMessageResultWithTools supports content_as_list."""
# Test the type itself, not the overload (overload requires client capability setup)
result = CreateMessageResultWithTools(
role="assistant",
content=ToolUseContent(type="tool_use", id="call_123", name="get_weather", input={"city": "SF"}),
model="test-model",
stopReason="toolUse",
)

# CreateMessageResultWithTools should have content_as_list
content_list = result.content_as_list
assert len(content_list) == 1
assert content_list[0].type == "tool_use"

# It should also work with array content
result_array = CreateMessageResultWithTools(
role="assistant",
content=[
TextContent(type="text", text="Let me check the weather"),
ToolUseContent(type="tool_use", id="call_456", name="get_weather", input={"city": "NYC"}),
],
model="test-model",
stopReason="toolUse",
)
content_list_array = result_array.content_as_list
assert len(content_list_array) == 2
assert content_list_array[0].type == "text"
assert content_list_array[1].type == "tool_use"
5 changes: 3 additions & 2 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,9 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
)

# Return the sampling result in the tool response
if all(c.type == "text" for c in sampling_result.content_as_list):
response = "\n".join(c.text for c in sampling_result.content_as_list if c.type == "text")
# Since we're not passing tools param, result.content is single content
if sampling_result.content.type == "text":
response = sampling_result.content.text
else:
response = str(sampling_result.content)
return [
Expand Down
25 changes: 23 additions & 2 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ClientRequest,
CreateMessageRequestParams,
CreateMessageResult,
CreateMessageResultWithTools,
Implementation,
InitializeRequest,
InitializeRequestParams,
Expand Down Expand Up @@ -239,15 +240,16 @@ async def test_create_message_request_params_with_tools():

@pytest.mark.anyio
async def test_create_message_result_with_tool_use():
"""Test CreateMessageResult with tool use content for SEP-1577."""
"""Test CreateMessageResultWithTools with tool use content for SEP-1577."""
result_data = {
"role": "assistant",
"content": {"type": "tool_use", "name": "search", "id": "call_123", "input": {"query": "test"}},
"model": "claude-3",
"stopReason": "toolUse",
}

result = CreateMessageResult.model_validate(result_data)
# Tool use content uses CreateMessageResultWithTools
result = CreateMessageResultWithTools.model_validate(result_data)
assert result.role == "assistant"
assert isinstance(result.content, ToolUseContent)
assert result.stopReason == "toolUse"
Expand All @@ -259,6 +261,25 @@ async def test_create_message_result_with_tool_use():
assert content_list[0] == result.content


@pytest.mark.anyio
async def test_create_message_result_basic():
"""Test CreateMessageResult with basic text content (backwards compatible)."""
result_data = {
"role": "assistant",
"content": {"type": "text", "text": "Hello!"},
"model": "claude-3",
"stopReason": "endTurn",
}

# Basic content uses CreateMessageResult (single content, no arrays)
result = CreateMessageResult.model_validate(result_data)
assert result.role == "assistant"
assert isinstance(result.content, TextContent)
assert result.content.text == "Hello!"
assert result.stopReason == "endTurn"
assert result.model == "claude-3"


@pytest.mark.anyio
async def test_client_capabilities_with_sampling_tools():
"""Test ClientCapabilities with nested sampling capabilities for SEP-1577."""
Expand Down