diff --git a/docs/benchmarking/alignment_roc_curves.png b/docs/benchmarking/alignment_roc_curves.png index 449783f..01631cd 100644 Binary files a/docs/benchmarking/alignment_roc_curves.png and b/docs/benchmarking/alignment_roc_curves.png differ diff --git a/docs/ref/checks/prompt_injection_detection.md b/docs/ref/checks/prompt_injection_detection.md index edb475c..9c3709f 100644 --- a/docs/ref/checks/prompt_injection_detection.md +++ b/docs/ref/checks/prompt_injection_detection.md @@ -67,8 +67,14 @@ Returns a `GuardrailResult` with the following `info` dictionary: "confidence": 0.1, "threshold": 0.7, "user_goal": "What's the weather in Tokyo?", - "action": "get_weather(location='Tokyo')", - "checked_text": "Original input text" + "action": [ + { + "type": "function_call", + "name": "get_weather", + "arguments": "{\"location\": \"Tokyo\"}" + } + ], + "checked_text": "[{\"role\": \"user\", \"content\": \"What is the weather in Tokyo?\"}]" } ``` @@ -77,18 +83,18 @@ Returns a `GuardrailResult` with the following `info` dictionary: - **`confidence`**: Confidence score (0.0 to 1.0) that the action is misaligned - **`threshold`**: The confidence threshold that was configured - **`user_goal`**: The tracked user intent from conversation -- **`action`**: The specific action being evaluated -- **`checked_text`**: Original input text +- **`action`**: The list of function calls or tool outputs analyzed for alignment +- **`checked_text`**: Serialized conversation history inspected during analysis ## Benchmark Results ### Dataset Description -This benchmark evaluates model performance on a synthetic dataset of agent conversation traces: +This benchmark evaluates model performance on agent conversation traces: -- **Dataset size**: 1,000 samples with 500 positive cases (50% prevalence) -- **Data type**: Internal synthetic dataset simulating realistic agent traces -- **Test scenarios**: Multi-turn conversations with function calls and tool outputs +- **Synthetic dataset**: 1,000 samples with 500 positive cases (50% prevalence) simulating realistic agent traces +- **AgentDojo dataset**: 1,046 samples from AgentDojo's workspace, travel, banking, and Slack suite combined with the "important_instructions" attack (949 positive cases, 97 negative samples) +- **Test scenarios**: Multi-turn conversations with function calls and tool outputs across realistic workplace domains - **Misalignment examples**: Unrelated function calls, harmful operations, and data leakage **Example of misaligned conversation:** @@ -107,12 +113,12 @@ This benchmark evaluates model performance on a synthetic dataset of agent conve | Model | ROC AUC | Prec@R=0.80 | Prec@R=0.90 | Prec@R=0.95 | Recall@FPR=0.01 | |---------------|---------|-------------|-------------|-------------|-----------------| -| gpt-5 | 0.9997 | 1.000 | 1.000 | 1.000 | 0.998 | -| gpt-5-mini | 0.9998 | 1.000 | 1.000 | 0.998 | 0.998 | -| gpt-5-nano | 0.9987 | 0.996 | 0.996 | 0.996 | 0.996 | -| gpt-4.1 | 0.9990 | 1.000 | 1.000 | 1.000 | 0.998 | -| gpt-4.1-mini (default) | 0.9930 | 1.000 | 1.000 | 1.000 | 0.986 | -| gpt-4.1-nano | 0.9431 | 0.982 | 0.845 | 0.695 | 0.000 | +| gpt-5 | 0.9604 | 0.998 | 0.995 | 0.963 | 0.431 | +| gpt-5-mini | 0.9796 | 0.999 | 0.999 | 0.966 | 0.000 | +| gpt-5-nano | 0.8651 | 0.963 | 0.963 | 0.951 | 0.056 | +| gpt-4.1 | 0.9846 | 0.998 | 0.998 | 0.998 | 0.000 | +| gpt-4.1-mini (default) | 0.9728 | 0.995 | 0.995 | 0.995 | 0.000 | +| gpt-4.1-nano | 0.8677 | 0.974 | 0.974 | 0.974 | 0.000 | **Notes:** diff --git a/docs/ref/types-typescript.md b/docs/ref/types-typescript.md index 6c96511..b320c33 100644 --- a/docs/ref/types-typescript.md +++ b/docs/ref/types-typescript.md @@ -17,8 +17,6 @@ Context interface providing access to the OpenAI client used by guardrails. ```typescript export interface GuardrailLLMContextWithHistory extends GuardrailLLMContext { getConversationHistory(): any[]; - getInjectionLastCheckedIndex(): number; - updateInjectionLastCheckedIndex(index: number): void; } ``` @@ -60,4 +58,3 @@ export type TCfg = object; For the full source, see [src/types.ts](https://github.com/openai/openai-guardrails-js/blob/main/src/types.ts) in the repository. - diff --git a/src/__tests__/unit/prompt_injection_detection.test.ts b/src/__tests__/unit/prompt_injection_detection.test.ts index 20151cd..ecfae67 100644 --- a/src/__tests__/unit/prompt_injection_detection.test.ts +++ b/src/__tests__/unit/prompt_injection_detection.test.ts @@ -52,8 +52,6 @@ describe('Prompt Injection Detection Check', () => { output: '{"temperature": 22, "condition": "sunny"}', }, ], - getInjectionLastCheckedIndex: () => 0, - updateInjectionLastCheckedIndex: () => {}, }; }); @@ -74,7 +72,6 @@ describe('Prompt Injection Detection Check', () => { const contextWithOnlyUserMessages = { ...mockContext, getConversationHistory: () => [{ role: 'user', content: 'Hello there!' }], - getInjectionLastCheckedIndex: () => 0, }; const result = await promptInjectionDetectionCheck( @@ -84,14 +81,13 @@ describe('Prompt Injection Detection Check', () => { ); expect(result.tripwireTriggered).toBe(false); - expect(result.info.observation).toBe('No function calls or function call outputs to evaluate'); + expect(result.info.observation).toBe('No actionable tool messages to evaluate'); }); it('should return skip result when no LLM actions', async () => { const contextWithNoLLMActions = { ...mockContext, getConversationHistory: () => [{ role: 'user', content: 'Hello there!' }], - getInjectionLastCheckedIndex: () => 1, // Already checked all messages }; const result = await promptInjectionDetectionCheck( @@ -101,7 +97,7 @@ describe('Prompt Injection Detection Check', () => { ); expect(result.tripwireTriggered).toBe(false); - expect(result.info.observation).toBe('No function calls or function call outputs to evaluate'); + expect(result.info.observation).toBe('No actionable tool messages to evaluate'); }); it('should extract user intent correctly', async () => { @@ -123,6 +119,6 @@ describe('Prompt Injection Detection Check', () => { const result = await promptInjectionDetectionCheck(contextWithError, 'test data', config); expect(result.tripwireTriggered).toBe(false); - expect(result.info.observation).toContain('Error during prompt injection detection check'); + expect(result.info.observation).toBe('No conversation history available'); }); }); diff --git a/src/base-client.ts b/src/base-client.ts index 80a062b..c03e4d5 100644 --- a/src/base-client.ts +++ b/src/base-client.ts @@ -109,7 +109,6 @@ export abstract class GuardrailsBaseClient { protected guardrails!: StageGuardrails; protected context!: GuardrailLLMContext; protected _resourceClient!: OpenAI; - protected _injectionLastCheckedIndex: number = 0; public raiseGuardrailErrors: boolean = false; /** @@ -527,14 +526,8 @@ export abstract class GuardrailsBaseClient { guardrailLlm: this.context.guardrailLlm, // Add conversation history methods getConversationHistory: () => conversationHistory, - getInjectionLastCheckedIndex: () => this._injectionLastCheckedIndex, - updateInjectionLastCheckedIndex: (newIndex: number) => { - this._injectionLastCheckedIndex = newIndex; - }, } as GuardrailLLMContext & { getConversationHistory(): any[]; - getInjectionLastCheckedIndex(): number; - updateInjectionLastCheckedIndex(index: number): void; }; } diff --git a/src/checks/prompt_injection_detection.ts b/src/checks/prompt_injection_detection.ts index a6f28be..31c6768 100644 --- a/src/checks/prompt_injection_detection.ts +++ b/src/checks/prompt_injection_detection.ts @@ -25,14 +25,10 @@ */ import { z } from 'zod'; -import { - CheckFn, - GuardrailResult, - GuardrailLLMContext, - GuardrailLLMContextWithHistory, -} from '../types'; +import { CheckFn, GuardrailResult, GuardrailLLMContext, GuardrailLLMContextWithHistory } from '../types'; import { defaultSpecRegistry } from '../registry'; import { LLMConfig, LLMOutput, runLLM } from './llm-base'; +import { parseConversationInput, POSSIBLE_CONVERSATION_KEYS } from '../utils/conversation'; /** * Configuration schema for the prompt injection detection guardrail. @@ -76,39 +72,60 @@ export type PromptInjectionDetectionOutput = z.infer = async (ctx, data, config): Promise => { try { - // Get conversation history and incremental checking state - const conversationHistory = ctx.getConversationHistory(); - if (!conversationHistory || conversationHistory.length === 0) { + const conversationHistory = safeGetConversationHistory(ctx); + const parsedDataMessages = parseConversationInput(data); + if (conversationHistory.length === 0 && parsedDataMessages.length === 0) { return createSkipResult( 'No conversation history available', config.confidence_threshold, - data + JSON.stringify([]) ); } - // Get incremental prompt injection detection checking state - const lastCheckedIndex = ctx.getInjectionLastCheckedIndex(); - - // Parse only new conversation data since last check - const { user_intent, new_llm_actions: initial_llm_actions } = parseConversationHistory( + const { recentMessages, actionableMessages, userIntent } = prepareConversationSlice( conversationHistory, - lastCheckedIndex + parsedDataMessages ); - // If no new actions found in conversation history, try parsing current response data - let new_llm_actions = initial_llm_actions; - if (new_llm_actions.length === 0) { - new_llm_actions = tryParseCurrentResponse(data); - } + const userGoalText = formatUserGoal(userIntent); + const checkedText = JSON.stringify(recentMessages, null, 2); - if (!new_llm_actions || new_llm_actions.length === 0 || !user_intent.most_recent_message) { + if (!userIntent.most_recent_message) { return createSkipResult( - 'No function calls or function call outputs to evaluate', + 'No LLM actions or user intent to evaluate', config.confidence_threshold, - data, - user_intent.most_recent_message || 'N/A', - new_llm_actions + checkedText, + userGoalText, + actionableMessages, + recentMessages ); } - // Format user context for analysis - let userGoalText: string; - if (user_intent.previous_context.length > 0) { - const contextText = user_intent.previous_context.map((msg) => `- ${msg}`).join('\n'); - userGoalText = `Most recent request: ${user_intent.most_recent_message} - -Previous context: -${contextText}`; - } else { - userGoalText = user_intent.most_recent_message; - } - - // Skip if the only new action is a user message (we don't check user alignment with their own goals) - if (new_llm_actions.length === 1 && new_llm_actions[0]?.role === 'user') { - ctx.updateInjectionLastCheckedIndex(conversationHistory.length); + if (actionableMessages.length === 0) { return createSkipResult( - 'Skipping check: only new action is a user message', + 'No actionable tool messages to evaluate', config.confidence_threshold, - data, + checkedText, userGoalText, - new_llm_actions + actionableMessages, + recentMessages ); } - // Format for LLM analysis - const analysisPrompt = `${PROMPT_INJECTION_DETECTION_CHECK_PROMPT} - -**User's goal:** ${userGoalText} -**LLM action:** ${JSON.stringify(new_llm_actions)}`; + const analysisPrompt = buildAnalysisPrompt(userGoalText, recentMessages, actionableMessages); - // Call LLM for analysis const analysis = await callPromptInjectionDetectionLLM(ctx, analysisPrompt, config); - // Update the last checked index now that we've successfully analyzed - ctx.updateInjectionLastCheckedIndex(conversationHistory.length); - - // Determine if tripwire should trigger const isMisaligned = analysis.flagged && analysis.confidence >= config.confidence_threshold; return { @@ -227,8 +211,9 @@ ${contextText}`; confidence: analysis.confidence, threshold: config.confidence_threshold, user_goal: userGoalText, - action: new_llm_actions, - checked_text: JSON.stringify(conversationHistory), + action: actionableMessages, + recent_messages: recentMessages, + checked_text: checkedText, }, }; } catch (error) { @@ -240,97 +225,102 @@ ${contextText}`; } }; -/** - * Parse conversation data incrementally, only analyzing new LLM actions. - * - * @param conversationHistory Full conversation history - * @param lastCheckedIndex Index of the last message we checked - * @returns Parsed conversation data with user intent and new LLM actions - */ -function parseConversationHistory( +function safeGetConversationHistory(ctx: PromptInjectionDetectionContext): any[] { + try { + const history = ctx.getConversationHistory(); + if (Array.isArray(history)) { + return history; + } + } catch { + // Fall through to empty array when conversation history is unavailable + } + return []; +} + +function prepareConversationSlice( conversationHistory: any[], - lastCheckedIndex: number -): ParsedConversation { - // Always get full user intent context for proper analysis - const user_intent = extractUserIntentFromMessages(conversationHistory); + parsedDataMessages: any[] +): { recentMessages: any[]; actionableMessages: any[]; userIntent: UserIntentDict } { + const historyMessages = Array.isArray(conversationHistory) ? conversationHistory : []; + const datasetMessages = Array.isArray(parsedDataMessages) ? parsedDataMessages : []; + + const sourceMessages = historyMessages.length > 0 ? historyMessages : datasetMessages; + let userIntent = extractUserIntentFromMessages(sourceMessages); + + let recentMessages = sliceMessagesAfterLatestUser(sourceMessages); + let actionableMessages = extractActionableMessages(recentMessages); + + if (actionableMessages.length === 0 && datasetMessages.length > 0 && historyMessages.length > 0) { + recentMessages = sliceMessagesAfterLatestUser(datasetMessages); + actionableMessages = extractActionableMessages(recentMessages); + if (!userIntent.most_recent_message) { + userIntent = extractUserIntentFromMessages(datasetMessages); + } + } + + return { recentMessages, actionableMessages, userIntent }; +} - // Get only new LLM actions since the last check - let new_llm_actions: any[]; - if (lastCheckedIndex >= conversationHistory.length) { - // No new actions since last check - new_llm_actions = []; - } else { - // Get actions from where we left off - const all_new_actions = conversationHistory.slice(lastCheckedIndex); +function sliceMessagesAfterLatestUser(messages: any[]): any[] { + if (!Array.isArray(messages) || messages.length === 0) { + return []; + } - // Filter to only include function calls and outputs (skip user/assistant text) - new_llm_actions = all_new_actions.filter(isFunctionCallOrOutput); + const lastUserIndex = findLastUserIndex(messages); + if (lastUserIndex >= 0) { + return messages.slice(lastUserIndex + 1); } - return { user_intent, new_llm_actions }; + return messages.slice(); } -/** - * Check if an action is a function call or function output that should be analyzed. - * - * @param action Action object to check - * @returns True if action should be analyzed for alignment - */ -function isFunctionCallOrOutput(action: any): boolean { - if (typeof action !== 'object' || action === null) { - return false; +function findLastUserIndex(messages: any[]): number { + for (let i = messages.length - 1; i >= 0; i -= 1) { + if (isUserMessageEntry(messages[i])) { + return i; + } } + return -1; +} - // Responses API formats - if (action.type === 'function_call' || action.type === 'function_call_output') { - return true; +function isUserMessageEntry(entry: any, seen: Set = new Set()): boolean { + if (!entry || typeof entry !== 'object') { + return false; } - // Chat completions API formats - if (action.role === 'assistant' && action.tool_calls?.length > 0) { - return true; // Assistant message with tool calls - } - if (action.role === 'tool') { - return true; // Tool response message + if (seen.has(entry)) { + return false; } + seen.add(entry); - return false; // Skip user messages, assistant text, etc. -} + if (entry.role === 'user') { + return true; + } -/** - * Extract text content from various message content formats. - * - * @param content Message content (string, array, or other) - * @returns Extracted text string - */ -function extractContentText(content: any): string { - if (typeof content === 'string') { - return content; + if (entry.type === 'message' && entry.role === 'user') { + return true; } - if (Array.isArray(content)) { - // For responses API format with content parts - return content - .filter((part) => part?.type === 'input_text' && typeof part.text === 'string') - .map((part) => part.text) - .join(' '); + + for (const key of NESTED_MESSAGE_KEYS) { + const value = (entry as Record)[key]; + if (Array.isArray(value)) { + for (const item of value) { + if (isUserMessageEntry(item, seen)) { + return true; + } + } + } } - return String(content || ''); + + return false; } -/** - * Extract user intent with full context from a list of messages. - * - * @param messages List of conversation messages - * @returns User intent dictionary with most recent message and previous context - */ function extractUserIntentFromMessages(messages: any[]): UserIntentDict { const userMessages: string[] = []; + const visited = new Set(); - // Extract all user messages in chronological order - for (const msg of messages) { - if (msg?.role === 'user') { - userMessages.push(extractContentText(msg.content)); - } + for (const message of messages) { + collectUserMessages(message, userMessages, visited); } if (userMessages.length === 0) { @@ -343,22 +333,171 @@ function extractUserIntentFromMessages(messages: any[]): UserIntentDict { }; } -/** - * Create result for skipped alignment checks (errors, no data, etc.). - * - * @param observation Description of why the check was skipped - * @param threshold Confidence threshold - * @param data Original data - * @param userGoal User goal (optional) - * @param action Action that was analyzed (optional) - * @returns GuardrailResult for skipped check - */ +function collectUserMessages(value: any, collected: string[], visited: Set): void { + if (!value || typeof value !== 'object') { + return; + } + + if (visited.has(value)) { + return; + } + visited.add(value); + + if (value.role === 'user') { + const text = extractUserMessageText(value); + if (text) { + collected.push(text); + } + } + + for (const key of NESTED_MESSAGE_KEYS) { + const nestedValue = (value as Record)[key]; + if (Array.isArray(nestedValue)) { + for (const item of nestedValue) { + collectUserMessages(item, collected, visited); + } + } + } +} + +function extractUserMessageText(message: any): string { + if (typeof message === 'string') { + return message; + } + + if (!message || typeof message !== 'object') { + return ''; + } + + if (typeof message.content === 'string') { + return message.content; + } + + if (Array.isArray(message.content)) { + const contentText = collectTextFromContent(message.content); + if (contentText) { + return contentText; + } + } + + if (typeof message.text === 'string') { + return message.text; + } + + if (typeof message.value === 'string') { + return message.value; + } + + return ''; +} + +function collectTextFromContent(content: any[]): string { + const parts: string[] = []; + + for (const item of content) { + if (item == null) { + continue; + } + + if (typeof item === 'string') { + if (item.trim().length > 0) { + parts.push(item.trim()); + } + continue; + } + + if (typeof item !== 'object') { + continue; + } + + if (typeof (item as { text?: string }).text === 'string') { + parts.push((item as { text: string }).text); + continue; + } + + if (typeof (item as { content?: string }).content === 'string') { + parts.push((item as { content: string }).content); + continue; + } + + if (Array.isArray((item as { content?: any[] }).content)) { + const nested = collectTextFromContent((item as { content: any[] }).content); + if (nested) { + parts.push(nested); + } + continue; + } + } + + return parts.join(' ').trim(); +} + +function extractActionableMessages(messages: any[]): any[] { + if (!Array.isArray(messages)) { + return []; + } + return messages.filter((message) => isActionableMessage(message)); +} + +function isActionableMessage(message: any, seen: Set = new Set()): boolean { + if (!message || typeof message !== 'object') { + return false; + } + + if (seen.has(message)) { + return false; + } + seen.add(message); + + if ( + message.type === 'function_call' || + message.type === 'function_call_output' || + message.type === 'tool_call' || + message.type === 'tool_result' + ) { + return true; + } + + if (message.role === 'assistant' && Array.isArray(message.tool_calls) && message.tool_calls.length > 0) { + return true; + } + + if (message.role === 'tool') { + return true; + } + + const content = (message as Record).content; + if (Array.isArray(content)) { + for (const part of content) { + if ( + part && + typeof part === 'object' && + ['tool_use', 'function_call', 'tool_result', 'tool_call', 'function_call_output'].includes( + (part as { type?: string }).type ?? '' + ) + ) { + return true; + } + } + } + + for (const key of NESTED_MESSAGE_KEYS) { + const nested = (message as Record)[key]; + if (Array.isArray(nested) && nested.some((item) => isActionableMessage(item, seen))) { + return true; + } + } + + return false; +} + function createSkipResult( observation: string, threshold: number, - data: string, + checkedText: string, userGoal: string = 'N/A', - action: any = null + action: any[] = [], + recentMessages: any[] = [] ): GuardrailResult { return { tripwireTriggered: false, @@ -369,28 +508,51 @@ function createSkipResult( confidence: 0.0, threshold, user_goal: userGoal, - action: action || [], - checked_text: data, + action: action ?? [], + recent_messages: recentMessages, + checked_text: checkedText, }, }; } -/** - * Try to parse current response data for tool calls (fallback mechanism). - * - * @param data Response data that might contain JSON - * @returns Array of actions found, empty if none - */ -function tryParseCurrentResponse(data: string): any[] { - try { - const currentResponse = JSON.parse(data); - if (currentResponse?.choices?.[0]?.message?.tool_calls?.length > 0) { - return [currentResponse.choices[0].message]; - } - } catch { - // data is not JSON, ignore +function formatUserGoal(userIntent: UserIntentDict): string { + if (!userIntent.most_recent_message) { + return 'N/A'; } - return []; + + if (userIntent.previous_context.length === 0) { + return userIntent.most_recent_message; + } + + const contextText = userIntent.previous_context.map((msg) => `- ${msg}`).join('\n'); + return `Most recent request: ${userIntent.most_recent_message} + +Previous context: +${contextText}`; +} + +function buildAnalysisPrompt( + userGoalText: string, + recentMessages: any[], + actionableMessages: any[] +): string { + const recentMessagesText = + recentMessages.length > 0 ? JSON.stringify(recentMessages, null, 2) : '[]'; + const actionableMessagesText = + actionableMessages.length > 0 ? JSON.stringify(actionableMessages, null, 2) : '[]'; + + return `${PROMPT_INJECTION_DETECTION_CHECK_PROMPT} + +${STRICT_JSON_INSTRUCTION} + +Most recent user goal: +${userGoalText} + +Recent conversation after latest user turn: +${recentMessagesText} + +LLM actions to evaluate: +${actionableMessagesText}`; } /** diff --git a/src/evals/core/async-engine.ts b/src/evals/core/async-engine.ts index 627f32e..a170c36 100644 --- a/src/evals/core/async-engine.ts +++ b/src/evals/core/async-engine.ts @@ -6,7 +6,9 @@ */ import { Context, RunEngine, Sample, SampleResult } from './types'; -import { ConfiguredGuardrail, runGuardrails } from '../../runtime'; +import { ConfiguredGuardrail } from '../../runtime'; +import { GuardrailLLMContextWithHistory, GuardrailResult } from '../../types'; +import { parseConversationInput } from '../../utils/conversation'; /** * Runs guardrail evaluations asynchronously. @@ -67,55 +69,143 @@ export class AsyncRunEngine implements RunEngine { * @returns Evaluation result for the sample */ private async evaluateSample(context: Context, sample: Sample): Promise { - try { - // Use the actual guardrail configurations that were loaded - const bundle = { - guardrails: this.guardrails.map((g) => ({ - name: g.definition.name, - config: g.config, - })), - }; - - const results = await runGuardrails(sample.data, bundle, context); + const triggered: Record = {}; + const details: Record = {}; - const triggered: Record = {}; - const details: Record = {}; - - // Initialize all guardrails as not triggered - for (const name of this.guardrailNames) { - triggered[name] = false; - } + for (const name of this.guardrailNames) { + triggered[name] = false; + } - // Process results - for (let i = 0; i < results.length; i++) { - const result = results[i]; - const name = this.guardrailNames[i] || 'unknown'; - triggered[name] = result.tripwireTriggered; - if (result.info) { - details[name] = result.info; + try { + for (let i = 0; i < this.guardrails.length; i += 1) { + const guardrail = this.guardrails[i]; + const name = this.guardrailNames[i] || guardrail.definition.name || 'unknown'; + + try { + const result = await this.runGuardrailWithIncrementalSupport( + context, + guardrail, + sample.data + ); + + triggered[name] = result.tripwireTriggered; + if (result.info) { + details[name] = result.info; + } + } catch (guardrailError) { + console.error(`Error running guardrail ${name} on sample ${sample.id}:`, guardrailError); + triggered[name] = false; + details[name] = { + error: guardrailError instanceof Error ? guardrailError.message : String(guardrailError), + }; } } - + } catch (error) { + console.error(`Error evaluating sample ${sample.id}:`, error); return { id: sample.id, expectedTriggers: sample.expectedTriggers, triggered, - details, + details: { + ...details, + error: error instanceof Error ? error.message : String(error), + }, }; - } catch (error) { - console.error(`Error evaluating sample ${sample.id}:`, error); + } + + return { + id: sample.id, + expectedTriggers: sample.expectedTriggers, + triggered, + details, + }; + } + + private async runGuardrailWithIncrementalSupport( + context: Context, + guardrail: ConfiguredGuardrail, + sampleData: string + ): Promise { + if (this.isPromptInjectionGuardrail(guardrail)) { + return await this.runPromptInjectionIncremental(context, guardrail, sampleData); + } + + return await guardrail.run(context as any, sampleData); + } + + private isPromptInjectionGuardrail(guardrail: ConfiguredGuardrail): boolean { + const normalized = (guardrail.definition.name ?? '') + .replace(/\s+/g, ' ') + .trim() + .toLowerCase(); + return normalized === 'prompt injection detection'; + } + + private async runPromptInjectionIncremental( + context: Context, + guardrail: ConfiguredGuardrail, + sampleData: string + ): Promise { + const conversation = parseConversationInput(sampleData); + + if (conversation.length === 0) { + const guardrailContext = this.createPromptInjectionContext(context, []); + return await guardrail.run(guardrailContext as GuardrailLLMContextWithHistory, sampleData); + } + + let finalResult: GuardrailResult | null = null; - const triggered: Record = {}; - for (const name of this.guardrailNames) { - triggered[name] = false; + for (let turnIndex = 0; turnIndex < conversation.length; turnIndex += 1) { + const historySlice = conversation.slice(0, turnIndex + 1); + const guardrailContext = this.createPromptInjectionContext( + context, + historySlice + ); + const serializedHistory = safeStringify(historySlice, sampleData); + + const result = await guardrail.run( + guardrailContext as GuardrailLLMContextWithHistory, + serializedHistory + ); + + finalResult = result; + + if (result.tripwireTriggered) { + break; } + } + if (!finalResult) { return { - id: sample.id, - expectedTriggers: sample.expectedTriggers, - triggered, - details: { error: error instanceof Error ? error.message : String(error) }, + tripwireTriggered: false, + info: { + guardrail_name: guardrail.definition.name, + observation: 'No conversation turns evaluated', + flagged: false, + confidence: 0.0, + checked_text: sampleData, + }, }; } + + return finalResult; + } + + private createPromptInjectionContext( + context: Context, + conversationHistory: any[] + ): GuardrailLLMContextWithHistory { + return { + guardrailLlm: context.guardrailLlm, + getConversationHistory: () => conversationHistory, + }; + } +} + +function safeStringify(value: unknown, fallback: string): string { + try { + return JSON.stringify(value, null, 2); + } catch { + return fallback; } } diff --git a/src/types.ts b/src/types.ts index bdc9c15..5bd5a9d 100644 --- a/src/types.ts +++ b/src/types.ts @@ -29,10 +29,6 @@ export interface GuardrailLLMContext { export interface GuardrailLLMContextWithHistory extends GuardrailLLMContext { /** Get the full conversation history */ getConversationHistory(): any[]; - /** Get the index of the last message that was checked for prompt injection detection */ - getInjectionLastCheckedIndex(): number; - /** Update the index of the last message that was checked for prompt injection detection */ - updateInjectionLastCheckedIndex(index: number): void; } /** diff --git a/src/utils/conversation.ts b/src/utils/conversation.ts new file mode 100644 index 0000000..3964e39 --- /dev/null +++ b/src/utils/conversation.ts @@ -0,0 +1,59 @@ +const POSSIBLE_CONVERSATION_KEYS = [ + 'messages', + 'conversation', + 'conversation_history', + 'conversationHistory', + 'recent_messages', + 'recentMessages', + 'turns', + 'output', + 'outputs', +] as const; + +/** + * Parse conversation-like input into a flat list of message objects. + * + * Accepts raw JSON strings, arrays, or objects that embed conversation arrays under + * several common keys. Returns an empty array when no conversation data is found. + */ +export function parseConversationInput(rawInput: unknown): any[] { + if (Array.isArray(rawInput)) { + return rawInput; + } + + if (rawInput == null) { + return []; + } + + if (typeof rawInput === 'string') { + const trimmed = rawInput.trim(); + if (!trimmed) { + return []; + } + try { + const parsed = JSON.parse(trimmed); + return parseConversationInput(parsed); + } catch { + return []; + } + } + + if (typeof rawInput === 'object') { + for (const key of POSSIBLE_CONVERSATION_KEYS) { + const value = (rawInput as Record)[key]; + if (Array.isArray(value)) { + return value; + } + if (value && typeof value === 'object') { + const nested = parseConversationInput(value); + if (nested.length > 0) { + return nested; + } + } + } + } + + return []; +} + +export { POSSIBLE_CONVERSATION_KEYS };