diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index b46b636baf..7924a002ef 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -25,45 +25,7 @@ import { TraceStore } from './tracing/types.js'; export type AsyncProvider = () => Promise; -const ACTIONS_BY_ID = 'genkit__ACTIONS_BY_ID'; -const TRACE_STORES_BY_ENV = 'genkit__TRACE_STORES_BY_ENV'; -const FLOW_STATE_STORES_BY_ENV = 'genkit__FLOW_STATE_STORES_BY_ENV'; -const PLUGINS_BY_NAME = 'genkit__PLUGINS_BY_NAME'; -const SCHEMAS_BY_NAME = 'genkit__SCHEMAS_BY_NAME'; - -function actionsById(): Record> { - if (global[ACTIONS_BY_ID] === undefined) { - global[ACTIONS_BY_ID] = {}; - } - return global[ACTIONS_BY_ID]; -} -function traceStoresByEnv(): Record> { - if (global[TRACE_STORES_BY_ENV] === undefined) { - global[TRACE_STORES_BY_ENV] = {}; - } - return global[TRACE_STORES_BY_ENV]; -} -function flowStateStoresByEnv(): Record> { - if (global[FLOW_STATE_STORES_BY_ENV] === undefined) { - global[FLOW_STATE_STORES_BY_ENV] = {}; - } - return global[FLOW_STATE_STORES_BY_ENV]; -} -function pluginsByName(): Record { - if (global[PLUGINS_BY_NAME] === undefined) { - global[PLUGINS_BY_NAME] = {}; - } - return global[PLUGINS_BY_NAME]; -} -function schemasByName(): Record< - string, - { schema?: z.ZodTypeAny; jsonSchema?: JSONSchema } -> { - if (global[SCHEMAS_BY_NAME] === undefined) { - global[SCHEMAS_BY_NAME] = {}; - } - return global[SCHEMAS_BY_NAME]; -} +const REGISTRY_KEY = 'genkit__REGISTRY'; /** * Type of a runnable action. @@ -82,17 +44,12 @@ export type ActionType = /** * Looks up a registry key (action type and key) in the registry. */ -export async function lookupAction< +export function lookupAction< I extends z.ZodTypeAny, O extends z.ZodTypeAny, R extends Action, >(key: string): Promise { - // If we don't see the key in the registry we try to initialize the plugin first. - const pluginName = parsePluginName(key); - if (!actionsById()[key] && pluginName) { - await initializePlugin(pluginName); - } - return actionsById()[key] as R; + return getRegistryInstance().lookupAction(key); } function parsePluginName(registryKey: string) { @@ -110,35 +67,23 @@ export function registerAction( type: ActionType, action: Action ) { - logger.info(`Registering ${type}: ${action.__action.name}`); - const key = `/${type}/${action.__action.name}`; - if (actionsById().hasOwnProperty(key)) { - logger.warn( - `WARNING: ${key} already has an entry in the registry. Overwriting.` - ); - } - actionsById()[key] = action; + return getRegistryInstance().registerAction(type, action); } type ActionsRecord = Record>; /** - * Returns all actions in the registry. + * Initialize all plugins in the registry. */ -export async function listActions(): Promise { - await initializeAllPlugins(); - return Object.assign({}, actionsById()); +export async function initializeAllPlugins() { + await getRegistryInstance().initializeAllPlugins(); } -let allPluginsInitialized = false; -export async function initializeAllPlugins() { - if (allPluginsInitialized) { - return; - } - for (const pluginName of Object.keys(pluginsByName())) { - await initializePlugin(pluginName); - } - allPluginsInitialized = true; +/** + * Returns all actions in the registry. + */ +export function listActions(): Promise { + return getRegistryInstance().listActions(); } /** @@ -148,27 +93,14 @@ export function registerTraceStore( env: string, traceStoreProvider: AsyncProvider ) { - traceStoresByEnv()[env] = traceStoreProvider; + return getRegistryInstance().registerTraceStore(env, traceStoreProvider); } -const traceStoresByEnvCache: Record> = {}; - /** * Looks up the trace store for the given environment. */ -export async function lookupTraceStore( - env: string -): Promise { - if (!traceStoresByEnv()[env]) { - return undefined; - } - const cached = traceStoresByEnvCache[env]; - if (!cached) { - const newStore = traceStoresByEnv()[env](); - traceStoresByEnvCache[env] = newStore; - return newStore; - } - return cached; +export function lookupTraceStore(env: string): Promise { + return getRegistryInstance().lookupTraceStore(env); } /** @@ -178,71 +110,48 @@ export function registerFlowStateStore( env: string, flowStateStoreProvider: AsyncProvider ) { - flowStateStoresByEnv()[env] = flowStateStoreProvider; + return getRegistryInstance().registerFlowStateStore( + env, + flowStateStoreProvider + ); } -const flowStateStoresByEnvCache: Record> = {}; /** * Looks up the flow state store for the given environment. */ export async function lookupFlowStateStore( env: string ): Promise { - if (!flowStateStoresByEnv()[env]) { - return undefined; - } - const cached = flowStateStoresByEnvCache[env]; - if (!cached) { - const newStore = flowStateStoresByEnv()[env](); - flowStateStoresByEnvCache[env] = newStore; - return newStore; - } - return cached; + return getRegistryInstance().lookupFlowStateStore(env); } /** * Registers a flow state store for the given environment. */ export function registerPluginProvider(name: string, provider: PluginProvider) { - allPluginsInitialized = false; - let cached; - let isInitialized = false; - pluginsByName()[name] = { - name: provider.name, - initializer: () => { - if (isInitialized) { - return cached; - } - cached = provider.initializer(); - isInitialized = true; - return cached; - }, - }; + return getRegistryInstance().registerPluginProvider(name, provider); } export function lookupPlugin(name: string) { - return pluginsByName()[name]; + return getRegistryInstance().lookupFlowStateStore(name); } /** - * + * Initialize plugin -- calls the plugin initialization function. */ export async function initializePlugin(name: string) { - if (pluginsByName()[name]) { - return await pluginsByName()[name].initializer(); - } - return undefined; + return getRegistryInstance().initializePlugin(name); } export function registerSchema( name: string, data: { schema?: z.ZodTypeAny; jsonSchema?: JSONSchema } ) { - schemasByName()[name] = data; + return getRegistryInstance().registerSchema(name, data); } export function lookupSchema(name: string) { - return schemasByName()[name]; + return getRegistryInstance().lookupSchema(name); } /** @@ -253,14 +162,187 @@ if (process.env.GENKIT_ENV === 'dev') { } export function __hardResetRegistryForTesting() { - delete global[ACTIONS_BY_ID]; - delete global[TRACE_STORES_BY_ENV]; - delete global[FLOW_STATE_STORES_BY_ENV]; - delete global[PLUGINS_BY_NAME]; - deleteAll(flowStateStoresByEnvCache); - deleteAll(traceStoresByEnvCache); + delete global[REGISTRY_KEY]; + global[REGISTRY_KEY] = new Registry(); +} + +export class Registry { + private actionsById: Record> = {}; + private traceStoresByEnv: Record> = {}; + private flowStateStoresByEnv: Record> = + {}; + private pluginsByName: Record = {}; + private schemasByName: Record< + string, + { schema?: z.ZodTypeAny; jsonSchema?: JSONSchema } + > = {}; + + private traceStoresByEnvCache: Record> = {}; + private flowStateStoresByEnvCache: Record> = {}; + private allPluginsInitialized = false; + + constructor(public parent?: Registry) {} + + static withCurrent() { + return new Registry(getRegistryInstance()); + } + + static withParent(parent: Registry) { + return new Registry(parent); + } + + async lookupAction< + I extends z.ZodTypeAny, + O extends z.ZodTypeAny, + R extends Action, + >(key: string): Promise { + // If we don't see the key in the registry we try to initialize the plugin first. + const pluginName = parsePluginName(key); + if (!this.actionsById[key] && pluginName) { + await this.initializePlugin(pluginName); + } + return (this.actionsById[key] as R) || this.parent?.lookupAction(key); + } + + registerAction( + type: ActionType, + action: Action + ) { + logger.info(`Registering ${type}: ${action.__action.name}`); + const key = `/${type}/${action.__action.name}`; + if (this.actionsById.hasOwnProperty(key)) { + logger.warn( + `WARNING: ${key} already has an entry in the registry. Overwriting.` + ); + } + this.actionsById[key] = action; + } + + async listActions(): Promise { + await this.initializeAllPlugins(); + return { + ...(await this.parent?.listActions()), + ...this.actionsById, + }; + } + + async initializeAllPlugins() { + if (this.allPluginsInitialized) { + return; + } + for (const pluginName of Object.keys(this.pluginsByName)) { + await initializePlugin(pluginName); + } + this.allPluginsInitialized = true; + } + + registerTraceStore( + env: string, + traceStoreProvider: AsyncProvider + ) { + this.traceStoresByEnv[env] = traceStoreProvider; + } + + async lookupTraceStore(env: string): Promise { + return ( + (await this.lookupOverlaidTraceStore(env)) || + this.parent?.lookupTraceStore(env) + ); + } + + private async lookupOverlaidTraceStore( + env: string + ): Promise { + if (!this.traceStoresByEnv[env]) { + return undefined; + } + const cached = this.traceStoresByEnvCache[env]; + if (!cached) { + const newStore = this.traceStoresByEnv[env](); + this.traceStoresByEnvCache[env] = newStore; + return newStore; + } + return cached; + } + + registerFlowStateStore( + env: string, + flowStateStoreProvider: AsyncProvider + ) { + this.flowStateStoresByEnv[env] = flowStateStoreProvider; + } + + async lookupFlowStateStore(env: string): Promise { + return ( + (await this.lookupOverlaidFlowStateStore(env)) || + this.parent?.lookupFlowStateStore(env) + ); + } + + private async lookupOverlaidFlowStateStore( + env: string + ): Promise { + if (!this.flowStateStoresByEnv[env]) { + return undefined; + } + const cached = this.flowStateStoresByEnvCache[env]; + if (!cached) { + const newStore = this.flowStateStoresByEnv[env](); + this.flowStateStoresByEnvCache[env] = newStore; + return newStore; + } + return cached; + } + + registerPluginProvider(name: string, provider: PluginProvider) { + this.allPluginsInitialized = false; + let cached; + let isInitialized = false; + this.pluginsByName[name] = { + name: provider.name, + initializer: () => { + if (isInitialized) { + return cached; + } + cached = provider.initializer(); + isInitialized = true; + return cached; + }, + }; + } + + lookupPlugin(name: string) { + return this.pluginsByName[name] || this.parent?.lookupPlugin(name); + } + + async initializePlugin(name: string) { + if (this.pluginsByName[name]) { + return await this.pluginsByName[name].initializer(); + } + return undefined; + } + + registerSchema( + name: string, + data: { schema?: z.ZodTypeAny; jsonSchema?: JSONSchema } + ) { + this.schemasByName[name] = data; + } + + lookupSchema(name: string) { + return this.schemasByName[name] || this.parent?.lookupSchema(name); + } +} + +// global regustry instance +global[REGISTRY_KEY] = new Registry(); + +/** Returns the current registry instance. */ +export function getRegistryInstance(): Registry { + return global[REGISTRY_KEY]; } -function deleteAll(map: Record) { - Object.keys(map).forEach((key) => delete map[key]); +/** Sets global registry instance. */ +export function setRegistryInstance(reg: Registry) { + global[REGISTRY_KEY] = reg; } diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts index c969eba290..23c165c6eb 100644 --- a/js/core/tests/registry_test.ts +++ b/js/core/tests/registry_test.ts @@ -15,9 +15,10 @@ */ import assert from 'node:assert'; -import { beforeEach, describe, it } from 'node:test'; +import { afterEach, beforeEach, describe, it } from 'node:test'; import { action } from '../src/action.js'; import { + Registry, __hardResetRegistryForTesting, listActions, lookupAction, @@ -25,8 +26,9 @@ import { registerPluginProvider, } from '../src/registry.js'; -describe('registry', () => { +describe('global registry', () => { beforeEach(__hardResetRegistryForTesting); + afterEach(__hardResetRegistryForTesting); describe('listActions', () => { it('returns all registered actions', async () => { @@ -169,3 +171,207 @@ describe('registry', () => { assert.strictEqual(await lookupAction('/model/foo/something'), undefined); }); }); + +describe('registry class', () => { + var registry: Registry; + beforeEach(() => { + registry = new Registry(); + }); + + describe('listActions', () => { + it('returns all registered actions', async () => { + const fooSomethingAction = action( + { name: 'foo_something' }, + async () => null + ); + registry.registerAction('model', fooSomethingAction); + const barSomethingAction = action( + { name: 'bar_something' }, + async () => null + ); + registry.registerAction('model', barSomethingAction); + + assert.deepEqual(await registry.listActions(), { + '/model/foo_something': fooSomethingAction, + '/model/bar_something': barSomethingAction, + }); + }); + + it('returns all registered actions by plugins', async () => { + registry.registerPluginProvider('foo', { + name: 'foo', + async initializer() { + registry.registerAction('model', fooSomethingAction); + return {}; + }, + }); + const fooSomethingAction = action( + { + name: { + pluginId: 'foo', + actionId: 'something', + }, + }, + async () => null + ); + registry.registerAction('custom', fooSomethingAction); + registry.registerPluginProvider('bar', { + name: 'bar', + async initializer() { + registry.registerAction('model', barSomethingAction); + return {}; + }, + }); + const barSomethingAction = action( + { + name: { + pluginId: 'bar', + actionId: 'something', + }, + }, + async () => null + ); + registry.registerAction('custom', barSomethingAction); + + assert.deepEqual(await registry.listActions(), { + '/custom/foo/something': fooSomethingAction, + '/custom/bar/something': barSomethingAction, + }); + }); + + it('returns all registered actions, including parent', async () => { + const child = Registry.withParent(registry); + + const fooSomethingAction = action( + { name: 'foo_something' }, + async () => null + ); + registry.registerAction('model', fooSomethingAction); + const barSomethingAction = action( + { name: 'bar_something' }, + async () => null + ); + child.registerAction('model', barSomethingAction); + + assert.deepEqual(await child.listActions(), { + '/model/foo_something': fooSomethingAction, + '/model/bar_something': barSomethingAction, + }); + assert.deepEqual(await registry.listActions(), { + '/model/foo_something': fooSomethingAction, + }); + }); + }); + + describe('lookupAction', () => { + it('initializes plugin for action first', async () => { + let fooInitialized = false; + registry.registerPluginProvider('foo', { + name: 'foo', + async initializer() { + fooInitialized = true; + return {}; + }, + }); + let barInitialized = false; + registry.registerPluginProvider('bar', { + name: 'bar', + async initializer() { + barInitialized = true; + return {}; + }, + }); + + await registry.lookupAction('/model/foo/something'); + + assert.strictEqual(fooInitialized, true); + assert.strictEqual(barInitialized, false); + + await registry.lookupAction('/model/bar/something'); + + assert.strictEqual(fooInitialized, true); + assert.strictEqual(barInitialized, true); + }); + + it('returns registered action', async () => { + const fooSomethingAction = action( + { name: 'foo_something' }, + async () => null + ); + registry.registerAction('model', fooSomethingAction); + const barSomethingAction = action( + { name: 'bar_something' }, + async () => null + ); + registry.registerAction('model', barSomethingAction); + + assert.strictEqual( + await registry.lookupAction('/model/foo_something'), + fooSomethingAction + ); + assert.strictEqual( + await registry.lookupAction('/model/bar_something'), + barSomethingAction + ); + }); + + it('returns action registered by plugin', async () => { + registry.registerPluginProvider('foo', { + name: 'foo', + async initializer() { + registry.registerAction('model', somethingAction); + return {}; + }, + }); + const somethingAction = action( + { + name: { + pluginId: 'foo', + actionId: 'something', + }, + }, + async () => null + ); + + assert.strictEqual( + await registry.lookupAction('/model/foo/something'), + somethingAction + ); + }); + + it('returns undefined for unknown action', async () => { + assert.strictEqual( + await registry.lookupAction('/model/foo/something'), + undefined + ); + }); + + it('should lookup parent registry when child missing action', async () => { + const childRegistry = new Registry(registry); + + const fooAction = action({ name: 'foo' }, async () => null); + registry.registerAction('model', fooAction); + + assert.strictEqual(await registry.lookupAction('/model/foo'), fooAction); + assert.strictEqual( + await childRegistry.lookupAction('/model/foo'), + fooAction + ); + }); + + it('registration on the child registry should not modify parent', async () => { + const childRegistry = Registry.withParent(registry); + + assert.strictEqual(childRegistry.parent, registry); + + const fooAction = action({ name: 'foo' }, async () => null); + childRegistry.registerAction('model', fooAction); + + assert.strictEqual(await registry.lookupAction('/model/foo'), undefined); + assert.strictEqual( + await childRegistry.lookupAction('/model/foo'), + fooAction + ); + }); + }); +});