diff --git a/js/plugins/vertexai/src/common/index.ts b/js/plugins/vertexai/src/common/index.ts index 77e098be0a..56f4725785 100644 --- a/js/plugins/vertexai/src/common/index.ts +++ b/js/plugins/vertexai/src/common/index.ts @@ -40,9 +40,19 @@ function parseFirebaseProjectId(): string | undefined { } } +/** @hidden */ +export function __setFakeDerivedParams(params: any) { + __fake_getDerivedParams = params; +} +let __fake_getDerivedParams: any; + export async function getDerivedParams( options?: PluginOptions ): Promise { + if (__fake_getDerivedParams) { + return __fake_getDerivedParams; + } + let authOptions = options?.googleAuth; let authClient: GoogleAuth; const providedProjectId = diff --git a/js/plugins/vertexai/src/common/types.ts b/js/plugins/vertexai/src/common/types.ts index a88fef9e68..642f5980a3 100644 --- a/js/plugins/vertexai/src/common/types.ts +++ b/js/plugins/vertexai/src/common/types.ts @@ -14,7 +14,9 @@ * limitations under the License. */ +import { ModelReference } from 'genkit'; import { GoogleAuthOptions } from 'google-auth-library'; +import { GeminiConfigSchema } from '../gemini'; /** Common options for Vertex AI plugin configuration */ export interface CommonPluginOptions { @@ -27,4 +29,9 @@ export interface CommonPluginOptions { } /** Combined plugin options, extending common options with subplugin-specific options */ -export interface PluginOptions extends CommonPluginOptions {} +export interface PluginOptions extends CommonPluginOptions { + models?: ( + | ModelReference + | string + )[]; +} diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 9281173b47..60b5abc363 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -39,6 +39,7 @@ import { MediaPart, MessageData, ModelAction, + ModelInfo, ModelMiddleware, ModelReference, Part, @@ -166,6 +167,69 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ .optional(), }); +/** + * Known model names, to allow code completion for convenience. Allows other model names. + */ +export type GeminiVersionString = + | keyof typeof SUPPORTED_GEMINI_MODELS + | (string & {}); + +/** + * Returns a reference to a model that can be used in generate calls. + * + * ```js + * await ai.generate({ + * prompt: 'hi', + * model: gemini('gemini-1.5-flash') + * }); + * ``` + */ +export function gemini( + version: GeminiVersionString, + options: GeminiConfig = {} +): ModelReference { + const nearestModel = nearestGeminiModelRef(version); + return modelRef({ + name: `vertexai/${version}`, + config: options, + configSchema: GeminiConfigSchema, + info: { + ...nearestModel.info, + // If exact suffix match for a known model, use its label, otherwise create a new label + label: nearestModel.name.endsWith(version) + ? nearestModel.info?.label + : `Vertex AI - ${version}`, + }, + }); +} + +function nearestGeminiModelRef( + version: GeminiVersionString, + options: GeminiConfig = {} +): ModelReference { + const matchingKey = longestMatchingPrefix( + version, + Object.keys(SUPPORTED_GEMINI_MODELS) + ); + if (matchingKey) { + return SUPPORTED_GEMINI_MODELS[matchingKey].withConfig({ + ...options, + version, + }); + } + return GENERIC_GEMINI_MODEL.withConfig({ ...options, version }); +} + +function longestMatchingPrefix(version: string, potentialMatches: string[]) { + return potentialMatches + .filter((p) => version.startsWith(p)) + .reduce( + (longest, current) => + current.length > longest.length ? current : longest, + '' + ); +} + /** * Gemini model configuration options. * @@ -268,6 +332,21 @@ export const gemini20FlashExp = modelRef({ configSchema: GeminiConfigSchema, }); +export const GENERIC_GEMINI_MODEL = modelRef({ + name: 'vertexai/gemini', + configSchema: GeminiConfigSchema, + info: { + label: 'Google Gemini', + supports: { + multiturn: true, + media: true, + tools: true, + toolChoice: true, + systemRole: true, + }, + }, +}); + export const SUPPORTED_V1_MODELS = { 'gemini-1.0-pro': gemini10Pro, }; @@ -281,11 +360,11 @@ export const SUPPORTED_V15_MODELS = { export const SUPPORTED_GEMINI_MODELS = { ...SUPPORTED_V1_MODELS, ...SUPPORTED_V15_MODELS, -}; +} as const; function toGeminiRole( role: MessageData['role'], - model?: ModelReference + modelInfo?: ModelInfo ): string { switch (role) { case 'user': @@ -293,7 +372,7 @@ function toGeminiRole( case 'model': return 'model'; case 'system': - if (model && SUPPORTED_V15_MODELS[model.name]) { + if (modelInfo && modelInfo.supports?.systemRole) { // We should have already pulled out the supported system messages, // anything remaining is unsupported; throw an error. throw new Error( @@ -387,10 +466,10 @@ export function toGeminiSystemInstruction(message: MessageData): Content { export function toGeminiMessage( message: MessageData, - model?: ModelReference + modelInfo?: ModelInfo ): Content { return { - role: toGeminiRole(message.role, model), + role: toGeminiRole(message.role, modelInfo), parts: message.content.map(toGeminiPart), }; } @@ -581,7 +660,7 @@ export function cleanSchema(schema: JSONSchema): JSONSchema { /** * Define a Vertex AI Gemini model. */ -export function defineGeminiModel( +export function defineGeminiKnownModel( ai: Genkit, name: string, vertexClientFactory: ( @@ -594,11 +673,34 @@ export function defineGeminiModel( const model: ModelReference = SUPPORTED_GEMINI_MODELS[name]; if (!model) throw new Error(`Unsupported model: ${name}`); + return defineGeminiModel( + ai, + modelName, + name, + model?.info, + vertexClientFactory, + options + ); +} + +/** + * Define a Vertex AI Gemini model. + */ +export function defineGeminiModel( + ai: Genkit, + modelName: string, + version: string, + modelInfo: ModelInfo | undefined, + vertexClientFactory: ( + request: GenerateRequest + ) => VertexAI, + options: PluginOptions +): ModelAction { const middlewares: ModelMiddleware[] = []; - if (SUPPORTED_V1_MODELS[name]) { + if (SUPPORTED_V1_MODELS[version]) { middlewares.push(simulateSystemPrompt()); } - if (model?.info?.supports?.media) { + if (modelInfo?.supports?.media) { // the gemini api doesn't support downloading media from http(s) middlewares.push(downloadRequestMedia({ maxBytes: 1024 * 1024 * 20 })); } @@ -606,7 +708,7 @@ export function defineGeminiModel( return ai.defineModel( { name: modelName, - ...model.info, + ...modelInfo, configSchema: GeminiConfigSchema, use: middlewares, }, @@ -619,7 +721,7 @@ export function defineGeminiModel( // Handle system instructions separately let systemInstruction: Content | undefined = undefined; - if (SUPPORTED_V15_MODELS[name]) { + if (!SUPPORTED_V1_MODELS[version]) { const systemMessage = messages.find((m) => m.role === 'system'); if (systemMessage) { messages.splice(messages.indexOf(systemMessage), 1); @@ -659,7 +761,7 @@ export function defineGeminiModel( toolConfig, history: messages .slice(0, -1) - .map((message) => toGeminiMessage(message, model)), + .map((message) => toGeminiMessage(message, modelInfo)), generationConfig: { candidateCount: request.candidates || undefined, temperature: request.config?.temperature, @@ -673,9 +775,7 @@ export function defineGeminiModel( }; // Handle cache - const modelVersion = (request.config?.version || - model.version || - name) as string; + const modelVersion = (request.config?.version || version) as string; const cacheConfigDetails = extractCacheConfig(request); const apiClient = new ApiClient( @@ -727,7 +827,7 @@ export function defineGeminiModel( }); } - const msg = toGeminiMessage(messages[messages.length - 1], model); + const msg = toGeminiMessage(messages[messages.length - 1], modelInfo); if (cache) { genModel = vertex.preview.getGenerativeModelFromCachedContent( diff --git a/js/plugins/vertexai/src/index.ts b/js/plugins/vertexai/src/index.ts index 18114488a2..47de7c9a02 100644 --- a/js/plugins/vertexai/src/index.ts +++ b/js/plugins/vertexai/src/index.ts @@ -35,7 +35,9 @@ import { } from './embedder.js'; import { SUPPORTED_GEMINI_MODELS, + defineGeminiKnownModel, defineGeminiModel, + gemini, gemini10Pro, gemini15Flash, gemini15Pro, @@ -51,6 +53,7 @@ import { } from './imagen.js'; export { type PluginOptions } from './common/types.js'; export { + gemini, gemini10Pro, gemini15Flash, gemini15Pro, @@ -78,8 +81,33 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin { imagenModel(ai, name, authClient, { projectId, location }) ); Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => - defineGeminiModel(ai, name, vertexClientFactory, { projectId, location }) + defineGeminiKnownModel(ai, name, vertexClientFactory, { + projectId, + location, + }) ); + if (options?.models) { + for (const modelOrRef of options?.models) { + const modelName = + typeof modelOrRef === 'string' + ? modelOrRef + : // strip out the `vertexai/` prefix + modelOrRef.name.split('/')[1]; + const modelRef = + typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef; + defineGeminiModel( + ai, + modelRef.name, + modelName, + modelRef.info, + vertexClientFactory, + { + projectId, + location, + } + ); + } + } Object.keys(SUPPORTED_EMBEDDER_MODELS).map((name) => defineVertexAIEmbedder(ai, name, authClient, { projectId, location }) diff --git a/js/plugins/vertexai/src/modelgarden/mistral.ts b/js/plugins/vertexai/src/modelgarden/mistral.ts index 05c04f8a4e..b2aac4dde9 100644 --- a/js/plugins/vertexai/src/modelgarden/mistral.ts +++ b/js/plugins/vertexai/src/modelgarden/mistral.ts @@ -124,6 +124,8 @@ function toMistralRole(role: Role): MistralRole { return 'tool'; case 'system': return 'system'; + default: + throw new Error(`Unknwon role ${role}`); } } function toMistralToolRequest(toolRequest: Record): FunctionCall { diff --git a/js/plugins/vertexai/tests/plugin_test.ts b/js/plugins/vertexai/tests/plugin_test.ts new file mode 100644 index 0000000000..9fc731b707 --- /dev/null +++ b/js/plugins/vertexai/tests/plugin_test.ts @@ -0,0 +1,143 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from 'assert'; +import { genkit } from 'genkit'; +import { ModelInfo } from 'genkit/model'; +import { describe, it } from 'node:test'; +import { __setFakeDerivedParams } from '../src/common/index.js'; +import { GENERIC_GEMINI_MODEL, gemini } from '../src/gemini.js'; +import vertexAI, { gemini15Flash, gemini15Pro } from '../src/index.js'; + +describe('plugin', () => { + __setFakeDerivedParams({ + projectId: 'test', + location: 'us-central1', + }); + + it('should init the plugin without requiring the api key', async () => { + const ai = genkit({ + plugins: [vertexAI()], + }); + + assert.ok(ai); + }); + + it('should pre-register a few flagship models', async () => { + const ai = genkit({ + plugins: [vertexAI()], + }); + + assert.ok(await ai.registry.lookupAction(`/model/${gemini15Flash.name}`)); + assert.ok(await ai.registry.lookupAction(`/model/${gemini15Pro.name}`)); + }); + + it('allow referencing models using `gemini` helper', async () => { + const ai = genkit({ + plugins: [vertexAI()], + }); + + const pro = await ai.registry.lookupAction( + `/model/${gemini('gemini-1.5-pro').name}` + ); + assert.ok(pro); + assert.strictEqual(pro.__action.name, 'vertexai/gemini-1.5-pro'); + const flash = await ai.registry.lookupAction( + `/model/${gemini('gemini-1.5-flash').name}` + ); + assert.ok(flash); + assert.strictEqual(flash.__action.name, 'vertexai/gemini-1.5-flash'); + }); + + it('references explicitly registered models', async () => { + const flash002Ref = gemini('gemini-1.5-flash-002'); + const ai = genkit({ + plugins: [ + vertexAI({ + location: 'us-central1', + models: ['gemini-1.5-pro-002', flash002Ref, 'gemini-4.0-banana'], + }), + ], + }); + + const pro002Ref = gemini('gemini-1.5-pro-002'); + assert.strictEqual(pro002Ref.name, 'vertexai/gemini-1.5-pro-002'); + assertEqualModelInfo( + pro002Ref.info!, + 'Vertex AI - gemini-1.5-pro-002', + gemini15Pro.info! + ); + const pro002 = await ai.registry.lookupAction(`/model/${pro002Ref.name}`); + assert.ok(pro002); + assert.strictEqual(pro002.__action.name, 'vertexai/gemini-1.5-pro-002'); + assertEqualModelInfo( + pro002.__action.metadata?.model, + 'Vertex AI - gemini-1.5-pro-002', + gemini15Pro.info! + ); + + assert.strictEqual(flash002Ref.name, 'vertexai/gemini-1.5-flash-002'); + assertEqualModelInfo( + flash002Ref.info!, + 'Vertex AI - gemini-1.5-flash-002', + gemini15Flash.info! + ); + const flash002 = await ai.registry.lookupAction( + `/model/${flash002Ref.name}` + ); + assert.ok(flash002); + assert.strictEqual(flash002.__action.name, 'vertexai/gemini-1.5-flash-002'); + assertEqualModelInfo( + flash002.__action.metadata?.model, + 'Vertex AI - gemini-1.5-flash-002', + gemini15Flash.info! + ); + + const bananaRef = gemini('gemini-4.0-banana'); + assert.strictEqual(bananaRef.name, 'vertexai/gemini-4.0-banana'); + assertEqualModelInfo( + bananaRef.info!, + 'Vertex AI - gemini-4.0-banana', + GENERIC_GEMINI_MODEL.info! // <---- generic model fallback + ); + const banana = await ai.registry.lookupAction(`/model/${bananaRef.name}`); + assert.ok(banana); + assert.strictEqual(banana.__action.name, 'vertexai/gemini-4.0-banana'); + assertEqualModelInfo( + banana.__action.metadata?.model, + 'Vertex AI - gemini-4.0-banana', + GENERIC_GEMINI_MODEL.info! // <---- generic model fallback + ); + + // this one is not registered + const flash003Ref = gemini('gemini-1.5-flash-003'); + assert.strictEqual(flash003Ref.name, 'vertexai/gemini-1.5-flash-003'); + const flash003 = await ai.registry.lookupAction( + `/model/${flash003Ref.name}` + ); + assert.ok(flash003 === undefined); + }); +}); + +function assertEqualModelInfo( + modelAction: ModelInfo, + expectedLabel: string, + expectedInfo: ModelInfo +) { + assert.strictEqual(modelAction.label, expectedLabel); + assert.deepStrictEqual(modelAction.supports, expectedInfo.supports); + assert.deepStrictEqual(modelAction.versions, expectedInfo.versions); +}