diff --git a/.changeset/wild-stars-teach.md b/.changeset/wild-stars-teach.md new file mode 100644 index 00000000..814b239d --- /dev/null +++ b/.changeset/wild-stars-teach.md @@ -0,0 +1,6 @@ +--- +'@openai/agents-openai': patch +'@openai/agents-core': patch +--- + +Add hosted MCP server support diff --git a/examples/mcp/hosted-mcp-human-in-the-loop.ts b/examples/mcp/hosted-mcp-human-in-the-loop.ts new file mode 100644 index 00000000..a185d116 --- /dev/null +++ b/examples/mcp/hosted-mcp-human-in-the-loop.ts @@ -0,0 +1,86 @@ +import * as readline from 'readline/promises'; +import { stdin, stdout } from 'node:process'; +import { Agent, run, hostedMcpTool, RunToolApprovalItem } from '@openai/agents'; + +async function promptApproval(item: RunToolApprovalItem): Promise { + const rl = readline.createInterface({ input: stdin, output: stdout }); + const name = item.rawItem.name; + const params = JSON.parse(item.rawItem.providerData?.arguments || '{}'); + const answer = await rl.question( + `Approve running tool (mcp: ${name}, params: ${JSON.stringify(params)})? (y/n) `, + ); + rl.close(); + return answer.toLowerCase().trim() === 'y'; +} + +async function main(verbose: boolean, stream: boolean): Promise { + // 'always' | + // 'never' | + // { never?: { toolNames: string[] }; always?: { toolNames: string[] } } + const requireApproval = { + never: { toolNames: ['search_codex_code', 'fetch_codex_documentation'] }, + always: { toolNames: ['fetch_generic_url_content'] }, + }; + const agent = new Agent({ + name: 'MCP Assistant', + instructions: 'You must always use the MCP tools to answer questions.', + tools: [ + hostedMcpTool({ + serverLabel: 'gitmcp', + serverUrl: 'https://gitmcp.io/openai/codex', + requireApproval, + // when you don't pass onApproval, the agent loop will handle the approval process + }), + ], + }); + + const input = 'Which language is this repo written in?'; + + if (stream) { + // Streaming + const result = await run(agent, input, { stream: true, maxTurns: 100 }); + for await (const event of result) { + if (verbose) { + console.log(JSON.stringify(event, null, 2)); + } else { + if ( + event.type === 'raw_model_stream_event' && + event.data.type === 'model' + ) { + console.log(event.data.event.type); + } + } + } + console.log(`Done streaming; final result: ${result.finalOutput}`); + } else { + // Non-streaming + let result = await run(agent, input, { maxTurns: 100 }); + while (result.interruptions && result.interruptions.length) { + for (const interruption of result.interruptions) { + const approval = await promptApproval(interruption); + if (approval) { + result.state.approve(interruption); + } else { + result.state.reject(interruption); + } + } + result = await run(agent, result.state, { maxTurns: 100 }); + } + console.log(result.finalOutput); + + if (verbose) { + console.log('----------------------------------------------------------'); + console.log(JSON.stringify(result.newItems, null, 2)); + console.log('----------------------------------------------------------'); + } + } +} + +const args = process.argv.slice(2); +const verbose = args.includes('--verbose'); +const stream = args.includes('--stream'); + +main(verbose, stream).catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/examples/mcp/hosted-mcp-on-approval.ts b/examples/mcp/hosted-mcp-on-approval.ts new file mode 100644 index 00000000..59085d31 --- /dev/null +++ b/examples/mcp/hosted-mcp-on-approval.ts @@ -0,0 +1,85 @@ +import * as readline from 'readline/promises'; +import { stdin, stdout } from 'node:process'; +import { Agent, run, hostedMcpTool, RunToolApprovalItem } from '@openai/agents'; + +async function promptApproval(item: RunToolApprovalItem): Promise { + const rl = readline.createInterface({ input: stdin, output: stdout }); + const name = item.rawItem.name; + const params = JSON.parse(item.rawItem.providerData?.arguments || '{}'); + const answer = await rl.question( + `Approve running tool (mcp: ${name}, params: ${JSON.stringify(params)})? (y/n) `, + ); + rl.close(); + return answer.toLowerCase().trim() === 'y'; +} + +async function main(verbose: boolean, stream: boolean): Promise { + // 'always' | + // 'never' | + // { never?: { toolNames: string[] }; always?: { toolNames: string[] } } + const requireApproval = { + never: { + toolNames: ['fetch_codex_documentation', 'fetch_generic_url_content'], + }, + always: { + toolNames: ['search_codex_code'], + }, + }; + const agent = new Agent({ + name: 'MCP Assistant', + instructions: 'You must always use the MCP tools to answer questions.', + tools: [ + hostedMcpTool({ + serverLabel: 'gitmcp', + serverUrl: 'https://gitmcp.io/openai/codex', + requireApproval, + onApproval: async (_context, item) => { + const approval = await promptApproval(item); + return { approve: approval, reason: undefined }; + }, + }), + ], + }); + + const input = 'Which language is this repo written in?'; + + if (stream) { + // Streaming + const result = await run(agent, input, { stream: true }); + for await (const event of result) { + if (verbose) { + console.log(JSON.stringify(event, null, 2)); + } else { + if ( + event.type === 'raw_model_stream_event' && + event.data.type === 'model' + ) { + console.log(event.data.event.type); + } + } + } + console.log(`Done streaming; final result: ${result.finalOutput}`); + } else { + // Non-streaming + let result = await run(agent, input); + while (result.interruptions && result.interruptions.length) { + result = await run(agent, result.state); + } + console.log(result.finalOutput); + + if (verbose) { + console.log('----------------------------------------------------------'); + console.log(JSON.stringify(result.newItems, null, 2)); + console.log('----------------------------------------------------------'); + } + } +} + +const args = process.argv.slice(2); +const verbose = args.includes('--verbose'); +const stream = args.includes('--stream'); + +main(verbose, stream).catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/examples/mcp/hosted-mcp-simple.ts b/examples/mcp/hosted-mcp-simple.ts new file mode 100644 index 00000000..439ab546 --- /dev/null +++ b/examples/mcp/hosted-mcp-simple.ts @@ -0,0 +1,55 @@ +import { Agent, run, hostedMcpTool, withTrace } from '@openai/agents'; + +async function main(verbose: boolean, stream: boolean): Promise { + withTrace('Hosted MCP Example', async () => { + const agent = new Agent({ + name: 'MCP Assistant', + instructions: 'You must always use the MCP tools to answer questions.', + tools: [ + hostedMcpTool({ + serverLabel: 'gitmcp', + serverUrl: 'https://gitmcp.io/openai/codex', + requireApproval: 'never', + }), + ], + }); + + const input = + 'Which language is the repo I pointed in the MCP tool settings written in?'; + if (stream) { + const result = await run(agent, input, { stream: true }); + for await (const event of result) { + if ( + event.type === 'raw_model_stream_event' && + event.data.type === 'model' && + event.data.event.type !== 'response.mcp_call_arguments.delta' && + event.data.event.type !== 'response.output_text.delta' + ) { + console.log(`Got event of type ${JSON.stringify(event.data)}`); + } + } + for (const item of result.newItems) { + console.log(JSON.stringify(item, null, 2)); + } + console.log(`Done streaming; final result: ${result.finalOutput}`); + } else { + const res = await run(agent, input); + // The repository is primarily written in multiple languages, including Rust and TypeScript... + if (verbose) { + for (const item of res.output) { + console.log(JSON.stringify(item, null, 2)); + } + } + console.log(res.finalOutput); + } + }); +} + +const args = process.argv.slice(2); +const verbose = args.includes('--verbose'); +const stream = args.includes('--stream'); + +main(verbose, stream).catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/examples/mcp/package.json b/examples/mcp/package.json index 81e30f62..92a08322 100644 --- a/examples/mcp/package.json +++ b/examples/mcp/package.json @@ -8,6 +8,9 @@ }, "scripts": { "build-check": "tsc --noEmit", - "start:stdio": "tsx filesystem-example.ts" + "start:stdio": "tsx filesystem-example.ts", + "start:hosted-mcp-on-approval": "tsx hosted-mcp-on-approval.ts", + "start:hosted-mcp-human-in-the-loop": "tsx hosted-mcp-human-in-the-loop.ts", + "start:hosted-mcp-simple": "tsx hosted-mcp-simple.ts" } } diff --git a/packages/agents-core/src/index.ts b/packages/agents-core/src/index.ts index f21f6f03..e2cfc772 100644 --- a/packages/agents-core/src/index.ts +++ b/packages/agents-core/src/index.ts @@ -99,6 +99,8 @@ export { HostedTool, ComputerTool, computerTool, + HostedMCPTool, + hostedMcpTool, FunctionTool, FunctionToolResult, Tool, diff --git a/packages/agents-core/src/items.ts b/packages/agents-core/src/items.ts index eee492bd..4955994a 100644 --- a/packages/agents-core/src/items.ts +++ b/packages/agents-core/src/items.ts @@ -142,7 +142,7 @@ export class RunToolApprovalItem extends RunItemBase { public readonly type = 'tool_approval_item' as const; constructor( - public rawItem: protocol.FunctionCallItem, + public rawItem: protocol.FunctionCallItem | protocol.HostedToolCallItem, public agent: Agent, ) { super(); diff --git a/packages/agents-core/src/runContext.ts b/packages/agents-core/src/runContext.ts index cd579231..8eb03202 100644 --- a/packages/agents-core/src/runContext.ts +++ b/packages/agents-core/src/runContext.ts @@ -117,7 +117,12 @@ export class RunContext { rejected: [], }; if (Array.isArray(approvalEntry.approved)) { - approvalEntry.approved.push(approvalItem.rawItem.callId); + // function tool has call_id, hosted tool call has id + const callId = + 'callId' in approvalItem.rawItem + ? approvalItem.rawItem.callId // function tools + : approvalItem.rawItem.id!; // hosted tools + approvalEntry.approved.push(callId); } this.#approvals.set(toolName, approvalEntry); } @@ -146,7 +151,12 @@ export class RunContext { }; if (Array.isArray(approvalEntry.rejected)) { - approvalEntry.rejected.push(approvalItem.rawItem.callId); + // function tool has call_id, hosted tool call has id + const callId = + 'callId' in approvalItem.rawItem + ? approvalItem.rawItem.callId // function tools + : approvalItem.rawItem.id!; // hosted tools + approvalEntry.rejected.push(callId); } this.#approvals.set(toolName, approvalEntry); } diff --git a/packages/agents-core/src/runImplementation.ts b/packages/agents-core/src/runImplementation.ts index e79887c2..6036a022 100644 --- a/packages/agents-core/src/runImplementation.ts +++ b/packages/agents-core/src/runImplementation.ts @@ -14,7 +14,13 @@ import { } from './items'; import logger, { Logger } from './logger'; import { ModelResponse, ModelSettings } from './model'; -import { ComputerTool, FunctionTool, Tool, FunctionToolResult } from './tool'; +import { + ComputerTool, + FunctionTool, + Tool, + FunctionToolResult, + HostedMCPTool, +} from './tool'; import { AgentInputItem, UnknownContext } from './types'; import { Runner } from './run'; import { RunContext } from './runContext'; @@ -31,6 +37,7 @@ import * as protocol from './types/protocol'; import { Computer } from './computer'; import { RunState } from './runState'; import { isZodObject } from './utils'; +import * as ProviderData from './types/providerData'; type ToolRunHandoff = { toolCall: protocol.FunctionCallItem; @@ -47,11 +54,17 @@ type ToolRunComputer = { computer: ComputerTool; }; +type ToolRunMCPApprovalRequest = { + requestItem: RunToolApprovalItem; + mcpTool: HostedMCPTool; +}; + export type ProcessedResponse = { newItems: RunItem[]; handoffs: ToolRunHandoff[]; functions: ToolRunFunction[]; computerActions: ToolRunComputer[]; + mcpApprovalRequests: ToolRunMCPApprovalRequest[]; toolsUsed: string[]; hasToolsOrApprovalsToRun(): boolean; }; @@ -69,12 +82,19 @@ export function processModelResponse( const runHandoffs: ToolRunHandoff[] = []; const runFunctions: ToolRunFunction[] = []; const runComputerActions: ToolRunComputer[] = []; + const runMCPApprovalRequests: ToolRunMCPApprovalRequest[] = []; const toolsUsed: string[] = []; const handoffMap = new Map(handoffs.map((h) => [h.toolName, h])); const functionMap = new Map( tools.filter((t) => t.type === 'function').map((t) => [t.name, t]), ); const computerTool = tools.find((t) => t.type === 'computer'); + const mcpToolMap = new Map( + tools + .filter((t) => t.type === 'hosted_tool' && t.providerData?.type === 'mcp') + .map((t) => t as HostedMCPTool) + .map((t) => [t.providerData.server_label, t]), + ); for (const output of modelResponse.output) { if (output.type === 'message') { @@ -83,7 +103,51 @@ export function processModelResponse( } } else if (output.type === 'hosted_tool_call') { items.push(new RunToolCallItem(output, agent)); - toolsUsed.push(output.name); + const toolName = output.name; + toolsUsed.push(toolName); + + if ( + output.providerData?.type === 'mcp_approval_request' || + output.name === 'mcp_approval_request' + ) { + // Hosted remote MCP server's approval process + const providerData = + output.providerData as ProviderData.HostedMCPApprovalRequest; + + const mcpServerLabel = providerData.server_label; + const mcpServerTool = mcpToolMap.get(mcpServerLabel); + if (typeof mcpServerTool === 'undefined') { + const message = `MCP server (${mcpServerLabel}) not found in Agent (${agent.name})`; + addErrorToCurrentSpan({ + message, + data: { mcp_server_label: mcpServerLabel }, + }); + throw new ModelBehaviorError(message); + } + + // Do this approval later: + // We support both onApproval callback (like the Python SDK does) and HITL patterns. + const approvalItem = new RunToolApprovalItem( + { + type: 'hosted_tool_call', + // We must use this name to align with the name sent from the servers + name: providerData.name, + id: providerData.id, + status: 'in_progress', + providerData, + }, + agent, + ); + runMCPApprovalRequests.push({ + requestItem: approvalItem, + mcpTool: mcpServerTool, + }); + if (!mcpServerTool.providerData.on_approval) { + // When onApproval function exists, it confirms the approval right after this. + // Thus, this approval item must be appended only for the next turn interrpution patterns. + items.push(approvalItem); + } + } } else if (output.type === 'reasoning') { items.push(new RunReasoningItem(output, agent)); } else if (output.type === 'computer_call') { @@ -147,11 +211,13 @@ export function processModelResponse( handoffs: runHandoffs, functions: runFunctions, computerActions: runComputerActions, + mcpApprovalRequests: runMCPApprovalRequests, toolsUsed: toolsUsed, hasToolsOrApprovalsToRun(): boolean { return ( runHandoffs.length > 0 || runFunctions.length > 0 || + runMCPApprovalRequests.length > 0 || runComputerActions.length > 0 ); }, @@ -236,20 +302,18 @@ export async function executeInterruptedToolsAndSideEffects( runner: Runner, state: RunState>, ): Promise { - const preStepItems = originalPreStepItems.filter((item) => { - return !(item instanceof RunToolApprovalItem); - }); - - const approvalRequests = originalPreStepItems - .filter((item) => { - return item instanceof RunToolApprovalItem; - }) - .map((item) => { - return item.rawItem.callId; - }); - + // call_ids for function tools + const functionCallIds = originalPreStepItems + .filter( + (item) => + item instanceof RunToolApprovalItem && + 'callId' in item.rawItem && + item.rawItem.type === 'function_call', + ) + .map((item) => (item.rawItem as protocol.FunctionCallItem).callId); + // Run function tools that require approval after they get their approval results const functionToolRuns = processedResponse.functions.filter((run) => { - return approvalRequests.includes(run.toolCall.callId); + return functionCallIds.includes(run.toolCall.callId); }); const functionResults = await executeFunctionToolCalls( @@ -259,7 +323,46 @@ export async function executeInterruptedToolsAndSideEffects( state, ); - const newItems = functionResults.map((r) => r.runItem); + // Create the initial set of the output items + const newItems: RunItem[] = functionResults.map((r) => r.runItem); + + // Run MCP tools that require approval after they get their approval results + const mcpApprovalRuns = processedResponse.mcpApprovalRequests.filter( + (run) => { + return ( + run.requestItem.type === 'tool_approval_item' && + run.requestItem.rawItem.type === 'hosted_tool_call' && + run.requestItem.rawItem.providerData?.type === 'mcp_approval_request' + ); + }, + ); + for (const run of mcpApprovalRuns) { + // the approval_request_id "mcpr_123..." + const approvalRequestId = run.requestItem.rawItem.id!; + const approved = state._context.isToolApproved({ + // Since this item name must be the same with the one sent from Responses API server + toolName: run.requestItem.rawItem.name, + callId: approvalRequestId, + }); + if (typeof approved !== 'undefined') { + const providerData: ProviderData.HostedMCPApprovalResponse = { + approve: approved, + approval_request_id: approvalRequestId, + reason: undefined, + }; + // Tell Responses API server the approval result in the next turn + newItems.push( + new RunToolCallItem( + { + type: 'hosted_tool_call', + name: 'mcp_approval_response', + providerData, + }, + agent as Agent, + ), + ); + } + } const checkToolOutput = await checkForFinalOutputFromTools( agent, @@ -267,6 +370,12 @@ export async function executeInterruptedToolsAndSideEffects( state, ); + // Exclude the tool approval items, which should not be sent to Responses API, + // from the SingleStepResult's preStepItems + const preStepItems = originalPreStepItems.filter((item) => { + return !(item instanceof RunToolApprovalItem); + }); + if (checkToolOutput.isFinalOutput) { runner.emit( 'agent_end', @@ -344,6 +453,58 @@ export async function executeToolsAndSideEffects( newItems = newItems.concat(functionResults.map((r) => r.runItem)); newItems = newItems.concat(computerResults); + // run hosted MCP approval requests + if (processedResponse.mcpApprovalRequests.length > 0) { + for (const approvalRequest of processedResponse.mcpApprovalRequests) { + const toolData = approvalRequest.mcpTool + .providerData as ProviderData.HostedMCPTool; + const requestData = approvalRequest.requestItem.rawItem + .providerData as ProviderData.HostedMCPApprovalRequest; + if (toolData.on_approval) { + // synchronously handle the approval process here + const approvalResult = await toolData.on_approval( + state._context, + approvalRequest.requestItem, + ); + const approvalResponseData: ProviderData.HostedMCPApprovalResponse = { + approve: approvalResult.approve, + approval_request_id: requestData.id, + reason: approvalResult.reason, + }; + newItems.push( + new RunToolCallItem( + { + type: 'hosted_tool_call', + name: 'mcp_approval_response', + providerData: approvalResponseData, + }, + agent as Agent, + ), + ); + } else { + // receive a user's approval on the next turn + newItems.push(approvalRequest.requestItem); + const approvalItem = { + type: 'hosted_mcp_tool_approval' as const, + tool: approvalRequest.mcpTool, + runItem: new RunToolApprovalItem( + { + type: 'hosted_tool_call', + name: requestData.name, + id: requestData.id, + arguments: requestData.arguments, + status: 'in_progress', + providerData: requestData, + }, + agent, + ), + }; + functionResults.push(approvalItem); + // newItems.push(approvalItem.runItem); + } + } + } + // process handoffs if (processedResponse.handoffs.length > 0) { return await executeHandoffCalls( @@ -521,6 +682,7 @@ export async function executeFunctionToolCalls( }); if (approval === false) { + // rejected return withFunctionSpan( async (span) => { const response = 'Tool execution was not approved.'; @@ -554,6 +716,7 @@ export async function executeFunctionToolCalls( } if (approval !== true) { + // this approval process needs to be done in the next turn return { type: 'function_approval' as const, tool: toolRun.tool, diff --git a/packages/agents-core/src/runState.ts b/packages/agents-core/src/runState.ts index 6acb049a..980601e3 100644 --- a/packages/agents-core/src/runState.ts +++ b/packages/agents-core/src/runState.ts @@ -31,6 +31,7 @@ import * as protocol from './types/protocol'; import { AgentInputItem, UnknownContext } from './types'; import type { InputGuardrailResult, OutputGuardrailResult } from './guardrail'; import { safeExecute } from './utils/safeExecute'; +import { HostedMCPTool } from './tool'; /** * The schema version of the serialized run state. This is used to ensure that the serialized @@ -118,7 +119,7 @@ const itemSchema = z.discriminatedUnion('type', [ }), z.object({ type: z.literal('tool_approval_item'), - rawItem: protocol.FunctionCallItem, + rawItem: protocol.FunctionCallItem.or(protocol.HostedToolCallItem), agent: serializedAgentSchema, }), ]); @@ -152,6 +153,28 @@ const serializedProcessedResponseSchema = z.object({ computer: z.any(), }), ), + mcpApprovalRequests: z + .array( + z.object({ + requestItem: z.object({ + // protocol.HostedToolCallItem + rawItem: z.object({ + type: z.literal('hosted_tool_call'), + name: z.string(), + arguments: z.string().optional(), + status: z.string().optional(), + output: z.string().optional(), + }), + }), + // HostedMCPTool + mcpTool: z.object({ + type: z.literal('hosted_tool'), + name: z.literal('hosted_mcp'), + providerData: z.record(z.string(), z.any()), + }), + }), + ) + .optional(), }); const guardrailFunctionOutputSchema = z.object({ @@ -734,6 +757,16 @@ async function deserializeProcessedResponse( }; }, ), + mcpApprovalRequests: ( + serializedProcessedResponse.mcpApprovalRequests ?? [] + ).map((approvalRequest) => ({ + requestItem: new RunToolApprovalItem( + approvalRequest.requestItem + .rawItem as unknown as protocol.HostedToolCallItem, + currentAgent, + ), + mcpTool: approvalRequest.mcpTool as unknown as HostedMCPTool, + })), }; return { @@ -742,6 +775,7 @@ async function deserializeProcessedResponse( return ( result.handoffs.length > 0 || result.functions.length > 0 || + result.mcpApprovalRequests.length > 0 || result.computerActions.length > 0 ); }, diff --git a/packages/agents-core/src/tool.ts b/packages/agents-core/src/tool.ts index 8e7f0015..d815e5cb 100644 --- a/packages/agents-core/src/tool.ts +++ b/packages/agents-core/src/tool.ts @@ -16,6 +16,7 @@ import logger from './logger'; import { getCurrentSpan } from './tracing'; import { RunToolApprovalItem, RunToolCallOutputItem } from './items'; import { toSmartString } from './utils/smartString'; +import * as ProviderData from './types/providerData'; /** * A function that determines if a tool call should be approved. @@ -110,6 +111,70 @@ export function computerTool( }; } +export type HostedMCPApprovalFunction = ( + context: RunContext, + data: RunToolApprovalItem, +) => Promise<{ approve: boolean; reason?: string }>; + +/** + * A hosted MCP tool that lets the model call a remote MCP server directly + * without a round trip back to your code. + */ +export type HostedMCPTool = HostedTool & { + name: 'hosted_mcp'; + providerData: ProviderData.HostedMCPTool; +}; + +/** + * Creates a hosted MCP tool definition. + * + * @param serverLabel - The label identifying the MCP server. + * @param serverUrl - The URL of the MCP server. + * @param requireApproval - Whether tool calls require approval. + */ +export function hostedMcpTool( + options: { + serverLabel: string; + serverUrl: string; + } & ( + | { requireApproval: never } + | { requireApproval: 'never' } + | { + requireApproval: + | 'always' + | { + never?: { toolNames: string[] }; + always?: { toolNames: string[] }; + }; + onApproval?: HostedMCPApprovalFunction; + } + ), +): HostedMCPTool { + const providerData: ProviderData.HostedMCPTool = + options.requireApproval === 'never' + ? { + type: 'mcp', + server_label: options.serverLabel, + server_url: options.serverUrl, + require_approval: 'never', + } + : { + type: 'mcp', + server_label: options.serverLabel, + server_url: options.serverUrl, + require_approval: + typeof options.requireApproval === 'string' + ? options.requireApproval + : buildRequireApproval(options.requireApproval), + on_approval: options.onApproval, + }; + return { + type: 'hosted_tool', + name: 'hosted_mcp', + providerData, + }; +} + /** * A built-in hosted tool that will be executed directly by the model during the request and won't result in local code executions. * Examples of these are `web_search_call` or `file_search_call`. @@ -177,6 +242,20 @@ export type FunctionToolResult< * The item representing the tool call that is requiring approval. */ runItem: RunToolApprovalItem; + } + | { + /** + * Indiciates that the tool requires approval before it can be called. + */ + type: 'hosted_mcp_tool_approval'; + /** + * The tool that is requiring to be approved. + */ + tool: HostedMCPTool; + /** + * The item representing the tool call that is requiring approval. + */ + runItem: RunToolApprovalItem; }; /** @@ -499,3 +578,20 @@ export function tool< needsApproval, }; } + +function buildRequireApproval(requireApproval: { + never?: { toolNames: string[] }; + always?: { toolNames: string[] }; +}): { never?: { tool_names: string[] }; always?: { tool_names: string[] } } { + const result: { + never?: { tool_names: string[] }; + always?: { tool_names: string[] }; + } = {}; + if (requireApproval.always) { + result.always = { tool_names: requireApproval.always.toolNames }; + } + if (requireApproval.never) { + result.never = { tool_names: requireApproval.never.toolNames }; + } + return result; +} diff --git a/packages/agents-core/src/types/index.ts b/packages/agents-core/src/types/index.ts index 36b05b5d..42522e13 100644 --- a/packages/agents-core/src/types/index.ts +++ b/packages/agents-core/src/types/index.ts @@ -2,3 +2,4 @@ export * from './protocol'; export * from './helpers'; export * from '../model'; export * from './aliases'; +export * as ProviderData from './providerData'; diff --git a/packages/agents-core/src/types/protocol.ts b/packages/agents-core/src/types/protocol.ts index 0dc00d3e..f8ad384c 100644 --- a/packages/agents-core/src/types/protocol.ts +++ b/packages/agents-core/src/types/protocol.ts @@ -316,6 +316,14 @@ export const HostedToolCallItem = ItemBase.extend({ */ name: z.string().describe('The name of the hosted tool'), + /** + * The arguments of the hosted tool call. + */ + arguments: z + .string() + .describe('The arguments of the hosted tool call') + .optional(), + /** * The status of the tool call. */ diff --git a/packages/agents-core/src/types/providerData.ts b/packages/agents-core/src/types/providerData.ts new file mode 100644 index 00000000..04d97c92 --- /dev/null +++ b/packages/agents-core/src/types/providerData.ts @@ -0,0 +1,57 @@ +import { HostedMCPApprovalFunction } from '../tool'; +import { UnknownContext } from './aliases'; + +/** + * OpenAI providerData type definition + */ +export type HostedMCPTool = { + type: 'mcp'; + server_label: string; + server_url: string; +} & ( + | { require_approval?: 'never'; on_approval?: never } + | { + require_approval: + | 'always' + | { + never?: { tool_names: string[] }; + always?: { tool_names: string[] }; + }; + on_approval?: HostedMCPApprovalFunction; + } +); + +export type HostedMCPListTools = { + id: string; + server_label: string; + tools: { + input_schema: unknown; + name: string; + annotations?: unknown | null; + description?: string | null; + }[]; + error?: string | null; +}; +export type HostedMCPCall = { + id: string; + arguments: string; + name: string; + server_label: string; + error?: string | null; + // excluding this large data field + // output?: string | null; +}; + +export type HostedMCPApprovalRequest = { + id: string; + name: string; + arguments: string; + server_label: string; +}; + +export type HostedMCPApprovalResponse = { + id?: string; + approve: boolean; + approval_request_id: string; + reason?: string; +}; diff --git a/packages/agents-core/test/run.test.ts b/packages/agents-core/test/run.test.ts index 1ef76f19..7aef1368 100644 --- a/packages/agents-core/test/run.test.ts +++ b/packages/agents-core/test/run.test.ts @@ -88,7 +88,7 @@ describe('Runner.run', () => { const rawItem = { name: 'toolZ', - call_id: 'c1', + callId: 'c1', type: 'function_call', arguments: '{}', } as any; @@ -126,6 +126,7 @@ describe('Runner.run', () => { }, ], handoffs: [], + mcpApprovalRequests: [], computerActions: [], } as any; diff --git a/packages/agents-core/test/runState.test.ts b/packages/agents-core/test/runState.test.ts index 9984d100..341d9caf 100644 --- a/packages/agents-core/test/runState.test.ts +++ b/packages/agents-core/test/runState.test.ts @@ -252,6 +252,7 @@ describe('deserialize helpers', () => { functions: [], handoffs: [], computerActions: [{ toolCall: call, computer: tool }], + mcpApprovalRequests: [], toolsUsed: [], hasToolsOrApprovalsToRun: () => true, }; @@ -277,6 +278,7 @@ describe('deserialize helpers', () => { functions: [], handoffs: [], computerActions: [{ toolCall: call, computer: tool }], + mcpApprovalRequests: [], toolsUsed: [], hasToolsOrApprovalsToRun: () => true, }; diff --git a/packages/agents-core/test/tool.test.ts b/packages/agents-core/test/tool.test.ts index 4d951fc4..27da8855 100644 --- a/packages/agents-core/test/tool.test.ts +++ b/packages/agents-core/test/tool.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect } from 'vitest'; -import { computerTool, tool } from '../src/tool'; +import { computerTool, hostedMcpTool, tool } from '../src/tool'; import { z } from 'zod/v3'; import { Computer } from '../src'; import { RunContext } from '../src/runContext'; @@ -35,6 +35,21 @@ describe('Tool', () => { }); }); +describe('create a tool using hostedMcpTool utility', () => { + it('hostedMcpTool', () => { + const t = hostedMcpTool({ + serverLabel: 'gitmcp', + serverUrl: 'https://gitmcp.io/openai/codex', + requireApproval: 'never', + }); + expect(t).toBeDefined(); + expect(t.type).toBe('hosted_tool'); + expect(t.name).toBe('hosted_mcp'); + expect(t.providerData.type).toBe('mcp'); + expect(t.providerData.server_label).toBe('gitmcp'); + }); +}); + describe('tool.invoke', () => { it('parses input and returns result', async () => { const t = tool({ diff --git a/packages/agents-openai/src/openaiResponsesModel.ts b/packages/agents-openai/src/openaiResponsesModel.ts index c37a8756..c544ad10 100644 --- a/packages/agents-openai/src/openaiResponsesModel.ts +++ b/packages/agents-openai/src/openaiResponsesModel.ts @@ -33,6 +33,8 @@ import { ImageGenerationStatus, WebSearchStatus, } from './tools'; +import { camelOrSnakeToSnakeCase } from './utils/providerData'; +import { ProviderData } from '@openai/agents-core/types'; type ToolChoice = ToolChoiceOptions | ToolChoiceTypes | ToolChoiceFunction; @@ -40,6 +42,9 @@ const HostedToolChoice = z.enum([ 'file_search', 'web_search_preview', 'computer_use_preview', + 'code_interpreter', + 'image_generation', + 'mcp', ]); const DefaultToolChoice = z.enum(['auto', 'required', 'none']); @@ -133,8 +138,8 @@ function converTool<_TContext = unknown>( return { tool: { type: 'web_search_preview', - user_location: tool.providerData.userLocation, - search_context_size: tool.providerData.searchContextSize, + user_location: tool.providerData.user_location, + search_context_size: tool.providerData.search_context_size, }, include: undefined, }; @@ -142,12 +147,17 @@ function converTool<_TContext = unknown>( return { tool: { type: 'file_search', - vector_store_ids: tool.providerData.vectorStoreId, - max_num_results: tool.providerData.maxNumResults, - ranking_options: tool.providerData.rankingOptions, + vector_store_ids: + tool.providerData.vector_store_ids || + // for backwards compatibility + (typeof tool.providerData.vector_store_id === 'string' + ? [tool.providerData.vector_store_id] + : tool.providerData.vector_store_id), + max_num_results: tool.providerData.max_num_results, + ranking_options: tool.providerData.ranking_options, filters: tool.providerData.filters, }, - include: tool.providerData.includeSearchResults + include: tool.providerData.include_search_results ? ['file_search_call.results'] : undefined, }; @@ -164,17 +174,29 @@ function converTool<_TContext = unknown>( tool: { type: 'image_generation', background: tool.providerData.background, - input_image_mask: tool.providerData.inputImageMask, + input_image_mask: tool.providerData.input_image_mask, model: tool.providerData.model, moderation: tool.providerData.moderation, - output_compression: tool.providerData.outputCompression, - output_format: tool.providerData.outputFormat, - partial_images: tool.providerData.partialImages, + output_compression: tool.providerData.output_compression, + output_format: tool.providerData.output_format, + partial_images: tool.providerData.partial_images, quality: tool.providerData.quality, size: tool.providerData.size, }, include: undefined, }; + } else if (tool.providerData?.type === 'mcp') { + return { + tool: { + type: 'mcp', + server_label: tool.providerData.server_label, + server_url: tool.providerData.server_url, + require_approval: convertMCPRequireApproval( + tool.providerData.require_approval, + ), + }, + include: undefined, + }; } else if (tool.providerData) { return { tool: tool.providerData as unknown as OpenAI.Responses.Tool, @@ -186,6 +208,23 @@ function converTool<_TContext = unknown>( throw new Error(`Unsupported tool type: ${JSON.stringify(tool)}`); } +function convertMCPRequireApproval( + requireApproval: ProviderData.HostedMCPTool['require_approval'], +): OpenAI.Responses.Tool.Mcp.McpToolApprovalFilter | 'always' | 'never' | null { + if (requireApproval === 'never' || requireApproval === undefined) { + return 'never'; + } + + if (requireApproval === 'always') { + return 'always'; + } + + return { + never: { tool_names: requireApproval.never?.tool_names }, + always: { tool_names: requireApproval.always?.tool_names }, + }; +} + function getHandoffTool(handoff: SerializedHandoff): OpenAI.Responses.Tool { return { name: handoff.toolName, @@ -203,7 +242,7 @@ function getInputMessageContent( return { type: 'input_text', text: entry.text, - ...entry.providerData, + ...camelOrSnakeToSnakeCase(entry.providerData), }; } else if (entry.type === 'input_image') { const imageEntry: OpenAI.Responses.ResponseInputImage = { @@ -217,7 +256,7 @@ function getInputMessageContent( } return { ...imageEntry, - ...entry.providerData, + ...camelOrSnakeToSnakeCase(entry.providerData), }; } else if (entry.type === 'input_file') { const fileEntry: OpenAI.Responses.ResponseInputFile = { @@ -230,7 +269,7 @@ function getInputMessageContent( } return { ...fileEntry, - ...entry.providerData, + ...camelOrSnakeToSnakeCase(entry.providerData), }; } @@ -247,7 +286,7 @@ function getOutputMessageContent( type: 'output_text', text: entry.text, annotations: [], - ...entry.providerData, + ...camelOrSnakeToSnakeCase(entry.providerData), }; } @@ -255,7 +294,7 @@ function getOutputMessageContent( return { type: 'refusal', refusal: entry.refusal, - ...entry.providerData, + ...camelOrSnakeToSnakeCase(entry.providerData), }; } @@ -275,7 +314,7 @@ function getMessageItem( id: item.id, role: 'system', content: item.content, - ...item.providerData, + ...camelOrSnakeToSnakeCase(item.providerData), }; } @@ -285,7 +324,7 @@ function getMessageItem( id: item.id, role: 'user', content: item.content, - ...item.providerData, + ...camelOrSnakeToSnakeCase(item.providerData), }; } @@ -293,7 +332,7 @@ function getMessageItem( id: item.id, role: 'user', content: item.content.map(getInputMessageContent), - ...item.providerData, + ...camelOrSnakeToSnakeCase(item.providerData), }; } @@ -304,7 +343,7 @@ function getMessageItem( role: 'assistant', content: item.content.map(getOutputMessageContent), status: item.status, - ...item.providerData, + ...camelOrSnakeToSnakeCase(item.providerData), }; return assistantMessage; } @@ -349,7 +388,7 @@ function getInputItems( call_id: item.callId, arguments: item.arguments, status: item.status, - ...item.providerData, + ...camelOrSnakeToSnakeCase(item.providerData), }; return entry; @@ -367,7 +406,8 @@ function getInputItems( id: item.id, call_id: item.callId, output: item.output.text, - ...item.providerData, + status: item.status, + ...camelOrSnakeToSnakeCase(item.providerData), }; return entry; @@ -380,9 +420,10 @@ function getInputItems( summary: item.content.map((content) => ({ type: 'summary_text', text: content.text, - ...content.providerData, + ...camelOrSnakeToSnakeCase(content.providerData), })), - ...item.providerData, + encrypted_content: item.providerData?.encryptedContent, + ...camelOrSnakeToSnakeCase(item.providerData), }; return entry; } @@ -395,7 +436,7 @@ function getInputItems( action: item.action, status: item.status, pending_safety_checks: [], - ...item.providerData, + ...camelOrSnakeToSnakeCase(item.providerData), }; return entry; @@ -407,7 +448,9 @@ function getInputItems( id: item.id, call_id: item.callId, output: buildResponseOutput(item), - ...item.providerData, + status: item.providerData?.status, + acknowledged_safety_checks: item.providerData?.acknowledgedSafetyChecks, + ...camelOrSnakeToSnakeCase(item.providerData), }; return entry; } @@ -418,7 +461,7 @@ function getInputItems( type: 'web_search_call', id: item.id!, status: WebSearchStatus.parse(item.status ?? 'failed'), - ...item.providerData, + ...camelOrSnakeToSnakeCase(item.providerData), }; return entry; @@ -430,7 +473,8 @@ function getInputItems( id: item.id!, status: FileSearchStatus.parse(item.status ?? 'failed'), queries: item.providerData?.queries ?? [], - ...item.providerData, + results: item.providerData?.results, + ...camelOrSnakeToSnakeCase(item.providerData), }; return entry; @@ -443,7 +487,8 @@ function getInputItems( code: item.providerData?.code ?? '', results: item.providerData?.results ?? [], status: CodeInterpreterStatus.parse(item.status ?? 'failed'), - ...item.providerData, + container_id: item.providerData?.containerId, + ...camelOrSnakeToSnakeCase(item.providerData), }; return entry; @@ -455,12 +500,76 @@ function getInputItems( id: item.id!, result: item.providerData?.result ?? null, status: ImageGenerationStatus.parse(item.status ?? 'failed'), - ...item.providerData, + ...camelOrSnakeToSnakeCase(item.providerData), }; return entry; } + if ( + item.providerData?.type === 'mcp_list_tools' || + item.name === 'mcp_list_tools' + ) { + const providerData = + item.providerData as ProviderData.HostedMCPListTools; + const entry: OpenAI.Responses.ResponseInputItem.McpListTools = { + type: 'mcp_list_tools', + id: item.id!, + tools: camelOrSnakeToSnakeCase(providerData.tools) as any, + server_label: providerData.server_label, + error: providerData.error, + ...camelOrSnakeToSnakeCase(item.providerData), + }; + return entry; + } else if ( + item.providerData?.type === 'mcp_approval_request' || + item.name === 'mcp_approval_request' + ) { + const providerData = + item.providerData as ProviderData.HostedMCPApprovalRequest; + const entry: OpenAI.Responses.ResponseInputItem.McpApprovalRequest = { + type: 'mcp_approval_request', + id: providerData.id ?? item.id!, + name: providerData.name, + arguments: providerData.arguments, + server_label: providerData.server_label, + ...camelOrSnakeToSnakeCase(item.providerData), + }; + return entry; + } else if ( + item.providerData?.type === 'mcp_approval_response' || + item.name === 'mcp_approval_response' + ) { + const providerData = + item.providerData as ProviderData.HostedMCPApprovalResponse; + const entry: OpenAI.Responses.ResponseInputItem.McpApprovalResponse = { + type: 'mcp_approval_response', + id: providerData.id, + approve: providerData.approve, + approval_request_id: providerData.approval_request_id, + reason: providerData.reason, + ...camelOrSnakeToSnakeCase(providerData), + }; + return entry; + } else if ( + item.providerData?.type === 'mcp_call' || + item.name === 'mcp_call' + ) { + const providerData = item.providerData as ProviderData.HostedMCPCall; + const entry: OpenAI.Responses.ResponseInputItem.McpCall = { + type: 'mcp_call', + id: providerData.id ?? item.id!, + name: providerData.name, + arguments: providerData.arguments, + server_label: providerData.server_label, + error: providerData.error, + // output, which can be a large text string, is optional here, so we don't include it + // output: item.output, + ...camelOrSnakeToSnakeCase(providerData), + }; + return entry; + } + throw new UserError( `Unsupported built-in tool call type: ${JSON.stringify(item)}`, ); @@ -469,7 +578,7 @@ function getInputItems( if (item.type === 'unknown') { return { id: item.id, - ...item.providerData, + ...camelOrSnakeToSnakeCase(item.providerData), } as OpenAI.Responses.ResponseItem; } @@ -495,7 +604,7 @@ function convertToMessageContentItem( const { type, text, ...remainingItem } = item; return { type, - text: text, + text, ...remainingItem, }; } @@ -504,7 +613,7 @@ function convertToMessageContentItem( const { type, refusal, ...remainingItem } = item; return { type, - refusal: refusal, + refusal, ...remainingItem, }; } @@ -517,14 +626,14 @@ function convertToOutputItem( ): protocol.OutputModelItem[] { return items.map((item) => { if (item.type === 'message') { - const { id, type, role, content, status, ...remainingItem } = item; + const { id, type, role, content, status, ...providerData } = item; return { - type, id, + type, role, content: content.map(convertToMessageContentItem), status, - providerData: remainingItem, + providerData, }; } else if ( item.type === 'file_search_call' || @@ -532,64 +641,95 @@ function convertToOutputItem( item.type === 'image_generation_call' || item.type === 'code_interpreter_call' ) { - const { id, type, status, ...remainingItem } = item; - const outputData = - 'result' in remainingItem && remainingItem.result !== null - ? remainingItem.result // type: "image_generation_call" - : undefined; + const { status, ...remainingItem } = item; + let outputData = undefined; + if ('result' in remainingItem && remainingItem.result !== null) { + // type: "image_generation_call" + outputData = remainingItem.result; + delete (remainingItem as any).result; + } const output: protocol.HostedToolCallItem = { type: 'hosted_tool_call', - id, - name: type, - status: status, + id: item.id!, + name: item.type, + status, output: outputData, providerData: remainingItem, }; return output; } else if (item.type === 'function_call') { - const { - id, - call_id, - name, - status, - arguments: args, - ...remainingItem - } = item; + const { call_id, name, status, arguments: args, ...providerData } = item; const output: protocol.FunctionCallItem = { type: 'function_call', - id: id, + id: item.id!, callId: call_id, - name: name, - status: status, + name, + status, arguments: args, - providerData: remainingItem, + providerData, }; return output; } else if (item.type === 'computer_call') { - const { id, call_id, status, action, ...remainingItem } = item; + const { call_id, status, action, ...providerData } = item; const output: protocol.ComputerUseCallItem = { type: 'computer_call', - id: id, + id: item.id!, callId: call_id, - status: status, - action: action, - providerData: remainingItem, + status, + action, + providerData, + }; + return output; + } else if (item.type === 'mcp_list_tools') { + const { ...providerData } = item; + const output: protocol.HostedToolCallItem = { + type: 'hosted_tool_call', + id: item.id!, + name: item.type, + status: 'completed', + output: undefined, + providerData, + }; + return output; + } else if (item.type === 'mcp_approval_request') { + const { ...providerData } = item; + const output: protocol.HostedToolCallItem = { + type: 'hosted_tool_call', + id: item.id!, + name: 'mcp_approval_request', + status: 'completed', + output: undefined, + providerData, + }; + return output; + } else if (item.type === 'mcp_call') { + // Avoiding to duplicate potentially large output data + const { output: outputData, ...providerData } = item; + const output: protocol.HostedToolCallItem = { + type: 'hosted_tool_call', + id: item.id!, + name: item.type, + status: 'completed', + output: outputData || undefined, + providerData, }; return output; } else if (item.type === 'reasoning') { - const { id, summary, ...remainingItem } = item; + // Avoiding to duplicate potentially large summary data + const { summary, ...providerData } = item; const output: protocol.ReasoningItem = { type: 'reasoning', - id: id, + id: item.id!, content: summary.map((content) => { + // Avoiding to duplicate potentially large text const { text, ...remainingContent } = content; return { type: 'input_text', - text: text, + text, providerData: remainingContent, }; }), - providerData: remainingItem, + providerData, }; return output; } diff --git a/packages/agents-openai/src/tools.ts b/packages/agents-openai/src/tools.ts index 2fdec7ac..19476808 100644 --- a/packages/agents-openai/src/tools.ts +++ b/packages/agents-openai/src/tools.ts @@ -1,6 +1,7 @@ import { HostedTool } from '@openai/agents-core'; import type OpenAI from 'openai'; import { z } from '@openai/zod/v3'; +import * as ProviderData from './types/providerData'; // ----------------------------------------------------- // Status enums @@ -50,15 +51,16 @@ export type WebSearchTool = { export function webSearchTool( options: Partial> = {}, ): HostedTool { + const providerData: ProviderData.WebSearchTool = { + type: 'web_search', + name: options.name ?? 'web_search_preview', + user_location: options.userLocation, + search_context_size: options.searchContextSize ?? 'medium', + }; return { type: 'hosted_tool', name: options.name ?? 'web_search_preview', - providerData: { - type: 'web_search', - name: options.name ?? 'web_search_preview', - userLocation: options.userLocation, - searchContextSize: options.searchContextSize ?? 'medium', - }, + providerData, }; } @@ -103,18 +105,19 @@ export function fileSearchTool( const vectorIds = Array.isArray(vectorStoreIds) ? vectorStoreIds : [vectorStoreIds]; + const providerData: ProviderData.FileSearchTool = { + type: 'file_search', + name: options.name ?? 'file_search', + vector_store_ids: vectorIds, + max_num_results: options.maxNumResults, + include_search_results: options.includeSearchResults, + ranking_options: options.rankingOptions, + filters: options.filters, + }; return { type: 'hosted_tool', name: options.name ?? 'file_search', - providerData: { - type: 'file_search', - name: options.name ?? 'file_search', - vectorStoreId: vectorIds, - maxNumResults: options.maxNumResults, - includeSearchResults: options.includeSearchResults, - rankingOptions: options.rankingOptions, - filters: options.filters, - }, + providerData, }; } @@ -134,14 +137,15 @@ export type CodeInterpreterTool = { export function codeInterpreterTool( options: Partial> = {}, ): HostedTool { + const providerData: ProviderData.CodeInterpreterTool = { + type: 'code_interpreter', + name: options.name ?? 'code_interpreter', + container: options.container ?? { type: 'auto' }, + }; return { type: 'hosted_tool', name: options.name ?? 'code_interpreter', - providerData: { - type: 'code_interpreter', - name: options.name ?? 'code_interpreter', - container: options.container, - }, + providerData, }; } @@ -170,21 +174,24 @@ export type ImageGenerationTool = { export function imageGenerationTool( options: Partial> = {}, ): HostedTool { + const providerData: ProviderData.ImageGenerationTool = { + type: 'image_generation', + name: options.name ?? 'image_generation', + background: options.background, + input_image_mask: options.inputImageMask, + model: options.model, + moderation: options.moderation, + output_compression: options.outputCompression, + output_format: options.outputFormat, + partial_images: options.partialImages, + quality: options.quality, + size: options.size, + }; return { type: 'hosted_tool', name: options.name ?? 'image_generation', - providerData: { - type: 'image_generation', - name: options.name ?? 'image_generation', - background: options.background, - inputImageMask: options.inputImageMask, - model: options.model, - moderation: options.moderation, - outputCompression: options.outputCompression, - outputFormat: options.outputFormat, - partialImages: options.partialImages, - quality: options.quality, - size: options.size, - }, + providerData, }; } + +// HostedMCPTool exists in agents-core package diff --git a/packages/agents-openai/src/types/providerData.ts b/packages/agents-openai/src/types/providerData.ts new file mode 100644 index 00000000..73f8f8ff --- /dev/null +++ b/packages/agents-openai/src/types/providerData.ts @@ -0,0 +1,40 @@ +import OpenAI from 'openai'; + +export type WebSearchTool = Omit & { + type: 'web_search'; + name: 'web_search_preview' | string; +}; + +export type FileSearchTool = Omit & { + type: 'file_search'; + name: 'file_search' | string; + include_search_results?: boolean; +}; + +export type CodeInterpreterTool = Omit< + OpenAI.Responses.Tool.CodeInterpreter, + 'type' +> & { + type: 'code_interpreter'; + name: 'code_interpreter' | string; +}; + +export type ImageGenerationTool = Omit< + OpenAI.Responses.Tool.ImageGeneration, + | 'type' + | 'background' + | 'model' + | 'moderation' + | 'output_format' + | 'quality' + | 'size' +> & { + type: 'image_generation'; + name: 'image_generation' | string; + background?: 'transparent' | 'opaque' | 'auto' | string; + model?: 'gpt-image-1' | string; + moderation?: 'auto' | 'low' | string; + output_format?: 'png' | 'webp' | 'jpeg' | string; + quality?: 'low' | 'medium' | 'high' | 'auto' | string; + size?: '1024x1024' | '1024x1536' | '1536x1024' | 'auto' | string; +}; diff --git a/packages/agents-openai/src/utils/providerData.ts b/packages/agents-openai/src/utils/providerData.ts new file mode 100644 index 00000000..6db93a2a --- /dev/null +++ b/packages/agents-openai/src/utils/providerData.ts @@ -0,0 +1,21 @@ +/** + * Converts camelCase or snake_case keys of an object to snake_case recursively. + */ +export function camelOrSnakeToSnakeCase< + T extends Record | undefined, +>(providerData: T | undefined): Record | undefined { + if ( + !providerData || + typeof providerData !== 'object' || + Array.isArray(providerData) + ) { + return providerData; + } + + const result: Record = {}; + for (const [key, value] of Object.entries(providerData)) { + const snakeKey = key.replace(/([A-Z])/g, '_$1').toLowerCase(); + result[snakeKey] = camelOrSnakeToSnakeCase(value); + } + return result; +} diff --git a/packages/agents-openai/test/openaiResponsesModel.helpers.test.ts b/packages/agents-openai/test/openaiResponsesModel.helpers.test.ts index 68f60c68..96a326d3 100644 --- a/packages/agents-openai/test/openaiResponsesModel.helpers.test.ts +++ b/packages/agents-openai/test/openaiResponsesModel.helpers.test.ts @@ -72,8 +72,8 @@ describe('converTool', () => { type: 'hosted_tool', providerData: { type: 'web_search', - userLocation: {}, - searchContextSize: 'low', + user_location: {}, + search_context_size: 'low', }, } as any); expect(web.tool).toEqual({ @@ -86,9 +86,9 @@ describe('converTool', () => { type: 'hosted_tool', providerData: { type: 'file_search', - vectorStoreId: ['v'], - maxNumResults: 5, - includeSearchResults: true, + vector_store_ids: ['v'], + max_num_results: 5, + include_search_results: true, }, } as any); expect(file.tool).toEqual({ diff --git a/packages/agents-openai/test/utils/providerData.test.ts b/packages/agents-openai/test/utils/providerData.test.ts new file mode 100644 index 00000000..3abf2ee0 --- /dev/null +++ b/packages/agents-openai/test/utils/providerData.test.ts @@ -0,0 +1,70 @@ +import { describe, it, expect } from 'vitest'; +import { camelOrSnakeToSnakeCase } from '../../src/utils/providerData'; + +describe('camelToSnakeCase', () => { + it('converts flat camelCase keys to snake_case', () => { + expect(camelOrSnakeToSnakeCase({ fooBar: 1, bazQux: 2 })).toEqual({ + foo_bar: 1, + baz_qux: 2, + }); + }); + it('converts snake_case keys to snake_case', () => { + expect( + camelOrSnakeToSnakeCase({ foo_bar_buz: 1, baz_qux: 2, foo_bar: 3 }), + ).toEqual({ + foo_bar_buz: 1, + baz_qux: 2, + foo_bar: 3, + }); + }); + it('converts mixed keys to snake_case', () => { + expect( + camelOrSnakeToSnakeCase({ foo_barBuz: 1, bazQux: 2, foo_bar: 3 }), + ).toEqual({ + foo_bar_buz: 1, + baz_qux: 2, + foo_bar: 3, + }); + }); + + it('handles nested objects', () => { + expect( + camelOrSnakeToSnakeCase({ + outerKey: { innerKey: 42, anotherInner: { deepKey: 'x' } }, + }), + ).toEqual({ + outer_key: { inner_key: 42, another_inner: { deep_key: 'x' } }, + }); + }); + + it('handles nested objects with mixed keys', () => { + expect( + camelOrSnakeToSnakeCase({ + outerKey: { innerKey: 42, anotherInner: { deep_key: 'x' } }, + }), + ).toEqual({ + outer_key: { inner_key: 42, another_inner: { deep_key: 'x' } }, + }); + }); + + it('handles arrays and primitives', () => { + expect(camelOrSnakeToSnakeCase([1, 2, 3])).toEqual([1, 2, 3]); + expect(camelOrSnakeToSnakeCase(undefined)).toBe(undefined); + }); + + it('leaves already snake_case keys as is', () => { + expect( + camelOrSnakeToSnakeCase({ already_snake: 1, also_snake_case: 2 }), + ).toEqual({ + already_snake: 1, + also_snake_case: 2, + }); + }); + + it('handles mixed keys', () => { + expect(camelOrSnakeToSnakeCase({ fooBar: 1, already_snake: 2 })).toEqual({ + foo_bar: 1, + already_snake: 2, + }); + }); +}); diff --git a/packages/agents-realtime/src/realtimeSession.ts b/packages/agents-realtime/src/realtimeSession.ts index b8b637be..a6bef503 100644 --- a/packages/agents-realtime/src/realtimeSession.ts +++ b/packages/agents-realtime/src/realtimeSession.ts @@ -714,7 +714,11 @@ export class RealtimeSession< const tool = this.#currentAgent.tools.find( (tool) => tool.name === approvalItem.rawItem.name, ); - if (tool && tool.type === 'function') { + if ( + tool && + tool.type === 'function' && + approvalItem.rawItem.type === 'function_call' + ) { await this.#handleFunctionToolCall(approvalItem.rawItem, tool); } else { throw new ModelBehaviorError( @@ -739,7 +743,11 @@ export class RealtimeSession< const tool = this.#currentAgent.tools.find( (tool) => tool.name === approvalItem.rawItem.name, ); - if (tool && tool.type === 'function') { + if ( + tool && + tool.type === 'function' && + approvalItem.rawItem.type === 'function_call' + ) { await this.#handleFunctionToolCall(approvalItem.rawItem, tool); } else { throw new ModelBehaviorError(