Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add systemInstruction and toolConfig #8146

Merged
merged 4 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions packages/vertexai/src/methods/chat-session-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ const VALID_PART_FIELDS: Array<keyof Part> = [
const VALID_PARTS_PER_ROLE: { [key in Role]: Array<keyof Part> } = {
user: ['text', 'inlineData'],
function: ['functionResponse'],
model: ['text', 'functionCall']
model: ['text', 'functionCall'],
// System instructions shouldn't be in history anyway.
system: ['text']
};

const VALID_PREVIOUS_CONTENT_ROLES: { [key in Role]: Role[] } = {
user: ['model'],
function: ['model'],
model: ['user', 'function']
model: ['user', 'function'],
// System instructions shouldn't be in history.
system: []
};

export function validateChatHistory(history: Content[]): void {
Expand Down
4 changes: 4 additions & 0 deletions packages/vertexai/src/methods/chat-session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ export class ChatSession {
safetySettings: this.params?.safetySettings,
generationConfig: this.params?.generationConfig,
tools: this.params?.tools,
toolConfig: this.params?.toolConfig,
systemInstruction: this.params?.systemInstruction,
contents: [...this._history, newContent]
};
let finalResult = {} as GenerateContentResult;
Expand Down Expand Up @@ -135,6 +137,8 @@ export class ChatSession {
safetySettings: this.params?.safetySettings,
generationConfig: this.params?.generationConfig,
tools: this.params?.tools,
toolConfig: this.params?.toolConfig,
systemInstruction: this.params?.systemInstruction,
contents: [...this._history, newContent]
};
const streamPromise = generateContentStream(
Expand Down
4 changes: 3 additions & 1 deletion packages/vertexai/src/methods/generate-content.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import * as request from '../requests/request';
import { generateContent } from './generate-content';
import {
GenerateContentRequest,
HarmBlockMethod,
HarmBlockThreshold,
HarmCategory
} from '../types';
Expand All @@ -47,7 +48,8 @@ const fakeRequestParams: GenerateContentRequest = {
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
method: HarmBlockMethod.SEVERITY
}
]
};
Expand Down
163 changes: 161 additions & 2 deletions packages/vertexai/src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { expect } from 'chai';
import { use, expect } from 'chai';
import { GenerativeModel } from './generative-model';
import { VertexAI } from '../public-types';
import { FunctionCallingMode, VertexAI } from '../public-types';
import * as request from '../requests/request';
import { match, restore, stub } from 'sinon';
import { getMockResponse } from '../../test-utils/mock-response';
import sinonChai from 'sinon-chai';

use(sinonChai);

const fakeVertexAI: VertexAI = {
app: {
Expand Down Expand Up @@ -53,4 +59,157 @@ describe('GenerativeModel', () => {
});
expect(genModel.model).to.equal('tunedModels/my-model');
});
it('passes params through to generateContent', async () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'my-model',
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
});
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE
);
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
const mockResponse = getMockResponse(
'unary-success-basic-reply-short.json'
);
const makeRequestStub = stub(request, 'makeRequest').resolves(
mockResponse as Response
);
await genModel.generateContent('hello');
expect(makeRequestStub).to.be.calledWith(
'publishers/google/models/my-model',
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return (
value.includes('myfunc') &&
value.includes(FunctionCallingMode.NONE) &&
value.includes('be friendly')
);
}),
{}
);
restore();
});
it('generateContent overrides model values', async () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'my-model',
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
});
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE
);
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
const mockResponse = getMockResponse(
'unary-success-basic-reply-short.json'
);
const makeRequestStub = stub(request, 'makeRequest').resolves(
mockResponse as Response
);
await genModel.generateContent({
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
tools: [{ functionDeclarations: [{ name: 'otherfunc' }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } },
systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }
});
expect(makeRequestStub).to.be.calledWith(
'publishers/google/models/my-model',
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return (
value.includes('otherfunc') &&
value.includes(FunctionCallingMode.AUTO) &&
value.includes('be formal')
);
}),
{}
);
restore();
});
it('passes params through to chat.sendMessage', async () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'my-model',
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
});
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE
);
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
const mockResponse = getMockResponse(
'unary-success-basic-reply-short.json'
);
const makeRequestStub = stub(request, 'makeRequest').resolves(
mockResponse as Response
);
await genModel.startChat().sendMessage('hello');
expect(makeRequestStub).to.be.calledWith(
'publishers/google/models/my-model',
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return (
value.includes('myfunc') &&
value.includes(FunctionCallingMode.NONE) &&
value.includes('be friendly')
);
}),
{}
);
restore();
});
it('startChat overrides model values', async () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'my-model',
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
});
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE
);
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
const mockResponse = getMockResponse(
'unary-success-basic-reply-short.json'
);
const makeRequestStub = stub(request, 'makeRequest').resolves(
mockResponse as Response
);
await genModel
.startChat({
tools: [{ functionDeclarations: [{ name: 'otherfunc' }] }],
toolConfig: {
functionCallingConfig: { mode: FunctionCallingMode.AUTO }
},
systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }
})
.sendMessage('hello');
expect(makeRequestStub).to.be.calledWith(
'publishers/google/models/my-model',
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return (
value.includes('otherfunc') &&
value.includes(FunctionCallingMode.AUTO) &&
value.includes('be formal')
);
}),
{}
);
restore();
});
});
14 changes: 13 additions & 1 deletion packages/vertexai/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
generateContentStream
} from '../methods/generate-content';
import {
Content,
CountTokensRequest,
CountTokensResponse,
GenerateContentRequest,
Expand All @@ -31,7 +32,8 @@ import {
RequestOptions,
SafetySetting,
StartChatParams,
Tool
Tool,
ToolConfig
} from '../types';
import { ChatSession } from '../methods/chat-session';
import { countTokens } from '../methods/count-tokens';
Expand All @@ -52,6 +54,8 @@ export class GenerativeModel {
safetySettings: SafetySetting[];
requestOptions?: RequestOptions;
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;

constructor(
vertexAI: VertexAI,
Expand Down Expand Up @@ -88,6 +92,8 @@ export class GenerativeModel {
this.generationConfig = modelParams.generationConfig || {};
this.safetySettings = modelParams.safetySettings || [];
this.tools = modelParams.tools;
this.toolConfig = modelParams.toolConfig;
this.systemInstruction = modelParams.systemInstruction;
this.requestOptions = requestOptions || {};
}

Expand All @@ -106,6 +112,8 @@ export class GenerativeModel {
generationConfig: this.generationConfig,
safetySettings: this.safetySettings,
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
...formattedParams
},
this.requestOptions
Expand All @@ -129,6 +137,8 @@ export class GenerativeModel {
generationConfig: this.generationConfig,
safetySettings: this.safetySettings,
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
...formattedParams
},
this.requestOptions
Expand All @@ -145,6 +155,8 @@ export class GenerativeModel {
this.model,
{
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
...startChatParams
},
this.requestOptions
Expand Down
21 changes: 20 additions & 1 deletion packages/vertexai/src/types/enums.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export type Role = (typeof POSSIBLE_ROLES)[number];
* Possible roles.
* @public
*/
export const POSSIBLE_ROLES = ['user', 'model', 'function'] as const;
export const POSSIBLE_ROLES = ['user', 'model', 'function', 'system'] as const;

/**
* Harm categories that would cause prompts or candidates to be blocked.
Expand Down Expand Up @@ -133,3 +133,22 @@ export enum FinishReason {
// Unknown reason.
OTHER = 'OTHER'
}

/**
* @public
*/
export enum FunctionCallingMode {
// Unspecified function calling mode. This value should not be used.
MODE_UNSPECIFIED = 'MODE_UNSPECIFIED',
// Default model behavior, model decides to predict either a function call
// or a natural language repspose.
AUTO = 'AUTO',
// Model is constrained to always predicting a function call only.
// If "allowed_function_names" is set, the predicted function call will be
// limited to any one of "allowed_function_names", else the predicted
// function call will be any one of the provided "function_declarations".
ANY = 'ANY',
// Model will not predict any function call. Model behavior is same as when
// not passing any function declarations.
NONE = 'NONE'
}
29 changes: 28 additions & 1 deletion packages/vertexai/src/types/requests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
*/

import { Content } from './content';
import { HarmBlockMethod, HarmBlockThreshold, HarmCategory } from './enums';
import {
FunctionCallingMode,
HarmBlockMethod,
HarmBlockThreshold,
HarmCategory
} from './enums';

/**
* Base parameters for a number of methods.
Expand All @@ -34,6 +39,8 @@ export interface BaseParams {
export interface ModelParams extends BaseParams {
model: string;
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;
}

/**
Expand All @@ -43,6 +50,8 @@ export interface ModelParams extends BaseParams {
export interface GenerateContentRequest extends BaseParams {
contents: Content[];
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;
}

/**
Expand Down Expand Up @@ -77,6 +86,8 @@ export interface GenerationConfig {
export interface StartChatParams extends BaseParams {
history?: Content[];
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;
}

/**
Expand Down Expand Up @@ -220,3 +231,19 @@ export interface FunctionDeclarationSchemaProperty {
/** Optional. The example of the property. */
example?: unknown;
}

/**
* Tool config. This config is shared for all tools provided in the request.
* @public
*/
export interface ToolConfig {
functionCallingConfig: FunctionCallingConfig;
}

/**
* @public
*/
export interface FunctionCallingConfig {
mode?: FunctionCallingMode;
allowedFunctionNames?: string[];
}