diff --git a/package-lock.json b/package-lock.json index c94ab5c9c..7374cedcc 100644 --- a/package-lock.json +++ b/package-lock.json @@ -19,10 +19,11 @@ "express-rate-limit": "^7.5.0", "pkce-challenge": "^5.0.0", "raw-body": "^3.0.0", - "zod": "^3.23.8", + "zod": "^3.25.0", "zod-to-json-schema": "^3.24.1" }, "devDependencies": { + "@anthropic-ai/sdk": "^0.65.0", "@eslint/js": "^9.8.0", "@jest-mock/express": "^3.0.0", "@types/content-type": "^1.1.8", @@ -61,6 +62,27 @@ "node": ">=6.0.0" } }, + "node_modules/@anthropic-ai/sdk": { + "version": "0.65.0", + "resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.65.0.tgz", + "integrity": "sha512-zIdPOcrCVEI8t3Di40nH4z9EoeyGZfXbYSvWdDLsB/KkaSYMnEgC7gmcgWu83g2NTn1ZTpbMvpdttWDGGIk6zw==", + "dev": true, + "license": "MIT", + "dependencies": { + "json-schema-to-ts": "^3.1.1" + }, + "bin": { + "anthropic-ai-sdk": "bin/cli" + }, + "peerDependencies": { + "zod": "^3.25.0 || ^4.0.0" + }, + "peerDependenciesMeta": { + "zod": { + "optional": true + } + } + }, "node_modules/@babel/code-frame": { "version": "7.26.2", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", @@ -465,6 +487,16 @@ "@babel/core": "^7.0.0-0" } }, + "node_modules/@babel/runtime": { + "version": "7.28.4", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.4.tgz", + "integrity": "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, "node_modules/@babel/template": { "version": "7.27.0", "resolved": "https://registry.npmjs.org/@babel/template/-/template-7.27.0.tgz", @@ -4896,6 +4928,20 @@ "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==", "dev": true }, + "node_modules/json-schema-to-ts": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/json-schema-to-ts/-/json-schema-to-ts-3.1.1.tgz", + "integrity": "sha512-+DWg8jCJG2TEnpy7kOm/7/AxaYoaRbjVB4LFZLySZlWn8exGs3A4OLJR966cVvU26N7X9TWxl+Jsw7dzAqKT6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/runtime": "^7.18.3", + "ts-algebra": "^2.0.0" + }, + "engines": { + "node": ">=16" + } + }, "node_modules/json-schema-traverse": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", @@ -6238,6 +6284,13 @@ "node": ">=0.6" } }, + "node_modules/ts-algebra": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ts-algebra/-/ts-algebra-2.0.0.tgz", + "integrity": "sha512-FPAhNPFMrkwz76P7cdjdmiShwMynZYN6SgOujD1urY4oNm80Ou9oMdmbR45LotcKOXoy7wSmHkRFE6Mxbrhefw==", + "dev": true, + "license": "MIT" + }, "node_modules/ts-api-utils": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-1.3.0.tgz", @@ -6639,9 +6692,9 @@ } }, "node_modules/zod": { - "version": "3.24.1", - "resolved": "https://registry.npmjs.org/zod/-/zod-3.24.1.tgz", - "integrity": "sha512-muH7gBL9sI1nciMZV67X5fTKKBLtwpZ5VBp1vsOQzj1MhrBZ4wlVCm3gedKZWLp0Oyel8sIGfeiz54Su+OVT+A==", + "version": "3.25.76", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", + "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", "license": "MIT", "funding": { "url": "https://github.com/sponsors/colinhacks" diff --git a/package.json b/package.json index b5b9b8ec9..cb4e2d2ea 100644 --- a/package.json +++ b/package.json @@ -72,10 +72,11 @@ "express-rate-limit": "^7.5.0", "pkce-challenge": "^5.0.0", "raw-body": "^3.0.0", - "zod": "^3.23.8", + "zod": "^3.25.0", "zod-to-json-schema": "^3.24.1" }, "devDependencies": { + "@anthropic-ai/sdk": "^0.65.0", "@eslint/js": "^9.8.0", "@jest-mock/express": "^3.0.0", "@types/content-type": "^1.1.8", diff --git a/src/examples/backfill/backfillSampling.ts b/src/examples/backfill/backfillSampling.ts new file mode 100644 index 000000000..0c8c97827 --- /dev/null +++ b/src/examples/backfill/backfillSampling.ts @@ -0,0 +1,356 @@ +/* + This example implements an stdio MCP proxy that backfills context-agnostic sampling requests using the Claude API. + + Usage: + npx -y @modelcontextprotocol/inspector \ + npx -y --silent tsx src/examples/backfill/backfillSampling.ts -- \ + npx -y --silent @modelcontextprotocol/server-everything +*/ + +import { Anthropic } from "@anthropic-ai/sdk"; +import { + Base64ImageSource, + ContentBlock, + ContentBlockParam, + TextBlockParam, + ImageBlockParam, + Tool as ClaudeTool, + ToolChoiceAuto, + ToolChoiceAny, + ToolChoiceTool, + ToolChoiceNone, +} from "@anthropic-ai/sdk/resources/messages.js"; +import { StdioServerTransport } from '../../server/stdio.js'; +import { StdioClientTransport } from '../../client/stdio.js'; +import { + CancelledNotification, + CancelledNotificationSchema, + isInitializeRequest, + isJSONRPCRequest, + ElicitRequest, + ElicitRequestSchema, + CreateMessageRequest, + CreateMessageRequestSchema, + CreateMessageResult, + JSONRPCResponse, + isInitializedNotification, + CallToolRequest, + CallToolRequestSchema, + isJSONRPCNotification, + Tool, + ToolCallContent, + LoggingMessageNotification, + JSONRPCNotification, + AssistantMessageContent, + UserMessageContent, + ElicitResult, + ElicitResultSchema, +TextContent, +} from "../../types.js"; +import { Transport } from "../../shared/transport.js"; + +const DEFAULT_MAX_TOKENS = process.env.DEFAULT_MAX_TOKENS ? parseInt(process.env.DEFAULT_MAX_TOKENS) : 1000; + +// TODO: move to SDK + +const isCancelledNotification: (value: unknown) => value is CancelledNotification = + ((value: any) => CancelledNotificationSchema.safeParse(value).success) as any; + +const isCallToolRequest: (value: unknown) => value is CallToolRequest = + ((value: any) => CallToolRequestSchema.safeParse(value).success) as any; + +const isElicitRequest: (value: unknown) => value is ElicitRequest = + ((value: any) => ElicitRequestSchema.safeParse(value).success) as any; + +const isElicitResult: (value: unknown) => value is ElicitResult = + ((value: any) => ElicitResultSchema.safeParse(value).success) as any; + +const isCreateMessageRequest: (value: unknown) => value is CreateMessageRequest = + ((value: any) => CreateMessageRequestSchema.safeParse(value).success) as any; + +/** + * Converts MCP Tool definition to Claude API tool format + */ +function toolToClaudeFormat(tool: Tool): ClaudeTool { + return { + name: tool.name, + description: tool.description || "", + input_schema: tool.inputSchema, + }; +} + +/** + * Converts MCP ToolChoice to Claude API tool_choice format + */ +function toolChoiceToClaudeFormat(toolChoice: CreateMessageRequest['params']['toolChoice']): ToolChoiceAuto | ToolChoiceAny | ToolChoiceNone | ToolChoiceTool | undefined { + switch (toolChoice?.mode) { + case "auto": + return { type: "auto", disable_parallel_tool_use: toolChoice.disable_parallel_tool_use }; + case "required": + return { type: "any", disable_parallel_tool_use: toolChoice.disable_parallel_tool_use }; + case "none": + return { type: "none" }; + case undefined: + return undefined; + default: + throw new Error(`Unsupported toolChoice mode: ${toolChoice}`); + } +} + +function contentToMcp(content: ContentBlock): CreateMessageResult['content'] { + switch (content.type) { + case 'text': + return { type: 'text', text: content.text }; + case 'tool_use': + return { + type: 'tool_use', + id: content.id, + name: content.name, + input: content.input, + } as ToolCallContent; + default: + throw new Error(`[contentToMcp] Unsupported content type: ${(content as any).type}`); + } +} + +function stopReasonToMcp(reason: string | null): CreateMessageResult['stopReason'] { + switch (reason) { + case 'max_tokens': + return 'maxTokens'; + case 'stop_sequence': + return 'stopSequence'; + case 'tool_use': + return 'toolUse'; + case 'end_turn': + return 'endTurn'; + case null: + return undefined; + default: + throw new Error(`[stopReasonToMcp] Unsupported stop reason: ${reason}`); + } +} + + +function contentBlockFromMcp(content: AssistantMessageContent | UserMessageContent): ContentBlockParam { + switch (content.type) { + case 'text': + return {type: 'text', text: content.text}; + case 'image': + return { + type: 'image', + source: { + data: content.data, + media_type: content.mimeType as Base64ImageSource['media_type'], + type: 'base64', + }, + }; + case 'tool_result': + return { + type: 'tool_result', + tool_use_id: content.toolUseId, + content: content.content.map(c => { + if (c.type === 'text') { + return {type: 'text', text: c.text}; + } else if (c.type === 'image') { + return { + type: 'image', + source: { + type: 'base64', + data: c.data, + media_type: c.mimeType as Base64ImageSource['media_type'], + }, + }; + } else { + throw new Error(`[contentBlockFromMcp] Unsupported content type in tool_result: ${c.type}`); + } + }), + is_error: content.isError, + }; + case 'tool_use': + return { + type: 'tool_use', + id: content.id, + name: content.name, + input: content.input, + }; + case 'audio': + default: + throw new Error(`[contentBlockFromMcp] Unsupported content type: ${(content as any).type}`); + } +} + +function contentFromMcp(content: CreateMessageRequest['params']['messages'][number]['content']): ContentBlockParam[] { + // Handle both single content block and arrays + const contentArray = Array.isArray(content) ? content : [content]; + return contentArray.map(contentBlockFromMcp); +} + +export type NamedTransport = { + name: 'client' | 'server', + transport: T, +} + +export async function setupBackfill(client: NamedTransport, server: NamedTransport, api: Anthropic) { + const backfillMeta = await (async () => { + const models = new Set(); + let defaultModel: string | undefined; + for await (const info of api.models.list()) { + models.add(info.id); + if (info.id.indexOf('sonnet') >= 0 && defaultModel === undefined) { + defaultModel = info.id; + } + } + if (defaultModel === undefined) { + if (models.size === 0) { + throw new Error("No models available from the API"); + } + defaultModel = models.values().next().value; + } + return { + sampling_models: Array.from(models), + sampling_default_model: defaultModel, + }; + })(); + + function pickModel(preferences: CreateMessageRequest['params']['modelPreferences'] | undefined): string { + if (preferences?.hints) { + for (const hint of Object.values(preferences.hints)) { + if (hint.name !== undefined && backfillMeta.sampling_models.includes(hint.name)) { + return hint.name; + } + } + } + // TODO: linear model on preferences?.{intelligencePriority, speedPriority, costPriority} to pick betwen haiku, sonnet, opus. + return backfillMeta.sampling_default_model!; + } + + let clientSupportsSampling: boolean | undefined; + + const propagateMessage = (source: NamedTransport, target: NamedTransport) => { + source.transport.onmessage = async (message, extra) => { + if (isJSONRPCRequest(message)) { + + const sendInternalError = (errorMessage: string) => { + console.error(`[proxy -> ${source.name}]: Error: ${errorMessage}`); + source.transport.send({ + jsonrpc: "2.0", + id: message.id, + error: { + code: -32603, // Internal error + message: errorMessage, + }, + }, {relatedRequestId: message.id}); + }; + + if (isInitializeRequest(message)) { + if (!(clientSupportsSampling = !!message.params.capabilities.sampling)) { + message.params.capabilities.sampling = {} + message.params._meta = {...(message.params._meta ?? {}), ...backfillMeta}; + } + } else if (isCreateMessageRequest(message)) {// && !clientSupportsSampling) { + if ((message.params.includeContext ?? 'none') !== 'none') { + sendInternalError("includeContext != none not supported by MCP sampling backfill"); + return; + } + + try { + // Note that having tools + tool_choice = 'none' does not disable tools, unlike in OpenAI's API. + // We forcibly empty out the tools list in that case, which messes with the prompt caching. + const tools = message.params.toolChoice?.mode === 'none' ? undefined + : message.params.tools?.map(toolToClaudeFormat); + const tool_choice = toolChoiceToClaudeFormat(message.params.toolChoice); + + // TODO: switch to streaming if maxTokens is too large + // "Streaming is required when max_tokens is greater than 21,333 tokens" + const msg = await api.messages.create({ + model: pickModel(message.params.modelPreferences), + system: message.params.systemPrompt === undefined ? undefined : [ + { + type: "text", + text: message.params.systemPrompt + }, + ], + messages: message.params.messages.map(({role, content}) => ({ + role, + content: contentFromMcp(content) + })), + max_tokens: message.params.maxTokens ?? DEFAULT_MAX_TOKENS, + temperature: message.params.temperature, + stop_sequences: message.params.stopSequences, + tools: tools && tools.length > 0 ? tools : undefined, + tool_choice: tool_choice, + ...(message.params.metadata ?? {}), + }); + + source.transport.send({ + jsonrpc: "2.0", + id: message.id, + result: { + model: msg.model, + stopReason: stopReasonToMcp(msg.stop_reason), + role: 'assistant', // Always assistant in MCP responses + content: (Array.isArray(msg.content) ? msg.content : [msg.content]).map(contentToMcp), + _meta: { + usage: msg.usage, + }, + }, + }); + } catch (error) { + sendInternalError(`Error processing message: ${(error as Error).message}`); + } + return; + } + } else if (isJSONRPCNotification(message)) { + if (isInitializedNotification(message) && source.name === 'server') { + if (!clientSupportsSampling) { + message.params = {...(message.params ?? {}), _meta: {...(message.params?._meta ?? {}), ...backfillMeta}}; + } + } + } + + try { + const relatedRequestId = isCancelledNotification(message)? message.params.requestId : undefined; + await target.transport.send(message, {relatedRequestId}); + } catch (error) { + source.transport.send({ + jsonrpc: "2.0", + method: "notifications/message", + params: { + type: "log_message", + level: "error", + message: `Error sending message to ${target.name}: ${(error as Error).message}`, + } + }); + } + }; + }; + propagateMessage(server, client); + propagateMessage(client, server); + + const addErrorHandler = (transport: NamedTransport) => { + transport.transport.onerror = async (error: Error) => { + console.error(`[proxy]: Error from ${transport.name} transport:`, error); + }; + }; + + addErrorHandler(client); + addErrorHandler(server); + + await server.transport.start(); + await client.transport.start(); +} + +async function main() { + const args = process.argv.slice(2); + const client: NamedTransport = {name: 'client', transport: new StdioClientTransport({command: args[0], args: args.slice(1)})}; + const server: NamedTransport = {name: 'server', transport: new StdioServerTransport()}; + + const api = new Anthropic(); + await setupBackfill(client, server, api); + console.error("[proxy]: Transports started."); +} + +main().catch((error) => { + console.error("[proxy]: Fatal error:", error); + process.exit(1); +}); diff --git a/src/examples/server/toolLoopSampling.test.ts b/src/examples/server/toolLoopSampling.test.ts new file mode 100644 index 000000000..85b4358fc --- /dev/null +++ b/src/examples/server/toolLoopSampling.test.ts @@ -0,0 +1,481 @@ +/** + * Tests for toolLoopSampling.ts + * + * These tests verify that the localResearch tool correctly implements a tool loop + * by simulating an LLM that makes ripgrep and read tool calls. + */ + +import { Client } from "../../client/index.js"; +import { StdioClientTransport } from "../../client/stdio.js"; +import { + CreateMessageRequestSchema, + CreateMessageResult, + CallToolResultSchema, + ToolCallContent, + SamplingMessage, +} from "../../types.js"; +import { resolve } from "node:path"; + +describe("toolLoopSampling server", () => { + jest.setTimeout(30000); // 30 second timeout for integration tests + + let client: Client; + let transport: StdioClientTransport; + + beforeEach(() => { + // Create client with sampling capability + client = new Client( + { + name: "test-client", + version: "1.0.0", + }, + { + capabilities: { + sampling: { + tools: {}, // Indicate we support tool calling in sampling + }, + }, + } + ); + + // Create transport that spawns the toolLoopSampling server + transport = new StdioClientTransport({ + command: "npx", + args: [ + "-y", + "--silent", + "tsx", + resolve(__dirname, "toolLoopSampling.ts"), + ], + }); + }); + + afterEach(async () => { + await transport.close(); + }); + + test("should handle a tool loop with ripgrep and read", async () => { + // Track sampling request count to simulate different LLM responses + let samplingCallCount = 0; + + // Set up sampling handler that simulates an LLM + client.setRequestHandler( + CreateMessageRequestSchema, + async (request): Promise => { + samplingCallCount++; + + // Extract the last message to understand context + const messages = request.params.messages; + const lastMessage = messages[messages.length - 1]; + + // Helper to get content as array + const getContentArray = (content: any) => Array.isArray(content) ? content : [content]; + const lastContent = getContentArray(lastMessage.content)[0]; + + console.error( + `[test] Sampling call ${samplingCallCount}, messages: ${messages.length}, last message type: ${lastContent.type}` + ); + + // First call: Return tool_use for ripgrep + if (samplingCallCount === 1) { + return { + model: "test-model", + role: "assistant", + content: { + type: "tool_use", + id: "call_1", + name: "ripgrep", + input: { + pattern: "McpServer", + path: "src", + }, + } as ToolCallContent, + stopReason: "toolUse", + }; + } + + // Second call: After getting ripgrep results, return tool_use for read + if (samplingCallCount === 2) { + // Verify we got a tool result + expect(lastContent.type).toBe("tool_result"); + + return { + model: "test-model", + role: "assistant", + content: { + type: "tool_use", + id: "call_2", + name: "read", + input: { + path: "src/server/mcp.ts", + }, + } as ToolCallContent, + stopReason: "toolUse", + }; + } + + // Third call: After getting read results, return final answer + if (samplingCallCount === 3) { + // Verify we got another tool result + expect(lastContent.type).toBe("tool_result"); + + return { + model: "test-model", + role: "assistant", + content: { + type: "text", + text: "I found the McpServer class in src/server/mcp.ts. It's the main server class for MCP.", + }, + stopReason: "endTurn", + }; + } + + // Should not reach here + throw new Error( + `Unexpected sampling call count: ${samplingCallCount}` + ); + } + ); + + // Connect client to server + await client.connect(transport); + + // Call the localResearch tool + const result = await client.request( + { + method: "tools/call", + params: { + name: "localResearch", + arguments: { + query: "Find the McpServer class definition", + }, + }, + }, + CallToolResultSchema + ); + + // Verify the result + expect(result.content).toBeDefined(); + expect(result.content.length).toBeGreaterThan(0); + expect(result.content[0].type).toBe("text"); + + // Verify we got the expected response + if (result.content[0].type === "text") { + expect(result.content[0].text).toContain("McpServer"); + } + + // Verify we made exactly 3 sampling calls (tool loop worked correctly) + expect(samplingCallCount).toBe(3); + }); + + test("should handle errors in tool execution", async () => { + let samplingCallCount = 0; + + // Set up sampling handler that requests an invalid path + client.setRequestHandler( + CreateMessageRequestSchema, + async (request): Promise => { + samplingCallCount++; + + const messages = request.params.messages; + const lastMessage = messages[messages.length - 1]; + + // First call: Return tool_use for ripgrep with path outside CWD + if (samplingCallCount === 1) { + return { + model: "test-model", + role: "assistant", + content: { + type: "tool_use", + id: "call_1", + name: "ripgrep", + input: { + pattern: "test", + path: "../../etc/passwd", // Try to escape CWD + }, + } as ToolCallContent, + stopReason: "toolUse", + }; + } + + // Second call: Should receive error in tool result + if (samplingCallCount === 2) { + const getContentArray = (content: any) => Array.isArray(content) ? content : [content]; + const lastContent = getContentArray(lastMessage.content)[0]; + expect(lastContent.type).toBe("tool_result"); + if (lastContent.type === "tool_result") { + // Verify error is present in tool result + const content = lastContent.content as Record< + string, + unknown + >; + expect(content.error).toBeDefined(); + expect( + typeof content.error === "string" && + content.error.includes("outside the current directory") + ).toBe(true); + } + + // Return final answer acknowledging the error + return { + model: "test-model", + role: "assistant", + content: { + type: "text", + text: "I encountered an error: the path is outside the current directory.", + }, + stopReason: "endTurn", + }; + } + + throw new Error( + `Unexpected sampling call count: ${samplingCallCount}` + ); + } + ); + + await client.connect(transport); + + // Call the localResearch tool + const result = await client.request( + { + method: "tools/call", + params: { + name: "localResearch", + arguments: { + query: "Search outside current directory", + }, + }, + }, + CallToolResultSchema + ); + + // Verify we got a response (even though there was an error) + expect(result.content).toBeDefined(); + expect(result.content.length).toBeGreaterThan(0); + expect(result.content[0].type).toBe("text"); + }); + + test("should handle invalid tool names", async () => { + let samplingCallCount = 0; + + // Set up sampling handler that requests an unknown tool + client.setRequestHandler( + CreateMessageRequestSchema, + async (request): Promise => { + samplingCallCount++; + + const messages = request.params.messages; + const lastMessage = messages[messages.length - 1]; + + // First call: Return tool_use for unknown tool + if (samplingCallCount === 1) { + return { + model: "test-model", + role: "assistant", + content: { + type: "tool_use", + id: "call_1", + name: "unknown_tool", + input: { + foo: "bar", + }, + } as ToolCallContent, + stopReason: "toolUse", + }; + } + + // Second call: Should receive error in tool result + if (samplingCallCount === 2) { + const getContentArray = (content: any) => Array.isArray(content) ? content : [content]; + const lastContent = getContentArray(lastMessage.content)[0]; + expect(lastContent.type).toBe("tool_result"); + if (lastContent.type === "tool_result") { + const content = lastContent.content as Record< + string, + unknown + >; + expect(content.error).toBeDefined(); + expect( + typeof content.error === "string" && + content.error.includes("Unknown tool") + ).toBe(true); + } + + return { + model: "test-model", + role: "assistant", + content: { + type: "text", + text: "The requested tool does not exist.", + }, + stopReason: "endTurn", + }; + } + + throw new Error( + `Unexpected sampling call count: ${samplingCallCount}` + ); + } + ); + + await client.connect(transport); + + const result = await client.request( + { + method: "tools/call", + params: { + name: "localResearch", + arguments: { + query: "Use unknown tool", + }, + }, + }, + CallToolResultSchema + ); + + expect(result.content).toBeDefined(); + expect(samplingCallCount).toBe(2); + }); + + test("should handle malformed tool inputs", async () => { + let samplingCallCount = 0; + + // Set up sampling handler that sends malformed input + client.setRequestHandler( + CreateMessageRequestSchema, + async (request): Promise => { + samplingCallCount++; + + const messages = request.params.messages; + const lastMessage = messages[messages.length - 1]; + + // First call: Return tool_use with missing required fields + if (samplingCallCount === 1) { + return { + model: "test-model", + role: "assistant", + content: { + type: "tool_use", + id: "call_1", + name: "ripgrep", + input: { + // Missing 'pattern' and 'path' required fields + foo: "bar", + }, + } as ToolCallContent, + stopReason: "toolUse", + }; + } + + // Second call: Should receive validation error + if (samplingCallCount === 2) { + const getContentArray = (content: any) => Array.isArray(content) ? content : [content]; + const lastContent = getContentArray(lastMessage.content)[0]; + expect(lastContent.type).toBe("tool_result"); + if (lastContent.type === "tool_result") { + const content = lastContent.content as Record< + string, + unknown + >; + expect(content.error).toBeDefined(); + // Verify it's a validation error + expect( + typeof content.error === "string" && + content.error.includes("Invalid input") + ).toBe(true); + } + + return { + model: "test-model", + role: "assistant", + content: { + type: "text", + text: "I provided invalid input to the tool.", + }, + stopReason: "endTurn", + }; + } + + throw new Error( + `Unexpected sampling call count: ${samplingCallCount}` + ); + } + ); + + await client.connect(transport); + + const result = await client.request( + { + method: "tools/call", + params: { + name: "localResearch", + arguments: { + query: "Test malformed input", + }, + }, + }, + CallToolResultSchema + ); + + expect(result.content).toBeDefined(); + expect(samplingCallCount).toBe(2); + }); + + test("should respect maximum iteration limit", async () => { + let samplingCallCount = 0; + + // Set up sampling handler that keeps requesting tools indefinitely + client.setRequestHandler( + CreateMessageRequestSchema, + async (request): Promise => { + samplingCallCount++; + + // Always return tool calls (never final answer) + return { + model: "test-model", + role: "assistant", + content: { + type: "tool_use", + id: `call_${samplingCallCount}`, + name: "ripgrep", + input: { + pattern: "test", + path: "src", + }, + } as ToolCallContent, + stopReason: "toolUse", + }; + } + ); + + await client.connect(transport); + + // Call localResearch with infinite loop scenario + const result = await client.request( + { + method: "tools/call", + params: { + name: "localResearch", + arguments: { + query: "Infinite loop test", + }, + }, + }, + CallToolResultSchema + ); + + // Verify we got an error response (not a throw) + expect(result.content).toBeDefined(); + expect(result.content.length).toBeGreaterThan(0); + expect(result.content[0].type).toBe("text"); + + // Verify the error message mentions the iteration limit + if (result.content[0].type === "text") { + expect(result.content[0].text).toContain("Tool loop exceeded maximum iterations"); + } + + // Verify we hit the iteration limit (10 iterations as defined in toolLoopSampling.ts) + expect(samplingCallCount).toBe(20); + }); +}); diff --git a/src/examples/server/toolLoopSampling.ts b/src/examples/server/toolLoopSampling.ts new file mode 100644 index 000000000..1ebbfbf9f --- /dev/null +++ b/src/examples/server/toolLoopSampling.ts @@ -0,0 +1,484 @@ +/* + This example demonstrates a tool loop using MCP sampling with locally defined tools. + + It exposes a "localResearch" tool that uses an LLM with ripgrep and read capabilities + to intelligently search and read files in the current directory. + + Usage: + npx -y @modelcontextprotocol/inspector \ + npx -- -y --silent tsx src/examples/backfill/backfillSampling.ts \ + npx -y --silent tsx src/examples/server/toolLoopSampling.ts + + Then connect with an MCP client and call the "localResearch" tool with a query like: + "Find all TypeScript files that export a Server class" +*/ + +import { McpServer } from "../../server/mcp.js"; +import { StdioServerTransport } from "../../server/stdio.js"; +import { z } from "zod"; +import { spawn } from "node:child_process"; +import { readFile } from "node:fs/promises"; +import { resolve, relative } from "node:path"; +import type { + SamplingMessage, + Tool, + ToolCallContent, + CreateMessageResult, + CreateMessageRequest, + ToolResultContent, + CallToolResult, +} from "../../types.js"; + +const CWD = process.cwd(); + +/** + * Interface for tracking aggregated token usage across API calls. + */ +interface AggregatedUsage { + input_tokens: number; + output_tokens: number; + cache_creation_input_tokens: number; + cache_read_input_tokens: number; + api_calls: number; +} + +/** + * Zod schemas for validating tool inputs + */ +const RipgrepInputSchema = z.object({ + pattern: z.string(), + path: z.string(), +}); + +const ReadInputSchema = z.object({ + path: z.string(), + startLineInclusive: z.number().int().positive().optional(), + endLineInclusive: z.number().int().positive().optional(), +}); + +/** + * Ensures a path is canonical and within the current working directory. + * Throws an error if the path attempts to escape CWD. + */ +function ensureSafePath(inputPath: string): string { + const resolved = resolve(CWD, inputPath); + const rel = relative(CWD, resolved); + + // Check if the path escapes CWD (starts with .. or is absolute outside CWD) + if (rel.startsWith("..") || resolve(CWD, rel) !== resolved) { + throw new Error(`Path "${inputPath}" is outside the current directory`); + } + + return resolved; +} + + +function makeErrorCallToolResult(error: any): CallToolResult { + return { + content: [ + { + type: "text", + text: error instanceof Error ? `${error.message}\n${error.stack}` : `${error}`, + }, + ], + isError: true, + } +} + +/** + * Executes ripgrep to search for a pattern in files. + * Returns search results as a string. + */ +async function executeRipgrep( + server: McpServer, + pattern: string, + path: string +): Promise { + try { + await server.sendLoggingMessage({ + level: "info", + data: `Searching pattern "${pattern}" under ${path}`, + }); + + const safePath = ensureSafePath(path); + + const output = await new Promise((resolve, reject) => { + const command = ["rg", "--json", "--max-count", "50", "--", pattern, safePath]; + const rg = spawn(command[0], command.slice(1)); + + let stdout = ""; + let stderr = ""; + rg.stdout.on("data", (data) => stdout += data.toString()); + rg.stderr.on("data", (data) => stderr += data.toString()); + rg.on("close", (code) => { + if (code === 0 || code === 1) { + // code 1 means no matches, which is fine + resolve(stdout || "No matches found"); + } else { + reject(new Error(`ripgrep exited with code ${code}:\n${stderr}`)); + } + }); + rg.on("error", err => reject(new Error(`Failed to start \`${command.map(a => a.indexOf(' ') >= 0 ? `"${a}"` : a).join(' ')}\`: ${err.message}\n${stderr}`))); + }); + const structuredContent = { output }; + return { + content: [{ type: "text", text: JSON.stringify(structuredContent) }], + structuredContent, + }; + } catch (error) { + return makeErrorCallToolResult(error); + } +} + + +/** + * Reads a file from the filesystem, optionally within a line range. + * Returns file contents as a string. + */ +async function executeRead( + server: McpServer, + path: string, + startLineInclusive?: number, + endLineInclusive?: number +): Promise { + try { + // Log the read operation + if (startLineInclusive !== undefined || endLineInclusive !== undefined) { + await server.sendLoggingMessage({ + level: "info", + data: `Reading file ${path} (lines ${startLineInclusive ?? 1}-${endLineInclusive ?? "end"})`, + }); + } else { + await server.sendLoggingMessage({ + level: "info", + data: `Reading file ${path}`, + }); + } + + const safePath = ensureSafePath(path); + const fileContent = await readFile(safePath, "utf-8"); + if (typeof fileContent !== "string") { + throw new Error(`Result of reading file ${path} is not text: ${fileContent}`); + } + + let content = fileContent; + + // If line range specified, extract only those lines + if (startLineInclusive !== undefined || endLineInclusive !== undefined) { + const lines = fileContent.split("\n"); + + const start = (startLineInclusive ?? 1) - 1; // Convert to 0-indexed + const end = endLineInclusive ?? lines.length; // Default to end of file + + if (start < 0 || start >= lines.length) { + throw new Error(`Start line ${startLineInclusive} is out of bounds (file has ${lines.length} lines)`); + } + if (end < start) { + throw new Error(`End line ${endLineInclusive} is before start line ${startLineInclusive}`); + } + + content = lines.slice(start, end) + .map((line, idx) => `${start + idx + 1}: ${line}`) + .join("\n"); + } + + const structuredContent = { content } + return { + content: [{ type: "text", text: JSON.stringify(structuredContent) }], + structuredContent, + }; + } catch (error) { + return makeErrorCallToolResult(error); + } +} + +/** + * Defines the local tools available to the LLM during sampling. + */ +const LOCAL_TOOLS: Tool[] = [ + { + name: "ripgrep", + description: + "Search for a pattern in files using ripgrep. Returns matching lines with file paths and line numbers.", + inputSchema: { + type: "object", + properties: { + pattern: { + type: "string", + description: "The regex pattern to search for", + }, + path: { + type: "string", + description: "The file or directory path to search in (relative to current directory)", + }, + }, + required: ["pattern", "path"], + }, + }, + { + name: "read", + description: + "Read the contents of a file. Use this to examine files found by ripgrep. " + + "You can optionally specify a line range to read only specific lines. " + + "Tip: When ripgrep finds matches, note the line numbers and request a few lines before and after for context.", + inputSchema: { + type: "object", + properties: { + path: { + type: "string", + description: "The file path to read (relative to current directory)", + }, + startLineInclusive: { + type: "number", + description: "Optional: First line to read (1-indexed, inclusive). Use with endLineInclusive to read a specific range.", + }, + endLineInclusive: { + type: "number", + description: "Optional: Last line to read (1-indexed, inclusive). If not specified, reads to end of file.", + }, + }, + required: ["path"], + }, + }, +]; + +/** + * Executes a local tool and returns the result. + */ +async function executeLocalTool( + server: McpServer, + toolName: string, + toolInput: Record +): Promise { + try { + switch (toolName) { + case "ripgrep": { + const validated = RipgrepInputSchema.parse(toolInput); + return await executeRipgrep(server, validated.pattern, validated.path); + } + case "read": { + const validated = ReadInputSchema.parse(toolInput); + return await executeRead( + server, + validated.path, + validated.startLineInclusive, + validated.endLineInclusive + ); + } + default: + return makeErrorCallToolResult(`Unknown tool: ${toolName}`); + } + } catch (error) { + if (error instanceof z.ZodError) { + return makeErrorCallToolResult(`Invalid input for tool '${toolName}': ${error.errors.map(e => e.message).join(", ")}`); + } + return makeErrorCallToolResult(error); + } +} + +/** + * Runs a tool loop using sampling. + * Continues until the LLM provides a final answer. + */ +async function runToolLoop( + server: McpServer, + initialQuery: string +): Promise<{ answer: string; transcript: SamplingMessage[]; usage: AggregatedUsage }> { + const messages: SamplingMessage[] = [ + { + role: "user", + content: { + type: "text", + text: initialQuery, + }, + }, + ]; + + // Initialize usage tracking + const aggregatedUsage: AggregatedUsage = { + input_tokens: 0, + output_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + api_calls: 0, + }; + + const MAX_ITERATIONS = 20; + let iteration = 0; + + const systemPrompt = + "You are a helpful assistant that searches through files to answer questions. " + + "You have access to ripgrep (for searching) and read (for reading file contents). " + + "Use ripgrep to find relevant files, then read them to provide accurate answers. " + + "All paths are relative to the current working directory. " + + "Be concise and focus on providing the most relevant information." + + "You will be allowed up to " + MAX_ITERATIONS + " iterations of tool use to find the information needed. When you have enough information or reach the last iteration, provide a final answer."; + + let request: CreateMessageRequest["params"] | undefined + let response: CreateMessageResult | undefined + while (iteration < MAX_ITERATIONS) { + iteration++; + + // Request message from LLM with available tools + response = await server.server.createMessage(request = { + messages, + systemPrompt, + maxTokens: 4000, + tools: iteration < MAX_ITERATIONS ? LOCAL_TOOLS : undefined, + // Don't allow tool calls at the last iteration: finish with an answer no matter what! + tool_choice: { mode: iteration < MAX_ITERATIONS ? "auto" : "none" }, + }); + + // Aggregate usage statistics from the response + if (response._meta?.usage) { + const usage = response._meta.usage as any; + aggregatedUsage.input_tokens += usage.input_tokens || 0; + aggregatedUsage.output_tokens += usage.output_tokens || 0; + aggregatedUsage.cache_creation_input_tokens += usage.cache_creation_input_tokens || 0; + aggregatedUsage.cache_read_input_tokens += usage.cache_read_input_tokens || 0; + aggregatedUsage.api_calls += 1; + } + + // Add assistant's response to message history + // SamplingMessage now supports arrays of content + messages.push({ + role: "assistant", + content: response.content, + }); + + if (response.stopReason === "toolUse") { + const contentArray = Array.isArray(response.content) ? response.content : [response.content]; + const toolCalls = contentArray.filter( + (content): content is ToolCallContent => content.type === "tool_use" + ); + + await server.sendLoggingMessage({ + level: "info", + data: `Loop iteration ${iteration}: ${toolCalls.length} tool invocation(s) requested`, + }); + + const toolResults: ToolResultContent[] = await Promise.all(toolCalls.map(async (toolCall) => { + const result = await executeLocalTool(server, toolCall.name, toolCall.input); + return { + type: "tool_result", + toolUseId: toolCall.id, + content: result.content, + structuredContent: result.structuredContent, + isError: result.isError, + } + })) + + messages.push({ + role: "user", + content: iteration < MAX_ITERATIONS ? toolResults : [ + ...toolResults, + { + type: "text", + text: "Using the information retrieved from the tools, please now provide a concise final answer to the original question (last iteration of the tool loop).", + } + ], + }); + } else if (response.stopReason === "endTurn") { + const contentArray = Array.isArray(response.content) ? response.content : [response.content]; + const unexpectedBlocks = contentArray.filter(content => content.type !== "text"); + if (unexpectedBlocks.length > 0) { + throw new Error(`Expected text content in final answer, but got: ${unexpectedBlocks.map(b => b.type).join(", ")}`); + } + + await server.sendLoggingMessage({ + level: "info", + data: `Tool loop completed after ${iteration} iteration(s)`, + }); + + return { + answer: contentArray.map(block => block.text).join("\n\n"), + transcript: messages, + usage: aggregatedUsage + }; + } else if (response?.stopReason === "maxTokens") { + throw new Error("LLM response hit max tokens limit"); + } else { + throw new Error(`Unsupported stop reason: ${response.stopReason}`); + } + } + + throw new Error(`Tool loop exceeded maximum iterations (${MAX_ITERATIONS}); request: ${JSON.stringify(request)}\nresponse: ${JSON.stringify(response)}`); +} + +// Create and configure MCP server +const mcpServer = new McpServer({ + name: "tool-loop-sampling-server", + version: "1.0.0", +}); + +// Register the localResearch tool that uses sampling with a tool loop +mcpServer.registerTool( + "localResearch", + { + description: + "Search for information in files using an AI assistant with ripgrep and file reading capabilities. " + + "The assistant will intelligently use ripgrep to find relevant files and read them to answer your query.", + inputSchema: { + query: z + .string() + .default("describe main classes") + .describe( + "A natural language query describing what to search for (e.g., 'Find all TypeScript files that export a Server class')" + ), + maxIterations: z.number().int().positive().optional().default(20).describe("Maximum number of tool use iterations (default 20)"), + }, + }, + async ({ query, maxIterations }) => { + try { + const { answer, transcript, usage } = await runToolLoop(mcpServer, query); + + // Calculate total input tokens + const totalInputTokens = + usage.input_tokens + + usage.cache_creation_input_tokens + + usage.cache_read_input_tokens; + + // Format usage summary + const usageSummary = + `--- Token Usage Summary ---\n` + + `Total Input Tokens: ${totalInputTokens}\n` + + ` - Regular: ${usage.input_tokens}\n` + + ` - Cache Creation: ${usage.cache_creation_input_tokens}\n` + + ` - Cache Read: ${usage.cache_read_input_tokens}\n` + + `Total Output Tokens: ${usage.output_tokens}\n` + + `Total Tokens: ${totalInputTokens + usage.output_tokens}\n` + + `API Calls: ${usage.api_calls}`; + + return { + content: [ + { + type: "text", + text: answer, + }, + { + type: "text", + text: `\n\n${usageSummary}`, + }, + { + type: "text", + text: `\n\n--- Debug Transcript (${transcript.length} messages) ---\n${JSON.stringify(transcript, null, 2)}`, + }, + ], + }; + } catch (error) { + return makeErrorCallToolResult(error); + } + } +); + +async function main() { + const transport = new StdioServerTransport(); + await mcpServer.connect(transport); + console.error("MCP Tool Loop Sampling Server is running..."); + console.error(`Working directory: ${CWD}`); +} + +main().catch((error) => { + console.error("Server error:", error); + process.exit(1); +}); diff --git a/src/examples/server/toolWithSampleServer.ts b/src/examples/server/toolWithSampleServer.ts index 44e5cecbb..961e5f516 100644 --- a/src/examples/server/toolWithSampleServer.ts +++ b/src/examples/server/toolWithSampleServer.ts @@ -34,11 +34,21 @@ mcpServer.registerTool( maxTokens: 500, }); + // Extract all text content blocks from the response + const parts: string[] = []; + for (const content of Array.isArray(response.content) ? response.content : [response.content]) { + if (content.type === "text") { + parts.push(content.text); + } else { + throw new Error(`Unexpected content type: ${content.type}`); + } + } + return { content: [ { type: "text", - text: response.content.type === "text" ? response.content.text : "Unable to generate summary", + text: parts.join('\n'), }, ], }; diff --git a/src/types.test.ts b/src/types.test.ts index 0aee62a93..2bc1be00d 100644 --- a/src/types.test.ts +++ b/src/types.test.ts @@ -5,7 +5,16 @@ import { ContentBlockSchema, PromptMessageSchema, CallToolResultSchema, - CompleteRequestSchema + CompleteRequestSchema, + ToolCallContentSchema, + ToolResultContentSchema, + ToolChoiceSchema, + UserMessageSchema, + AssistantMessageSchema, + SamplingMessageSchema, + CreateMessageRequestSchema, + CreateMessageResultSchema, + ClientCapabilitiesSchema, } from "./types.js"; describe("Types", () => { @@ -312,4 +321,459 @@ describe("Types", () => { } }); }); + + describe("SEP-1577: Sampling with Tools", () => { + describe("ToolCallContent", () => { + test("should validate a tool call content", () => { + const toolCall = { + type: "tool_use", + id: "call_123", + name: "get_weather", + input: { city: "San Francisco", units: "celsius" } + }; + + const result = ToolCallContentSchema.safeParse(toolCall); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.type).toBe("tool_use"); + expect(result.data.id).toBe("call_123"); + expect(result.data.name).toBe("get_weather"); + expect(result.data.input).toEqual({ city: "San Francisco", units: "celsius" }); + } + }); + + test("should validate tool call with _meta", () => { + const toolCall = { + type: "tool_use", + id: "call_456", + name: "search", + input: { query: "test" }, + _meta: { custom: "data" } + }; + + const result = ToolCallContentSchema.safeParse(toolCall); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data._meta).toEqual({ custom: "data" }); + } + }); + + test("should fail validation for missing required fields", () => { + const invalidToolCall = { + type: "tool_use", + name: "test" + // missing id and input + }; + + const result = ToolCallContentSchema.safeParse(invalidToolCall); + expect(result.success).toBe(false); + }); + }); + + describe("ToolResultContent", () => { + test("should validate a tool result content", () => { + const toolResult = { + type: "tool_result", + toolUseId: "call_123", + content: { temperature: 72, condition: "sunny" } + }; + + const result = ToolResultContentSchema.safeParse(toolResult); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.type).toBe("tool_result"); + expect(result.data.toolUseId).toBe("call_123"); + expect(result.data.content).toEqual({ temperature: 72, condition: "sunny" }); + } + }); + + test("should validate tool result with error in content", () => { + const toolResult = { + type: "tool_result", + toolUseId: "call_456", + content: { error: "API_ERROR", message: "Service unavailable" } + }; + + const result = ToolResultContentSchema.safeParse(toolResult); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.content).toEqual({ error: "API_ERROR", message: "Service unavailable" }); + } + }); + + test("should fail validation for missing required fields", () => { + const invalidToolResult = { + type: "tool_result", + content: { data: "test" } + // missing toolUseId + }; + + const result = ToolResultContentSchema.safeParse(invalidToolResult); + expect(result.success).toBe(false); + }); + }); + + describe("ToolChoice", () => { + test("should validate tool choice with mode auto", () => { + const toolChoice = { + mode: "auto" + }; + + const result = ToolChoiceSchema.safeParse(toolChoice); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.mode).toBe("auto"); + } + }); + + test("should validate tool choice with mode required", () => { + const toolChoice = { + mode: "required", + disable_parallel_tool_use: true + }; + + const result = ToolChoiceSchema.safeParse(toolChoice); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.mode).toBe("required"); + expect(result.data.disable_parallel_tool_use).toBe(true); + } + }); + + test("should validate empty tool choice", () => { + const toolChoice = {}; + + const result = ToolChoiceSchema.safeParse(toolChoice); + expect(result.success).toBe(true); + }); + + test("should fail validation for invalid mode", () => { + const invalidToolChoice = { + mode: "invalid" + }; + + const result = ToolChoiceSchema.safeParse(invalidToolChoice); + expect(result.success).toBe(false); + }); + }); + + describe("UserMessage and AssistantMessage", () => { + test("should validate user message with text", () => { + const userMessage = { + role: "user", + content: { type: "text", text: "What's the weather?" } + }; + + const result = UserMessageSchema.safeParse(userMessage); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.role).toBe("user"); + expect(result.data.content.type).toBe("text"); + } + }); + + test("should validate user message with tool result", () => { + const userMessage = { + role: "user", + content: { + type: "tool_result", + toolUseId: "call_123", + content: { temperature: 72 } + } + }; + + const result = UserMessageSchema.safeParse(userMessage); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.content.type).toBe("tool_result"); + } + }); + + test("should validate assistant message with text", () => { + const assistantMessage = { + role: "assistant", + content: { type: "text", text: "I'll check the weather for you." } + }; + + const result = AssistantMessageSchema.safeParse(assistantMessage); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.role).toBe("assistant"); + } + }); + + test("should validate assistant message with tool call", () => { + const assistantMessage = { + role: "assistant", + content: { + type: "tool_use", + id: "call_123", + name: "get_weather", + input: { city: "SF" } + } + }; + + const result = AssistantMessageSchema.safeParse(assistantMessage); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.content.type).toBe("tool_use"); + } + }); + + test("should fail validation for assistant with tool result", () => { + const invalidMessage = { + role: "assistant", + content: { + type: "tool_result", + toolUseId: "call_123", + content: {} + } + }; + + const result = AssistantMessageSchema.safeParse(invalidMessage); + expect(result.success).toBe(false); + }); + + test("should fail validation for user with tool call", () => { + const invalidMessage = { + role: "user", + content: { + type: "tool_use", + id: "call_123", + name: "test", + input: {} + } + }; + + const result = UserMessageSchema.safeParse(invalidMessage); + expect(result.success).toBe(false); + }); + }); + + describe("SamplingMessage", () => { + test("should validate user message via discriminated union", () => { + const message = { + role: "user", + content: { type: "text", text: "Hello" } + }; + + const result = SamplingMessageSchema.safeParse(message); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.role).toBe("user"); + } + }); + + test("should validate assistant message via discriminated union", () => { + const message = { + role: "assistant", + content: { type: "text", text: "Hi there!" } + }; + + const result = SamplingMessageSchema.safeParse(message); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.role).toBe("assistant"); + } + }); + }); + + describe("CreateMessageRequest", () => { + test("should validate request without tools", () => { + const request = { + method: "sampling/createMessage", + params: { + messages: [ + { role: "user", content: { type: "text", text: "Hello" } } + ], + maxTokens: 1000 + } + }; + + const result = CreateMessageRequestSchema.safeParse(request); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.params.tools).toBeUndefined(); + } + }); + + test("should validate request with tools", () => { + const request = { + method: "sampling/createMessage", + params: { + messages: [ + { role: "user", content: { type: "text", text: "What's the weather?" } } + ], + maxTokens: 1000, + tools: [ + { + name: "get_weather", + description: "Get weather for a location", + inputSchema: { + type: "object", + properties: { + location: { type: "string" } + }, + required: ["location"] + } + } + ], + tool_choice: { + mode: "auto" + } + } + }; + + const result = CreateMessageRequestSchema.safeParse(request); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.params.tools).toHaveLength(1); + expect(result.data.params.tool_choice?.mode).toBe("auto"); + } + }); + + test("should validate request with includeContext (soft-deprecated)", () => { + const request = { + method: "sampling/createMessage", + params: { + messages: [ + { role: "user", content: { type: "text", text: "Help" } } + ], + maxTokens: 1000, + includeContext: "thisServer" + } + }; + + const result = CreateMessageRequestSchema.safeParse(request); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.params.includeContext).toBe("thisServer"); + } + }); + }); + + describe("CreateMessageResult", () => { + test("should validate result with text content", () => { + const result = { + model: "claude-3-5-sonnet-20241022", + role: "assistant", + content: { type: "text", text: "Here's the answer." }, + stopReason: "endTurn" + }; + + const parseResult = CreateMessageResultSchema.safeParse(result); + expect(parseResult.success).toBe(true); + if (parseResult.success) { + expect(parseResult.data.role).toBe("assistant"); + expect(parseResult.data.stopReason).toBe("endTurn"); + } + }); + + test("should validate result with tool call", () => { + const result = { + model: "claude-3-5-sonnet-20241022", + role: "assistant", + content: { + type: "tool_use", + id: "call_123", + name: "get_weather", + input: { city: "SF" } + }, + stopReason: "toolUse" + }; + + const parseResult = CreateMessageResultSchema.safeParse(result); + expect(parseResult.success).toBe(true); + if (parseResult.success) { + expect(parseResult.data.stopReason).toBe("toolUse"); + expect(parseResult.data.content.type).toBe("tool_use"); + } + }); + + test("should validate all new stop reasons", () => { + const stopReasons = ["endTurn", "stopSequence", "maxTokens", "toolUse", "refusal", "other"]; + + stopReasons.forEach(stopReason => { + const result = { + model: "test", + role: "assistant", + content: { type: "text", text: "test" }, + stopReason + }; + + const parseResult = CreateMessageResultSchema.safeParse(result); + expect(parseResult.success).toBe(true); + }); + }); + + test("should allow custom stop reason string", () => { + const result = { + model: "test", + role: "assistant", + content: { type: "text", text: "test" }, + stopReason: "custom_provider_reason" + }; + + const parseResult = CreateMessageResultSchema.safeParse(result); + expect(parseResult.success).toBe(true); + }); + + test("should fail for user role in result", () => { + const result = { + model: "test", + role: "user", + content: { type: "text", text: "test" } + }; + + const parseResult = CreateMessageResultSchema.safeParse(result); + expect(parseResult.success).toBe(false); + }); + }); + + describe("ClientCapabilities with sampling", () => { + test("should validate capabilities with sampling.tools", () => { + const capabilities = { + sampling: { + tools: {} + } + }; + + const result = ClientCapabilitiesSchema.safeParse(capabilities); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.sampling?.tools).toBeDefined(); + } + }); + + test("should validate capabilities with sampling.context", () => { + const capabilities = { + sampling: { + context: {} + } + }; + + const result = ClientCapabilitiesSchema.safeParse(capabilities); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.sampling?.context).toBeDefined(); + } + }); + + test("should validate capabilities with both", () => { + const capabilities = { + sampling: { + context: {}, + tools: {} + } + }; + + const result = ClientCapabilitiesSchema.safeParse(capabilities); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.sampling?.context).toBeDefined(); + expect(result.data.sampling?.tools).toBeDefined(); + } + }); + }); + }); }); diff --git a/src/types.ts b/src/types.ts index ee2ceb5ed..6cd9cd36c 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,5 +1,6 @@ import { z, ZodTypeAny } from "zod"; import { AuthInfo } from "./server/auth/types.js"; +import { is } from "@babel/types"; export const LATEST_PROTOCOL_VERSION = "2025-06-18"; export const DEFAULT_NEGOTIATED_PROTOCOL_VERSION = "2025-03-26"; @@ -285,7 +286,22 @@ export const ClientCapabilitiesSchema = z /** * Present if the client supports sampling from an LLM. */ - sampling: z.optional(z.object({}).passthrough()), + sampling: z.optional( + z + .object({ + /** + * Present if the client supports non-'none' values for includeContext parameter. + * SOFT-DEPRECATED: New implementations should use tools parameter instead. + */ + context: z.optional(z.object({}).passthrough()), + /** + * Present if the client supports tools and tool_choice parameters in sampling requests. + * Presence indicates full tool calling support. + */ + tools: z.optional(z.object({}).passthrough()), + }) + .passthrough(), + ), /** * Present if the client supports eliciting user input. */ @@ -821,6 +837,36 @@ export const AudioContentSchema = z }) .passthrough(); +/** + * A tool call request from an assistant (LLM). + * Represents the assistant's request to use a tool. + */ +export const ToolCallContentSchema = z + .object({ + type: z.literal("tool_use"), + /** + * The name of the tool to invoke. + * Must match a tool name from the request's tools array. + */ + name: z.string(), + /** + * Unique identifier for this tool call. + * Used to correlate with ToolResultContent in subsequent messages. + */ + id: z.string(), + /** + * Arguments to pass to the tool. + * Must conform to the tool's inputSchema. + */ + input: z.object({}).passthrough(), + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()), + }) + .passthrough(); + /** * The contents of a resource, embedded into a prompt or tool call result. */ @@ -1147,15 +1193,91 @@ export const ModelPreferencesSchema = z .passthrough(); /** - * Describes a message issued to or received from an LLM API. + * Controls tool usage behavior in sampling requests. */ -export const SamplingMessageSchema = z +export const ToolChoiceSchema = z .object({ - role: z.enum(["user", "assistant"]), - content: z.union([TextContentSchema, ImageContentSchema, AudioContentSchema]), + /** + * Controls when tools are used: + * - "auto": Model decides whether to use tools (default) + * - "required": Model MUST use at least one tool before completing + */ + mode: z.optional(z.enum(["auto", "required", "none"])), + /** + * If true, model should not use multiple tools in parallel. + * Some models may ignore this hint. + * Default: false + */ + disable_parallel_tool_use: z.optional(z.boolean()), + }) + .passthrough(); + +/** + * The result of a tool execution, provided by the user (server). + * Represents the outcome of invoking a tool requested via ToolCallContent. + */ +export const ToolResultContentSchema = z.object({ + type: z.literal("tool_result"), + toolUseId: z.string().describe("The unique identifier for the corresponding tool call."), + content: z.array(z.union([TextContentSchema, ImageContentSchema, AudioContentSchema])), + structuredContent: z.object({}).passthrough().optional(), + isError: z.optional(z.boolean()), +}) + +export const UserMessageContentSchema = z.discriminatedUnion("type", [ + TextContentSchema, + ImageContentSchema, + AudioContentSchema, + ToolResultContentSchema, +]); + +/** + * A message from the user (server) in a sampling conversation. + */ +export const UserMessageSchema = z + .object({ + role: z.literal("user"), + content: z.union([UserMessageContentSchema, z.array(UserMessageContentSchema)]), + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()), + }) + .passthrough(); + +export const AssistantMessageContentSchema = z.discriminatedUnion("type", [ + TextContentSchema, + ImageContentSchema, + AudioContentSchema, + ToolCallContentSchema, +]); + +/** + * A message from the assistant (LLM) in a sampling conversation. + */ +export const AssistantMessageSchema = z + .object({ + role: z.literal("assistant"), + content: z.union([AssistantMessageContentSchema, z.array(AssistantMessageContentSchema)]), + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()), }) .passthrough(); +/** + * Describes a message issued to or received from an LLM API. + * This is a discriminated union of UserMessage and AssistantMessage, where + * each role has its own set of allowed content types. + */ +export const SamplingMessageSchema = z.discriminatedUnion("role", [ + UserMessageSchema, + AssistantMessageSchema, +]); + /** * A request from the server to sample an LLM via the client. The client has full discretion over which model to select. The client should also inform the user before beginning sampling, to allow them to inspect the request (human in the loop) and decide whether to approve it. */ @@ -1168,7 +1290,9 @@ export const CreateMessageRequestSchema = RequestSchema.extend({ */ systemPrompt: z.optional(z.string()), /** + * SOFT-DEPRECATED: Use tools parameter instead. * A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. The client MAY ignore this request. + * Requires clientCapabilities.sampling.context. */ includeContext: z.optional(z.enum(["none", "thisServer", "allServers"])), temperature: z.optional(z.number()), @@ -1185,6 +1309,16 @@ export const CreateMessageRequestSchema = RequestSchema.extend({ * The server's preferences for which model to select. */ modelPreferences: z.optional(ModelPreferencesSchema), + /** + * Tool definitions for the LLM to use. + * Requires clientCapabilities.sampling.tools. + */ + tools: z.optional(z.array(ToolSchema)), + /** + * Controls tool usage behavior. + * Requires clientCapabilities.sampling.tools and tools parameter. + */ + toolChoice: z.optional(ToolChoiceSchema), }), }); @@ -1198,16 +1332,24 @@ export const CreateMessageResultSchema = ResultSchema.extend({ model: z.string(), /** * The reason why sampling stopped. + * - "endTurn": Model completed naturally + * - "stopSequence": Hit a stop sequence + * - "maxTokens": Reached token limit + * - "toolUse": Model wants to use a tool + * - "refusal": Model refused the request + * - "other": Other provider-specific reason */ stopReason: z.optional( - z.enum(["endTurn", "stopSequence", "maxTokens"]).or(z.string()), + z.enum(["endTurn", "stopSequence", "maxTokens", "toolUse"]).or(z.string()), ), - role: z.enum(["user", "assistant"]), - content: z.discriminatedUnion("type", [ - TextContentSchema, - ImageContentSchema, - AudioContentSchema - ]), + /** + * The role is always "assistant" in responses from the LLM. + */ + role: z.literal("assistant"), + /** + * Response content. May be ToolCallContent if stopReason is "toolUse". + */ + content: z.union([AssistantMessageContentSchema, z.array(AssistantMessageContentSchema)]), }); /* Elicitation */ @@ -1630,6 +1772,8 @@ export type GetPromptRequest = Infer; export type TextContent = Infer; export type ImageContent = Infer; export type AudioContent = Infer; +export type ToolCallContent = Infer; +export type ToolResultContent = Infer; export type EmbeddedResource = Infer; export type ResourceLink = Infer; export type ContentBlock = Infer; @@ -1653,9 +1797,14 @@ export type SetLevelRequest = Infer; export type LoggingMessageNotification = Infer; /* Sampling */ +export type ToolChoice = Infer; +export type UserMessage = Infer; +export type AssistantMessage = Infer; export type SamplingMessage = Infer; export type CreateMessageRequest = Infer; export type CreateMessageResult = Infer; +export type AssistantMessageContent = Infer; +export type UserMessageContent = Infer; /* Elicitation */ export type BooleanSchema = Infer;