diff --git a/examples/clients/simple-task-client/README.md b/examples/clients/simple-task-client/README.md new file mode 100644 index 000000000..103be0f1f --- /dev/null +++ b/examples/clients/simple-task-client/README.md @@ -0,0 +1,43 @@ +# Simple Task Client + +A minimal MCP client demonstrating polling for task results over streamable HTTP. + +## Running + +First, start the simple-task server in another terminal: + +```bash +cd examples/servers/simple-task +uv run mcp-simple-task +``` + +Then run the client: + +```bash +cd examples/clients/simple-task-client +uv run mcp-simple-task-client +``` + +Use `--url` to connect to a different server. + +## What it does + +1. Connects to the server via streamable HTTP +2. Calls the `long_running_task` tool as a task +3. Polls the task status until completion +4. Retrieves and prints the result + +## Expected output + +```text +Available tools: ['long_running_task'] + +Calling tool as a task... +Task created: + Status: working - Starting work... + Status: working - Processing step 1... + Status: working - Processing step 2... + Status: completed - + +Result: Task completed! +``` diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py b/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py new file mode 100644 index 000000000..2fc2cda8d --- /dev/null +++ b/examples/clients/simple-task-client/mcp_simple_task_client/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .main import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/clients/simple-task-client/mcp_simple_task_client/main.py b/examples/clients/simple-task-client/mcp_simple_task_client/main.py new file mode 100644 index 000000000..ea997d7ea --- /dev/null +++ b/examples/clients/simple-task-client/mcp_simple_task_client/main.py @@ -0,0 +1,73 @@ +"""Simple task client demonstrating MCP tasks polling over streamable HTTP.""" + +import asyncio + +import click +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + CallToolResult, + ClientRequest, + CreateTaskResult, + TaskMetadata, + TextContent, +) + + +async def run(url: str) -> None: + async with streamablehttp_client(url) as (read, write, _): + async with ClientSession(read, write) as session: + await session.initialize() + + # List tools + tools = await session.list_tools() + print(f"Available tools: {[t.name for t in tools.tools]}") + + # Call the tool as a task + print("\nCalling tool as a task...") + result = await session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="long_running_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + task_id = result.task.taskId + print(f"Task created: {task_id}") + + # Poll until done + while True: + status = await session.experimental.get_task(task_id) + print(f" Status: {status.status} - {status.statusMessage or ''}") + + if status.status == "completed": + break + elif status.status in ("failed", "cancelled"): + print(f"Task ended with status: {status.status}") + return + + await asyncio.sleep(0.5) + + # Get the result + task_result = await session.experimental.get_task_result(task_id, CallToolResult) + content = task_result.content[0] + if isinstance(content, TextContent): + print(f"\nResult: {content.text}") + + +@click.command() +@click.option("--url", default="http://localhost:8000/mcp", help="Server URL") +def main(url: str) -> int: + asyncio.run(run(url)) + return 0 + + +if __name__ == "__main__": + main() diff --git a/examples/clients/simple-task-client/pyproject.toml b/examples/clients/simple-task-client/pyproject.toml new file mode 100644 index 000000000..da10392e3 --- /dev/null +++ b/examples/clients/simple-task-client/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-simple-task-client" +version = "0.1.0" +description = "A simple MCP client demonstrating task polling" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "tasks", "client"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["click>=8.0", "mcp"] + +[project.scripts] +mcp-simple-task-client = "mcp_simple_task_client.main:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_task_client"] + +[tool.pyright] +include = ["mcp_simple_task_client"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/examples/servers/simple-task/README.md b/examples/servers/simple-task/README.md new file mode 100644 index 000000000..6914e0414 --- /dev/null +++ b/examples/servers/simple-task/README.md @@ -0,0 +1,37 @@ +# Simple Task Server + +A minimal MCP server demonstrating the experimental tasks feature over streamable HTTP. + +## Running + +```bash +cd examples/servers/simple-task +uv run mcp-simple-task +``` + +The server starts on `http://localhost:8000/mcp` by default. Use `--port` to change. + +## What it does + +This server exposes a single tool `long_running_task` that: + +1. Must be called as a task (with `task` metadata in the request) +2. Takes ~3 seconds to complete +3. Sends status updates during execution +4. Returns a result when complete + +## Usage with the client + +In one terminal, start the server: + +```bash +cd examples/servers/simple-task +uv run mcp-simple-task +``` + +In another terminal, run the client: + +```bash +cd examples/clients/simple-task-client +uv run mcp-simple-task-client +``` diff --git a/examples/servers/simple-task/mcp_simple_task/__init__.py b/examples/servers/simple-task/mcp_simple_task/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/servers/simple-task/mcp_simple_task/__main__.py b/examples/servers/simple-task/mcp_simple_task/__main__.py new file mode 100644 index 000000000..e7ef16530 --- /dev/null +++ b/examples/servers/simple-task/mcp_simple_task/__main__.py @@ -0,0 +1,5 @@ +import sys + +from .server import main + +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-task/mcp_simple_task/server.py b/examples/servers/simple-task/mcp_simple_task/server.py new file mode 100644 index 000000000..845f05323 --- /dev/null +++ b/examples/servers/simple-task/mcp_simple_task/server.py @@ -0,0 +1,125 @@ +"""Simple task server demonstrating MCP tasks over streamable HTTP.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any + +import anyio +import click +import mcp.types as types +from anyio.abc import TaskGroup +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.shared.experimental.tasks import InMemoryTaskStore, task_execution +from starlette.applications import Starlette +from starlette.routing import Mount + + +@dataclass +class AppContext: + task_group: TaskGroup + store: InMemoryTaskStore + + +@asynccontextmanager +async def lifespan(server: Server[AppContext, Any]) -> AsyncIterator[AppContext]: + store = InMemoryTaskStore() + async with anyio.create_task_group() as tg: + yield AppContext(task_group=tg, store=store) + store.cleanup() + + +server: Server[AppContext, Any] = Server("simple-task-server", lifespan=lifespan) + + +@server.list_tools() +async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="long_running_task", + description="A task that takes a few seconds to complete with status updates", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[types.TextContent] | types.CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if not ctx.experimental.is_task: + return [types.TextContent(type="text", text="Error: This tool must be called as a task")] + + # Create the task + metadata = ctx.experimental.task_metadata + assert metadata is not None + task = await app.store.create_task(metadata) + + # Spawn background work + async def do_work() -> None: + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.update_status("Starting work...") + await anyio.sleep(1) + + await task_ctx.update_status("Processing step 1...") + await anyio.sleep(1) + + await task_ctx.update_status("Processing step 2...") + await anyio.sleep(1) + + await task_ctx.complete( + types.CallToolResult(content=[types.TextContent(type="text", text="Task completed!")]) + ) + + app.task_group.start_soon(do_work) + return types.CreateTaskResult(task=task) + + +@server.experimental.get_task() +async def handle_get_task(request: types.GetTaskRequest) -> types.GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return types.GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + +@server.experimental.get_task_result() +async def handle_get_task_result(request: types.GetTaskPayloadRequest) -> types.GetTaskPayloadResult: + app = server.request_context.lifespan_context + result = await app.store.get_result(request.params.taskId) + if result is None: + raise ValueError(f"Result for task {request.params.taskId} not found") + assert isinstance(result, types.CallToolResult) + return types.GetTaskPayloadResult(**result.model_dump()) + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on") +def main(port: int) -> int: + import uvicorn + + session_manager = StreamableHTTPSessionManager(app=server) + + @asynccontextmanager + async def app_lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + + starlette_app = Starlette( + routes=[Mount("/mcp", app=session_manager.handle_request)], + lifespan=app_lifespan, + ) + + print(f"Starting server on http://localhost:{port}/mcp") + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 diff --git a/examples/servers/simple-task/pyproject.toml b/examples/servers/simple-task/pyproject.toml new file mode 100644 index 000000000..a8fba8bdc --- /dev/null +++ b/examples/servers/simple-task/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "mcp-simple-task" +version = "0.1.0" +description = "A simple MCP server demonstrating tasks" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "tasks"] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", +] +dependencies = ["anyio>=4.5", "click>=8.0", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-task = "mcp_simple_task.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_task"] + +[tool.pyright] +include = ["mcp_simple_task"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[dependency-groups] +dev = ["pyright>=1.1.378", "ruff>=0.6.9"] diff --git a/src/mcp/client/experimental/__init__.py b/src/mcp/client/experimental/__init__.py new file mode 100644 index 000000000..b6579b191 --- /dev/null +++ b/src/mcp/client/experimental/__init__.py @@ -0,0 +1,9 @@ +""" +Experimental client features. + +WARNING: These APIs are experimental and may change without notice. +""" + +from mcp.client.experimental.tasks import ExperimentalClientFeatures + +__all__ = ["ExperimentalClientFeatures"] diff --git a/src/mcp/client/experimental/tasks.py b/src/mcp/client/experimental/tasks.py new file mode 100644 index 000000000..136abd1da --- /dev/null +++ b/src/mcp/client/experimental/tasks.py @@ -0,0 +1,131 @@ +""" +Experimental client-side task support. + +This module provides client methods for interacting with MCP tasks. + +WARNING: These APIs are experimental and may change without notice. + +Example: + # Get task status + status = await session.experimental.get_task(task_id) + + # Get task result when complete + if status.status == "completed": + result = await session.experimental.get_task_result(task_id, CallToolResult) + + # List all tasks + tasks = await session.experimental.list_tasks() + + # Cancel a task + await session.experimental.cancel_task(task_id) +""" + +from typing import TYPE_CHECKING, TypeVar + +import mcp.types as types + +if TYPE_CHECKING: + from mcp.client.session import ClientSession + +ResultT = TypeVar("ResultT", bound=types.Result) + + +class ExperimentalClientFeatures: + """ + Experimental client features for tasks and other experimental APIs. + + WARNING: These APIs are experimental and may change without notice. + + Access via session.experimental: + status = await session.experimental.get_task(task_id) + """ + + def __init__(self, session: "ClientSession") -> None: + self._session = session + + async def get_task(self, task_id: str) -> types.GetTaskResult: + """ + Get the current status of a task. + + Args: + task_id: The task identifier + + Returns: + GetTaskResult containing the task status and metadata + """ + return await self._session.send_request( + types.ClientRequest( + types.GetTaskRequest( + params=types.GetTaskRequestParams(taskId=task_id), + ) + ), + types.GetTaskResult, + ) + + async def get_task_result( + self, + task_id: str, + result_type: type[ResultT], + ) -> ResultT: + """ + Get the result of a completed task. + + The result type depends on the original request type: + - tools/call tasks return CallToolResult + - Other request types return their corresponding result type + + Args: + task_id: The task identifier + result_type: The expected result type (e.g., CallToolResult) + + Returns: + The task result, validated against result_type + """ + return await self._session.send_request( + types.ClientRequest( + types.GetTaskPayloadRequest( + params=types.GetTaskPayloadRequestParams(taskId=task_id), + ) + ), + result_type, + ) + + async def list_tasks( + self, + cursor: str | None = None, + ) -> types.ListTasksResult: + """ + List all tasks. + + Args: + cursor: Optional pagination cursor + + Returns: + ListTasksResult containing tasks and optional next cursor + """ + params = types.PaginatedRequestParams(cursor=cursor) if cursor else None + return await self._session.send_request( + types.ClientRequest( + types.ListTasksRequest(params=params), + ), + types.ListTasksResult, + ) + + async def cancel_task(self, task_id: str) -> types.CancelTaskResult: + """ + Cancel a running task. + + Args: + task_id: The task identifier + + Returns: + CancelTaskResult with the updated task state + """ + return await self._session.send_request( + types.ClientRequest( + types.CancelTaskRequest( + params=types.CancelTaskRequestParams(taskId=task_id), + ) + ), + types.CancelTaskResult, + ) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3835a2a57..354baa813 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -9,6 +9,7 @@ from typing_extensions import deprecated import mcp.types as types +from mcp.client.experimental import ExperimentalClientFeatures from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder @@ -135,6 +136,7 @@ def __init__( self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._server_capabilities: types.ServerCapabilities | None = None + self._experimental: ExperimentalClientFeatures | None = None async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None @@ -184,6 +186,20 @@ def get_server_capabilities(self) -> types.ServerCapabilities | None: """ return self._server_capabilities + @property + def experimental(self) -> "ExperimentalClientFeatures": + """Experimental APIs for tasks and other features. + + WARNING: These APIs are experimental and may change without notice. + + Example: + status = await session.experimental.get_task(task_id) + result = await session.experimental.get_task_result(task_id, CallToolResult) + """ + if self._experimental is None: + self._experimental = ExperimentalClientFeatures(self) + return self._experimental + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( @@ -537,6 +553,8 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques case types.PingRequest(): # pragma: no cover with responder: return await responder.respond(types.ClientResult(root=types.EmptyResult())) + case _: + raise NotImplementedError() async def _handle_incoming( self, diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py new file mode 100644 index 000000000..575738104 --- /dev/null +++ b/src/mcp/server/lowlevel/experimental.py @@ -0,0 +1,137 @@ +"""Experimental handlers for the low-level MCP server. + +WARNING: These APIs are experimental and may change without notice. +""" + +import logging +from collections.abc import Awaitable, Callable + +from mcp.server.lowlevel.func_inspection import create_call_wrapper +from mcp.types import ( + CancelTaskRequest, + CancelTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ServerCapabilities, + ServerResult, + ServerTasksCapability, + ServerTasksRequestsCapability, + TasksCancelCapability, + TasksListCapability, + TasksToolsCapability, +) + +logger = logging.getLogger(__name__) + + +class ExperimentalHandlers: + """Experimental request/notification handlers. + + WARNING: These APIs are experimental and may change without notice. + """ + + def __init__( + self, + request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]], + notification_handlers: dict[type, Callable[..., Awaitable[None]]], + ): + self._request_handlers = request_handlers + self._notification_handlers = notification_handlers + + def update_capabilities(self, capabilities: ServerCapabilities) -> None: + capabilities.tasks = ServerTasksCapability() + if ListTasksRequest in self._request_handlers: + capabilities.tasks.list = TasksListCapability() + if CancelTaskRequest in self._request_handlers: + capabilities.tasks.cancel = TasksCancelCapability() + + capabilities.tasks.requests = ServerTasksRequestsCapability( + tools=TasksToolsCapability() + ) # assuming always supported for now + + def list_tasks( + self, + ) -> Callable[ + [Callable[[ListTasksRequest], Awaitable[ListTasksResult]]], + Callable[[ListTasksRequest], Awaitable[ListTasksResult]], + ]: + """Register a handler for listing tasks. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator( + func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]], + ) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]: + logger.debug("Registering handler for ListTasksRequest") + wrapper = create_call_wrapper(func, ListTasksRequest) + + async def handler(req: ListTasksRequest): + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[ListTasksRequest] = handler + return func + + return decorator + + def get_task(self): + """Register a handler for getting task status. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator(func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]]): + logger.debug("Registering handler for GetTaskRequest") + wrapper = create_call_wrapper(func, GetTaskRequest) + + async def handler(req: GetTaskRequest): + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[GetTaskRequest] = handler + return func + + return decorator + + def get_task_result(self): + """Register a handler for getting task results/payload. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator(func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]): + logger.debug("Registering handler for GetTaskPayloadRequest") + wrapper = create_call_wrapper(func, GetTaskPayloadRequest) + + async def handler(req: GetTaskPayloadRequest): + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[GetTaskPayloadRequest] = handler + return func + + return decorator + + def cancel_task(self): + """Register a handler for cancelling tasks. + + WARNING: This API is experimental and may change without notice. + """ + + def decorator(func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]): + logger.debug("Registering handler for CancelTaskRequest") + wrapper = create_call_wrapper(func, CancelTaskRequest) + + async def handler(req: CancelTaskRequest): + result = await wrapper(req) + return ServerResult(result) + + self._request_handlers[CancelTaskRequest] = handler + return func + + return decorator diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 49d289fb7..5c205f2ee 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -82,11 +82,12 @@ async def main(): from typing_extensions import TypeVar import mcp.types as types +from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.shared.context import RequestContext +from mcp.shared.context import Experimental, RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder @@ -154,6 +155,7 @@ def __init__( } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self._tool_cache: dict[str, types.Tool] = {} + self._experimental_handlers: ExperimentalHandlers | None = None logger.debug("Initializing server %r", name) def create_initialization_options( @@ -219,7 +221,7 @@ def get_capabilities( if types.CompleteRequest in self.request_handlers: completions_capability = types.CompletionsCapability() - return types.ServerCapabilities( + capabilities = types.ServerCapabilities( prompts=prompts_capability, resources=resources_capability, tools=tools_capability, @@ -227,6 +229,9 @@ def get_capabilities( experimental=experimental_capabilities, completions=completions_capability, ) + if self._experimental_handlers: + self._experimental_handlers.update_capabilities(capabilities) + return capabilities @property def request_context( @@ -235,6 +240,18 @@ def request_context( """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() + @property + def experimental(self) -> ExperimentalHandlers: + """Experimental APIs for tasks and other features. + + WARNING: These APIs are experimental and may change without notice. + """ + + # We create this inline so we only add these capabilities _if_ they're actually used + if self._experimental_handlers is None: + self._experimental_handlers = ExperimentalHandlers(self.request_handlers, self.notification_handlers) + return self._experimental_handlers + def list_prompts(self): def decorator( func: Callable[[], Awaitable[list[types.Prompt]]] @@ -480,7 +497,13 @@ def call_tool(self, *, validate_input: bool = True): def decorator( func: Callable[ ..., - Awaitable[UnstructuredContent | StructuredContent | CombinationContent | types.CallToolResult], + Awaitable[ + UnstructuredContent + | StructuredContent + | CombinationContent + | types.CallToolResult + | types.CreateTaskResult + ], ], ): logger.debug("Registering handler for CallToolRequest") @@ -506,6 +529,9 @@ async def handler(req: types.CallToolRequest): maybe_structured_content: StructuredContent | None if isinstance(results, types.CallToolResult): return types.ServerResult(results) + elif isinstance(results, types.CreateTaskResult): + # Task-augmented execution returns task info instead of result + return types.ServerResult(results) elif isinstance(results, tuple) and len(results) == 2: # tool returned both structured and unstructured content unstructured_content, maybe_structured_content = cast(CombinationContent, results) @@ -666,13 +692,14 @@ async def _handle_message( async def _handle_request( self, message: RequestResponder[types.ClientRequest, types.ServerResult], - req: Any, + req: types.ClientRequestType, session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool, ): logger.info("Processing request of type %s", type(req).__name__) - if handler := self.request_handlers.get(type(req)): # type: ignore + + if handler := self.request_handlers.get(type(req)): logger.debug("Dispatching request of type %s", type(req).__name__) token = None @@ -692,6 +719,7 @@ async def _handle_request( message.request_meta, session, lifespan_context, + Experimental(task_metadata=message.request_params.task if message.request_params else None), request=request_data, ) ) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index f3006e7d5..090fdff69 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,20 +1,30 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Generic from typing_extensions import TypeVar from mcp.shared.session import BaseSession -from mcp.types import RequestId, RequestParams +from mcp.types import RequestId, RequestParams, TaskMetadata SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) +@dataclass +class Experimental: + task_metadata: TaskMetadata | None = None + + @property + def is_task(self) -> bool: + return self.task_metadata is not None + + @dataclass class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): request_id: RequestId meta: RequestParams.Meta | None session: SessionT lifespan_context: LifespanContextT + experimental: Experimental = field(default_factory=Experimental) request: RequestT | None = None diff --git a/src/mcp/shared/experimental/__init__.py b/src/mcp/shared/experimental/__init__.py new file mode 100644 index 000000000..9bb0f72c6 --- /dev/null +++ b/src/mcp/shared/experimental/__init__.py @@ -0,0 +1,8 @@ +"""Experimental MCP features. + +WARNING: These APIs are experimental and may change without notice. +""" + +from mcp.shared.experimental import tasks + +__all__ = ["tasks"] diff --git a/src/mcp/shared/experimental/tasks/__init__.py b/src/mcp/shared/experimental/tasks/__init__.py new file mode 100644 index 000000000..9d7cf2eed --- /dev/null +++ b/src/mcp/shared/experimental/tasks/__init__.py @@ -0,0 +1,38 @@ +""" +Experimental task management for MCP. + +This module provides: +- TaskStore: Abstract interface for task state storage +- TaskContext: Context object for task work to interact with state/notifications +- InMemoryTaskStore: Reference implementation for testing/development +- Helper functions: run_task, is_terminal, create_task_state, generate_task_id + +Architecture: +- TaskStore is pure storage - it doesn't know about execution +- TaskContext wraps store + session, providing a clean API for task work +- run_task is optional convenience for spawning in-process tasks + +WARNING: These APIs are experimental and may change without notice. +""" + +from mcp.shared.experimental.tasks.context import TaskContext +from mcp.shared.experimental.tasks.helpers import ( + create_task_state, + generate_task_id, + is_terminal, + run_task, + task_execution, +) +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.experimental.tasks.store import TaskStore + +__all__ = [ + "TaskStore", + "TaskContext", + "InMemoryTaskStore", + "run_task", + "task_execution", + "is_terminal", + "create_task_state", + "generate_task_id", +] diff --git a/src/mcp/shared/experimental/tasks/context.py b/src/mcp/shared/experimental/tasks/context.py new file mode 100644 index 000000000..3c9c7831c --- /dev/null +++ b/src/mcp/shared/experimental/tasks/context.py @@ -0,0 +1,140 @@ +""" +TaskContext - Context for task work to interact with state and notifications. +""" + +from typing import TYPE_CHECKING + +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import ( + Result, + ServerNotification, + Task, + TaskStatusNotification, + TaskStatusNotificationParams, +) + +if TYPE_CHECKING: + from mcp.server.session import ServerSession + + +class TaskContext: + """ + Context provided to task work for state management and notifications. + + This wraps a TaskStore and optional session, providing a clean API + for task work to update status, complete, fail, and send notifications. + + Example: + async def my_task_work(ctx: TaskContext) -> CallToolResult: + await ctx.update_status("Starting processing...") + + for i, item in enumerate(items): + await ctx.update_status(f"Processing item {i+1}/{len(items)}") + if ctx.is_cancelled: + return CallToolResult(content=[TextContent(type="text", text="Cancelled")]) + process(item) + + return CallToolResult(content=[TextContent(type="text", text="Done!")]) + """ + + def __init__( + self, + task: Task, + store: TaskStore, + session: "ServerSession | None" = None, + ): + self._task = task + self._store = store + self._session = session + self._cancelled = False + + @property + def task_id(self) -> str: + """The task identifier.""" + return self._task.taskId + + @property + def task(self) -> Task: + """The current task state.""" + return self._task + + @property + def is_cancelled(self) -> bool: + """Whether cancellation has been requested.""" + return self._cancelled + + def request_cancellation(self) -> None: + """ + Request cancellation of this task. + + This sets is_cancelled=True. Task work should check this + periodically and exit gracefully if set. + """ + self._cancelled = True + + async def update_status(self, message: str, *, notify: bool = True) -> None: + """ + Update the task's status message. + + Args: + message: The new status message + notify: Whether to send a notification to the client + """ + self._task = await self._store.update_task( + self.task_id, + status_message=message, + ) + if notify: + await self._send_notification() + + async def complete(self, result: Result, *, notify: bool = True) -> None: + """ + Mark the task as completed with the given result. + + Args: + result: The task result + notify: Whether to send a notification to the client + """ + await self._store.store_result(self.task_id, result) + self._task = await self._store.update_task( + self.task_id, + status="completed", + ) + if notify: + await self._send_notification() + + async def fail(self, error: str, *, notify: bool = True) -> None: + """ + Mark the task as failed with an error message. + + Args: + error: The error message + notify: Whether to send a notification to the client + """ + self._task = await self._store.update_task( + self.task_id, + status="failed", + status_message=error, + ) + if notify: + await self._send_notification() + + async def _send_notification(self) -> None: + """Send a task status notification to the client.""" + if self._session is None: + return + + await self._session.send_notification( + ServerNotification( + TaskStatusNotification( + params=TaskStatusNotificationParams( + taskId=self._task.taskId, + status=self._task.status, + statusMessage=self._task.statusMessage, + createdAt=self._task.createdAt, + ttl=self._task.ttl, + pollInterval=self._task.pollInterval, + ) + ) + ) + ) diff --git a/src/mcp/shared/experimental/tasks/helpers.py b/src/mcp/shared/experimental/tasks/helpers.py new file mode 100644 index 000000000..23f21d735 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/helpers.py @@ -0,0 +1,187 @@ +""" +Helper functions for task management. +""" + +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import asynccontextmanager +from datetime import UTC, datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +from anyio.abc import TaskGroup + +from mcp.shared.experimental.tasks.context import TaskContext +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import CreateTaskResult, Result, Task, TaskMetadata, TaskStatus + +if TYPE_CHECKING: + from mcp.server.session import ServerSession + + +def is_terminal(status: TaskStatus) -> bool: + """ + Check if a task status represents a terminal state. + + Terminal states are those where the task has finished and will not change. + + Args: + status: The task status to check + + Returns: + True if the status is terminal (completed, failed, or cancelled) + """ + return status in ("completed", "failed", "cancelled") + + +def generate_task_id() -> str: + """Generate a unique task ID.""" + return str(uuid4()) + + +def create_task_state( + metadata: TaskMetadata, + task_id: str | None = None, +) -> Task: + """ + Create a Task object with initial state. + + This is a helper for TaskStore implementations. + + Args: + metadata: Task metadata + task_id: Optional task ID (generated if not provided) + + Returns: + A new Task in "working" status + """ + return Task( + taskId=task_id or generate_task_id(), + status="working", + createdAt=datetime.now(UTC), + ttl=metadata.ttl, + pollInterval=500, # Default 500ms poll interval + ) + + +@asynccontextmanager +async def task_execution( + task_id: str, + store: TaskStore, + session: "ServerSession | None" = None, +) -> AsyncIterator[TaskContext]: + """ + Context manager for safe task execution. + + Loads a task from the store and provides a TaskContext for the work. + If an unhandled exception occurs, the task is automatically marked as failed + and the exception is suppressed (since the failure is captured in task state). + + This is the recommended pattern for executing task work, especially in + distributed scenarios where the worker may be a separate process. + + Args: + task_id: The task identifier to execute + store: The task store (must be accessible by the worker) + session: Optional session for sending notifications (often None for workers) + + Yields: + TaskContext for updating status and completing/failing the task + + Raises: + ValueError: If the task is not found in the store + + Example (in-memory): + async def work(): + async with task_execution(task.taskId, store) as ctx: + await ctx.update_status("Processing...") + result = await do_work() + await ctx.complete(result) + + task_group.start_soon(work) + + Example (distributed worker): + async def worker_process(task_id: str): + store = RedisTaskStore(redis_url) + async with task_execution(task_id, store) as ctx: + await ctx.update_status("Working...") + result = await do_work() + await ctx.complete(result) + """ + task = await store.get_task(task_id) + if task is None: + raise ValueError(f"Task {task_id} not found") + + ctx = TaskContext(task, store, session) + try: + yield ctx + except Exception as e: + # Auto-fail the task if an exception occurs and task isn't already terminal + # Exception is suppressed since failure is captured in task state + if not is_terminal(ctx.task.status): + await ctx.fail(str(e), notify=session is not None) + # Don't re-raise - the failure is recorded in task state + + +async def run_task( + task_group: TaskGroup, + store: TaskStore, + metadata: TaskMetadata, + work: Callable[[TaskContext], Awaitable[Result]], + *, + session: "ServerSession | None" = None, + task_id: str | None = None, +) -> tuple[CreateTaskResult, TaskContext]: + """ + Create a task and spawn work to execute it. + + This is a convenience helper for in-process task execution. + For distributed systems, you'll want to handle task creation + and execution separately. + + Args: + task_group: The anyio TaskGroup to spawn work in + store: The task store for state management + metadata: Task metadata (ttl, etc.) + work: Async function that does the actual work + session: Optional session for sending notifications + task_id: Optional task ID (generated if not provided) + + Returns: + Tuple of (CreateTaskResult to return to client, TaskContext for cancellation) + + Example: + async with anyio.create_task_group() as tg: + @server.call_tool() + async def handle_tool(name: str, args: dict): + ctx = server.request_context + if ctx.experimental.is_task: + result, task_ctx = await run_task( + tg, + store, + ctx.experimental.task_metadata, + lambda ctx: do_long_work(ctx, args), + session=ctx.session, + ) + # Optionally store task_ctx for cancellation handling + return result + else: + return await do_work_sync(args) + """ + task = await store.create_task(metadata, task_id) + ctx = TaskContext(task, store, session) + + async def execute() -> None: + try: + result = await work(ctx) + # Only complete if not already in terminal state (e.g., cancelled) + if not is_terminal(ctx.task.status): + await ctx.complete(result) + except Exception as e: + # Only fail if not already in terminal state + if not is_terminal(ctx.task.status): + await ctx.fail(str(e)) + + # Spawn the work in the task group + task_group.start_soon(execute) + + return CreateTaskResult(task=task), ctx diff --git a/src/mcp/shared/experimental/tasks/in_memory_task_store.py b/src/mcp/shared/experimental/tasks/in_memory_task_store.py new file mode 100644 index 000000000..edd4d2f5c --- /dev/null +++ b/src/mcp/shared/experimental/tasks/in_memory_task_store.py @@ -0,0 +1,187 @@ +""" +In-memory implementation of TaskStore for demonstration purposes. + +This implementation stores all tasks in memory and provides automatic cleanup +based on the TTL duration specified in the task metadata using lazy expiration. + +Note: This is not suitable for production use as all data is lost on restart. +For production, consider implementing TaskStore with a database or distributed cache. +""" + +from dataclasses import dataclass, field +from datetime import UTC, datetime, timedelta + +from mcp.shared.experimental.tasks.helpers import create_task_state, is_terminal +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.types import Result, Task, TaskMetadata, TaskStatus + + +@dataclass +class StoredTask: + """Internal storage representation of a task.""" + + task: Task + result: Result | None = None + # Time when this task should be removed (None = never) + expires_at: datetime | None = field(default=None) + + +class InMemoryTaskStore(TaskStore): + """ + A simple in-memory implementation of TaskStore. + + Features: + - Automatic TTL-based cleanup (lazy expiration) + - Thread-safe for single-process async use + - Pagination support for list_tasks + + Limitations: + - All data lost on restart + - Not suitable for distributed systems + - No persistence + + For production, implement TaskStore with Redis, PostgreSQL, etc. + """ + + def __init__(self, page_size: int = 10) -> None: + self._tasks: dict[str, StoredTask] = {} + self._page_size = page_size + + def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None: + """Calculate expiry time from TTL in milliseconds.""" + if ttl_ms is None: + return None + return datetime.now(UTC) + timedelta(milliseconds=ttl_ms) + + def _is_expired(self, stored: StoredTask) -> bool: + """Check if a task has expired.""" + if stored.expires_at is None: + return False + return datetime.now(UTC) >= stored.expires_at + + def _cleanup_expired(self) -> None: + """Remove all expired tasks. Called lazily during access operations.""" + expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)] + for task_id in expired_ids: + del self._tasks[task_id] + + async def create_task( + self, + metadata: TaskMetadata, + task_id: str | None = None, + ) -> Task: + """Create a new task with the given metadata.""" + # Cleanup expired tasks on access + self._cleanup_expired() + + task = create_task_state(metadata, task_id) + + if task.taskId in self._tasks: + raise ValueError(f"Task with ID {task.taskId} already exists") + + stored = StoredTask( + task=task, + expires_at=self._calculate_expiry(metadata.ttl), + ) + self._tasks[task.taskId] = stored + + # Return a copy to prevent external modification + return Task(**task.model_dump()) + + async def get_task(self, task_id: str) -> Task | None: + """Get a task by ID.""" + # Cleanup expired tasks on access + self._cleanup_expired() + + stored = self._tasks.get(task_id) + if stored is None: + return None + + # Return a copy to prevent external modification + return Task(**stored.task.model_dump()) + + async def update_task( + self, + task_id: str, + status: TaskStatus | None = None, + status_message: str | None = None, + ) -> Task: + """Update a task's status and/or message.""" + stored = self._tasks.get(task_id) + if stored is None: + raise ValueError(f"Task with ID {task_id} not found") + + if status is not None: + stored.task.status = status + + if status_message is not None: + stored.task.statusMessage = status_message + + # If task is now terminal and has TTL, reset expiry timer + if status is not None and is_terminal(status) and stored.task.ttl is not None: + stored.expires_at = self._calculate_expiry(stored.task.ttl) + + return Task(**stored.task.model_dump()) + + async def store_result(self, task_id: str, result: Result) -> None: + """Store the result for a task.""" + stored = self._tasks.get(task_id) + if stored is None: + raise ValueError(f"Task with ID {task_id} not found") + + stored.result = result + + async def get_result(self, task_id: str) -> Result | None: + """Get the stored result for a task.""" + stored = self._tasks.get(task_id) + if stored is None: + return None + + return stored.result + + async def list_tasks( + self, + cursor: str | None = None, + ) -> tuple[list[Task], str | None]: + """List tasks with pagination.""" + # Cleanup expired tasks on access + self._cleanup_expired() + + all_task_ids = list(self._tasks.keys()) + + start_index = 0 + if cursor is not None: + try: + cursor_index = all_task_ids.index(cursor) + start_index = cursor_index + 1 + except ValueError: + raise ValueError(f"Invalid cursor: {cursor}") + + page_task_ids = all_task_ids[start_index : start_index + self._page_size] + tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids] + + # Determine next cursor + next_cursor = None + if start_index + self._page_size < len(all_task_ids) and page_task_ids: + next_cursor = page_task_ids[-1] + + return tasks, next_cursor + + async def delete_task(self, task_id: str) -> bool: + """Delete a task.""" + if task_id not in self._tasks: + return False + + del self._tasks[task_id] + return True + + # --- Testing/debugging helpers --- + + def cleanup(self) -> None: + """Cleanup all tasks (useful for testing or graceful shutdown).""" + self._tasks.clear() + + def get_all_tasks(self) -> list[Task]: + """Get all tasks (useful for debugging). Returns copies to prevent modification.""" + self._cleanup_expired() + return [Task(**stored.task.model_dump()) for stored in self._tasks.values()] diff --git a/src/mcp/shared/experimental/tasks/store.py b/src/mcp/shared/experimental/tasks/store.py new file mode 100644 index 000000000..58d335c96 --- /dev/null +++ b/src/mcp/shared/experimental/tasks/store.py @@ -0,0 +1,124 @@ +""" +TaskStore - Abstract interface for task state storage. +""" + +from abc import ABC, abstractmethod + +from mcp.types import Result, Task, TaskMetadata, TaskStatus + + +class TaskStore(ABC): + """ + Abstract interface for task state storage. + + This is a pure storage interface - it doesn't manage execution. + Implementations can use in-memory storage, databases, Redis, etc. + + All methods are async to support various backends. + """ + + @abstractmethod + async def create_task( + self, + metadata: TaskMetadata, + task_id: str | None = None, + ) -> Task: + """ + Create a new task. + + Args: + metadata: Task metadata (ttl, etc.) + task_id: Optional task ID. If None, implementation should generate one. + + Returns: + The created Task with status="working" + + Raises: + ValueError: If task_id already exists + """ + + @abstractmethod + async def get_task(self, task_id: str) -> Task | None: + """ + Get a task by ID. + + Args: + task_id: The task identifier + + Returns: + The Task, or None if not found + """ + + @abstractmethod + async def update_task( + self, + task_id: str, + status: TaskStatus | None = None, + status_message: str | None = None, + ) -> Task: + """ + Update a task's status and/or message. + + Args: + task_id: The task identifier + status: New status (if changing) + status_message: New status message (if changing) + + Returns: + The updated Task + + Raises: + ValueError: If task not found + """ + + @abstractmethod + async def store_result(self, task_id: str, result: Result) -> None: + """ + Store the result for a task. + + Args: + task_id: The task identifier + result: The result to store + + Raises: + ValueError: If task not found + """ + + @abstractmethod + async def get_result(self, task_id: str) -> Result | None: + """ + Get the stored result for a task. + + Args: + task_id: The task identifier + + Returns: + The stored Result, or None if not available + """ + + @abstractmethod + async def list_tasks( + self, + cursor: str | None = None, + ) -> tuple[list[Task], str | None]: + """ + List tasks with pagination. + + Args: + cursor: Optional cursor for pagination + + Returns: + Tuple of (tasks, next_cursor). next_cursor is None if no more pages. + """ + + @abstractmethod + async def delete_task(self, task_id: str) -> bool: + """ + Delete a task. + + Args: + task_id: The task identifier + + Returns: + True if deleted, False if not found + """ diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3b2cd3ecb..b62e531f8 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -81,9 +81,11 @@ def __init__( ]""", on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], message_metadata: MessageMetadata = None, + request_params: RequestParams | None = None, ) -> None: self.request_id = request_id self.request_meta = request_meta + self.request_params = request_params self.request = request self.message_metadata = message_metadata self._session = session @@ -353,6 +355,7 @@ async def _receive_loop(self) -> None: session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, + request_params=validated_request.root.params, ) self._in_flight[responder.request_id] = responder await self._received_request(responder) diff --git a/src/mcp/types.py b/src/mcp/types.py index 871322740..6a50f2918 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from datetime import datetime from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel @@ -38,6 +39,13 @@ Role = Literal["user", "assistant"] RequestId = Annotated[int, Field(strict=True)] | str AnyFunction: TypeAlias = Callable[..., Any] +TaskHint = Literal["never", "optional", "always"] + + +class TaskMetadata(BaseModel): + model_config = ConfigDict(extra="allow") + + ttl: Annotated[int, Field(strict=True)] | None = None class RequestParams(BaseModel): @@ -52,6 +60,16 @@ class Meta(BaseModel): model_config = ConfigDict(extra="allow") + task: TaskMetadata | None = None + """ + If specified, the caller is requesting task-augmented execution for this request. + The request will return a CreateTaskResult immediately, and the actual result can be + retrieved later via tasks/result. + + Task augmentation is subject to capability negotiation - receivers MUST declare support + for task augmentation of specific request types in their capabilities. + """ + meta: Meta | None = Field(alias="_meta", default=None) @@ -262,6 +280,71 @@ class ElicitationCapability(BaseModel): model_config = ConfigDict(extra="allow") +class TasksListCapability(BaseModel): + """Capability for tasks listing operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksCancelCapability(BaseModel): + """Capability for tasks cancel operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksCreateMessageCapability(BaseModel): + """Capability for tasks create messages.""" + + model_config = ConfigDict(extra="allow") + + +class TasksSamplingCapability(BaseModel): + """Capability for tasks sampling operations.""" + + model_config = ConfigDict(extra="allow") + + createMessage: TasksCreateMessageCapability | None = None + + +class TasksCreateElicitationCapability(BaseModel): + """Capability for tasks create elicitation operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksElicitationCapability(BaseModel): + """Capability for tasks elicitation operations.""" + + model_config = ConfigDict(extra="allow") + + create: TasksCreateElicitationCapability | None = None + + +class ClientTasksRequestsCapability(BaseModel): + """Capability for tasks requests operations.""" + + model_config = ConfigDict(extra="allow") + + sampling: TasksSamplingCapability | None = None + + elicitation: TasksElicitationCapability | None = None + + +class ClientTasksCapability(BaseModel): + """Capability for client tasks operations.""" + + model_config = ConfigDict(extra="allow") + + list: TasksListCapability | None = None + """Whether this client supports tasks/list.""" + + cancel: TasksCancelCapability | None = None + """Whether this client supports tasks/cancel.""" + + requests: ClientTasksRequestsCapability | None = None + """Specifies which request types can be augmented with tasks.""" + + class ClientCapabilities(BaseModel): """Capabilities a client may support.""" @@ -273,6 +356,9 @@ class ClientCapabilities(BaseModel): """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None """Present if the client supports listing roots.""" + tasks: ClientTasksCapability | None = None + """Present if the client supports task-augmented requests.""" + model_config = ConfigDict(extra="allow") @@ -314,6 +400,37 @@ class CompletionsCapability(BaseModel): model_config = ConfigDict(extra="allow") +class TasksCallCapability(BaseModel): + """Capability for tasks call operations.""" + + model_config = ConfigDict(extra="allow") + + +class TasksToolsCapability(BaseModel): + """Capability for tasks tools operations.""" + + model_config = ConfigDict(extra="allow") + call: TasksCallCapability | None = None + + +class ServerTasksRequestsCapability(BaseModel): + """Capability for tasks requests operations.""" + + model_config = ConfigDict(extra="allow") + + tools: TasksToolsCapability | None = None + + +class ServerTasksCapability(BaseModel): + """Capability for server tasks operations.""" + + model_config = ConfigDict(extra="allow") + + list: TasksListCapability | None = None + cancel: TasksCancelCapability | None = None + requests: ServerTasksRequestsCapability | None = None + + class ServerCapabilities(BaseModel): """Capabilities that a server may support.""" @@ -329,9 +446,144 @@ class ServerCapabilities(BaseModel): """Present if the server offers any tools to call.""" completions: CompletionsCapability | None = None """Present if the server offers autocompletion suggestions for prompts and resources.""" + tasks: ServerTasksCapability | None = None + """Present if the server supports task-augmented requests.""" model_config = ConfigDict(extra="allow") +TaskStatus = Literal["working", "input_required", "completed", "failed", "cancelled"] + + +class RelatedTaskMetadata(BaseModel): + """ + Metadata for associating messages with a task. + + Include this in the `_meta` field under the key `io.modelcontextprotocol/related-task`. + """ + + model_config = ConfigDict(extra="allow") + taskId: str + + +class Task(BaseModel): + """Data associated with a task.""" + + model_config = ConfigDict(extra="allow") + + taskId: str + """The task identifier.""" + + status: TaskStatus + """Current task state.""" + + statusMessage: str | None = None + """ + Optional human-readable message describing the current task state. + This can provide context for any status, including: + - Reasons for "cancelled" status + - Summaries for "completed" status + - Diagnostic information for "failed" status (e.g., error details, what went wrong) + """ + + createdAt: datetime # Pydantic will enforce ISO 8601 and re-serialize as a string later + """ISO 8601 timestamp when the task was created.""" + + ttl: Annotated[int, Field(strict=True)] | None + """Actual retention duration from creation in milliseconds, null for unlimited.""" + + pollInterval: Annotated[int, Field(strict=True)] | None = None + + +class CreateTaskResult(Result): + """A response to a task-augmented request.""" + + task: Task + + +class GetTaskRequestParams(RequestParams): + model_config = ConfigDict(extra="allow") + taskId: str + """The task identifier to query.""" + + +class GetTaskRequest(Request[GetTaskRequestParams, Literal["tasks/get"]]): + """A request to retrieve the state of a task.""" + + method: Literal["tasks/get"] = "tasks/get" + + params: GetTaskRequestParams + + +class GetTaskResult(Result, Task): + """The response to a tasks/get request.""" + + +class GetTaskPayloadRequestParams(RequestParams): + model_config = ConfigDict(extra="allow") + + taskId: str + """The task identifier to retrieve results for.""" + + +class GetTaskPayloadRequest(Request[GetTaskPayloadRequestParams, Literal["tasks/result"]]): + """A request to retrieve the result of a completed task.""" + + method: Literal["tasks/result"] = "tasks/result" + params: GetTaskPayloadRequestParams + + +class GetTaskPayloadResult(Result): + """ + The response to a tasks/result request. + The structure matches the result type of the original request. + For example, a tools/call task would return the CallToolResult structure. + """ + + +class CancelTaskRequestParams(RequestParams): + model_config = ConfigDict(extra="allow") + + taskId: str + """The task identifier to cancel.""" + + +class CancelTaskRequest(Request[CancelTaskRequestParams, Literal["tasks/cancel"]]): + """A request to cancel a task.""" + + method: Literal["tasks/cancel"] = "tasks/cancel" + params: CancelTaskRequestParams + + +class CancelTaskResult(Result, Task): + """The response to a tasks/cancel request.""" + + +class ListTasksRequest(PaginatedRequest[Literal["tasks/list"]]): + """A request to retrieve a list of tasks.""" + + method: Literal["tasks/list"] = "tasks/list" + + +class ListTasksResult(PaginatedResult): + """The response to a tasks/list request.""" + + tasks: list[Task] + + +class TaskStatusNotificationParams(NotificationParams, Task): + """Parameters for a `notifications/tasks/status` notification.""" + + +class TaskStatusNotification(Notification[TaskStatusNotificationParams, Literal["notifications/tasks/status"]]): + """ + An optional notification from the receiver to the requestor, informing them that a task's status has changed. + Receivers are not required to send these notifications + """ + + method: Literal["notifications/tasks/status"] = "notifications/tasks/status" + params: TaskStatusNotificationParams + + class InitializeRequestParams(RequestParams): """Parameters for the initialize request.""" @@ -865,6 +1117,20 @@ class ToolAnnotations(BaseModel): of a memory tool is not. Default: true """ + + taskHint: TaskHint | None = None + """ + Indicates whether this tool supports task-augmented execution. + This allows clients to handle long-running operations through polling + the task system. + + - "never": Tool does not support task-augmented execution (default when absent) + - "optional": Tool may support task-augmented execution + - "always": Tool requires task-augmented execution + + Default: "never" + """ + model_config = ConfigDict(extra="allow") @@ -1230,10 +1496,14 @@ class RootsListChangedNotification( class CancelledNotificationParams(NotificationParams): """Parameters for cancellation notifications.""" - requestId: RequestId + requestId: RequestId | None = None """The ID of the request to cancel.""" reason: str | None = None """An optional string describing the reason for the cancellation.""" + + taskId: str | None = None + """Deprecated: Use the `tasks/cancel` request instead of this notification for task cancellation.""" + model_config = ConfigDict(extra="allow") @@ -1247,29 +1517,41 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n params: CancelledNotificationParams -class ClientRequest( - RootModel[ - PingRequest - | InitializeRequest - | CompleteRequest - | SetLevelRequest - | GetPromptRequest - | ListPromptsRequest - | ListResourcesRequest - | ListResourceTemplatesRequest - | ReadResourceRequest - | SubscribeRequest - | UnsubscribeRequest - | CallToolRequest - | ListToolsRequest - ] -): +ClientRequestType: TypeAlias = ( + PingRequest + | InitializeRequest + | CompleteRequest + | SetLevelRequest + | GetPromptRequest + | ListPromptsRequest + | ListResourcesRequest + | ListResourceTemplatesRequest + | ReadResourceRequest + | SubscribeRequest + | UnsubscribeRequest + | CallToolRequest + | ListToolsRequest + | GetTaskRequest + | GetTaskPayloadRequest + | ListTasksRequest + | CancelTaskRequest +) + + +class ClientRequest(RootModel[ClientRequestType]): pass -class ClientNotification( - RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification] -): +ClientNotificationType: TypeAlias = ( + CancelledNotification + | ProgressNotification + | InitializedNotification + | RootsListChangedNotification + | TaskStatusNotification +) + + +class ClientNotification(RootModel[ClientNotificationType]): pass @@ -1311,40 +1593,72 @@ class ElicitResult(Result): """ -class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): +ClientResultType: TypeAlias = ( + EmptyResult + | CreateMessageResult + | ListRootsResult + | ElicitResult + | GetTaskResult + | GetTaskPayloadResult + | ListTasksResult + | CancelTaskResult +) + + +class ClientResult(RootModel[ClientResultType]): pass -class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]): +ServerRequestType: TypeAlias = ( + PingRequest + | CreateMessageRequest + | ListRootsRequest + | ElicitRequest + | GetTaskRequest + | GetTaskPayloadRequest + | ListTasksRequest + | CancelTaskRequest +) + + +class ServerRequest(RootModel[ServerRequestType]): pass -class ServerNotification( - RootModel[ - CancelledNotification - | ProgressNotification - | LoggingMessageNotification - | ResourceUpdatedNotification - | ResourceListChangedNotification - | ToolListChangedNotification - | PromptListChangedNotification - ] -): +ServerNotificationType: TypeAlias = ( + CancelledNotification + | ProgressNotification + | LoggingMessageNotification + | ResourceUpdatedNotification + | ResourceListChangedNotification + | ToolListChangedNotification + | PromptListChangedNotification + | TaskStatusNotification +) + + +class ServerNotification(RootModel[ServerNotificationType]): pass -class ServerResult( - RootModel[ - EmptyResult - | InitializeResult - | CompleteResult - | GetPromptResult - | ListPromptsResult - | ListResourcesResult - | ListResourceTemplatesResult - | ReadResourceResult - | CallToolResult - | ListToolsResult - ] -): +ServerResultType: TypeAlias = ( + EmptyResult + | InitializeResult + | CompleteResult + | GetPromptResult + | ListPromptsResult + | ListResourcesResult + | ListResourceTemplatesResult + | ReadResourceResult + | CallToolResult + | ListToolsResult + | GetTaskResult + | GetTaskPayloadResult + | ListTasksResult + | CancelTaskResult + | CreateTaskResult +) + + +class ServerResult(RootModel[ServerResultType]): pass diff --git a/tests/experimental/__init__.py b/tests/experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/experimental/tasks/__init__.py b/tests/experimental/tasks/__init__.py new file mode 100644 index 000000000..6e8649d28 --- /dev/null +++ b/tests/experimental/tasks/__init__.py @@ -0,0 +1 @@ +"""Tests for MCP task support.""" diff --git a/tests/experimental/tasks/client/__init__.py b/tests/experimental/tasks/client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/experimental/tasks/client/test_tasks.py b/tests/experimental/tasks/client/test_tasks.py new file mode 100644 index 000000000..fc451a99b --- /dev/null +++ b/tests/experimental/tasks/client/test_tasks.py @@ -0,0 +1,508 @@ +"""Tests for the experimental client task methods (session.experimental).""" + +from dataclasses import dataclass, field +from typing import Any + +import anyio +import pytest +from anyio import Event +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.experimental.tasks import InMemoryTaskStore, task_execution +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + CallToolResult, + CancelTaskRequest, + CancelTaskResult, + ClientRequest, + ClientResult, + CreateTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ServerNotification, + ServerRequest, + TaskMetadata, + TextContent, + Tool, +) + + +@dataclass +class AppContext: + """Application context passed via lifespan_context.""" + + task_group: TaskGroup + store: InMemoryTaskStore + task_done_events: dict[str, Event] = field(default_factory=lambda: {}) + + +@pytest.mark.anyio +async def test_session_experimental_get_task() -> None: + """Test session.experimental.get_task() method.""" + # Note: We bypass the normal lifespan mechanism + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text="Done")]), + notify=False, + ) + done_event.set() + + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Sync")] + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create a task + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + task_id = create_result.task.taskId + + # Wait for task to complete + await app_context.task_done_events[task_id].wait() + + # Use session.experimental to get task status + task_status = await client_session.experimental.get_task(task_id) + + assert task_status.taskId == task_id + assert task_status.status == "completed" + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_session_experimental_get_task_result() -> None: + """Test session.experimental.get_task_result() method.""" + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text="Task result content")]), + notify=False, + ) + done_event.set() + + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Sync")] + + @server.experimental.get_task_result() + async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPayloadResult: + app = server.request_context.lifespan_context + result = await app.store.get_result(request.params.taskId) + if result is None: + raise ValueError(f"Result for task {request.params.taskId} not found") + assert isinstance(result, CallToolResult) + return GetTaskPayloadResult(**result.model_dump()) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create a task + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + task_id = create_result.task.taskId + + # Wait for task to complete + await app_context.task_done_events[task_id].wait() + + # Use TaskClient to get task result + task_result = await client_session.experimental.get_task_result(task_id, CallToolResult) + + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Task result content" + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_session_experimental_list_tasks() -> None: + """Test TaskClient.list_tasks() method.""" + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text="Done")]), + notify=False, + ) + done_event.set() + + app.task_group.start_soon(do_work) + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Sync")] + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + app = server.request_context.lifespan_context + tasks_list, next_cursor = await app.store.list_tasks(cursor=request.params.cursor if request.params else None) + return ListTasksResult(tasks=tasks_list, nextCursor=next_cursor) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create two tasks + for _ in range(2): + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + await app_context.task_done_events[create_result.task.taskId].wait() + + # Use TaskClient to list tasks + list_result = await client_session.experimental.list_tasks() + + assert len(list_result.tasks) == 2 + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_session_experimental_cancel_task() -> None: + """Test TaskClient.cancel_task() method.""" + server: Server[AppContext, Any] = Server("test-server") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [Tool(name="test_tool", description="Test", inputSchema={"type": "object"})] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + # Don't start any work - task stays in "working" status + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Sync")] + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + @server.experimental.cancel_task() + async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + await app.store.update_task(request.params.taskId, status="cancelled") + # CancelTaskResult extends Task, so we need to return the updated task info + updated_task = await app.store.get_task(request.params.taskId) + assert updated_task is not None + return CancelTaskResult( + taskId=updated_task.taskId, + status=updated_task.status, + createdAt=updated_task.createdAt, + ttl=updated_task.ttl, + ) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Create a task (but don't complete it) + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ) + ) + ), + CreateTaskResult, + ) + task_id = create_result.task.taskId + + # Verify task is working + status_before = await client_session.experimental.get_task(task_id) + assert status_before.status == "working" + + # Cancel the task + await client_session.experimental.cancel_task(task_id) + + # Verify task is cancelled + status_after = await client_session.experimental.get_task(task_id) + assert status_after.status == "cancelled" + + tg.cancel_scope.cancel() + + store.cleanup() diff --git a/tests/experimental/tasks/server/__init__.py b/tests/experimental/tasks/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/experimental/tasks/server/test_context.py b/tests/experimental/tasks/server/test_context.py new file mode 100644 index 000000000..f1232fddd --- /dev/null +++ b/tests/experimental/tasks/server/test_context.py @@ -0,0 +1,166 @@ +"""Tests for TaskContext and helper functions.""" + +import pytest + +from mcp.shared.experimental.tasks import ( + InMemoryTaskStore, + TaskContext, + create_task_state, +) +from mcp.types import CallToolResult, TaskMetadata, TextContent + +# --- TaskContext tests --- + + +@pytest.mark.anyio +async def test_task_context_properties() -> None: + """Test TaskContext basic properties.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + assert ctx.task_id == task.taskId + assert ctx.task.taskId == task.taskId + assert ctx.task.status == "working" + assert ctx.is_cancelled is False + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_update_status() -> None: + """Test TaskContext.update_status.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + await ctx.update_status("Processing...", notify=False) + + assert ctx.task.statusMessage == "Processing..." + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + assert retrieved.statusMessage == "Processing..." + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_update_status_multiple() -> None: + """Test multiple status updates.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + await ctx.update_status("Step 1...", notify=False) + assert ctx.task.statusMessage == "Step 1..." + + await ctx.update_status("Step 2...", notify=False) + assert ctx.task.statusMessage == "Step 2..." + + await ctx.update_status("Step 3...", notify=False) + assert ctx.task.statusMessage == "Step 3..." + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_complete() -> None: + """Test TaskContext.complete.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + result = CallToolResult(content=[TextContent(type="text", text="Done!")]) + await ctx.complete(result, notify=False) + + assert ctx.task.status == "completed" + + stored_result = await store.get_result(task.taskId) + assert stored_result == result + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_fail() -> None: + """Test TaskContext.fail.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + await ctx.fail("Something went wrong", notify=False) + + assert ctx.task.status == "failed" + assert ctx.task.statusMessage == "Something went wrong" + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_cancellation() -> None: + """Test TaskContext cancellation flag.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + assert ctx.is_cancelled is False + + ctx.request_cancellation() + + assert ctx.is_cancelled is True + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_context_no_notification_without_session() -> None: + """Test that notification doesn't fail when no session is provided.""" + store = InMemoryTaskStore() + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + ctx = TaskContext(task, store, session=None) + + # These should not raise even with notify=True (default) + await ctx.update_status("Status update") + await ctx.complete(CallToolResult(content=[TextContent(type="text", text="Done")])) + + store.cleanup() + + +# --- create_task_state helper tests --- + + +def test_create_task_state_generates_id() -> None: + """Test create_task_state generates a task ID.""" + metadata = TaskMetadata(ttl=60000) + task = create_task_state(metadata) + + assert task.taskId is not None + assert len(task.taskId) > 0 + assert task.status == "working" + assert task.ttl == 60000 + assert task.pollInterval == 500 # Default poll interval + + +def test_create_task_state_uses_provided_id() -> None: + """Test create_task_state uses provided task ID.""" + metadata = TaskMetadata(ttl=60000) + task = create_task_state(metadata, task_id="my-task-id") + + assert task.taskId == "my-task-id" + + +def test_create_task_state_null_ttl() -> None: + """Test create_task_state with null TTL.""" + metadata = TaskMetadata(ttl=None) + task = create_task_state(metadata) + + assert task.ttl is None + assert task.status == "working" + + +def test_create_task_state_has_created_at() -> None: + """Test create_task_state sets createdAt timestamp.""" + metadata = TaskMetadata(ttl=60000) + task = create_task_state(metadata) + + assert task.createdAt is not None diff --git a/tests/experimental/tasks/server/test_integration.py b/tests/experimental/tasks/server/test_integration.py new file mode 100644 index 000000000..e1d29915e --- /dev/null +++ b/tests/experimental/tasks/server/test_integration.py @@ -0,0 +1,372 @@ +"""End-to-end integration tests for tasks functionality. + +These tests demonstrate the full task lifecycle: +1. Client sends task-augmented request (tools/call with task metadata) +2. Server creates task and returns CreateTaskResult immediately +3. Background work executes (using task_execution context manager) +4. Client polls with tasks/get +5. Client retrieves result with tasks/result +""" + +from dataclasses import dataclass, field +from typing import Any + +import anyio +import pytest +from anyio import Event +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.experimental.tasks import InMemoryTaskStore, task_execution +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + CallToolResult, + ClientRequest, + ClientResult, + CreateTaskResult, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskRequestParams, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ServerNotification, + ServerRequest, + TaskMetadata, + TextContent, + Tool, + ToolAnnotations, +) + + +@dataclass +class AppContext: + """Application context passed via lifespan_context.""" + + task_group: TaskGroup + store: InMemoryTaskStore + # Events to signal when tasks complete (for testing without sleeps) + task_done_events: dict[str, Event] = field(default_factory=lambda: {}) + + +@pytest.mark.anyio +async def test_task_lifecycle_with_task_execution() -> None: + """ + Test the complete task lifecycle using the task_execution pattern. + + This demonstrates the recommended way to implement task-augmented tools: + 1. Create task in store + 2. Spawn work using task_execution() context manager + 3. Return CreateTaskResult immediately + 4. Work executes in background, auto-fails on exception + """ + # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message + server: Server[AppContext, Any] = Server("test-tasks") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="process_data", + description="Process data asynchronously", + inputSchema={ + "type": "object", + "properties": {"input": {"type": "string"}}, + }, + annotations=ToolAnnotations(taskHint="always"), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if name == "process_data" and ctx.experimental.is_task: + # 1. Create task in store + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + # 2. Create event to signal completion (for testing) + done_event = Event() + app.task_done_events[task.taskId] = done_event + + # 3. Define work function using task_execution for safety + async def do_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.update_status("Processing input...", notify=False) + # Simulate work + input_value = arguments.get("input", "") + result_text = f"Processed: {input_value.upper()}" + await task_ctx.complete( + CallToolResult(content=[TextContent(type="text", text=result_text)]), + notify=False, + ) + # Signal completion + done_event.set() + + # 4. Spawn work in task group (from lifespan_context) + app.task_group.start_soon(do_work) + + # 5. Return CreateTaskResult immediately + return CreateTaskResult(task=task) + + # Non-task execution path + return [TextContent(type="text", text="Sync result")] + + # Register task query handlers (delegate to store) + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + @server.experimental.get_task_result() + async def handle_get_task_result( + request: GetTaskPayloadRequest, + ) -> GetTaskPayloadResult: + app = server.request_context.lifespan_context + result = await app.store.get_result(request.params.taskId) + if result is None: + raise ValueError(f"Result for task {request.params.taskId} not found") + assert isinstance(result, CallToolResult) + # Return as GetTaskPayloadResult (which accepts extra fields) + return GetTaskPayloadResult(**result.model_dump()) + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + app = server.request_context.lifespan_context + tasks, next_cursor = await app.store.list_tasks(cursor=request.params.cursor if request.params else None) + return ListTasksResult(tasks=tasks, nextCursor=next_cursor) + + # Set up client-server communication + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + # Create app context with task group and store + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # === Step 1: Send task-augmented tool call === + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="process_data", + arguments={"input": "hello world"}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CreateTaskResult, + ) + + assert isinstance(create_result, CreateTaskResult) + assert create_result.task.status == "working" + task_id = create_result.task.taskId + + # === Step 2: Wait for task to complete === + await app_context.task_done_events[task_id].wait() + + task_status = await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))), + GetTaskResult, + ) + + assert task_status.taskId == task_id + assert task_status.status == "completed" + + # === Step 3: Retrieve the actual result === + task_result = await client_session.send_request( + ClientRequest(GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task_id))), + CallToolResult, + ) + + assert len(task_result.content) == 1 + content = task_result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Processed: HELLO WORLD" + + tg.cancel_scope.cancel() + + store.cleanup() + + +@pytest.mark.anyio +async def test_task_auto_fails_on_exception() -> None: + """Test that task_execution automatically fails the task on unhandled exception.""" + # Note: We bypass the normal lifespan mechanism and pass context directly to _handle_message + server: Server[AppContext, Any] = Server("test-tasks-failure") # type: ignore[assignment] + store = InMemoryTaskStore() + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="failing_task", + description="A task that fails", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent] | CreateTaskResult: + ctx = server.request_context + app = ctx.lifespan_context + + if name == "failing_task" and ctx.experimental.is_task: + task_metadata = ctx.experimental.task_metadata + assert task_metadata is not None + task = await app.store.create_task(task_metadata) + + # Create event to signal completion (for testing) + done_event = Event() + app.task_done_events[task.taskId] = done_event + + async def do_failing_work(): + async with task_execution(task.taskId, app.store) as task_ctx: + await task_ctx.update_status("About to fail...", notify=False) + raise RuntimeError("Something went wrong!") + # Note: complete() is never called, but task_execution + # will automatically call fail() due to the exception + # This line is reached because task_execution suppresses the exception + done_event.set() + + app.task_group.start_soon(do_failing_work) + return CreateTaskResult(task=task) + + return [TextContent(type="text", text="Sync")] + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + app = server.request_context.lifespan_context + task = await app.store.get_task(request.params.taskId) + if task is None: + raise ValueError(f"Task {request.params.taskId} not found") + return GetTaskResult( + taskId=task.taskId, + status=task.status, + statusMessage=task.statusMessage, + createdAt=task.createdAt, + ttl=task.ttl, + pollInterval=task.pollInterval, + ) + + # Set up streams + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(app_context: AppContext): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, app_context, raise_exceptions=False) + + async with anyio.create_task_group() as tg: + app_context = AppContext(task_group=tg, store=store) + tg.start_soon(run_server, app_context) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Send task request + create_result = await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="failing_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CreateTaskResult, + ) + + task_id = create_result.task.taskId + + # Wait for task to complete (even though it fails) + await app_context.task_done_events[task_id].wait() + + # Check that task was auto-failed + task_status = await client_session.send_request( + ClientRequest(GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))), + GetTaskResult, + ) + + assert task_status.status == "failed" + assert task_status.statusMessage == "Something went wrong!" + + tg.cancel_scope.cancel() + + store.cleanup() diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py new file mode 100644 index 000000000..2077d7196 --- /dev/null +++ b/tests/experimental/tasks/server/test_server.py @@ -0,0 +1,440 @@ +"""Tests for server-side task support (handlers, capabilities, integration).""" + +from datetime import UTC, datetime +from typing import Any + +import anyio +import pytest + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder +from mcp.types import ( + CallToolRequest, + CallToolRequestParams, + CallToolResult, + CancelTaskRequest, + CancelTaskRequestParams, + CancelTaskResult, + ClientRequest, + ClientResult, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskRequestParams, + GetTaskResult, + ListTasksRequest, + ListTasksResult, + ListToolsRequest, + ListToolsResult, + ServerNotification, + ServerRequest, + ServerResult, + Task, + TaskMetadata, + TextContent, + Tool, + ToolAnnotations, +) + +# --- Experimental handler tests --- + + +@pytest.mark.anyio +async def test_list_tasks_handler() -> None: + """Test that experimental list_tasks handler works.""" + server = Server("test") + + test_tasks = [ + Task( + taskId="task-1", + status="working", + createdAt=datetime.now(UTC), + ttl=60000, + pollInterval=1000, + ), + Task( + taskId="task-2", + status="completed", + createdAt=datetime.now(UTC), + ttl=60000, + pollInterval=1000, + ), + ] + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + return ListTasksResult(tasks=test_tasks) + + handler = server.request_handlers[ListTasksRequest] + request = ListTasksRequest(method="tasks/list") + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListTasksResult) + assert len(result.root.tasks) == 2 + assert result.root.tasks[0].taskId == "task-1" + assert result.root.tasks[1].taskId == "task-2" + + +@pytest.mark.anyio +async def test_get_task_handler() -> None: + """Test that experimental get_task handler works.""" + server = Server("test") + + @server.experimental.get_task() + async def handle_get_task(request: GetTaskRequest) -> GetTaskResult: + return GetTaskResult( + taskId=request.params.taskId, + status="working", + createdAt=datetime.now(UTC), + ttl=60000, + pollInterval=1000, + ) + + handler = server.request_handlers[GetTaskRequest] + request = GetTaskRequest( + method="tasks/get", + params=GetTaskRequestParams(taskId="test-task-123"), + ) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, GetTaskResult) + assert result.root.taskId == "test-task-123" + assert result.root.status == "working" + + +@pytest.mark.anyio +async def test_get_task_result_handler() -> None: + """Test that experimental get_task_result handler works.""" + server = Server("test") + + @server.experimental.get_task_result() + async def handle_get_task_result(request: GetTaskPayloadRequest) -> GetTaskPayloadResult: + return GetTaskPayloadResult() + + handler = server.request_handlers[GetTaskPayloadRequest] + request = GetTaskPayloadRequest( + method="tasks/result", + params=GetTaskPayloadRequestParams(taskId="test-task-123"), + ) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, GetTaskPayloadResult) + + +@pytest.mark.anyio +async def test_cancel_task_handler() -> None: + """Test that experimental cancel_task handler works.""" + server = Server("test") + + @server.experimental.cancel_task() + async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + return CancelTaskResult( + taskId=request.params.taskId, + status="cancelled", + createdAt=datetime.now(UTC), + ttl=60000, + ) + + handler = server.request_handlers[CancelTaskRequest] + request = CancelTaskRequest( + method="tasks/cancel", + params=CancelTaskRequestParams(taskId="test-task-123"), + ) + result = await handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, CancelTaskResult) + assert result.root.taskId == "test-task-123" + assert result.root.status == "cancelled" + + +# --- Server capabilities tests --- + + +@pytest.mark.anyio +async def test_server_capabilities_include_tasks() -> None: + """Test that server capabilities include tasks when handlers are registered.""" + server = Server("test") + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + return ListTasksResult(tasks=[]) + + @server.experimental.cancel_task() + async def handle_cancel_task(request: CancelTaskRequest) -> CancelTaskResult: + return CancelTaskResult( + taskId=request.params.taskId, + status="cancelled", + createdAt=datetime.now(UTC), + ttl=None, + ) + + capabilities = server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ) + + assert capabilities.tasks is not None + assert capabilities.tasks.list is not None + assert capabilities.tasks.cancel is not None + assert capabilities.tasks.requests is not None + assert capabilities.tasks.requests.tools is not None + + +@pytest.mark.anyio +async def test_server_capabilities_partial_tasks() -> None: + """Test capabilities with only some task handlers registered.""" + server = Server("test") + + @server.experimental.list_tasks() + async def handle_list_tasks(request: ListTasksRequest) -> ListTasksResult: + return ListTasksResult(tasks=[]) + + # Only list_tasks registered, not cancel_task + + capabilities = server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ) + + assert capabilities.tasks is not None + assert capabilities.tasks.list is not None + assert capabilities.tasks.cancel is None # Not registered + + +# --- Tool annotation tests --- + + +@pytest.mark.anyio +async def test_tool_with_task_hint_annotation() -> None: + """Test that tools can declare taskHint in annotations.""" + server = Server("test") + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="quick_tool", + description="Fast tool", + inputSchema={"type": "object", "properties": {}}, + annotations=ToolAnnotations(taskHint="never"), + ), + Tool( + name="long_tool", + description="Long running tool", + inputSchema={"type": "object", "properties": {}}, + annotations=ToolAnnotations(taskHint="always"), + ), + Tool( + name="flexible_tool", + description="Can be either", + inputSchema={"type": "object", "properties": {}}, + annotations=ToolAnnotations(taskHint="optional"), + ), + ] + + tools_handler = server.request_handlers[ListToolsRequest] + request = ListToolsRequest(method="tools/list") + result = await tools_handler(request) + + assert isinstance(result, ServerResult) + assert isinstance(result.root, ListToolsResult) + tools = result.root.tools + + assert tools[0].annotations is not None + assert tools[0].annotations.taskHint == "never" + assert tools[1].annotations is not None + assert tools[1].annotations.taskHint == "always" + assert tools[2].annotations is not None + assert tools[2].annotations.taskHint == "optional" + + +# --- Integration tests --- + + +@pytest.mark.anyio +async def test_task_metadata_in_call_tool_request() -> None: + """Test that task metadata is accessible via RequestContext when calling a tool.""" + server = Server("test") + captured_task_metadata: TaskMetadata | None = None + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="long_task", + description="A long running task", + inputSchema={"type": "object", "properties": {}}, + annotations=ToolAnnotations(taskHint="optional"), + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + nonlocal captured_task_metadata + ctx = server.request_context + captured_task_metadata = ctx.experimental.task_metadata + return [TextContent(type="text", text="done")] + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async with anyio.create_task_group() as tg: + + async def handle_messages(): + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}, False) + + tg.start_soon(handle_messages) + await anyio.sleep_forever() + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Call tool with task metadata + await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="long_task", + arguments={}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CallToolResult, + ) + + tg.cancel_scope.cancel() + + assert captured_task_metadata is not None + assert captured_task_metadata.ttl == 60000 + + +@pytest.mark.anyio +async def test_task_metadata_is_task_property() -> None: + """Test that RequestContext.experimental.is_task works correctly.""" + server = Server("test") + is_task_values: list[bool] = [] + + @server.list_tools() + async def list_tools(): + return [ + Tool( + name="test_tool", + description="Test tool", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: + ctx = server.request_context + is_task_values.append(ctx.experimental.is_task) + return [TextContent(type="text", text="done")] + + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def message_handler( + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async def run_server(): + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="test-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) as server_session: + async with anyio.create_task_group() as tg: + + async def handle_messages(): + async for message in server_session.incoming_messages: + await server._handle_message(message, server_session, {}, False) + + tg.start_soon(handle_messages) + await anyio.sleep_forever() + + async with anyio.create_task_group() as tg: + tg.start_soon(run_server) + + async with ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session: + await client_session.initialize() + + # Call without task metadata + await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams(name="test_tool", arguments={}), + ) + ), + CallToolResult, + ) + + # Call with task metadata + await client_session.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams( + name="test_tool", + arguments={}, + task=TaskMetadata(ttl=60000), + ), + ) + ), + CallToolResult, + ) + + tg.cancel_scope.cancel() + + assert len(is_task_values) == 2 + assert is_task_values[0] is False # First call without task + assert is_task_values[1] is True # Second call with task diff --git a/tests/experimental/tasks/server/test_store.py b/tests/experimental/tasks/server/test_store.py new file mode 100644 index 000000000..773136ec4 --- /dev/null +++ b/tests/experimental/tasks/server/test_store.py @@ -0,0 +1,231 @@ +"""Tests for InMemoryTaskStore.""" + +import pytest + +from mcp.shared.experimental.tasks import InMemoryTaskStore +from mcp.types import CallToolResult, TaskMetadata, TextContent + + +@pytest.mark.anyio +async def test_create_and_get() -> None: + """Test InMemoryTaskStore create and get operations.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + assert task.taskId is not None + assert task.status == "working" + assert task.ttl == 60000 + + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + assert retrieved.taskId == task.taskId + assert retrieved.status == "working" + + store.cleanup() + + +@pytest.mark.anyio +async def test_create_with_custom_id() -> None: + """Test InMemoryTaskStore create with custom task ID.""" + store = InMemoryTaskStore() + + task = await store.create_task( + metadata=TaskMetadata(ttl=60000), + task_id="my-custom-id", + ) + + assert task.taskId == "my-custom-id" + assert task.status == "working" + + retrieved = await store.get_task("my-custom-id") + assert retrieved is not None + assert retrieved.taskId == "my-custom-id" + + store.cleanup() + + +@pytest.mark.anyio +async def test_create_duplicate_id_raises() -> None: + """Test that creating a task with duplicate ID raises.""" + store = InMemoryTaskStore() + + await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") + + with pytest.raises(ValueError, match="already exists"): + await store.create_task(metadata=TaskMetadata(ttl=60000), task_id="duplicate") + + store.cleanup() + + +@pytest.mark.anyio +async def test_get_nonexistent_returns_none() -> None: + """Test that getting a nonexistent task returns None.""" + store = InMemoryTaskStore() + + retrieved = await store.get_task("nonexistent") + assert retrieved is None + + store.cleanup() + + +@pytest.mark.anyio +async def test_update_status() -> None: + """Test InMemoryTaskStore status updates.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + updated = await store.update_task(task.taskId, status="completed", status_message="All done!") + + assert updated.status == "completed" + assert updated.statusMessage == "All done!" + + retrieved = await store.get_task(task.taskId) + assert retrieved is not None + assert retrieved.status == "completed" + assert retrieved.statusMessage == "All done!" + + store.cleanup() + + +@pytest.mark.anyio +async def test_update_nonexistent_raises() -> None: + """Test that updating a nonexistent task raises.""" + store = InMemoryTaskStore() + + with pytest.raises(ValueError, match="not found"): + await store.update_task("nonexistent", status="completed") + + store.cleanup() + + +@pytest.mark.anyio +async def test_store_and_get_result() -> None: + """Test InMemoryTaskStore result storage and retrieval.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # Store result + result = CallToolResult(content=[TextContent(type="text", text="Result data")]) + await store.store_result(task.taskId, result) + + # Retrieve result + retrieved_result = await store.get_result(task.taskId) + assert retrieved_result == result + + store.cleanup() + + +@pytest.mark.anyio +async def test_get_result_nonexistent_returns_none() -> None: + """Test that getting result for nonexistent task returns None.""" + store = InMemoryTaskStore() + + result = await store.get_result("nonexistent") + assert result is None + + store.cleanup() + + +@pytest.mark.anyio +async def test_get_result_no_result_returns_none() -> None: + """Test that getting result when none stored returns None.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + result = await store.get_result(task.taskId) + assert result is None + + store.cleanup() + + +@pytest.mark.anyio +async def test_list_tasks() -> None: + """Test InMemoryTaskStore list operation.""" + store = InMemoryTaskStore() + + # Create multiple tasks + for _ in range(3): + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + tasks, next_cursor = await store.list_tasks() + assert len(tasks) == 3 + assert next_cursor is None # Less than page size + + store.cleanup() + + +@pytest.mark.anyio +async def test_list_tasks_pagination() -> None: + """Test InMemoryTaskStore pagination.""" + store = InMemoryTaskStore(page_size=2) + + # Create 5 tasks + for _ in range(5): + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # First page + tasks, next_cursor = await store.list_tasks() + assert len(tasks) == 2 + assert next_cursor is not None + + # Second page + tasks, next_cursor = await store.list_tasks(cursor=next_cursor) + assert len(tasks) == 2 + assert next_cursor is not None + + # Third page (last) + tasks, next_cursor = await store.list_tasks(cursor=next_cursor) + assert len(tasks) == 1 + assert next_cursor is None + + store.cleanup() + + +@pytest.mark.anyio +async def test_list_tasks_invalid_cursor() -> None: + """Test that invalid cursor raises.""" + store = InMemoryTaskStore() + + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + with pytest.raises(ValueError, match="Invalid cursor"): + await store.list_tasks(cursor="invalid-cursor") + + store.cleanup() + + +@pytest.mark.anyio +async def test_delete_task() -> None: + """Test InMemoryTaskStore delete operation.""" + store = InMemoryTaskStore() + + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + deleted = await store.delete_task(task.taskId) + assert deleted is True + + retrieved = await store.get_task(task.taskId) + assert retrieved is None + + # Delete non-existent + deleted = await store.delete_task(task.taskId) + assert deleted is False + + store.cleanup() + + +@pytest.mark.anyio +async def test_get_all_tasks_helper() -> None: + """Test the get_all_tasks debugging helper.""" + store = InMemoryTaskStore() + + await store.create_task(metadata=TaskMetadata(ttl=60000)) + await store.create_task(metadata=TaskMetadata(ttl=60000)) + + all_tasks = store.get_all_tasks() + assert len(all_tasks) == 2 + + store.cleanup()