diff --git a/docs/models.md b/docs/models.md index cf7e83a341..e8cb3b1082 100644 --- a/docs/models.md +++ b/docs/models.md @@ -169,6 +169,24 @@ await generate({ }); ``` +## Retriever context + +Documents from a retriever can be passed directly to `generate` to provide +grounding context: + +```javascript +const docs = await companyPolicyRetriever({ query: question }); + +await generate({ + model: geminiPro, + prompt: `Answer using the available context from company policy: ${question}`, + context: docs, +}); +``` + +The document context is automatically appended to the content of the prompt +sent to the model. + ## Message history Genkit models support maintaining a history of the messages sent to the model diff --git a/docs/rag.md b/docs/rag.md index b7425f602d..bc6db947aa 100644 --- a/docs/rag.md +++ b/docs/rag.md @@ -181,7 +181,6 @@ import { configureGenkit } from '@genkit-ai/core'; import { defineFlow } from '@genkit-ai/flow'; import { generate } from '@genkit-ai/ai/generate'; import { retrieve } from '@genkit-ai/ai/retriever'; -import { definePrompt } from '@genkit-ai/dotprompt'; import { devLocalRetrieverRef, devLocalVectorstore, @@ -211,31 +210,11 @@ export const ragFlow = defineFlow( query: input, options: { k: 3 }, }); - const facts = docs.map((d) => d.text()); - - const promptGenerator = definePrompt( - { - name: 'bob-facts', - model: 'google-vertex/gemini-pro', - input: { - schema: z.object({ - facts: z.array(z.string()), - question: z.string(), - }), - }, - }, - '{{#each people}}{{this}}\n\n{{/each}}\n{{question}}' - ); - const prompt = await promptGenerator.generate({ - input: { - facts, - question: input, - }, - }); const llmResponse = await generate({ model: geminiPro, - prompt: prompt.text(), + prompt: `Answer this question: ${input}`, + context: docs, }); const output = llmResponse.text(); diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 7b8595a35e..a5fa2a63d0 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -22,6 +22,7 @@ import { } from '@genkit-ai/core'; import { lookupAction } from '@genkit-ai/core/registry'; import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema'; +import { DocumentData } from '@google-cloud/firestore'; import { z } from 'zod'; import { extractJson } from './extract.js'; import { @@ -386,36 +387,38 @@ function inferRoleFromParts(parts: Part[]): Role { } export async function toGenerateRequest( - prompt: GenerateOptions + options: GenerateOptions ): Promise { const promptMessage: MessageData = { role: 'user', content: [] }; - if (typeof prompt.prompt === 'string') { - promptMessage.content.push({ text: prompt.prompt }); - } else if (Array.isArray(prompt.prompt)) { - promptMessage.role = inferRoleFromParts(prompt.prompt); - promptMessage.content.push(...prompt.prompt); + if (typeof options.prompt === 'string') { + promptMessage.content.push({ text: options.prompt }); + } else if (Array.isArray(options.prompt)) { + promptMessage.role = inferRoleFromParts(options.prompt); + promptMessage.content.push(...options.prompt); } else { - promptMessage.role = inferRoleFromParts([prompt.prompt]); - promptMessage.content.push(prompt.prompt); + promptMessage.role = inferRoleFromParts([options.prompt]); + promptMessage.content.push(options.prompt); } - const messages: MessageData[] = [...(prompt.history || []), promptMessage]; + const messages: MessageData[] = [...(options.history || []), promptMessage]; let tools: Action[] | undefined; - if (prompt.tools) { - tools = await resolveTools(prompt.tools); + if (options.tools) { + tools = await resolveTools(options.tools); } const out = { messages, - candidates: prompt.candidates, - config: prompt.config, + candidates: options.candidates, + config: options.config, tools: tools?.map((tool) => toToolDefinition(tool)) || [], output: { format: - prompt.output?.format || - (prompt.output?.schema || prompt.output?.jsonSchema ? 'json' : 'text'), + options.output?.format || + (options.output?.schema || options.output?.jsonSchema + ? 'json' + : 'text'), schema: toJsonSchema({ - schema: prompt.output?.schema, - jsonSchema: prompt.output?.jsonSchema, + schema: options.output?.schema, + jsonSchema: options.output?.jsonSchema, }), }, }; @@ -431,6 +434,8 @@ export interface GenerateOptions< model: ModelArgument; /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ prompt: string | Part | Part[]; + /** Retrieved documents to be used as context for this generation. */ + context?: DocumentData[]; /** Conversation history for multi-turn prompting when supported by the underlying model. */ history?: MessageData[]; /** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */ @@ -530,29 +535,33 @@ export async function generate< | GenerateOptions | PromiseLike> ): Promise>> { - const prompt: GenerateOptions = + const resolvedOptions: GenerateOptions = await Promise.resolve(options); - const model = await resolveModel(prompt.model); + const model = await resolveModel(resolvedOptions.model); if (!model) { - throw new Error(`Model ${JSON.stringify(prompt.model)} not found`); + throw new Error(`Model ${JSON.stringify(resolvedOptions.model)} not found`); } let tools: ToolAction[] | undefined; - if (prompt.tools?.length) { + if (resolvedOptions.tools?.length) { if (!model.__action.metadata?.model.supports?.tools) { throw new Error( - `Model ${JSON.stringify(prompt.model)} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` + `Model ${JSON.stringify(resolvedOptions.model)} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` ); } - tools = await resolveTools(prompt.tools); + tools = await resolveTools(resolvedOptions.tools); } - const request = await toGenerateRequest(prompt); - telemetry.recordGenerateActionInputLogs(model.__action.name, prompt, request); + const request = await toGenerateRequest(resolvedOptions); + telemetry.recordGenerateActionInputLogs( + model.__action.name, + resolvedOptions, + request + ); const response = await runWithStreamingCallback( - prompt.streamingCallback + resolvedOptions.streamingCallback ? (chunk: GenerateResponseChunkData) => - prompt.streamingCallback!(new GenerateResponseChunk(chunk)) + resolvedOptions.streamingCallback!(new GenerateResponseChunk(chunk)) : undefined, async () => new GenerateResponse>(await model(request), request) ); @@ -569,13 +578,13 @@ export async function generate< }); } - if (prompt.output?.schema || prompt.output?.jsonSchema) { + if (resolvedOptions.output?.schema || resolvedOptions.output?.jsonSchema) { // find a candidate with valid output schema const candidateValidations = response.candidates.map((c) => { try { return validateSchema(c.output(), { - jsonSchema: prompt.output?.jsonSchema, - schema: prompt.output?.schema, + jsonSchema: resolvedOptions.output?.jsonSchema, + schema: resolvedOptions.output?.schema, }); } catch (e) { return { @@ -612,10 +621,10 @@ export async function generate< const toolCalls = selected.message.content.filter( (part) => !!part.toolRequest ); - if (prompt.returnToolRequests || toolCalls.length === 0) { + if (resolvedOptions.returnToolRequests || toolCalls.length === 0) { telemetry.recordGenerateActionOutputLogs( model.__action.name, - prompt, + resolvedOptions, response ); return response; @@ -642,10 +651,10 @@ export async function generate< }; }) ); - prompt.history = request.messages; - prompt.history.push(selected.message); - prompt.prompt = toolResponses; - return await generate(prompt); + resolvedOptions.history = request.messages; + resolvedOptions.history.push(selected.message); + resolvedOptions.prompt = toolResponses; + return await generate(resolvedOptions); } export type GenerateStreamOptions< diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 2758a485fb..10036ea621 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -24,7 +24,12 @@ import { import { toJsonSchema } from '@genkit-ai/core/schema'; import { performance } from 'node:perf_hooks'; import { z } from 'zod'; -import { conformOutput, validateSupport } from './model/middleware.js'; +import { DocumentDataSchema } from './document.js'; +import { + augmentWithContext, + conformOutput, + validateSupport, +} from './model/middleware.js'; import * as telemetry from './telemetry.js'; // @@ -127,6 +132,8 @@ export const ModelInfoSchema = z.object({ systemRole: z.boolean().optional(), /** Model can output this type of data. */ output: z.array(OutputFormatSchema).optional(), + /** Model can natively support document-based context grounding. */ + context: z.boolean().optional(), }) .optional(), }); @@ -166,6 +173,7 @@ export const GenerateRequestSchema = z.object({ config: z.any().optional(), tools: z.array(ToolDefinitionSchema).optional(), output: OutputConfigSchema.optional(), + context: z.array(DocumentDataSchema).optional(), candidates: z.number().optional(), }); @@ -264,11 +272,12 @@ export function defineModel< ) => Promise ): ModelAction { const label = options.label || `${options.name} GenAI model`; - const middleware = [ + const middleware: ModelMiddleware[] = [ ...(options.use || []), validateSupport(options), - conformOutput(), ]; + if (!options?.supports?.context) middleware.push(augmentWithContext()); + middleware.push(conformOutput()); const act = defineAction( { actionType: 'model', diff --git a/js/ai/src/model/middleware.ts b/js/ai/src/model/middleware.ts index 866046896e..f15c794568 100644 --- a/js/ai/src/model/middleware.ts +++ b/js/ai/src/model/middleware.ts @@ -14,6 +14,7 @@ * limitations under the License. */ +import { Document } from '../document.js'; import { ModelInfo, ModelMiddleware, Part } from '../model.js'; /** @@ -171,3 +172,54 @@ export function simulateSystemPrompt(options?: { return next({ ...req, messages }); }; } + +export interface AugmentWithContextOptions { + /** Preceding text to place before the rendered context documents. */ + preface?: string | null; + /** A function to render a document into a text part to be included in the message. */ + itemTemplate?: (d: Document, options?: AugmentWithContextOptions) => string; + /** The metadata key to use for citation reference. Pass `null` to provide no citations. */ + citationKey?: string | null; +} + +export const CONTEXT_PREFACE = + '\n\nUse the following information to complete your task:\n\n'; +const CONTEXT_ITEM_TEMPLATE = ( + d: Document, + index: number, + options?: AugmentWithContextOptions +) => { + let out = '- '; + if (options?.citationKey) { + out += `[${d.metadata![options.citationKey]}]: `; + } else if (options?.citationKey === undefined) { + out += `[${d.metadata?.['ref'] || d.metadata?.['id'] || index}]: `; + } + out += d.text() + '\n'; + return out; +}; +export function augmentWithContext( + options?: AugmentWithContextOptions +): ModelMiddleware { + const preface = + typeof options?.preface === 'undefined' ? CONTEXT_PREFACE : options.preface; + const itemTemplate = options?.itemTemplate || CONTEXT_ITEM_TEMPLATE; + const citationKey = options?.citationKey; + return (req, next) => { + // if there is no context in the request, no-op + if (!req.context?.length) return next(req); + const userMessage = req.messages.at(-1); + // if there are no messages, no-op + if (!userMessage) return next(req); + // if there is already a context part, no-op + if (userMessage?.content.find((p) => p.metadata?.purpose === 'context')) + return next(req); + let out = `${preface || ''}`; + req.context?.forEach((d, i) => { + out += itemTemplate(new Document(d), i, options); + }); + out += '\n'; + userMessage.content.push({ text: out, metadata: { purpose: 'context' } }); + return next(req); + }; +} diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index df62c30e0d..c2050b8c36 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -16,6 +16,7 @@ import { Action, defineAction, JSONSchema7 } from '@genkit-ai/core'; import { lookupAction } from '@genkit-ai/core/registry'; +import { DocumentData } from '@google-cloud/firestore'; import z from 'zod'; import { GenerateOptions } from './generate'; import { GenerateRequest, GenerateRequestSchema, ModelArgument } from './model'; @@ -35,6 +36,13 @@ export type PromptAction = Action< }; }; +export function isPrompt(arg: any): boolean { + return ( + typeof arg === 'function' && + (arg as any).__action?.metadata?.type === 'prompt' + ); +} + export function definePrompt( { name, @@ -79,6 +87,7 @@ export async function renderPrompt< >(params: { prompt: PromptArgument; input: z.infer; + context?: DocumentData[]; model: ModelArgument; config?: z.infer; }): Promise { @@ -94,5 +103,6 @@ export async function renderPrompt< config: { ...(rendered.config || {}), ...params.config }, history: rendered.messages.slice(0, rendered.messages.length - 1), prompt: rendered.messages[rendered.messages.length - 1].content, + context: params.context, }; } diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index ee191d8f5d..ec5ceae847 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -16,13 +16,18 @@ import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; +import { DocumentData } from '../../src/document.js'; import { GenerateRequest, GenerateResponseData, + MessageData, Part, defineModel, } from '../../src/model.js'; import { + AugmentWithContextOptions, + CONTEXT_PREFACE, + augmentWithContext, simulateSystemPrompt, validateSupport, } from '../../src/model/middleware.js'; @@ -258,3 +263,160 @@ describe('simulateSystemPrompt', () => { }); }); }); + +describe('augmentWithContext', () => { + async function testRequest( + messages: MessageData[], + context?: DocumentData[], + options?: AugmentWithContextOptions + ) { + const changedRequest = await new Promise( + (resolve, reject) => { + augmentWithContext(options)( + { + messages, + context, + }, + resolve as any + ); + } + ); + return changedRequest.messages; + } + + it('should not change a message with empty context', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + assert.deepEqual(await testRequest(messages, undefined), messages); + assert.deepEqual(await testRequest(messages, []), messages); + }); + + it('should not change a message that already has a context part', async () => { + const messages: MessageData[] = [ + { + role: 'user', + content: [{ text: 'first part', metadata: { purpose: 'context' } }], + }, + ]; + assert.deepEqual( + await testRequest(messages, [{ content: [{ text: 'i am context' }] }]), + messages + ); + }); + + it('should append a new text part', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest(messages, [ + { content: [{ text: 'i am context' }] }, + { content: [{ text: 'i am more context' }] }, + ]); + assert.deepEqual(result[0].content.at(-1), { + text: `${CONTEXT_PREFACE}- [0]: i am context\n- [1]: i am more context\n\n`, + metadata: { purpose: 'context' }, + }); + }); + + it('should use a custom preface', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest( + messages, + [ + { content: [{ text: 'i am context' }] }, + { content: [{ text: 'i am more context' }] }, + ], + { preface: '\n\nCheck this out:\n\n' } + ); + assert.deepEqual(result[0].content.at(-1), { + text: '\n\nCheck this out:\n\n- [0]: i am context\n- [1]: i am more context\n\n', + metadata: { purpose: 'context' }, + }); + }); + + it('should elide a null preface', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest( + messages, + [ + { content: [{ text: 'i am context' }] }, + { content: [{ text: 'i am more context' }] }, + ], + { preface: null } + ); + assert.deepEqual(result[0].content.at(-1), { + text: '- [0]: i am context\n- [1]: i am more context\n\n', + metadata: { purpose: 'context' }, + }); + }); + + it('should use a citationKey', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest( + messages, + [ + { content: [{ text: 'i am context' }], metadata: { uid: 'first' } }, + { + content: [{ text: 'i am more context' }], + metadata: { uid: 'second' }, + }, + ], + { citationKey: 'uid' } + ); + assert.deepEqual(result[0].content.at(-1), { + text: `${CONTEXT_PREFACE}- [first]: i am context\n- [second]: i am more context\n\n`, + metadata: { purpose: 'context' }, + }); + }); + + it('should use "ref", "id", and index, in that order, if citationKey is unspecified', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest(messages, [ + { + content: [{ text: 'i am context' }], + metadata: { ref: 'first', id: 'wrong' }, + }, + { + content: [{ text: 'i am more context' }], + metadata: { id: 'second' }, + }, + { + content: [{ text: 'i am even more context' }], + }, + ]); + assert.deepEqual(result[0].content.at(-1), { + text: `${CONTEXT_PREFACE}- [first]: i am context\n- [second]: i am more context\n- [2]: i am even more context\n\n`, + metadata: { purpose: 'context' }, + }); + }); + + it('should use a custom itemTemplate', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest( + messages, + [ + { content: [{ text: 'i am context' }], metadata: { uid: 'first' } }, + { + content: [{ text: 'i am more context' }], + metadata: { uid: 'second' }, + }, + ], + { itemTemplate: (d) => `* (${d.metadata!.uid}) -- ${d.text()}\n` } + ); + assert.deepEqual(result[0].content.at(-1), { + text: `${CONTEXT_PREFACE}* (first) -- i am context\n* (second) -- i am more context\n\n`, + metadata: { purpose: 'context' }, + }); + }); +}); diff --git a/js/dotprompt/src/prompt.ts b/js/dotprompt/src/prompt.ts index ba27ce479b..9b57cc58b6 100644 --- a/js/dotprompt/src/prompt.ts +++ b/js/dotprompt/src/prompt.ts @@ -25,6 +25,7 @@ import { toGenerateRequest, } from '@genkit-ai/ai'; import { GenerationCommonConfigSchema, MessageData } from '@genkit-ai/ai/model'; +import { DocumentData } from '@genkit-ai/ai/retriever'; import { GenkitError } from '@genkit-ai/core'; import { parseSchema } from '@genkit-ai/core/schema'; import { createHash } from 'crypto'; @@ -113,8 +114,8 @@ export class Dotprompt implements PromptMetadata { this._render = compile(this.template, options); } - renderText(input: Variables): string { - const result = this.renderMessages(input); + renderText(input: Variables, context?: DocumentData[]): string { + const result = this.renderMessages(input, context); if (result.length !== 1) { throw new Error("Multi-message prompt can't be rendered as text."); } @@ -128,7 +129,7 @@ export class Dotprompt implements PromptMetadata { return out; } - renderMessages(input?: Variables): MessageData[] { + renderMessages(input?: Variables, context?: DocumentData[]): MessageData[] { input = parseSchema(input, { schema: this.input?.schema, jsonSchema: this.input?.jsonSchema, @@ -152,8 +153,7 @@ export class Dotprompt implements PromptMetadata { prompt: this.toJSON(), }, }, - async (input: Variables) => - toGenerateRequest(this.render({ input: input })) + async (input?: Variables) => toGenerateRequest(this.render({ input })) ); } @@ -174,6 +174,7 @@ export class Dotprompt implements PromptMetadata { config: { ...this.config, ...options.config } || {}, history: messages.slice(0, messages.length - 1), prompt: messages[messages.length - 1].content, + context: options.context, candidates: options.candidates || this.candidates || 1, output: { format: options.output?.format || this.output?.format || undefined, diff --git a/js/dotprompt/src/template.ts b/js/dotprompt/src/template.ts index 7461bfadb4..9bbac75d74 100644 --- a/js/dotprompt/src/template.ts +++ b/js/dotprompt/src/template.ts @@ -14,7 +14,14 @@ * limitations under the License. */ -import { MediaPart, MessageData, Part, Role } from '@genkit-ai/ai/model'; +import { + MediaPart, + MessageData, + Part, + Role, + TextPart, +} from '@genkit-ai/ai/model'; +import { DocumentData } from '@genkit-ai/ai/retriever'; import Handlebars from 'handlebars'; import { PromptMetadata } from './metadata.js'; @@ -72,23 +79,30 @@ function toMessages(renderedString: string): MessageData[] { })); } -const MEDIA_REGEX = /(<<>>/g; +const PART_REGEX = /(<<>>/g; function toParts(source: string): Part[] { const parts: Part[] = []; - for (const piece of source - .split(MEDIA_REGEX) - .filter((s) => s.trim() !== '')) { + const pieces = source.split(PART_REGEX).filter((s) => s.trim() !== ''); + for (let i = 0; i < pieces.length; i++) { + const piece = pieces[i]; if (piece.startsWith('<<( return ( input: Variables, - options?: { context?: any[]; history?: MessageData[] } + options?: { context?: DocumentData[]; history?: MessageData[] } ) => { const renderedString = renderString(input, { - data: { prompt: metadata, context: options?.context || null }, + data: { + metadata: { prompt: metadata, context: options?.context || null }, + }, }); return toMessages(renderedString); };