From 518b54e1acd637ff96739872ba9b002e654fc540 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 27 Jan 2025 12:41:04 -0500 Subject: [PATCH 1/7] feat(js/plugins/vertexai): instruduced gemini model ref helper and ability to register versions --- js/plugins/vertexai/src/common/types.ts | 9 +- js/plugins/vertexai/src/gemini.ts | 130 +++++++++++++++-- js/plugins/vertexai/src/index.ts | 30 +++- .../vertexai/src/modelgarden/mistral.ts | 2 + js/plugins/vertexai/tests/plugin_test.ts | 137 ++++++++++++++++++ 5 files changed, 291 insertions(+), 17 deletions(-) create mode 100644 js/plugins/vertexai/tests/plugin_test.ts diff --git a/js/plugins/vertexai/src/common/types.ts b/js/plugins/vertexai/src/common/types.ts index a88fef9e68..642f5980a3 100644 --- a/js/plugins/vertexai/src/common/types.ts +++ b/js/plugins/vertexai/src/common/types.ts @@ -14,7 +14,9 @@ * limitations under the License. */ +import { ModelReference } from 'genkit'; import { GoogleAuthOptions } from 'google-auth-library'; +import { GeminiConfigSchema } from '../gemini'; /** Common options for Vertex AI plugin configuration */ export interface CommonPluginOptions { @@ -27,4 +29,9 @@ export interface CommonPluginOptions { } /** Combined plugin options, extending common options with subplugin-specific options */ -export interface PluginOptions extends CommonPluginOptions {} +export interface PluginOptions extends CommonPluginOptions { + models?: ( + | ModelReference + | string + )[]; +} diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 9281173b47..1c8a81bc6f 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -39,6 +39,7 @@ import { MediaPart, MessageData, ModelAction, + ModelInfo, ModelMiddleware, ModelReference, Part, @@ -166,6 +167,69 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ .optional(), }); +/** + * Known model names, to allow code completion for convenience. Allows other model names. + */ +export type GeminiVersionString = + | keyof typeof SUPPORTED_GEMINI_MODELS + | (string & {}); + +/** + * Returns a reference to a model that can be used in generate calls. + * + * ```js + * await ai.generate({ + * prompt: 'hi', + * model: gemini('gemini-1.5-flash') + * }); + * ``` + */ +export function gemini( + version: GeminiVersionString, + options: GeminiConfig = {} +): ModelReference { + const nearestModel = nearestGeminiModelRef(version); + return modelRef({ + name: `vertexai/${version}`, + config: options, + configSchema: GeminiConfigSchema, + info: { + ...nearestModel.info, + // If exact suffix match for a known model, use its label, otherwise create a new label + label: nearestModel.name.endsWith(version) + ? nearestModel.info?.label + : `Vertex AI - ${version}`, + }, + }); +} + +function nearestGeminiModelRef( + version: GeminiVersionString, + options: GeminiConfig = {} +): ModelReference { + const matchingKey = longestMatchingPrefix( + version, + Object.keys(SUPPORTED_GEMINI_MODELS) + ); + if (matchingKey) { + return SUPPORTED_GEMINI_MODELS[matchingKey].withConfig({ + ...options, + version, + }); + } + return GENERIC_GEMINI_MODEL.withConfig({ ...options, version }); +} + +function longestMatchingPrefix(version: string, potentialMatches: string[]) { + return potentialMatches + .filter((p) => version.startsWith(p)) + .reduce( + (longest, current) => + current.length > longest.length ? current : longest, + '' + ); +} + /** * Gemini model configuration options. * @@ -268,6 +332,21 @@ export const gemini20FlashExp = modelRef({ configSchema: GeminiConfigSchema, }); +export const GENERIC_GEMINI_MODEL = modelRef({ + name: 'vertexai/gemini', + configSchema: GeminiConfigSchema, + info: { + label: 'Google Gemini', + supports: { + multiturn: true, + media: true, + tools: true, + toolChoice: true, + systemRole: true, + }, + }, +}); + export const SUPPORTED_V1_MODELS = { 'gemini-1.0-pro': gemini10Pro, }; @@ -281,11 +360,11 @@ export const SUPPORTED_V15_MODELS = { export const SUPPORTED_GEMINI_MODELS = { ...SUPPORTED_V1_MODELS, ...SUPPORTED_V15_MODELS, -}; +} as const; function toGeminiRole( role: MessageData['role'], - model?: ModelReference + modelInfo?: ModelInfo ): string { switch (role) { case 'user': @@ -293,7 +372,7 @@ function toGeminiRole( case 'model': return 'model'; case 'system': - if (model && SUPPORTED_V15_MODELS[model.name]) { + if (modelInfo && modelInfo.supports?.systemRole) { // We should have already pulled out the supported system messages, // anything remaining is unsupported; throw an error. throw new Error( @@ -387,10 +466,10 @@ export function toGeminiSystemInstruction(message: MessageData): Content { export function toGeminiMessage( message: MessageData, - model?: ModelReference + modelInfo?: ModelInfo ): Content { return { - role: toGeminiRole(message.role, model), + role: toGeminiRole(message.role, modelInfo), parts: message.content.map(toGeminiPart), }; } @@ -581,7 +660,7 @@ export function cleanSchema(schema: JSONSchema): JSONSchema { /** * Define a Vertex AI Gemini model. */ -export function defineGeminiModel( +export function defineGeminiKnownModel( ai: Genkit, name: string, vertexClientFactory: ( @@ -594,11 +673,34 @@ export function defineGeminiModel( const model: ModelReference = SUPPORTED_GEMINI_MODELS[name]; if (!model) throw new Error(`Unsupported model: ${name}`); + return defineGeminiModel( + ai, + modelName, + name, + model?.info, + vertexClientFactory, + options + ); +} + +/** + * Define a Vertex AI Gemini model. + */ +export function defineGeminiModel( + ai: Genkit, + modelName: string, + version: string, + modelInfo: ModelInfo | undefined, + vertexClientFactory: ( + request: GenerateRequest + ) => VertexAI, + options: PluginOptions +): ModelAction { const middlewares: ModelMiddleware[] = []; - if (SUPPORTED_V1_MODELS[name]) { + if (!modelInfo?.supports?.systemRole) { middlewares.push(simulateSystemPrompt()); } - if (model?.info?.supports?.media) { + if (modelInfo?.supports?.media) { // the gemini api doesn't support downloading media from http(s) middlewares.push(downloadRequestMedia({ maxBytes: 1024 * 1024 * 20 })); } @@ -606,7 +708,7 @@ export function defineGeminiModel( return ai.defineModel( { name: modelName, - ...model.info, + ...modelInfo, configSchema: GeminiConfigSchema, use: middlewares, }, @@ -619,7 +721,7 @@ export function defineGeminiModel( // Handle system instructions separately let systemInstruction: Content | undefined = undefined; - if (SUPPORTED_V15_MODELS[name]) { + if (modelInfo?.supports?.systemRole) { const systemMessage = messages.find((m) => m.role === 'system'); if (systemMessage) { messages.splice(messages.indexOf(systemMessage), 1); @@ -659,7 +761,7 @@ export function defineGeminiModel( toolConfig, history: messages .slice(0, -1) - .map((message) => toGeminiMessage(message, model)), + .map((message) => toGeminiMessage(message, modelInfo)), generationConfig: { candidateCount: request.candidates || undefined, temperature: request.config?.temperature, @@ -673,9 +775,7 @@ export function defineGeminiModel( }; // Handle cache - const modelVersion = (request.config?.version || - model.version || - name) as string; + const modelVersion = (request.config?.version || version) as string; const cacheConfigDetails = extractCacheConfig(request); const apiClient = new ApiClient( @@ -727,7 +827,7 @@ export function defineGeminiModel( }); } - const msg = toGeminiMessage(messages[messages.length - 1], model); + const msg = toGeminiMessage(messages[messages.length - 1], modelInfo); if (cache) { genModel = vertex.preview.getGenerativeModelFromCachedContent( diff --git a/js/plugins/vertexai/src/index.ts b/js/plugins/vertexai/src/index.ts index 18114488a2..47de7c9a02 100644 --- a/js/plugins/vertexai/src/index.ts +++ b/js/plugins/vertexai/src/index.ts @@ -35,7 +35,9 @@ import { } from './embedder.js'; import { SUPPORTED_GEMINI_MODELS, + defineGeminiKnownModel, defineGeminiModel, + gemini, gemini10Pro, gemini15Flash, gemini15Pro, @@ -51,6 +53,7 @@ import { } from './imagen.js'; export { type PluginOptions } from './common/types.js'; export { + gemini, gemini10Pro, gemini15Flash, gemini15Pro, @@ -78,8 +81,33 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin { imagenModel(ai, name, authClient, { projectId, location }) ); Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => - defineGeminiModel(ai, name, vertexClientFactory, { projectId, location }) + defineGeminiKnownModel(ai, name, vertexClientFactory, { + projectId, + location, + }) ); + if (options?.models) { + for (const modelOrRef of options?.models) { + const modelName = + typeof modelOrRef === 'string' + ? modelOrRef + : // strip out the `vertexai/` prefix + modelOrRef.name.split('/')[1]; + const modelRef = + typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef; + defineGeminiModel( + ai, + modelRef.name, + modelName, + modelRef.info, + vertexClientFactory, + { + projectId, + location, + } + ); + } + } Object.keys(SUPPORTED_EMBEDDER_MODELS).map((name) => defineVertexAIEmbedder(ai, name, authClient, { projectId, location }) diff --git a/js/plugins/vertexai/src/modelgarden/mistral.ts b/js/plugins/vertexai/src/modelgarden/mistral.ts index 05c04f8a4e..b2aac4dde9 100644 --- a/js/plugins/vertexai/src/modelgarden/mistral.ts +++ b/js/plugins/vertexai/src/modelgarden/mistral.ts @@ -124,6 +124,8 @@ function toMistralRole(role: Role): MistralRole { return 'tool'; case 'system': return 'system'; + default: + throw new Error(`Unknwon role ${role}`); } } function toMistralToolRequest(toolRequest: Record): FunctionCall { diff --git a/js/plugins/vertexai/tests/plugin_test.ts b/js/plugins/vertexai/tests/plugin_test.ts new file mode 100644 index 0000000000..e8bdda6ab9 --- /dev/null +++ b/js/plugins/vertexai/tests/plugin_test.ts @@ -0,0 +1,137 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from 'assert'; +import { genkit } from 'genkit'; +import { ModelInfo } from 'genkit/model'; +import { describe, it } from 'node:test'; +import { GENERIC_GEMINI_MODEL, gemini } from '../src/gemini.js'; +import vertexAI, { gemini15Flash, gemini15Pro } from '../src/index.js'; + +describe('plugin', () => { + it('should init the plugin without requiring the api key', async () => { + const ai = genkit({ + plugins: [vertexAI()], + }); + + assert.ok(ai); + }); + + it('should pre-register a few flagship models', async () => { + const ai = genkit({ + plugins: [vertexAI()], + }); + + assert.ok(await ai.registry.lookupAction(`/model/${gemini15Flash.name}`)); + assert.ok(await ai.registry.lookupAction(`/model/${gemini15Pro.name}`)); + }); + + it('allow referencing models using `gemini` helper', async () => { + const ai = genkit({ + plugins: [vertexAI()], + }); + + const pro = await ai.registry.lookupAction( + `/model/${gemini('gemini-1.5-pro').name}` + ); + assert.ok(pro); + assert.strictEqual(pro.__action.name, 'vertexai/gemini-1.5-pro'); + const flash = await ai.registry.lookupAction( + `/model/${gemini('gemini-1.5-flash').name}` + ); + assert.ok(flash); + assert.strictEqual(flash.__action.name, 'vertexai/gemini-1.5-flash'); + }); + + it('references explicitly registered models', async () => { + const flash002Ref = gemini('gemini-1.5-flash-002'); + const ai = genkit({ + plugins: [ + vertexAI({ + location: 'us-central1', + models: ['gemini-1.5-pro-002', flash002Ref, 'gemini-4.0-banana'], + }), + ], + }); + + const pro002Ref = gemini('gemini-1.5-pro-002'); + assert.strictEqual(pro002Ref.name, 'vertexai/gemini-1.5-pro-002'); + assertEqualModelInfo( + pro002Ref.info!, + 'Google AI - gemini-1.5-pro-002', + gemini15Pro.info! + ); + const pro002 = await ai.registry.lookupAction(`/model/${pro002Ref.name}`); + assert.ok(pro002); + assert.strictEqual(pro002.__action.name, 'vertexai/gemini-1.5-pro-002'); + assertEqualModelInfo( + pro002.__action.metadata?.model, + 'Google AI - gemini-1.5-pro-002', + gemini15Pro.info! + ); + + assert.strictEqual(flash002Ref.name, 'vertexai/gemini-1.5-flash-002'); + assertEqualModelInfo( + flash002Ref.info!, + 'Google AI - gemini-1.5-flash-002', + gemini15Flash.info! + ); + const flash002 = await ai.registry.lookupAction( + `/model/${flash002Ref.name}` + ); + assert.ok(flash002); + assert.strictEqual(flash002.__action.name, 'vertexai/gemini-1.5-flash-002'); + assertEqualModelInfo( + flash002.__action.metadata?.model, + 'Google AI - gemini-1.5-flash-002', + gemini15Flash.info! + ); + + const bananaRef = gemini('gemini-4.0-banana'); + assert.strictEqual(bananaRef.name, 'vertexai/gemini-4.0-banana'); + assertEqualModelInfo( + bananaRef.info!, + 'Google AI - gemini-4.0-banana', + GENERIC_GEMINI_MODEL.info! // <---- generic model fallback + ); + const banana = await ai.registry.lookupAction(`/model/${bananaRef.name}`); + assert.ok(banana); + assert.strictEqual(banana.__action.name, 'vertexai/gemini-4.0-banana'); + assertEqualModelInfo( + banana.__action.metadata?.model, + 'Google AI - gemini-4.0-banana', + GENERIC_GEMINI_MODEL.info! // <---- generic model fallback + ); + + // this one is not registered + const flash003Ref = gemini('gemini-1.5-flash-003'); + assert.strictEqual(flash003Ref.name, 'vertexai/gemini-1.5-flash-003'); + const flash003 = await ai.registry.lookupAction( + `/model/${flash003Ref.name}` + ); + assert.ok(flash003 === undefined); + }); +}); + +function assertEqualModelInfo( + modelAction: ModelInfo, + expectedLabel: string, + expectedInfo: ModelInfo +) { + assert.strictEqual(modelAction.label, expectedLabel); + assert.deepStrictEqual(modelAction.supports, expectedInfo.supports); + assert.deepStrictEqual(modelAction.versions, expectedInfo.versions); +} From d0b6559a02f932556e9b4ef04d64aa3ab5e34835 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 27 Jan 2025 13:59:47 -0500 Subject: [PATCH 2/7] fix tests --- js/plugins/vertexai/tests/plugin_test.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/js/plugins/vertexai/tests/plugin_test.ts b/js/plugins/vertexai/tests/plugin_test.ts index e8bdda6ab9..c78ab5ad67 100644 --- a/js/plugins/vertexai/tests/plugin_test.ts +++ b/js/plugins/vertexai/tests/plugin_test.ts @@ -71,7 +71,7 @@ describe('plugin', () => { assert.strictEqual(pro002Ref.name, 'vertexai/gemini-1.5-pro-002'); assertEqualModelInfo( pro002Ref.info!, - 'Google AI - gemini-1.5-pro-002', + 'Vertex AI - gemini-1.5-pro-002', gemini15Pro.info! ); const pro002 = await ai.registry.lookupAction(`/model/${pro002Ref.name}`); @@ -79,14 +79,14 @@ describe('plugin', () => { assert.strictEqual(pro002.__action.name, 'vertexai/gemini-1.5-pro-002'); assertEqualModelInfo( pro002.__action.metadata?.model, - 'Google AI - gemini-1.5-pro-002', + 'Vertex AI - gemini-1.5-pro-002', gemini15Pro.info! ); assert.strictEqual(flash002Ref.name, 'vertexai/gemini-1.5-flash-002'); assertEqualModelInfo( flash002Ref.info!, - 'Google AI - gemini-1.5-flash-002', + 'Vertex AI - gemini-1.5-flash-002', gemini15Flash.info! ); const flash002 = await ai.registry.lookupAction( @@ -96,7 +96,7 @@ describe('plugin', () => { assert.strictEqual(flash002.__action.name, 'vertexai/gemini-1.5-flash-002'); assertEqualModelInfo( flash002.__action.metadata?.model, - 'Google AI - gemini-1.5-flash-002', + 'Vertex AI - gemini-1.5-flash-002', gemini15Flash.info! ); @@ -104,7 +104,7 @@ describe('plugin', () => { assert.strictEqual(bananaRef.name, 'vertexai/gemini-4.0-banana'); assertEqualModelInfo( bananaRef.info!, - 'Google AI - gemini-4.0-banana', + 'Vertex AI - gemini-4.0-banana', GENERIC_GEMINI_MODEL.info! // <---- generic model fallback ); const banana = await ai.registry.lookupAction(`/model/${bananaRef.name}`); @@ -112,7 +112,7 @@ describe('plugin', () => { assert.strictEqual(banana.__action.name, 'vertexai/gemini-4.0-banana'); assertEqualModelInfo( banana.__action.metadata?.model, - 'Google AI - gemini-4.0-banana', + 'Vertex AI - gemini-4.0-banana', GENERIC_GEMINI_MODEL.info! // <---- generic model fallback ); From 77fd800a6bfade4afb14ad8535dfa2e47b5c9da2 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 27 Jan 2025 15:38:41 -0500 Subject: [PATCH 3/7] fake derived params --- js/plugins/vertexai/src/common/index.ts | 9 +++++++++ js/plugins/vertexai/tests/plugin_test.ts | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/js/plugins/vertexai/src/common/index.ts b/js/plugins/vertexai/src/common/index.ts index 77e098be0a..9b71fb966b 100644 --- a/js/plugins/vertexai/src/common/index.ts +++ b/js/plugins/vertexai/src/common/index.ts @@ -40,9 +40,18 @@ function parseFirebaseProjectId(): string | undefined { } } +export function __setFakeDerivedParams(params: any) { + __fake_getDerivedParams = params; +} +let __fake_getDerivedParams: any; + export async function getDerivedParams( options?: PluginOptions ): Promise { + if (__fake_getDerivedParams) { + return __fake_getDerivedParams; + } + let authOptions = options?.googleAuth; let authClient: GoogleAuth; const providedProjectId = diff --git a/js/plugins/vertexai/tests/plugin_test.ts b/js/plugins/vertexai/tests/plugin_test.ts index c78ab5ad67..9fc731b707 100644 --- a/js/plugins/vertexai/tests/plugin_test.ts +++ b/js/plugins/vertexai/tests/plugin_test.ts @@ -18,10 +18,16 @@ import * as assert from 'assert'; import { genkit } from 'genkit'; import { ModelInfo } from 'genkit/model'; import { describe, it } from 'node:test'; +import { __setFakeDerivedParams } from '../src/common/index.js'; import { GENERIC_GEMINI_MODEL, gemini } from '../src/gemini.js'; import vertexAI, { gemini15Flash, gemini15Pro } from '../src/index.js'; describe('plugin', () => { + __setFakeDerivedParams({ + projectId: 'test', + location: 'us-central1', + }); + it('should init the plugin without requiring the api key', async () => { const ai = genkit({ plugins: [vertexAI()], From f8b0cef7f0e18b5797b4d8363b5cdc65186d4926 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 27 Jan 2025 16:11:15 -0500 Subject: [PATCH 4/7] Update js/plugins/vertexai/src/common/index.ts Co-authored-by: Michael Doyle --- js/plugins/vertexai/src/common/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/plugins/vertexai/src/common/index.ts b/js/plugins/vertexai/src/common/index.ts index 9b71fb966b..2f01742270 100644 --- a/js/plugins/vertexai/src/common/index.ts +++ b/js/plugins/vertexai/src/common/index.ts @@ -39,7 +39,7 @@ function parseFirebaseProjectId(): string | undefined { return undefined; } } - +/** @hidden */ export function __setFakeDerivedParams(params: any) { __fake_getDerivedParams = params; } From 8d21fb72ece4006eb116f3f86aa5c90555ce83eb Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 27 Jan 2025 16:38:16 -0500 Subject: [PATCH 5/7] use SUPPORTED_V1_MODELS for system role middleware --- js/plugins/vertexai/src/gemini.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 1c8a81bc6f..5381df4311 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -697,7 +697,7 @@ export function defineGeminiModel( options: PluginOptions ): ModelAction { const middlewares: ModelMiddleware[] = []; - if (!modelInfo?.supports?.systemRole) { + if (SUPPORTED_V1_MODELS[version]) { middlewares.push(simulateSystemPrompt()); } if (modelInfo?.supports?.media) { From 0570d91da98bb056c9d67d67e09bcb717f3b6b86 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 27 Jan 2025 16:39:51 -0500 Subject: [PATCH 6/7] more --- js/plugins/vertexai/src/gemini.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 5381df4311..60b5abc363 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -721,7 +721,7 @@ export function defineGeminiModel( // Handle system instructions separately let systemInstruction: Content | undefined = undefined; - if (modelInfo?.supports?.systemRole) { + if (!SUPPORTED_V1_MODELS[version]) { const systemMessage = messages.find((m) => m.role === 'system'); if (systemMessage) { messages.splice(messages.indexOf(systemMessage), 1); From 04bfb9393bfbc4d89082ed1fd1866e9b532e3628 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 27 Jan 2025 20:41:57 -0500 Subject: [PATCH 7/7] formatting --- js/plugins/vertexai/src/common/index.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/js/plugins/vertexai/src/common/index.ts b/js/plugins/vertexai/src/common/index.ts index 2f01742270..56f4725785 100644 --- a/js/plugins/vertexai/src/common/index.ts +++ b/js/plugins/vertexai/src/common/index.ts @@ -39,6 +39,7 @@ function parseFirebaseProjectId(): string | undefined { return undefined; } } + /** @hidden */ export function __setFakeDerivedParams(params: any) { __fake_getDerivedParams = params;