Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 37 additions & 19 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
runWithContext,
runWithStreamingCallback,
sentinelNoopStreamingCallback,
stripUndefinedProps,
z,
} from '@genkit-ai/core';
import { Channel } from '@genkit-ai/core/async';
Expand All @@ -33,7 +34,7 @@ import {
resolveFormat,
resolveInstructions,
} from './formats/index.js';
import { GenerateUtilParamSchema, generateHelper } from './generate/action.js';
import { GenerateActionOptions, generateHelper } from './generate/action.js';
import { GenerateResponseChunk } from './generate/chunk.js';
import { GenerateResponse } from './generate/response.js';
import { Message } from './message.js';
Expand Down Expand Up @@ -139,11 +140,12 @@ export interface GenerateOptions<
context?: ActionContext;
}

function applyResumeOption(
/** Amends message history to handle `resume` arguments. Returns the amended history. */
async function applyResumeOption(
options: GenerateOptions,
messages: MessageData[]
): MessageData[] {
if (!options.resume) return [];
): Promise<MessageData[]> {
if (!options.resume) return messages;
if (
messages.at(-1)?.role !== 'model' ||
!messages
Expand All @@ -159,29 +161,37 @@ function applyResumeOption(
const toolRequests = lastModelMessage.content.filter((p) => !!p.toolRequest);

const pendingResponses: ToolResponsePart[] = toolRequests
.filter((t) => !!t.metadata?.pendingToolResponse)
.map((t) => ({
toolResponse: t.metadata!.pendingToolResponse,
})) as ToolResponsePart[];
.filter((t) => !!t.metadata?.pendingOutput)
.map((t) =>
stripUndefinedProps({
toolResponse: {
name: t.toolRequest!.name,
ref: t.toolRequest!.ref,
output: t.metadata!.pendingOutput,
},
metadata: { source: 'pending' },
})
) as ToolResponsePart[];

const reply = Array.isArray(options.resume.reply)
? options.resume.reply
: [options.resume.reply];

const message: MessageData = {
role: 'tool',
content: [...pendingResponses, ...reply],
metadata: {
resume: options.resume.metadata || true,
},
};
return [message];
return [...messages, message];
}

export async function toGenerateRequest(
registry: Registry,
options: GenerateOptions
): Promise<GenerateRequest> {
const messages: MessageData[] = [];
let messages: MessageData[] = [];
if (options.system) {
messages.push({
role: 'system',
Expand All @@ -192,7 +202,7 @@ export async function toGenerateRequest(
messages.push(...options.messages.map((m) => Message.parseData(m)));
}
// resuming from interrupts occurs after message history but before user prompt
messages.push(...applyResumeOption(options, messages));
messages = await applyResumeOption(options, messages);
if (options.prompt) {
messages.push({
role: 'user',
Expand Down Expand Up @@ -346,12 +356,21 @@ export async function generate<
jsonSchema: resolvedOptions.output?.jsonSchema,
});

// If is schema is set but format is not explicitly set, default to `json` format.
if (resolvedOptions.output?.schema && !resolvedOptions.output?.format) {
resolvedOptions.output.format = 'json';
}
const resolvedFormat = await resolveFormat(registry, resolvedOptions.output);
const instructions = resolveInstructions(
resolvedFormat,
resolvedSchema,
resolvedOptions?.output?.instructions
);

const params: z.infer<typeof GenerateUtilParamSchema> = {
const params: GenerateActionOptions = {
model: resolvedModel.modelAction.__action.name,
docs: resolvedOptions.docs,
messages: messages,
messages: injectInstructions(messages, instructions),
tools,
toolChoice: resolvedOptions.toolChoice,
config: {
Expand All @@ -371,15 +390,14 @@ export async function generate<
registry,
stripNoop(resolvedOptions.onChunk ?? resolvedOptions.streamingCallback),
async () => {
const generateFn = () =>
generateHelper(registry, {
rawRequest: params,
middleware: resolvedOptions.use,
});
const response = await runWithContext(
registry,
resolvedOptions.context,
generateFn
() =>
generateHelper(registry, {
rawRequest: params,
middleware: resolvedOptions.use,
})
);
const request = await toGenerateRequest(registry, {
...resolvedOptions,
Expand Down
Loading
Loading