diff --git a/.changeset/fine-symbols-jam.md b/.changeset/fine-symbols-jam.md new file mode 100644 index 00000000..1e754b34 --- /dev/null +++ b/.changeset/fine-symbols-jam.md @@ -0,0 +1,5 @@ +--- +'@openai/agents-extensions': patch +--- + +fix: preserve Gemini thought_signature in multi-turn tool calls diff --git a/packages/agents-extensions/src/aiSdk.ts b/packages/agents-extensions/src/aiSdk.ts index 9279a7e9..f2b5d363 100644 --- a/packages/agents-extensions/src/aiSdk.ts +++ b/packages/agents-extensions/src/aiSdk.ts @@ -711,7 +711,9 @@ export class AiSdkModel implements Model { name: toolCall.toolName, arguments: toolCallArguments, status: 'completed', - providerData: hasToolCalls ? result.providerMetadata : undefined, + providerData: + toolCall.providerMetadata ?? + (hasToolCalls ? result.providerMetadata : undefined), }); } @@ -916,6 +918,9 @@ export class AiSdkModel implements Model { name: (part as any).toolName, arguments: (part as any).input ?? '', status: 'completed', + ...((part as any).providerMetadata + ? { providerData: (part as any).providerMetadata } + : {}), }; } break; diff --git a/packages/agents-extensions/test/aiSdk.test.ts b/packages/agents-extensions/test/aiSdk.test.ts index 897517ea..39ecf936 100644 --- a/packages/agents-extensions/test/aiSdk.test.ts +++ b/packages/agents-extensions/test/aiSdk.test.ts @@ -745,6 +745,106 @@ describe('AiSdkModel.getResponse', () => { ]); }); + test('preserves per-tool-call providerMetadata (e.g., Gemini thoughtSignature)', async () => { + const toolCallProviderMetadata = { + google: { thoughtSignature: 'sig123' }, + }; + const resultProviderMetadata = { + google: { usageMetadata: { totalTokenCount: 100 } }, + }; + + const model = new AiSdkModel( + stubModel({ + async doGenerate() { + return { + content: [ + { + type: 'tool-call', + toolCallId: 'c1', + toolName: 'get_weather', + input: { location: 'Tokyo' }, + providerMetadata: toolCallProviderMetadata, + }, + ], + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30 }, + providerMetadata: resultProviderMetadata, + response: { id: 'resp-1' }, + finishReason: 'tool-calls', + warnings: [], + } as any; + }, + }), + ); + + const res = await withTrace('t', () => + model.getResponse({ + input: 'What is the weather in Tokyo?', + tools: [ + { + type: 'function', + name: 'get_weather', + description: 'Get weather', + parameters: { type: 'object', properties: {} }, + }, + ], + handoffs: [], + modelSettings: {}, + outputType: 'text', + tracing: false, + } as any), + ); + + expect(res.output).toHaveLength(1); + expect(res.output[0]).toMatchObject({ + type: 'function_call', + callId: 'c1', + name: 'get_weather', + providerData: toolCallProviderMetadata, + }); + // Ensure we get per-tool-call metadata, not result-level metadata + expect(res.output[0].providerData).not.toEqual(resultProviderMetadata); + }); + + test('falls back to result.providerMetadata when toolCall.providerMetadata is undefined', async () => { + const resultProviderMetadata = { fallback: true }; + + const model = new AiSdkModel( + stubModel({ + async doGenerate() { + return { + content: [ + { + type: 'tool-call', + toolCallId: 'c1', + toolName: 'foo', + input: {}, + // No providerMetadata on tool call + }, + ], + usage: { inputTokens: 1, outputTokens: 2, totalTokens: 3 }, + providerMetadata: resultProviderMetadata, + response: { id: 'id' }, + finishReason: 'tool-calls', + warnings: [], + } as any; + }, + }), + ); + + const res = await withTrace('t', () => + model.getResponse({ + input: 'hi', + tools: [], + handoffs: [], + modelSettings: {}, + outputType: 'text', + tracing: false, + } as any), + ); + + expect(res.output[0].providerData).toEqual(resultProviderMetadata); + }); + test('propagates errors', async () => { const model = new AiSdkModel( stubModel({ @@ -905,6 +1005,116 @@ describe('AiSdkModel.getStreamedResponse', () => { ]); }); + test('preserves per-tool-call providerMetadata in streaming mode (e.g., Gemini thoughtSignature)', async () => { + const toolCallProviderMetadata = { + google: { thoughtSignature: 'stream-sig-456' }, + }; + + const parts = [ + { + type: 'tool-call', + toolCallId: 'c1', + toolName: 'get_weather', + input: '{"location":"Tokyo"}', + providerMetadata: toolCallProviderMetadata, + }, + { type: 'response-metadata', id: 'resp-stream-1' }, + { + type: 'finish', + finishReason: 'tool-calls', + usage: { inputTokens: 10, outputTokens: 20 }, + }, + ]; + + const model = new AiSdkModel( + stubModel({ + async doStream() { + return { + stream: partsStream(parts), + } as any; + }, + }), + ); + + const events: any[] = []; + for await (const ev of model.getStreamedResponse({ + input: 'What is the weather?', + tools: [ + { + type: 'function', + name: 'get_weather', + description: 'Get weather', + parameters: { type: 'object', properties: {} }, + }, + ], + handoffs: [], + modelSettings: {}, + outputType: 'text', + tracing: false, + } as any)) { + events.push(ev); + } + + const final = events.at(-1); + expect(final.type).toBe('response_done'); + expect(final.response.output).toHaveLength(1); + expect(final.response.output[0]).toMatchObject({ + type: 'function_call', + callId: 'c1', + name: 'get_weather', + providerData: toolCallProviderMetadata, + }); + }); + + test('omits providerData in streaming mode when providerMetadata is not present', async () => { + const parts = [ + { + type: 'tool-call', + toolCallId: 'c1', + toolName: 'foo', + input: '{}', + // No providerMetadata + }, + { + type: 'finish', + finishReason: 'tool-calls', + usage: { inputTokens: 1, outputTokens: 2 }, + }, + ]; + + const model = new AiSdkModel( + stubModel({ + async doStream() { + return { + stream: partsStream(parts), + } as any; + }, + }), + ); + + const events: any[] = []; + for await (const ev of model.getStreamedResponse({ + input: 'hi', + tools: [], + handoffs: [], + modelSettings: {}, + outputType: 'text', + tracing: false, + } as any)) { + events.push(ev); + } + + const final = events.at(-1); + expect(final.type).toBe('response_done'); + expect(final.response.output[0]).toMatchObject({ + type: 'function_call', + callId: 'c1', + name: 'foo', + }); + // providerData should not be present when providerMetadata was not provided + expect(final.response.output[0].providerData).toBeUndefined(); + }); + test('propagates stream errors', async () => { const err = new Error('bad'); const parts = [{ type: 'error', error: err }];