Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import { z } from 'zod';
import { DocumentData } from './document.js';
import { extractJson } from './extract.js';
import {
generateAction,
generateHelper,
GenerateUtilParamSchema,
inferRoleFromParts,
} from './generateAction.js';
Expand All @@ -41,6 +41,7 @@ import {
MessageData,
ModelAction,
ModelArgument,
ModelMiddleware,
ModelReference,
Part,
ToolDefinition,
Expand Down Expand Up @@ -490,6 +491,8 @@ export interface GenerateOptions<
returnToolRequests?: boolean;
/** When provided, models supporting streaming will call the provided callback with chunks as generation progresses. */
streamingCallback?: StreamingCallback<GenerateResponseChunk>;
/** Middleware to be used with this model call. */
use?: ModelMiddleware[];
}

async function resolveModel(options: GenerateOptions): Promise<ModelAction> {
Expand Down Expand Up @@ -612,7 +615,7 @@ export async function generate<
resolvedOptions.streamingCallback,
async () =>
new GenerateResponse<O>(
await generateAction(params),
await generateHelper(params, resolvedOptions.use),
await toGenerateRequest(resolvedOptions)
)
);
Expand Down
286 changes: 171 additions & 115 deletions js/ai/src/generateAction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
Action,
defineAction,
getStreamingCallback,
Middleware,
runWithStreamingCallback,
} from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
Expand All @@ -26,6 +27,7 @@ import {
toJsonSchema,
validateSchema,
} from '@genkit-ai/core/schema';
import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing';
import { z } from 'zod';
import { DocumentDataSchema } from './document.js';
import {
Expand All @@ -37,7 +39,9 @@ import {
import {
CandidateData,
GenerateRequest,
GenerateRequestSchema,
GenerateResponseChunkData,
GenerateResponseData,
GenerateResponseSchema,
MessageData,
MessageSchema,
Expand Down Expand Up @@ -85,141 +89,193 @@ export const generateAction = defineAction(
inputSchema: GenerateUtilParamSchema,
outputSchema: GenerateResponseSchema,
},
async (input) => {
const model = (await lookupAction(`/model/${input.model}`)) as ModelAction;
if (!model) {
throw new Error(`Model ${input.model} not found`);
async (input) => generate(input)
);

/**
* Encapsulates all generate logic. This is similar to `generateAction` except not an action and can take middleware.
*/
export async function generateHelper(
input: z.infer<typeof GenerateUtilParamSchema>,
middleware?: Middleware[]
): Promise<GenerateResponseData> {
// do tracing
return await runInNewSpan(
{
metadata: {
name: 'generate',
},
labels: {
[SPAN_TYPE_ATTR]: 'helper',
},
},
async (metadata) => {
metadata.name = 'generate';
metadata.input = input;
const output = await generate(input, middleware);
metadata.output = JSON.stringify(output);
return output;
}
);
}

let tools: ToolAction[] | undefined;
if (input.tools?.length) {
if (!model.__action.metadata?.model.supports?.tools) {
throw new Error(
`Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
);
}
tools = await Promise.all(
input.tools.map(async (toolRef) => {
if (typeof toolRef === 'string') {
const tool = (await lookupAction(toolRef)) as ToolAction;
if (!tool) {
throw new Error(`Tool ${toolRef} not found`);
}
return tool;
}
throw '';
})
async function generate(
input: z.infer<typeof GenerateUtilParamSchema>,
middleware?: Middleware[]
): Promise<GenerateResponseData> {
const model = (await lookupAction(`/model/${input.model}`)) as ModelAction;
if (!model) {
throw new Error(`Model ${input.model} not found`);
}

let tools: ToolAction[] | undefined;
if (input.tools?.length) {
if (!model.__action.metadata?.model.supports?.tools) {
throw new Error(
`Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.`
);
}
tools = await Promise.all(
input.tools.map(async (toolRef) => {
if (typeof toolRef === 'string') {
const tool = (await lookupAction(toolRef)) as ToolAction;
if (!tool) {
throw new Error(`Tool ${toolRef} not found`);
}
return tool;
}
throw '';
})
);
}

const request = await actionToGenerateRequest(input, tools);
const request = await actionToGenerateRequest(input, tools);

const accumulatedChunks: GenerateResponseChunkData[] = [];
const accumulatedChunks: GenerateResponseChunkData[] = [];

const streamingCallback = getStreamingCallback();
const response = await runWithStreamingCallback(
streamingCallback
? (chunk: GenerateResponseChunkData) => {
// Store accumulated chunk data
accumulatedChunks.push(chunk);
if (streamingCallback) {
streamingCallback!(
new GenerateResponseChunk(chunk, accumulatedChunks)
);
}
const streamingCallback = getStreamingCallback();
const response = await runWithStreamingCallback(
streamingCallback
? (chunk: GenerateResponseChunkData) => {
// Store accumulated chunk data
accumulatedChunks.push(chunk);
if (streamingCallback) {
streamingCallback!(
new GenerateResponseChunk(chunk, accumulatedChunks)
);
}
: undefined,
async () => new GenerateResponse(await model(request))
);
}
: undefined,
async () => {
const dispatch = async (
index: number,
req: z.infer<typeof GenerateRequestSchema>
) => {
if (!middleware || index === middleware.length) {
// end of the chain, call the original model action
return await model(req);
}

// throw NoValidCandidates if all candidates are blocked or
if (
!response.candidates.some((c) =>
['stop', 'length'].includes(c.finishReason)
)
) {
throw new NoValidCandidatesError({
message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`,
response,
});
const currentMiddleware = middleware[index];
return currentMiddleware(req, async (modifiedReq) =>
dispatch(index + 1, modifiedReq || req)
);
};

return new GenerateResponse(await dispatch(0, request));
}
);

if (input.output?.jsonSchema && !response.toolRequests()?.length) {
// find a candidate with valid output schema
const candidateErrors = response.candidates.map((c) => {
// don't validate messages that have no text or data
if (c.text() === '' && c.data() === null) return null;
// throw NoValidCandidates if all candidates are blocked or
if (
!response.candidates.some((c) =>
['stop', 'length'].includes(c.finishReason)
)
) {
throw new NoValidCandidatesError({
message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`,
response,
});
}

try {
parseSchema(c.output(), {
jsonSchema: input.output?.jsonSchema,
});
return null;
} catch (e) {
return e as Error;
}
});
// if all candidates have a non-null error...
if (candidateErrors.every((c) => !!c)) {
throw new NoValidCandidatesError({
message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`,
response,
detail: {
candidateErrors: candidateErrors,
},
if (input.output?.jsonSchema && !response.toolRequests()?.length) {
// find a candidate with valid output schema
const candidateErrors = response.candidates.map((c) => {
// don't validate messages that have no text or data
if (c.text() === '' && c.data() === null) return null;

try {
parseSchema(c.output(), {
jsonSchema: input.output?.jsonSchema,
});
return null;
} catch (e) {
return e as Error;
}
});
// if all candidates have a non-null error...
if (candidateErrors.every((c) => !!c)) {
throw new NoValidCandidatesError({
message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`,
response,
detail: {
candidateErrors: candidateErrors,
},
});
}
}

// Pick the first valid candidate.
let selected: Candidate<any> | undefined;
for (const candidate of response.candidates) {
if (isValidCandidate(candidate, tools || [])) {
selected = candidate;
break;
}
// Pick the first valid candidate.
let selected: Candidate<any> | undefined;
for (const candidate of response.candidates) {
if (isValidCandidate(candidate, tools || [])) {
selected = candidate;
break;
}
}

if (!selected) {
throw new Error('No valid candidates found');
}
if (!selected) {
throw new NoValidCandidatesError({
message: 'No valid candidates found',
response,
});
}

const toolCalls = selected.message.content.filter(
(part) => !!part.toolRequest
);
if (input.returnToolRequests || toolCalls.length === 0) {
return response.toJSON();
}
const toolResponses: ToolResponsePart[] = await Promise.all(
toolCalls.map(async (part) => {
if (!part.toolRequest) {
throw Error(
'Tool request expected but not provided in tool request part'
);
}
const tool = tools?.find(
(tool) => tool.__action.name === part.toolRequest?.name
);
if (!tool) {
throw Error('Tool not found');
}
return {
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: await tool(part.toolRequest?.input),
},
};
})
);
const nextRequest = {
...input,
history: [...request.messages, selected.message],
prompt: toolResponses,
};
return await generateAction(nextRequest);
const toolCalls = selected.message.content.filter(
(part) => !!part.toolRequest
);
if (input.returnToolRequests || toolCalls.length === 0) {
return response.toJSON();
}
);
const toolResponses: ToolResponsePart[] = await Promise.all(
toolCalls.map(async (part) => {
if (!part.toolRequest) {
throw Error(
'Tool request expected but not provided in tool request part'
);
}
const tool = tools?.find(
(tool) => tool.__action.name === part.toolRequest?.name
);
if (!tool) {
throw Error('Tool not found');
}
return {
toolResponse: {
name: part.toolRequest.name,
ref: part.toolRequest.ref,
output: await tool(part.toolRequest?.input),
},
};
})
);
const nextRequest = {
...input,
history: [...request.messages, selected.message],
prompt: toolResponses,
};
return await generateHelper(nextRequest, middleware);
}

async function actionToGenerateRequest(
options: z.infer<typeof GenerateUtilParamSchema>,
Expand Down
1 change: 1 addition & 0 deletions js/ai/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ export function defineModel<
configSchema?: CustomOptionsSchema;
/** Descriptive name for this model e.g. 'Google AI - Gemini Pro'. */
label?: string;
/** Middleware to be used with this model. */
use?: ModelMiddleware[];
},
runner: (
Expand Down
Loading