From 947adef9a863051e3cb5f5bec1734007961fb33d Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 4 Sep 2024 22:32:53 -0400 Subject: [PATCH 1/6] feat: allow specifying middleware on the generate function #450 --- js/ai/src/generate.ts | 7 +- js/ai/src/generateAction.ts | 257 ++++++++++++---------- js/ai/src/model.ts | 1 + js/ai/tests/generate/generate_test.ts | 119 +++++++++- js/plugins/dotprompt/src/prompt.ts | 1 + js/plugins/dotprompt/tests/prompt_test.ts | 3 + 6 files changed, 258 insertions(+), 130 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 8d412942ec..e7cb20642c 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -27,7 +27,7 @@ import { z } from 'zod'; import { DocumentData } from './document.js'; import { extractJson } from './extract.js'; import { - generateAction, + generateHelper, GenerateUtilParamSchema, inferRoleFromParts, } from './generateAction.js'; @@ -41,6 +41,7 @@ import { MessageData, ModelAction, ModelArgument, + ModelMiddleware, ModelReference, Part, ToolDefinition, @@ -490,6 +491,8 @@ export interface GenerateOptions< returnToolRequests?: boolean; /** When provided, models supporting streaming will call the provided callback with chunks as generation progresses. */ streamingCallback?: StreamingCallback; + /** Middlewera to be used with this model call. */ + use?: ModelMiddleware[]; } async function resolveModel(options: GenerateOptions): Promise { @@ -612,7 +615,7 @@ export async function generate< resolvedOptions.streamingCallback, async () => new GenerateResponse( - await generateAction(params), + await generateHelper(params, resolvedOptions.use), await toGenerateRequest(resolvedOptions) ) ); diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts index 2938d1ab2d..95e0a147a4 100644 --- a/js/ai/src/generateAction.ts +++ b/js/ai/src/generateAction.ts @@ -18,6 +18,7 @@ import { Action, defineAction, getStreamingCallback, + Middleware, runWithStreamingCallback, } from '@genkit-ai/core'; import { lookupAction } from '@genkit-ai/core/registry'; @@ -37,7 +38,9 @@ import { import { CandidateData, GenerateRequest, + GenerateRequestSchema, GenerateResponseChunkData, + GenerateResponseData, GenerateResponseSchema, MessageData, MessageSchema, @@ -85,141 +88,163 @@ export const generateAction = defineAction( inputSchema: GenerateUtilParamSchema, outputSchema: GenerateResponseSchema, }, - async (input) => { - const model = (await lookupAction(`/model/${input.model}`)) as ModelAction; - if (!model) { - throw new Error(`Model ${input.model} not found`); - } + async (input) => generateHelper(input) +); - let tools: ToolAction[] | undefined; - if (input.tools?.length) { - if (!model.__action.metadata?.model.supports?.tools) { - throw new Error( - `Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` - ); - } - tools = await Promise.all( - input.tools.map(async (toolRef) => { - if (typeof toolRef === 'string') { - const tool = (await lookupAction(toolRef)) as ToolAction; - if (!tool) { - throw new Error(`Tool ${toolRef} not found`); - } - return tool; - } - throw ''; - }) +export async function generateHelper( + input: z.infer, + middleware?: Middleware[] +): Promise { + const model = (await lookupAction(`/model/${input.model}`)) as ModelAction; + if (!model) { + throw new Error(`Model ${input.model} not found`); + } + + let tools: ToolAction[] | undefined; + if (input.tools?.length) { + if (!model.__action.metadata?.model.supports?.tools) { + throw new Error( + `Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` ); } + tools = await Promise.all( + input.tools.map(async (toolRef) => { + if (typeof toolRef === 'string') { + const tool = (await lookupAction(toolRef)) as ToolAction; + if (!tool) { + throw new Error(`Tool ${toolRef} not found`); + } + return tool; + } + throw ''; + }) + ); + } - const request = await actionToGenerateRequest(input, tools); + const request = await actionToGenerateRequest(input, tools); - const accumulatedChunks: GenerateResponseChunkData[] = []; + const accumulatedChunks: GenerateResponseChunkData[] = []; - const streamingCallback = getStreamingCallback(); - const response = await runWithStreamingCallback( - streamingCallback - ? (chunk: GenerateResponseChunkData) => { - // Store accumulated chunk data - accumulatedChunks.push(chunk); - if (streamingCallback) { - streamingCallback!( - new GenerateResponseChunk(chunk, accumulatedChunks) - ); - } + const streamingCallback = getStreamingCallback(); + const response = await runWithStreamingCallback( + streamingCallback + ? (chunk: GenerateResponseChunkData) => { + // Store accumulated chunk data + accumulatedChunks.push(chunk); + if (streamingCallback) { + streamingCallback!( + new GenerateResponseChunk(chunk, accumulatedChunks) + ); } - : undefined, - async () => new GenerateResponse(await model(request)) - ); + } + : undefined, + async () => { + const dispatch = async ( + index: number, + req: z.infer + ) => { + if (!middleware || index === middleware.length) { + // end of the chain, call the original model action + return await model(req); + } - // throw NoValidCandidates if all candidates are blocked or - if ( - !response.candidates.some((c) => - ['stop', 'length'].includes(c.finishReason) - ) - ) { - throw new NoValidCandidatesError({ - message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`, - response, - }); + const currentMiddleware = middleware[index]; + return currentMiddleware(req, async (modifiedReq) => + dispatch(index + 1, modifiedReq || req) + ); + }; + + return new GenerateResponse(await dispatch(0, request)); } + ); - if (input.output?.jsonSchema && !response.toolRequests()?.length) { - // find a candidate with valid output schema - const candidateErrors = response.candidates.map((c) => { - // don't validate messages that have no text or data - if (c.text() === '' && c.data() === null) return null; + // throw NoValidCandidates if all candidates are blocked or + if ( + !response.candidates.some((c) => + ['stop', 'length'].includes(c.finishReason) + ) + ) { + throw new NoValidCandidatesError({ + message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`, + response, + }); + } - try { - parseSchema(c.output(), { - jsonSchema: input.output?.jsonSchema, - }); - return null; - } catch (e) { - return e as Error; - } - }); - // if all candidates have a non-null error... - if (candidateErrors.every((c) => !!c)) { - throw new NoValidCandidatesError({ - message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`, - response, - detail: { - candidateErrors: candidateErrors, - }, + if (input.output?.jsonSchema && !response.toolRequests()?.length) { + // find a candidate with valid output schema + const candidateErrors = response.candidates.map((c) => { + // don't validate messages that have no text or data + if (c.text() === '' && c.data() === null) return null; + + try { + parseSchema(c.output(), { + jsonSchema: input.output?.jsonSchema, }); + return null; + } catch (e) { + return e as Error; } + }); + // if all candidates have a non-null error... + if (candidateErrors.every((c) => !!c)) { + throw new NoValidCandidatesError({ + message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`, + response, + detail: { + candidateErrors: candidateErrors, + }, + }); } + } - // Pick the first valid candidate. - let selected: Candidate | undefined; - for (const candidate of response.candidates) { - if (isValidCandidate(candidate, tools || [])) { - selected = candidate; - break; - } + // Pick the first valid candidate. + let selected: Candidate | undefined; + for (const candidate of response.candidates) { + if (isValidCandidate(candidate, tools || [])) { + selected = candidate; + break; } + } - if (!selected) { - throw new Error('No valid candidates found'); - } + if (!selected) { + throw new Error('No valid candidates found'); + } - const toolCalls = selected.message.content.filter( - (part) => !!part.toolRequest - ); - if (input.returnToolRequests || toolCalls.length === 0) { - return response.toJSON(); - } - const toolResponses: ToolResponsePart[] = await Promise.all( - toolCalls.map(async (part) => { - if (!part.toolRequest) { - throw Error( - 'Tool request expected but not provided in tool request part' - ); - } - const tool = tools?.find( - (tool) => tool.__action.name === part.toolRequest?.name - ); - if (!tool) { - throw Error('Tool not found'); - } - return { - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output: await tool(part.toolRequest?.input), - }, - }; - }) - ); - const nextRequest = { - ...input, - history: [...request.messages, selected.message], - prompt: toolResponses, - }; - return await generateAction(nextRequest); + const toolCalls = selected.message.content.filter( + (part) => !!part.toolRequest + ); + if (input.returnToolRequests || toolCalls.length === 0) { + return response.toJSON(); } -); + const toolResponses: ToolResponsePart[] = await Promise.all( + toolCalls.map(async (part) => { + if (!part.toolRequest) { + throw Error( + 'Tool request expected but not provided in tool request part' + ); + } + const tool = tools?.find( + (tool) => tool.__action.name === part.toolRequest?.name + ); + if (!tool) { + throw Error('Tool not found'); + } + return { + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output: await tool(part.toolRequest?.input), + }, + }; + }) + ); + const nextRequest = { + ...input, + history: [...request.messages, selected.message], + prompt: toolResponses, + }; + return await generateHelper(nextRequest); +} async function actionToGenerateRequest( options: z.infer, diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 0881d56a0e..2c47d7610c 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -289,6 +289,7 @@ export function defineModel< configSchema?: CustomOptionsSchema; /** Descriptive name for this model e.g. 'Google AI - Gemini Pro'. */ label?: string; + /** Middlewera to be used with this model. */ use?: ModelMiddleware[]; }, runner: ( diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 67173d6da7..7995aa9736 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -14,22 +14,27 @@ * limitations under the License. */ +import { __hardResetRegistryForTesting } from '@genkit-ai/core/registry'; import assert from 'node:assert'; -import { describe, it } from 'node:test'; +import { beforeEach, describe, it } from 'node:test'; import { z } from 'zod'; -import { GenerateResponseChunk, generate } from '../../src/generate'; +import { GenerateResponseChunk } from '../../lib/generate.js'; +import { GenerateResponseChunkData } from '../../lib/model.js'; import { Candidate, GenerateOptions, GenerateResponse, Message, + generate, toGenerateRequest, } from '../../src/generate.js'; -import { GenerateResponseChunkData, defineModel } from '../../src/model'; import { CandidateData, GenerateRequest, MessageData, + ModelAction, + ModelMiddleware, + defineModel, } from '../../src/model.js'; import { defineTool } from '../../src/tool.js'; @@ -582,19 +587,109 @@ describe('GenerateResponseChunk', () => { }); }); -const echo = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - candidates: [ - { index: 0, message: input.messages[0], finishReason: 'stop' }, - ], - }) -); +describe('generate', () => { + beforeEach(__hardResetRegistryForTesting); + + var echoModel: ModelAction; + + beforeEach(() => { + echoModel = defineModel( + { + name: 'echoModel', + }, + async (request) => { + return { + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + text: + 'Echo: ' + + request.messages + .map((m) => m.content.map((c) => c.text).join()) + .join(), + }, + ], + }, + }, + ], + }; + } + ); + }); + + it('applies middleware', async () => { + const wrapRequest: ModelMiddleware = async (req, next) => { + return next({ + ...req, + messages: [ + { + role: 'user', + content: [ + { + text: + '(' + + req.messages + .map((m) => m.content.map((c) => c.text).join()) + .join() + + ')', + }, + ], + }, + ], + }); + }; + const wrapResponse: ModelMiddleware = async (req, next) => { + const res = await next(req); + return { + candidates: [ + { + index: 0, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + text: + '[' + + res.candidates[0].message.content + .map((c) => c.text) + .join() + + ']', + }, + ], + }, + }, + ], + }; + }; + + const response = await generate({ + prompt: 'banana', + model: echoModel, + use: [wrapRequest, wrapResponse], + }); + + const want = '[Echo: (banana)]'; + assert.deepStrictEqual(response.text(), want); + }); +}); describe('generate', () => { + beforeEach(() => { + defineModel({ name: 'echo', supports: { tools: true } }, async (input) => ({ + candidates: [ + { index: 0, message: input.messages[0], finishReason: 'stop' }, + ], + })); + }); it('should preserve the request in the returned response, enabling toHistory()', async () => { const response = await generate({ - model: echo, + model: 'echo', prompt: 'Testing toHistory', }); diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index 42aa32dec3..f1c73185b9 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -207,6 +207,7 @@ export class Dotprompt implements PromptMetadata { tools: (options.tools || []).concat(this.tools || []), streamingCallback: options.streamingCallback, returnToolRequests: options.returnToolRequests, + use: options.use, } as GenerateOptions; } diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index ac012a55d3..f5a3cd8fb7 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -96,14 +96,17 @@ describe('Prompt', () => { const prompt = testPrompt(`Hello {{name}}, how are you?`); const streamingCallback = (c) => console.log(c); + const middleware = []; const rendered = await prompt.render({ input: { name: 'Michael' }, streamingCallback, returnToolRequests: true, + use: middleware, }); assert.strictEqual(rendered.streamingCallback, streamingCallback); assert.strictEqual(rendered.returnToolRequests, true); + assert.strictEqual(rendered.use, middleware); }); it('should support system prompt with history', async () => { From bc6c9c6786536f4058f969a42ad6ffabba28564f Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 4 Sep 2024 22:35:11 -0400 Subject: [PATCH 2/6] typo --- js/ai/src/generate.ts | 2 +- js/ai/src/model.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index e7cb20642c..0d85ea0cb8 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -491,7 +491,7 @@ export interface GenerateOptions< returnToolRequests?: boolean; /** When provided, models supporting streaming will call the provided callback with chunks as generation progresses. */ streamingCallback?: StreamingCallback; - /** Middlewera to be used with this model call. */ + /** Middleware to be used with this model call. */ use?: ModelMiddleware[]; } diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 2c47d7610c..732bae896d 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -289,7 +289,7 @@ export function defineModel< configSchema?: CustomOptionsSchema; /** Descriptive name for this model e.g. 'Google AI - Gemini Pro'. */ label?: string; - /** Middlewera to be used with this model. */ + /** Middleware to be used with this model. */ use?: ModelMiddleware[]; }, runner: ( From b9c730e4834924c9c5a32bdbda39b7c9410e769b Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 5 Sep 2024 07:47:39 -0400 Subject: [PATCH 3/6] added tracing --- js/ai/src/generateAction.ts | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts index 95e0a147a4..9b929ff8e7 100644 --- a/js/ai/src/generateAction.ts +++ b/js/ai/src/generateAction.ts @@ -27,6 +27,7 @@ import { toJsonSchema, validateSchema, } from '@genkit-ai/core/schema'; +import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing'; import { z } from 'zod'; import { DocumentDataSchema } from './document.js'; import { @@ -94,6 +95,30 @@ export const generateAction = defineAction( export async function generateHelper( input: z.infer, middleware?: Middleware[] +): Promise { + // do tracing + return await runInNewSpan( + { + metadata: { + name: 'generate', + }, + labels: { + [SPAN_TYPE_ATTR]: 'helper', + }, + }, + async (metadata) => { + metadata.name = 'generate'; + metadata.input = input; + const output = await generate(input, middleware); + metadata.output = JSON.stringify(output); + return output; + } + ); +} + +async function generate( + input: z.infer, + middleware?: Middleware[] ): Promise { const model = (await lookupAction(`/model/${input.model}`)) as ModelAction; if (!model) { @@ -243,7 +268,7 @@ export async function generateHelper( history: [...request.messages, selected.message], prompt: toolResponses, }; - return await generateHelper(nextRequest); + return await generateHelper(nextRequest, middleware); } async function actionToGenerateRequest( From 4fa3f84280a2b6fedb13898b6b4c529100fa50ae Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 5 Sep 2024 07:48:11 -0400 Subject: [PATCH 4/6] action call generate --- js/ai/src/generateAction.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts index 9b929ff8e7..beeb3a3f96 100644 --- a/js/ai/src/generateAction.ts +++ b/js/ai/src/generateAction.ts @@ -89,7 +89,7 @@ export const generateAction = defineAction( inputSchema: GenerateUtilParamSchema, outputSchema: GenerateResponseSchema, }, - async (input) => generateHelper(input) + async (input) => generate(input) ); export async function generateHelper( From 42145e3655e0400a5db19bff2a36d6367957a794 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 9 Sep 2024 17:36:43 -0400 Subject: [PATCH 5/6] feedback! --- js/ai/src/generateAction.ts | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts index beeb3a3f96..22f161079b 100644 --- a/js/ai/src/generateAction.ts +++ b/js/ai/src/generateAction.ts @@ -92,6 +92,9 @@ export const generateAction = defineAction( async (input) => generate(input) ); +/** + * Encapsulates all generate logic. This is similar to `generateAction` except not an action and can take middleware. + */ export async function generateHelper( input: z.infer, middleware?: Middleware[] @@ -232,7 +235,10 @@ async function generate( } if (!selected) { - throw new Error('No valid candidates found'); + throw new NoValidCandidatesError({ + message: 'No valid candidates found', + response, + }); } const toolCalls = selected.message.content.filter( From 82d9620b79c1690e2727808affb04c41913c927a Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Tue, 10 Sep 2024 09:32:17 -0400 Subject: [PATCH 6/6] format --- js/plugins/google-cloud/tests/traces_test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/plugins/google-cloud/tests/traces_test.ts b/js/plugins/google-cloud/tests/traces_test.ts index 6dd33e2dc8..1e3842f104 100644 --- a/js/plugins/google-cloud/tests/traces_test.ts +++ b/js/plugins/google-cloud/tests/traces_test.ts @@ -24,9 +24,9 @@ import { import { registerFlowStateStore } from '@genkit-ai/core/registry'; import { defineFlow, run } from '@genkit-ai/flow'; import { - googleCloud, __forceFlushSpansForTesting, __getSpanExporterForTesting, + googleCloud, } from '@genkit-ai/google-cloud'; import { ReadableSpan } from '@opentelemetry/sdk-trace-base'; import assert from 'node:assert';