diff --git a/docs/dotprompt.md b/docs/dotprompt.md index ce09b925e7..71cfed0589 100644 --- a/docs/dotprompt.md +++ b/docs/dotprompt.md @@ -62,7 +62,7 @@ conditional portions to your prompt or iterate through structured content. The file format utilizes YAML frontmatter to provide metadata for a prompt inline with the template. -## Defining Input/Output Schemas with Picoschema +## Defining Input/Output Schemas Dotprompt includes a compact, YAML-optimized schema definition format called Picoschema to make it easy to define the most important attributs of a schema @@ -142,6 +142,50 @@ output: minimum: 20 ``` +### Leveraging Reusable Schemas + +In addition to directly defining schemas in the `.prompt` file, you can reference +a schema registered with `defineSchema` by name. To register a schema: + +```ts +import { defineSchema } from '@genkit-ai/core'; +import { z } from 'zod'; + +const MySchema = defineSchema( + 'MySchema', + z.object({ + field1: z.string(), + field2: z.number(), + }) +); +``` + +Within your prompt, you can provide the name of the registered schema: + +```yaml +# myPrompt.prompt +--- +model: vertexai/gemini-1.5-flash +output: + schema: MySchema +--- +``` + +The Dotprompt library will automatically resolve the name to the underlying +registered Zod schema. You can then utilize the schema to strongly type the +output of a Dotprompt: + +```ts +import { prompt } from "@genkit-ai/dotprompt"; + +const myPrompt = await prompt("myPrompt"); + +const result = await myPrompt.generate({...}); + +// now strongly typed as MySchema +result.output(); +``` + ## Overriding Prompt Metadata While `.prompt` files allow you to embed metadata such as model configuration in diff --git a/js/core/src/index.ts b/js/core/src/index.ts index 2155298f44..71ebb997c4 100644 --- a/js/core/src/index.ts +++ b/js/core/src/index.ts @@ -21,4 +21,5 @@ export * from './action.js'; export * from './config.js'; export { GenkitError } from './error.js'; export * from './flowTypes.js'; +export { defineJsonSchema, defineSchema } from './schema.js'; export * from './telemetryTypes.js'; diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index d24935d387..94647c92be 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -20,6 +20,7 @@ import { FlowStateStore } from './flowTypes.js'; import { logger } from './logging.js'; import { PluginProvider } from './plugin.js'; import { startReflectionApi } from './reflectionApi.js'; +import { JSONSchema } from './schema.js'; import { TraceStore } from './tracing/types.js'; export type AsyncProvider = () => Promise; @@ -28,6 +29,7 @@ const ACTIONS_BY_ID = 'genkit__ACTIONS_BY_ID'; const TRACE_STORES_BY_ENV = 'genkit__TRACE_STORES_BY_ENV'; const FLOW_STATE_STORES_BY_ENV = 'genkit__FLOW_STATE_STORES_BY_ENV'; const PLUGINS_BY_NAME = 'genkit__PLUGINS_BY_NAME'; +const SCHEMAS_BY_NAME = 'genkit__SCHEMAS_BY_NAME'; function actionsById(): Record> { if (global[ACTIONS_BY_ID] === undefined) { @@ -53,6 +55,15 @@ function pluginsByName(): Record { } return global[PLUGINS_BY_NAME]; } +function schemasByName(): Record< + string, + { schema?: z.ZodTypeAny; jsonSchema?: JSONSchema } +> { + if (global[SCHEMAS_BY_NAME] === undefined) { + global[SCHEMAS_BY_NAME] = {}; + } + return global[SCHEMAS_BY_NAME]; +} /** * Type of a runnable action. @@ -211,6 +222,17 @@ export async function initializePlugin(name: string) { return undefined; } +export function registerSchema( + name: string, + data: { schema?: z.ZodTypeAny; jsonSchema?: JSONSchema } +) { + schemasByName()[name] = data; +} + +export function lookupSchema(name: string) { + return schemasByName()[name]; +} + /** * Development mode only. Starts a Reflection API so that the actions can be called by the Runner. */ diff --git a/js/core/src/schema.ts b/js/core/src/schema.ts index 5c6609469a..16a45160d3 100644 --- a/js/core/src/schema.ts +++ b/js/core/src/schema.ts @@ -19,6 +19,7 @@ import addFormats from 'ajv-formats'; import { z } from 'zod'; import zodToJsonSchema from 'zod-to-json-schema'; import { GenkitError } from './error.js'; +import { registerSchema } from './registry.js'; const ajv = new Ajv(); addFormats(ajv); @@ -109,3 +110,16 @@ export function parseSchema( if (!valid) throw new ValidationError({ data, errors: errors!, schema }); return data as T; } + +export function defineSchema( + name: string, + schema: T +): T { + registerSchema(name, { schema }); + return schema; +} + +export function defineJsonSchema(name: string, jsonSchema: JSONSchema) { + registerSchema(name, { jsonSchema }); + return jsonSchema; +} diff --git a/js/plugins/dotprompt/src/metadata.ts b/js/plugins/dotprompt/src/metadata.ts index 131ea1373c..839a7068d5 100644 --- a/js/plugins/dotprompt/src/metadata.ts +++ b/js/plugins/dotprompt/src/metadata.ts @@ -24,6 +24,7 @@ import { ModelArgument, } from '@genkit-ai/ai/model'; import { ToolArgument } from '@genkit-ai/ai/tool'; +import { lookupSchema } from '@genkit-ai/core/registry'; import { JSONSchema, parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import z from 'zod'; import { picoschema } from './picoschema.js'; @@ -92,8 +93,8 @@ export const PromptFrontmatterSchema = z.object({ config: GenerationCommonConfigSchema.passthrough().optional(), input: z .object({ - schema: z.unknown(), default: z.any(), + schema: z.unknown(), }) .optional(), output: z @@ -122,21 +123,37 @@ function stripUndefinedOrNull(obj: any) { return obj; } +function fmSchemaToSchema(fmSchema: any) { + if (!fmSchema) return {}; + if (typeof fmSchema === 'string') return lookupSchema(fmSchema); + return { jsonSchema: picoschema(fmSchema) }; +} + export function toMetadata(attributes: unknown): Partial { const fm = parseSchema>(attributes, { schema: PromptFrontmatterSchema, }); + + let input: PromptMetadata['input'] | undefined; + if (fm.input) { + input = { default: fm.input.default, ...fmSchemaToSchema(fm.input.schema) }; + } + + let output: PromptMetadata['output'] | undefined; + if (fm.output) { + output = { + format: fm.output.format, + ...fmSchemaToSchema(fm.output.schema), + }; + } + return stripUndefinedOrNull({ name: fm.name, variant: fm.variant, model: fm.model, config: fm.config, - input: fm.input - ? { default: fm.input.default, jsonSchema: picoschema(fm.input.schema) } - : undefined, - output: fm.output - ? { format: fm.output.format, jsonSchema: picoschema(fm.output.schema) } - : undefined, + input, + output, metadata: fm.metadata, tools: fm.tools, candidates: fm.candidates, diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index 24bbc80aa0..6d1db8851d 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -89,7 +89,7 @@ export class Dotprompt implements PromptMetadata { throw new GenkitError({ source: 'Dotprompt', status: 'INVALID_ARGUMENT', - message: `Error parsing YAML frontmatter of '${name}' prompt: ${e.message}`, + message: `Error parsing YAML frontmatter of '${name}' prompt: ${e.stack}`, }); } } @@ -166,9 +166,9 @@ export class Dotprompt implements PromptMetadata { ); } - private _generateOptions( + private _generateOptions( options: PromptGenerateOptions - ): GenerateOptions { + ): GenerateOptions { const messages = this.renderMessages(options.input, { history: options.history, context: options.context, @@ -188,17 +188,19 @@ export class Dotprompt implements PromptMetadata { tools: (options.tools || []).concat(this.tools || []), streamingCallback: options.streamingCallback, returnToolRequests: options.returnToolRequests, - }; + } as GenerateOptions; } - render(opt: PromptGenerateOptions): GenerateOptions { - return this._generateOptions(opt); + render( + opt: PromptGenerateOptions + ): GenerateOptions { + return this._generateOptions(opt); } - async generate( + async generate( opt: PromptGenerateOptions - ): Promise { - return generate(this.render(opt)); + ): Promise>> { + return generate(this.render(opt)); } async generateStream( diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index ce0f14464d..d1dd4ce321 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -21,6 +21,7 @@ import { defineModel } from '@genkit-ai/ai/model'; import { toJsonSchema, ValidationError } from '@genkit-ai/core/schema'; import z from 'zod'; import { registerPluginProvider } from '../../../core/src/registry.js'; +import { defineJsonSchema, defineSchema } from '../../../core/src/schema.js'; import { defineDotprompt, Dotprompt, prompt } from '../src/index.js'; import { PromptMetadata } from '../src/metadata.js'; @@ -200,6 +201,24 @@ output: }, }); }); + + it('should use registered schemas', () => { + const MyInput = defineSchema('MyInput', z.number()); + defineJsonSchema('MyOutput', { type: 'boolean' }); + + const p = Dotprompt.parse( + 'example2', + `--- +input: + schema: MyInput +output: + schema: MyOutput +---` + ); + + assert.deepEqual(p.input, { schema: MyInput }); + assert.deepEqual(p.output, { jsonSchema: { type: 'boolean' } }); + }); }); describe('defineDotprompt', () => { diff --git a/js/testapps/prompt-file/package.json b/js/testapps/prompt-file/package.json index 7d1383f76f..ec4efbb0eb 100644 --- a/js/testapps/prompt-file/package.json +++ b/js/testapps/prompt-file/package.json @@ -10,6 +10,7 @@ "@genkit-ai/googleai": "workspace:*", "zod": "^3.22.4" }, + "main": "lib/index.js", "scripts": { "build": "tsc", "test": "echo \"Error: no test specified\" && exit 1" diff --git a/js/testapps/prompt-file/prompts/recipe.prompt b/js/testapps/prompt-file/prompts/recipe.prompt index c51e1887a7..78400dc6b0 100644 --- a/js/testapps/prompt-file/prompts/recipe.prompt +++ b/js/testapps/prompt-file/prompts/recipe.prompt @@ -4,12 +4,7 @@ input: schema: food: string output: - schema: - title: string, recipe title - ingredients(array): - name: string - quantity: string - steps(array, the steps required to complete the recipe): string + schema: Recipe --- You are a chef famous for making creative recipes that can be prepared in 45 minutes or less. diff --git a/js/testapps/prompt-file/src/index.ts b/js/testapps/prompt-file/src/index.ts index bff74e4808..9339232bd0 100644 --- a/js/testapps/prompt-file/src/index.ts +++ b/js/testapps/prompt-file/src/index.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { configureGenkit } from '@genkit-ai/core'; +import { configureGenkit, defineSchema } from '@genkit-ai/core'; import { dotprompt, prompt } from '@genkit-ai/dotprompt'; import { defineFlow } from '@genkit-ai/flow'; import { googleAI } from '@genkit-ai/googleai'; @@ -26,6 +26,24 @@ configureGenkit({ logLevel: 'debug', }); +/* +title: string, recipe title + ingredients(array): + name: string + quantity: string + steps(array, the steps required to complete the recipe): string + */ +const RecipeSchema = defineSchema( + 'Recipe', + z.object({ + title: z.string().describe('recipe title'), + ingredients: z.array(z.object({ name: z.string(), quantity: z.string() })), + steps: z + .array(z.string()) + .describe('the steps required to complete the recipe'), + }) +); + // This example demonstrates using prompt files in a flow // Load the prompt file during initialization. // If it fails, due to the prompt file being invalid, the process will crash, @@ -38,9 +56,12 @@ prompt('recipe').then((recipePrompt) => { inputSchema: z.object({ food: z.string(), }), - outputSchema: z.any(), + outputSchema: RecipeSchema, }, - async (input) => (await recipePrompt.generate({ input: input })).output() + async (input) => + ( + await recipePrompt.generate({ input: input }) + ).output()! ); });