diff --git a/docs/errors/no_new_actions_at_runtime.md b/docs/errors/no_new_actions_at_runtime.md new file mode 100644 index 0000000000..cfed39eac7 --- /dev/null +++ b/docs/errors/no_new_actions_at_runtime.md @@ -0,0 +1,22 @@ +# No new actions at runtime error + +Defining new actions at runtime is not allowed. + +✅ DO: + +```ts +const prompt = defineDotprompt({...}) + +const flow = defineFlow({...}, async (input) => { + await prompt.generate(...); +}) +``` + +❌ DON'T: + +```ts +const flow = defineFlow({...}, async (input) => { + const prompt = defineDotprompt({...}) + prompt.generate(...); +}) +``` diff --git a/js/core/src/action.ts b/js/core/src/action.ts index a78c8a724f..182995b8fd 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -18,7 +18,12 @@ import { JSONSchema7 } from 'json-schema'; import { AsyncLocalStorage } from 'node:async_hooks'; import { performance } from 'node:perf_hooks'; import * as z from 'zod'; -import { ActionType, lookupPlugin, registerAction } from './registry.js'; +import { + ActionType, + initializeAllPlugins, + lookupPlugin, + registerAction, +} from './registry.js'; import { parseSchema } from './schema.js'; import * as telemetry from './telemetry.js'; import { @@ -216,9 +221,16 @@ export function defineAction< }, fn: (input: z.infer) => Promise> ): Action { - const act = action(config, (i: I): Promise> => { + if (isInRuntimeContext()) { + throw new Error( + 'Cannot define new actions at runtime.\n' + + 'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md' + ); + } + const act = action(config, async (i: I): Promise> => { setCustomMetadataAttributes({ subtype: config.actionType }); - return fn(i); + await initializeAllPlugins(); + return await runInActionRuntimeContext(() => fn(i)); }); act.__action.actionType = config.actionType; registerAction(config.actionType, act); @@ -252,3 +264,19 @@ export function getStreamingCallback(): StreamingCallback | undefined { } return cb; } + +const runtimeCtxAls = new AsyncLocalStorage(); + +/** + * Checks whether the caller is currently in the runtime context of an action. + */ +export function isInRuntimeContext() { + return !!runtimeCtxAls.getStore(); +} + +/** + * Execute the provided function in the action runtime context. + */ +export function runInActionRuntimeContext(fn: () => R) { + return runtimeCtxAls.run('runtime', fn); +} diff --git a/js/core/src/plugin.ts b/js/core/src/plugin.ts index a21a3a5f3d..97955ce002 100644 --- a/js/core/src/plugin.ts +++ b/js/core/src/plugin.ts @@ -15,7 +15,7 @@ */ import { z } from 'zod'; -import { Action } from './action.js'; +import { Action, isInRuntimeContext } from './action.js'; import { FlowStateStore } from './flowTypes.js'; import { LoggerConfig, TelemetryConfig } from './telemetryTypes.js'; import { TraceStore } from './tracing.js'; @@ -60,6 +60,12 @@ export function genkitPlugin( pluginName: string, initFn: T ): Plugin> { + if (isInRuntimeContext()) { + throw new Error( + 'Cannot define new plugins at runtime.\n' + + 'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md' + ); + } return (...args: Parameters) => ({ name: pluginName, initializer: async () => { diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index 94647c92be..b46b636baf 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -126,10 +126,19 @@ type ActionsRecord = Record>; * Returns all actions in the registry. */ export async function listActions(): Promise { + await initializeAllPlugins(); + return Object.assign({}, actionsById()); +} + +let allPluginsInitialized = false; +export async function initializeAllPlugins() { + if (allPluginsInitialized) { + return; + } for (const pluginName of Object.keys(pluginsByName())) { await initializePlugin(pluginName); } - return Object.assign({}, actionsById()); + allPluginsInitialized = true; } /** @@ -195,14 +204,17 @@ export async function lookupFlowStateStore( * Registers a flow state store for the given environment. */ export function registerPluginProvider(name: string, provider: PluginProvider) { + allPluginsInitialized = false; let cached; + let isInitialized = false; pluginsByName()[name] = { name: provider.name, initializer: () => { - if (cached) { + if (isInitialized) { return cached; } cached = provider.initializer(); + isInitialized = true; return cached; }, }; diff --git a/js/flow/src/flow.ts b/js/flow/src/flow.ts index 8d318a1c33..1eb51fb7b3 100644 --- a/js/flow/src/flow.ts +++ b/js/flow/src/flow.ts @@ -28,6 +28,7 @@ import { StreamingCallback, } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; +import { initializeAllPlugins } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { newTrace, @@ -391,6 +392,7 @@ export class Flow< labels: Record | undefined ) { const startTimeMs = performance.now(); + await initializeAllPlugins(); await runWithActiveContext(ctx, async () => { let traceContext; if (ctx.state.traceContext) { diff --git a/js/flow/src/utils.ts b/js/flow/src/utils.ts index 774d81c4cd..4110195175 100644 --- a/js/flow/src/utils.ts +++ b/js/flow/src/utils.ts @@ -14,6 +14,7 @@ * limitations under the License. */ +import { runInActionRuntimeContext } from '@genkit-ai/core'; import { AsyncLocalStorage } from 'node:async_hooks'; import { v4 as uuidv4 } from 'uuid'; import z from 'zod'; @@ -45,7 +46,7 @@ export function runWithActiveContext( ctx: Context, fn: () => R ) { - return ctxAsyncLocalStorage.run(ctx, fn); + return ctxAsyncLocalStorage.run(ctx, () => runInActionRuntimeContext(fn)); } /**