diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 797b3c0d9a..6d5e223e9a 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -279,6 +279,61 @@ export function toGeminiMessage( }; } +export function toGeminiMessages( + messages: MessageData[], + model?: ModelReference +): Content[] { + const geminiMessages: Content[] = []; + + for (let i = 0; i < messages.length; i++) { + if (messages[i].content.some((part) => part.toolResponse)) { + const toolResponseMessage = messages[i]; + + // get the most recent user message in the history + const previousUserMessageIndex = messages + .slice(0, i) + .reverse() + .findIndex((m) => m.role === toGeminiRole('user', model)); + + if (previousUserMessageIndex === -1) { + throw new Error( + 'Tool response message must be preceded by a user message' + ); + } + + const actualPreviousUserMessageIndex = i - 1 - previousUserMessageIndex; + const previousUserMessage = messages[actualPreviousUserMessageIndex]; + + const newToolResponseMessage = { + role: toGeminiRole(toolResponseMessage.role, model), + parts: [toGeminiToolResponsePart(toolResponseMessage.content[0])], + }; + + const otherParts = toolResponseMessage.content.filter( + (part) => !part.toolResponse + ); + + if (otherParts.length > 0) { + const newPreviousUserMessage = { + role: toGeminiRole(previousUserMessage.role, model), + parts: [...previousUserMessage.content, ...otherParts].map( + toGeminiPart + ), + }; + + // Modify geminiMessages in place to replace the previous user message + geminiMessages[actualPreviousUserMessageIndex] = newPreviousUserMessage; + } + + geminiMessages.push(newToolResponseMessage); + } else { + geminiMessages.push(toGeminiMessage(messages[i], model)); + } + } + + return geminiMessages; +} + function fromGeminiFinishReason( reason: GenerateContentCandidate['finishReason'] ): CandidateData['finishReason'] { @@ -490,14 +545,14 @@ export function geminiModel(name: string, vertex: VertexAI): ModelAction { } } + const geminiMessages = toGeminiMessages(messages, model); + const chatRequest: StartChatParams = { systemInstruction, tools: request.tools?.length ? [{ functionDeclarations: request.tools?.map(toGeminiTool) }] : [], - history: messages - .slice(0, -1) - .map((message) => toGeminiMessage(message, model)), + history: geminiMessages.slice(0, -1), generationConfig: { candidateCount: request.candidates || undefined, temperature: request.config?.temperature, @@ -508,7 +563,7 @@ export function geminiModel(name: string, vertex: VertexAI): ModelAction { }, safetySettings: request.config?.safetySettings, }; - const msg = toGeminiMessage(messages[messages.length - 1], model); + const msg = geminiMessages[geminiMessages.length - 1]; if (streamingCallback) { const result = await client .startChat(chatRequest) diff --git a/js/plugins/vertexai/tests/gemini_test.ts b/js/plugins/vertexai/tests/gemini_test.ts index dbb39e8f1c..5fe42025e8 100644 --- a/js/plugins/vertexai/tests/gemini_test.ts +++ b/js/plugins/vertexai/tests/gemini_test.ts @@ -21,10 +21,11 @@ import { describe, it } from 'node:test'; import { fromGeminiCandidate, toGeminiMessage, + toGeminiMessages, toGeminiSystemInstruction, } from '../src/gemini.js'; -describe('toGeminiMessages', () => { +describe('toGeminiMessage', () => { const testCases = [ { should: 'should transform genkit message (text content) correctly', @@ -345,3 +346,79 @@ describe('fromGeminiCandidate', () => { }); } }); + +describe('toGeminiMessages', () => { + it('should handle tool request messages correctly', () => { + const messages = [ + { + role: 'user', + content: [{ text: 'What is the weather like today?.' }], + }, + { + role: 'model', + content: [ + { + toolRequest: { + name: 'tellMeTheWeather', + input: { request: 'tell me the weather' }, + }, + }, + ], + }, + { + role: 'tool', + content: [ + { + toolResponse: { + name: 'tellMeTheWeather', + output: 'it is sunny today', + }, + }, + { text: 'This is extra context', metadata: { purpose: 'context' } }, + ], + }, + ]; + + const expectedOutput = [ + { + role: 'user', + parts: [ + { + text: 'What is the weather like today?.', + }, + { text: 'This is extra context' }, + ], + }, + { + role: 'model', + parts: [ + { + functionCall: { + name: 'tellMeTheWeather', + args: { + request: 'tell me the weather', + }, + }, + }, + ], + }, + { + role: 'function', + parts: [ + { + functionResponse: { + name: 'tellMeTheWeather', + response: { + name: 'tellMeTheWeather', + content: 'it is sunny today', + }, + }, + }, + ], + }, + ]; + const output = toGeminiMessages(messages as MessageData[]); + + assert.deepEqual(output, expectedOutput); + }); +});