diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 3d55762992..a2abfe8a91 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -588,6 +588,9 @@ async function resolveFullToolNames( if (await registry.lookupAction(`/tool/${name}`)) { return [`/tool/${name}`]; } + if (await registry.lookupAction(`/tool.v2/${name}`)) { + return [`/tool.v2/${name}`]; + } if (await registry.lookupAction(`/prompt/${name}`)) { return [`/prompt/${name}`]; } diff --git a/js/ai/src/generate/resolve-tool-requests.ts b/js/ai/src/generate/resolve-tool-requests.ts index 7faf3a8c21..0fe23a7b1f 100644 --- a/js/ai/src/generate/resolve-tool-requests.ts +++ b/js/ai/src/generate/resolve-tool-requests.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { GenkitError, stripUndefinedProps } from '@genkit-ai/core'; +import { GenkitError, stripUndefinedProps, z } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; import type { Registry } from '@genkit-ai/core/registry'; import type { @@ -25,8 +25,10 @@ import type { ToolRequestPart, ToolResponsePart, } from '../model.js'; +import { ToolResponse } from '../parts.js'; import { isPromptAction } from '../prompt.js'; import { + MultipartToolResponseSchema, ToolInterruptError, isToolRequest, resolveTools, @@ -120,15 +122,33 @@ export async function resolveToolRequest( // otherwise, execute the tool and catch interrupts try { const output = await tool(part.toolRequest.input, toRunOptions(part)); - const response = stripUndefinedProps({ - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output, - }, - }); + if (tool.__action.actionType === 'tool.v2') { + const multipartResponse = output as z.infer< + typeof MultipartToolResponseSchema + >; + const strategy = multipartResponse.fallbackOutput ? 'fallback' : 'both'; + const response = stripUndefinedProps({ + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output: multipartResponse.output || multipartResponse.fallbackOutput, + content: multipartResponse.content, + payloadStrategy: strategy, + } as ToolResponse, + }); - return { response }; + return { response }; + } else { + const response = stripUndefinedProps({ + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output, + }, + }); + + return { response }; + } } catch (e) { if ( e instanceof ToolInterruptError || diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index 19c9665676..f3c79573fa 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -16,8 +16,8 @@ import { action, + ActionFnArg, assertUnstable, - defineAction, isAction, stripUndefinedProps, z, @@ -29,11 +29,12 @@ import { import type { Registry } from '@genkit-ai/core/registry'; import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; -import type { - Part, - ToolDefinition, - ToolRequestPart, - ToolResponsePart, +import { + PartSchema, + type Part, + type ToolDefinition, + type ToolRequestPart, + type ToolResponsePart, } from './model.js'; import { isExecutablePrompt, type ExecutablePrompt } from './prompt.js'; @@ -100,6 +101,26 @@ export type ToolAction< }; }; +/** + * An action with a `tool.v2` type. + */ +export type MultipartToolAction< + I extends z.ZodTypeAny = z.ZodTypeAny, + O extends z.ZodTypeAny = z.ZodTypeAny, +> = Action< + I, + typeof MultipartToolResponseSchema, + z.ZodTypeAny, + ToolRunOptions +> & + Resumable & { + __action: { + metadata: { + type: 'tool.v2'; + }; + }; + }; + /** * A dynamic action with a `tool` type. Dynamic tools are detached actions -- not associated with any registry. */ @@ -218,6 +239,7 @@ export async function lookupToolByName( const tool = (await registry.lookupAction(name)) || (await registry.lookupAction(`/tool/${name}`)) || + (await registry.lookupAction(`/tool.v2/${name}`)) || (await registry.lookupAction(`/prompt/${name}`)) || (await registry.lookupAction(`/dynamic-action-provider/${name}`)); if (!tool) { @@ -258,7 +280,7 @@ export function toToolDefinition( return out; } -export interface ToolFnOptions { +export interface ToolFnOptions extends ActionFnArg { /** * A function that can be called during tool execution that will result in the tool * getting interrupted (immediately) and tool request returned to the upstream caller. @@ -273,6 +295,26 @@ export type ToolFn = ( ctx: ToolFnOptions & ToolRunOptions ) => Promise>; +export type MultipartToolFn = ( + input: z.infer, + ctx: ToolFnOptions & ToolRunOptions +) => Promise<{ + output?: z.infer; + fallbackOutput?: z.infer; + content?: Part[]; +}>; + +export function defineTool( + registry: Registry, + config: { multipart: true } & ToolConfig, + fn?: ToolFn +): MultipartToolAction; +export function defineTool( + registry: Registry, + config: ToolConfig, + fn?: ToolFn +): ToolAction; + /** * Defines a tool. * @@ -280,25 +322,15 @@ export type ToolFn = ( */ export function defineTool( registry: Registry, - config: ToolConfig, - fn: ToolFn -): ToolAction { - const a = defineAction( - registry, - { - ...config, - actionType: 'tool', - metadata: { ...(config.metadata || {}), type: 'tool' }, - }, - (i, runOptions) => { - return fn(i, { - ...runOptions, - context: { ...runOptions.context }, - interrupt: interruptTool(registry), - }); - } - ); - implementTool(a as ToolAction, config, registry); + config: { multipart?: true } & ToolConfig, + fn?: ToolFn | MultipartToolFn +): ToolAction | MultipartToolAction { + const a = tool(config, fn); + registry.registerAction(config.multipart ? 'tool.v2' : 'tool', a); + if (!config.multipart) { + // For non-multipart tools, we register a v2 tool action as well + registry.registerAction('tool.v2', basicToolV2(config, fn as ToolFn)); + } return a as ToolAction; } @@ -432,27 +464,30 @@ function interruptTool(registry?: Registry) { }; } -/** - * Defines a dynamic tool. Dynamic tools are just like regular tools but will not be registered in the - * Genkit registry and can be defined dynamically at runtime. - */ +export function tool( + config: { multipart: true } & ToolConfig, + fn?: ToolFn +): MultipartToolAction; export function tool( config: ToolConfig, fn?: ToolFn -): ToolAction { - return dynamicTool(config, fn); -} +): ToolAction; /** * Defines a dynamic tool. Dynamic tools are just like regular tools but will not be registered in the * Genkit registry and can be defined dynamically at runtime. - * - * @deprecated renamed to {@link tool}. */ -export function dynamicTool( +export function tool( + config: { multipart?: true } & ToolConfig, + fn?: ToolFn | MultipartToolFn +): ToolAction | MultipartToolAction { + return config.multipart ? multipartTool(config, fn) : basicTool(config, fn); +} + +function basicTool( config: ToolConfig, fn?: ToolFn -): DynamicToolAction { +): ToolAction { const a = action( { ...config, @@ -470,8 +505,74 @@ export function dynamicTool( } return interrupt(); } - ) as DynamicToolAction; + ) as ToolAction; + implementTool(a, config); + return a; +} + +function basicToolV2( + config: ToolConfig, + fn?: ToolFn +): MultipartToolAction { + return multipartTool(config, async (input, ctx) => { + if (!fn) { + const interrupt = interruptTool(ctx.registry); + return interrupt(); + } + return { + output: await fn(input, ctx), + }; + }); +} + +export const MultipartToolResponseSchema = z.object({ + output: z.any().optional(), + fallbackOutput: z.any().optional(), + content: z.array(PartSchema).optional(), +}); + +function multipartTool( + config: ToolConfig, + fn?: MultipartToolFn +): MultipartToolAction { + const a = action( + { + ...config, + outputSchema: MultipartToolResponseSchema, + actionType: 'tool.v2', + metadata: { + ...(config.metadata || {}), + type: 'tool.v2', + tool: { multipart: true }, + }, + }, + (i, runOptions) => { + const interrupt = interruptTool(runOptions.registry); + if (fn) { + return fn(i, { + ...runOptions, + context: { ...runOptions.context }, + interrupt, + }); + } + return interrupt(); + } + ) as MultipartToolAction; implementTool(a as any, config); - a.attach = (_: Registry) => a; return a; } + +/** + * Defines a dynamic tool. Dynamic tools are just like regular tools but will not be registered in the + * Genkit registry and can be defined dynamically at runtime. + * + * @deprecated renamed to {@link tool}. + */ +export function dynamicTool( + config: ToolConfig, + fn?: ToolFn +): DynamicToolAction { + const t = basicTool(config, fn) as DynamicToolAction; + t.attach = (_: Registry) => t; + return t; +} diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 4abd89038e..089edf170e 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -602,4 +602,192 @@ describe('generate', () => { ['Testing default step name', 'Testing default step name'] ); }); + + it('handles multipart tool responses', async () => { + defineTool( + registry, + { + name: 'multiTool', + description: 'a tool with multiple parts', + multipart: true, + }, + async () => { + return { + output: 'main output', + content: [{ text: 'part 1' }], + }; + } + ); + + let requestCount = 0; + defineModel( + registry, + { name: 'multi-tool-model', supports: { tools: true } }, + async (input) => { + requestCount++; + return { + message: { + role: 'model', + content: [ + requestCount == 1 + ? { + toolRequest: { + name: 'multiTool', + input: {}, + }, + } + : { text: 'done' }, + ], + }, + finishReason: 'stop', + }; + } + ); + + const response = await generate(registry, { + model: 'multi-tool-model', + prompt: 'go', + tools: ['multiTool'], + }); + assert.deepStrictEqual(response.messages, [ + { + role: 'user', + content: [ + { + text: 'go', + }, + ], + }, + { + role: 'model', + content: [ + { + toolRequest: { + name: 'multiTool', + input: {}, + }, + }, + ], + }, + { + role: 'tool', + content: [ + { + toolResponse: { + name: 'multiTool', + output: 'main output', + content: [ + { + text: 'part 1', + }, + ], + payloadStrategy: 'both', + }, + }, + ], + }, + { + role: 'model', + content: [ + { + text: 'done', + }, + ], + }, + ]); + }); + + it('handles fallback tool responses', async () => { + defineTool( + registry, + { + name: 'fallbackTool', + description: 'a tool with fallback output', + multipart: true, + }, + async () => { + return { + fallbackOutput: 'fallback output', + content: [{ text: 'part 1' }], + }; + } + ); + + let requestCount = 0; + defineModel( + registry, + { name: 'fallback-tool-model', supports: { tools: true } }, + async (input) => { + requestCount++; + return { + message: { + role: 'model', + content: [ + requestCount == 1 + ? { + toolRequest: { + name: 'fallbackTool', + input: {}, + }, + } + : { text: 'done' }, + ], + }, + finishReason: 'stop', + }; + } + ); + + const response = await generate(registry, { + model: 'fallback-tool-model', + prompt: 'go', + tools: ['fallbackTool'], + }); + assert.deepStrictEqual(response.messages, [ + { + role: 'user', + content: [ + { + text: 'go', + }, + ], + }, + { + role: 'model', + content: [ + { + toolRequest: { + name: 'fallbackTool', + input: {}, + }, + }, + ], + }, + { + role: 'tool', + content: [ + { + toolResponse: { + name: 'fallbackTool', + output: 'fallback output', + content: [ + { + text: 'part 1', + }, + ], + payloadStrategy: 'fallback', + }, + }, + ], + }, + { + role: 'model', + content: [ + { + text: 'done', + }, + ], + }, + ]); + }); }); diff --git a/js/ai/tests/tool_test.ts b/js/ai/tests/tool_test.ts index 74a0194e91..2bb71dccca 100644 --- a/js/ai/tests/tool_test.ts +++ b/js/ai/tests/tool_test.ts @@ -107,6 +107,46 @@ describe('defineInterrupt', () => { type: 'string', }); }); + + describe('multipart tools', () => { + it('should define a multipart tool', async () => { + const t = defineTool( + registry, + { name: 'test', description: 'test', multipart: true }, + async () => { + return { + output: 'main output', + content: [{ text: 'part 1' }], + }; + } + ); + assert.equal(t.__action.metadata.type, 'tool.v2'); + assert.equal(t.__action.actionType, 'tool.v2'); + const result = await t({}); + assert.deepStrictEqual(result, { + output: 'main output', + content: [{ text: 'part 1' }], + }); + }); + + it('should handle fallback output', async () => { + const t = defineTool( + registry, + { name: 'test', description: 'test', multipart: true }, + async () => { + return { + fallbackOutput: 'fallback', + content: [{ text: 'part 1' }], + }; + } + ); + const result = await t({}); + assert.deepStrictEqual(result, { + fallbackOutput: 'fallback', + content: [{ text: 'part 1' }], + }); + }); + }); }); describe('defineTool', () => { @@ -267,4 +307,32 @@ describe('defineTool', () => { ); }); }); + + it('should register a v1 tool as v2 as well', async () => { + defineTool(registry, { name: 'test', description: 'test' }, async () => {}); + assert.ok(await registry.lookupAction('/tool/test')); + assert.ok(await registry.lookupAction('/tool.v2/test')); + }); + + it('should only register a multipart tool as v2', async () => { + defineTool( + registry, + { name: 'test', description: 'test', multipart: true }, + async () => {} + ); + assert.ok(await registry.lookupAction('/tool.v2/test')); + assert.equal(await registry.lookupAction('/tool/test'), undefined); + }); + + it('should wrap v1 tool output when called as v2', async () => { + defineTool( + registry, + { name: 'test', description: 'test' }, + async () => 'foo' + ); + const action = await registry.lookupAction('/tool.v2/test'); + assert.ok(action); + const result = await action!({}); + assert.deepStrictEqual(result, { output: 'foo' }); + }); }); diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index fbe128a241..a758a047d1 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -53,6 +53,7 @@ const ACTION_TYPES = [ 'reranker', 'retriever', 'tool', + 'tool.v2', 'util', 'resource', ] as const; diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 040d640113..2cd34a4c7a 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -97,7 +97,12 @@ import { type RetrieverFn, type SimpleRetrieverOptions, } from '@genkit-ai/ai/retriever'; -import { dynamicTool, type ToolFn } from '@genkit-ai/ai/tool'; +import { + dynamicTool, + type MultipartToolAction, + type MultipartToolFn, + type ToolFn, +} from '@genkit-ai/ai/tool'; import { ActionFnArg, GenkitError, @@ -222,6 +227,16 @@ export class Genkit implements HasRegistry { return flow; } + /** + * Defines and registers a tool that can return multiple parts of content. + * + * Tools can be passed to models by name or value during `generate` calls to be called automatically based on the prompt and situation. + */ + defineTool( + config: { multipart: true } & ToolConfig, + fn: MultipartToolFn + ): MultipartToolAction; + /** * Defines and registers a tool. * @@ -230,8 +245,13 @@ export class Genkit implements HasRegistry { defineTool( config: ToolConfig, fn: ToolFn - ): ToolAction { - return defineTool(this.registry, config, fn); + ): ToolAction; + + defineTool( + config: ({ multipart?: true } & ToolConfig) | string, + fn: ToolFn | MultipartToolFn + ): ToolAction | MultipartToolAction { + return defineTool(this.registry, config as any, fn as any); } /** diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index 9244851fa6..8114117f46 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -814,6 +814,58 @@ describe('generate', () => { assert.strictEqual(text, '{"foo":"bar a@b.c"}'); }); + it('calls the multipart tool', async () => { + const t = ai.defineTool( + { name: 'testTool', description: 'description', multipart: true }, + async () => ({ + output: 'tool called', + content: [{ text: 'part 1' }], + }) + ); + + // first response is a tool call, the subsequent responses are just text response from agent b. + let reqCounter = 0; + pm.handleResponse = async (req, sc) => { + return { + message: { + role: 'model', + content: [ + reqCounter++ === 0 + ? { + toolRequest: { + name: 'testTool', + input: {}, + ref: 'ref123', + }, + } + : { text: 'done' }, + ], + }, + }; + }; + + const { text, messages } = await ai.generate({ + prompt: 'call the tool', + tools: [t], + }); + + assert.strictEqual(text, 'done'); + assert.strictEqual(messages.length, 4); + const toolMessage = messages[2]; + assert.strictEqual(toolMessage.role, 'tool'); + assert.deepStrictEqual(toolMessage.content, [ + { + toolResponse: { + name: 'testTool', + ref: 'ref123', + output: 'tool called', + content: [{ text: 'part 1' }], + payloadStrategy: 'both', + }, + }, + ]); + }); + it('streams the tool responses', async () => { ai.defineTool( { name: 'testTool', description: 'description' },