Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions js/ai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
* limitations under the License.
*/

import { action, Action } from '@genkit-ai/core';
import { lookupAction, registerAction } from '@genkit-ai/core/registry';
import { Action, defineAction } from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing';
import * as z from 'zod';
import { Document, DocumentData, DocumentDataSchema } from './document.js';
Expand Down Expand Up @@ -68,8 +68,9 @@ export function defineEmbedder<
},
runner: EmbedderFn<ConfigSchema>
) {
const embedder = action(
const embedder = defineAction(
{
actionType: 'embedder',
name: options.name,
inputSchema: options.configSchema
? EmbedRequestSchema.extend({
Expand All @@ -94,7 +95,6 @@ export function defineEmbedder<
embedder as Action<typeof EmbedRequestSchema, typeof EmbedResponseSchema>,
options.configSchema
);
registerAction('embedder', ewm.__action.name, ewm);
return ewm;
}

Expand Down
10 changes: 5 additions & 5 deletions js/ai/src/evaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
* limitations under the License.
*/

import { action, Action } from '@genkit-ai/core';
import { Action, defineAction } from '@genkit-ai/core';
import { logger } from '@genkit-ai/core/logging';
import { lookupAction, registerAction } from '@genkit-ai/core/registry';
import { lookupAction } from '@genkit-ai/core/registry';
import {
SPAN_TYPE_ATTR,
runInNewSpan,
setCustomMetadataAttributes,
SPAN_TYPE_ATTR,
} from '@genkit-ai/core/tracing';
import * as z from 'zod';

Expand Down Expand Up @@ -128,8 +128,9 @@ export function defineEvaluator<
options.isBilled == undefined ? true : options.isBilled;
metadata[EVALUATOR_METADATA_KEY_DISPLAY_NAME] = options.displayName;
metadata[EVALUATOR_METADATA_KEY_DEFINITION] = options.definition;
const evaluator = action(
const evaluator = defineAction(
{
actionType: 'evaluator',
name: options.name,
inputSchema: EvalRequestSchema.extend({
dataset: options.dataPointType
Expand Down Expand Up @@ -205,7 +206,6 @@ export function defineEvaluator<
options.dataPointType,
options.configSchema
);
registerAction('evaluator', evaluator.__action.name, evaluator);
return ewm;
}

Expand Down
9 changes: 4 additions & 5 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

import {
Action,
action,
defineAction,
getStreamingCallback,
StreamingCallback,
} from '@genkit-ai/core';
import { registerAction } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing';
import { performance } from 'node:perf_hooks';
Expand Down Expand Up @@ -270,7 +269,7 @@ export function modelWithMiddleware(
}

/**
*
* Defines a new model and adds it to the registry.
*/
export function defineModel<
CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,
Expand All @@ -293,8 +292,9 @@ export function defineModel<
) => Promise<GenerateResponseData>
): ModelAction<CustomOptionsSchema> {
const label = options.label || `${options.name} GenAI model`;
const act = action(
const act = defineAction(
{
actionType: 'model',
name: options.name,
description: label,
inputSchema: GenerateRequestSchema,
Expand Down Expand Up @@ -342,7 +342,6 @@ export function defineModel<
act as ModelAction,
middleware
) as ModelAction<CustomOptionsSchema>;
registerAction('model', ma.__action.name, ma);
return ma;
}

Expand Down
2 changes: 1 addition & 1 deletion js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export function definePrompt<I extends z.ZodTypeAny>(
return fn(i);
}
);
registerAction('prompt', name, a);
registerAction('prompt', a);
return a as PromptAction<I>;
}

Expand Down
12 changes: 6 additions & 6 deletions js/ai/src/retriever.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
* limitations under the License.
*/

import { action, Action } from '@genkit-ai/core';
import { lookupAction, registerAction } from '@genkit-ai/core/registry';
import { Action, defineAction } from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing';
import * as z from 'zod';
import { Document, DocumentData, DocumentDataSchema } from './document.js';
Expand Down Expand Up @@ -113,8 +113,9 @@ export function defineRetriever<
},
runner: RetrieverFn<OptionsType>
) {
const retriever = action(
const retriever = defineAction(
{
actionType: 'retriever',
name: options.name,
inputSchema: options.configSchema
? RetrieverRequestSchema.extend({
Expand All @@ -139,7 +140,6 @@ export function defineRetriever<
>,
options.configSchema
);
registerAction('retriever', rwm.__action.name, rwm);
return rwm;
}

Expand All @@ -154,8 +154,9 @@ export function defineIndexer<IndexerOptions extends z.ZodTypeAny>(
},
runner: IndexerFn<IndexerOptions>
) {
const indexer = action(
const indexer = defineAction(
{
actionType: 'indexer',
name: options.name,
inputSchema: options.configSchema
? IndexerRequestSchema.extend({
Expand All @@ -180,7 +181,6 @@ export function defineIndexer<IndexerOptions extends z.ZodTypeAny>(
indexer as Action<typeof IndexerRequestSchema, z.ZodVoid>,
options.configSchema
);
registerAction('indexer', iwm.__action.name, iwm);
return iwm;
}

Expand Down
8 changes: 4 additions & 4 deletions js/ai/src/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
* limitations under the License.
*/

import { Action, JSONSchema7, action } from '@genkit-ai/core';
import { lookupAction, registerAction } from '@genkit-ai/core/registry';
import { Action, defineAction, JSONSchema7 } from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing';
import z from 'zod';
Expand Down Expand Up @@ -117,8 +117,9 @@ export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
},
fn: (input: z.infer<I>) => Promise<z.infer<O>>
): ToolAction<I, O> {
const a = action(
const a = defineAction(
{
actionType: 'tool',
name,
description,
inputSchema,
Expand All @@ -132,6 +133,5 @@ export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
return fn(i);
}
);
registerAction('tool', name, a);
return a as ToolAction<I, O>;
}
90 changes: 75 additions & 15 deletions js/core/src/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ import { JSONSchema7 } from 'json-schema';
import { AsyncLocalStorage } from 'node:async_hooks';
import { performance } from 'node:perf_hooks';
import * as z from 'zod';
import { ActionType, lookupPlugin, registerAction } from './registry.js';
import { parseSchema } from './schema.js';
import * as telemetry from './telemetry.js';
import { runInNewSpan, SPAN_TYPE_ATTR } from './tracing.js';
import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js';

export { Status, StatusCodes, StatusSchema } from './statusTypes.js';
export { JSONSchema7 };
Expand All @@ -30,6 +31,7 @@ export interface ActionMetadata<
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
> {
actionType?: ActionType;
name: string;
description?: string;
inputSchema?: I;
Expand All @@ -49,25 +51,40 @@ export type Action<

export type SideChannelData = Record<string, any>;

type ActionParams<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
> = {
name:
| string
| {
pluginId: string;
actionId: string;
};
description?: string;
inputSchema?: I;
inputJsonSchema?: JSONSchema7;
outputSchema?: O;
outputJsonSchema?: JSONSchema7;
metadata?: M;
};

/**
*
* Creates an action with the provided config.
*/
export function action<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
>(
config: {
name: string;
description?: string;
inputSchema?: I;
inputJsonSchema?: JSONSchema7;
outputSchema?: O;
outputJsonSchema?: JSONSchema7;
metadata?: M;
},
config: ActionParams<I, O, M>,
fn: (input: z.infer<I>) => Promise<z.infer<O>>
): Action<I, O> {
const actionName =
typeof config.name === 'string'
? validateActionName(config.name)
: `${validatePluginName(config.name.pluginId)}/${validateActionId(config.name.actionId)}`;
const actionFn = async (input: I) => {
input = parseSchema(input, {
schema: config.inputSchema,
Expand All @@ -76,14 +93,14 @@ export function action<
let output = await runInNewSpan(
{
metadata: {
name: config.name,
name: actionName,
},
labels: {
[SPAN_TYPE_ATTR]: 'action',
},
},
async (metadata) => {
metadata.name = config.name;
metadata.name = actionName;
metadata.input = input;
const startTimeMs = performance.now();
try {
Expand Down Expand Up @@ -111,17 +128,60 @@ export function action<
return output;
};
actionFn.__action = {
name: config.name,
name: actionName,
description: config.description,
inputSchema: config.inputSchema,
inputJsonSchema: config.inputJsonSchema,
outputSchema: config.outputSchema,
outputJsonSchema: config.outputJsonSchema,
metadata: config.metadata,
};
} as ActionMetadata<I, O, M>;
return actionFn;
}

function validateActionName(name: string) {
if (name.includes('/')) {
validatePluginName(name.split('/', 1)[0]);
validateActionId(name.substring(name.indexOf('/') + 1));
}
return name;
}

function validatePluginName(pluginId: string) {
if (!lookupPlugin(pluginId)) {
throw new Error(
`Unable to find plugin name used in the action name: ${pluginId}`
);
}
return pluginId;
}

function validateActionId(actionId: string) {
if (actionId.includes('/')) {
throw new Error(`Action name must not include slashes (/): ${actionId}`);
}
return actionId;
}

/**
* Defines an action with the given config and registers it in the registry.
*/
export function defineAction<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
>(
config: ActionParams<I, O, M> & {
actionType: ActionType;
},
fn: (input: z.infer<I>) => Promise<z.infer<O>>
): Action<I, O> {
const act = action(config, fn);
act.__action.actionType = config.actionType;
registerAction(config.actionType, act);
return act;
}

// Streaming callback function.
export type StreamingCallback<T> = (chunk: T) => void;

Expand Down
2 changes: 1 addition & 1 deletion js/core/src/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type PluginInit = (
export type Plugin<T extends any[]> = (...args: T) => PluginProvider;

/**
*
* Defines a Genkit plugin.
*/
export function genkitPlugin<T extends PluginInit>(
pluginName: string,
Expand Down
7 changes: 5 additions & 2 deletions js/core/src/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,10 @@ function parsePluginName(registryKey: string) {
*/
export function registerAction<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
type: ActionType,
id: string,
action: Action<I, O>
) {
logger.info(`Registering ${type}: ${action.__action.name}`);
const key = `/${type}/${id}`;
const key = `/${type}/${action.__action.name}`;
if (actionsById().hasOwnProperty(key)) {
logger.warn(
`WARNING: ${key} already has an entry in the registry. Overwriting.`
Expand Down Expand Up @@ -199,6 +198,10 @@ export function registerPluginProvider(name: string, provider: PluginProvider) {
};
}

export function lookupPlugin(name: string) {
return pluginsByName()[name];
}

/**
*
*/
Expand Down
Loading