diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 065624bb85..9f61eea3c7 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -39,6 +39,7 @@ import { ModelReference, Part, Role, + ToolRequestPart, ToolResponsePart, } from './model.js'; import * as telemetry from './telemetry.js'; @@ -104,6 +105,16 @@ export class Message implements MessageData { return this.content.find((part) => part.data)?.data as T | null; } + /** + * Returns all tool request found in this message. + * @returns Array of all tool request found in this message. + */ + toolRequests(): ToolRequestPart[] { + return this.content.filter( + (part) => !!part.toolRequest + ) as ToolRequestPart[]; + } + /** * Converts the Message to a plain JS object. * @returns Plain JS object representing the data contained in the message. @@ -183,6 +194,14 @@ export class Candidate implements CandidateData { return this.message.data(); } + /** + * Returns all tool request found in this candidate. + * @returns Array of all tool request found in this candidate. + */ + toolRequests(): ToolRequestPart[] { + return this.message.toolRequests(); + } + /** * Determine whether this candidate has output that conforms to a provided schema. * @@ -290,6 +309,15 @@ export class GenerateResponse implements GenerateResponseData { return this.candidates[index]?.data() || null; } + /** + * Returns all tool request found in the candidate. + * @param index The candidate index from which to extract tool requests, defaults to first candidate. + * @returns Array of all tool request found in the candidate. + */ + toolRequests(index: number = 0): ToolRequestPart[] { + return this.candidates[index].toolRequests(); + } + /** * Appends the message generated by the selected candidate to the messages already * present in the generation request. The result of this method can be safely @@ -361,6 +389,16 @@ export class GenerateResponseChunk return this.content.find((part) => part.data)?.data as T | null; } + /** + * Returns all tool request found in this chunk. + * @returns Array of all tool request found in this chunk. + */ + toolRequests(): ToolRequestPart[] { + return this.content.filter( + (part) => !!part.toolRequest + ) as ToolRequestPart[]; + } + toJSON(): GenerateResponseChunkData { return { index: this.index, content: this.content, custom: this.custom }; } diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index adb3edb6d6..f8717f9f86 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -230,6 +230,88 @@ describe('GenerateResponse', () => { assert.deepStrictEqual(response.output(0), { abc: '123' }); }); }); + describe('#toolRequests()', () => { + it('returns empty array if no tools requests found', () => { + const response = new GenerateResponse({ + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { role: 'model', content: [{ text: '{"abc":"123"}' }] }, + }, + { + index: 0, + finishReason: 'stop', + message: { role: 'model', content: [{ text: '{"abc":123}' }] }, + }, + ], + }); + assert.deepStrictEqual(response.toolRequests(), []); + assert.deepStrictEqual(response.toolRequests(0), []); + }); + it('picks the first candidate if no index provided', () => { + const toolCall = { + toolRequest: { + name: 'foo', + ref: 'abc', + input: 'banana', + }, + }; + const response = new GenerateResponse({ + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [toolCall], + }, + }, + { + index: 0, + finishReason: 'stop', + message: { role: 'model', content: [{ text: '{"abc":123}' }] }, + }, + ], + }); + assert.deepStrictEqual(response.toolRequests(), [toolCall]); + assert.deepStrictEqual(response.toolRequests(0), [toolCall]); + }); + it('returns all tool call', () => { + const toolCall1 = { + toolRequest: { + name: 'foo', + ref: 'abc', + input: 'banana', + }, + }; + const toolCall2 = { + toolRequest: { + name: 'bar', + ref: 'bcd', + input: 'apple', + }, + }; + const response = new GenerateResponse({ + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [toolCall1, toolCall2], + }, + }, + { + index: 0, + finishReason: 'stop', + message: { role: 'model', content: [{ text: '{"abc":123}' }] }, + }, + ], + }); + assert.deepStrictEqual(response.toolRequests(), [toolCall1, toolCall2]); + }); + }); }); describe('toGenerateRequest', () => {