Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,9 @@ async function resolveFullToolNames(
if (await registry.lookupAction(`/tool/${name}`)) {
return [`/tool/${name}`];
}
if (await registry.lookupAction(`/tool.v2/${name}`)) {
return [`/tool.v2/${name}`];
}
if (await registry.lookupAction(`/prompt/${name}`)) {
return [`/prompt/${name}`];
}
Expand Down
38 changes: 29 additions & 9 deletions js/ai/src/generate/resolve-tool-requests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import { GenkitError, stripUndefinedProps } from '@genkit-ai/core';
import { GenkitError, stripUndefinedProps, z } from '@genkit-ai/core';
import { logger } from '@genkit-ai/core/logging';
import type { Registry } from '@genkit-ai/core/registry';
import type {
Expand All @@ -25,8 +25,10 @@ import type {
ToolRequestPart,
ToolResponsePart,
} from '../model.js';
import { ToolResponse } from '../parts.js';
import { isPromptAction } from '../prompt.js';
import {
MultipartToolResponseSchema,
ToolInterruptError,
isToolRequest,
resolveTools,
Expand Down Expand Up @@ -120,15 +122,33 @@ export async function resolveToolRequest(
// otherwise, execute the tool and catch interrupts
try {
const output = await tool(part.toolRequest.input, toRunOptions(part));
const response = stripUndefinedProps({
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output,
},
});
if (tool.__action.actionType === 'tool.v2') {
const multipartResponse = output as z.infer<
typeof MultipartToolResponseSchema
>;
const strategy = multipartResponse.fallbackOutput ? 'fallback' : 'both';
const response = stripUndefinedProps({
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: multipartResponse.output || multipartResponse.fallbackOutput,
content: multipartResponse.content,
payloadStrategy: strategy,
} as ToolResponse,
});

return { response };
return { response };
} else {
const response = stripUndefinedProps({
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output,
},
});

return { response };
}
} catch (e) {
if (
e instanceof ToolInterruptError ||
Expand Down
179 changes: 140 additions & 39 deletions js/ai/src/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import {
action,
ActionFnArg,
assertUnstable,
defineAction,
isAction,
stripUndefinedProps,
z,
Expand All @@ -29,11 +29,12 @@ import {
import type { Registry } from '@genkit-ai/core/registry';
import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema';
import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing';
import type {
Part,
ToolDefinition,
ToolRequestPart,
ToolResponsePart,
import {
PartSchema,
type Part,
type ToolDefinition,
type ToolRequestPart,
type ToolResponsePart,
} from './model.js';
import { isExecutablePrompt, type ExecutablePrompt } from './prompt.js';

Expand Down Expand Up @@ -100,6 +101,26 @@ export type ToolAction<
};
};

/**
* An action with a `tool.v2` type.
*/
export type MultipartToolAction<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
> = Action<
I,
typeof MultipartToolResponseSchema,
z.ZodTypeAny,
ToolRunOptions
> &
Resumable<I, O> & {
__action: {
metadata: {
type: 'tool.v2';
};
};
};

/**
* A dynamic action with a `tool` type. Dynamic tools are detached actions -- not associated with any registry.
*/
Expand Down Expand Up @@ -218,6 +239,7 @@ export async function lookupToolByName(
const tool =
(await registry.lookupAction(name)) ||
(await registry.lookupAction(`/tool/${name}`)) ||
(await registry.lookupAction(`/tool.v2/${name}`)) ||
(await registry.lookupAction(`/prompt/${name}`)) ||
(await registry.lookupAction(`/dynamic-action-provider/${name}`));
if (!tool) {
Expand Down Expand Up @@ -258,7 +280,7 @@ export function toToolDefinition(
return out;
}

export interface ToolFnOptions {
export interface ToolFnOptions extends ActionFnArg<never> {
/**
* A function that can be called during tool execution that will result in the tool
* getting interrupted (immediately) and tool request returned to the upstream caller.
Expand All @@ -273,32 +295,42 @@ export type ToolFn<I extends z.ZodTypeAny, O extends z.ZodTypeAny> = (
ctx: ToolFnOptions & ToolRunOptions
) => Promise<z.infer<O>>;

export type MultipartToolFn<I extends z.ZodTypeAny, O extends z.ZodTypeAny> = (
input: z.infer<I>,
ctx: ToolFnOptions & ToolRunOptions
) => Promise<{
output?: z.infer<O>;
fallbackOutput?: z.infer<O>;
content?: Part[];
}>;

export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
registry: Registry,
config: { multipart: true } & ToolConfig<I, O>,
fn?: ToolFn<I, O>
): MultipartToolAction<I, O>;
export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
registry: Registry,
config: ToolConfig<I, O>,
fn?: ToolFn<I, O>
): ToolAction<I, O>;

/**
* Defines a tool.
*
* A tool is an action that can be passed to a model to be called automatically if it so chooses.
*/
export function defineTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
registry: Registry,
config: ToolConfig<I, O>,
fn: ToolFn<I, O>
): ToolAction<I, O> {
const a = defineAction(
registry,
{
...config,
actionType: 'tool',
metadata: { ...(config.metadata || {}), type: 'tool' },
},
(i, runOptions) => {
return fn(i, {
...runOptions,
context: { ...runOptions.context },
interrupt: interruptTool(registry),
});
}
);
implementTool(a as ToolAction<I, O>, config, registry);
config: { multipart?: true } & ToolConfig<I, O>,
fn?: ToolFn<I, O> | MultipartToolFn<I, O>
): ToolAction<I, O> | MultipartToolAction<I, O> {
const a = tool(config, fn);
registry.registerAction(config.multipart ? 'tool.v2' : 'tool', a);
if (!config.multipart) {
// For non-multipart tools, we register a v2 tool action as well
registry.registerAction('tool.v2', basicToolV2(config, fn as ToolFn<I, O>));
}
return a as ToolAction<I, O>;
}

Expand Down Expand Up @@ -432,27 +464,30 @@ function interruptTool(registry?: Registry) {
};
}

/**
* Defines a dynamic tool. Dynamic tools are just like regular tools but will not be registered in the
* Genkit registry and can be defined dynamically at runtime.
*/
export function tool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: { multipart: true } & ToolConfig<I, O>,
fn?: ToolFn<I, O>
): MultipartToolAction<I, O>;
export function tool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: ToolConfig<I, O>,
fn?: ToolFn<I, O>
): ToolAction<I, O> {
return dynamicTool(config, fn);
}
): ToolAction<I, O>;

/**
* Defines a dynamic tool. Dynamic tools are just like regular tools but will not be registered in the
* Genkit registry and can be defined dynamically at runtime.
*
* @deprecated renamed to {@link tool}.
*/
export function dynamicTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
export function tool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: { multipart?: true } & ToolConfig<I, O>,
fn?: ToolFn<I, O> | MultipartToolFn<I, O>
): ToolAction<I, O> | MultipartToolAction<I, O> {
return config.multipart ? multipartTool(config, fn) : basicTool(config, fn);
}

function basicTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: ToolConfig<I, O>,
fn?: ToolFn<I, O>
): DynamicToolAction<I, O> {
): ToolAction<I, O> {
const a = action(
{
...config,
Expand All @@ -470,8 +505,74 @@ export function dynamicTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
}
return interrupt();
}
) as DynamicToolAction<I, O>;
) as ToolAction<I, O>;
implementTool(a, config);
return a;
}

function basicToolV2<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: ToolConfig<I, O>,
fn?: ToolFn<I, O>
): MultipartToolAction<I, O> {
return multipartTool(config, async (input, ctx) => {
if (!fn) {
const interrupt = interruptTool(ctx.registry);
return interrupt();
}
return {
output: await fn(input, ctx),
};
});
}

export const MultipartToolResponseSchema = z.object({
output: z.any().optional(),
fallbackOutput: z.any().optional(),
content: z.array(PartSchema).optional(),
});

function multipartTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: ToolConfig<I, O>,
fn?: MultipartToolFn<I, O>
): MultipartToolAction<I, O> {
const a = action(
{
...config,
outputSchema: MultipartToolResponseSchema,
actionType: 'tool.v2',
metadata: {
...(config.metadata || {}),
type: 'tool.v2',
tool: { multipart: true },
},
},
(i, runOptions) => {
const interrupt = interruptTool(runOptions.registry);
if (fn) {
return fn(i, {
...runOptions,
context: { ...runOptions.context },
interrupt,
});
}
return interrupt();
}
) as MultipartToolAction<I, O>;
implementTool(a as any, config);
a.attach = (_: Registry) => a;
return a;
}

/**
* Defines a dynamic tool. Dynamic tools are just like regular tools but will not be registered in the
* Genkit registry and can be defined dynamically at runtime.
*
* @deprecated renamed to {@link tool}.
*/
export function dynamicTool<I extends z.ZodTypeAny, O extends z.ZodTypeAny>(
config: ToolConfig<I, O>,
fn?: ToolFn<I, O>
): DynamicToolAction<I, O> {
const t = basicTool(config, fn) as DynamicToolAction<I, O>;
t.attach = (_: Registry) => t;
return t;
}
Loading