diff --git a/js/ai/src/formats/array.ts b/js/ai/src/formats/array.ts index ec32a5be74..2df3a4777e 100644 --- a/js/ai/src/formats/array.ts +++ b/js/ai/src/formats/array.ts @@ -18,47 +18,48 @@ import { GenkitError } from '@genkit-ai/core'; import { extractItems } from '../extract'; import type { Formatter } from './types'; -export const arrayParser: Formatter = (request) => { - if (request.output?.schema && request.output?.schema.type !== 'array') { - throw new GenkitError({ - status: 'INVALID_ARGUMENT', - message: `Must supply an 'array' schema type when using the 'items' parser format.`, - }); - } +export const arrayFormatter: Formatter = { + name: 'array', + config: { + contentType: 'application/json', + constrained: true, + }, + handler: (request) => { + if (request.output?.schema && request.output?.schema.type !== 'array') { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Must supply an 'array' schema type when using the 'items' parser format.`, + }); + } - let instructions: boolean | string = false; - if (request.output?.schema) { - instructions = `Output should be a JSON array conforming to the following schema: + let instructions: string | undefined; + if (request.output?.schema) { + instructions = `Output should be a JSON array conforming to the following schema: \`\`\` ${JSON.stringify(request.output!.schema!)} \`\`\` `; - } + } - let cursor: number = 0; + return { + parseChunk: (chunk) => { + // first, determine the cursor position from the previous chunks + const cursor = chunk.previousChunks?.length + ? extractItems(chunk.previousText).cursor + : 0; + // then, extract the items starting at that cursor + const { items } = extractItems(chunk.accumulatedText, cursor); - return { - parseChunk: (chunk, emit) => { - const { items, cursor: newCursor } = extractItems( - chunk.accumulatedText, - cursor - ); + return items; + }, - // Emit any complete items - for (const item of items) { - emit(item); - } + parseResponse: (response) => { + const { items } = extractItems(response.text, 0); + return items; + }, - // Update cursor position - cursor = newCursor; - }, - - parseResponse: (response) => { - const { items } = extractItems(response.text, 0); - return items; - }, - - instructions, - }; + instructions, + }; + }, }; diff --git a/js/ai/src/formats/enum.ts b/js/ai/src/formats/enum.ts index dd4744c28f..3aad6eb2be 100644 --- a/js/ai/src/formats/enum.ts +++ b/js/ai/src/formats/enum.ts @@ -17,24 +17,31 @@ import { GenkitError } from '@genkit-ai/core'; import type { Formatter } from './types'; -export const enumParser: Formatter = (request) => { - const schemaType = request.output?.schema?.type; - if (schemaType && schemaType !== 'string' && schemaType !== 'enum') { - throw new GenkitError({ - status: 'INVALID_ARGUMENT', - message: `Must supply a 'string' or 'enum' schema type when using the enum parser format.`, - }); - } +export const enumFormatter: Formatter = { + name: 'enum', + config: { + contentType: 'text/plain', + constrained: true, + }, + handler: (request) => { + const schemaType = request.output?.schema?.type; + if (schemaType && schemaType !== 'string' && schemaType !== 'enum') { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Must supply a 'string' or 'enum' schema type when using the enum parser format.`, + }); + } - let instructions: boolean | string = false; - if (request.output?.schema?.enum) { - instructions = `Output should be ONLY one of the following enum values. Do not output any additional information or add quotes.\n\n${request.output?.schema?.enum.map((v) => v.toString()).join('\n')}`; - } + let instructions: string | undefined; + if (request.output?.schema?.enum) { + instructions = `Output should be ONLY one of the following enum values. Do not output any additional information or add quotes.\n\n${request.output?.schema?.enum.map((v) => v.toString()).join('\n')}`; + } - return { - parseResponse: (response) => { - return response.text.trim(); - }, - instructions, - }; + return { + parseResponse: (response) => { + return response.text.trim(); + }, + instructions, + }; + }, }; diff --git a/js/ai/src/formats/index.ts b/js/ai/src/formats/index.ts index 567e38efd3..40189e4ba7 100644 --- a/js/ai/src/formats/index.ts +++ b/js/ai/src/formats/index.ts @@ -15,26 +15,20 @@ */ import { Registry } from '@genkit-ai/core/registry'; -import { arrayParser } from './array'; -import { enumParser } from './enum'; -import { jsonParser } from './json'; -import { jsonlParser } from './jsonl'; -import { textParser } from './text'; +import { arrayFormatter } from './array'; +import { enumFormatter } from './enum'; +import { jsonFormatter } from './json'; +import { jsonlFormatter } from './jsonl'; +import { textFormatter } from './text'; import { Formatter } from './types'; -export const DEFAULT_FORMATS = { - json: jsonParser, - array: arrayParser, - text: textParser, - enum: enumParser, - jsonl: jsonlParser, -}; - export function defineFormat( registry: Registry, - name: string, - formatter: Formatter + options: { name: string } & Formatter['config'], + handler: Formatter['handler'] ) { + const { name, ...config } = options; + const formatter = { config, handler }; registry.registerValue('format', name, formatter); return formatter; } @@ -53,3 +47,24 @@ export async function resolveFormat( } return arg as Formatter; } + +export const DEFAULT_FORMATS: Formatter[] = [ + jsonFormatter, + arrayFormatter, + textFormatter, + enumFormatter, + jsonlFormatter, +]; + +/** + * initializeFormats registers the default built-in formats on a registry. + */ +export function initializeFormats(registry: Registry) { + for (const format of DEFAULT_FORMATS) { + defineFormat( + registry, + { name: format.name, ...format.config }, + format.handler + ); + } +} diff --git a/js/ai/src/formats/json.ts b/js/ai/src/formats/json.ts index fe0d0bdad7..61fe6e8c1f 100644 --- a/js/ai/src/formats/json.ts +++ b/js/ai/src/formats/json.ts @@ -17,29 +17,34 @@ import { extractJson } from '../extract'; import type { Formatter } from './types'; -export const jsonParser: Formatter = (request) => { - let accumulatedText: string = ''; - let instructions: boolean | string = false; +export const jsonFormatter: Formatter = { + name: 'json', + config: { + contentType: 'application/json', + constrained: true, + }, + handler: (request) => { + let instructions: string | undefined; - if (request.output?.schema) { - instructions = `Output should be in JSON format and conform to the following schema: + if (request.output?.schema) { + instructions = `Output should be in JSON format and conform to the following schema: \`\`\` ${JSON.stringify(request.output!.schema!)} \`\`\` `; - } + } - return { - parseChunk: (chunk, emit) => { - accumulatedText = chunk.accumulatedText; - emit(extractJson(accumulatedText)); - }, + return { + parseChunk: (chunk) => { + return extractJson(chunk.accumulatedText); + }, - parseResponse: (response) => { - return extractJson(response.text); - }, + parseResponse: (response) => { + return extractJson(response.text); + }, - instructions, - }; + instructions, + }; + }, }; diff --git a/js/ai/src/formats/jsonl.ts b/js/ai/src/formats/jsonl.ts index a186faa7f9..d208987149 100644 --- a/js/ai/src/formats/jsonl.ts +++ b/js/ai/src/formats/jsonl.ts @@ -26,57 +26,75 @@ function objectLines(text: string): string[] { .filter((line) => line.startsWith('{')); } -export const jsonlParser: Formatter = (request) => { - if ( - request.output?.schema && - (request.output?.schema.type !== 'array' || - request.output?.schema.items?.type !== 'object') - ) { - throw new GenkitError({ - status: 'INVALID_ARGUMENT', - message: `Must supply an 'array' schema type containing 'object' items when using the 'jsonl' parser format.`, - }); - } +export const jsonlFormatter: Formatter = { + name: 'jsonl', + config: { + contentType: 'application/jsonl', + }, + handler: (request) => { + if ( + request.output?.schema && + (request.output?.schema.type !== 'array' || + request.output?.schema.items?.type !== 'object') + ) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Must supply an 'array' schema type containing 'object' items when using the 'jsonl' parser format.`, + }); + } - let instructions: boolean | string = false; - if (request.output?.schema?.items) { - instructions = `Output should be JSONL format, a sequence of JSON objects (one per line). Each line should conform to the following schema: + let instructions: string | undefined; + if (request.output?.schema?.items) { + instructions = `Output should be JSONL format, a sequence of JSON objects (one per line). Each line should conform to the following schema: \`\`\` ${JSON.stringify(request.output.schema.items)} \`\`\` `; - } + } - let cursor = 0; + return { + parseChunk: (chunk) => { + const results: unknown[] = []; - return { - parseChunk: (chunk, emit) => { - const jsonLines = objectLines(chunk.accumulatedText); + const text = chunk.accumulatedText; - for (let i = cursor; i < jsonLines.length; i++) { - try { - const result = JSON5.parse(jsonLines[i]); - if (result) { - emit(result); + let startIndex = 0; + if (chunk.previousChunks?.length) { + const lastNewline = chunk.previousText.lastIndexOf('\n'); + if (lastNewline !== -1) { + startIndex = lastNewline + 1; } - } catch (e) { - cursor = i; - return; } - } - cursor = jsonLines.length; - }, + const lines = text.slice(startIndex).split('\n'); - parseResponse: (response) => { - const items = objectLines(response.text) - .map((l) => extractJson(l)) - .filter((l) => !!l); + for (const line of lines) { + const trimmed = line.trim(); + if (trimmed.startsWith('{')) { + try { + const result = JSON5.parse(trimmed); + if (result) { + results.push(result); + } + } catch (e) { + break; + } + } + } + + return results; + }, + + parseResponse: (response) => { + const items = objectLines(response.text) + .map((l) => extractJson(l)) + .filter((l) => !!l); - return items; - }, + return items; + }, - instructions, - }; + instructions, + }; + }, }; diff --git a/js/ai/src/formats/text.ts b/js/ai/src/formats/text.ts index f968693893..985b64a961 100644 --- a/js/ai/src/formats/text.ts +++ b/js/ai/src/formats/text.ts @@ -14,17 +14,22 @@ * limitations under the License. */ -import { GenerateResponse, GenerateResponseChunk } from '../generate'; import type { Formatter } from './types'; -export const textParser: Formatter = (request) => { - return { - parseChunk: (chunk: GenerateResponseChunk, emit: (chunk: any) => void) => { - emit(chunk.text); - }, +export const textFormatter: Formatter = { + name: 'text', + config: { + contentType: 'text/plain', + }, + handler: () => { + return { + parseChunk: (chunk) => { + return chunk.text; + }, - parseResponse: (response: GenerateResponse) => { - return response.text; - }, - }; + parseResponse: (response) => { + return response.text; + }, + }; + }, }; diff --git a/js/ai/src/formats/types.d.ts b/js/ai/src/formats/types.d.ts index 3f5736bc40..e8847e572b 100644 --- a/js/ai/src/formats/types.d.ts +++ b/js/ai/src/formats/types.d.ts @@ -14,16 +14,20 @@ * limitations under the License. */ -import { GenerateResponse, GenerateResponseChunk } from '../generate'; -import { GenerateRequest } from '../model'; +import { GenerateResponse, GenerateResponseChunk } from '../generate.js'; +import { ModelRequest, Part } from '../model.js'; -export interface Formatter { - (req: GenerateRequest): { - parseChunk?: ( - chunk: GenerateResponseChunk, - emit: (chunk: any) => void - ) => void; - parseResponse(response: GenerateResponse): any; - instructions?: boolean | string; +type OutputContentTypes = + | 'application/json' + | 'text/plain' + | 'application/jsonl'; + +export interface Formatter { + name: string; + config: ModelRequest['output']; + handler: (req: ModelRequest) => { + parseResponse(response: GenerateResponse): O; + parseChunk?: (chunk: GenerateResponseChunk, cursor?: CC) => CO; + instructions?: string | Part[]; }; } diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index bbad7cb8ab..b7b6289d40 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -22,282 +22,57 @@ import { z, } from '@genkit-ai/core'; import { Registry } from '@genkit-ai/core/registry'; -import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; +import { toJsonSchema } from '@genkit-ai/core/schema'; import { DocumentData } from './document.js'; -import { extractJson } from './extract.js'; -import { generateHelper, GenerateUtilParamSchema } from './generateAction.js'; +import { generateHelper, GenerateUtilParamSchema } from './generate/action.js'; +import { GenerateResponseChunk } from './generate/chunk.js'; +import { GenerateResponse } from './generate/response.js'; import { Message } from './message.js'; import { GenerateRequest, - GenerateResponseChunkData, - GenerateResponseData, GenerationCommonConfigSchema, - GenerationUsage, MessageData, ModelAction, ModelArgument, ModelMiddleware, ModelReference, - ModelResponseData, Part, ToolDefinition, - ToolRequestPart, } from './model.js'; import { ExecutablePrompt } from './prompt.js'; import { resolveTools, ToolArgument, toToolDefinition } from './tool.js'; +export { GenerateResponse, GenerateResponseChunk }; -/** - * GenerateResponse is the result from a `generate()` call and contains one or - * more generated candidate messages. - */ -export class GenerateResponse implements ModelResponseData { - /** The generated message. */ - message?: Message; - /** The reason generation stopped for this request. */ - finishReason: ModelResponseData['finishReason']; - /** Additional information about why the model stopped generating, if any. */ - finishMessage?: string; - /** Usage information. */ - usage: GenerationUsage; - /** Provider-specific response data. */ - custom: unknown; - /** The request that generated this response. */ - request?: GenerateRequest; - - constructor(response: GenerateResponseData, request?: GenerateRequest) { - // Check for candidates in addition to message for backwards compatibility. - const generatedMessage = - response.message || response.candidates?.[0]?.message; - if (generatedMessage) { - this.message = new Message(generatedMessage); - } - this.finishReason = - response.finishReason || response.candidates?.[0]?.finishReason!; - this.finishMessage = - response.finishMessage || response.candidates?.[0]?.finishMessage; - this.usage = response.usage || {}; - this.custom = response.custom || {}; - this.request = request; - } - - private get assertMessage(): Message { - if (!this.message) - throw new Error( - 'Operation could not be completed because the response does not contain a generated message.' - ); - return this.message; - } - - /** - * Throws an error if the response does not contain valid output. - */ - assertValid(request?: GenerateRequest): void { - if (this.finishReason === 'blocked') { - throw new GenerationBlockedError( - this, - `Generation blocked${this.finishMessage ? `: ${this.finishMessage}` : '.'}` - ); - } - - if (!this.message) { - throw new GenerationResponseError( - this, - `Model did not generate a message. Finish reason: '${this.finishReason}': ${this.finishMessage}` - ); - } - - if (request?.output?.schema || this.request?.output?.schema) { - const o = this.output; - parseSchema(o, { - jsonSchema: request?.output?.schema || this.request?.output?.schema, - }); - } - } - - isValid(request?: GenerateRequest): boolean { - try { - this.assertValid(request); - return true; - } catch (e) { - return false; - } - } - - /** - * If the selected candidate's message contains a `data` part, it is returned. Otherwise, - * the `output()` method extracts the first valid JSON object or array from the text - * contained in the selected candidate's message and returns it. - * - * @param index The candidate index from which to extract output. If not provided, finds first candidate that conforms to output schema. - * @returns The structured output contained in the selected candidate. - */ - get output(): O | null { - return this.message?.output || null; - } - - /** - * Concatenates all `text` parts present in the candidate's message with no delimiter. - * @param index The candidate index from which to extract text, defaults to first candidate. - * @returns A string of all concatenated text parts. - */ - get text(): string { - return this.message?.text || ''; - } - - /** - * Returns the first detected media part in the selected candidate's message. Useful for - * extracting (for example) an image from a generation expected to create one. - * @param index The candidate index from which to extract media, defaults to first candidate. - * @returns The first detected `media` part in the candidate. - */ - get media(): { url: string; contentType?: string } | null { - return this.message?.media || null; - } - - /** - * Returns the first detected `data` part of the selected candidate's message. - * @param index The candidate index from which to extract data, defaults to first candidate. - * @returns The first `data` part detected in the candidate (if any). - */ - get data(): O | null { - return this.message?.data || null; - } - - /** - * Returns all tool request found in the candidate. - * @param index The candidate index from which to extract tool requests, defaults to first candidate. - * @returns Array of all tool request found in the candidate. - */ - get toolRequests(): ToolRequestPart[] { - return this.message?.toolRequests || []; - } - - /** - * Appends the message generated by the selected candidate to the messages already - * present in the generation request. The result of this method can be safely - * serialized to JSON for persistence in a database. - * @param index The candidate index to utilize during conversion, defaults to first candidate. - * @returns A serializable list of messages compatible with `generate({history})`. - */ - get messages(): MessageData[] { - if (!this.request) - throw new Error( - "Can't construct history for response without request reference." - ); - if (!this.message) - throw new Error( - "Can't construct history for response without generated message." - ); - return [...this.request?.messages, this.message.toJSON()]; - } - - get raw(): unknown { - return this.raw ?? this.custom; - } - - toJSON(): ModelResponseData { - const out = { - message: this.message?.toJSON(), - finishReason: this.finishReason, - finishMessage: this.finishMessage, - usage: this.usage, - custom: (this.custom as { toJSON?: () => any }).toJSON?.() || this.custom, - request: this.request, - }; - if (!out.finishMessage) delete out.finishMessage; - if (!out.request) delete out.request; - return out; - } -} - -export class GenerateResponseChunk - implements GenerateResponseChunkData -{ - /** The index of the candidate this chunk corresponds to. */ - index?: number; - /** The content generated in this chunk. */ - content: Part[]; - /** Custom model-specific data for this chunk. */ - custom?: unknown; - /** Accumulated chunks for partial output extraction. */ - accumulatedChunks?: GenerateResponseChunkData[]; - - constructor( - data: GenerateResponseChunkData, - accumulatedChunks?: GenerateResponseChunkData[] - ) { - this.index = data.index; - this.content = data.content || []; - this.custom = data.custom; - this.accumulatedChunks = accumulatedChunks; - } - - /** - * Concatenates all `text` parts present in the chunk with no delimiter. - * @returns A string of all concatenated text parts. - */ - get text(): string { - return this.content.map((part) => part.text || '').join(''); - } - - /** - * Concatenates all `text` parts of all chunks from the response thus far. - * @returns A string of all concatenated chunk text content. - */ - get accumulatedText(): string { - if (!this.accumulatedChunks) - throw new GenkitError({ - status: 'FAILED_PRECONDITION', - message: 'Cannot compose accumulated text without accumulated chunks.', - }); - - return this.accumulatedChunks - ?.map((c) => c.content.map((p) => p.text || '').join('')) - .join(''); - } - - /** - * Returns the first media part detected in the chunk. Useful for extracting - * (for example) an image from a generation expected to create one. - * @returns The first detected `media` part in the chunk. - */ - get media(): { url: string; contentType?: string } | null { - return this.content.find((part) => part.media)?.media || null; - } - - /** - * Returns the first detected `data` part of a chunk. - * @returns The first `data` part detected in the chunk (if any). - */ - get data(): T | null { - return this.content.find((part) => part.data)?.data as T | null; - } - - /** - * Returns all tool request found in this chunk. - * @returns Array of all tool request found in this chunk. - */ - get toolRequests(): ToolRequestPart[] { - return this.content.filter( - (part) => !!part.toolRequest - ) as ToolRequestPart[]; - } - - /** - * Attempts to extract the longest valid JSON substring from the accumulated chunks. - * @returns The longest valid JSON substring found in the accumulated chunks. - */ - get output(): T | null { - if (!this.accumulatedChunks) return null; - const accumulatedText = this.accumulatedChunks - .map((chunk) => chunk.content.map((part) => part.text || '').join('')) - .join(''); - return extractJson(accumulatedText, false); - } - - toJSON(): GenerateResponseChunkData { - return { index: this.index, content: this.content, custom: this.custom }; - } +export interface GenerateOptions< + O extends z.ZodTypeAny = z.ZodTypeAny, + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +> { + /** A model name (e.g. `vertexai/gemini-1.0-pro`) or reference. */ + model?: ModelArgument; + /** The system prompt to be included in the generate request. Can be a string for a simple text prompt or one or more parts for multi-modal prompts (subject to model support). */ + system?: string | Part | Part[]; + /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ + prompt?: string | Part | Part[]; + /** Retrieved documents to be used as context for this generation. */ + docs?: DocumentData[]; + /** Conversation messages (history) for multi-turn prompting when supported by the underlying model. */ + messages?: (MessageData & { content: Part[] | string | (string | Part)[] })[]; + /** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */ + tools?: ToolArgument[]; + /** Configuration for the generation request. */ + config?: z.infer; + /** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */ + output?: { + format?: 'json' | 'text' | 'media'; + schema?: O; + jsonSchema?: any; + }; + /** When true, return tool calls for manual processing instead of automatically resolving them. */ + returnToolRequests?: boolean; + /** When provided, models supporting streaming will call the provided callback with chunks as generation progresses. */ + streamingCallback?: StreamingCallback; + /** Middleware to be used with this model call. */ + use?: ModelMiddleware[]; } export async function toGenerateRequest( @@ -349,38 +124,6 @@ export async function toGenerateRequest( return out; } -export interface GenerateOptions< - O extends z.ZodTypeAny = z.ZodTypeAny, - CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, -> { - /** A model name (e.g. `vertexai/gemini-1.0-pro`) or reference. */ - model?: ModelArgument; - /** The system prompt to be included in the generate request. Can be a string for a simple text prompt or one or more parts for multi-modal prompts (subject to model support). */ - system?: string | Part | Part[]; - /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ - prompt?: string | Part | Part[]; - /** Retrieved documents to be used as context for this generation. */ - docs?: DocumentData[]; - /** Conversation messages (history) for multi-turn prompting when supported by the underlying model. */ - messages?: (MessageData & { content: Part[] | string | (string | Part)[] })[]; - /** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */ - tools?: ToolArgument[]; - /** Configuration for the generation request. */ - config?: z.infer; - /** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */ - output?: { - format?: 'json' | 'text' | 'media'; - schema?: O; - jsonSchema?: any; - }; - /** When true, return tool calls for manual processing instead of automatically resolving them. */ - returnToolRequests?: boolean; - /** When provided, models supporting streaming will call the provided callback with chunks as generation progresses. */ - streamingCallback?: StreamingCallback; - /** Middleware to be used with this model call. */ - use?: ModelMiddleware[]; -} - interface ResolvedModel { modelAction: ModelAction; config?: z.infer; diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generate/action.ts similarity index 97% rename from js/ai/src/generateAction.ts rename to js/ai/src/generate/action.ts index f75a6f7619..7a563c701e 100644 --- a/js/ai/src/generateAction.ts +++ b/js/ai/src/generate/action.ts @@ -24,12 +24,12 @@ import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing'; import * as clc from 'colorette'; -import { DocumentDataSchema } from './document.js'; +import { DocumentDataSchema } from '../document.js'; import { GenerateResponse, GenerateResponseChunk, tagAsPreamble, -} from './generate.js'; +} from '../generate.js'; import { GenerateRequest, GenerateRequestSchema, @@ -42,8 +42,8 @@ import { Role, ToolDefinitionSchema, ToolResponsePart, -} from './model.js'; -import { lookupToolByName, ToolAction, toToolDefinition } from './tool.js'; +} from '../model.js'; +import { lookupToolByName, ToolAction, toToolDefinition } from '../tool.js'; export const GenerateUtilParamSchema = z.object({ /** A model name (e.g. `vertexai/gemini-1.0-pro`). */ @@ -142,12 +142,14 @@ async function generate( streamingCallback ? (chunk: GenerateResponseChunkData) => { // Store accumulated chunk data - accumulatedChunks.push(chunk); if (streamingCallback) { streamingCallback!( - new GenerateResponseChunk(chunk, accumulatedChunks) + new GenerateResponseChunk(chunk, { + previousChunks: accumulatedChunks, + }) ); } + accumulatedChunks.push(chunk); } : undefined, async () => { diff --git a/js/ai/src/generate/chunk.ts b/js/ai/src/generate/chunk.ts new file mode 100644 index 0000000000..4cf20d67a5 --- /dev/null +++ b/js/ai/src/generate/chunk.ts @@ -0,0 +1,135 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { GenkitError } from '@genkit-ai/core'; +import { extractJson } from '../extract.js'; +import { + GenerateResponseChunkData, + Part, + Role, + ToolRequestPart, +} from '../model.js'; + +export interface ChunkParser { + (chunk: GenerateResponseChunk): T; +} + +export class GenerateResponseChunk + implements GenerateResponseChunkData +{ + /** The index of the message this chunk corresponds to, starting with `0` for the first model response of the generation. */ + index?: number; + /** The role of the message this chunk corresponds to. Will always be `model` or `tool`. */ + role?: Role; + /** The content generated in this chunk. */ + content: Part[]; + /** Custom model-specific data for this chunk. */ + custom?: unknown; + /** Accumulated chunks for partial output extraction. */ + previousChunks?: GenerateResponseChunkData[]; + /** The parser to be used to parse `output` from this chunk. */ + parser?: ChunkParser; + + constructor( + data: GenerateResponseChunkData, + options?: { + previousChunks?: GenerateResponseChunkData[]; + role?: Role; + index?: number; + parser?: ChunkParser; + } + ) { + this.content = data.content || []; + this.custom = data.custom; + this.previousChunks = options?.previousChunks + ? [...options.previousChunks] + : undefined; + this.index = options?.index; + this.role = options?.role; + this.parser = options?.parser; + } + + /** + * Concatenates all `text` parts present in the chunk with no delimiter. + * @returns A string of all concatenated text parts. + */ + get text(): string { + return this.content.map((part) => part.text || '').join(''); + } + + /** + * Concatenates all `text` parts of all chunks from the response thus far. + * @returns A string of all concatenated chunk text content. + */ + get accumulatedText(): string { + return this.previousText + this.text; + } + + /** + * Concatenates all `text` parts of all preceding chunks. + */ + get previousText(): string { + if (!this.previousChunks) + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: 'Cannot compose accumulated text without previous chunks.', + }); + + return this.previousChunks + ?.map((c) => c.content.map((p) => p.text || '').join('')) + .join(''); + } + + /** + * Returns the first media part detected in the chunk. Useful for extracting + * (for example) an image from a generation expected to create one. + * @returns The first detected `media` part in the chunk. + */ + get media(): { url: string; contentType?: string } | null { + return this.content.find((part) => part.media)?.media || null; + } + + /** + * Returns the first detected `data` part of a chunk. + * @returns The first `data` part detected in the chunk (if any). + */ + get data(): T | null { + return this.content.find((part) => part.data)?.data as T | null; + } + + /** + * Returns all tool request found in this chunk. + * @returns Array of all tool request found in this chunk. + */ + get toolRequests(): ToolRequestPart[] { + return this.content.filter( + (part) => !!part.toolRequest + ) as ToolRequestPart[]; + } + + /** + * Parses the chunk into the desired output format using the parser associated + * with the generate request, or falls back to naive JSON parsing otherwise. + */ + get output(): T | null { + if (this.parser) return this.parser(this); + return this.data || extractJson(this.accumulatedText); + } + + toJSON(): GenerateResponseChunkData { + return { content: this.content, custom: this.custom }; + } +} diff --git a/js/ai/src/generate/response.ts b/js/ai/src/generate/response.ts new file mode 100644 index 0000000000..1094c49171 --- /dev/null +++ b/js/ai/src/generate/response.ts @@ -0,0 +1,194 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { parseSchema } from '@genkit-ai/core/schema'; +import { + GenerationBlockedError, + GenerationResponseError, +} from '../generate.js'; +import { Message } from '../message.js'; +import { + GenerateRequest, + GenerateResponseData, + GenerationUsage, + MessageData, + ModelResponseData, + ToolRequestPart, +} from '../model.js'; + +/** + * GenerateResponse is the result from a `generate()` call and contains one or + * more generated candidate messages. + */ +export class GenerateResponse implements ModelResponseData { + /** The generated message. */ + message?: Message; + /** The reason generation stopped for this request. */ + finishReason: ModelResponseData['finishReason']; + /** Additional information about why the model stopped generating, if any. */ + finishMessage?: string; + /** Usage information. */ + usage: GenerationUsage; + /** Provider-specific response data. */ + custom: unknown; + /** The request that generated this response. */ + request?: GenerateRequest; + + constructor(response: GenerateResponseData, request?: GenerateRequest) { + // Check for candidates in addition to message for backwards compatibility. + const generatedMessage = + response.message || response.candidates?.[0]?.message; + if (generatedMessage) { + this.message = new Message(generatedMessage); + } + this.finishReason = + response.finishReason || response.candidates?.[0]?.finishReason!; + this.finishMessage = + response.finishMessage || response.candidates?.[0]?.finishMessage; + this.usage = response.usage || {}; + this.custom = response.custom || {}; + this.request = request; + } + + private get assertMessage(): Message { + if (!this.message) + throw new Error( + 'Operation could not be completed because the response does not contain a generated message.' + ); + return this.message; + } + + /** + * Throws an error if the response does not contain valid output. + */ + assertValid(request?: GenerateRequest): void { + if (this.finishReason === 'blocked') { + throw new GenerationBlockedError( + this, + `Generation blocked${this.finishMessage ? `: ${this.finishMessage}` : '.'}` + ); + } + + if (!this.message) { + throw new GenerationResponseError( + this, + `Model did not generate a message. Finish reason: '${this.finishReason}': ${this.finishMessage}` + ); + } + + if (request?.output?.schema || this.request?.output?.schema) { + const o = this.output; + parseSchema(o, { + jsonSchema: request?.output?.schema || this.request?.output?.schema, + }); + } + } + + isValid(request?: GenerateRequest): boolean { + try { + this.assertValid(request); + return true; + } catch (e) { + return false; + } + } + + /** + * If the selected candidate's message contains a `data` part, it is returned. Otherwise, + * the `output()` method extracts the first valid JSON object or array from the text + * contained in the selected candidate's message and returns it. + * + * @param index The candidate index from which to extract output. If not provided, finds first candidate that conforms to output schema. + * @returns The structured output contained in the selected candidate. + */ + get output(): O | null { + return this.message?.output || null; + } + + /** + * Concatenates all `text` parts present in the candidate's message with no delimiter. + * @param index The candidate index from which to extract text, defaults to first candidate. + * @returns A string of all concatenated text parts. + */ + get text(): string { + return this.message?.text || ''; + } + + /** + * Returns the first detected media part in the selected candidate's message. Useful for + * extracting (for example) an image from a generation expected to create one. + * @param index The candidate index from which to extract media, defaults to first candidate. + * @returns The first detected `media` part in the candidate. + */ + get media(): { url: string; contentType?: string } | null { + return this.message?.media || null; + } + + /** + * Returns the first detected `data` part of the selected candidate's message. + * @param index The candidate index from which to extract data, defaults to first candidate. + * @returns The first `data` part detected in the candidate (if any). + */ + get data(): O | null { + return this.message?.data || null; + } + + /** + * Returns all tool request found in the candidate. + * @param index The candidate index from which to extract tool requests, defaults to first candidate. + * @returns Array of all tool request found in the candidate. + */ + get toolRequests(): ToolRequestPart[] { + return this.message?.toolRequests || []; + } + + /** + * Appends the message generated by the selected candidate to the messages already + * present in the generation request. The result of this method can be safely + * serialized to JSON for persistence in a database. + * @param index The candidate index to utilize during conversion, defaults to first candidate. + * @returns A serializable list of messages compatible with `generate({history})`. + */ + get messages(): MessageData[] { + if (!this.request) + throw new Error( + "Can't construct history for response without request reference." + ); + if (!this.message) + throw new Error( + "Can't construct history for response without generated message." + ); + return [...this.request?.messages, this.message.toJSON()]; + } + + get raw(): unknown { + return this.raw ?? this.custom; + } + + toJSON(): ModelResponseData { + const out = { + message: this.message?.toJSON(), + finishReason: this.finishReason, + finishMessage: this.finishMessage, + usage: this.usage, + custom: (this.custom as { toJSON?: () => any }).toJSON?.() || this.custom, + request: this.request, + }; + if (!out.finishMessage) delete out.finishMessage; + if (!out.request) delete out.request; + return out; + } +} diff --git a/js/ai/src/index.ts b/js/ai/src/index.ts index eb7ea3e416..5b63b9f10f 100644 --- a/js/ai/src/index.ts +++ b/js/ai/src/index.ts @@ -37,6 +37,7 @@ export { } from './evaluator.js'; export { GenerateResponse, + GenerateResponseChunk, GenerationBlockedError, GenerationResponseError, generate, diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 4a176e70be..342cb06084 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -186,6 +186,8 @@ export type GenerationCommonConfig = typeof GenerationCommonConfigSchema; const OutputConfigSchema = z.object({ format: OutputFormatSchema.optional(), schema: z.record(z.any()).optional(), + constrained: z.boolean().optional(), + contentType: z.string().optional(), }); export type OutputConfig = z.infer; diff --git a/js/ai/tests/formats/array_test.ts b/js/ai/tests/formats/array_test.ts index bb8a8f8a58..f174d1cd0c 100644 --- a/js/ai/tests/formats/array_test.ts +++ b/js/ai/tests/formats/array_test.ts @@ -16,7 +16,7 @@ import assert from 'node:assert'; import { describe, it } from 'node:test'; -import { arrayParser } from '../../src/formats/array.js'; +import { arrayFormatter } from '../../src/formats/array.js'; import { GenerateResponse, GenerateResponseChunk } from '../../src/generate.js'; import { GenerateResponseChunkData } from '../../src/model.js'; @@ -65,22 +65,21 @@ describe('arrayFormat', () => { for (const st of streamingTests) { it(st.desc, () => { - const parser = arrayParser({ messages: [] }); + const parser = arrayFormatter.handler({ messages: [] }); const chunks: GenerateResponseChunkData[] = []; - let lastEmitted: any[] = []; + let lastCursor = 0; + for (const chunk of st.chunks) { const newChunk: GenerateResponseChunkData = { content: [{ text: chunk.text }], }; - chunks.push(newChunk); - lastEmitted = []; - const emit = (item: any) => { - lastEmitted.push(item); - }; - parser.parseChunk!(new GenerateResponseChunk(newChunk, chunks), emit); + const result = parser.parseChunk!( + new GenerateResponseChunk(newChunk, { previousChunks: chunks }) + ); + chunks.push(newChunk); - assert.deepStrictEqual(lastEmitted, chunk.want); + assert.deepStrictEqual(result, chunk.want); } }); } @@ -122,7 +121,7 @@ describe('arrayFormat', () => { for (const rt of responseTests) { it(rt.desc, () => { - const parser = arrayParser({ messages: [] }); + const parser = arrayFormatter.handler({ messages: [] }); assert.deepStrictEqual(parser.parseResponse(rt.response), rt.want); }); } @@ -153,7 +152,7 @@ describe('arrayFormat', () => { for (const et of errorTests) { it(et.desc, () => { assert.throws(() => { - arrayParser(et.request); + arrayFormatter.handler(et.request); }, et.wantError); }); } diff --git a/js/ai/tests/formats/enum_test.ts b/js/ai/tests/formats/enum_test.ts index 338921821a..283a714264 100644 --- a/js/ai/tests/formats/enum_test.ts +++ b/js/ai/tests/formats/enum_test.ts @@ -14,83 +14,70 @@ * limitations under the License. */ -import { GenkitError } from '@genkit-ai/core'; import assert from 'node:assert'; import { describe, it } from 'node:test'; -import { enumParser } from '../../src/formats/enum.js'; +import { enumFormatter } from '../../src/formats/enum.js'; import { GenerateResponse } from '../../src/generate.js'; describe('enumFormat', () => { const responseTests = [ { - desc: 'parses simple string response', + desc: 'parses simple enum value', response: new GenerateResponse({ message: { role: 'model', - content: [{ text: 'value1' }], + content: [{ text: 'VALUE1' }], }, }), - want: 'value1', + want: 'VALUE1', }, { - desc: 'trims whitespace from response', + desc: 'trims whitespace', response: new GenerateResponse({ message: { role: 'model', - content: [{ text: ' value2 \n' }], + content: [{ text: ' VALUE2\n' }], }, }), - want: 'value2', + want: 'VALUE2', }, ]; for (const rt of responseTests) { it(rt.desc, () => { - const parser = enumParser({ - messages: [], - output: { schema: { type: 'string' } }, - }); + const parser = enumFormatter.handler({ messages: [] }); assert.strictEqual(parser.parseResponse(rt.response), rt.want); }); } - it('throws error for invalid schema type', () => { - assert.throws( - () => { - enumParser({ messages: [], output: { schema: { type: 'number' } } }); + const errorTests = [ + { + desc: 'throws error for number schema type', + request: { + messages: [], + output: { + schema: { type: 'number' }, + }, }, - (err: GenkitError) => { - return ( - err.status === 'INVALID_ARGUMENT' && - err.message.includes( - `Must supply a 'string' or 'enum' schema type when using the enum parser format.` - ) - ); - } - ); - }); - - it('includes enum values in instructions when provided', () => { - const enumValues = ['option1', 'option2', 'option3']; - const parser = enumParser({ - messages: [], - output: { schema: { type: 'enum', enum: enumValues } }, - }); - - assert.match( - parser.instructions as string, - /Output should be ONLY one of the following enum values/ - ); - for (const value of enumValues) { - assert.match(parser.instructions as string, new RegExp(value)); - } - }); + wantError: /Must supply a 'string' or 'enum' schema type/, + }, + { + desc: 'throws error for array schema type', + request: { + messages: [], + output: { + schema: { type: 'array' }, + }, + }, + wantError: /Must supply a 'string' or 'enum' schema type/, + }, + ]; - it('has no instructions when no enum values provided', () => { - const parser = enumParser({ - messages: [], - output: { schema: { type: 'string' } }, + for (const et of errorTests) { + it(et.desc, () => { + assert.throws(() => { + enumFormatter.handler(et.request); + }, et.wantError); }); - assert.strictEqual(parser.instructions, false); - }); + } }); diff --git a/js/ai/tests/formats/json_test.ts b/js/ai/tests/formats/json_test.ts index f5824fa418..833ca790dd 100644 --- a/js/ai/tests/formats/json_test.ts +++ b/js/ai/tests/formats/json_test.ts @@ -16,35 +16,31 @@ import assert from 'node:assert'; import { describe, it } from 'node:test'; -import { jsonParser } from '../../src/formats/json.js'; +import { jsonFormatter } from '../../src/formats/json.js'; import { GenerateResponse, GenerateResponseChunk } from '../../src/generate.js'; import { GenerateResponseChunkData } from '../../src/model.js'; describe('jsonFormat', () => { const streamingTests = [ { - desc: 'emits partial object as it streams', + desc: 'parses complete JSON object', chunks: [ { - text: '{"name": "test', - want: { name: 'test' }, - }, - { - text: '", "value": 42}', - want: { name: 'test', value: 42 }, + text: '{"id": 1, "name": "test"}', + want: { id: 1, name: 'test' }, }, ], }, { - desc: 'handles nested objects', + desc: 'handles partial JSON', chunks: [ { - text: '{"outer": {"inner": ', - want: { outer: {} }, + text: '{"id": 1', + want: { id: 1 }, }, { - text: '"value"}}', - want: { outer: { inner: 'value' } }, + text: ', "name": "test"}', + want: { id: 1, name: 'test' }, }, ], }, @@ -52,25 +48,12 @@ describe('jsonFormat', () => { desc: 'handles preamble with code fence', chunks: [ { - text: 'Here is the JSON:\n\n```json\n{"key": ', - want: {}, + text: 'Here is the JSON:\n\n```json\n', + want: null, }, { - text: '"value"}\n```', - want: { key: 'value' }, - }, - ], - }, - { - desc: 'handles arrays', - chunks: [ - { - text: '[{"id": 1}, {"id"', - want: [{ id: 1 }, {}], - }, - { - text: ': 2}]', - want: [{ id: 1 }, { id: 2 }], + text: '{"id": 1}\n```', + want: { id: 1 }, }, ], }, @@ -78,102 +61,63 @@ describe('jsonFormat', () => { for (const st of streamingTests) { it(st.desc, () => { - const parser = jsonParser({ messages: [] }); + const parser = jsonFormatter.handler({ messages: [] }); const chunks: GenerateResponseChunkData[] = []; - let lastEmitted: any; + let lastCursor = ''; + for (const chunk of st.chunks) { const newChunk: GenerateResponseChunkData = { content: [{ text: chunk.text }], }; - chunks.push(newChunk); - lastEmitted = undefined; - const emit = (value: any) => { - lastEmitted = value; - }; - parser.parseChunk!(new GenerateResponseChunk(newChunk, chunks), emit); + const result = parser.parseChunk!( + new GenerateResponseChunk(newChunk, { previousChunks: [...chunks] }), + lastCursor + ); + chunks.push(newChunk); - assert.deepStrictEqual(lastEmitted, chunk.want); + assert.deepStrictEqual(result, chunk.want); } }); } const responseTests = [ { - desc: 'parses complete object response', - response: new GenerateResponse({ - message: { - role: 'model', - content: [{ text: '{"name": "test", "value": 42}' }], - }, - }), - want: { name: 'test', value: 42 }, - }, - { - desc: 'parses array response', + desc: 'parses complete JSON response', response: new GenerateResponse({ message: { role: 'model', - content: [{ text: '[1, 2, 3]' }], + content: [{ text: '{"id": 1, "name": "test"}' }], }, }), - want: [1, 2, 3], + want: { id: 1, name: 'test' }, }, { - desc: 'parses nested structures', + desc: 'handles empty response', response: new GenerateResponse({ message: { role: 'model', - content: [{ text: '{"outer": {"inner": [1, 2]}}' }], + content: [{ text: '' }], }, }), - want: { outer: { inner: [1, 2] } }, + want: null, }, { - desc: 'parses with preamble and code fence', + desc: 'parses JSON with preamble and code fence', response: new GenerateResponse({ message: { role: 'model', - content: [ - { text: 'Here is the JSON:\n\n```json\n{"key": "value"}\n```' }, - ], + content: [{ text: 'Here is the JSON:\n\n```json\n{"id": 1}\n```' }], }, }), - want: { key: 'value' }, + want: { id: 1 }, }, ]; for (const rt of responseTests) { it(rt.desc, () => { - const parser = jsonParser({ messages: [] }); + const parser = jsonFormatter.handler({ messages: [] }); assert.deepStrictEqual(parser.parseResponse(rt.response), rt.want); }); } - - it('includes schema in instructions when provided', () => { - const schema = { - type: 'object', - properties: { - name: { type: 'string' }, - }, - }; - const parser = jsonParser({ - messages: [], - output: { schema }, - }); - - assert.match( - parser.instructions as string, - /Output should be in JSON format/ - ); - assert.match( - parser.instructions as string, - new RegExp(JSON.stringify(schema)) - ); - }); - - it('has no instructions when no schema provided', () => { - const parser = jsonParser({ messages: [] }); - assert.strictEqual(parser.instructions, false); - }); }); diff --git a/js/ai/tests/formats/jsonl_test.ts b/js/ai/tests/formats/jsonl_test.ts index a08119be41..6d7fe138d5 100644 --- a/js/ai/tests/formats/jsonl_test.ts +++ b/js/ai/tests/formats/jsonl_test.ts @@ -16,57 +16,57 @@ import assert from 'node:assert'; import { describe, it } from 'node:test'; -import { jsonlParser } from '../../src/formats/jsonl.js'; +import { jsonlFormatter } from '../../src/formats/jsonl.js'; import { GenerateResponse, GenerateResponseChunk } from '../../src/generate.js'; import { GenerateResponseChunkData } from '../../src/model.js'; describe('jsonlFormat', () => { const streamingTests = [ { - desc: 'emits complete objects line by line', + desc: 'emits complete JSON objects as they arrive', chunks: [ { - text: '{"id": 1}\n{"id"', - want: [{ id: 1 }], + text: '{"id": 1, "name": "first"}\n', + want: [{ id: 1, name: 'first' }], }, { - text: ': 2}\n{"id": 3}', - want: [{ id: 2 }, { id: 3 }], + text: '{"id": 2, "name": "second"}\n{"id": 3', + want: [{ id: 2, name: 'second' }], + }, + { + text: ', "name": "third"}\n', + want: [{ id: 3, name: 'third' }], }, ], }, { - desc: 'handles preamble with code fence', + desc: 'handles single object', chunks: [ { - text: 'Here are the items:\n\n```jsonl\n{"id": 1', - want: [], - }, - { - text: '}\n{"id": 2}\n```', - want: [{ id: 1 }, { id: 2 }], + text: '{"id": 1, "name": "single"}\n', + want: [{ id: 1, name: 'single' }], }, ], }, { - desc: 'ignores non-object lines', + desc: 'handles preamble with code fence', chunks: [ { - text: 'Starting output:\n{"id": 1}\nsome text\n{"id": 2}', - want: [{ id: 1 }, { id: 2 }], + text: 'Here are the objects:\n\n```\n', + want: [], + }, + { + text: '{"id": 1, "name": "item"}\n```', + want: [{ id: 1, name: 'item' }], }, ], }, { - desc: 'handles objects with nested structures', + desc: 'ignores non-object lines', chunks: [ { - text: '{"user": {"name": "test"}}\n{"data": ', - want: [{ user: { name: 'test' } }], - }, - { - text: '{"values": [1,2]}}', - want: [{ data: { values: [1, 2] } }], + text: 'First object:\n{"id": 1}\nSecond object:\n{"id": 2}\n', + want: [{ id: 1 }, { id: 2 }], }, ], }, @@ -74,55 +74,53 @@ describe('jsonlFormat', () => { for (const st of streamingTests) { it(st.desc, () => { - const parser = jsonlParser({ messages: [] }); + const parser = jsonlFormatter.handler({ messages: [] }); const chunks: GenerateResponseChunkData[] = []; - let lastEmitted: any[] = []; + for (const chunk of st.chunks) { const newChunk: GenerateResponseChunkData = { content: [{ text: chunk.text }], }; - chunks.push(newChunk); - lastEmitted = []; - const emit = (item: any) => { - lastEmitted.push(item); - }; - parser.parseChunk!(new GenerateResponseChunk(newChunk, chunks), emit); + const result = parser.parseChunk!( + new GenerateResponseChunk(newChunk, { previousChunks: chunks }) + ); + chunks.push(newChunk); - assert.deepStrictEqual(lastEmitted, chunk.want); + assert.deepStrictEqual(result, chunk.want); } }); } const responseTests = [ { - desc: 'parses multiple objects', + desc: 'parses complete JSONL response', response: new GenerateResponse({ message: { role: 'model', - content: [{ text: '{"id": 1}\n{"id": 2}\n{"id": 3}' }], + content: [{ text: '{"id": 1, "name": "test"}\n{"id": 2}\n' }], }, }), - want: [{ id: 1 }, { id: 2 }, { id: 3 }], + want: [{ id: 1, name: 'test' }, { id: 2 }], }, { - desc: 'handles empty lines and non-object lines', + desc: 'handles empty response', response: new GenerateResponse({ message: { role: 'model', - content: [{ text: '\n{"id": 1}\nsome text\n{"id": 2}\n' }], + content: [{ text: '' }], }, }), - want: [{ id: 1 }, { id: 2 }], + want: [], }, { - desc: 'parses with preamble and code fence', + desc: 'parses JSONL with preamble and code fence', response: new GenerateResponse({ message: { role: 'model', content: [ { - text: 'Here are the items:\n\n```jsonl\n{"id": 1}\n{"id": 2}\n```', + text: 'Here are the objects:\n\n```\n{"id": 1}\n{"id": 2}\n```', }, ], }, @@ -133,7 +131,7 @@ describe('jsonlFormat', () => { for (const rt of responseTests) { it(rt.desc, () => { - const parser = jsonlParser({ messages: [] }); + const parser = jsonlFormatter.handler({ messages: [] }); assert.deepStrictEqual(parser.parseResponse(rt.response), rt.want); }); } @@ -154,10 +152,7 @@ describe('jsonlFormat', () => { request: { messages: [], output: { - schema: { - type: 'array', - items: { type: 'string' }, - }, + schema: { type: 'array', items: { type: 'string' } }, }, }, wantError: /Must supply an 'array' schema type containing 'object' items/, @@ -167,38 +162,8 @@ describe('jsonlFormat', () => { for (const et of errorTests) { it(et.desc, () => { assert.throws(() => { - jsonlParser(et.request); + jsonlFormatter.handler(et.request); }, et.wantError); }); } - - it('includes schema in instructions when provided', () => { - const schema = { - type: 'array', - items: { - type: 'object', - properties: { - id: { type: 'number' }, - }, - }, - }; - const parser = jsonlParser({ - messages: [], - output: { schema }, - }); - - assert.match( - parser.instructions as string, - /Output should be JSONL format/ - ); - assert.match( - parser.instructions as string, - new RegExp(JSON.stringify(schema.items)) - ); - }); - - it('has no instructions when no schema provided', () => { - const parser = jsonlParser({ messages: [] }); - assert.strictEqual(parser.instructions, false); - }); }); diff --git a/js/ai/tests/formats/text_test.ts b/js/ai/tests/formats/text_test.ts index 0e1f268042..cee7a719cd 100644 --- a/js/ai/tests/formats/text_test.ts +++ b/js/ai/tests/formats/text_test.ts @@ -16,60 +16,83 @@ import assert from 'node:assert'; import { describe, it } from 'node:test'; -import { textParser } from '../../src/formats/text.js'; +import { textFormatter } from '../../src/formats/text.js'; import { GenerateResponse, GenerateResponseChunk } from '../../src/generate.js'; import { GenerateResponseChunkData } from '../../src/model.js'; describe('textFormat', () => { const streamingTests = [ { - desc: 'emits each chunk as it comes', + desc: 'emits text chunks as they arrive', chunks: [ - { text: 'this is', want: ['this is'] }, - { text: ' a two-chunk response', want: [' a two-chunk response'] }, + { + text: 'Hello', + want: 'Hello', + }, + { + text: ' world', + want: ' world', + }, + ], + }, + { + desc: 'handles empty chunks', + chunks: [ + { + text: '', + want: '', + }, ], }, ]; for (const st of streamingTests) { it(st.desc, () => { - const parser = textParser({ messages: [] }); + const parser = textFormatter.handler({ messages: [] }); const chunks: GenerateResponseChunkData[] = []; - let lastEmitted: string[] = []; + for (const chunk of st.chunks) { const newChunk: GenerateResponseChunkData = { content: [{ text: chunk.text }], }; - chunks.push(newChunk); - lastEmitted = []; - const emit = (chunk: string) => { - lastEmitted.push(chunk); - }; - parser.parseChunk!(new GenerateResponseChunk(newChunk, chunks), emit); + const result = parser.parseChunk!( + new GenerateResponseChunk(newChunk, { previousChunks: chunks }) + ); + chunks.push(newChunk); - assert.deepStrictEqual(lastEmitted, chunk.want); + assert.strictEqual(result, chunk.want); } }); } const responseTests = [ { - desc: 'it returns the concatenated text', + desc: 'parses complete text response', + response: new GenerateResponse({ + message: { + role: 'model', + content: [{ text: 'Hello world' }], + }, + }), + want: 'Hello world', + }, + { + desc: 'handles empty response', response: new GenerateResponse({ message: { role: 'model', - content: [{ text: 'chunk one.' }, { text: 'chunk two.' }], + content: [{ text: '' }], }, }), - want: 'chunk one.chunk two.', + want: '', }, ]; for (const rt of responseTests) { it(rt.desc, () => { - const parser = textParser({ messages: [] }); - assert.deepStrictEqual(parser.parseResponse(rt.response), rt.want); + const parser = textFormatter.handler({ messages: [] }); + assert.strictEqual(parser.parseResponse(rt.response), rt.want); }); } }); diff --git a/js/ai/tests/generate/chunk_test.ts b/js/ai/tests/generate/chunk_test.ts new file mode 100644 index 0000000000..febc670a17 --- /dev/null +++ b/js/ai/tests/generate/chunk_test.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { GenerateResponseChunk } from '../../src/generate.js'; + +describe('GenerateResponseChunk', () => { + describe('text accumulation', () => { + const testChunk = new GenerateResponseChunk( + { content: [{ text: 'new' }] }, + { + previousChunks: [ + { content: [{ text: 'old1' }] }, + { content: [{ text: 'old2' }] }, + ], + } + ); + + it('#previousText should concatenate the text of previous parts', () => { + assert.strictEqual(testChunk.previousText, 'old1old2'); + }); + + it('#accumulatedText should concatenate previous with current text', () => { + assert.strictEqual(testChunk.accumulatedText, 'old1old2new'); + }); + }); +}); diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index da577c31a8..5e0930f915 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -16,249 +16,16 @@ import { z } from '@genkit-ai/core'; import { Registry } from '@genkit-ai/core/registry'; -import { toJsonSchema } from '@genkit-ai/core/schema'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; import { GenerateOptions, - GenerateResponse, - GenerateResponseChunk, - GenerationBlockedError, - GenerationResponseError, generate, toGenerateRequest, } from '../../src/generate.js'; -import { Message } from '../../src/message.js'; -import { - GenerateRequest, - GenerateResponseChunkData, - GenerateResponseData, - ModelAction, - ModelMiddleware, - defineModel, -} from '../../src/model.js'; +import { ModelAction, ModelMiddleware, defineModel } from '../../src/model.js'; import { defineTool } from '../../src/tool.js'; -describe('GenerateResponse', () => { - describe('#toJSON()', () => { - const testCases = [ - { - should: 'serialize correctly when custom is undefined', - responseData: { - message: { - role: 'model', - content: [{ text: '{"name": "Bob"}' }], - }, - finishReason: 'stop', - finishMessage: '', - usage: {}, - // No 'custom' property - }, - expectedOutput: { - message: { content: [{ text: '{"name": "Bob"}' }], role: 'model' }, - finishReason: 'stop', - usage: {}, - custom: {}, - }, - }, - ]; - - for (const test of testCases) { - it(test.should, () => { - const response = new GenerateResponse( - test.responseData as GenerateResponseData - ); - assert.deepStrictEqual(response.toJSON(), test.expectedOutput); - }); - } - }); - - describe('#output()', () => { - const testCases = [ - { - should: 'return structured data from the data part', - responseData: { - message: new Message({ - role: 'model', - content: [{ data: { name: 'Alice', age: 30 } }], - }), - finishReason: 'stop', - finishMessage: '', - usage: {}, - }, - expectedOutput: { name: 'Alice', age: 30 }, - }, - { - should: 'parse JSON from text when the data part is absent', - responseData: { - message: new Message({ - role: 'model', - content: [{ text: '{"name": "Bob"}' }], - }), - finishReason: 'stop', - finishMessage: '', - usage: {}, - }, - expectedOutput: { name: 'Bob' }, - }, - ]; - - for (const test of testCases) { - it(test.should, () => { - const response = new GenerateResponse( - test.responseData as GenerateResponseData - ); - assert.deepStrictEqual(response.output, test.expectedOutput); - }); - } - }); - - describe('#assertValid()', () => { - it('throws GenerationBlockedError if finishReason is blocked', () => { - const response = new GenerateResponse({ - finishReason: 'blocked', - finishMessage: 'Content was blocked', - }); - - assert.throws( - () => { - response.assertValid(); - }, - (err: unknown) => { - return err instanceof GenerationBlockedError; - } - ); - }); - - it('throws GenerationResponseError if no message is generated', () => { - const response = new GenerateResponse({ - finishReason: 'length', - finishMessage: 'Reached max tokens', - }); - - assert.throws( - () => { - response.assertValid(); - }, - (err: unknown) => { - return err instanceof GenerationResponseError; - } - ); - }); - - it('throws error if output does not conform to schema', () => { - const schema = z.object({ - name: z.string(), - age: z.number(), - }); - - const response = new GenerateResponse({ - message: { - role: 'model', - content: [{ text: '{"name": "John", "age": "30"}' }], - }, - finishReason: 'stop', - }); - - const request: GenerateRequest = { - messages: [], - output: { - schema: toJsonSchema({ schema }), - }, - }; - - assert.throws( - () => { - response.assertValid(request); - }, - (err: unknown) => { - return err instanceof Error && err.message.includes('must be number'); - } - ); - }); - - it('does not throw if output conforms to schema', () => { - const schema = z.object({ - name: z.string(), - age: z.number(), - }); - - const response = new GenerateResponse({ - message: { - role: 'model', - content: [{ text: '{"name": "John", "age": 30}' }], - }, - finishReason: 'stop', - }); - - const request: GenerateRequest = { - messages: [], - output: { - schema: toJsonSchema({ schema }), - }, - }; - - assert.doesNotThrow(() => { - response.assertValid(request); - }); - }); - }); - - describe('#toolRequests()', () => { - it('returns empty array if no tools requests found', () => { - const response = new GenerateResponse({ - message: new Message({ - role: 'model', - content: [{ text: '{"abc":"123"}' }], - }), - finishReason: 'stop', - }); - assert.deepStrictEqual(response.toolRequests, []); - }); - it('returns tool call if present', () => { - const toolCall = { - toolRequest: { - name: 'foo', - ref: 'abc', - input: 'banana', - }, - }; - const response = new GenerateResponse({ - message: new Message({ - role: 'model', - content: [toolCall], - }), - finishReason: 'stop', - }); - assert.deepStrictEqual(response.toolRequests, [toolCall]); - }); - it('returns all tool calls', () => { - const toolCall1 = { - toolRequest: { - name: 'foo', - ref: 'abc', - input: 'banana', - }, - }; - const toolCall2 = { - toolRequest: { - name: 'bar', - ref: 'bcd', - input: 'apple', - }, - }; - const response = new GenerateResponse({ - message: new Message({ - role: 'model', - content: [toolCall1, toolCall2], - }), - finishReason: 'stop', - }); - assert.deepStrictEqual(response.toolRequests, [toolCall1, toolCall2]); - }); - }); -}); - describe('toGenerateRequest', () => { const registry = new Registry(); // register tools @@ -448,79 +215,6 @@ describe('toGenerateRequest', () => { } }); -describe('GenerateResponseChunk', () => { - describe('#output()', () => { - const testCases = [ - { - should: 'parse ``` correctly', - accumulatedChunksTexts: ['```'], - correctJson: null, - }, - { - should: 'parse valid json correctly', - accumulatedChunksTexts: [`{"foo":"bar"}`], - correctJson: { foo: 'bar' }, - }, - { - should: 'if json invalid, return null', - accumulatedChunksTexts: [`invalid json`], - correctJson: null, - }, - { - should: 'handle missing closing brace', - accumulatedChunksTexts: [`{"foo":"bar"`], - correctJson: { foo: 'bar' }, - }, - { - should: 'handle missing closing bracket in nested object', - accumulatedChunksTexts: [`{"foo": {"bar": "baz"`], - correctJson: { foo: { bar: 'baz' } }, - }, - { - should: 'handle multiple chunks', - accumulatedChunksTexts: [`{"foo": {"bar"`, `: "baz`], - correctJson: { foo: { bar: 'baz' } }, - }, - { - should: 'handle multiple chunks with nested objects', - accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: {"baz": "qux`], - correctJson: { foo: { bar: { baz: 'qux' } } }, - }, - { - should: 'handle array nested in object', - accumulatedChunksTexts: [`{"foo": ["bar`], - correctJson: { foo: ['bar'] }, - }, - { - should: 'handle array nested in object with multiple chunks', - accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: ["baz`], - correctJson: { foo: { bar: ['baz'] } }, - }, - ]; - - for (const test of testCases) { - if (test.should) { - it(test.should, () => { - const accumulatedChunks: GenerateResponseChunkData[] = - test.accumulatedChunksTexts.map((text, index) => ({ - index, - content: [{ text }], - })); - - const chunkData = accumulatedChunks[accumulatedChunks.length - 1]; - - const responseChunk: GenerateResponseChunk = - new GenerateResponseChunk(chunkData, accumulatedChunks); - - const output = responseChunk.output; - - assert.deepStrictEqual(output, test.correctJson); - }); - } - } - }); -}); - describe('generate', () => { let registry: Registry; var echoModel: ModelAction; diff --git a/js/ai/tests/generate/response_test.ts b/js/ai/tests/generate/response_test.ts new file mode 100644 index 0000000000..698795385e --- /dev/null +++ b/js/ai/tests/generate/response_test.ts @@ -0,0 +1,247 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { z } from '@genkit-ai/core'; +import { toJsonSchema } from '@genkit-ai/core/schema'; +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { + GenerateResponse, + GenerationBlockedError, + GenerationResponseError, +} from '../../src/generate.js'; +import { Message } from '../../src/message.js'; +import { GenerateRequest, GenerateResponseData } from '../../src/model.js'; + +describe('GenerateResponse', () => { + describe('#toJSON()', () => { + const testCases = [ + { + should: 'serialize correctly when custom is undefined', + responseData: { + message: { + role: 'model', + content: [{ text: '{"name": "Bob"}' }], + }, + finishReason: 'stop', + finishMessage: '', + usage: {}, + // No 'custom' property + }, + expectedOutput: { + message: { content: [{ text: '{"name": "Bob"}' }], role: 'model' }, + finishReason: 'stop', + usage: {}, + custom: {}, + }, + }, + ]; + + for (const test of testCases) { + it(test.should, () => { + const response = new GenerateResponse( + test.responseData as GenerateResponseData + ); + assert.deepStrictEqual(response.toJSON(), test.expectedOutput); + }); + } + }); + + describe('#output()', () => { + const testCases = [ + { + should: 'return structured data from the data part', + responseData: { + message: new Message({ + role: 'model', + content: [{ data: { name: 'Alice', age: 30 } }], + }), + finishReason: 'stop', + finishMessage: '', + usage: {}, + }, + expectedOutput: { name: 'Alice', age: 30 }, + }, + { + should: 'parse JSON from text when the data part is absent', + responseData: { + message: new Message({ + role: 'model', + content: [{ text: '{"name": "Bob"}' }], + }), + finishReason: 'stop', + finishMessage: '', + usage: {}, + }, + expectedOutput: { name: 'Bob' }, + }, + ]; + + for (const test of testCases) { + it(test.should, () => { + const response = new GenerateResponse( + test.responseData as GenerateResponseData + ); + assert.deepStrictEqual(response.output, test.expectedOutput); + }); + } + }); + + describe('#assertValid()', () => { + it('throws GenerationBlockedError if finishReason is blocked', () => { + const response = new GenerateResponse({ + finishReason: 'blocked', + finishMessage: 'Content was blocked', + }); + + assert.throws( + () => { + response.assertValid(); + }, + (err: unknown) => { + return err instanceof GenerationBlockedError; + } + ); + }); + + it('throws GenerationResponseError if no message is generated', () => { + const response = new GenerateResponse({ + finishReason: 'length', + finishMessage: 'Reached max tokens', + }); + + assert.throws( + () => { + response.assertValid(); + }, + (err: unknown) => { + return err instanceof GenerationResponseError; + } + ); + }); + + it('throws error if output does not conform to schema', () => { + const schema = z.object({ + name: z.string(), + age: z.number(), + }); + + const response = new GenerateResponse({ + message: { + role: 'model', + content: [{ text: '{"name": "John", "age": "30"}' }], + }, + finishReason: 'stop', + }); + + const request: GenerateRequest = { + messages: [], + output: { + schema: toJsonSchema({ schema }), + }, + }; + + assert.throws( + () => { + response.assertValid(request); + }, + (err: unknown) => { + return err instanceof Error && err.message.includes('must be number'); + } + ); + }); + + it('does not throw if output conforms to schema', () => { + const schema = z.object({ + name: z.string(), + age: z.number(), + }); + + const response = new GenerateResponse({ + message: { + role: 'model', + content: [{ text: '{"name": "John", "age": 30}' }], + }, + finishReason: 'stop', + }); + + const request: GenerateRequest = { + messages: [], + output: { + schema: toJsonSchema({ schema }), + }, + }; + + assert.doesNotThrow(() => { + response.assertValid(request); + }); + }); + }); + + describe('#toolRequests()', () => { + it('returns empty array if no tools requests found', () => { + const response = new GenerateResponse({ + message: new Message({ + role: 'model', + content: [{ text: '{"abc":"123"}' }], + }), + finishReason: 'stop', + }); + assert.deepStrictEqual(response.toolRequests, []); + }); + it('returns tool call if present', () => { + const toolCall = { + toolRequest: { + name: 'foo', + ref: 'abc', + input: 'banana', + }, + }; + const response = new GenerateResponse({ + message: new Message({ + role: 'model', + content: [toolCall], + }), + finishReason: 'stop', + }); + assert.deepStrictEqual(response.toolRequests, [toolCall]); + }); + it('returns all tool calls', () => { + const toolCall1 = { + toolRequest: { + name: 'foo', + ref: 'abc', + input: 'banana', + }, + }; + const toolCall2 = { + toolRequest: { + name: 'bar', + ref: 'bcd', + input: 'apple', + }, + }; + const response = new GenerateResponse({ + message: new Message({ + role: 'model', + content: [toolCall1, toolCall2], + }), + finishReason: 'stop', + }); + assert.deepStrictEqual(response.toolRequests, [toolCall1, toolCall2]); + }); + }); +});