Skip to content

Commit

Permalink
make test() accessor handle mixed-parts responses
Browse files Browse the repository at this point in the history
  • Loading branch information
hsubox76 committed May 7, 2024
1 parent 4c83b34 commit 8e0fef7
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 38 deletions.
121 changes: 88 additions & 33 deletions packages/vertexai/src/requests/response-helpers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,54 +40,85 @@ const fakeResponseText: GenerateContentResponse = {
]
};

const functionCallPart1 = {
functionCall: {
name: 'find_theaters',
args: {
location: 'Mountain View, CA',
movie: 'Barbie'
}
}
};

const functionCallPart2 = {
functionCall: {
name: 'find_times',
args: {
location: 'Mountain View, CA',
movie: 'Barbie',
time: '20:00'
}
}
};

const fakeResponseFunctionCall: GenerateContentResponse = {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [
{
functionCall: {
name: 'find_theaters',
args: {
location: 'Mountain View, CA',
movie: 'Barbie'
}
}
}
]
parts: [functionCallPart1]
}
}
]
};

const fakeResponseFunctionCalls: GenerateContentResponse = {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [functionCallPart1, functionCallPart2]
}
}
]
};

const fakeResponseMixed1: GenerateContentResponse = {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [{ text: 'some text' }, functionCallPart2]
}
}
]
};

const fakeResponseMixed2: GenerateContentResponse = {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [functionCallPart1, { text: 'some text' }]
}
}
]
};

const fakeResponseMixed3: GenerateContentResponse = {
candidates: [
{
index: 0,
content: {
role: 'model',
parts: [
{
functionCall: {
name: 'find_theaters',
args: {
location: 'Mountain View, CA',
movie: 'Barbie'
}
}
},
{
functionCall: {
name: 'find_times',
args: {
location: 'Mountain View, CA',
movie: 'Barbie',
time: '20:00'
}
}
}
{ text: 'some text' },
functionCallPart1,
{ text: ' and more text' }
]
}
}
Expand All @@ -109,19 +140,43 @@ describe('response-helpers methods', () => {
it('good response text', async () => {
const enhancedResponse = addHelpers(fakeResponseText);
expect(enhancedResponse.text()).to.equal('Some text and some more text');
expect(enhancedResponse.functionCalls()).to.be.undefined;
});
it('good response functionCall', async () => {
const enhancedResponse = addHelpers(fakeResponseFunctionCall);
expect(enhancedResponse.text()).to.equal('');
expect(enhancedResponse.functionCalls()).to.deep.equal([
fakeResponseFunctionCall.candidates?.[0].content.parts[0].functionCall
functionCallPart1.functionCall
]);
});
it('good response functionCalls', async () => {
const enhancedResponse = addHelpers(fakeResponseFunctionCalls);
expect(enhancedResponse.text()).to.equal('');
expect(enhancedResponse.functionCalls()).to.deep.equal([
functionCallPart1.functionCall,
functionCallPart2.functionCall
]);
});
it('good response text/functionCall', async () => {
const enhancedResponse = addHelpers(fakeResponseMixed1);
expect(enhancedResponse.functionCalls()).to.deep.equal([
functionCallPart2.functionCall
]);
expect(enhancedResponse.text()).to.equal('some text');
});
it('good response functionCall/text', async () => {
const enhancedResponse = addHelpers(fakeResponseMixed2);
expect(enhancedResponse.functionCalls()).to.deep.equal([
functionCallPart1.functionCall
]);
expect(enhancedResponse.text()).to.equal('some text');
});
it('good response text/functionCall/text', async () => {
const enhancedResponse = addHelpers(fakeResponseMixed3);
expect(enhancedResponse.functionCalls()).to.deep.equal([
fakeResponseFunctionCalls.candidates?.[0].content.parts[0].functionCall,
fakeResponseFunctionCalls.candidates?.[0].content.parts[1].functionCall
functionCallPart1.functionCall
]);
expect(enhancedResponse.text()).to.equal('some text and more text');
});
it('bad response safety', async () => {
const enhancedResponse = addHelpers(badFakeResponse);
Expand Down
16 changes: 11 additions & 5 deletions packages/vertexai/src/requests/response-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,19 @@ export function addHelpers(
}

/**
* Returns text of first candidate.
* Returns all text found in all parts of first candidate.
*/
export function getText(response: GenerateContentResponse): string {
if (response.candidates?.[0].content?.parts?.[0]?.text) {
return response.candidates[0].content.parts
.map(({ text }) => text)
.join('');
const textStrings = [];
if (response.candidates?.[0].content?.parts) {
for (const part of response.candidates?.[0].content?.parts) {
if (part.text) {
textStrings.push(part.text);
}
}
}
if (textStrings.length > 0) {
return textStrings.join('');
} else {
return '';
}
Expand Down

0 comments on commit 8e0fef7

Please sign in to comment.