Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/__tests__/unit/agents.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,90 @@ describe('GuardrailAgent', () => {
expect(typeof result.tripwireTriggered).toBe('boolean');
});

it('passes the latest user message text to guardrails for conversation inputs', async () => {
process.env.OPENAI_API_KEY = 'test';
const config = {
version: 1,
input: {
version: 1,
guardrails: [{ name: 'Moderation', config: {} }],
},
};

const { instantiateGuardrails } = await import('../../runtime');
const runSpy = vi.fn().mockResolvedValue({
tripwireTriggered: false,
info: { guardrail_name: 'Moderation' },
});

vi.mocked(instantiateGuardrails).mockImplementationOnce(() =>
Promise.resolve([
{
definition: {
name: 'Moderation',
description: 'Moderation guardrail',
mediaType: 'text/plain',
configSchema: z.object({}),
checkFn: vi.fn(),
metadata: {},
ctxRequirements: z.object({}),
schema: () => ({}),
instantiate: vi.fn(),
},
config: {},
run: runSpy,
} as unknown as Parameters<typeof instantiateGuardrails>[0] extends Promise<infer T>
? T extends readonly (infer U)[]
? U
: never
: never,
])
);

const agent = (await GuardrailAgent.create(
config,
'Conversation Agent',
'Handle multi-turn conversations'
)) as MockAgent;

const guardrail = agent.inputGuardrails[0] as unknown as {
execute: (args: { input: unknown; context?: unknown }) => Promise<{
outputInfo: Record<string, unknown>;
tripwireTriggered: boolean;
}>;
};

const conversation = [
{ role: 'system', content: 'You are helpful.' },
{ role: 'user', content: [{ type: 'input_text', text: 'First question?' }] },
{ role: 'assistant', content: [{ type: 'output_text', text: 'An answer.' }] },
{
role: 'user',
content: [
{ type: 'input_text', text: 'Latest user message' },
{ type: 'input_text', text: 'with additional context.' },
],
},
];

const result = await guardrail.execute({ input: conversation, context: {} });

expect(runSpy).toHaveBeenCalledTimes(1);
const [ctxArgRaw, dataArg] = runSpy.mock.calls[0] as [unknown, string];
const ctxArg = ctxArgRaw as { getConversationHistory?: () => unknown[] };
expect(dataArg).toBe('Latest user message with additional context.');
expect(typeof ctxArg.getConversationHistory).toBe('function');

const history = ctxArg.getConversationHistory?.() as Array<{ content?: unknown }> | undefined;
expect(Array.isArray(history)).toBe(true);
expect(history && history[history.length - 1]?.content).toBe(
'Latest user message with additional context.'
);

expect(result.tripwireTriggered).toBe(false);
expect(result.outputInfo.input).toBe('Latest user message with additional context.');
});

it('should handle guardrail execution errors based on raiseGuardrailErrors setting', async () => {
process.env.OPENAI_API_KEY = 'test';
const config = {
Expand Down
195 changes: 179 additions & 16 deletions src/agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import type {
InputGuardrailFunctionArgs,
OutputGuardrailFunctionArgs,
} from '@openai/agents-core';
import { GuardrailLLMContext, GuardrailResult, TextOnlyContent, ContentPart } from './types';
import { ContentUtils } from './utils/content';
import { GuardrailLLMContext, GuardrailResult, TextOnlyContent } from './types';
import { TEXT_CONTENT_TYPES } from './utils/content';
import {
loadPipelineBundles,
instantiateGuardrails,
Expand Down Expand Up @@ -250,6 +250,180 @@ function ensureGuardrailContext(
} as GuardrailLLMContext;
}

const TEXTUAL_CONTENT_TYPES = new Set<string>(TEXT_CONTENT_TYPES);
const MAX_CONTENT_EXTRACTION_DEPTH = 10;

/**
* Extract text from any nested content value with optional type filtering.
*
* @param value Arbitrary content value (string, array, or object) to inspect.
* @param depth Current recursion depth, used to guard against circular structures.
* @param filterByType When true, only content parts with recognized text types are returned.
* @returns The extracted text, or an empty string when no text is found.
*/
function extractTextFromValue(value: unknown, depth: number, filterByType: boolean): string {
if (depth > MAX_CONTENT_EXTRACTION_DEPTH) {
return '';
}

if (typeof value === 'string') {
return value.trim();
}

if (Array.isArray(value)) {
const parts: string[] = [];
for (const item of value) {
const text = extractTextFromValue(item, depth + 1, filterByType);
if (text) {
parts.push(text);
}
}
return parts.join(' ').trim();
}

if (value && typeof value === 'object') {
const record = value as Record<string, unknown>;
const typeValue = typeof record.type === 'string' ? record.type : null;
const isRecognizedTextType = typeValue ? TEXTUAL_CONTENT_TYPES.has(typeValue) : false;

if (typeof record.text === 'string') {
if (!filterByType || isRecognizedTextType || typeValue === null) {
return record.text.trim();
}
}

const contentValue = record.content;
// If a direct text field was skipped due to type filtering, fall back to nested content.
if (contentValue != null) {
const nested = extractTextFromValue(contentValue, depth + 1, filterByType);
if (nested) {
return nested;
}
}
}

return '';
}

/**
* Extract text from structured content parts (e.g., the `content` field on a message).
*
* Only textual content-part types enumerated in TEXTUAL_CONTENT_TYPES are considered so
* that non-text modalities (images, tools, etc.) remain ignored.
*/
function extractTextFromContentParts(content: unknown, depth = 0): string {
return extractTextFromValue(content, depth, true);
}

/**
* Extract text from a single message entry.
*
* Handles strings, arrays of content parts, or message-like objects that contain a
* `content` collection or a plain `text` field.
*/
function extractTextFromMessageEntry(entry: unknown, depth = 0): string {
if (depth > MAX_CONTENT_EXTRACTION_DEPTH) {
return '';
}

if (entry == null) {
return '';
}

if (typeof entry === 'string') {
return entry.trim();
}

if (Array.isArray(entry)) {
return extractTextFromContentParts(entry, depth + 1);
}

if (typeof entry === 'object') {
const record = entry as Record<string, unknown>;

if (record.content !== undefined) {
const contentText = extractTextFromContentParts(record.content, depth + 1);
if (contentText) {
return contentText;
}
}

if (typeof record.text === 'string') {
return record.text.trim();
}
}

Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The inline comment explains the parameter but doesn't clarify when this fallback path is reached. Consider adding a brief comment above this line explaining that this is a last-resort extraction attempt for object structures that don't match standard message patterns.

Suggested change
// Last-resort extraction: if entry does not match standard message patterns,
// attempt to extract text from any value type.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit]

return extractTextFromValue(entry, depth + 1, false /* allow all types when falling back */);
}

/**
* Extract the latest user-authored text from raw agent input.
*
* Accepts strings, message objects, or arrays of mixed items. Arrays are scanned
* from newest to oldest, returning the first user-role message with textual content.
*/
function extractTextFromAgentInput(input: unknown): string {
if (input == null) {
return '';
}

if (typeof input === 'string') {
return input.trim();
}

if (Array.isArray(input)) {
for (let idx = input.length - 1; idx >= 0; idx -= 1) {
const candidate = input[idx];
if (candidate && typeof candidate === 'object') {
const record = candidate as Record<string, unknown>;
if (record.role === 'user') {
const text = extractTextFromMessageEntry(candidate);
if (text) {
return text;
}
}
} else if (typeof candidate === 'string') {
const text = candidate.trim();
if (text) {
return text;
}
}
}
return '';
}

if (input && typeof input === 'object') {
const record = input as Record<string, unknown>;
if (record.role === 'user') {
const text = extractTextFromMessageEntry(record);
if (text) {
return text;
}
}

if (record.content != null) {
const contentText = extractTextFromContentParts(record.content);
if (contentText) {
return contentText;
}
}

if (typeof record.text === 'string') {
return record.text.trim();
}
}

if (
typeof input === 'number' ||
typeof input === 'boolean' ||
typeof input === 'bigint'
) {
return String(input);
}

return '';
}

function extractLatestUserText(history: NormalizedConversationEntry[]): string {
for (let i = history.length - 1; i >= 0; i -= 1) {
const entry = history[i];
Expand All @@ -261,20 +435,9 @@ function extractLatestUserText(history: NormalizedConversationEntry[]): string {
}

function resolveInputText(input: unknown, history: NormalizedConversationEntry[]): string {
if (typeof input === 'string') {
return input;
}

if (input && typeof input === 'object' && 'content' in (input as Record<string, unknown>)) {
const content = (input as { content: string | ContentPart[] }).content;
const message = {
role: 'user',
content,
};
const extracted = ContentUtils.extractTextFromMessage(message);
if (extracted) {
return extracted;
}
const directText = extractTextFromAgentInput(input);
if (directText) {
return directText;
}

return extractLatestUserText(history);
Expand Down
8 changes: 4 additions & 4 deletions src/utils/content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@

import { Message, ContentPart, TextContentPart, TextOnlyMessageArray } from '../types';

export const TEXT_CONTENT_TYPES = ['input_text', 'text', 'output_text', 'summary_text'] as const;
const TEXT_CONTENT_TYPES_SET = new Set<string>(TEXT_CONTENT_TYPES);

export class ContentUtils {
// Clear: what types are considered text
private static readonly TEXT_TYPES = ['input_text', 'text', 'output_text', 'summary_text'] as const;

/**
* Check if a content part is text-based.
*/
static isText(part: ContentPart): boolean {
return this.TEXT_TYPES.includes(part.type as typeof this.TEXT_TYPES[number]);
return typeof part.type === 'string' && TEXT_CONTENT_TYPES_SET.has(part.type);
}

/**
Expand Down
Loading