diff --git a/package-lock.json b/package-lock.json index c94ab5c9c..75e963d45 100644 --- a/package-lock.json +++ b/package-lock.json @@ -23,6 +23,7 @@ "zod-to-json-schema": "^3.24.1" }, "devDependencies": { + "@anthropic-ai/sdk": "^0.65.0", "@eslint/js": "^9.8.0", "@jest-mock/express": "^3.0.0", "@types/content-type": "^1.1.8", @@ -61,6 +62,27 @@ "node": ">=6.0.0" } }, + "node_modules/@anthropic-ai/sdk": { + "version": "0.65.0", + "resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.65.0.tgz", + "integrity": "sha512-zIdPOcrCVEI8t3Di40nH4z9EoeyGZfXbYSvWdDLsB/KkaSYMnEgC7gmcgWu83g2NTn1ZTpbMvpdttWDGGIk6zw==", + "dev": true, + "license": "MIT", + "dependencies": { + "json-schema-to-ts": "^3.1.1" + }, + "bin": { + "anthropic-ai-sdk": "bin/cli" + }, + "peerDependencies": { + "zod": "^3.25.0 || ^4.0.0" + }, + "peerDependenciesMeta": { + "zod": { + "optional": true + } + } + }, "node_modules/@babel/code-frame": { "version": "7.26.2", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.26.2.tgz", diff --git a/package.json b/package.json index b5b9b8ec9..b1eaa902d 100644 --- a/package.json +++ b/package.json @@ -76,6 +76,7 @@ "zod-to-json-schema": "^3.24.1" }, "devDependencies": { + "@anthropic-ai/sdk": "^0.65.0", "@eslint/js": "^9.8.0", "@jest-mock/express": "^3.0.0", "@types/content-type": "^1.1.8", diff --git a/src/examples/backfill/backfillSampling.ts b/src/examples/backfill/backfillSampling.ts new file mode 100644 index 000000000..c377021cd --- /dev/null +++ b/src/examples/backfill/backfillSampling.ts @@ -0,0 +1,240 @@ +/* + This example implements an stdio MCP proxy that backfills sampling requests using the Claude API. + + Usage: + npx -y @modelcontextprotocol/inspector \ + npx -y --silent tsx src/examples/backfill/backfillSampling.ts -- \ + npx -y --silent @modelcontextprotocol/server-everything +*/ + +import { Anthropic } from "@anthropic-ai/sdk"; +import { Base64ImageSource, ContentBlock, ContentBlockParam, MessageParam } from "@anthropic-ai/sdk/resources/messages.js"; +import { StdioServerTransport } from '../../server/stdio.js'; +import { StdioClientTransport } from '../../client/stdio.js'; +import { + CancelledNotification, + CancelledNotificationSchema, + isInitializeRequest, + isJSONRPCRequest, + ElicitRequest, + ElicitRequestSchema, + CreateMessageRequest, + CreateMessageRequestSchema, + CreateMessageResult, + JSONRPCResponse, + isInitializedNotification, + CallToolRequest, + CallToolRequestSchema, + isJSONRPCNotification, +} from "../../types.js"; +import { Transport } from "../../shared/transport.js"; + +const DEFAULT_MAX_TOKENS = process.env.DEFAULT_MAX_TOKENS ? parseInt(process.env.DEFAULT_MAX_TOKENS) : 1000; + +// TODO: move to SDK + +const isCancelledNotification: (value: unknown) => value is CancelledNotification = + ((value: any) => CancelledNotificationSchema.safeParse(value).success) as any; + +const isCallToolRequest: (value: unknown) => value is CallToolRequest = + ((value: any) => CallToolRequestSchema.safeParse(value).success) as any; + +const isElicitRequest: (value: unknown) => value is ElicitRequest = + ((value: any) => ElicitRequestSchema.safeParse(value).success) as any; + +const isCreateMessageRequest: (value: unknown) => value is CreateMessageRequest = + ((value: any) => CreateMessageRequestSchema.safeParse(value).success) as any; + + +function contentToMcp(content: ContentBlock): CreateMessageResult['content'][number] { + switch (content.type) { + case 'text': + return {type: 'text', text: content.text}; + default: + throw new Error(`Unsupported content type: ${content.type}`); + } +} + +function contentFromMcp(content: CreateMessageRequest['params']['messages'][number]['content']): ContentBlockParam { + switch (content.type) { + case 'text': + return {type: 'text', text: content.text}; + case 'image': + return { + type: 'image', + source: { + data: content.data, + media_type: content.mimeType as Base64ImageSource['media_type'], + type: 'base64', + }, + }; + case 'audio': + default: + throw new Error(`Unsupported content type: ${content.type}`); + } +} + +export type NamedTransport = { + name: 'client' | 'server', + transport: T, +} + +export async function setupBackfill(client: NamedTransport, server: NamedTransport, api: Anthropic) { + const backfillMeta = await (async () => { + const models = new Set(); + let defaultModel: string | undefined; + for await (const info of api.models.list()) { + models.add(info.id); + if (info.id.indexOf('sonnet') >= 0 && defaultModel === undefined) { + defaultModel = info.id; + } + } + if (defaultModel === undefined) { + if (models.size === 0) { + throw new Error("No models available from the API"); + } + defaultModel = models.values().next().value; + } + return { + sampling_models: Array.from(models), + sampling_default_model: defaultModel, + }; + })(); + + function pickModel(preferences: CreateMessageRequest['params']['modelPreferences'] | undefined): string { + if (preferences?.hints) { + for (const hint of Object.values(preferences.hints)) { + if (hint.name !== undefined && backfillMeta.sampling_models.includes(hint.name)) { + return hint.name; + } + } + } + // TODO: linear model on preferences?.{intelligencePriority, speedPriority, costPriority} to pick betwen haiku, sonnet, opus. + return backfillMeta.sampling_default_model!; + } + + let clientSupportsSampling: boolean | undefined; + // let clientSupportsElicitation: boolean | undefined; + + const propagateMessage = (source: NamedTransport, target: NamedTransport) => { + source.transport.onmessage = async (message, extra) => { + console.error(`[proxy]: Message from ${source.name} transport: ${JSON.stringify(message)}; extra: ${JSON.stringify(extra)}`); + + if (isJSONRPCRequest(message)) { + if (isInitializeRequest(message)) { + if (!(clientSupportsSampling = !!message.params.capabilities.sampling)) { + message.params.capabilities.sampling = {} + message.params._meta = {...(message.params._meta ?? {}), ...backfillMeta}; + } + } else if (isCreateMessageRequest(message)) { + if ((message.params.includeContext ?? 'none') !== 'none') { + const errorMessage = "includeContext != none not supported by MCP sampling backfill" + console.error(`[proxy]: ${errorMessage}`); + source.transport.send({ + jsonrpc: "2.0", + id: message.id, + error: { + code: -32601, // Method not found + message: errorMessage, + }, + }, {relatedRequestId: message.id}); + return; + } + + try { + // message.params. + const msg = await api.messages.create({ + model: pickModel(message.params.modelPreferences), + system: message.params.systemPrompt === undefined ? undefined : [ + { + type: "text", + text: message.params.systemPrompt + }, + ], + messages: message.params.messages.map(({role, content}) => ({ + role, + content: [contentFromMcp(content)] + })), + max_tokens: message.params.maxTokens ?? DEFAULT_MAX_TOKENS, + temperature: message.params.temperature, + stop_sequences: message.params.stopSequences, + ...(message.params.metadata ?? {}), + }); + + if (msg.content.length !== 1) { + throw new Error(`Expected exactly one content item in the response, got ${msg.content.length}`); + } + + source.transport.send({ + jsonrpc: "2.0", + id: message.id, + result: { + model: msg.model, + stopReason: msg.stop_reason, + role: msg.role, + content: contentToMcp(msg.content[0]), + }, + }); + } catch (error) { + console.error(`[proxy]: Error processing message: ${(error as Error).message}`); + + source.transport.send({ + jsonrpc: "2.0", + id: message.id, + error: { + code: -32601, // Method not found + message: `Error processing message: ${(error as Error).message}`, + }, + }, {relatedRequestId: message.id}); + } + return; + // } else if (isElicitRequest(message) && !clientSupportsElicitation) { + // // TODO: form + // return; + } + } else if (isJSONRPCNotification(message)) { + if (isInitializedNotification(message) && source.name === 'server') { + if (!clientSupportsSampling) { + message.params = {...(message.params ?? {}), _meta: {...(message.params?._meta ?? {}), ...backfillMeta}}; + } + } + } + + try { + const relatedRequestId = isCancelledNotification(message)? message.params.requestId : undefined; + await target.transport.send(message, {relatedRequestId}); + } catch (error) { + console.error(`[proxy]: Error sending message to ${target.name}:`, error); + } + }; + }; + propagateMessage(server, client); + propagateMessage(client, server); + + const addErrorHandler = (transport: NamedTransport) => { + transport.transport.onerror = async (error: Error) => { + console.error(`[proxy]: Error from ${transport.name} transport:`, error); + }; + }; + + addErrorHandler(client); + addErrorHandler(server); + + await server.transport.start(); + await client.transport.start(); +} + +async function main() { + const args = process.argv.slice(2); + const client: NamedTransport = {name: 'client', transport: new StdioClientTransport({command: args[0], args: args.slice(1)})}; + const server: NamedTransport = {name: 'server', transport: new StdioServerTransport()}; + + const api = new Anthropic(); + await setupBackfill(client, server, api); + console.error("[proxy]: Transports started."); +} + +main().catch((error) => { + console.error("[proxy]: Fatal error:", error); + process.exit(1); +});