diff --git a/js/ai/src/embedder.ts b/js/ai/src/embedder.ts index bf9102e2e2..92fab96eef 100644 --- a/js/ai/src/embedder.ts +++ b/js/ai/src/embedder.ts @@ -14,8 +14,8 @@ * limitations under the License. */ -import { action, Action } from '@genkit-ai/core'; -import { lookupAction, registerAction } from '@genkit-ai/core/registry'; +import { Action, defineAction } from '@genkit-ai/core'; +import { lookupAction } from '@genkit-ai/core/registry'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import * as z from 'zod'; import { Document, DocumentData, DocumentDataSchema } from './document.js'; @@ -68,8 +68,9 @@ export function defineEmbedder< }, runner: EmbedderFn ) { - const embedder = action( + const embedder = defineAction( { + actionType: 'embedder', name: options.name, inputSchema: options.configSchema ? EmbedRequestSchema.extend({ @@ -94,7 +95,6 @@ export function defineEmbedder< embedder as Action, options.configSchema ); - registerAction('embedder', ewm.__action.name, ewm); return ewm; } diff --git a/js/ai/src/evaluator.ts b/js/ai/src/evaluator.ts index 02c800ffd8..1e27a83197 100644 --- a/js/ai/src/evaluator.ts +++ b/js/ai/src/evaluator.ts @@ -14,13 +14,13 @@ * limitations under the License. */ -import { action, Action } from '@genkit-ai/core'; +import { Action, defineAction } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { lookupAction, registerAction } from '@genkit-ai/core/registry'; +import { lookupAction } from '@genkit-ai/core/registry'; import { + SPAN_TYPE_ATTR, runInNewSpan, setCustomMetadataAttributes, - SPAN_TYPE_ATTR, } from '@genkit-ai/core/tracing'; import * as z from 'zod'; @@ -128,8 +128,9 @@ export function defineEvaluator< options.isBilled == undefined ? true : options.isBilled; metadata[EVALUATOR_METADATA_KEY_DISPLAY_NAME] = options.displayName; metadata[EVALUATOR_METADATA_KEY_DEFINITION] = options.definition; - const evaluator = action( + const evaluator = defineAction( { + actionType: 'evaluator', name: options.name, inputSchema: EvalRequestSchema.extend({ dataset: options.dataPointType @@ -205,7 +206,6 @@ export function defineEvaluator< options.dataPointType, options.configSchema ); - registerAction('evaluator', evaluator.__action.name, evaluator); return ewm; } diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 30d5d3e4d4..4d118917a0 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -16,11 +16,10 @@ import { Action, - action, + defineAction, getStreamingCallback, StreamingCallback, } from '@genkit-ai/core'; -import { registerAction } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import { performance } from 'node:perf_hooks'; @@ -270,7 +269,7 @@ export function modelWithMiddleware( } /** - * + * Defines a new model and adds it to the registry. */ export function defineModel< CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny, @@ -293,8 +292,9 @@ export function defineModel< ) => Promise ): ModelAction { const label = options.label || `${options.name} GenAI model`; - const act = action( + const act = defineAction( { + actionType: 'model', name: options.name, description: label, inputSchema: GenerateRequestSchema, @@ -342,7 +342,6 @@ export function defineModel< act as ModelAction, middleware ) as ModelAction; - registerAction('model', ma.__action.name, ma); return ma; } diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index 2262d84897..1ec0c06ace 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -65,7 +65,7 @@ export function definePrompt( return fn(i); } ); - registerAction('prompt', name, a); + registerAction('prompt', a); return a as PromptAction; } diff --git a/js/ai/src/retriever.ts b/js/ai/src/retriever.ts index 9a77a11f56..e8b2afcdfb 100644 --- a/js/ai/src/retriever.ts +++ b/js/ai/src/retriever.ts @@ -14,8 +14,8 @@ * limitations under the License. */ -import { action, Action } from '@genkit-ai/core'; -import { lookupAction, registerAction } from '@genkit-ai/core/registry'; +import { Action, defineAction } from '@genkit-ai/core'; +import { lookupAction } from '@genkit-ai/core/registry'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import * as z from 'zod'; import { Document, DocumentData, DocumentDataSchema } from './document.js'; @@ -113,8 +113,9 @@ export function defineRetriever< }, runner: RetrieverFn ) { - const retriever = action( + const retriever = defineAction( { + actionType: 'retriever', name: options.name, inputSchema: options.configSchema ? RetrieverRequestSchema.extend({ @@ -139,7 +140,6 @@ export function defineRetriever< >, options.configSchema ); - registerAction('retriever', rwm.__action.name, rwm); return rwm; } @@ -154,8 +154,9 @@ export function defineIndexer( }, runner: IndexerFn ) { - const indexer = action( + const indexer = defineAction( { + actionType: 'indexer', name: options.name, inputSchema: options.configSchema ? IndexerRequestSchema.extend({ @@ -180,7 +181,6 @@ export function defineIndexer( indexer as Action, options.configSchema ); - registerAction('indexer', iwm.__action.name, iwm); return iwm; } diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index e430d4dcd6..c20043bc09 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -14,8 +14,8 @@ * limitations under the License. */ -import { Action, JSONSchema7, action } from '@genkit-ai/core'; -import { lookupAction, registerAction } from '@genkit-ai/core/registry'; +import { Action, defineAction, JSONSchema7 } from '@genkit-ai/core'; +import { lookupAction } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; import z from 'zod'; @@ -117,8 +117,9 @@ export function defineTool( }, fn: (input: z.infer) => Promise> ): ToolAction { - const a = action( + const a = defineAction( { + actionType: 'tool', name, description, inputSchema, @@ -132,6 +133,5 @@ export function defineTool( return fn(i); } ); - registerAction('tool', name, a); return a as ToolAction; } diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 15eae303a0..44454a5f29 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -18,9 +18,10 @@ import { JSONSchema7 } from 'json-schema'; import { AsyncLocalStorage } from 'node:async_hooks'; import { performance } from 'node:perf_hooks'; import * as z from 'zod'; +import { ActionType, lookupPlugin, registerAction } from './registry.js'; import { parseSchema } from './schema.js'; import * as telemetry from './telemetry.js'; -import { runInNewSpan, SPAN_TYPE_ATTR } from './tracing.js'; +import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js'; export { Status, StatusCodes, StatusSchema } from './statusTypes.js'; export { JSONSchema7 }; @@ -30,6 +31,7 @@ export interface ActionMetadata< O extends z.ZodTypeAny, M extends Record = Record, > { + actionType?: ActionType; name: string; description?: string; inputSchema?: I; @@ -49,25 +51,40 @@ export type Action< export type SideChannelData = Record; +type ActionParams< + I extends z.ZodTypeAny, + O extends z.ZodTypeAny, + M extends Record = Record, +> = { + name: + | string + | { + pluginId: string; + actionId: string; + }; + description?: string; + inputSchema?: I; + inputJsonSchema?: JSONSchema7; + outputSchema?: O; + outputJsonSchema?: JSONSchema7; + metadata?: M; +}; + /** - * + * Creates an action with the provided config. */ export function action< I extends z.ZodTypeAny, O extends z.ZodTypeAny, M extends Record = Record, >( - config: { - name: string; - description?: string; - inputSchema?: I; - inputJsonSchema?: JSONSchema7; - outputSchema?: O; - outputJsonSchema?: JSONSchema7; - metadata?: M; - }, + config: ActionParams, fn: (input: z.infer) => Promise> ): Action { + const actionName = + typeof config.name === 'string' + ? validateActionName(config.name) + : `${validatePluginName(config.name.pluginId)}/${validateActionId(config.name.actionId)}`; const actionFn = async (input: I) => { input = parseSchema(input, { schema: config.inputSchema, @@ -76,14 +93,14 @@ export function action< let output = await runInNewSpan( { metadata: { - name: config.name, + name: actionName, }, labels: { [SPAN_TYPE_ATTR]: 'action', }, }, async (metadata) => { - metadata.name = config.name; + metadata.name = actionName; metadata.input = input; const startTimeMs = performance.now(); try { @@ -111,17 +128,60 @@ export function action< return output; }; actionFn.__action = { - name: config.name, + name: actionName, description: config.description, inputSchema: config.inputSchema, inputJsonSchema: config.inputJsonSchema, outputSchema: config.outputSchema, outputJsonSchema: config.outputJsonSchema, metadata: config.metadata, - }; + } as ActionMetadata; return actionFn; } +function validateActionName(name: string) { + if (name.includes('/')) { + validatePluginName(name.split('/', 1)[0]); + validateActionId(name.substring(name.indexOf('/') + 1)); + } + return name; +} + +function validatePluginName(pluginId: string) { + if (!lookupPlugin(pluginId)) { + throw new Error( + `Unable to find plugin name used in the action name: ${pluginId}` + ); + } + return pluginId; +} + +function validateActionId(actionId: string) { + if (actionId.includes('/')) { + throw new Error(`Action name must not include slashes (/): ${actionId}`); + } + return actionId; +} + +/** + * Defines an action with the given config and registers it in the registry. + */ +export function defineAction< + I extends z.ZodTypeAny, + O extends z.ZodTypeAny, + M extends Record = Record, +>( + config: ActionParams & { + actionType: ActionType; + }, + fn: (input: z.infer) => Promise> +): Action { + const act = action(config, fn); + act.__action.actionType = config.actionType; + registerAction(config.actionType, act); + return act; +} + // Streaming callback function. export type StreamingCallback = (chunk: T) => void; diff --git a/js/core/src/plugin.ts b/js/core/src/plugin.ts index a9267506cf..a21a3a5f3d 100644 --- a/js/core/src/plugin.ts +++ b/js/core/src/plugin.ts @@ -54,7 +54,7 @@ type PluginInit = ( export type Plugin = (...args: T) => PluginProvider; /** - * + * Defines a Genkit plugin. */ export function genkitPlugin( pluginName: string, diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index 51316aa518..b1e447ad71 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -98,11 +98,10 @@ function parsePluginName(registryKey: string) { */ export function registerAction( type: ActionType, - id: string, action: Action ) { logger.info(`Registering ${type}: ${action.__action.name}`); - const key = `/${type}/${id}`; + const key = `/${type}/${action.__action.name}`; if (actionsById().hasOwnProperty(key)) { logger.warn( `WARNING: ${key} already has an entry in the registry. Overwriting.` @@ -199,6 +198,10 @@ export function registerPluginProvider(name: string, provider: PluginProvider) { }; } +export function lookupPlugin(name: string) { + return pluginsByName()[name]; +} + /** * */ diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts index 25dd5c0004..c969eba290 100644 --- a/js/core/tests/registry_test.ts +++ b/js/core/tests/registry_test.ts @@ -31,45 +31,55 @@ describe('registry', () => { describe('listActions', () => { it('returns all registered actions', async () => { const fooSomethingAction = action( - { name: 'foo/something' }, + { name: 'foo_something' }, async () => null ); - registerAction('model', 'foo/something', fooSomethingAction); + registerAction('model', fooSomethingAction); const barSomethingAction = action( - { name: 'bar/something' }, + { name: 'bar_something' }, async () => null ); - registerAction('model', 'bar/something', barSomethingAction); + registerAction('model', barSomethingAction); assert.deepEqual(await listActions(), { - '/model/foo/something': fooSomethingAction, - '/model/bar/something': barSomethingAction, + '/model/foo_something': fooSomethingAction, + '/model/bar_something': barSomethingAction, }); }); it('returns all registered actions by plugins', async () => { - const fooSomethingAction = action( - { name: 'foo/something' }, - async () => null - ); registerPluginProvider('foo', { name: 'foo', async initializer() { - registerAction('model', 'foo/something', fooSomethingAction); + registerAction('model', fooSomethingAction); return {}; }, }); - const barSomethingAction = action( - { name: 'bar/something' }, + const fooSomethingAction = action( + { + name: { + pluginId: 'foo', + actionId: 'something', + }, + }, async () => null ); registerPluginProvider('bar', { name: 'bar', async initializer() { - registerAction('model', 'bar/something', barSomethingAction); + registerAction('model', barSomethingAction); return {}; }, }); + const barSomethingAction = action( + { + name: { + pluginId: 'bar', + actionId: 'something', + }, + }, + async () => null + ); assert.deepEqual(await listActions(), { '/model/foo/something': fooSomethingAction, @@ -111,36 +121,43 @@ describe('registry', () => { it('returns registered action', async () => { const fooSomethingAction = action( - { name: 'foo/something' }, + { name: 'foo_something' }, async () => null ); - registerAction('model', 'foo/something', fooSomethingAction); + registerAction('model', fooSomethingAction); const barSomethingAction = action( - { name: 'bar/something' }, + { name: 'bar_something' }, async () => null ); - registerAction('model', 'bar/something', barSomethingAction); + registerAction('model', barSomethingAction); assert.strictEqual( - await lookupAction('/model/foo/something'), + await lookupAction('/model/foo_something'), fooSomethingAction ); assert.strictEqual( - await lookupAction('/model/bar/something'), + await lookupAction('/model/bar_something'), barSomethingAction ); }); it('returns action registered by plugin', async () => { - const somethingAction = action({ name: 'foo/something' }, async () => null); - registerPluginProvider('foo', { name: 'foo', async initializer() { - registerAction('model', 'foo/something', somethingAction); + registerAction('model', somethingAction); return {}; }, }); + const somethingAction = action( + { + name: { + pluginId: 'foo', + actionId: 'something', + }, + }, + async () => null + ); assert.strictEqual( await lookupAction('/model/foo/something'), diff --git a/js/flow/src/flow.ts b/js/flow/src/flow.ts index f6a037a1b5..f46c8c4e63 100644 --- a/js/flow/src/flow.ts +++ b/js/flow/src/flow.ts @@ -22,13 +22,12 @@ import { FlowStateStore, Operation, StreamingCallback, - action, + defineAction, getStreamingCallback, config as globalConfig, isDevEnv, } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { registerAction } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { SPAN_TYPE_ATTR, @@ -157,7 +156,7 @@ export function defineFlow< steps ); createdFlows().push(f); - registerAction('flow', config.name, wrapAsAction(f)); + wrapAsAction(f); return f; } @@ -803,8 +802,9 @@ function wrapAsAction< >( flow: Flow ): Action { - return action( + return defineAction( { + actionType: 'flow', name: flow.name, inputSchema: FlowActionInputSchema, outputSchema: FlowStateSchema,