Skip to content

Commit

Permalink
feat: allow overriding location per request in vertex ai plugin (#716)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj committed Jul 30, 2024
1 parent 8198ae2 commit f720a2f
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 94 deletions.
42 changes: 27 additions & 15 deletions js/plugins/vertexai/src/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ import {
modelRef,
} from '@genkit-ai/ai/model';
import { GENKIT_CLIENT_HEADER } from '@genkit-ai/core';
import z from 'zod';

export const AnthropicConfigSchema = GenerationCommonConfigSchema.extend({
location: z.string().optional(),
});

export const claude35Sonnet = modelRef({
name: 'vertexai/claude-3-5-sonnet',
Expand All @@ -56,7 +61,7 @@ export const claude35Sonnet = modelRef({
output: ['text'],
},
},
configSchema: GenerationCommonConfigSchema,
configSchema: AnthropicConfigSchema,
});

export const claude3Sonnet = modelRef({
Expand All @@ -72,7 +77,7 @@ export const claude3Sonnet = modelRef({
output: ['text'],
},
},
configSchema: GenerationCommonConfigSchema,
configSchema: AnthropicConfigSchema,
});

export const claude3Haiku = modelRef({
Expand All @@ -88,7 +93,7 @@ export const claude3Haiku = modelRef({
output: ['text'],
},
},
configSchema: GenerationCommonConfigSchema,
configSchema: AnthropicConfigSchema,
});

export const claude3Opus = modelRef({
Expand All @@ -104,12 +109,12 @@ export const claude3Opus = modelRef({
output: ['text'],
},
},
configSchema: GenerationCommonConfigSchema,
configSchema: AnthropicConfigSchema,
});

export const SUPPORTED_ANTHROPIC_MODELS: Record<
string,
ModelReference<typeof GenerationCommonConfigSchema>
ModelReference<typeof AnthropicConfigSchema>
> = {
'claude-3-5-sonnet': claude35Sonnet,
'claude-3-sonnet': claude3Sonnet,
Expand All @@ -119,7 +124,7 @@ export const SUPPORTED_ANTHROPIC_MODELS: Record<

export function toAnthropicRequest(
model: string,
input: GenerateRequest<typeof GenerationCommonConfigSchema>
input: GenerateRequest<typeof AnthropicConfigSchema>
): MessageCreateParamsBase {
let system: string | undefined = undefined;
const messages: MessageParam[] = [];
Expand Down Expand Up @@ -289,7 +294,7 @@ function fromAnthropicCandidate(candidate: Message): CandidateData {
}

export function fromAnthropicResponse(
input: GenerateRequest<typeof GenerationCommonConfigSchema>,
input: GenerateRequest<typeof AnthropicConfigSchema>,
response: Message
): GenerateResponseData {
const candidates: CandidateData[] = [fromAnthropicCandidate(response)];
Expand Down Expand Up @@ -365,13 +370,19 @@ export function anthropicModel(
projectId: string,
region: string
) {
const client = new AnthropicVertex({
region,
projectId,
defaultHeaders: {
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
},
});
const clients: Record<string, AnthropicVertex> = {};
const clientFactory = (region: string): AnthropicVertex => {
if (!clients[region]) {
clients[region] = new AnthropicVertex({
region,
projectId,
defaultHeaders: {
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
},
});
}
return clients[region];
};
const model = SUPPORTED_ANTHROPIC_MODELS[modelName];
if (!model) {
throw new Error(`unsupported Anthropic model name ${modelName}`);
Expand All @@ -381,11 +392,12 @@ export function anthropicModel(
{
name: model.name,
label: model.info?.label,
configSchema: GenerationCommonConfigSchema,
configSchema: AnthropicConfigSchema,
supports: model.info?.supports,
versions: model.info?.versions,
},
async (input, streamingCallback) => {
const client = clientFactory(input.config?.location || region);
if (!streamingCallback) {
const response = await client.messages.create({
...toAnthropicRequest(input.config?.version ?? modelName, input),
Expand Down
37 changes: 29 additions & 8 deletions js/plugins/vertexai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {
import { GoogleAuth } from 'google-auth-library';
import { z } from 'zod';
import { PluginOptions } from './index.js';
import { predictModel } from './predict.js';
import { PredictClient, predictModel } from './predict.js';

export const TaskTypeSchema = z.enum([
'RETRIEVAL_DOCUMENT',
Expand All @@ -40,6 +40,7 @@ export const TextEmbeddingGeckoConfigSchema = z.object({
**/
taskType: TaskTypeSchema.optional(),
title: z.string().optional(),
location: z.string().optional(),
});
export type TextEmbeddingGeckoConfig = z.infer<
typeof TextEmbeddingGeckoConfigSchema
Expand Down Expand Up @@ -149,20 +150,40 @@ export function textEmbeddingGeckoEmbedder(
options: PluginOptions
) {
const embedder = SUPPORTED_EMBEDDER_MODELS[name];
// TODO: Figure out how to allow different versions while still sharing a single implementation.
const predict = predictModel<EmbeddingInstance, EmbeddingPrediction>(
client,
options,
name
);
const predictClients: Record<
string,
PredictClient<EmbeddingInstance, EmbeddingPrediction>
> = {};
const predictClientFactory = (
config: TextEmbeddingGeckoConfig
): PredictClient<EmbeddingInstance, EmbeddingPrediction> => {
const requestLocation = config?.location || options.location;
if (!predictClients[requestLocation]) {
// TODO: Figure out how to allow different versions while still sharing a single implementation.
predictClients[requestLocation] = predictModel<
EmbeddingInstance,
EmbeddingPrediction
>(
client,
{
...options,
location: requestLocation,
},
name
);
}
return predictClients[requestLocation];
};

return defineEmbedder(
{
name: embedder.name,
configSchema: embedder.configSchema,
info: embedder.info!,
},
async (input, options) => {
const response = await predict(
const predictClient = predictClientFactory(options);
const response = await predictClient(
input.map((i) => {
return {
content: i.text(),
Expand Down
15 changes: 10 additions & 5 deletions js/plugins/vertexai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import {
CandidateData,
defineModel,
GenerateRequest,
GenerationCommonConfigSchema,
getBasicUsageStats,
MediaPart,
Expand Down Expand Up @@ -53,8 +54,9 @@ const SafetySettingsSchema = z.object({
threshold: z.nativeEnum(HarmBlockThreshold),
});

const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
safetySettings: z.array(SafetySettingsSchema).optional(),
location: z.string().optional(),
});

export const geminiPro = modelRef({
Expand Down Expand Up @@ -438,10 +440,12 @@ const convertSchemaProperty = (property) => {
}
};

/**
*
*/
export function geminiModel(name: string, vertex: VertexAI): ModelAction {
export function geminiModel(
name: string,
vertexClientFactory: (
request: GenerateRequest<typeof GeminiConfigSchema>
) => VertexAI
): ModelAction {
const modelName = `vertexai/${name}`;

const model: ModelReference<z.ZodTypeAny> = SUPPORTED_GEMINI_MODELS[name];
Expand All @@ -464,6 +468,7 @@ export function geminiModel(name: string, vertex: VertexAI): ModelAction {
use: middlewares,
},
async (request, streamingCallback) => {
const vertex = vertexClientFactory(request);
const client = vertex.preview.getGenerativeModel(
{
model: request.config?.version || model.version || name,
Expand Down
39 changes: 28 additions & 11 deletions js/plugins/vertexai/src/imagen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import {
import { GoogleAuth } from 'google-auth-library';
import z from 'zod';
import { PluginOptions } from './index.js';
import { predictModel } from './predict.js';
import { PredictClient, predictModel } from './predict.js';

const ImagenConfigSchema = GenerationCommonConfigSchema.extend({
/** Language of the prompt text. */
Expand All @@ -38,8 +38,8 @@ const ImagenConfigSchema = GenerationCommonConfigSchema.extend({
negativePrompt: z.string().optional(),
/** Any non-negative integer you provide to make output images deterministic. Providing the same seed number always results in the same output images. Accepted integer values: 1 - 2147483647. */
seed: z.number().optional(),
location: z.string().optional(),
});
type ImagenConfig = z.infer<typeof ImagenConfigSchema>;

export const imagen2 = modelRef({
name: 'vertexai/imagen2',
Expand Down Expand Up @@ -106,15 +106,31 @@ interface ImagenInstance {
image?: { bytesBase64Encoded: string };
}

/**
*
*/
export function imagen2Model(client: GoogleAuth, options: PluginOptions) {
const predict = predictModel<
ImagenInstance,
ImagenPrediction,
ImagenParameters
>(client, options, 'imagegeneration@005');
const predictClients: Record<
string,
PredictClient<ImagenInstance, ImagenPrediction, ImagenParameters>
> = {};
const predictClientFactory = (
request: GenerateRequest<typeof ImagenConfigSchema>
): PredictClient<ImagenInstance, ImagenPrediction, ImagenParameters> => {
const requestLocation = request.config?.location || options.location;
if (!predictClients[requestLocation]) {
predictClients[requestLocation] = predictModel<
ImagenInstance,
ImagenPrediction,
ImagenParameters
>(
client,
{
...options,
location: requestLocation,
},
'imagegeneration@005'
);
}
return predictClients[requestLocation];
};

return defineModel(
{
Expand All @@ -134,7 +150,8 @@ export function imagen2Model(client: GoogleAuth, options: PluginOptions) {
parameters: toParameters(request),
};

const response = await predict([instance], toParameters(request));
const predictClient = predictClientFactory(request);
const response = await predictClient([instance], toParameters(request));

const candidates: CandidateData[] = response.predictions.map((p, i) => {
const b64data = p.bytesBase64Encoded;
Expand Down
24 changes: 17 additions & 7 deletions js/plugins/vertexai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { ModelReference } from '@genkit-ai/ai/model';
import { GenerateRequest, ModelReference } from '@genkit-ai/ai/model';
import { IndexerAction, RetrieverAction } from '@genkit-ai/ai/retriever';
import { Plugin, genkitPlugin } from '@genkit-ai/core';
import { VertexAI } from '@google-cloud/vertexai';
Expand Down Expand Up @@ -45,6 +45,7 @@ import {
vertexEvaluators,
} from './evaluation.js';
import {
GeminiConfigSchema,
SUPPORTED_GEMINI_MODELS,
gemini15Flash,
gemini15FlashPreview,
Expand Down Expand Up @@ -152,11 +153,20 @@ export const vertexAI: Plugin<[PluginOptions] | []> = genkitPlugin(
throw confError('project', 'GCLOUD_PROJECT');
}

const vertexClient = new VertexAI({
project: projectId,
location,
googleAuthOptions: options?.googleAuth,
});
const vertexClientFactoryCache: Record<string, VertexAI> = {};
const vertexClientFactory = (
request: GenerateRequest<typeof GeminiConfigSchema>
): VertexAI => {
const requestLocation = request.config?.location || location;
if (!vertexClientFactoryCache[requestLocation]) {
vertexClientFactoryCache[requestLocation] = new VertexAI({
project: projectId,
location: requestLocation,
googleAuthOptions: options?.googleAuth,
});
}
return vertexClientFactoryCache[requestLocation];
};
const metrics =
options?.evaluation && options.evaluation.metrics.length > 0
? options.evaluation.metrics
Expand All @@ -165,7 +175,7 @@ export const vertexAI: Plugin<[PluginOptions] | []> = genkitPlugin(
const models = [
imagen2Model(authClient, { projectId, location }),
...Object.keys(SUPPORTED_GEMINI_MODELS).map((name) =>
geminiModel(name, vertexClient)
geminiModel(name, vertexClientFactory)
),
];

Expand Down
Loading

0 comments on commit f720a2f

Please sign in to comment.