From d449594691426ce59380f16830374a8faba2dec4 Mon Sep 17 00:00:00 2001 From: Michael Doyle Date: Wed, 12 Jun 2024 10:29:51 -0400 Subject: [PATCH] Add explicit return types for plugins that define models, retriever, indexers, etc --- js/plugins/chroma/src/index.ts | 6 +++-- js/plugins/dev-local-vectorstore/src/index.ts | 6 +++-- js/plugins/googleai/src/embedder.ts | 8 +++++-- js/plugins/googleai/src/gemini.ts | 13 ++++++----- js/plugins/ollama/src/index.ts | 3 ++- js/plugins/pinecone/src/index.ts | 6 +++-- js/plugins/vertexai/src/anthropic.ts | 3 ++- js/plugins/vertexai/src/embedder.ts | 8 +++++-- js/plugins/vertexai/src/gemini.ts | 22 ++++++++++++++----- js/plugins/vertexai/src/imagen.ts | 7 ++++-- js/plugins/vertexai/src/index.ts | 12 +++++----- 11 files changed, 62 insertions(+), 32 deletions(-) diff --git a/js/plugins/chroma/src/index.ts b/js/plugins/chroma/src/index.ts index 20a0443b5a..cb04a0a12a 100644 --- a/js/plugins/chroma/src/index.ts +++ b/js/plugins/chroma/src/index.ts @@ -20,7 +20,9 @@ import { defineIndexer, defineRetriever, Document, + IndexerAction, indexerRef, + RetrieverAction, retrieverRef, } from '@genkit-ai/ai/retriever'; import { genkitPlugin, PluginProvider } from '@genkit-ai/core'; @@ -120,7 +122,7 @@ export function chromaRetriever< createCollectionIfMissing?: boolean; embedder: EmbedderArgument; embedderOptions?: z.infer; -}) { +}): RetrieverAction> { const { embedder, collectionName, embedderOptions } = params; return defineRetriever( { @@ -191,7 +193,7 @@ export function chromaIndexer< createCollectionIfMissing?: boolean; embedder: EmbedderArgument; embedderOptions?: z.infer; -}) { +}): IndexerAction { const { collectionName, embedder, embedderOptions } = { ...params, }; diff --git a/js/plugins/dev-local-vectorstore/src/index.ts b/js/plugins/dev-local-vectorstore/src/index.ts index 8e646a0944..f66bd568fe 100644 --- a/js/plugins/dev-local-vectorstore/src/index.ts +++ b/js/plugins/dev-local-vectorstore/src/index.ts @@ -21,7 +21,9 @@ import { defineRetriever, Document, DocumentData, + IndexerAction, indexerRef, + RetrieverAction, retrieverRef, } from '@genkit-ai/ai/retriever'; import { genkitPlugin, PluginProvider } from '@genkit-ai/core'; @@ -173,7 +175,7 @@ export function configureDevLocalRetriever< indexName: string; embedder: EmbedderArgument; embedderOptions?: z.infer; -}) { +}): RetrieverAction { const { embedder, embedderOptions } = params; const vectorstore = defineRetriever( { @@ -209,7 +211,7 @@ export function configureDevLocalIndexer< indexName: string; embedder: EmbedderArgument; embedderOptions?: z.infer; -}) { +}): IndexerAction { const { embedder, embedderOptions } = params; const vectorstore = defineIndexer( { name: `devLocalVectorstore/${params.indexName}` }, diff --git a/js/plugins/googleai/src/embedder.ts b/js/plugins/googleai/src/embedder.ts index 449f76b18c..9d7805e244 100644 --- a/js/plugins/googleai/src/embedder.ts +++ b/js/plugins/googleai/src/embedder.ts @@ -14,7 +14,11 @@ * limitations under the License. */ -import { defineEmbedder, embedderRef } from '@genkit-ai/ai/embedder'; +import { + defineEmbedder, + EmbedderAction, + embedderRef, +} from '@genkit-ai/ai/embedder'; import { EmbedContentRequest, GoogleGenerativeAI } from '@google/generative-ai'; import { string, z } from 'zod'; import { PluginOptions } from './index.js'; @@ -60,7 +64,7 @@ export const SUPPORTED_MODELS = { export function textEmbeddingGeckoEmbedder( name: string, options: PluginOptions -) { +): EmbedderAction { let apiKey = options?.apiKey || process.env.GOOGLE_GENAI_API_KEY || diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index 03a15b2ab4..cf6825c5e3 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -150,7 +150,7 @@ export const geminiUltra = modelRef({ export const SUPPORTED_V1_MODELS: Record< string, - ModelReference + ModelReference > = { 'gemini-pro': geminiPro, 'gemini-pro-vision': geminiProVision, @@ -159,7 +159,7 @@ export const SUPPORTED_V1_MODELS: Record< export const SUPPORTED_V15_MODELS: Record< string, - ModelReference + ModelReference > = { 'gemini-1.5-pro-latest': gemini15Pro, 'gemini-1.5-flash-latest': gemini15Flash, @@ -172,7 +172,7 @@ const SUPPORTED_MODELS = { function toGeminiRole( role: MessageData['role'], - model?: ModelReference + model?: ModelReference ): string { switch (role) { case 'user': @@ -331,7 +331,7 @@ function fromGeminiPart(part: GeminiPart): Part { export function toGeminiMessage( message: MessageData, - model?: ModelReference + model?: ModelReference ): GeminiMessage { return { role: toGeminiRole(message.role, model), @@ -387,7 +387,7 @@ export function googleAIModel( apiKey?: string, apiVersion?: string, baseUrl?: string -): ModelAction { +): ModelAction { const modelName = `googleai/${name}`; if (!apiKey) { @@ -400,7 +400,8 @@ export function googleAIModel( ); } - const model: ModelReference = SUPPORTED_MODELS[name]; + const model: ModelReference = + SUPPORTED_MODELS[name]; if (!model) throw new Error(`Unsupported model: ${name}`); const middleware: ModelMiddleware[] = []; diff --git a/js/plugins/ollama/src/index.ts b/js/plugins/ollama/src/index.ts index 8ed8760e05..a05dcfc69b 100644 --- a/js/plugins/ollama/src/index.ts +++ b/js/plugins/ollama/src/index.ts @@ -22,6 +22,7 @@ import { GenerationCommonConfigSchema, getBasicUsageStats, MessageData, + ModelAction, } from '@genkit-ai/ai/model'; import { genkitPlugin, Plugin } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; @@ -63,7 +64,7 @@ function ollamaModel( model: ModelDefinition, serverAddress: string, requestHeaders?: RequestHeaders -) { +): ModelAction { return defineModel( { name: `ollama/${model.name}`, diff --git a/js/plugins/pinecone/src/index.ts b/js/plugins/pinecone/src/index.ts index 7746f31845..9813f91cdd 100644 --- a/js/plugins/pinecone/src/index.ts +++ b/js/plugins/pinecone/src/index.ts @@ -20,7 +20,9 @@ import { defineIndexer, defineRetriever, Document, + IndexerAction, indexerRef, + RetrieverAction, retrieverRef, } from '@genkit-ai/ai/retriever'; import { genkitPlugin, PluginProvider } from '@genkit-ai/core'; @@ -130,7 +132,7 @@ export function configurePineconeRetriever< textKey?: string; embedder: EmbedderArgument; embedderOptions?: z.infer; -}) { +}): RetrieverAction { const { indexId, embedder, embedderOptions } = { ...params, }; @@ -185,7 +187,7 @@ export function configurePineconeIndexer< textKey?: string; embedder: EmbedderArgument; embedderOptions?: z.infer; -}) { +}): IndexerAction> { const { indexId, embedder, embedderOptions } = { ...params, }; diff --git a/js/plugins/vertexai/src/anthropic.ts b/js/plugins/vertexai/src/anthropic.ts index d6bf2ed088..64fcbb29d4 100644 --- a/js/plugins/vertexai/src/anthropic.ts +++ b/js/plugins/vertexai/src/anthropic.ts @@ -28,6 +28,7 @@ import { GenerateResponseData, GenerationCommonConfigSchema, Part as GenkitPart, + ModelAction, ModelReference, defineModel, getBasicUsageStats, @@ -96,7 +97,7 @@ export function anthropicModel( modelName: string, projectId: string, region: string -) { +): ModelAction { const client = new AnthropicVertex({ region, projectId, diff --git a/js/plugins/vertexai/src/embedder.ts b/js/plugins/vertexai/src/embedder.ts index 072464409d..865c6f0926 100644 --- a/js/plugins/vertexai/src/embedder.ts +++ b/js/plugins/vertexai/src/embedder.ts @@ -16,6 +16,7 @@ import { defineEmbedder, + EmbedderAction, embedderRef, EmbedderReference, } from '@genkit-ai/ai/embedder'; @@ -119,7 +120,10 @@ export const textEmbeddingGeckoMultilingual001 = embedderRef({ export const textEmbeddingGecko = textEmbeddingGecko003; -export const SUPPORTED_EMBEDDER_MODELS: Record = { +export const SUPPORTED_EMBEDDER_MODELS: Record< + string, + EmbedderReference +> = { 'textembedding-gecko@003': textEmbeddingGecko003, 'textembedding-gecko@002': textEmbeddingGecko002, 'textembedding-gecko@001': textEmbeddingGecko001, @@ -147,7 +151,7 @@ export function textEmbeddingGeckoEmbedder( name: string, client: GoogleAuth, options: PluginOptions -) { +): EmbedderAction { const embedder = SUPPORTED_EMBEDDER_MODELS[name]; // TODO: Figure out how to allow different versions while still sharing a single implementation. const predict = predictModel( diff --git a/js/plugins/vertexai/src/gemini.ts b/js/plugins/vertexai/src/gemini.ts index 797b3c0d9a..05a208d391 100644 --- a/js/plugins/vertexai/src/gemini.ts +++ b/js/plugins/vertexai/src/gemini.ts @@ -149,13 +149,19 @@ export const gemini15Flash = modelRef({ configSchema: GeminiConfigSchema, }); -export const SUPPORTED_V1_MODELS = { +export const SUPPORTED_V1_MODELS: Record< + string, + ModelReference +> = { 'gemini-1.0-pro': geminiPro, 'gemini-1.0-pro-vision': geminiProVision, // 'gemini-ultra': geminiUltra, }; -export const SUPPORTED_V15_MODELS = { +export const SUPPORTED_V15_MODELS: Record< + string, + ModelReference +> = { 'gemini-1.5-pro': gemini15Pro, 'gemini-1.5-flash': gemini15Flash, 'gemini-1.5-pro-preview': gemini15ProPreview, @@ -169,7 +175,7 @@ export const SUPPORTED_GEMINI_MODELS = { function toGeminiRole( role: MessageData['role'], - model?: ModelReference + model?: ModelReference ): string { switch (role) { case 'user': @@ -271,7 +277,7 @@ export function toGeminiSystemInstruction(message: MessageData): Content { export function toGeminiMessage( message: MessageData, - model?: ModelReference + model?: ModelReference ): Content { return { role: toGeminiRole(message.role, model), @@ -441,10 +447,14 @@ const convertSchemaProperty = (property) => { /** * */ -export function geminiModel(name: string, vertex: VertexAI): ModelAction { +export function geminiModel( + name: string, + vertex: VertexAI +): ModelAction { const modelName = `vertexai/${name}`; - const model: ModelReference = SUPPORTED_GEMINI_MODELS[name]; + const model: ModelReference = + SUPPORTED_GEMINI_MODELS[name]; if (!model) throw new Error(`Unsupported model: ${name}`); const middlewares: ModelMiddleware[] = []; diff --git a/js/plugins/vertexai/src/imagen.ts b/js/plugins/vertexai/src/imagen.ts index dd654f922f..f18e58ece9 100644 --- a/js/plugins/vertexai/src/imagen.ts +++ b/js/plugins/vertexai/src/imagen.ts @@ -20,6 +20,7 @@ import { GenerateRequest, GenerationCommonConfigSchema, getBasicUsageStats, + ModelAction, modelRef, } from '@genkit-ai/ai/model'; import { GoogleAuth } from 'google-auth-library'; @@ -39,7 +40,6 @@ const ImagenConfigSchema = GenerationCommonConfigSchema.extend({ /** Any non-negative integer you provide to make output images deterministic. Providing the same seed number always results in the same output images. Accepted integer values: 1 - 2147483647. */ seed: z.number().optional(), }); -type ImagenConfig = z.infer; export const imagen2 = modelRef({ name: 'vertexai/imagen2', @@ -109,7 +109,10 @@ interface ImagenInstance { /** * */ -export function imagen2Model(client: GoogleAuth, options: PluginOptions) { +export function imagen2Model( + client: GoogleAuth, + options: PluginOptions +): ModelAction { const predict = predictModel< ImagenInstance, ImagenPrediction, diff --git a/js/plugins/vertexai/src/index.ts b/js/plugins/vertexai/src/index.ts index 5c00e2fa33..47236dd771 100644 --- a/js/plugins/vertexai/src/index.ts +++ b/js/plugins/vertexai/src/index.ts @@ -14,16 +14,16 @@ * limitations under the License. */ -import { ModelReference } from '@genkit-ai/ai/model'; -import { genkitPlugin, Plugin } from '@genkit-ai/core'; +import { ModelAction, ModelReference } from '@genkit-ai/ai/model'; +import { Plugin, genkitPlugin } from '@genkit-ai/core'; import { VertexAI } from '@google-cloud/vertexai'; import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; import { + SUPPORTED_ANTHROPIC_MODELS, anthropicModel, claude3Haiku, claude3Opus, claude3Sonnet, - SUPPORTED_ANTHROPIC_MODELS, } from './anthropic.js'; import { SUPPORTED_EMBEDDER_MODELS, @@ -42,6 +42,7 @@ import { vertexEvaluators, } from './evaluation.js'; import { + SUPPORTED_GEMINI_MODELS, gemini15Flash, gemini15FlashPreview, gemini15Pro, @@ -49,11 +50,11 @@ import { geminiModel, geminiPro, geminiProVision, - SUPPORTED_GEMINI_MODELS, } from './gemini.js'; import { imagen2, imagen2Model } from './imagen.js'; export { + VertexAIEvaluationMetricType as VertexAIEvaluationMetricType, claude3Haiku, claude3Opus, claude3Sonnet, @@ -71,7 +72,6 @@ export { textEmbeddingGecko003, textEmbeddingGeckoMultilingual001, textMultilingualEmbedding002, - VertexAIEvaluationMetricType as VertexAIEvaluationMetricType, }; export interface PluginOptions { @@ -120,7 +120,7 @@ export const vertexAI: Plugin<[PluginOptions] | []> = genkitPlugin( ? options.evaluation.metrics : []; - const models = [ + const models: ModelAction[] = [ imagen2Model(authClient, { projectId, location }), ...Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => geminiModel(name, vertexClient)