diff --git a/src/__tests__/unit/agents.test.ts b/src/__tests__/unit/agents.test.ts index 4486b34..58a9791 100644 --- a/src/__tests__/unit/agents.test.ts +++ b/src/__tests__/unit/agents.test.ts @@ -365,6 +365,90 @@ describe('GuardrailAgent', () => { expect(typeof result.tripwireTriggered).toBe('boolean'); }); + it('passes the latest user message text to guardrails for conversation inputs', async () => { + process.env.OPENAI_API_KEY = 'test'; + const config = { + version: 1, + input: { + version: 1, + guardrails: [{ name: 'Moderation', config: {} }], + }, + }; + + const { instantiateGuardrails } = await import('../../runtime'); + const runSpy = vi.fn().mockResolvedValue({ + tripwireTriggered: false, + info: { guardrail_name: 'Moderation' }, + }); + + vi.mocked(instantiateGuardrails).mockImplementationOnce(() => + Promise.resolve([ + { + definition: { + name: 'Moderation', + description: 'Moderation guardrail', + mediaType: 'text/plain', + configSchema: z.object({}), + checkFn: vi.fn(), + metadata: {}, + ctxRequirements: z.object({}), + schema: () => ({}), + instantiate: vi.fn(), + }, + config: {}, + run: runSpy, + } as unknown as Parameters[0] extends Promise + ? T extends readonly (infer U)[] + ? U + : never + : never, + ]) + ); + + const agent = (await GuardrailAgent.create( + config, + 'Conversation Agent', + 'Handle multi-turn conversations' + )) as MockAgent; + + const guardrail = agent.inputGuardrails[0] as unknown as { + execute: (args: { input: unknown; context?: unknown }) => Promise<{ + outputInfo: Record; + tripwireTriggered: boolean; + }>; + }; + + const conversation = [ + { role: 'system', content: 'You are helpful.' }, + { role: 'user', content: [{ type: 'input_text', text: 'First question?' }] }, + { role: 'assistant', content: [{ type: 'output_text', text: 'An answer.' }] }, + { + role: 'user', + content: [ + { type: 'input_text', text: 'Latest user message' }, + { type: 'input_text', text: 'with additional context.' }, + ], + }, + ]; + + const result = await guardrail.execute({ input: conversation, context: {} }); + + expect(runSpy).toHaveBeenCalledTimes(1); + const [ctxArgRaw, dataArg] = runSpy.mock.calls[0] as [unknown, string]; + const ctxArg = ctxArgRaw as { getConversationHistory?: () => unknown[] }; + expect(dataArg).toBe('Latest user message with additional context.'); + expect(typeof ctxArg.getConversationHistory).toBe('function'); + + const history = ctxArg.getConversationHistory?.() as Array<{ content?: unknown }> | undefined; + expect(Array.isArray(history)).toBe(true); + expect(history && history[history.length - 1]?.content).toBe( + 'Latest user message with additional context.' + ); + + expect(result.tripwireTriggered).toBe(false); + expect(result.outputInfo.input).toBe('Latest user message with additional context.'); + }); + it('should handle guardrail execution errors based on raiseGuardrailErrors setting', async () => { process.env.OPENAI_API_KEY = 'test'; const config = { diff --git a/src/agents.ts b/src/agents.ts index 43fafdd..95d897d 100644 --- a/src/agents.ts +++ b/src/agents.ts @@ -13,8 +13,8 @@ import type { InputGuardrailFunctionArgs, OutputGuardrailFunctionArgs, } from '@openai/agents-core'; -import { GuardrailLLMContext, GuardrailResult, TextOnlyContent, ContentPart } from './types'; -import { ContentUtils } from './utils/content'; +import { GuardrailLLMContext, GuardrailResult, TextOnlyContent } from './types'; +import { TEXT_CONTENT_TYPES } from './utils/content'; import { loadPipelineBundles, instantiateGuardrails, @@ -250,6 +250,180 @@ function ensureGuardrailContext( } as GuardrailLLMContext; } +const TEXTUAL_CONTENT_TYPES = new Set(TEXT_CONTENT_TYPES); +const MAX_CONTENT_EXTRACTION_DEPTH = 10; + +/** + * Extract text from any nested content value with optional type filtering. + * + * @param value Arbitrary content value (string, array, or object) to inspect. + * @param depth Current recursion depth, used to guard against circular structures. + * @param filterByType When true, only content parts with recognized text types are returned. + * @returns The extracted text, or an empty string when no text is found. + */ +function extractTextFromValue(value: unknown, depth: number, filterByType: boolean): string { + if (depth > MAX_CONTENT_EXTRACTION_DEPTH) { + return ''; + } + + if (typeof value === 'string') { + return value.trim(); + } + + if (Array.isArray(value)) { + const parts: string[] = []; + for (const item of value) { + const text = extractTextFromValue(item, depth + 1, filterByType); + if (text) { + parts.push(text); + } + } + return parts.join(' ').trim(); + } + + if (value && typeof value === 'object') { + const record = value as Record; + const typeValue = typeof record.type === 'string' ? record.type : null; + const isRecognizedTextType = typeValue ? TEXTUAL_CONTENT_TYPES.has(typeValue) : false; + + if (typeof record.text === 'string') { + if (!filterByType || isRecognizedTextType || typeValue === null) { + return record.text.trim(); + } + } + + const contentValue = record.content; + // If a direct text field was skipped due to type filtering, fall back to nested content. + if (contentValue != null) { + const nested = extractTextFromValue(contentValue, depth + 1, filterByType); + if (nested) { + return nested; + } + } + } + + return ''; +} + +/** + * Extract text from structured content parts (e.g., the `content` field on a message). + * + * Only textual content-part types enumerated in TEXTUAL_CONTENT_TYPES are considered so + * that non-text modalities (images, tools, etc.) remain ignored. + */ +function extractTextFromContentParts(content: unknown, depth = 0): string { + return extractTextFromValue(content, depth, true); +} + +/** + * Extract text from a single message entry. + * + * Handles strings, arrays of content parts, or message-like objects that contain a + * `content` collection or a plain `text` field. + */ +function extractTextFromMessageEntry(entry: unknown, depth = 0): string { + if (depth > MAX_CONTENT_EXTRACTION_DEPTH) { + return ''; + } + + if (entry == null) { + return ''; + } + + if (typeof entry === 'string') { + return entry.trim(); + } + + if (Array.isArray(entry)) { + return extractTextFromContentParts(entry, depth + 1); + } + + if (typeof entry === 'object') { + const record = entry as Record; + + if (record.content !== undefined) { + const contentText = extractTextFromContentParts(record.content, depth + 1); + if (contentText) { + return contentText; + } + } + + if (typeof record.text === 'string') { + return record.text.trim(); + } + } + + return extractTextFromValue(entry, depth + 1, false /* allow all types when falling back */); +} + +/** + * Extract the latest user-authored text from raw agent input. + * + * Accepts strings, message objects, or arrays of mixed items. Arrays are scanned + * from newest to oldest, returning the first user-role message with textual content. + */ +function extractTextFromAgentInput(input: unknown): string { + if (input == null) { + return ''; + } + + if (typeof input === 'string') { + return input.trim(); + } + + if (Array.isArray(input)) { + for (let idx = input.length - 1; idx >= 0; idx -= 1) { + const candidate = input[idx]; + if (candidate && typeof candidate === 'object') { + const record = candidate as Record; + if (record.role === 'user') { + const text = extractTextFromMessageEntry(candidate); + if (text) { + return text; + } + } + } else if (typeof candidate === 'string') { + const text = candidate.trim(); + if (text) { + return text; + } + } + } + return ''; + } + + if (input && typeof input === 'object') { + const record = input as Record; + if (record.role === 'user') { + const text = extractTextFromMessageEntry(record); + if (text) { + return text; + } + } + + if (record.content != null) { + const contentText = extractTextFromContentParts(record.content); + if (contentText) { + return contentText; + } + } + + if (typeof record.text === 'string') { + return record.text.trim(); + } + } + + if ( + typeof input === 'number' || + typeof input === 'boolean' || + typeof input === 'bigint' + ) { + return String(input); + } + + return ''; +} + function extractLatestUserText(history: NormalizedConversationEntry[]): string { for (let i = history.length - 1; i >= 0; i -= 1) { const entry = history[i]; @@ -261,20 +435,9 @@ function extractLatestUserText(history: NormalizedConversationEntry[]): string { } function resolveInputText(input: unknown, history: NormalizedConversationEntry[]): string { - if (typeof input === 'string') { - return input; - } - - if (input && typeof input === 'object' && 'content' in (input as Record)) { - const content = (input as { content: string | ContentPart[] }).content; - const message = { - role: 'user', - content, - }; - const extracted = ContentUtils.extractTextFromMessage(message); - if (extracted) { - return extracted; - } + const directText = extractTextFromAgentInput(input); + if (directText) { + return directText; } return extractLatestUserText(history); diff --git a/src/utils/content.ts b/src/utils/content.ts index a5e74ad..9cafd32 100644 --- a/src/utils/content.ts +++ b/src/utils/content.ts @@ -7,15 +7,15 @@ import { Message, ContentPart, TextContentPart, TextOnlyMessageArray } from '../types'; +export const TEXT_CONTENT_TYPES = ['input_text', 'text', 'output_text', 'summary_text'] as const; +const TEXT_CONTENT_TYPES_SET = new Set(TEXT_CONTENT_TYPES); + export class ContentUtils { - // Clear: what types are considered text - private static readonly TEXT_TYPES = ['input_text', 'text', 'output_text', 'summary_text'] as const; - /** * Check if a content part is text-based. */ static isText(part: ContentPart): boolean { - return this.TEXT_TYPES.includes(part.type as typeof this.TEXT_TYPES[number]); + return typeof part.type === 'string' && TEXT_CONTENT_TYPES_SET.has(part.type); } /**