Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@

import {
getStreamingCallback,
Middleware,
runWithStreamingCallback,
z,
} from '@genkit-ai/core';
import { logger } from '@genkit-ai/core/logging';
import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing';
import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing';
import * as clc from 'colorette';
import { DocumentDataSchema } from '../document.js';
import { resolveFormat } from '../formats/index.js';
Expand All @@ -40,13 +39,14 @@ import {
GenerateResponseData,
MessageData,
MessageSchema,
ModelMiddleware,
Part,
resolveModel,
Role,
ToolDefinitionSchema,
ToolResponsePart,
resolveModel,
} from '../model.js';
import { resolveTools, ToolAction, toToolDefinition } from '../tool.js';
import { ToolAction, resolveTools, toToolDefinition } from '../tool.js';

export const GenerateUtilParamSchema = z.object({
/** A model name (e.g. `vertexai/gemini-1.0-pro`). */
Expand Down Expand Up @@ -78,7 +78,7 @@ export const GenerateUtilParamSchema = z.object({
export async function generateHelper(
registry: Registry,
input: z.infer<typeof GenerateUtilParamSchema>,
middleware?: Middleware[]
middleware?: ModelMiddleware[]
): Promise<GenerateResponseData> {
// do tracing
return await runInNewSpan(
Expand All @@ -103,7 +103,7 @@ export async function generateHelper(
async function generate(
registry: Registry,
rawRequest: z.infer<typeof GenerateUtilParamSchema>,
middleware?: Middleware[]
middleware?: ModelMiddleware[]
): Promise<GenerateResponseData> {
const { modelAction: model } = await resolveModel(registry, rawRequest.model);
if (model.__action.metadata?.model.stage === 'deprecated') {
Expand Down
6 changes: 3 additions & 3 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
defineAction,
GenkitError,
getStreamingCallback,
Middleware,
SimpleMiddleware,
StreamingCallback,
z,
} from '@genkit-ai/core';
Expand Down Expand Up @@ -299,12 +299,12 @@ export type ModelAction<
> = Action<
typeof GenerateRequestSchema,
typeof GenerateResponseSchema,
{ model: ModelInfo }
typeof GenerateResponseChunkSchema
> & {
__configSchema: CustomOptionsSchema;
};

export type ModelMiddleware = Middleware<
export type ModelMiddleware = SimpleMiddleware<
z.infer<typeof GenerateRequestSchema>,
z.infer<typeof GenerateResponseSchema>
>;
Expand Down
4 changes: 3 additions & 1 deletion js/ai/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
import {
GenerateRequest,
GenerateRequestSchema,
GenerateResponseChunkSchema,
ModelArgument,
} from './model.js';
import { ToolAction } from './tool.js';
Expand All @@ -36,7 +37,8 @@ export type PromptFn<

export type PromptAction<I extends z.ZodTypeAny = z.ZodTypeAny> = Action<
I,
typeof GenerateRequestSchema
typeof GenerateRequestSchema,
typeof GenerateResponseChunkSchema
> & {
__action: {
metadata: {
Expand Down
6 changes: 1 addition & 5 deletions js/ai/src/reranker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,7 @@ export const RerankerInfoSchema = z.object({
export type RerankerInfo = z.infer<typeof RerankerInfoSchema>;

export type RerankerAction<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> =
Action<
typeof RerankerRequestSchema,
typeof RerankerResponseSchema,
{ model: RerankerInfo }
> & {
Action<typeof RerankerRequestSchema, typeof RerankerResponseSchema> & {
__configSchema?: CustomOptions;
};

Expand Down
6 changes: 1 addition & 5 deletions js/ai/src/retriever.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ export const RetrieverInfoSchema = z.object({
export type RetrieverInfo = z.infer<typeof RetrieverInfoSchema>;

export type RetrieverAction<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> =
Action<
typeof RetrieverRequestSchema,
typeof RetrieverResponseSchema,
{ model: RetrieverInfo }
> & {
Action<typeof RetrieverRequestSchema, typeof RetrieverResponseSchema> & {
__configSchema?: CustomOptions;
};

Expand Down
154 changes: 116 additions & 38 deletions js/core/src/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export { JSONSchema7 };
export interface ActionMetadata<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
S extends z.ZodTypeAny,
> {
actionType?: ActionType;
name: string;
Expand All @@ -43,7 +43,8 @@ export interface ActionMetadata<
inputJsonSchema?: JSONSchema7;
outputSchema?: O;
outputJsonSchema?: JSONSchema7;
metadata?: M;
streamSchema?: S;
metadata?: Record<string, any>;
}

/**
Expand All @@ -57,16 +58,52 @@ export interface ActionResult<O> {
};
}

/**
* Options (side channel) data to pass to the model.
*/
export interface ActionRunOptions<S> {
/**
* Streaming callback (optional).
*/
onChunk?: StreamingCallback<S>;

/**
* Additional runtime context data (ex. auth context data).
*/
context?: any;
}

/**
* Options (side channel) data to pass to the model.
*/
export interface ActionFnArg<S> {
/**
* Streaming callback (optional).
*/
sendChunk: StreamingCallback<S>;

/**
* Additional runtime context data (ex. auth context data).
*/
context?: any;
}

/**
* Self-describing, validating, observable, locally and remotely callable function.
*/
export type Action<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
> = ((input: z.infer<I>) => Promise<z.infer<O>>) & {
__action: ActionMetadata<I, O, M>;
run(input: z.infer<I>): Promise<ActionResult<z.infer<O>>>;
S extends z.ZodTypeAny = z.ZodTypeAny,
> = ((
input: z.infer<I>,
options?: ActionRunOptions<S>
) => Promise<z.infer<O>>) & {
__action: ActionMetadata<I, O, S>;
run(
input: z.infer<I>,
options?: ActionRunOptions<z.infer<S>>
): Promise<ActionResult<z.infer<O>>>;
};

/**
Expand All @@ -75,7 +112,7 @@ export type Action<
type ActionParams<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
S extends z.ZodTypeAny = z.ZodTypeAny,
> = {
name:
| string
Expand All @@ -88,49 +125,80 @@ type ActionParams<
inputJsonSchema?: JSONSchema7;
outputSchema?: O;
outputJsonSchema?: JSONSchema7;
metadata?: M;
use?: Middleware<z.infer<I>, z.infer<O>>[];
metadata?: Record<string, any>;
use?: Middleware<z.infer<I>, z.infer<O>, z.infer<S>>[];
streamingSchema?: S;
};

export type SimpleMiddleware<I = any, O = any> = (
req: I,
next: (req?: I) => Promise<O>
) => Promise<O>;

export type MiddlewareWithOptions<I = any, O = any, S = any> = (
req: I,
options: ActionRunOptions<S> | undefined,
next: (req?: I, options?: ActionRunOptions<S>) => Promise<O>
) => Promise<O>;

/**
* Middleware function for actions.
*/
export interface Middleware<I = any, O = any> {
(req: I, next: (req?: I) => Promise<O>): Promise<O>;
}
export type Middleware<I = any, O = any, S = any> =
| SimpleMiddleware<I, O>
| MiddlewareWithOptions<I, O, S>;

/**
* Creates an action with provided middleware.
*/
export function actionWithMiddleware<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
action: Action<I, O, M>,
middleware: Middleware<z.infer<I>, z.infer<O>>[]
): Action<I, O, M> {
action: Action<I, O, S>,
middleware: Middleware<z.infer<I>, z.infer<O>, z.infer<S>>[]
): Action<I, O, S> {
const wrapped = (async (req: z.infer<I>) => {
return (await wrapped.run(req)).result;
}) as Action<I, O, M>;
}) as Action<I, O, S>;
wrapped.__action = action.__action;
wrapped.run = async (req: z.infer<I>): Promise<ActionResult<z.infer<O>>> => {
wrapped.run = async (
req: z.infer<I>,
options?: ActionRunOptions<z.infer<S>>
): Promise<ActionResult<z.infer<O>>> => {
let telemetry;
const dispatch = async (index: number, req: z.infer<I>) => {
const dispatch = async (
index: number,
req: z.infer<I>,
opts?: ActionRunOptions<z.infer<S>>
) => {
if (index === middleware.length) {
// end of the chain, call the original model action
const result = await action.run(req);
const result = await action.run(req, opts);
telemetry = result.telemetry;
return result.result;
}

const currentMiddleware = middleware[index];
return currentMiddleware(req, async (modifiedReq) =>
dispatch(index + 1, modifiedReq || req)
);
if (currentMiddleware.length === 3) {
return (currentMiddleware as MiddlewareWithOptions<I, O, z.infer<S>>)(
req,
opts,
async (modifiedReq, modifiedOptions) =>
dispatch(index + 1, modifiedReq || req, modifiedOptions || opts)
);
} else if (currentMiddleware.length === 2) {
return (currentMiddleware as SimpleMiddleware<I, O>)(
req,
async (modifiedReq) => dispatch(index + 1, modifiedReq || req, opts)
);
} else {
throw new Error('unspported middleware function shape');
}
};

return { result: await dispatch(0, req), telemetry };
return { result: await dispatch(0, req, options), telemetry };
};
return wrapped;
}
Expand All @@ -141,17 +209,20 @@ export function actionWithMiddleware<
export function action<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
config: ActionParams<I, O, M>,
fn: (input: z.infer<I>) => Promise<z.infer<O>>
): Action<I, O> {
config: ActionParams<I, O, S>,
fn: (
input: z.infer<I>,
options: ActionFnArg<z.infer<S>>
) => Promise<z.infer<O>>
): Action<I, O, z.infer<S>> {
const actionName =
typeof config.name === 'string'
? config.name
: `${config.name.pluginId}/${config.name.actionId}`;
const actionFn = async (input: I) => {
return (await actionFn.run(input)).result;
const actionFn = async (input: I, options?: ActionRunOptions<z.infer<S>>) => {
return (await actionFn.run(input, options)).result;
};
actionFn.__action = {
name: actionName,
Expand All @@ -161,9 +232,10 @@ export function action<
outputSchema: config.outputSchema,
outputJsonSchema: config.outputJsonSchema,
metadata: config.metadata,
} as ActionMetadata<I, O, M>;
} as ActionMetadata<I, O, S>;
actionFn.run = async (
input: z.infer<I>
input: z.infer<I>,
options?: ActionRunOptions<z.infer<S>>
): Promise<ActionResult<z.infer<O>>> => {
input = parseSchema(input, {
schema: config.inputSchema,
Expand All @@ -184,7 +256,10 @@ export function action<
metadata.name = actionName;
metadata.input = input;

const output = await fn(input);
const output = await fn(input, {
context: options?.context,
sendChunk: options?.onChunk ?? ((c) => {}),
});

metadata.output = JSON.stringify(output);
return output;
Expand Down Expand Up @@ -239,13 +314,16 @@ function validateActionId(actionId: string) {
export function defineAction<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
config: ActionParams<I, O, M> & {
config: ActionParams<I, O, S> & {
actionType: ActionType;
},
fn: (input: z.infer<I>) => Promise<z.infer<O>>
fn: (
input: z.infer<I>,
options: ActionFnArg<z.infer<S>>
) => Promise<z.infer<O>>
): Action<I, O> {
if (isInRuntimeContext()) {
throw new Error(
Expand All @@ -258,10 +336,10 @@ export function defineAction<
} else {
validateActionId(config.name.actionId);
}
const act = action(config, async (i: I): Promise<z.infer<O>> => {
const act = action(config, async (i: I, options): Promise<z.infer<O>> => {
setCustomMetadataAttributes({ subtype: config.actionType });
await registry.initializeAllPlugins();
return await runInActionRuntimeContext(() => fn(i));
return await runInActionRuntimeContext(() => fn(i, options));
});
act.__action.actionType = config.actionType;
registry.registerAction(config.actionType, act);
Expand Down
Loading
Loading