diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 604fcfbf1a..c29e24fd77 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -26,6 +26,7 @@ import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { DocumentData } from './document.js'; import { extractJson } from './extract.js'; import { generateHelper, GenerateUtilParamSchema } from './generateAction.js'; +import { Message } from './message.js'; import { GenerateRequest, GenerateResponseChunkData, @@ -41,88 +42,10 @@ import { Part, ToolDefinition, ToolRequestPart, - ToolResponsePart, } from './model.js'; import { ExecutablePrompt } from './prompt.js'; import { resolveTools, ToolArgument, toToolDefinition } from './tool.js'; -/** - * Message represents a single role's contribution to a generation. Each message - * can contain multiple parts (for example text and an image), and each generation - * can contain multiple messages. - */ -export class Message implements MessageData { - role: MessageData['role']; - content: Part[]; - - constructor(message: MessageData) { - this.role = message.role; - this.content = message.content; - } - - /** - * If a message contains a `data` part, it is returned. Otherwise, the `output()` - * method extracts the first valid JSON object or array from the text contained in - * the message and returns it. - * - * @returns The structured output contained in the message. - */ - get output(): T { - return this.data || extractJson(this.text); - } - - toolResponseParts(): ToolResponsePart[] { - const res = this.content.filter((part) => !!part.toolResponse); - return res as ToolResponsePart[]; - } - - /** - * Concatenates all `text` parts present in the message with no delimiter. - * @returns A string of all concatenated text parts. - */ - get text(): string { - return this.content.map((part) => part.text || '').join(''); - } - - /** - * Returns the first media part detected in the message. Useful for extracting - * (for example) an image from a generation expected to create one. - * @returns The first detected `media` part in the message. - */ - get media(): { url: string; contentType?: string } | null { - return this.content.find((part) => part.media)?.media || null; - } - - /** - * Returns the first detected `data` part of a message. - * @returns The first `data` part detected in the message (if any). - */ - get data(): T | null { - return this.content.find((part) => part.data)?.data as T | null; - } - - /** - * Returns all tool request found in this message. - * @returns Array of all tool request found in this message. - */ - get toolRequests(): ToolRequestPart[] { - return this.content.filter( - (part) => !!part.toolRequest - ) as ToolRequestPart[]; - } - - /** - * Converts the Message to a plain JS object. - * @returns Plain JS object representing the data contained in the message. - */ - toJSON(): MessageData { - return { - role: this.role, - content: [...this.content], - }; - } -} - /** * GenerateResponse is the result from a `generate()` call and contains one or * more generated candidate messages. @@ -361,29 +284,25 @@ export class GenerateResponseChunk } } -export function normalizePart(input: string | Part | Part[]): Part[] { - if (typeof input === 'string') { - return [{ text: input }]; - } else if (Array.isArray(input)) { - return input; - } else { - return [input]; - } -} - export async function toGenerateRequest( registry: Registry, options: GenerateOptions ): Promise { const messages: MessageData[] = []; if (options.system) { - messages.push({ role: 'system', content: normalizePart(options.system) }); + messages.push({ + role: 'system', + content: Message.parseContent(options.system), + }); } if (options.messages) { - messages.push(...options.messages); + messages.push(...options.messages.map((m) => Message.parseData(m))); } if (options.prompt) { - messages.push({ role: 'user', content: normalizePart(options.prompt) }); + messages.push({ + role: 'user', + content: Message.parseContent(options.prompt), + }); } if (messages.length === 0) { throw new Error('at least one message is required in generate request'); @@ -427,7 +346,7 @@ export interface GenerateOptions< /** Retrieved documents to be used as context for this generation. */ docs?: DocumentData[]; /** Conversation messages (history) for multi-turn prompting when supported by the underlying model. */ - messages?: MessageData[]; + messages?: (MessageData & { content: Part[] | string | (string | Part)[] })[]; /** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */ tools?: ToolArgument[]; /** Configuration for the generation request. */ @@ -569,7 +488,7 @@ export async function generate< if (resolvedOptions.system) { messages.push({ role: 'system', - content: normalizePart(resolvedOptions.system), + content: Message.parseContent(resolvedOptions.system), }); } if (resolvedOptions.messages) { @@ -578,7 +497,7 @@ export async function generate< if (resolvedOptions.prompt) { messages.push({ role: 'user', - content: normalizePart(resolvedOptions.prompt), + content: Message.parseContent(resolvedOptions.prompt), }); } diff --git a/js/ai/src/index.ts b/js/ai/src/index.ts index e629db64e7..eb7ea3e416 100644 --- a/js/ai/src/index.ts +++ b/js/ai/src/index.ts @@ -39,16 +39,15 @@ export { GenerateResponse, GenerationBlockedError, GenerationResponseError, - Message, generate, generateStream, - normalizePart, tagAsPreamble, toGenerateRequest, type GenerateOptions, type GenerateStreamOptions, type GenerateStreamResponse, } from './generate.js'; +export { Message } from './message.js'; export { GenerationCommonConfigSchema, MessageSchema, diff --git a/js/ai/src/message.ts b/js/ai/src/message.ts new file mode 100644 index 0000000000..926124dc05 --- /dev/null +++ b/js/ai/src/message.ts @@ -0,0 +1,131 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { extractJson } from './extract'; +import { MessageData, Part, ToolRequestPart, ToolResponsePart } from './model'; + +/** + * Message represents a single role's contribution to a generation. Each message + * can contain multiple parts (for example text and an image), and each generation + * can contain multiple messages. + */ +export class Message implements MessageData { + role: MessageData['role']; + content: Part[]; + metadata?: Record; + + static parseData( + lenientMessage: + | string + | (MessageData & { content: string | Part | Part[]; role: string }) + | MessageData, + defaultRole: MessageData['role'] = 'user' + ): MessageData { + if (typeof lenientMessage === 'string') { + return { role: defaultRole, content: [{ text: lenientMessage }] }; + } + return { + ...lenientMessage, + content: Message.parseContent(lenientMessage.content), + }; + } + + static parse( + lenientMessage: string | (MessageData & { content: string }) | MessageData + ): Message { + return new Message(Message.parseData(lenientMessage)); + } + + static parseContent(lenientPart: string | Part | (string | Part)[]): Part[] { + if (typeof lenientPart === 'string') { + return [{ text: lenientPart }]; + } else if (Array.isArray(lenientPart)) { + return lenientPart.map((p) => (typeof p === 'string' ? { text: p } : p)); + } else { + return [lenientPart]; + } + } + + constructor(message: MessageData) { + this.role = message.role; + this.content = message.content; + this.metadata = message.metadata; + } + + /** + * If a message contains a `data` part, it is returned. Otherwise, the `output()` + * method extracts the first valid JSON object or array from the text contained in + * the message and returns it. + * + * @returns The structured output contained in the message. + */ + get output(): T { + return this.data || extractJson(this.text); + } + + toolResponseParts(): ToolResponsePart[] { + const res = this.content.filter((part) => !!part.toolResponse); + return res as ToolResponsePart[]; + } + + /** + * Concatenates all `text` parts present in the message with no delimiter. + * @returns A string of all concatenated text parts. + */ + get text(): string { + return this.content.map((part) => part.text || '').join(''); + } + + /** + * Returns the first media part detected in the message. Useful for extracting + * (for example) an image from a generation expected to create one. + * @returns The first detected `media` part in the message. + */ + get media(): { url: string; contentType?: string } | null { + return this.content.find((part) => part.media)?.media || null; + } + + /** + * Returns the first detected `data` part of a message. + * @returns The first `data` part detected in the message (if any). + */ + get data(): T | null { + return this.content.find((part) => part.data)?.data as T | null; + } + + /** + * Returns all tool request found in this message. + * @returns Array of all tool request found in this message. + */ + get toolRequests(): ToolRequestPart[] { + return this.content.filter( + (part) => !!part.toolRequest + ) as ToolRequestPart[]; + } + + /** + * Converts the Message to a plain JS object. + * @returns Plain JS object representing the data contained in the message. + */ + toJSON(): MessageData { + let out: MessageData = { + role: this.role, + content: [...this.content], + }; + if (this.metadata) out.metadata = this.metadata; + return out; + } +} diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 9a02b0f6e1..da577c31a8 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -25,10 +25,10 @@ import { GenerateResponseChunk, GenerationBlockedError, GenerationResponseError, - Message, generate, toGenerateRequest, } from '../../src/generate.js'; +import { Message } from '../../src/message.js'; import { GenerateRequest, GenerateResponseChunkData, diff --git a/js/ai/tests/message/message_test.ts b/js/ai/tests/message/message_test.ts new file mode 100644 index 0000000000..8046a9aade --- /dev/null +++ b/js/ai/tests/message/message_test.ts @@ -0,0 +1,55 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { Message } from '../../src/message'; + +describe('Message', () => { + describe('.parseData()', () => { + const testCases = [ + { + desc: 'convert string to user message', + input: 'i am a user message', + want: { role: 'user', content: [{ text: 'i am a user message' }] }, + }, + { + desc: 'convert string content to Part[] content', + input: { + role: 'system', + content: 'i am a system message', + metadata: { extra: true }, + }, + want: { + role: 'system', + content: [{ text: 'i am a system message' }], + metadata: { extra: true }, + }, + }, + { + desc: 'leave valid MessageData alone', + input: { role: 'model', content: [{ text: 'i am a model message' }] }, + want: { role: 'model', content: [{ text: 'i am a model message' }] }, + }, + ]; + + for (const t of testCases) { + it(t.desc, () => { + assert.deepStrictEqual(Message.parseData(t.input as any), t.want); + }); + } + }); +}); diff --git a/js/genkit/src/session.ts b/js/genkit/src/session.ts index 667dfd6cfc..6ba305c973 100644 --- a/js/genkit/src/session.ts +++ b/js/genkit/src/session.ts @@ -16,8 +16,8 @@ import { GenerateOptions, + Message, MessageData, - normalizePart, tagAsPreamble, } from '@genkit-ai/ai'; import { z } from '@genkit-ai/core'; @@ -186,7 +186,7 @@ export class Session { if (baseOptions.system) { messages.push({ role: 'system', - content: normalizePart(baseOptions.system), + content: Message.parseContent(baseOptions.system), }); } delete baseOptions.system; diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index 61361e6093..81da064a72 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -65,7 +65,7 @@ describe('generate', () => { messages: [ { role: 'system', - content: [{ text: 'talk like a pirate' }], + content: 'talk like a pirate', }, { role: 'user',