diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 86eaf6d9e..97b02a55c 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -1582,6 +1582,341 @@ test('should respect log level for transport without sessionId', async () => { expect(clientTransport.onmessage).toHaveBeenCalled(); }); +describe('createMessage validation', () => { + test('should throw when tools are provided without sampling.tools capability', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client( + { name: 'test client', version: '1.0' }, + { capabilities: { sampling: {} } } // No tools capability + ); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await expect( + server.createMessage({ + messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] + }) + ).rejects.toThrow('Client does not support sampling tools capability.'); + }); + + test('should throw when toolChoice is provided without sampling.tools capability', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client( + { name: 'test client', version: '1.0' }, + { capabilities: { sampling: {} } } // No tools capability + ); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await expect( + server.createMessage({ + messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }], + maxTokens: 100, + toolChoice: { mode: 'auto' } + }) + ).rejects.toThrow('Client does not support sampling tools capability.'); + }); + + test('should throw when tool_result is mixed with other content', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await expect( + server.createMessage({ + messages: [ + { role: 'user', content: { type: 'text', text: 'hello' } }, + { role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } }, + { + role: 'user', + content: [ + { type: 'tool_result', toolUseId: 'call_1', content: [] }, + { type: 'text', text: 'mixed content' } // Mixed! + ] + } + ], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] + }) + ).rejects.toThrow('The last message must contain only tool_result content if any is present'); + }); + + test('should throw when tool_result has no matching tool_use in previous message', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // tool_result without previous tool_use + await expect( + server.createMessage({ + messages: [ + { role: 'user', content: { type: 'text', text: 'hello' } }, + { role: 'user', content: { type: 'tool_result', toolUseId: 'call_1', content: [] } } + ], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] + }) + ).rejects.toThrow('tool_result blocks are not matching any tool_use from the previous message'); + }); + + test('should throw when tool_result IDs do not match tool_use IDs', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await expect( + server.createMessage({ + messages: [ + { role: 'user', content: { type: 'text', text: 'hello' } }, + { role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } }, + { role: 'user', content: { type: 'tool_result', toolUseId: 'wrong_id', content: [] } } + ], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] + }) + ).rejects.toThrow('ids of tool_result blocks and tool_use blocks from previous message do not match'); + }); + + test('should allow text-only messages with tools (no tool_results)', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await expect( + server.createMessage({ + messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] + }) + ).resolves.toMatchObject({ model: 'test-model' }); + }); + + test('should allow valid matching tool_result/tool_use IDs', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await expect( + server.createMessage({ + messages: [ + { role: 'user', content: { type: 'text', text: 'hello' } }, + { role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } }, + { role: 'user', content: { type: 'tool_result', toolUseId: 'call_1', content: [] } } + ], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] + }) + ).resolves.toMatchObject({ model: 'test-model' }); + }); + + test('should throw when user sends text instead of tool_result after tool_use', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // User ignores tool_use and sends text instead + await expect( + server.createMessage({ + messages: [ + { role: 'user', content: { type: 'text', text: 'hello' } }, + { role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } }, + { role: 'user', content: { type: 'text', text: 'actually nevermind' } } + ], + maxTokens: 100, + tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }] + }) + ).rejects.toThrow('ids of tool_result blocks and tool_use blocks from previous message do not match'); + }); + + test('should throw when only some tool_results are provided for parallel tool_use', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Parallel tool_use but only one tool_result provided + await expect( + server.createMessage({ + messages: [ + { role: 'user', content: { type: 'text', text: 'hello' } }, + { + role: 'assistant', + content: [ + { type: 'tool_use', id: 'call_1', name: 'tool_a', input: {} }, + { type: 'tool_use', id: 'call_2', name: 'tool_b', input: {} } + ] + }, + { role: 'user', content: { type: 'tool_result', toolUseId: 'call_1', content: [] } } + ], + maxTokens: 100, + tools: [ + { name: 'tool_a', inputSchema: { type: 'object' } }, + { name: 'tool_b', inputSchema: { type: 'object' } } + ] + }) + ).rejects.toThrow('ids of tool_result blocks and tool_use blocks from previous message do not match'); + }); + + test('should validate tool_use/tool_result even without tools in current request', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Previous request returned tool_use, now sending tool_result without tools param + await expect( + server.createMessage({ + messages: [ + { role: 'user', content: { type: 'text', text: 'hello' } }, + { role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } }, + { role: 'user', content: { type: 'tool_result', toolUseId: 'wrong_id', content: [] } } + ], + maxTokens: 100 + // Note: no tools param - this is a follow-up request after tool execution + }) + ).rejects.toThrow('ids of tool_result blocks and tool_use blocks from previous message do not match'); + }); + + test('should allow valid tool_use/tool_result without tools in current request', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Previous request returned tool_use, now sending matching tool_result without tools param + await expect( + server.createMessage({ + messages: [ + { role: 'user', content: { type: 'text', text: 'hello' } }, + { role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } }, + { role: 'user', content: { type: 'tool_result', toolUseId: 'call_1', content: [] } } + ], + maxTokens: 100 + // Note: no tools param - this is a follow-up request after tool execution + }) + ).resolves.toMatchObject({ model: 'test-model' }); + }); + + test('should handle empty messages array', async () => { + const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} }); + + const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } }); + + client.setRequestHandler(CreateMessageRequestSchema, async () => ({ + model: 'test-model', + role: 'assistant', + content: { type: 'text', text: 'Response' } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Empty messages array should not crash + await expect( + server.createMessage({ + messages: [], + maxTokens: 100 + }) + ).resolves.toMatchObject({ model: 'test-model' }); + }); +}); + test('should respect log level for transport with sessionId', async () => { const server = new Server( { diff --git a/src/server/index.ts b/src/server/index.ts index 8ec838e51..cd8ec6d28 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -30,7 +30,9 @@ import { type ServerRequest, type ServerResult, SetLevelRequestSchema, - SUPPORTED_PROTOCOL_VERSIONS + SUPPORTED_PROTOCOL_VERSIONS, + type ToolResultContent, + type ToolUseContent } from '../types.js'; import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, jsonSchemaValidator } from '../validation/types.js'; @@ -326,6 +328,48 @@ export class Server< } async createMessage(params: CreateMessageRequest['params'], options?: RequestOptions) { + // Capability check - only required when tools/toolChoice are provided + if (params.tools || params.toolChoice) { + if (!this._clientCapabilities?.sampling?.tools) { + throw new Error('Client does not support sampling tools capability.'); + } + } + + // Message structure validation - always validate tool_use/tool_result pairs. + // These may appear even without tools/toolChoice in the current request when + // a previous sampling request returned tool_use and this is a follow-up with results. + if (params.messages.length > 0) { + const lastMessage = params.messages[params.messages.length - 1]; + const lastContent = Array.isArray(lastMessage.content) ? lastMessage.content : [lastMessage.content]; + const hasToolResults = lastContent.some(c => c.type === 'tool_result'); + + const previousMessage = params.messages.length > 1 ? params.messages[params.messages.length - 2] : undefined; + const previousContent = previousMessage + ? Array.isArray(previousMessage.content) + ? previousMessage.content + : [previousMessage.content] + : []; + const hasPreviousToolUse = previousContent.some(c => c.type === 'tool_use'); + + if (hasToolResults) { + if (lastContent.some(c => c.type !== 'tool_result')) { + throw new Error('The last message must contain only tool_result content if any is present'); + } + if (!hasPreviousToolUse) { + throw new Error('tool_result blocks are not matching any tool_use from the previous message'); + } + } + if (hasPreviousToolUse) { + const toolUseIds = new Set(previousContent.filter(c => c.type === 'tool_use').map(c => (c as ToolUseContent).id)); + const toolResultIds = new Set( + lastContent.filter(c => c.type === 'tool_result').map(c => (c as ToolResultContent).toolUseId) + ); + if (toolUseIds.size !== toolResultIds.size || ![...toolUseIds].every(id => toolResultIds.has(id))) { + throw new Error('ids of tool_result blocks and tool_use blocks from previous message do not match'); + } + } + } + return this.request({ method: 'sampling/createMessage', params }, CreateMessageResultSchema, options); }