Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import {
ModelReference,
Part,
Role,
ToolRequestPart,
ToolResponsePart,
} from './model.js';
import * as telemetry from './telemetry.js';
Expand Down Expand Up @@ -104,6 +105,16 @@ export class Message<T = unknown> implements MessageData {
return this.content.find((part) => part.data)?.data as T | null;
}

/**
* Returns all tool request found in this message.
* @returns Array of all tool request found in this message.
*/
toolRequests(): ToolRequestPart[] {
return this.content.filter(
(part) => !!part.toolRequest
) as ToolRequestPart[];
}

/**
* Converts the Message to a plain JS object.
* @returns Plain JS object representing the data contained in the message.
Expand Down Expand Up @@ -183,6 +194,14 @@ export class Candidate<O = unknown> implements CandidateData {
return this.message.data();
}

/**
* Returns all tool request found in this candidate.
* @returns Array of all tool request found in this candidate.
*/
toolRequests(): ToolRequestPart[] {
return this.message.toolRequests();
}

/**
* Determine whether this candidate has output that conforms to a provided schema.
*
Expand Down Expand Up @@ -290,6 +309,15 @@ export class GenerateResponse<O = unknown> implements GenerateResponseData {
return this.candidates[index]?.data() || null;
}

/**
* Returns all tool request found in the candidate.
* @param index The candidate index from which to extract tool requests, defaults to first candidate.
* @returns Array of all tool request found in the candidate.
*/
toolRequests(index: number = 0): ToolRequestPart[] {
return this.candidates[index].toolRequests();
}

/**
* Appends the message generated by the selected candidate to the messages already
* present in the generation request. The result of this method can be safely
Expand Down Expand Up @@ -361,6 +389,16 @@ export class GenerateResponseChunk<T = unknown>
return this.content.find((part) => part.data)?.data as T | null;
}

/**
* Returns all tool request found in this chunk.
* @returns Array of all tool request found in this chunk.
*/
toolRequests(): ToolRequestPart[] {
return this.content.filter(
(part) => !!part.toolRequest
) as ToolRequestPart[];
}

toJSON(): GenerateResponseChunkData {
return { index: this.index, content: this.content, custom: this.custom };
}
Expand Down
82 changes: 82 additions & 0 deletions js/ai/tests/generate/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,88 @@ describe('GenerateResponse', () => {
assert.deepStrictEqual(response.output(0), { abc: '123' });
});
});
describe('#toolRequests()', () => {
it('returns empty array if no tools requests found', () => {
const response = new GenerateResponse({
candidates: [
{
index: 0,
finishReason: 'stop',
message: { role: 'model', content: [{ text: '{"abc":"123"}' }] },
},
{
index: 0,
finishReason: 'stop',
message: { role: 'model', content: [{ text: '{"abc":123}' }] },
},
],
});
assert.deepStrictEqual(response.toolRequests(), []);
assert.deepStrictEqual(response.toolRequests(0), []);
});
it('picks the first candidate if no index provided', () => {
const toolCall = {
toolRequest: {
name: 'foo',
ref: 'abc',
input: 'banana',
},
};
const response = new GenerateResponse({
candidates: [
{
index: 0,
finishReason: 'stop',
message: {
role: 'model',
content: [toolCall],
},
},
{
index: 0,
finishReason: 'stop',
message: { role: 'model', content: [{ text: '{"abc":123}' }] },
},
],
});
assert.deepStrictEqual(response.toolRequests(), [toolCall]);
assert.deepStrictEqual(response.toolRequests(0), [toolCall]);
});
it('returns all tool call', () => {
const toolCall1 = {
toolRequest: {
name: 'foo',
ref: 'abc',
input: 'banana',
},
};
const toolCall2 = {
toolRequest: {
name: 'bar',
ref: 'bcd',
input: 'apple',
},
};
const response = new GenerateResponse({
candidates: [
{
index: 0,
finishReason: 'stop',
message: {
role: 'model',
content: [toolCall1, toolCall2],
},
},
{
index: 0,
finishReason: 'stop',
message: { role: 'model', content: [{ text: '{"abc":123}' }] },
},
],
});
assert.deepStrictEqual(response.toolRequests(), [toolCall1, toolCall2]);
});
});
});

describe('toGenerateRequest', () => {
Expand Down