From f285d0e59995ac10d2edecc17493202ba4d1dfa4 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 29 May 2024 11:44:42 -0400 Subject: [PATCH 1/3] Added `toolRequests` helper to generate response to make it easier to work with tools --- js/ai/src/generate.ts | 38 +++++++++++++ js/ai/tests/generate/generate_test.ts | 82 +++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 065624bb85..9f61eea3c7 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -39,6 +39,7 @@ import { ModelReference, Part, Role, + ToolRequestPart, ToolResponsePart, } from './model.js'; import * as telemetry from './telemetry.js'; @@ -104,6 +105,16 @@ export class Message implements MessageData { return this.content.find((part) => part.data)?.data as T | null; } + /** + * Returns all tool request found in this message. + * @returns Array of all tool request found in this message. + */ + toolRequests(): ToolRequestPart[] { + return this.content.filter( + (part) => !!part.toolRequest + ) as ToolRequestPart[]; + } + /** * Converts the Message to a plain JS object. * @returns Plain JS object representing the data contained in the message. @@ -183,6 +194,14 @@ export class Candidate implements CandidateData { return this.message.data(); } + /** + * Returns all tool request found in this candidate. + * @returns Array of all tool request found in this candidate. + */ + toolRequests(): ToolRequestPart[] { + return this.message.toolRequests(); + } + /** * Determine whether this candidate has output that conforms to a provided schema. * @@ -290,6 +309,15 @@ export class GenerateResponse implements GenerateResponseData { return this.candidates[index]?.data() || null; } + /** + * Returns all tool request found in the candidate. + * @param index The candidate index from which to extract tool requests, defaults to first candidate. + * @returns Array of all tool request found in the candidate. + */ + toolRequests(index: number = 0): ToolRequestPart[] { + return this.candidates[index].toolRequests(); + } + /** * Appends the message generated by the selected candidate to the messages already * present in the generation request. The result of this method can be safely @@ -361,6 +389,16 @@ export class GenerateResponseChunk return this.content.find((part) => part.data)?.data as T | null; } + /** + * Returns all tool request found in this chunk. + * @returns Array of all tool request found in this chunk. + */ + toolRequests(): ToolRequestPart[] { + return this.content.filter( + (part) => !!part.toolRequest + ) as ToolRequestPart[]; + } + toJSON(): GenerateResponseChunkData { return { index: this.index, content: this.content, custom: this.custom }; } diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index adb3edb6d6..b322774b62 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -230,6 +230,88 @@ describe('GenerateResponse', () => { assert.deepStrictEqual(response.output(0), { abc: '123' }); }); }); + describe('#toolRequests()', () => { + it('returns empty array if no tools requests found', () => { + const response = new GenerateResponse({ + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { role: 'model', content: [{ text: '{"abc":"123"}' }] }, + }, + { + index: 0, + finishReason: 'stop', + message: { role: 'model', content: [{ text: '{"abc":123}' }] }, + }, + ], + }); + assert.deepStrictEqual(response.toolRequests(), []); + assert.deepStrictEqual(response.toolRequests(0), []); + }); + it('picks the first candidate if no index provided', () => { + const toolCall = { + toolRequest: { + name: 'foo', + ref: 'abc', + input: 'banana', + }, + }; + const response = new GenerateResponse({ + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [toolCall], + }, + }, + { + index: 0, + finishReason: 'stop', + message: { role: 'model', content: [{ text: '{"abc":123}' }] }, + }, + ], + }); + assert.deepStrictEqual(response.toolRequests(), [toolCall]); + assert.deepStrictEqual(response.toolRequests(0), [toolCall]); + }); + it('returns all tool call', () => { + const toolCall1 = { + toolRequest: { + name: 'foo', + ref: 'abc', + input: 'banana', + }, + }; + const toolCall2 = { + toolRequest: { + name: 'foo', + ref: 'abc', + input: 'banana', + }, + }; + const response = new GenerateResponse({ + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [toolCall1, toolCall2], + }, + }, + { + index: 0, + finishReason: 'stop', + message: { role: 'model', content: [{ text: '{"abc":123}' }] }, + }, + ], + }); + assert.deepStrictEqual(response.toolRequests(), [toolCall1, toolCall2]); + }); + }); }); describe('toGenerateRequest', () => { From 2c1a909152a0291ab012dda98aea210adc1ce19d Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 29 May 2024 13:48:50 -0400 Subject: [PATCH 2/3] return ToolRequest instead of ToolRequestPart --- js/ai/src/generate.ts | 22 ++++++++++----------- js/ai/src/model.ts | 19 ++++++++++-------- js/ai/tests/generate/generate_test.ts | 28 +++++++++++---------------- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 9f61eea3c7..0521aaf1a4 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -39,7 +39,7 @@ import { ModelReference, Part, Role, - ToolRequestPart, + ToolRequest, ToolResponsePart, } from './model.js'; import * as telemetry from './telemetry.js'; @@ -109,10 +109,10 @@ export class Message implements MessageData { * Returns all tool request found in this message. * @returns Array of all tool request found in this message. */ - toolRequests(): ToolRequestPart[] { - return this.content.filter( - (part) => !!part.toolRequest - ) as ToolRequestPart[]; + toolRequests(): ToolRequest[] { + return this.content + .filter((part) => !!part.toolRequest) + .map((part) => part.toolRequest!); } /** @@ -198,7 +198,7 @@ export class Candidate implements CandidateData { * Returns all tool request found in this candidate. * @returns Array of all tool request found in this candidate. */ - toolRequests(): ToolRequestPart[] { + toolRequests(): ToolRequest[] { return this.message.toolRequests(); } @@ -314,7 +314,7 @@ export class GenerateResponse implements GenerateResponseData { * @param index The candidate index from which to extract tool requests, defaults to first candidate. * @returns Array of all tool request found in the candidate. */ - toolRequests(index: number = 0): ToolRequestPart[] { + toolRequests(index: number = 0): ToolRequest[] { return this.candidates[index].toolRequests(); } @@ -393,10 +393,10 @@ export class GenerateResponseChunk * Returns all tool request found in this chunk. * @returns Array of all tool request found in this chunk. */ - toolRequests(): ToolRequestPart[] { - return this.content.filter( - (part) => !!part.toolRequest - ) as ToolRequestPart[]; + toolRequests(): ToolRequest[] { + return this.content + .filter((part) => !!part.toolRequest) + .map((part) => part.toolRequest!); } toJSON(): GenerateResponseChunkData { diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 10036ea621..903fd75c64 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -62,16 +62,19 @@ export const MediaPartSchema = EmptyPartSchema.extend({ }); export type MediaPart = z.infer; +export const ToolRequestSchema = z.object({ + /** The call id or reference for a specific request. */ + ref: z.string().optional(), + /** The name of the tool to call. */ + name: z.string(), + /** The input parameters for the tool, usually a JSON object. */ + input: z.unknown().optional(), +}); +export type ToolRequest = z.infer; + export const ToolRequestPartSchema = EmptyPartSchema.extend({ /** A request for a tool to be executed, usually provided by a model. */ - toolRequest: z.object({ - /** The call id or reference for a specific request. */ - ref: z.string().optional(), - /** The name of the tool to call. */ - name: z.string(), - /** The input parameters for the tool, usually a JSON object. */ - input: z.unknown().optional(), - }), + toolRequest: ToolRequestSchema, }); export type ToolRequestPart = z.infer; diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index b322774b62..11ff513273 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -251,11 +251,9 @@ describe('GenerateResponse', () => { }); it('picks the first candidate if no index provided', () => { const toolCall = { - toolRequest: { - name: 'foo', - ref: 'abc', - input: 'banana', - }, + name: 'foo', + ref: 'abc', + input: 'banana', }; const response = new GenerateResponse({ candidates: [ @@ -264,7 +262,7 @@ describe('GenerateResponse', () => { finishReason: 'stop', message: { role: 'model', - content: [toolCall], + content: [{ toolRequest: toolCall }], }, }, { @@ -279,18 +277,14 @@ describe('GenerateResponse', () => { }); it('returns all tool call', () => { const toolCall1 = { - toolRequest: { - name: 'foo', - ref: 'abc', - input: 'banana', - }, + name: 'foo', + ref: 'abc', + input: 'banana', }; const toolCall2 = { - toolRequest: { - name: 'foo', - ref: 'abc', - input: 'banana', - }, + name: 'bar', + ref: 'bcd', + input: 'apple', }; const response = new GenerateResponse({ candidates: [ @@ -299,7 +293,7 @@ describe('GenerateResponse', () => { finishReason: 'stop', message: { role: 'model', - content: [toolCall1, toolCall2], + content: [{ toolRequest: toolCall1 }, { toolRequest: toolCall2 }], }, }, { From c2b0ff6eb4d1c8c7244a2cf2c3ee7b6d47eb34ed Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 29 May 2024 14:38:55 -0400 Subject: [PATCH 3/3] undo --- js/ai/src/generate.ts | 22 ++++++++++----------- js/ai/src/model.ts | 19 ++++++++---------- js/ai/tests/generate/generate_test.ts | 28 ++++++++++++++++----------- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 0521aaf1a4..9f61eea3c7 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -39,7 +39,7 @@ import { ModelReference, Part, Role, - ToolRequest, + ToolRequestPart, ToolResponsePart, } from './model.js'; import * as telemetry from './telemetry.js'; @@ -109,10 +109,10 @@ export class Message implements MessageData { * Returns all tool request found in this message. * @returns Array of all tool request found in this message. */ - toolRequests(): ToolRequest[] { - return this.content - .filter((part) => !!part.toolRequest) - .map((part) => part.toolRequest!); + toolRequests(): ToolRequestPart[] { + return this.content.filter( + (part) => !!part.toolRequest + ) as ToolRequestPart[]; } /** @@ -198,7 +198,7 @@ export class Candidate implements CandidateData { * Returns all tool request found in this candidate. * @returns Array of all tool request found in this candidate. */ - toolRequests(): ToolRequest[] { + toolRequests(): ToolRequestPart[] { return this.message.toolRequests(); } @@ -314,7 +314,7 @@ export class GenerateResponse implements GenerateResponseData { * @param index The candidate index from which to extract tool requests, defaults to first candidate. * @returns Array of all tool request found in the candidate. */ - toolRequests(index: number = 0): ToolRequest[] { + toolRequests(index: number = 0): ToolRequestPart[] { return this.candidates[index].toolRequests(); } @@ -393,10 +393,10 @@ export class GenerateResponseChunk * Returns all tool request found in this chunk. * @returns Array of all tool request found in this chunk. */ - toolRequests(): ToolRequest[] { - return this.content - .filter((part) => !!part.toolRequest) - .map((part) => part.toolRequest!); + toolRequests(): ToolRequestPart[] { + return this.content.filter( + (part) => !!part.toolRequest + ) as ToolRequestPart[]; } toJSON(): GenerateResponseChunkData { diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 903fd75c64..10036ea621 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -62,19 +62,16 @@ export const MediaPartSchema = EmptyPartSchema.extend({ }); export type MediaPart = z.infer; -export const ToolRequestSchema = z.object({ - /** The call id or reference for a specific request. */ - ref: z.string().optional(), - /** The name of the tool to call. */ - name: z.string(), - /** The input parameters for the tool, usually a JSON object. */ - input: z.unknown().optional(), -}); -export type ToolRequest = z.infer; - export const ToolRequestPartSchema = EmptyPartSchema.extend({ /** A request for a tool to be executed, usually provided by a model. */ - toolRequest: ToolRequestSchema, + toolRequest: z.object({ + /** The call id or reference for a specific request. */ + ref: z.string().optional(), + /** The name of the tool to call. */ + name: z.string(), + /** The input parameters for the tool, usually a JSON object. */ + input: z.unknown().optional(), + }), }); export type ToolRequestPart = z.infer; diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 11ff513273..f8717f9f86 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -251,9 +251,11 @@ describe('GenerateResponse', () => { }); it('picks the first candidate if no index provided', () => { const toolCall = { - name: 'foo', - ref: 'abc', - input: 'banana', + toolRequest: { + name: 'foo', + ref: 'abc', + input: 'banana', + }, }; const response = new GenerateResponse({ candidates: [ @@ -262,7 +264,7 @@ describe('GenerateResponse', () => { finishReason: 'stop', message: { role: 'model', - content: [{ toolRequest: toolCall }], + content: [toolCall], }, }, { @@ -277,14 +279,18 @@ describe('GenerateResponse', () => { }); it('returns all tool call', () => { const toolCall1 = { - name: 'foo', - ref: 'abc', - input: 'banana', + toolRequest: { + name: 'foo', + ref: 'abc', + input: 'banana', + }, }; const toolCall2 = { - name: 'bar', - ref: 'bcd', - input: 'apple', + toolRequest: { + name: 'bar', + ref: 'bcd', + input: 'apple', + }, }; const response = new GenerateResponse({ candidates: [ @@ -293,7 +299,7 @@ describe('GenerateResponse', () => { finishReason: 'stop', message: { role: 'model', - content: [{ toolRequest: toolCall1 }, { toolRequest: toolCall2 }], + content: [toolCall1, toolCall2], }, }, {