diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index ab83b50129..f1f84bfb8e 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -22,6 +22,7 @@ import { runWithContext, runWithStreamingCallback, sentinelNoopStreamingCallback, + stripUndefinedProps, z, } from '@genkit-ai/core'; import { Channel } from '@genkit-ai/core/async'; @@ -33,7 +34,7 @@ import { resolveFormat, resolveInstructions, } from './formats/index.js'; -import { GenerateUtilParamSchema, generateHelper } from './generate/action.js'; +import { GenerateActionOptions, generateHelper } from './generate/action.js'; import { GenerateResponseChunk } from './generate/chunk.js'; import { GenerateResponse } from './generate/response.js'; import { Message } from './message.js'; @@ -139,11 +140,12 @@ export interface GenerateOptions< context?: ActionContext; } -function applyResumeOption( +/** Amends message history to handle `resume` arguments. Returns the amended history. */ +async function applyResumeOption( options: GenerateOptions, messages: MessageData[] -): MessageData[] { - if (!options.resume) return []; +): Promise { + if (!options.resume) return messages; if ( messages.at(-1)?.role !== 'model' || !messages @@ -159,14 +161,22 @@ function applyResumeOption( const toolRequests = lastModelMessage.content.filter((p) => !!p.toolRequest); const pendingResponses: ToolResponsePart[] = toolRequests - .filter((t) => !!t.metadata?.pendingToolResponse) - .map((t) => ({ - toolResponse: t.metadata!.pendingToolResponse, - })) as ToolResponsePart[]; + .filter((t) => !!t.metadata?.pendingOutput) + .map((t) => + stripUndefinedProps({ + toolResponse: { + name: t.toolRequest!.name, + ref: t.toolRequest!.ref, + output: t.metadata!.pendingOutput, + }, + metadata: { source: 'pending' }, + }) + ) as ToolResponsePart[]; const reply = Array.isArray(options.resume.reply) ? options.resume.reply : [options.resume.reply]; + const message: MessageData = { role: 'tool', content: [...pendingResponses, ...reply], @@ -174,14 +184,14 @@ function applyResumeOption( resume: options.resume.metadata || true, }, }; - return [message]; + return [...messages, message]; } export async function toGenerateRequest( registry: Registry, options: GenerateOptions ): Promise { - const messages: MessageData[] = []; + let messages: MessageData[] = []; if (options.system) { messages.push({ role: 'system', @@ -192,7 +202,7 @@ export async function toGenerateRequest( messages.push(...options.messages.map((m) => Message.parseData(m))); } // resuming from interrupts occurs after message history but before user prompt - messages.push(...applyResumeOption(options, messages)); + messages = await applyResumeOption(options, messages); if (options.prompt) { messages.push({ role: 'user', @@ -346,12 +356,21 @@ export async function generate< jsonSchema: resolvedOptions.output?.jsonSchema, }); + // If is schema is set but format is not explicitly set, default to `json` format. + if (resolvedOptions.output?.schema && !resolvedOptions.output?.format) { + resolvedOptions.output.format = 'json'; + } const resolvedFormat = await resolveFormat(registry, resolvedOptions.output); + const instructions = resolveInstructions( + resolvedFormat, + resolvedSchema, + resolvedOptions?.output?.instructions + ); - const params: z.infer = { + const params: GenerateActionOptions = { model: resolvedModel.modelAction.__action.name, docs: resolvedOptions.docs, - messages: messages, + messages: injectInstructions(messages, instructions), tools, toolChoice: resolvedOptions.toolChoice, config: { @@ -371,15 +390,14 @@ export async function generate< registry, stripNoop(resolvedOptions.onChunk ?? resolvedOptions.streamingCallback), async () => { - const generateFn = () => - generateHelper(registry, { - rawRequest: params, - middleware: resolvedOptions.use, - }); const response = await runWithContext( registry, resolvedOptions.context, - generateFn + () => + generateHelper(registry, { + rawRequest: params, + middleware: resolvedOptions.use, + }) ); const request = await toGenerateRequest(registry, { ...resolvedOptions, diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 9701285ce9..67d4a3202e 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -15,16 +15,15 @@ */ import { - GenkitError, getStreamingCallback, runWithStreamingCallback, + stripUndefinedProps, z, } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing'; -import * as clc from 'colorette'; import { DocumentDataSchema } from '../document.js'; import { injectInstructions, @@ -43,7 +42,6 @@ import { GenerateRequestSchema, GenerateResponseChunkData, GenerateResponseData, - MessageData, MessageSchema, ModelAction, ModelInfo, @@ -52,17 +50,15 @@ import { Part, Role, ToolDefinitionSchema, - ToolResponsePart, resolveModel, } from '../model.js'; +import { ToolAction, resolveTools, toToolDefinition } from '../tool.js'; import { - ToolAction, - ToolInterruptError, - resolveTools, - toToolDefinition, -} from '../tool.js'; + assertValidToolNames, + resolveToolRequests, +} from './resolve-tool-requests.js'; -export const GenerateUtilParamSchema = z.object({ +export const GenerateActionOptionsSchema = z.object({ /** A model name (e.g. `vertexai/gemini-1.0-pro`). */ model: z.string(), /** Retrieved documents to be used as context for this generation. */ @@ -89,6 +85,7 @@ export const GenerateUtilParamSchema = z.object({ /** Maximum number of tool call iterations that can be performed in a single generate call (default 5). */ maxTurns: z.number().optional(), }); +export type GenerateActionOptions = z.infer; /** * Encapsulates all generate logic. This is similar to `generateAction` except not an action and can take middleware. @@ -96,7 +93,7 @@ export const GenerateUtilParamSchema = z.object({ export async function generateHelper( registry: Registry, options: { - rawRequest: z.infer; + rawRequest: GenerateActionOptions; middleware?: ModelMiddleware[]; currentTurn?: number; messageIndex?: number; @@ -130,115 +127,140 @@ export async function generateHelper( ); } -async function generate( +/** Take the raw request and resolve tools, model, and format into their registry action counterparts. */ +async function resolveParameters( registry: Registry, - options: { - rawRequest: z.infer; - middleware: ModelMiddleware[] | undefined; - currentTurn: number; - messageIndex: number; - } -): Promise { - const { modelAction: model } = await resolveModel( - registry, - options.rawRequest.model - ); - if (model.__action.metadata?.model.stage === 'deprecated') { - logger.warn( - `${clc.bold(clc.yellow('Warning:'))} ` + - `Model '${model.__action.name}' is deprecated and may be removed in a future release.` - ); - } - - const tools = await resolveTools(registry, options.rawRequest.tools); - - const resolvedSchema = toJsonSchema({ - jsonSchema: options.rawRequest.output?.jsonSchema, - }); + request: GenerateActionOptions +) { + const [model, tools, format] = await Promise.all([ + resolveModel(registry, request.model, { warnDeprecated: true }).then( + (r) => r.modelAction + ), + resolveTools(registry, request.tools), + resolveFormat(registry, request.output), + ]); + return { model, tools, format }; +} +/** Given a raw request and a formatter, apply the formatter's logic and instructions to the request. */ +function applyFormat( + rawRequest: GenerateActionOptions, + resolvedFormat?: Formatter +) { + const outRequest = { ...rawRequest }; // If is schema is set but format is not explicitly set, default to `json` format. - if ( - options.rawRequest.output?.jsonSchema && - !options.rawRequest.output?.format - ) { - options.rawRequest.output.format = 'json'; + if (rawRequest.output?.jsonSchema && !rawRequest.output?.format) { + outRequest.output = { ...rawRequest.output, format: 'json' }; } - const resolvedFormat = await resolveFormat( - registry, - options.rawRequest.output - ); + const instructions = resolveInstructions( resolvedFormat, - resolvedSchema, - options.rawRequest?.output?.instructions + outRequest.output?.jsonSchema, + outRequest?.output?.instructions ); + if (resolvedFormat) { - options.rawRequest.messages = injectInstructions( - options.rawRequest.messages, - instructions - ); - options.rawRequest.output = { + outRequest.messages = injectInstructions(outRequest.messages, instructions); + outRequest.output = { // use output config from the format ...resolvedFormat.config, // if anything is set explicitly, use that - ...options.rawRequest.output, + ...outRequest.output, }; } - // Create a lookup of tool names with namespaces stripped to original names - const toolMap = tools.reduce>((acc, tool) => { - const name = tool.__action.name; - const shortName = name.substring(name.lastIndexOf('/') + 1); - if (acc[shortName]) { - throw new GenkitError({ - status: 'INVALID_ARGUMENT', - message: `Cannot provide two tools with the same name: '${name}' and '${acc[shortName]}'`, - }); - } - acc[shortName] = tool; - return acc; - }, {}); + return outRequest; +} + +function applyTransferPreamble( + rawRequest: GenerateActionOptions, + transferPreamble?: GenerateActionOptions +): GenerateActionOptions { + if (!transferPreamble) { + return rawRequest; + } + + return stripUndefinedProps({ + ...rawRequest, + messages: [ + ...tagAsPreamble(transferPreamble.messages!)!, + ...rawRequest.messages.filter((m) => !m.metadata?.preamble), + ], + toolChoice: transferPreamble.toolChoice || rawRequest.toolChoice, + tools: transferPreamble.tools || rawRequest.tools, + }); +} + +async function generate( + registry: Registry, + { + rawRequest, + middleware, + currentTurn, + messageIndex, + }: { + rawRequest: GenerateActionOptions; + middleware: ModelMiddleware[] | undefined; + currentTurn: number; + messageIndex: number; + } +): Promise { + const { model, tools, format } = await resolveParameters( + registry, + rawRequest + ); + rawRequest = applyFormat(rawRequest, format); + + // check to make sure we don't have overlapping tool names *before* generation + await assertValidToolNames(tools); const request = await actionToGenerateRequest( - options.rawRequest, + rawRequest, tools, - resolvedFormat, + format, model ); - const accumulatedChunks: GenerateResponseChunkData[] = []; + const previousChunks: GenerateResponseChunkData[] = []; + + let chunkRole: Role = 'model'; + // convenience method to create a full chunk from role and data, append the chunk + // to the previousChunks array, and increment the message index as needed + const makeChunk = ( + role: Role, + chunk: GenerateResponseChunkData + ): GenerateResponseChunk => { + if (role !== chunkRole) messageIndex++; + chunkRole = role; + + const prevToSend = [...previousChunks]; + previousChunks.push(chunk); + + return new GenerateResponseChunk(chunk, { + index: messageIndex, + role, + previousChunks: prevToSend, + parser: format?.handler(request.output?.schema).parseChunk, + }); + }; const streamingCallback = getStreamingCallback(registry); const response = await runWithStreamingCallback( registry, - streamingCallback - ? (chunk: GenerateResponseChunkData) => { - // Store accumulated chunk data - if (streamingCallback) { - streamingCallback!( - new GenerateResponseChunk(chunk, { - index: options.messageIndex, - role: 'model', - previousChunks: accumulatedChunks, - parser: resolvedFormat?.handler(request.output?.schema) - .parseChunk, - }) - ); - } - accumulatedChunks.push(chunk); - } - : undefined, + streamingCallback && + ((chunk: GenerateResponseChunkData) => + streamingCallback(makeChunk('model', chunk))), async () => { const dispatch = async ( index: number, req: z.infer ) => { - if (!options.middleware || index === options.middleware.length) { + if (!middleware || index === middleware.length) { // end of the chain, call the original model action return await model(req); } - const currentMiddleware = options.middleware[index]; + const currentMiddleware = middleware[index]; return currentMiddleware(req, async (modifiedReq) => dispatch(index + 1, modifiedReq || req) ); @@ -246,24 +268,26 @@ async function generate( return new GenerateResponse(await dispatch(0, request), { request, - parser: resolvedFormat?.handler(request.output?.schema).parseMessage, + parser: format?.handler(request.output?.schema).parseMessage, }); } ); // Throw an error if the response is not usable. response.assertValid(); - const message = response.message!; // would have thrown if no message + const generatedMessage = response.message!; // would have thrown if no message - const toolCalls = message.content.filter((part) => !!part.toolRequest); - if (options.rawRequest.returnToolRequests || toolCalls.length === 0) { - if (toolCalls.length === 0) { - response.assertValidSchema(request); - } + const toolRequests = generatedMessage.content.filter( + (part) => !!part.toolRequest + ); + + if (rawRequest.returnToolRequests || toolRequests.length === 0) { + if (toolRequests.length === 0) response.assertValidSchema(request); return response.toJSON(); } - const maxIterations = options.rawRequest.maxTurns ?? 5; - if (options.currentTurn + 1 > maxIterations) { + + const maxIterations = rawRequest.maxTurns ?? 5; + if (currentTurn + 1 > maxIterations) { throw new GenerationResponseError( response, `Exceeded maximum tool call iterations (${maxIterations})`, @@ -272,132 +296,43 @@ async function generate( ); } - const toolResponses: ToolResponsePart[] = []; - let messages: MessageData[] = [...request.messages, message]; - let newTools = options.rawRequest.tools; - let newToolChoice = options.rawRequest.toolChoice; - let interruptedParts: Part[] = []; - let pendingToolRequests: Part[] = []; - for (const part of toolCalls) { - if (!part.toolRequest) { - throw Error( - 'Tool request expected but not provided in tool request part' - ); - } - const tool = toolMap[part.toolRequest?.name]; - if (!tool) { - throw Error(`Tool ${part.toolRequest?.name} not found`); - } - if ((tool.__action.metadata.type as string) === 'prompt') { - try { - const newPreamble = await tool(part.toolRequest?.input); - toolResponses.push({ - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output: `transferred to ${part.toolRequest.name}`, - }, - }); - // swap out the preamble - messages = [ - ...tagAsPreamble(newPreamble.messages)!, - ...messages.filter((m) => !m?.metadata?.preamble), - ]; - newTools = newPreamble.tools; - newToolChoice = newPreamble.toolChoice; - } catch (e) { - if (e instanceof ToolInterruptError) { - logger.debug(`interrupted tool ${part.toolRequest?.name}`); - part.metadata = { ...part.metadata, interrupt: e.metadata || true }; - interruptedParts.push(part); - } else { - throw e; - } - } - } else { - try { - const toolOutput = await tool(part.toolRequest?.input); - toolResponses.push({ - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output: toolOutput, - }, - }); - // we prep these in case any other tool gets interrupted. - pendingToolRequests.push({ - ...part, - metadata: { - ...part.metadata, - pendingToolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output: toolOutput, - }, - }, - }); - } catch (e) { - if (e instanceof ToolInterruptError) { - logger.debug(`interrupted tool ${part.toolRequest?.name}`); - part.metadata = { ...part.metadata, interrupt: e.metadata || true }; - interruptedParts.push(part); - } else { - throw e; - } - } - } - } - options.messageIndex++; - const nextRequest = { - ...options.rawRequest, - messages: [ - ...messages, - { - role: 'tool', - content: toolResponses, - }, - ] as MessageData[], - tools: newTools, - toolCoice: newToolChoice, - }; - // stream out the tool responses - streamingCallback?.( - new GenerateResponseChunk( - { - content: toolResponses, - }, - { - index: options.messageIndex, - role: 'model', - previousChunks: accumulatedChunks, - parser: resolvedFormat?.handler(request.output?.schema).parseChunk, - } - ) - ); - if (interruptedParts.length > 0) { - const nonToolParts = - (response.message?.content.filter((c) => !c.toolRequest) as Part[]) || []; + const { revisedModelMessage, toolMessage, transferPreamble } = + await resolveToolRequests(registry, rawRequest, generatedMessage); + + // if an interrupt message is returned, stop the tool loop and return a response + if (revisedModelMessage) { return { ...response.toJSON(), finishReason: 'interrupted', - message: { - role: 'model', - content: nonToolParts - .concat(pendingToolRequests) - .concat(interruptedParts), - }, + finishMessage: 'One or more tool calls resulted in interrupts.', + message: revisedModelMessage, }; } + + // if the loop will continue, stream out the tool response message... + streamingCallback?.( + makeChunk('tool', { + content: toolMessage!.content, + }) + ); + + let nextRequest = { + ...rawRequest, + messages: [...rawRequest.messages, generatedMessage, toolMessage!], + }; + nextRequest = applyTransferPreamble(nextRequest, transferPreamble); + + // then recursively call for another loop return await generateHelper(registry, { rawRequest: nextRequest, - middleware: options.middleware, - currentTurn: options.currentTurn + 1, - messageIndex: options.messageIndex + 1, + middleware: middleware, + currentTurn: currentTurn + 1, + messageIndex: messageIndex + 1, }); } async function actionToGenerateRequest( - options: z.infer, + options: GenerateActionOptions, resolvedTools: ToolAction[] | undefined, resolvedFormat: Formatter | undefined, model: ModelAction diff --git a/js/ai/src/generate/resolve-tool-requests.ts b/js/ai/src/generate/resolve-tool-requests.ts new file mode 100644 index 0000000000..7376fad501 --- /dev/null +++ b/js/ai/src/generate/resolve-tool-requests.ts @@ -0,0 +1,153 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { GenkitError, stripUndefinedProps } from '@genkit-ai/core'; +import { logger } from '@genkit-ai/core/logging'; +import { Registry } from '@genkit-ai/core/registry'; +import { MessageData, ToolResponsePart } from '../model.js'; +import { isPromptAction } from '../prompt.js'; +import { ToolAction, ToolInterruptError, resolveTools } from '../tool.js'; +import { GenerateActionOptions } from './action.js'; + +export function toToolMap(tools: ToolAction[]): Record { + assertValidToolNames(tools); + const out: Record = {}; + for (const tool of tools) { + const name = tool.__action.name; + const shortName = name.substring(name.lastIndexOf('/') + 1); + out[shortName] = tool; + } + return out; +} + +/** Ensures that each tool has a unique name. */ +export function assertValidToolNames(tools: ToolAction[]) { + const nameMap: Record = {}; + for (const tool of tools) { + const name = tool.__action.name; + const shortName = name.substring(name.lastIndexOf('/') + 1); + if (nameMap[shortName]) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Cannot provide two tools with the same name: '${name}' and '${nameMap[shortName]}'`, + }); + } + nameMap[shortName] = name; + } +} + +/** + * resolveToolRequests is responsible for executing the tools requested by the model for a single turn. it + * returns either a toolMessage to append or a revisedModelMessage when an interrupt occurs, and a transferPreamble + * if a prompt tool is called + */ +export async function resolveToolRequests( + registry: Registry, + rawRequest: GenerateActionOptions, + generatedMessage: MessageData +): Promise<{ + revisedModelMessage?: MessageData; + toolMessage?: MessageData; + transferPreamble?: GenerateActionOptions; +}> { + const toolMap = toToolMap(await resolveTools(registry, rawRequest.tools)); + + const responseParts: ToolResponsePart[] = []; + let hasInterrupts: boolean = false; + let transferPreamble: GenerateActionOptions | undefined; + + const revisedModelMessage = { + ...generatedMessage, + content: [...generatedMessage.content], + }; + + await Promise.all( + revisedModelMessage.content.map(async (part, i) => { + if (!part.toolRequest) return; // skip non-tool-request parts + + const tool = toolMap[part.toolRequest.name]; + if (!tool) { + throw new GenkitError({ + status: 'NOT_FOUND', + message: `Tool ${part.toolRequest.name} not found`, + detail: { request: rawRequest }, + }); + } + + // if it's a prompt action, go ahead and render the preamble + if (isPromptAction(tool)) { + if (transferPreamble) + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Model attempted to transfer to multiple prompt tools.`, + }); + transferPreamble = await tool(part.toolRequest.input); + responseParts.push({ + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output: `transferred to ${part.toolRequest.name}`, + }, + }); + return; + } + + // otherwise, execute the tool and catch interrupts + try { + const output = await tool(part.toolRequest.input, {}); + const responsePart = stripUndefinedProps({ + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output, + }, + }); + + revisedModelMessage.content.splice(i, 1, { + ...part, + metadata: { + ...part.metadata, + pendingOutput: responsePart.toolResponse.output, + }, + }); + responseParts.push(responsePart); + } catch (e) { + if (e instanceof ToolInterruptError) { + logger.debug( + `tool '${toolMap[part.toolRequest?.name].__action.name}' triggered an interrupt${e.metadata ? `: ${JSON.stringify(e.metadata)}` : ''}` + ); + revisedModelMessage.content.splice(i, 1, { + toolRequest: part.toolRequest, + metadata: { ...part.metadata, interrupt: e.metadata || true }, + }); + hasInterrupts = true; + return; + } + + throw e; + } + }) + ); + + if (hasInterrupts) { + return { revisedModelMessage }; + } + + return { + toolMessage: { role: 'tool', content: responseParts }, + transferPreamble, + }; +} diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index ce15adaa3c..f23e902065 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -23,6 +23,7 @@ import { StreamingCallback, z, } from '@genkit-ai/core'; +import { logger } from '@genkit-ai/core/logging'; import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { performance } from 'node:perf_hooks'; @@ -633,7 +634,8 @@ export interface ResolvedModel< export async function resolveModel( registry: Registry, - model: ModelArgument | undefined + model: ModelArgument | undefined, + options?: { warnDeprecated?: boolean } ): Promise> { let out: ResolvedModel; let modelId: string; @@ -670,9 +672,18 @@ export async function resolveModel( if (!out.modelAction) { throw new GenkitError({ status: 'NOT_FOUND', - message: `Model ${modelId} not found`, + message: `Model '${modelId}' not found`, }); } + if ( + options?.warnDeprecated && + out.modelAction.__action.metadata?.model?.stage === 'deprecated' + ) { + logger.warn( + `Model '${out.modelAction.__action.name}' is deprecated and may be removed in a future release.` + ); + } + return out; } diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index abfe82e6d6..c649198f38 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -74,6 +74,10 @@ export type PromptAction = Action< __executablePrompt: ExecutablePrompt; }; +export function isPromptAction(action: Action): action is PromptAction { + return action.__action.metadata?.type === 'prompt'; +} + /** * Prompt action. */ diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index da91335600..a39cc96da0 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -46,7 +46,12 @@ export type ToolAction< * it exists. */ reply( + /** The interrupt tool request to which you want to respond. */ interrupt: ToolRequestPart, + /** + * The data with which you want to respond. Must conform to a tool's output schema or an + * interrupt's input schema. + **/ replyData: z.infer, options?: { metadata?: Record } ): ToolResponsePart; diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 872c40007c..ac5399a835 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -290,17 +290,13 @@ describe('toGenerateRequest', () => { { toolRequest: { name: 'p1', ref: '1', input: { one: '1' } }, metadata: { - pendingToolResponse: { name: 'p1', ref: '1', output: 'done' }, + pendingOutput: 'done', }, }, { toolRequest: { name: 'p2', ref: '2', input: { one: '1' } }, metadata: { - pendingToolResponse: { - name: 'p2', - ref: '2', - output: 'done2', - }, + pendingOutput: 'done2', }, }, { @@ -338,17 +334,13 @@ describe('toGenerateRequest', () => { { toolRequest: { name: 'p1', ref: '1', input: { one: '1' } }, metadata: { - pendingToolResponse: { name: 'p1', ref: '1', output: 'done' }, + pendingOutput: 'done', }, }, { toolRequest: { name: 'p2', ref: '2', input: { one: '1' } }, metadata: { - pendingToolResponse: { - name: 'p2', - ref: '2', - output: 'done2', - }, + pendingOutput: 'done2', }, }, { @@ -371,8 +363,14 @@ describe('toGenerateRequest', () => { resume: true, }, content: [ - { toolResponse: { name: 'p1', ref: '1', output: 'done' } }, - { toolResponse: { name: 'p2', ref: '2', output: 'done2' } }, + { + toolResponse: { name: 'p1', ref: '1', output: 'done' }, + metadata: { source: 'pending' }, + }, + { + toolResponse: { name: 'p2', ref: '2', output: 'done2' }, + metadata: { source: 'pending' }, + }, { toolResponse: { name: 'i1', ref: '3', output: 'done3' } }, { toolResponse: { name: 'i2', ref: '4', output: 'done4' } }, ], @@ -391,10 +389,11 @@ describe('toGenerateRequest', () => { { name: 'GenkitError', status: test.throws } ); } else { - assert.deepStrictEqual( - await toGenerateRequest(registry, test.prompt as GenerateOptions), - test.expectedOutput + const actualOutput = await toGenerateRequest( + registry, + test.prompt as GenerateOptions ); + assert.deepStrictEqual(actualOutput, test.expectedOutput); } }); } diff --git a/js/ai/tests/prompt/prompt_test.ts b/js/ai/tests/prompt/prompt_test.ts index f2db66e3af..214f454be4 100644 --- a/js/ai/tests/prompt/prompt_test.ts +++ b/js/ai/tests/prompt/prompt_test.ts @@ -16,19 +16,19 @@ import { ActionContext, runWithContext, z } from '@genkit-ai/core'; import { Registry } from '@genkit-ai/core/registry'; +import { toJsonSchema } from '@genkit-ai/core/schema'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; -import { toJsonSchema } from '../../../core/src/schema'; -import { Document } from '../../lib/document'; -import { GenerateOptions } from '../../lib/index'; -import { Session } from '../../lib/session'; +import { Document } from '../../src/document.js'; +import { GenerateOptions } from '../../src/index.js'; import { ModelAction, defineModel } from '../../src/model'; import { PromptConfig, PromptGenerateOptions, definePrompt, -} from '../../src/prompt'; -import { defineTool } from '../../src/tool'; +} from '../../src/prompt.js'; +import { Session } from '../../src/session.js'; +import { defineTool } from '../../src/tool.js'; describe('prompt', () => { let registry; diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index c6337d6abc..81285862df 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -218,9 +218,7 @@ describe('generate', () => { // nothing } }, - (e: Error) => { - return e.message.includes('Model modelNotFound not found'); - } + { status: 'NOT_FOUND' } ); }); @@ -589,7 +587,7 @@ describe('generate', () => { }, ], index: 1, - role: 'model', + role: 'tool', }, { content: [{ text: 'done' }], @@ -688,13 +686,18 @@ describe('generate', () => { assert.strictEqual(reqCounter, 1); assert.deepStrictEqual(response.toolRequests, [ { + toolRequest: { + input: {}, + name: 'interruptingTool', + ref: 'ref123', + }, metadata: { - pendingToolResponse: { - name: 'simpleTool', - output: 'response: foo', - ref: 'ref123', + interrupt: { + confirm: 'is it a banana?', }, }, + }, + { toolRequest: { input: { name: 'foo', @@ -702,17 +705,8 @@ describe('generate', () => { name: 'simpleTool', ref: 'ref123', }, - }, - { - toolRequest: { - input: {}, - name: 'interruptingTool', - ref: 'ref123', - }, metadata: { - interrupt: { - confirm: 'is it a banana?', - }, + pendingOutput: 'response: foo', }, }, ]); @@ -724,12 +718,17 @@ describe('generate', () => { }, { metadata: { - pendingToolResponse: { - name: 'simpleTool', - output: 'response: foo', - ref: 'ref123', + interrupt: { + confirm: 'is it a banana?', }, }, + toolRequest: { + input: {}, + name: 'interruptingTool', + ref: 'ref123', + }, + }, + { toolRequest: { input: { name: 'foo', @@ -737,17 +736,8 @@ describe('generate', () => { name: 'simpleTool', ref: 'ref123', }, - }, - { metadata: { - interrupt: { - confirm: 'is it a banana?', - }, - }, - toolRequest: { - input: {}, - name: 'interruptingTool', - ref: 'ref123', + pendingOutput: 'response: foo', }, }, ], diff --git a/js/genkit/tests/prompts_test.ts b/js/genkit/tests/prompts_test.ts index 13622cda60..6a19d7f99c 100644 --- a/js/genkit/tests/prompts_test.ts +++ b/js/genkit/tests/prompts_test.ts @@ -601,7 +601,7 @@ describe('definePrompt - dotprompt', () => { const response = hi({ name: 'Genkit' }); await assert.rejects(response, { - message: 'NOT_FOUND: Model modelThatDoesNotExist not found', + message: "NOT_FOUND: Model 'modelThatDoesNotExist' not found", }); }); });