diff --git a/js/plugins/vertexai/src/evaluation.ts b/js/plugins/vertexai/src/evaluation.ts index c8fa1937e0..7cc94c026d 100644 --- a/js/plugins/vertexai/src/evaluation.ts +++ b/js/plugins/vertexai/src/evaluation.ts @@ -18,6 +18,7 @@ import { BaseDataPoint } from '@genkit-ai/ai/evaluator'; import { Action } from '@genkit-ai/core'; import { GoogleAuth } from 'google-auth-library'; import { JSONClient } from 'google-auth-library/build/src/auth/googleauth'; +import z from 'zod'; import { EvaluatorFactory } from './evaluator_factory'; /** @@ -57,10 +58,6 @@ export function vertexEvaluators( const metricType = isConfig(metric) ? metric.type : metric; const metricSpec = isConfig(metric) ? metric.metricSpec : {}; - console.log( - `Creating evaluator for metric ${metricType} with metricSpec ${metricSpec}` - ); - switch (metricType) { case VertexAIEvaluationMetricType.BLEU: { return createBleuEvaluator(factory, metricSpec); @@ -84,6 +81,12 @@ function isConfig( return (config as VertexAIEvaluationMetricConfig).type !== undefined; } +const BleuResponseSchema = z.object({ + bleuResults: z.object({ + bleuMetricValues: z.array(z.object({ score: z.number() })), + }), +}); + // TODO: Add support for batch inputs function createBleuEvaluator( factory: EvaluatorFactory, @@ -95,6 +98,7 @@ function createBleuEvaluator( displayName: 'BLEU', definition: 'Computes the BLEU score by comparing the output against the ground truth', + responseSchema: BleuResponseSchema, }, (datapoint) => { if (!datapoint.reference) { @@ -124,6 +128,12 @@ function createBleuEvaluator( ); } +const RougeResponseSchema = z.object({ + rougeResults: z.object({ + rougeMetricValues: z.array(z.object({ score: z.number() })), + }), +}); + // TODO: Add support for batch inputs function createRougeEvaluator( factory: EvaluatorFactory, @@ -135,6 +145,7 @@ function createRougeEvaluator( displayName: 'ROUGE', definition: 'Computes the ROUGE score by comparing the output against the ground truth', + responseSchema: RougeResponseSchema, }, (datapoint) => { if (!datapoint.reference) { @@ -162,6 +173,14 @@ function createRougeEvaluator( ); } +const SafetyResponseSchema = z.object({ + safetyResult: z.object({ + score: z.number(), + explanation: z.string(), + confidence: z.number(), + }), +}); + function createSafetyEvaluator( factory: EvaluatorFactory, metricSpec: any @@ -171,6 +190,7 @@ function createSafetyEvaluator( metric: VertexAIEvaluationMetricType.SAFETY, displayName: 'Safety', definition: 'Assesses the level of safety of an output', + responseSchema: SafetyResponseSchema, }, (datapoint) => { return { @@ -182,7 +202,7 @@ function createSafetyEvaluator( }, }; }, - (response: any, datapoint: BaseDataPoint) => { + (response, datapoint: BaseDataPoint) => { return { testCaseId: datapoint.testCaseId, evaluation: { @@ -196,6 +216,14 @@ function createSafetyEvaluator( ); } +const GroundednessResponseSchema = z.object({ + groundednessResult: z.object({ + score: z.number(), + explanation: z.string(), + confidence: z.number(), + }), +}); + function createGroundednessEvaluator( factory: EvaluatorFactory, metricSpec: any @@ -206,6 +234,7 @@ function createGroundednessEvaluator( displayName: 'Groundedness', definition: 'Assesses the ability to provide or reference information included only in the context', + responseSchema: GroundednessResponseSchema, }, (datapoint) => { return { @@ -218,7 +247,7 @@ function createGroundednessEvaluator( }, }; }, - (response: any, datapoint: BaseDataPoint) => { + (response, datapoint: BaseDataPoint) => { return { testCaseId: datapoint.testCaseId, evaluation: { diff --git a/js/plugins/vertexai/src/evaluator_factory.ts b/js/plugins/vertexai/src/evaluator_factory.ts index d2db0ff6ad..017de6a948 100644 --- a/js/plugins/vertexai/src/evaluator_factory.ts +++ b/js/plugins/vertexai/src/evaluator_factory.ts @@ -19,6 +19,7 @@ import { Action } from '@genkit-ai/core'; import { runInNewSpan } from '@genkit-ai/core/tracing'; import { GoogleAuth } from 'google-auth-library'; import { JSONClient } from 'google-auth-library/build/src/auth/googleauth'; +import z from 'zod'; import { VertexAIEvaluationMetricType } from './evaluation'; export class EvaluatorFactory { @@ -28,14 +29,18 @@ export class EvaluatorFactory { private readonly projectId: string ) {} - create( + create( config: { metric: VertexAIEvaluationMetricType; displayName: string; definition: string; + responseSchema: ResponseType; }, toRequest: (datapoint: BaseDataPoint) => any, - responseHandler: (response: any, datapoint: BaseDataPoint) => any + responseHandler: ( + response: z.infer, + datapoint: BaseDataPoint + ) => any ): Action { return defineEvaluator( { @@ -44,14 +49,21 @@ export class EvaluatorFactory { definition: config.definition, }, async (datapoint: BaseDataPoint) => { - const response = await this.evaluateInstances(toRequest(datapoint)); + const responseSchema = config.responseSchema; + const response = await this.evaluateInstances( + toRequest(datapoint), + responseSchema + ); return responseHandler(response, datapoint); } ); } - async evaluateInstances(partialRequest: any) { + async evaluateInstances( + partialRequest: any, + responseSchema: ResponseType + ): Promise> { const locationName = `projects/${this.projectId}/locations/${this.location}`; return await runInNewSpan( { @@ -64,15 +76,22 @@ export class EvaluatorFactory { location: locationName, ...partialRequest, }; + metadata.input = request; const client = await this.auth.getClient(); + const url = `https://${this.location}-aiplatform.googleapis.com/v1beta1/${locationName}:evaluateInstances`; const response = await client.request({ - url: `https://${this.location}-aiplatform.googleapis.com/v1beta1/${locationName}:evaluateInstances`, + url, method: 'POST', body: JSON.stringify(request), }); metadata.output = response.data; - return response.data as any; + + try { + return responseSchema.parse(response.data); + } catch (e) { + throw new Error(`Error parsing ${url} API response: ${e}`); + } } ); }