From 3ca484fd2f667ee32e11f59fdb1076fdb7c1f541 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 6 May 2024 10:31:25 -0400 Subject: [PATCH 1/4] Introduced `defineAction` which handles registry for actions under the hood --- js/ai/src/embedder.ts | 8 ++-- js/ai/src/evaluator.ts | 10 ++--- js/ai/src/model.ts | 7 ++-- js/ai/src/retriever.ts | 12 +++--- js/ai/src/tool.ts | 8 ++-- js/core/src/action.ts | 73 +++++++++++++++++++++++++++------- js/core/src/registry.ts | 3 +- js/core/tests/registry_test.ts | 54 ++++++++++++++++--------- js/dotprompt/src/index.ts | 6 +-- js/dotprompt/src/registry.ts | 6 +-- js/flow/src/flow.ts | 8 ++-- 11 files changed, 123 insertions(+), 72 deletions(-) diff --git a/js/ai/src/embedder.ts b/js/ai/src/embedder.ts index bf9102e2e2..92fab96eef 100644 --- a/js/ai/src/embedder.ts +++ b/js/ai/src/embedder.ts @@ -14,8 +14,8 @@ * limitations under the License. */ -import { action, Action } from '@genkit-ai/core'; -import { lookupAction, registerAction } from '@genkit-ai/core/registry'; +import { Action, defineAction } from '@genkit-ai/core'; +import { lookupAction } from '@genkit-ai/core/registry'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import * as z from 'zod'; import { Document, DocumentData, DocumentDataSchema } from './document.js'; @@ -68,8 +68,9 @@ export function defineEmbedder< }, runner: EmbedderFn ) { - const embedder = action( + const embedder = defineAction( { + actionType: 'embedder', name: options.name, inputSchema: options.configSchema ? EmbedRequestSchema.extend({ @@ -94,7 +95,6 @@ export function defineEmbedder< embedder as Action, options.configSchema ); - registerAction('embedder', ewm.__action.name, ewm); return ewm; } diff --git a/js/ai/src/evaluator.ts b/js/ai/src/evaluator.ts index 02c800ffd8..1e27a83197 100644 --- a/js/ai/src/evaluator.ts +++ b/js/ai/src/evaluator.ts @@ -14,13 +14,13 @@ * limitations under the License. */ -import { action, Action } from '@genkit-ai/core'; +import { Action, defineAction } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { lookupAction, registerAction } from '@genkit-ai/core/registry'; +import { lookupAction } from '@genkit-ai/core/registry'; import { + SPAN_TYPE_ATTR, runInNewSpan, setCustomMetadataAttributes, - SPAN_TYPE_ATTR, } from '@genkit-ai/core/tracing'; import * as z from 'zod'; @@ -128,8 +128,9 @@ export function defineEvaluator< options.isBilled == undefined ? true : options.isBilled; metadata[EVALUATOR_METADATA_KEY_DISPLAY_NAME] = options.displayName; metadata[EVALUATOR_METADATA_KEY_DEFINITION] = options.definition; - const evaluator = action( + const evaluator = defineAction( { + actionType: 'evaluator', name: options.name, inputSchema: EvalRequestSchema.extend({ dataset: options.dataPointType @@ -205,7 +206,6 @@ export function defineEvaluator< options.dataPointType, options.configSchema ); - registerAction('evaluator', evaluator.__action.name, evaluator); return ewm; } diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 2f3b06df63..b6183340ed 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -16,11 +16,10 @@ import { Action, - action, + defineAction, getStreamingCallback, StreamingCallback, } from '@genkit-ai/core'; -import { registerAction } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import { performance } from 'node:perf_hooks'; @@ -291,8 +290,9 @@ export function defineModel< ) => Promise ): ModelAction { const label = options.label || `${options.name} GenAI model`; - const act = action( + const act = defineAction( { + actionType: 'model', name: options.name, description: label, inputSchema: GenerateRequestSchema, @@ -340,7 +340,6 @@ export function defineModel< act as ModelAction, middleware ) as ModelAction; - registerAction('model', ma.__action.name, ma); return ma; } diff --git a/js/ai/src/retriever.ts b/js/ai/src/retriever.ts index 9a77a11f56..e8b2afcdfb 100644 --- a/js/ai/src/retriever.ts +++ b/js/ai/src/retriever.ts @@ -14,8 +14,8 @@ * limitations under the License. */ -import { action, Action } from '@genkit-ai/core'; -import { lookupAction, registerAction } from '@genkit-ai/core/registry'; +import { Action, defineAction } from '@genkit-ai/core'; +import { lookupAction } from '@genkit-ai/core/registry'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import * as z from 'zod'; import { Document, DocumentData, DocumentDataSchema } from './document.js'; @@ -113,8 +113,9 @@ export function defineRetriever< }, runner: RetrieverFn ) { - const retriever = action( + const retriever = defineAction( { + actionType: 'retriever', name: options.name, inputSchema: options.configSchema ? RetrieverRequestSchema.extend({ @@ -139,7 +140,6 @@ export function defineRetriever< >, options.configSchema ); - registerAction('retriever', rwm.__action.name, rwm); return rwm; } @@ -154,8 +154,9 @@ export function defineIndexer( }, runner: IndexerFn ) { - const indexer = action( + const indexer = defineAction( { + actionType: 'indexer', name: options.name, inputSchema: options.configSchema ? IndexerRequestSchema.extend({ @@ -180,7 +181,6 @@ export function defineIndexer( indexer as Action, options.configSchema ); - registerAction('indexer', iwm.__action.name, iwm); return iwm; } diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index e430d4dcd6..c20043bc09 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -14,8 +14,8 @@ * limitations under the License. */ -import { Action, JSONSchema7, action } from '@genkit-ai/core'; -import { lookupAction, registerAction } from '@genkit-ai/core/registry'; +import { Action, defineAction, JSONSchema7 } from '@genkit-ai/core'; +import { lookupAction } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import z from 'zod'; @@ -117,8 +117,9 @@ export function defineTool( }, fn: (input: z.infer) => Promise> ): ToolAction { - const a = action( + const a = defineAction( { + actionType: 'tool', name, description, inputSchema, @@ -132,6 +133,5 @@ export function defineTool( return fn(i); } ); - registerAction('tool', name, a); return a as ToolAction; } diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 15eae303a0..b9aad8c89f 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -18,9 +18,10 @@ import { JSONSchema7 } from 'json-schema'; import { AsyncLocalStorage } from 'node:async_hooks'; import { performance } from 'node:perf_hooks'; import * as z from 'zod'; +import { ActionType, registerAction } from './registry.js'; import { parseSchema } from './schema.js'; import * as telemetry from './telemetry.js'; -import { runInNewSpan, SPAN_TYPE_ATTR } from './tracing.js'; +import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js'; export { Status, StatusCodes, StatusSchema } from './statusTypes.js'; export { JSONSchema7 }; @@ -30,6 +31,7 @@ export interface ActionMetadata< O extends z.ZodTypeAny, M extends Record = Record, > { + actionType?: ActionType; name: string; description?: string; inputSchema?: I; @@ -49,25 +51,40 @@ export type Action< export type SideChannelData = Record; +type ActionParams< + I extends z.ZodTypeAny, + O extends z.ZodTypeAny, + M extends Record = Record, +> = { + name: + | string + | { + pluginId: string; + actionId: string; + }; + description?: string; + inputSchema?: I; + inputJsonSchema?: JSONSchema7; + outputSchema?: O; + outputJsonSchema?: JSONSchema7; + metadata?: M; +}; + /** - * + * Creates an action with the provided config. */ export function action< I extends z.ZodTypeAny, O extends z.ZodTypeAny, M extends Record = Record, >( - config: { - name: string; - description?: string; - inputSchema?: I; - inputJsonSchema?: JSONSchema7; - outputSchema?: O; - outputJsonSchema?: JSONSchema7; - metadata?: M; - }, + config: ActionParams, fn: (input: z.infer) => Promise> ): Action { + const actionName = + typeof config.name === 'string' + ? validateActionNameChunk(config.name) + : `${validateActionNameChunk(config.name.pluginId)}/${validateActionNameChunk(config.name.actionId)}`; const actionFn = async (input: I) => { input = parseSchema(input, { schema: config.inputSchema, @@ -76,14 +93,14 @@ export function action< let output = await runInNewSpan( { metadata: { - name: config.name, + name: actionName, }, labels: { [SPAN_TYPE_ATTR]: 'action', }, }, async (metadata) => { - metadata.name = config.name; + metadata.name = actionName; metadata.input = input; const startTimeMs = performance.now(); try { @@ -111,17 +128,43 @@ export function action< return output; }; actionFn.__action = { - name: config.name, + name: actionName, description: config.description, inputSchema: config.inputSchema, inputJsonSchema: config.inputJsonSchema, outputSchema: config.outputSchema, outputJsonSchema: config.outputJsonSchema, metadata: config.metadata, - }; + } as ActionMetadata; return actionFn; } +function validateActionNameChunk(chunk: string) { + if (chunk.includes('/')) { + throw new Error(`Action name must not include slashes (/): ${chunk}`); + } + return chunk; +} + +/** + * Defines an action with the given config and registers it in the registry. + */ +export function defineAction< + I extends z.ZodTypeAny, + O extends z.ZodTypeAny, + M extends Record = Record, +>( + config: ActionParams & { + actionType: ActionType; + }, + fn: (input: z.infer) => Promise> +): Action { + const act = action(config, fn); + act.__action.actionType = config.actionType; + registerAction(config.actionType, act); + return act; +} + // Streaming callback function. export type StreamingCallback = (chunk: T) => void; diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index 51316aa518..c23636d7a3 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -98,11 +98,10 @@ function parsePluginName(registryKey: string) { */ export function registerAction( type: ActionType, - id: string, action: Action ) { logger.info(`Registering ${type}: ${action.__action.name}`); - const key = `/${type}/${id}`; + const key = `/${type}/${action.__action.name}`; if (actionsById().hasOwnProperty(key)) { logger.warn( `WARNING: ${key} already has an entry in the registry. Overwriting.` diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts index 25dd5c0004..c76f57aa51 100644 --- a/js/core/tests/registry_test.ts +++ b/js/core/tests/registry_test.ts @@ -31,42 +31,52 @@ describe('registry', () => { describe('listActions', () => { it('returns all registered actions', async () => { const fooSomethingAction = action( - { name: 'foo/something' }, + { name: 'foo_something' }, async () => null ); - registerAction('model', 'foo/something', fooSomethingAction); + registerAction('model', fooSomethingAction); const barSomethingAction = action( - { name: 'bar/something' }, + { name: 'bar_something' }, async () => null ); - registerAction('model', 'bar/something', barSomethingAction); + registerAction('model', barSomethingAction); assert.deepEqual(await listActions(), { - '/model/foo/something': fooSomethingAction, - '/model/bar/something': barSomethingAction, + '/model/foo_something': fooSomethingAction, + '/model/bar_something': barSomethingAction, }); }); it('returns all registered actions by plugins', async () => { const fooSomethingAction = action( - { name: 'foo/something' }, + { + name: { + pluginId: 'foo', + actionId: 'something', + }, + }, async () => null ); registerPluginProvider('foo', { name: 'foo', async initializer() { - registerAction('model', 'foo/something', fooSomethingAction); + registerAction('model', fooSomethingAction); return {}; }, }); const barSomethingAction = action( - { name: 'bar/something' }, + { + name: { + pluginId: 'bar', + actionId: 'something', + }, + }, async () => null ); registerPluginProvider('bar', { name: 'bar', async initializer() { - registerAction('model', 'bar/something', barSomethingAction); + registerAction('model', barSomethingAction); return {}; }, }); @@ -111,33 +121,41 @@ describe('registry', () => { it('returns registered action', async () => { const fooSomethingAction = action( - { name: 'foo/something' }, + { name: 'foo_something' }, async () => null ); - registerAction('model', 'foo/something', fooSomethingAction); + registerAction('model', fooSomethingAction); const barSomethingAction = action( - { name: 'bar/something' }, + { name: 'bar_something' }, async () => null ); - registerAction('model', 'bar/something', barSomethingAction); + registerAction('model', barSomethingAction); assert.strictEqual( - await lookupAction('/model/foo/something'), + await lookupAction('/model/foo_something'), fooSomethingAction ); assert.strictEqual( - await lookupAction('/model/bar/something'), + await lookupAction('/model/bar_something'), barSomethingAction ); }); it('returns action registered by plugin', async () => { - const somethingAction = action({ name: 'foo/something' }, async () => null); + const somethingAction = action( + { + name: { + pluginId: 'foo', + actionId: 'something', + }, + }, + async () => null + ); registerPluginProvider('foo', { name: 'foo', async initializer() { - registerAction('model', 'foo/something', somethingAction); + registerAction('model', somethingAction); return {}; }, }); diff --git a/js/dotprompt/src/index.ts b/js/dotprompt/src/index.ts index b66df4b883..28dfd7d0cb 100644 --- a/js/dotprompt/src/index.ts +++ b/js/dotprompt/src/index.ts @@ -56,10 +56,6 @@ export function definePrompt( template: string ): Prompt> { const prompt = new Prompt(options, template); - registerAction( - 'prompt', - `${prompt.name}${prompt.variant ? `.${prompt.variant}` : ''}`, - prompt.action() - ); + registerAction('prompt', prompt.action()); return prompt; } diff --git a/js/dotprompt/src/registry.ts b/js/dotprompt/src/registry.ts index 67af3aec8f..e4b34f6c91 100644 --- a/js/dotprompt/src/registry.ts +++ b/js/dotprompt/src/registry.ts @@ -31,11 +31,7 @@ export async function lookupPrompt( if (registryPrompt) return Prompt.fromAction(registryPrompt); const prompt = loadPrompt(name, variant); - registerAction( - 'prompt', - `${name}${variant ? `.${variant}` : ''}`, - prompt.action() - ); + registerAction('prompt', prompt.action()); return prompt; } diff --git a/js/flow/src/flow.ts b/js/flow/src/flow.ts index 1fb5a6bd25..4c19cbbac0 100644 --- a/js/flow/src/flow.ts +++ b/js/flow/src/flow.ts @@ -16,7 +16,7 @@ import { Action, - action, + defineAction, FlowError, FlowState, FlowStateSchema, @@ -28,7 +28,6 @@ import { StreamingCallback, } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { registerAction } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { newTrace, @@ -157,7 +156,7 @@ export function defineFlow< steps ); createdFlows().push(f); - registerAction('flow', config.name, wrapAsAction(f)); + wrapAsAction(f); return f; } @@ -800,8 +799,9 @@ function wrapAsAction< >( flow: Flow ): Action { - return action( + return defineAction( { + actionType: 'flow', name: flow.name, inputSchema: FlowActionInputSchema, outputSchema: FlowStateSchema, From 97f0146d87178a8b7bcdb4732091381e99993474 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 6 May 2024 11:25:52 -0400 Subject: [PATCH 2/4] A few fixes --- js/ai/src/model.ts | 2 +- js/core/src/action.ts | 31 +++++++++++++++++++++++------- js/core/src/plugin.ts | 2 +- js/core/src/registry.ts | 4 ++++ js/core/tests/registry_test.ts | 35 +++++++++++++++++----------------- 5 files changed, 47 insertions(+), 27 deletions(-) diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index b6183340ed..c38e7413c0 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -267,7 +267,7 @@ export function modelWithMiddleware( } /** - * + * Defines a new model and adds it to the registry. */ export function defineModel< CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny, diff --git a/js/core/src/action.ts b/js/core/src/action.ts index b9aad8c89f..44454a5f29 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -18,7 +18,7 @@ import { JSONSchema7 } from 'json-schema'; import { AsyncLocalStorage } from 'node:async_hooks'; import { performance } from 'node:perf_hooks'; import * as z from 'zod'; -import { ActionType, registerAction } from './registry.js'; +import { ActionType, lookupPlugin, registerAction } from './registry.js'; import { parseSchema } from './schema.js'; import * as telemetry from './telemetry.js'; import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js'; @@ -83,8 +83,8 @@ export function action< ): Action { const actionName = typeof config.name === 'string' - ? validateActionNameChunk(config.name) - : `${validateActionNameChunk(config.name.pluginId)}/${validateActionNameChunk(config.name.actionId)}`; + ? validateActionName(config.name) + : `${validatePluginName(config.name.pluginId)}/${validateActionId(config.name.actionId)}`; const actionFn = async (input: I) => { input = parseSchema(input, { schema: config.inputSchema, @@ -139,11 +139,28 @@ export function action< return actionFn; } -function validateActionNameChunk(chunk: string) { - if (chunk.includes('/')) { - throw new Error(`Action name must not include slashes (/): ${chunk}`); +function validateActionName(name: string) { + if (name.includes('/')) { + validatePluginName(name.split('/', 1)[0]); + validateActionId(name.substring(name.indexOf('/') + 1)); } - return chunk; + return name; +} + +function validatePluginName(pluginId: string) { + if (!lookupPlugin(pluginId)) { + throw new Error( + `Unable to find plugin name used in the action name: ${pluginId}` + ); + } + return pluginId; +} + +function validateActionId(actionId: string) { + if (actionId.includes('/')) { + throw new Error(`Action name must not include slashes (/): ${actionId}`); + } + return actionId; } /** diff --git a/js/core/src/plugin.ts b/js/core/src/plugin.ts index a9267506cf..a21a3a5f3d 100644 --- a/js/core/src/plugin.ts +++ b/js/core/src/plugin.ts @@ -54,7 +54,7 @@ type PluginInit = ( export type Plugin = (...args: T) => PluginProvider; /** - * + * Defines a Genkit plugin. */ export function genkitPlugin( pluginName: string, diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index c23636d7a3..b1e447ad71 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -198,6 +198,10 @@ export function registerPluginProvider(name: string, provider: PluginProvider) { }; } +export function lookupPlugin(name: string) { + return pluginsByName()[name]; +} + /** * */ diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts index c76f57aa51..c969eba290 100644 --- a/js/core/tests/registry_test.ts +++ b/js/core/tests/registry_test.ts @@ -48,6 +48,13 @@ describe('registry', () => { }); it('returns all registered actions by plugins', async () => { + registerPluginProvider('foo', { + name: 'foo', + async initializer() { + registerAction('model', fooSomethingAction); + return {}; + }, + }); const fooSomethingAction = action( { name: { @@ -57,10 +64,10 @@ describe('registry', () => { }, async () => null ); - registerPluginProvider('foo', { - name: 'foo', + registerPluginProvider('bar', { + name: 'bar', async initializer() { - registerAction('model', fooSomethingAction); + registerAction('model', barSomethingAction); return {}; }, }); @@ -73,13 +80,6 @@ describe('registry', () => { }, async () => null ); - registerPluginProvider('bar', { - name: 'bar', - async initializer() { - registerAction('model', barSomethingAction); - return {}; - }, - }); assert.deepEqual(await listActions(), { '/model/foo/something': fooSomethingAction, @@ -142,6 +142,13 @@ describe('registry', () => { }); it('returns action registered by plugin', async () => { + registerPluginProvider('foo', { + name: 'foo', + async initializer() { + registerAction('model', somethingAction); + return {}; + }, + }); const somethingAction = action( { name: { @@ -152,14 +159,6 @@ describe('registry', () => { async () => null ); - registerPluginProvider('foo', { - name: 'foo', - async initializer() { - registerAction('model', somethingAction); - return {}; - }, - }); - assert.strictEqual( await lookupAction('/model/foo/something'), somethingAction From f3ae8686ecb9b7b521e88ffc940efb21a34856ab Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 6 May 2024 11:37:30 -0400 Subject: [PATCH 3/4] fixed prompts --- js/ai/src/prompt.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index 2262d84897..1ec0c06ace 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -65,7 +65,7 @@ export function definePrompt( return fn(i); } ); - registerAction('prompt', name, a); + registerAction('prompt', a); return a as PromptAction; } From 2bf9b453df009f6683c6e8e31f61613c0d5cbcfb Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Mon, 6 May 2024 11:38:59 -0400 Subject: [PATCH 4/4] format --- js/flow/src/flow.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/js/flow/src/flow.ts b/js/flow/src/flow.ts index 544e1a242c..f46c8c4e63 100644 --- a/js/flow/src/flow.ts +++ b/js/flow/src/flow.ts @@ -16,24 +16,24 @@ import { Action, - config as globalConfig, - defineAction, FlowError, FlowState, FlowStateSchema, FlowStateStore, - getStreamingCallback, - isDevEnv, Operation, StreamingCallback, + defineAction, + getStreamingCallback, + config as globalConfig, + isDevEnv, } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { + SPAN_TYPE_ATTR, newTrace, setCustomMetadataAttribute, setCustomMetadataAttributes, - SPAN_TYPE_ATTR, } from '@genkit-ai/core/tracing'; import { SpanStatusCode } from '@opentelemetry/api'; import * as bodyParser from 'body-parser'; @@ -45,9 +45,9 @@ import { Context } from './context.js'; import { FlowExecutionError, FlowStillRunningError, + InterruptError, getErrorMessage, getErrorStack, - InterruptError, } from './errors.js'; import * as telemetry from './telemetry.js'; import {