From 55c819f1f5b9e3c3cd705fdf66f779bf0e95fb99 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Mon, 6 May 2024 23:47:17 -0700 Subject: [PATCH 01/13] [WIP] Adds context as first-class feature of Dotprompt. --- js/dotprompt/src/prompt.ts | 41 ++++++++++++++---- js/dotprompt/src/template.ts | 50 +++++++++++++++++---- js/dotprompt/tests/template_test.ts | 67 ++++++++++++++++++++++++++++- 3 files changed, 140 insertions(+), 18 deletions(-) diff --git a/js/dotprompt/src/prompt.ts b/js/dotprompt/src/prompt.ts index ba27ce479b..880e491c08 100644 --- a/js/dotprompt/src/prompt.ts +++ b/js/dotprompt/src/prompt.ts @@ -9,7 +9,7 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +cd * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ @@ -25,8 +25,9 @@ import { toGenerateRequest, } from '@genkit-ai/ai'; import { GenerationCommonConfigSchema, MessageData } from '@genkit-ai/ai/model'; +import { DocumentData, DocumentDataSchema } from '@genkit-ai/ai/retriever'; import { GenkitError } from '@genkit-ai/core'; -import { parseSchema } from '@genkit-ai/core/schema'; +import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { createHash } from 'crypto'; import fm, { FrontMatterResult } from 'front-matter'; import z from 'zod'; @@ -46,6 +47,7 @@ export type PromptGenerateOptions = Omit< > & { model?: string; input?: V; + context?: DocumentData[]; }; export class Dotprompt implements PromptMetadata { @@ -113,8 +115,8 @@ export class Dotprompt implements PromptMetadata { this._render = compile(this.template, options); } - renderText(input: Variables): string { - const result = this.renderMessages(input); + renderText(input: Variables, context?: DocumentData[]): string { + const result = this.renderMessages(input, context); if (result.length !== 1) { throw new Error("Multi-message prompt can't be rendered as text."); } @@ -128,7 +130,7 @@ export class Dotprompt implements PromptMetadata { return out; } - renderMessages(input?: Variables): MessageData[] { + renderMessages(input?: Variables, context?: DocumentData[]): MessageData[] { input = parseSchema(input, { schema: this.input?.schema, jsonSchema: this.input?.jsonSchema, @@ -145,15 +147,36 @@ export class Dotprompt implements PromptMetadata { { name: `${this.name}${this.variant ? `.${this.variant}` : ''}`, description: 'Defined by Dotprompt', - inputSchema: this.input?.schema, - inputJsonSchema: this.input?.jsonSchema, + inputSchema: this.input?.schema + ? z.object({ + input: this.input.schema, + context: z.array(DocumentDataSchema).optional(), + }) + : undefined, + inputJsonSchema: this.input?.jsonSchema + ? { + type: 'object', + properties: { + input: this.input?.jsonSchema, + context: { + type: 'array', + items: toJsonSchema({ schema: DocumentDataSchema }), + }, + }, + } + : undefined, metadata: { type: 'prompt', prompt: this.toJSON(), }, }, - async (input: Variables) => - toGenerateRequest(this.render({ input: input })) + async ({ + input, + context, + }: { + input?: Variables; + context?: DocumentData[]; + }) => toGenerateRequest(this.render({ input, context })) ); } diff --git a/js/dotprompt/src/template.ts b/js/dotprompt/src/template.ts index 7461bfadb4..c00abc5c23 100644 --- a/js/dotprompt/src/template.ts +++ b/js/dotprompt/src/template.ts @@ -14,7 +14,14 @@ * limitations under the License. */ -import { MediaPart, MessageData, Part, Role } from '@genkit-ai/ai/model'; +import { + MediaPart, + MessageData, + Part, + Role, + TextPart, +} from '@genkit-ai/ai/model'; +import { Document, DocumentData } from '@genkit-ai/ai/retriever'; import Handlebars from 'handlebars'; import { PromptMetadata } from './metadata.js'; @@ -41,6 +48,23 @@ function mediaHelper(options: Handlebars.HelperOptions) { } Promptbars.registerHelper('media', mediaHelper); +function contextHelper(options: Handlebars.HelperOptions) { + const context = options.data?.metadata?.context || []; + const items = context.map((d: DocumentData, i) => { + let text = new Document(d).text(); + if (options.hash.cite === true) { + text += `[${i}]`; + } else if (d.metadata?.[options.hash.cite]) { + text += `[${d.metadata[options.hash.cite]}]`; + } + return Promptbars.escapeExpression(text); + }); + + return new Promptbars.SafeString(`<<>> +- ${items.join('\n- ')}`); +} +Promptbars.registerHelper('context', contextHelper); + const ROLE_REGEX = /(<<>>/g; function toMessages(renderedString: string): MessageData[] { @@ -72,23 +96,30 @@ function toMessages(renderedString: string): MessageData[] { })); } -const MEDIA_REGEX = /(<<>>/g; +const PART_REGEX = /(<<>>/g; function toParts(source: string): Part[] { const parts: Part[] = []; - for (const piece of source - .split(MEDIA_REGEX) - .filter((s) => s.trim() !== '')) { + const pieces = source.split(PART_REGEX).filter((s) => s.trim() !== ''); + for (let i = 0; i < pieces.length; i++) { + const piece = pieces[i]; if (piece.startsWith('<<( media: true, role: true, history: true, + context: true, }, - knownHelpersOnly: true, + // knownHelpersOnly: true, }); return ( input: Variables, - options?: { context?: any[]; history?: MessageData[] } + options?: { context?: DocumentData[]; history?: MessageData[] } ) => { const renderedString = renderString(input, { - data: { prompt: metadata, context: options?.context || null }, + data: { + metadata: { prompt: metadata, context: options?.context || null }, + }, }); return toMessages(renderedString); }; diff --git a/js/dotprompt/tests/template_test.ts b/js/dotprompt/tests/template_test.ts index ba01191fad..812abe454a 100644 --- a/js/dotprompt/tests/template_test.ts +++ b/js/dotprompt/tests/template_test.ts @@ -106,10 +106,75 @@ describe('compile', () => { template: '{{json . indent=2}}', want: [{ role: 'user', content: [{ text: '{\n "test": true\n}' }] }], }, + { + should: 'allow defining context', + input: {}, + context: [ + { content: [{ text: 'abc' }, { text: 'def' }] }, + { content: [{ text: 'hgi' }] }, + ], + template: '{{context}}', + want: [ + { + role: 'user', + content: [ + { text: '\n- abcdef\n- hgi', metadata: { purpose: 'context' } }, + ], + }, + ], + }, + { + should: 'allow defining context with custom citations', + input: {}, + context: [ + { + content: [{ text: 'abc' }, { text: 'def' }], + metadata: { ref: 'first' }, + }, + { content: [{ text: 'hgi' }], metadata: { ref: 'second' } }, + ], + template: '{{context cite="ref"}}', + want: [ + { + role: 'user', + content: [ + { + text: '\n- abcdef[first]\n- hgi[second]', + metadata: { purpose: 'context' }, + }, + ], + }, + ], + }, + { + should: 'allow defining context with numbered citations', + input: {}, + context: [ + { + content: [{ text: 'abc' }, { text: 'def' }], + metadata: { ref: 'first' }, + }, + { content: [{ text: 'hgi' }], metadata: { ref: 'second' } }, + ], + template: '{{context cite=true}}', + want: [ + { + role: 'user', + content: [ + { + text: '\n- abcdef[0]\n- hgi[1]', + metadata: { purpose: 'context' }, + }, + ], + }, + ], + }, ]) { it(test.should, () => { assert.deepEqual( - compile(test.template, { model: 'test/example' })(test.input), + compile(test.template, { model: 'test/example' })(test.input, { + context: test.context, + }), test.want ); }); From 94ff913680a8f4a9d12b5840a23f33214239b8bf Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 12:05:29 -0700 Subject: [PATCH 02/13] move context to ai package --- js/ai/src/model.ts | 11 +- js/ai/src/model/middleware.ts | 51 +++++++++ js/ai/tests/model/middleware_test.ts | 164 +++++++++++++++++++++++++++ 3 files changed, 223 insertions(+), 3 deletions(-) diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 2758a485fb..386bba6a81 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -24,7 +24,8 @@ import { import { toJsonSchema } from '@genkit-ai/core/schema'; import { performance } from 'node:perf_hooks'; import { z } from 'zod'; -import { conformOutput, validateSupport } from './model/middleware.js'; +import { DocumentDataSchema } from './document.js'; +import { augmentWithContext, conformOutput, validateSupport } from './model/middleware.js'; import * as telemetry from './telemetry.js'; // @@ -127,6 +128,8 @@ export const ModelInfoSchema = z.object({ systemRole: z.boolean().optional(), /** Model can output this type of data. */ output: z.array(OutputFormatSchema).optional(), + /** Model can natively support document-based context grounding. */ + context: z.boolean().optional(), }) .optional(), }); @@ -166,6 +169,7 @@ export const GenerateRequestSchema = z.object({ config: z.any().optional(), tools: z.array(ToolDefinitionSchema).optional(), output: OutputConfigSchema.optional(), + context: z.array(DocumentDataSchema).optional(), candidates: z.number().optional(), }); @@ -264,11 +268,12 @@ export function defineModel< ) => Promise ): ModelAction { const label = options.label || `${options.name} GenAI model`; - const middleware = [ + const middleware: ModelMiddleware[] = [ ...(options.use || []), validateSupport(options), - conformOutput(), ]; + if (!options?.supports?.context) middleware.push(augmentWithContext()); + middleware.push(conformOutput()); const act = defineAction( { actionType: 'model', diff --git a/js/ai/src/model/middleware.ts b/js/ai/src/model/middleware.ts index 866046896e..3d3d0b78d7 100644 --- a/js/ai/src/model/middleware.ts +++ b/js/ai/src/model/middleware.ts @@ -14,6 +14,7 @@ * limitations under the License. */ +import { Document } from '../document.js'; import { ModelInfo, ModelMiddleware, Part } from '../model.js'; /** @@ -171,3 +172,53 @@ export function simulateSystemPrompt(options?: { return next({ ...req, messages }); }; } + +export interface AugmentWithContextOptions { + /** Preceding text to place before the rendered context documents. */ + preface?: string | null; + /** A function to render a document into a text part to be included in the message. */ + itemTemplate?: (d: Document, options?: AugmentWithContextOptions) => string; + /** The metadata key to use for citation reference. Pass `null` to provide no citations. */ + citationKey?: string | null; +} + +const CONTEXT_PREFACE = + 'Use the following information to complete your task:\n\n'; +const CONTEXT_ITEM_TEMPLATE = ( + d: Document, + index: number, + options?: AugmentWithContextOptions +) => { + let out = '- '; + if (options?.citationKey) { + out += `[${d.metadata![options.citationKey]}]: `; + } else if (options?.citationKey === undefined) { + out += `[${d.metadata?.['ref'] || d.metadata?.['id'] || index}]: `; + } + out += d.text() + '\n'; + return out; +}; +export function augmentWithContext( + options?: AugmentWithContextOptions +): ModelMiddleware { + const preface = + typeof options?.preface === 'undefined' ? CONTEXT_PREFACE : options.preface; + const itemTemplate = options?.itemTemplate || CONTEXT_ITEM_TEMPLATE; + const citationKey = options?.citationKey; + return (req, next) => { + // if there is no context in the request, no-op + if (!req.context?.length) return next(req); + const userMessage = req.messages.at(-1); + // if there are no messages, no-op + if (!userMessage) return next(req); + // if there is already a context part, no-op + if (userMessage?.content.find((p) => p.metadata?.purpose === 'context')) + return next(req); + let out = `${preface || ''}`; + req.context?.forEach((d, i) => { + out += itemTemplate(new Document(d), i, options); + }); + userMessage.content.push({ text: out, metadata: { purpose: 'context' } }); + return next(req); + }; +} diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index ee191d8f5d..0ba79ea2e4 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -16,13 +16,17 @@ import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; +import { DocumentData } from '../../src/document.js'; import { GenerateRequest, GenerateResponseData, + MessageData, Part, defineModel, } from '../../src/model.js'; import { + AugmentWithContextOptions, + augmentWithContext, simulateSystemPrompt, validateSupport, } from '../../src/model/middleware.js'; @@ -258,3 +262,163 @@ describe('simulateSystemPrompt', () => { }); }); }); + +describe('augmentWithContext', () => { + async function testRequest( + messages: MessageData[], + context?: DocumentData[], + options?: AugmentWithContextOptions + ) { + const changedRequest = await new Promise( + (resolve, reject) => { + augmentWithContext(options)( + { + messages, + context, + }, + resolve as any + ); + } + ); + return changedRequest.messages; + } + + it('should not change a message with empty context', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + assert.deepEqual(await testRequest(messages, undefined), messages); + assert.deepEqual(await testRequest(messages, []), messages); + }); + + it('should not change a message that already has a context part', async () => { + const messages: MessageData[] = [ + { + role: 'user', + content: [{ text: 'first part', metadata: { purpose: 'context' } }], + }, + ]; + assert.deepEqual( + await testRequest(messages, [{ content: [{ text: 'i am context' }] }]), + messages + ); + }); + + it('should append a new text part', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest(messages, [ + { content: [{ text: 'i am context' }] }, + { content: [{ text: 'i am more context' }] }, + ]); + assert.deepEqual(result[0].content.at(-1), { + text: 'Use the following information to complete your task:\n\n- [0]: i am context\n- [1]: i am more context\n', + metadata: { purpose: 'context' }, + }); + }); + + it('should use a custom preface', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest( + messages, + [ + { content: [{ text: 'i am context' }] }, + { content: [{ text: 'i am more context' }] }, + ], + { preface: 'Check this out:\n\n' } + ); + assert.deepEqual(result[0].content.at(-1), { + text: 'Check this out:\n\n- [0]: i am context\n- [1]: i am more context\n', + metadata: { purpose: 'context' }, + }); + }); + + it('should elide a null preface', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest( + messages, + [ + { content: [{ text: 'i am context' }] }, + { content: [{ text: 'i am more context' }] }, + ], + { preface: null } + ); + assert.deepEqual(result[0].content.at(-1), { + text: '- [0]: i am context\n- [1]: i am more context\n', + metadata: { purpose: 'context' }, + }); + }); + + it('should use a citationKey', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest( + messages, + [ + { content: [{ text: 'i am context' }], metadata: { uid: 'first' } }, + { + content: [{ text: 'i am more context' }], + metadata: { uid: 'second' }, + }, + ], + { citationKey: 'uid' } + ); + assert.deepEqual(result[0].content.at(-1), { + text: 'Use the following information to complete your task:\n\n- [first]: i am context\n- [second]: i am more context\n', + metadata: { purpose: 'context' }, + }); + }); + + it('should use "ref", "id", and index, in that order, if citationKey is unspecified', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest( + messages, + [ + { + content: [{ text: 'i am context' }], + metadata: { ref: 'first', id: 'wrong' }, + }, + { + content: [{ text: 'i am more context' }], + metadata: { id: 'second' }, + }, + { + content: [{ text: 'i am even more context' }], + }, + ] + ); + assert.deepEqual(result[0].content.at(-1), { + text: 'Use the following information to complete your task:\n\n- [first]: i am context\n- [second]: i am more context\n- [2]: i am even more context\n', + metadata: { purpose: 'context' }, + }); + }); + + it('should use a custom itemTemplate', async () => { + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'first part' }] }, + ]; + const result = await testRequest( + messages, + [ + { content: [{ text: 'i am context' }], metadata: { uid: 'first' } }, + { + content: [{ text: 'i am more context' }], + metadata: { uid: 'second' }, + }, + ], + { itemTemplate: (d) => `* (${d.metadata!.uid}) -- ${d.text()}\n`} + ); + assert.deepEqual(result[0].content.at(-1), { + text: 'Use the following information to complete your task:\n\n* (first) -- i am context\n* (second) -- i am more context\n', + metadata: { purpose: 'context' }, + }); + }) +}); From 358ff7c41e91f38864113d84efc469d9fc310dc5 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 12:21:12 -0700 Subject: [PATCH 03/13] Adds isPrompt() helper --- js/ai/src/prompt.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index df62c30e0d..0d54500841 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -35,6 +35,10 @@ export type PromptAction = Action< }; }; +export function isPrompt(arg: any): boolean { + return typeof arg === 'function' && (arg as any).__action?.metadata?.type === 'prompt'; +} + export function definePrompt( { name, From b70f98608a56789d05e3df6554105bb2e76091da Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 12:28:29 -0700 Subject: [PATCH 04/13] docs --- docs/models.md | 18 ++++++++++++++++++ docs/rag.md | 25 ++----------------------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/docs/models.md b/docs/models.md index cf7e83a341..48a0892b48 100644 --- a/docs/models.md +++ b/docs/models.md @@ -169,6 +169,24 @@ await generate({ }); ``` +## Retriever context + +Documents from a retriever can be passed directly to `generate` to provide +grounding context: + +```javascript +const docs = await companyPolicyRetriever({query: question}); + +await generate({ + model: geminiPro, + prompt: `Answer using the available context from company policy: ${question}`, + context: docs, +}) +``` + +The document context is automatically appended to the content of the prompt +sent to the model. + ## Message history Genkit models support maintaining a history of the messages sent to the model diff --git a/docs/rag.md b/docs/rag.md index b7425f602d..bc6db947aa 100644 --- a/docs/rag.md +++ b/docs/rag.md @@ -181,7 +181,6 @@ import { configureGenkit } from '@genkit-ai/core'; import { defineFlow } from '@genkit-ai/flow'; import { generate } from '@genkit-ai/ai/generate'; import { retrieve } from '@genkit-ai/ai/retriever'; -import { definePrompt } from '@genkit-ai/dotprompt'; import { devLocalRetrieverRef, devLocalVectorstore, @@ -211,31 +210,11 @@ export const ragFlow = defineFlow( query: input, options: { k: 3 }, }); - const facts = docs.map((d) => d.text()); - - const promptGenerator = definePrompt( - { - name: 'bob-facts', - model: 'google-vertex/gemini-pro', - input: { - schema: z.object({ - facts: z.array(z.string()), - question: z.string(), - }), - }, - }, - '{{#each people}}{{this}}\n\n{{/each}}\n{{question}}' - ); - const prompt = await promptGenerator.generate({ - input: { - facts, - question: input, - }, - }); const llmResponse = await generate({ model: geminiPro, - prompt: prompt.text(), + prompt: `Answer this question: ${input}`, + context: docs, }); const output = llmResponse.text(); From 75e576f5bc2b6e54b1922eeb4560172cbef13e26 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 12:29:38 -0700 Subject: [PATCH 05/13] Adds docs for context. --- docs/models.md | 4 ++-- js/ai/src/model.ts | 6 ++++- js/ai/src/prompt.ts | 5 ++++- js/ai/tests/model/middleware_test.ts | 33 +++++++++++++--------------- 4 files changed, 26 insertions(+), 22 deletions(-) diff --git a/docs/models.md b/docs/models.md index 48a0892b48..e8cb3b1082 100644 --- a/docs/models.md +++ b/docs/models.md @@ -175,13 +175,13 @@ Documents from a retriever can be passed directly to `generate` to provide grounding context: ```javascript -const docs = await companyPolicyRetriever({query: question}); +const docs = await companyPolicyRetriever({ query: question }); await generate({ model: geminiPro, prompt: `Answer using the available context from company policy: ${question}`, context: docs, -}) +}); ``` The document context is automatically appended to the content of the prompt diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 386bba6a81..10036ea621 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -25,7 +25,11 @@ import { toJsonSchema } from '@genkit-ai/core/schema'; import { performance } from 'node:perf_hooks'; import { z } from 'zod'; import { DocumentDataSchema } from './document.js'; -import { augmentWithContext, conformOutput, validateSupport } from './model/middleware.js'; +import { + augmentWithContext, + conformOutput, + validateSupport, +} from './model/middleware.js'; import * as telemetry from './telemetry.js'; // diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index 0d54500841..f179791672 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -36,7 +36,10 @@ export type PromptAction = Action< }; export function isPrompt(arg: any): boolean { - return typeof arg === 'function' && (arg as any).__action?.metadata?.type === 'prompt'; + return ( + typeof arg === 'function' && + (arg as any).__action?.metadata?.type === 'prompt' + ); } export function definePrompt( diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index 0ba79ea2e4..2ef5384e92 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -379,22 +379,19 @@ describe('augmentWithContext', () => { const messages: MessageData[] = [ { role: 'user', content: [{ text: 'first part' }] }, ]; - const result = await testRequest( - messages, - [ - { - content: [{ text: 'i am context' }], - metadata: { ref: 'first', id: 'wrong' }, - }, - { - content: [{ text: 'i am more context' }], - metadata: { id: 'second' }, - }, - { - content: [{ text: 'i am even more context' }], - }, - ] - ); + const result = await testRequest(messages, [ + { + content: [{ text: 'i am context' }], + metadata: { ref: 'first', id: 'wrong' }, + }, + { + content: [{ text: 'i am more context' }], + metadata: { id: 'second' }, + }, + { + content: [{ text: 'i am even more context' }], + }, + ]); assert.deepEqual(result[0].content.at(-1), { text: 'Use the following information to complete your task:\n\n- [first]: i am context\n- [second]: i am more context\n- [2]: i am even more context\n', metadata: { purpose: 'context' }, @@ -414,11 +411,11 @@ describe('augmentWithContext', () => { metadata: { uid: 'second' }, }, ], - { itemTemplate: (d) => `* (${d.metadata!.uid}) -- ${d.text()}\n`} + { itemTemplate: (d) => `* (${d.metadata!.uid}) -- ${d.text()}\n` } ); assert.deepEqual(result[0].content.at(-1), { text: 'Use the following information to complete your task:\n\n* (first) -- i am context\n* (second) -- i am more context\n', metadata: { purpose: 'context' }, }); - }) + }); }); From 70ea5d665945039f96878593272600a674d0a797 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 12:39:42 -0700 Subject: [PATCH 06/13] newlines --- js/ai/src/model/middleware.ts | 5 +++-- js/ai/tests/model/middleware_test.ts | 15 ++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/js/ai/src/model/middleware.ts b/js/ai/src/model/middleware.ts index 3d3d0b78d7..5518cc2431 100644 --- a/js/ai/src/model/middleware.ts +++ b/js/ai/src/model/middleware.ts @@ -182,8 +182,8 @@ export interface AugmentWithContextOptions { citationKey?: string | null; } -const CONTEXT_PREFACE = - 'Use the following information to complete your task:\n\n'; +export const CONTEXT_PREFACE = + '\n\nUse the following information to complete your task:\n\n'; const CONTEXT_ITEM_TEMPLATE = ( d: Document, index: number, @@ -218,6 +218,7 @@ export function augmentWithContext( req.context?.forEach((d, i) => { out += itemTemplate(new Document(d), i, options); }); + out += "\n"; userMessage.content.push({ text: out, metadata: { purpose: 'context' } }); return next(req); }; diff --git a/js/ai/tests/model/middleware_test.ts b/js/ai/tests/model/middleware_test.ts index 2ef5384e92..ec5ceae847 100644 --- a/js/ai/tests/model/middleware_test.ts +++ b/js/ai/tests/model/middleware_test.ts @@ -26,6 +26,7 @@ import { } from '../../src/model.js'; import { AugmentWithContextOptions, + CONTEXT_PREFACE, augmentWithContext, simulateSystemPrompt, validateSupport, @@ -313,7 +314,7 @@ describe('augmentWithContext', () => { { content: [{ text: 'i am more context' }] }, ]); assert.deepEqual(result[0].content.at(-1), { - text: 'Use the following information to complete your task:\n\n- [0]: i am context\n- [1]: i am more context\n', + text: `${CONTEXT_PREFACE}- [0]: i am context\n- [1]: i am more context\n\n`, metadata: { purpose: 'context' }, }); }); @@ -328,10 +329,10 @@ describe('augmentWithContext', () => { { content: [{ text: 'i am context' }] }, { content: [{ text: 'i am more context' }] }, ], - { preface: 'Check this out:\n\n' } + { preface: '\n\nCheck this out:\n\n' } ); assert.deepEqual(result[0].content.at(-1), { - text: 'Check this out:\n\n- [0]: i am context\n- [1]: i am more context\n', + text: '\n\nCheck this out:\n\n- [0]: i am context\n- [1]: i am more context\n\n', metadata: { purpose: 'context' }, }); }); @@ -349,7 +350,7 @@ describe('augmentWithContext', () => { { preface: null } ); assert.deepEqual(result[0].content.at(-1), { - text: '- [0]: i am context\n- [1]: i am more context\n', + text: '- [0]: i am context\n- [1]: i am more context\n\n', metadata: { purpose: 'context' }, }); }); @@ -370,7 +371,7 @@ describe('augmentWithContext', () => { { citationKey: 'uid' } ); assert.deepEqual(result[0].content.at(-1), { - text: 'Use the following information to complete your task:\n\n- [first]: i am context\n- [second]: i am more context\n', + text: `${CONTEXT_PREFACE}- [first]: i am context\n- [second]: i am more context\n\n`, metadata: { purpose: 'context' }, }); }); @@ -393,7 +394,7 @@ describe('augmentWithContext', () => { }, ]); assert.deepEqual(result[0].content.at(-1), { - text: 'Use the following information to complete your task:\n\n- [first]: i am context\n- [second]: i am more context\n- [2]: i am even more context\n', + text: `${CONTEXT_PREFACE}- [first]: i am context\n- [second]: i am more context\n- [2]: i am even more context\n\n`, metadata: { purpose: 'context' }, }); }); @@ -414,7 +415,7 @@ describe('augmentWithContext', () => { { itemTemplate: (d) => `* (${d.metadata!.uid}) -- ${d.text()}\n` } ); assert.deepEqual(result[0].content.at(-1), { - text: 'Use the following information to complete your task:\n\n* (first) -- i am context\n* (second) -- i am more context\n', + text: `${CONTEXT_PREFACE}* (first) -- i am context\n* (second) -- i am more context\n\n`, metadata: { purpose: 'context' }, }); }); From 6470016e32b5372e3ce21c5d3479852f16d1ed2d Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 12:40:05 -0700 Subject: [PATCH 07/13] format --- js/ai/src/model/middleware.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/ai/src/model/middleware.ts b/js/ai/src/model/middleware.ts index 5518cc2431..f15c794568 100644 --- a/js/ai/src/model/middleware.ts +++ b/js/ai/src/model/middleware.ts @@ -218,7 +218,7 @@ export function augmentWithContext( req.context?.forEach((d, i) => { out += itemTemplate(new Document(d), i, options); }); - out += "\n"; + out += '\n'; userMessage.content.push({ text: out, metadata: { purpose: 'context' } }); return next(req); }; From cb9b80accc6131a4bef73d64fe7b3b13e9bf2e7b Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 16:27:24 -0700 Subject: [PATCH 08/13] adds to GenerateOptions --- js/ai/src/generate.ts | 75 ++++++++++++++++++++------------------ js/ai/src/prompt.ts | 5 ++- js/dotprompt/src/prompt.ts | 2 +- 3 files changed, 44 insertions(+), 38 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 7b8595a35e..451a981453 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -46,6 +46,7 @@ import { ToolArgument, toToolDefinition, } from './tool.js'; +import { DocumentData } from '@google-cloud/firestore'; /** * Message represents a single role's contribution to a generation. Each message @@ -386,36 +387,36 @@ function inferRoleFromParts(parts: Part[]): Role { } export async function toGenerateRequest( - prompt: GenerateOptions + options: GenerateOptions ): Promise { const promptMessage: MessageData = { role: 'user', content: [] }; - if (typeof prompt.prompt === 'string') { - promptMessage.content.push({ text: prompt.prompt }); - } else if (Array.isArray(prompt.prompt)) { - promptMessage.role = inferRoleFromParts(prompt.prompt); - promptMessage.content.push(...prompt.prompt); + if (typeof options.prompt === 'string') { + promptMessage.content.push({ text: options.prompt }); + } else if (Array.isArray(options.prompt)) { + promptMessage.role = inferRoleFromParts(options.prompt); + promptMessage.content.push(...options.prompt); } else { - promptMessage.role = inferRoleFromParts([prompt.prompt]); - promptMessage.content.push(prompt.prompt); + promptMessage.role = inferRoleFromParts([options.prompt]); + promptMessage.content.push(options.prompt); } - const messages: MessageData[] = [...(prompt.history || []), promptMessage]; + const messages: MessageData[] = [...(options.history || []), promptMessage]; let tools: Action[] | undefined; - if (prompt.tools) { - tools = await resolveTools(prompt.tools); + if (options.tools) { + tools = await resolveTools(options.tools); } const out = { messages, - candidates: prompt.candidates, - config: prompt.config, + candidates: options.candidates, + config: options.config, tools: tools?.map((tool) => toToolDefinition(tool)) || [], output: { format: - prompt.output?.format || - (prompt.output?.schema || prompt.output?.jsonSchema ? 'json' : 'text'), + options.output?.format || + (options.output?.schema || options.output?.jsonSchema ? 'json' : 'text'), schema: toJsonSchema({ - schema: prompt.output?.schema, - jsonSchema: prompt.output?.jsonSchema, + schema: options.output?.schema, + jsonSchema: options.output?.jsonSchema, }), }, }; @@ -431,6 +432,8 @@ export interface GenerateOptions< model: ModelArgument; /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ prompt: string | Part | Part[]; + /** Retrieved documents to be used as context for this generation. */ + context?: DocumentData[], /** Conversation history for multi-turn prompting when supported by the underlying model. */ history?: MessageData[]; /** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */ @@ -530,29 +533,29 @@ export async function generate< | GenerateOptions | PromiseLike> ): Promise>> { - const prompt: GenerateOptions = + const resolvedOptions: GenerateOptions = await Promise.resolve(options); - const model = await resolveModel(prompt.model); + const model = await resolveModel(resolvedOptions.model); if (!model) { - throw new Error(`Model ${JSON.stringify(prompt.model)} not found`); + throw new Error(`Model ${JSON.stringify(resolvedOptions.model)} not found`); } let tools: ToolAction[] | undefined; - if (prompt.tools?.length) { + if (resolvedOptions.tools?.length) { if (!model.__action.metadata?.model.supports?.tools) { throw new Error( - `Model ${JSON.stringify(prompt.model)} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` + `Model ${JSON.stringify(resolvedOptions.model)} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` ); } - tools = await resolveTools(prompt.tools); + tools = await resolveTools(resolvedOptions.tools); } - const request = await toGenerateRequest(prompt); - telemetry.recordGenerateActionInputLogs(model.__action.name, prompt, request); + const request = await toGenerateRequest(resolvedOptions); + telemetry.recordGenerateActionInputLogs(model.__action.name, resolvedOptions, request); const response = await runWithStreamingCallback( - prompt.streamingCallback + resolvedOptions.streamingCallback ? (chunk: GenerateResponseChunkData) => - prompt.streamingCallback!(new GenerateResponseChunk(chunk)) + resolvedOptions.streamingCallback!(new GenerateResponseChunk(chunk)) : undefined, async () => new GenerateResponse>(await model(request), request) ); @@ -569,13 +572,13 @@ export async function generate< }); } - if (prompt.output?.schema || prompt.output?.jsonSchema) { + if (resolvedOptions.output?.schema || resolvedOptions.output?.jsonSchema) { // find a candidate with valid output schema const candidateValidations = response.candidates.map((c) => { try { return validateSchema(c.output(), { - jsonSchema: prompt.output?.jsonSchema, - schema: prompt.output?.schema, + jsonSchema: resolvedOptions.output?.jsonSchema, + schema: resolvedOptions.output?.schema, }); } catch (e) { return { @@ -612,10 +615,10 @@ export async function generate< const toolCalls = selected.message.content.filter( (part) => !!part.toolRequest ); - if (prompt.returnToolRequests || toolCalls.length === 0) { + if (resolvedOptions.returnToolRequests || toolCalls.length === 0) { telemetry.recordGenerateActionOutputLogs( model.__action.name, - prompt, + resolvedOptions, response ); return response; @@ -642,10 +645,10 @@ export async function generate< }; }) ); - prompt.history = request.messages; - prompt.history.push(selected.message); - prompt.prompt = toolResponses; - return await generate(prompt); + resolvedOptions.history = request.messages; + resolvedOptions.history.push(selected.message); + resolvedOptions.prompt = toolResponses; + return await generate(resolvedOptions); } export type GenerateStreamOptions< diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index f179791672..f5554f173e 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -19,6 +19,7 @@ import { lookupAction } from '@genkit-ai/core/registry'; import z from 'zod'; import { GenerateOptions } from './generate'; import { GenerateRequest, GenerateRequestSchema, ModelArgument } from './model'; +import { DocumentData } from '@google-cloud/firestore'; export type PromptFn = ( input: z.infer @@ -86,6 +87,7 @@ export async function renderPrompt< >(params: { prompt: PromptArgument; input: z.infer; + context?: DocumentData[], model: ModelArgument; config?: z.infer; }): Promise { @@ -101,5 +103,6 @@ export async function renderPrompt< config: { ...(rendered.config || {}), ...params.config }, history: rendered.messages.slice(0, rendered.messages.length - 1), prompt: rendered.messages[rendered.messages.length - 1].content, + context: params.context, }; -} +} \ No newline at end of file diff --git a/js/dotprompt/src/prompt.ts b/js/dotprompt/src/prompt.ts index 880e491c08..58c12435ae 100644 --- a/js/dotprompt/src/prompt.ts +++ b/js/dotprompt/src/prompt.ts @@ -9,7 +9,7 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, -cd * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ From 3780880bfa210b0b3c562ee3f5aeb8b975eb77a9 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 16:33:00 -0700 Subject: [PATCH 09/13] update --- js/dotprompt/src/prompt.ts | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/js/dotprompt/src/prompt.ts b/js/dotprompt/src/prompt.ts index 58c12435ae..99e5f8438c 100644 --- a/js/dotprompt/src/prompt.ts +++ b/js/dotprompt/src/prompt.ts @@ -47,7 +47,6 @@ export type PromptGenerateOptions = Omit< > & { model?: string; input?: V; - context?: DocumentData[]; }; export class Dotprompt implements PromptMetadata { @@ -147,36 +146,14 @@ export class Dotprompt implements PromptMetadata { { name: `${this.name}${this.variant ? `.${this.variant}` : ''}`, description: 'Defined by Dotprompt', - inputSchema: this.input?.schema - ? z.object({ - input: this.input.schema, - context: z.array(DocumentDataSchema).optional(), - }) - : undefined, - inputJsonSchema: this.input?.jsonSchema - ? { - type: 'object', - properties: { - input: this.input?.jsonSchema, - context: { - type: 'array', - items: toJsonSchema({ schema: DocumentDataSchema }), - }, - }, - } - : undefined, + inputSchema: this.input?.schema, + inputJsonSchema: this.input?.jsonSchema, metadata: { type: 'prompt', prompt: this.toJSON(), }, }, - async ({ - input, - context, - }: { - input?: Variables; - context?: DocumentData[]; - }) => toGenerateRequest(this.render({ input, context })) + async (input?: Variables) => toGenerateRequest(this.render({input})) ); } @@ -197,6 +174,7 @@ export class Dotprompt implements PromptMetadata { config: { ...this.config, ...options.config } || {}, history: messages.slice(0, messages.length - 1), prompt: messages[messages.length - 1].content, + context: options.context, candidates: options.candidates || this.candidates || 1, output: { format: options.output?.format || this.output?.format || undefined, From 52702c69027e4d18ba17f94f80166952b7bac96d Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 16:33:20 -0700 Subject: [PATCH 10/13] format --- js/ai/src/generate.ts | 14 ++++++++++---- js/ai/src/prompt.ts | 6 +++--- js/dotprompt/src/prompt.ts | 6 +++--- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 451a981453..a5fa2a63d0 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -22,6 +22,7 @@ import { } from '@genkit-ai/core'; import { lookupAction } from '@genkit-ai/core/registry'; import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema'; +import { DocumentData } from '@google-cloud/firestore'; import { z } from 'zod'; import { extractJson } from './extract.js'; import { @@ -46,7 +47,6 @@ import { ToolArgument, toToolDefinition, } from './tool.js'; -import { DocumentData } from '@google-cloud/firestore'; /** * Message represents a single role's contribution to a generation. Each message @@ -413,7 +413,9 @@ export async function toGenerateRequest( output: { format: options.output?.format || - (options.output?.schema || options.output?.jsonSchema ? 'json' : 'text'), + (options.output?.schema || options.output?.jsonSchema + ? 'json' + : 'text'), schema: toJsonSchema({ schema: options.output?.schema, jsonSchema: options.output?.jsonSchema, @@ -433,7 +435,7 @@ export interface GenerateOptions< /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ prompt: string | Part | Part[]; /** Retrieved documents to be used as context for this generation. */ - context?: DocumentData[], + context?: DocumentData[]; /** Conversation history for multi-turn prompting when supported by the underlying model. */ history?: MessageData[]; /** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */ @@ -551,7 +553,11 @@ export async function generate< } const request = await toGenerateRequest(resolvedOptions); - telemetry.recordGenerateActionInputLogs(model.__action.name, resolvedOptions, request); + telemetry.recordGenerateActionInputLogs( + model.__action.name, + resolvedOptions, + request + ); const response = await runWithStreamingCallback( resolvedOptions.streamingCallback ? (chunk: GenerateResponseChunkData) => diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index f5554f173e..c2050b8c36 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -16,10 +16,10 @@ import { Action, defineAction, JSONSchema7 } from '@genkit-ai/core'; import { lookupAction } from '@genkit-ai/core/registry'; +import { DocumentData } from '@google-cloud/firestore'; import z from 'zod'; import { GenerateOptions } from './generate'; import { GenerateRequest, GenerateRequestSchema, ModelArgument } from './model'; -import { DocumentData } from '@google-cloud/firestore'; export type PromptFn = ( input: z.infer @@ -87,7 +87,7 @@ export async function renderPrompt< >(params: { prompt: PromptArgument; input: z.infer; - context?: DocumentData[], + context?: DocumentData[]; model: ModelArgument; config?: z.infer; }): Promise { @@ -105,4 +105,4 @@ export async function renderPrompt< prompt: rendered.messages[rendered.messages.length - 1].content, context: params.context, }; -} \ No newline at end of file +} diff --git a/js/dotprompt/src/prompt.ts b/js/dotprompt/src/prompt.ts index 99e5f8438c..9b57cc58b6 100644 --- a/js/dotprompt/src/prompt.ts +++ b/js/dotprompt/src/prompt.ts @@ -25,9 +25,9 @@ import { toGenerateRequest, } from '@genkit-ai/ai'; import { GenerationCommonConfigSchema, MessageData } from '@genkit-ai/ai/model'; -import { DocumentData, DocumentDataSchema } from '@genkit-ai/ai/retriever'; +import { DocumentData } from '@genkit-ai/ai/retriever'; import { GenkitError } from '@genkit-ai/core'; -import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; +import { parseSchema } from '@genkit-ai/core/schema'; import { createHash } from 'crypto'; import fm, { FrontMatterResult } from 'front-matter'; import z from 'zod'; @@ -153,7 +153,7 @@ export class Dotprompt implements PromptMetadata { prompt: this.toJSON(), }, }, - async (input?: Variables) => toGenerateRequest(this.render({input})) + async (input?: Variables) => toGenerateRequest(this.render({ input })) ); } From c72f3b579885ae7fdb452a875f4ff5cc99544cee Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 16:36:49 -0700 Subject: [PATCH 11/13] remove dotprompt stuff, it doesnt work --- js/dotprompt/src/template.ts | 36 ++++---- js/dotprompt/tests/template_test.ts | 130 ++++++++++++++-------------- 2 files changed, 82 insertions(+), 84 deletions(-) diff --git a/js/dotprompt/src/template.ts b/js/dotprompt/src/template.ts index c00abc5c23..c5449b0fc7 100644 --- a/js/dotprompt/src/template.ts +++ b/js/dotprompt/src/template.ts @@ -48,22 +48,22 @@ function mediaHelper(options: Handlebars.HelperOptions) { } Promptbars.registerHelper('media', mediaHelper); -function contextHelper(options: Handlebars.HelperOptions) { - const context = options.data?.metadata?.context || []; - const items = context.map((d: DocumentData, i) => { - let text = new Document(d).text(); - if (options.hash.cite === true) { - text += `[${i}]`; - } else if (d.metadata?.[options.hash.cite]) { - text += `[${d.metadata[options.hash.cite]}]`; - } - return Promptbars.escapeExpression(text); - }); - - return new Promptbars.SafeString(`<<>> -- ${items.join('\n- ')}`); -} -Promptbars.registerHelper('context', contextHelper); +// function contextHelper(options: Handlebars.HelperOptions) { +// const context = options.data?.metadata?.context || []; +// const items = context.map((d: DocumentData, i) => { +// let text = new Document(d).text(); +// if (options.hash.cite === true) { +// text += `[${i}]`; +// } else if (d.metadata?.[options.hash.cite]) { +// text += `[${d.metadata[options.hash.cite]}]`; +// } +// return Promptbars.escapeExpression(text); +// }); + +// return new Promptbars.SafeString(`<<>> +// - ${items.join('\n- ')}`); +// } +// Promptbars.registerHelper('context', contextHelper); const ROLE_REGEX = /(<<>>/g; @@ -134,9 +134,9 @@ export function compile( media: true, role: true, history: true, - context: true, + // context: true, }, - // knownHelpersOnly: true, + knownHelpersOnly: true, }); return ( diff --git a/js/dotprompt/tests/template_test.ts b/js/dotprompt/tests/template_test.ts index 812abe454a..bd560063e0 100644 --- a/js/dotprompt/tests/template_test.ts +++ b/js/dotprompt/tests/template_test.ts @@ -106,75 +106,73 @@ describe('compile', () => { template: '{{json . indent=2}}', want: [{ role: 'user', content: [{ text: '{\n "test": true\n}' }] }], }, - { - should: 'allow defining context', - input: {}, - context: [ - { content: [{ text: 'abc' }, { text: 'def' }] }, - { content: [{ text: 'hgi' }] }, - ], - template: '{{context}}', - want: [ - { - role: 'user', - content: [ - { text: '\n- abcdef\n- hgi', metadata: { purpose: 'context' } }, - ], - }, - ], - }, - { - should: 'allow defining context with custom citations', - input: {}, - context: [ - { - content: [{ text: 'abc' }, { text: 'def' }], - metadata: { ref: 'first' }, - }, - { content: [{ text: 'hgi' }], metadata: { ref: 'second' } }, - ], - template: '{{context cite="ref"}}', - want: [ - { - role: 'user', - content: [ - { - text: '\n- abcdef[first]\n- hgi[second]', - metadata: { purpose: 'context' }, - }, - ], - }, - ], - }, - { - should: 'allow defining context with numbered citations', - input: {}, - context: [ - { - content: [{ text: 'abc' }, { text: 'def' }], - metadata: { ref: 'first' }, - }, - { content: [{ text: 'hgi' }], metadata: { ref: 'second' } }, - ], - template: '{{context cite=true}}', - want: [ - { - role: 'user', - content: [ - { - text: '\n- abcdef[0]\n- hgi[1]', - metadata: { purpose: 'context' }, - }, - ], - }, - ], - }, + // { + // should: 'allow defining context', + // input: {}, + // context: [ + // { content: [{ text: 'abc' }, { text: 'def' }] }, + // { content: [{ text: 'hgi' }] }, + // ], + // template: '{{context}}', + // want: [ + // { + // role: 'user', + // content: [ + // { text: '\n- abcdef\n- hgi', metadata: { purpose: 'context' } }, + // ], + // }, + // ], + // }, + // { + // should: 'allow defining context with custom citations', + // input: {}, + // context: [ + // { + // content: [{ text: 'abc' }, { text: 'def' }], + // metadata: { ref: 'first' }, + // }, + // { content: [{ text: 'hgi' }], metadata: { ref: 'second' } }, + // ], + // template: '{{context cite="ref"}}', + // want: [ + // { + // role: 'user', + // content: [ + // { + // text: '\n- abcdef[first]\n- hgi[second]', + // metadata: { purpose: 'context' }, + // }, + // ], + // }, + // ], + // }, + // { + // should: 'allow defining context with numbered citations', + // input: {}, + // context: [ + // { + // content: [{ text: 'abc' }, { text: 'def' }], + // metadata: { ref: 'first' }, + // }, + // { content: [{ text: 'hgi' }], metadata: { ref: 'second' } }, + // ], + // template: '{{context cite=true}}', + // want: [ + // { + // role: 'user', + // content: [ + // { + // text: '\n- abcdef[0]\n- hgi[1]', + // metadata: { purpose: 'context' }, + // }, + // ], + // }, + // ], + // }, ]) { it(test.should, () => { assert.deepEqual( - compile(test.template, { model: 'test/example' })(test.input, { - context: test.context, - }), + compile(test.template, { model: 'test/example' })(test.input), test.want ); }); From d5b385a570620eaab9767ba4b76f34d26cd3d8d5 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 19:31:23 -0700 Subject: [PATCH 12/13] address feedback --- js/dotprompt/src/template.ts | 18 --------- js/dotprompt/tests/template_test.ts | 63 ----------------------------- 2 files changed, 81 deletions(-) diff --git a/js/dotprompt/src/template.ts b/js/dotprompt/src/template.ts index c5449b0fc7..e7c13bb496 100644 --- a/js/dotprompt/src/template.ts +++ b/js/dotprompt/src/template.ts @@ -48,23 +48,6 @@ function mediaHelper(options: Handlebars.HelperOptions) { } Promptbars.registerHelper('media', mediaHelper); -// function contextHelper(options: Handlebars.HelperOptions) { -// const context = options.data?.metadata?.context || []; -// const items = context.map((d: DocumentData, i) => { -// let text = new Document(d).text(); -// if (options.hash.cite === true) { -// text += `[${i}]`; -// } else if (d.metadata?.[options.hash.cite]) { -// text += `[${d.metadata[options.hash.cite]}]`; -// } -// return Promptbars.escapeExpression(text); -// }); - -// return new Promptbars.SafeString(`<<>> -// - ${items.join('\n- ')}`); -// } -// Promptbars.registerHelper('context', contextHelper); - const ROLE_REGEX = /(<<>>/g; function toMessages(renderedString: string): MessageData[] { @@ -134,7 +117,6 @@ export function compile( media: true, role: true, history: true, - // context: true, }, knownHelpersOnly: true, }); diff --git a/js/dotprompt/tests/template_test.ts b/js/dotprompt/tests/template_test.ts index bd560063e0..ba01191fad 100644 --- a/js/dotprompt/tests/template_test.ts +++ b/js/dotprompt/tests/template_test.ts @@ -106,69 +106,6 @@ describe('compile', () => { template: '{{json . indent=2}}', want: [{ role: 'user', content: [{ text: '{\n "test": true\n}' }] }], }, - // { - // should: 'allow defining context', - // input: {}, - // context: [ - // { content: [{ text: 'abc' }, { text: 'def' }] }, - // { content: [{ text: 'hgi' }] }, - // ], - // template: '{{context}}', - // want: [ - // { - // role: 'user', - // content: [ - // { text: '\n- abcdef\n- hgi', metadata: { purpose: 'context' } }, - // ], - // }, - // ], - // }, - // { - // should: 'allow defining context with custom citations', - // input: {}, - // context: [ - // { - // content: [{ text: 'abc' }, { text: 'def' }], - // metadata: { ref: 'first' }, - // }, - // { content: [{ text: 'hgi' }], metadata: { ref: 'second' } }, - // ], - // template: '{{context cite="ref"}}', - // want: [ - // { - // role: 'user', - // content: [ - // { - // text: '\n- abcdef[first]\n- hgi[second]', - // metadata: { purpose: 'context' }, - // }, - // ], - // }, - // ], - // }, - // { - // should: 'allow defining context with numbered citations', - // input: {}, - // context: [ - // { - // content: [{ text: 'abc' }, { text: 'def' }], - // metadata: { ref: 'first' }, - // }, - // { content: [{ text: 'hgi' }], metadata: { ref: 'second' } }, - // ], - // template: '{{context cite=true}}', - // want: [ - // { - // role: 'user', - // content: [ - // { - // text: '\n- abcdef[0]\n- hgi[1]', - // metadata: { purpose: 'context' }, - // }, - // ], - // }, - // ], - // }, ]) { it(test.should, () => { assert.deepEqual( From 1dbf78102dc2ad13ff34ba61c5998a30b16786df Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Tue, 7 May 2024 19:31:38 -0700 Subject: [PATCH 13/13] format --- js/dotprompt/src/template.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/dotprompt/src/template.ts b/js/dotprompt/src/template.ts index e7c13bb496..9bbac75d74 100644 --- a/js/dotprompt/src/template.ts +++ b/js/dotprompt/src/template.ts @@ -21,7 +21,7 @@ import { Role, TextPart, } from '@genkit-ai/ai/model'; -import { Document, DocumentData } from '@genkit-ai/ai/retriever'; +import { DocumentData } from '@genkit-ai/ai/retriever'; import Handlebars from 'handlebars'; import { PromptMetadata } from './metadata.js';