diff --git a/src/__tests__/unit/checks/topical-alignment.test.ts b/src/__tests__/unit/checks/topical-alignment.test.ts index 1324dca..7705826 100644 --- a/src/__tests__/unit/checks/topical-alignment.test.ts +++ b/src/__tests__/unit/checks/topical-alignment.test.ts @@ -2,13 +2,20 @@ * Tests for the topical alignment guardrail. */ -import { describe, it, expect, vi, afterEach } from 'vitest'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { GuardrailLLMContext } from '../../../types'; -const buildFullPromptMock = vi.fn((prompt: string) => `FULL:${prompt}`); +const createLLMCheckFnMock = vi.fn(() => 'mocked-guardrail'); const registerMock = vi.fn(); vi.mock('../../../checks/llm-base', () => ({ - buildFullPrompt: buildFullPromptMock, + createLLMCheckFn: createLLMCheckFnMock, + LLMConfig: { + omit: vi.fn(() => ({ + extend: vi.fn(() => ({})), + })), + }, + LLMOutput: {}, })); vi.mock('../../../registry', () => ({ @@ -17,9 +24,23 @@ vi.mock('../../../registry', () => ({ }, })); -describe('topicalAlignmentCheck', () => { - afterEach(() => { - buildFullPromptMock.mockClear(); +describe('topicalAlignment guardrail', () => { + beforeEach(() => { + registerMock.mockClear(); + createLLMCheckFnMock.mockClear(); + }); + + it('is created via createLLMCheckFn', async () => { + const { topicalAlignment } = await import('../../../checks/topical-alignment'); + + expect(topicalAlignment).toBe('mocked-guardrail'); + expect(createLLMCheckFnMock).toHaveBeenCalled(); + }); +}); + +describe('topicalAlignment integration tests', () => { + beforeEach(() => { + vi.resetModules(); }); interface TopicalAlignmentConfig { @@ -28,12 +49,6 @@ describe('topicalAlignmentCheck', () => { system_prompt_details: string; } - const config: TopicalAlignmentConfig = { - model: 'gpt-topic', - confidence_threshold: 0.6, - system_prompt_details: 'Stay on topic about finance.', - }; - interface MockLLMResponse { choices: Array<{ message: { @@ -42,8 +57,13 @@ describe('topicalAlignmentCheck', () => { }>; } - const makeCtx = (response: MockLLMResponse) => { - const create = vi.fn().mockResolvedValue(response); + const makeCtx = (response: MockLLMResponse, capturedParams?: { value?: unknown }) => { + const create = vi.fn().mockImplementation((params) => { + if (capturedParams) { + capturedParams.value = params; + } + return Promise.resolve(response); + }); return { ctx: { guardrailLlm: { @@ -52,83 +72,208 @@ describe('topicalAlignmentCheck', () => { create, }, }, + baseURL: 'https://api.openai.com/v1', }, - }, + } as GuardrailLLMContext, create, }; }; - it('triggers when LLM flags off-topic content above threshold', async () => { - const { topicalAlignmentCheck } = await import('../../../checks/topical-alignment'); - const { ctx, create } = makeCtx({ - choices: [ - { - message: { - content: JSON.stringify({ flagged: true, confidence: 0.8 }), + it('triggers when LLM flags off-topic content above threshold with gpt-4', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/topical-alignment'); + + const { topicalAlignment } = await import('../../../checks/topical-alignment'); + const capturedParams: { value?: unknown } = {}; + const { ctx, create } = makeCtx( + { + choices: [ + { + message: { + content: JSON.stringify({ flagged: true, confidence: 0.8 }), + }, }, - }, - ], - }); + ], + }, + capturedParams + ); - const result = await topicalAlignmentCheck(ctx, 'Discussing sports', config); + const config: TopicalAlignmentConfig = { + model: 'gpt-4', + confidence_threshold: 0.7, + system_prompt_details: 'Stay on topic about finance.', + }; - expect(buildFullPromptMock).toHaveBeenCalled(); - expect(create).toHaveBeenCalledWith({ - messages: [ - { role: 'system', content: expect.stringContaining('Stay on topic about finance.') }, - { role: 'user', content: 'Discussing sports' }, - ], - model: 'gpt-topic', - temperature: 0.0, - response_format: { type: 'json_object' }, - }); + const result = await topicalAlignment(ctx, 'Discussing sports', config); + + expect(create).toHaveBeenCalled(); + const params = capturedParams.value as Record; + expect(params.model).toBe('gpt-4'); + expect(params.temperature).toBe(0.0); // gpt-4 uses temperature 0 + expect(params.response_format).toEqual({ type: 'json_object' }); expect(result.tripwireTriggered).toBe(true); expect(result.info?.flagged).toBe(true); expect(result.info?.confidence).toBe(0.8); }); - it('returns failure info when no content is returned', async () => { - const { topicalAlignmentCheck } = await import('../../../checks/topical-alignment'); + it('uses temperature 1.0 for gpt-5 models (which do not support temperature 0)', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/topical-alignment'); + + const { topicalAlignment } = await import('../../../checks/topical-alignment'); + const capturedParams: { value?: unknown } = {}; + const { ctx, create } = makeCtx( + { + choices: [ + { + message: { + content: JSON.stringify({ flagged: false, confidence: 0.2 }), + }, + }, + ], + }, + capturedParams + ); + + const config: TopicalAlignmentConfig = { + model: 'gpt-5', + confidence_threshold: 0.7, + system_prompt_details: 'Stay on topic about technology.', + }; + + const result = await topicalAlignment(ctx, 'Discussing AI and ML', config); + + expect(create).toHaveBeenCalled(); + const params = capturedParams.value as Record; + expect(params.model).toBe('gpt-5'); + expect(params.temperature).toBe(1.0); // gpt-5 uses temperature 1.0, not 0 + expect(params.response_format).toEqual({ type: 'json_object' }); + expect(result.tripwireTriggered).toBe(false); + }); + + it('works with gpt-4o model', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/topical-alignment'); + + const { topicalAlignment } = await import('../../../checks/topical-alignment'); + const capturedParams: { value?: unknown } = {}; + const { ctx, create } = makeCtx( + { + choices: [ + { + message: { + content: JSON.stringify({ flagged: true, confidence: 0.9 }), + }, + }, + ], + }, + capturedParams + ); + + const config: TopicalAlignmentConfig = { + model: 'gpt-4o', + confidence_threshold: 0.8, + system_prompt_details: 'Stay on topic about healthcare.', + }; + + const result = await topicalAlignment(ctx, 'Talking about cars', config); + + expect(create).toHaveBeenCalled(); + const params = capturedParams.value as Record; + expect(params.model).toBe('gpt-4o'); + expect(params.temperature).toBe(0.0); // gpt-4o uses temperature 0 + expect(result.tripwireTriggered).toBe(true); + }); + + it('works with gpt-3.5-turbo model', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/topical-alignment'); + + const { topicalAlignment } = await import('../../../checks/topical-alignment'); + const capturedParams: { value?: unknown } = {}; + const { ctx, create } = makeCtx( + { + choices: [ + { + message: { + content: JSON.stringify({ flagged: false, confidence: 0.3 }), + }, + }, + ], + }, + capturedParams + ); + + const config: TopicalAlignmentConfig = { + model: 'gpt-3.5-turbo', + confidence_threshold: 0.7, + system_prompt_details: 'Stay on topic about education.', + }; + + const result = await topicalAlignment(ctx, 'Discussing teaching methods', config); + + expect(create).toHaveBeenCalled(); + const params = capturedParams.value as Record; + expect(params.model).toBe('gpt-3.5-turbo'); + expect(params.temperature).toBe(0.0); + expect(result.tripwireTriggered).toBe(false); + }); + + it('does not trigger when confidence is below threshold', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/topical-alignment'); + + const { topicalAlignment } = await import('../../../checks/topical-alignment'); const { ctx } = makeCtx({ - choices: [{ message: { content: '' } }], + choices: [ + { + message: { + content: JSON.stringify({ flagged: true, confidence: 0.5 }), + }, + }, + ], }); - const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); - const result = await topicalAlignmentCheck(ctx, 'Hi', config); + const config: TopicalAlignmentConfig = { + model: 'gpt-4', + confidence_threshold: 0.7, + system_prompt_details: 'Stay on topic about finance.', + }; - consoleSpy.mockRestore(); + const result = await topicalAlignment(ctx, 'Maybe off topic', config); expect(result.tripwireTriggered).toBe(false); - expect(result.info?.error).toBeDefined(); + expect(result.info?.flagged).toBe(true); + expect(result.info?.confidence).toBe(0.5); }); - it('handles unexpected errors gracefully', async () => { + it('handles execution failures gracefully', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/topical-alignment'); + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); - const { topicalAlignmentCheck } = await import('../../../checks/topical-alignment'); + const { topicalAlignment } = await import('../../../checks/topical-alignment'); const ctx = { guardrailLlm: { chat: { completions: { - create: vi.fn().mockRejectedValue(new Error('timeout')), + create: vi.fn().mockRejectedValue(new Error('API timeout')), }, }, + baseURL: 'https://api.openai.com/v1', }, + } as GuardrailLLMContext; + + const config: TopicalAlignmentConfig = { + model: 'gpt-4', + confidence_threshold: 0.7, + system_prompt_details: 'Stay on topic about finance.', }; - interface MockContext { - guardrailLlm: { - chat: { - completions: { - create: ReturnType; - }; - }; - }; - } - - const result = await topicalAlignmentCheck(ctx as MockContext, 'Test', config); + const result = await topicalAlignment(ctx, 'Test text', config); expect(result.tripwireTriggered).toBe(false); - expect(result.info?.error).toContain('timeout'); + expect(result.executionFailed).toBe(true); consoleSpy.mockRestore(); }); }); diff --git a/src/__tests__/unit/checks/user-defined-llm.test.ts b/src/__tests__/unit/checks/user-defined-llm.test.ts index 0b367e1..44ba367 100644 --- a/src/__tests__/unit/checks/user-defined-llm.test.ts +++ b/src/__tests__/unit/checks/user-defined-llm.test.ts @@ -2,115 +2,311 @@ * Tests for the user-defined LLM guardrail. */ -import { describe, it, expect, vi } from 'vitest'; -import { - userDefinedLLMCheck, - UserDefinedConfig, - UserDefinedContext, -} from '../../../checks/user-defined-llm'; - -const makeCtx = () => { - const create = vi.fn(); - const ctx: UserDefinedContext = { - guardrailLlm: { - chat: { - completions: { - create, - }, - }, - }, - }; - return { ctx, create }; -}; +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { GuardrailLLMContext } from '../../../types'; + +const createLLMCheckFnMock = vi.fn(() => 'mocked-guardrail'); +const registerMock = vi.fn(); + +vi.mock('../../../checks/llm-base', () => ({ + createLLMCheckFn: createLLMCheckFnMock, + LLMConfig: { + omit: vi.fn(() => ({ + extend: vi.fn(() => ({})), + })), + }, + LLMOutput: { + extend: vi.fn(() => ({})), + }, +})); -const config = UserDefinedConfig.parse({ - model: 'gpt-test', - confidence_threshold: 0.7, - system_prompt_details: 'Only allow positive comments.', +vi.mock('../../../registry', () => ({ + defaultSpecRegistry: { + register: registerMock, + }, +})); + +describe('userDefinedLLM guardrail', () => { + beforeEach(() => { + registerMock.mockClear(); + createLLMCheckFnMock.mockClear(); + }); + + it('is created via createLLMCheckFn', async () => { + const { userDefinedLLM } = await import('../../../checks/user-defined-llm'); + + expect(userDefinedLLM).toBe('mocked-guardrail'); + expect(createLLMCheckFnMock).toHaveBeenCalled(); + }); }); -describe('userDefinedLLMCheck', () => { - it('triggers tripwire when flagged above threshold from JSON response', async () => { - const { ctx, create } = makeCtx(); - create.mockResolvedValue({ - choices: [ - { - message: { - content: JSON.stringify({ flagged: true, confidence: 0.95, reason: 'negative tone' }), +describe('userDefinedLLM integration tests', () => { + beforeEach(() => { + vi.resetModules(); + }); + + interface UserDefinedConfig { + model: string; + confidence_threshold: number; + system_prompt_details: string; + } + + interface MockLLMResponse { + choices: Array<{ + message: { + content: string; + }; + }>; + } + + const makeCtx = (response: MockLLMResponse, capturedParams?: { value?: unknown }) => { + const create = vi.fn().mockImplementation((params) => { + if (capturedParams) { + capturedParams.value = params; + } + return Promise.resolve(response); + }); + return { + ctx: { + guardrailLlm: { + chat: { + completions: { + create, + }, }, + baseURL: 'https://api.openai.com/v1', }, - ], - }); + } as unknown as GuardrailLLMContext, + create, + }; + }; - const result = await userDefinedLLMCheck(ctx, 'This is bad.', config); + it('triggers tripwire when flagged above threshold with gpt-4', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/user-defined-llm'); - expect(create).toHaveBeenCalledWith({ - messages: [ - { role: 'system', content: expect.stringContaining('Only allow positive comments.') }, - { role: 'user', content: 'This is bad.' }, - ], - model: 'gpt-test', - temperature: 0.0, - response_format: { type: 'json_object' }, - safety_identifier: 'openai-guardrails-js', - }); + const { userDefinedLLM } = await import('../../../checks/user-defined-llm'); + const capturedParams: { value?: unknown } = {}; + const { ctx, create } = makeCtx( + { + choices: [ + { + message: { + content: JSON.stringify({ flagged: true, confidence: 0.95, reason: 'negative tone' }), + }, + }, + ], + }, + capturedParams + ); + + const config: UserDefinedConfig = { + model: 'gpt-4', + confidence_threshold: 0.7, + system_prompt_details: 'Only allow positive comments.', + }; + + const result = await userDefinedLLM(ctx, 'This is bad.', config); + + expect(create).toHaveBeenCalled(); + const params = capturedParams.value as Record; + expect(params.model).toBe('gpt-4'); + expect(params.temperature).toBe(0.0); + expect(params.response_format).toEqual({ type: 'json_object' }); expect(result.tripwireTriggered).toBe(true); expect(result.info?.flagged).toBe(true); expect(result.info?.confidence).toBe(0.95); - expect(result.info?.reason).toBe('negative tone'); }); - it('falls back to text parsing when response_format is unsupported', async () => { - const { ctx, create } = makeCtx(); - interface OpenAIError extends Error { - error: { - param: string; - code?: string; - message?: string; - }; - } - - const errorObj = new Error('format not supported') as OpenAIError; - errorObj.error = { param: 'response_format' }; - create.mockRejectedValueOnce(errorObj); - create.mockResolvedValueOnce({ + it('uses temperature 1.0 for gpt-5 models', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/user-defined-llm'); + + const { userDefinedLLM } = await import('../../../checks/user-defined-llm'); + const capturedParams: { value?: unknown } = {}; + const { ctx, create } = makeCtx( + { + choices: [ + { + message: { + content: JSON.stringify({ flagged: false, confidence: 0.2 }), + }, + }, + ], + }, + capturedParams + ); + + const config: UserDefinedConfig = { + model: 'gpt-5', + confidence_threshold: 0.7, + system_prompt_details: 'Only allow technical content.', + }; + + const result = await userDefinedLLM(ctx, 'This is technical content.', config); + + expect(create).toHaveBeenCalled(); + const params = capturedParams.value as Record; + expect(params.model).toBe('gpt-5'); + expect(params.temperature).toBe(1.0); // gpt-5 uses temperature 1.0 + expect(params.response_format).toEqual({ type: 'json_object' }); + expect(result.tripwireTriggered).toBe(false); + }); + + it('works with gpt-4o model', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/user-defined-llm'); + + const { userDefinedLLM } = await import('../../../checks/user-defined-llm'); + const capturedParams: { value?: unknown } = {}; + const { ctx, create } = makeCtx( + { + choices: [ + { + message: { + content: JSON.stringify({ flagged: true, confidence: 0.9 }), + }, + }, + ], + }, + capturedParams + ); + + const config: UserDefinedConfig = { + model: 'gpt-4o', + confidence_threshold: 0.8, + system_prompt_details: 'Flag inappropriate language.', + }; + + const result = await userDefinedLLM(ctx, 'Bad words here', config); + + expect(create).toHaveBeenCalled(); + const params = capturedParams.value as Record; + expect(params.model).toBe('gpt-4o'); + expect(params.temperature).toBe(0.0); + expect(result.tripwireTriggered).toBe(true); + }); + + it('works with gpt-3.5-turbo model', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/user-defined-llm'); + + const { userDefinedLLM } = await import('../../../checks/user-defined-llm'); + const capturedParams: { value?: unknown } = {}; + const { ctx, create } = makeCtx( + { + choices: [ + { + message: { + content: JSON.stringify({ flagged: false, confidence: 0.1 }), + }, + }, + ], + }, + capturedParams + ); + + const config: UserDefinedConfig = { + model: 'gpt-3.5-turbo', + confidence_threshold: 0.7, + system_prompt_details: 'Check for spam.', + }; + + const result = await userDefinedLLM(ctx, 'Normal message', config); + + expect(create).toHaveBeenCalled(); + const params = capturedParams.value as Record; + expect(params.model).toBe('gpt-3.5-turbo'); + expect(params.temperature).toBe(0.0); + expect(result.tripwireTriggered).toBe(false); + }); + + it('does not trigger when confidence is below threshold', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/user-defined-llm'); + + const { userDefinedLLM } = await import('../../../checks/user-defined-llm'); + const { ctx } = makeCtx({ choices: [ { message: { - content: 'flagged: false, confidence: 0.4, reason: "acceptable"', + content: JSON.stringify({ flagged: true, confidence: 0.5 }), }, }, ], }); - const result = await userDefinedLLMCheck(ctx, 'All good here.', config); + const config: UserDefinedConfig = { + model: 'gpt-4', + confidence_threshold: 0.7, + system_prompt_details: 'Custom check.', + }; + + const result = await userDefinedLLM(ctx, 'Maybe problematic', config); - expect(create).toHaveBeenCalledTimes(2); expect(result.tripwireTriggered).toBe(false); - expect(result.info?.flagged).toBe(false); - expect(result.info?.confidence).toBe(0.4); - expect(result.info?.reason).toBe('acceptable'); + expect(result.info?.flagged).toBe(true); + expect(result.info?.confidence).toBe(0.5); }); - it('returns execution failure metadata when other errors occur', async () => { - const { ctx, create } = makeCtx(); - create.mockRejectedValueOnce(new Error('network down')); + it('handles execution failures gracefully', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/user-defined-llm'); + + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + const { userDefinedLLM } = await import('../../../checks/user-defined-llm'); + const ctx = { + guardrailLlm: { + chat: { + completions: { + create: vi.fn().mockRejectedValue(new Error('Network error')), + }, + }, + baseURL: 'https://api.openai.com/v1', + }, + } as unknown as GuardrailLLMContext; + + const config: UserDefinedConfig = { + model: 'gpt-4', + confidence_threshold: 0.7, + system_prompt_details: 'Custom check.', + }; - const result = await userDefinedLLMCheck(ctx, 'Hello', config); + const result = await userDefinedLLM(ctx, 'Test text', config); expect(result.tripwireTriggered).toBe(false); expect(result.executionFailed).toBe(true); - expect(result.info?.error_message).toContain('network down'); + consoleSpy.mockRestore(); }); - it('handles missing content gracefully', async () => { - const { ctx, create } = makeCtx(); - create.mockResolvedValue({ choices: [{ message: {} }] }); + it('supports optional reason field in output', async () => { + vi.doUnmock('../../../checks/llm-base'); + vi.doUnmock('../../../checks/user-defined-llm'); - const result = await userDefinedLLMCheck(ctx, 'Test', config); + const { userDefinedLLM } = await import('../../../checks/user-defined-llm'); + const { ctx } = makeCtx({ + choices: [ + { + message: { + content: JSON.stringify({ + flagged: true, + confidence: 0.9, + reason: 'Contains profanity', + }), + }, + }, + ], + }); - expect(result.tripwireTriggered).toBe(false); - expect(result.executionFailed).toBe(true); - expect(result.info?.error_message).toBe('No response content from LLM'); + const config: UserDefinedConfig = { + model: 'gpt-4', + confidence_threshold: 0.7, + system_prompt_details: 'Flag profanity.', + }; + + const result = await userDefinedLLM(ctx, 'Bad words', config); + + expect(result.tripwireTriggered).toBe(true); + expect(result.info?.reason).toBe('Contains profanity'); }); }); diff --git a/src/checks/topical-alignment.ts b/src/checks/topical-alignment.ts index 12b0bff..0e72da6 100644 --- a/src/checks/topical-alignment.ts +++ b/src/checks/topical-alignment.ts @@ -7,22 +7,17 @@ */ import { z } from 'zod'; -import { CheckFn, GuardrailResult } from '../types'; -import { defaultSpecRegistry } from '../registry'; -import { buildFullPrompt } from './llm-base'; +import { CheckFn, GuardrailLLMContext } from '../types'; +import { LLMConfig, LLMOutput, createLLMCheckFn } from './llm-base'; /** * Configuration for topical alignment guardrail. * * Extends LLMConfig with a required business scope for content checks. */ -export const TopicalAlignmentConfig = z.object({ - /** The LLM model to use for content checking */ - model: z.string(), - /** Minimum confidence score (0.0 to 1.0) required to trigger the guardrail. Defaults to 0.7. */ - confidence_threshold: z.number().min(0.0).max(1.0).default(0.7), +export const TopicalAlignmentConfig = LLMConfig.omit({ system_prompt_details: true }).extend({ /** Description of the allowed business scope or on-topic context */ - system_prompt_details: z.string(), + system_prompt_details: z.string().describe('Description of the allowed business scope or on-topic context'), }); export type TopicalAlignmentConfig = z.infer; @@ -30,22 +25,12 @@ export type TopicalAlignmentConfig = z.infer; /** * Context requirements for the topical alignment guardrail. */ -export const TopicalAlignmentContext = z.object({ - /** OpenAI client for LLM operations */ - guardrailLlm: z.any(), -}); - -export type TopicalAlignmentContext = z.infer; +export type TopicalAlignmentContext = GuardrailLLMContext; /** * Output schema for topical alignment analysis. */ -export const TopicalAlignmentOutput = z.object({ - /** Whether the content was flagged as off-topic */ - flagged: z.boolean(), - /** Confidence score (0.0 to 1.0) that the input is off-topic */ - confidence: z.number().min(0.0).max(1.0), -}); +export const TopicalAlignmentOutput = LLMOutput; export type TopicalAlignmentOutput = z.infer; @@ -62,86 +47,14 @@ that strays from the allowed topics.`; /** * Topical alignment guardrail. * - * Checks that the content stays within the defined business scope. - * - * @param ctx Guardrail context containing the LLM client. - * @param data Text to analyze for topical alignment. - * @param config Configuration for topical alignment detection. - * @returns GuardrailResult containing topical alignment analysis with flagged status - * and confidence score. + * Checks that the content stays within the defined business scope using + * an LLM to analyze text against a defined context. */ -export const topicalAlignmentCheck: CheckFn< - TopicalAlignmentContext, - string, - TopicalAlignmentConfig -> = async (ctx, data, config): Promise => { - try { - // Render the system prompt with business scope details - const renderedSystemPrompt = SYSTEM_PROMPT.replace( - '{system_prompt_details}', - config.system_prompt_details - ); - - // Use buildFullPrompt to ensure "json" is included for OpenAI's response_format requirement - const fullPrompt = buildFullPrompt(renderedSystemPrompt); - - // Use the OpenAI API to analyze the text - const response = await ctx.guardrailLlm.chat.completions.create({ - messages: [ - { role: 'system', content: fullPrompt }, - { role: 'user', content: data }, - ], - model: config.model, - temperature: 0.0, - response_format: { type: 'json_object' }, - }); - - const content = response.choices[0]?.message?.content; - if (!content) { - throw new Error('No response content from LLM'); - } - - // Parse the JSON response - const analysis: TopicalAlignmentOutput = JSON.parse(content); - - // Determine if tripwire should be triggered - const isTrigger = analysis.flagged && analysis.confidence >= config.confidence_threshold; - - return { - tripwireTriggered: isTrigger, - info: { - checked_text: data, // Alignment doesn't modify the text - guardrail_name: 'Off Topic Content', - ...analysis, - threshold: config.confidence_threshold, - business_scope: config.system_prompt_details, - }, - }; - } catch (error) { - // Log unexpected errors and return safe default - console.error('Unexpected error in topical alignment detection:', error); - return { - tripwireTriggered: false, - info: { - checked_text: data, // Return original text on error - guardrail_name: 'Off Topic Content', - flagged: false, - confidence: 0.0, - threshold: config.confidence_threshold, - business_scope: config.system_prompt_details, - error: String(error), - }, - }; - } -}; - -// Auto-register this guardrail with the default registry -defaultSpecRegistry.register( - 'Off Topic Prompts', - topicalAlignmentCheck, - 'Checks that the content stays within the defined business scope', - 'text/plain', - TopicalAlignmentConfig as z.ZodType, - TopicalAlignmentContext, - { engine: 'llm' } -); +export const topicalAlignment: CheckFn = + createLLMCheckFn( + 'Off Topic Prompts', + 'Checks that the content stays within the defined business scope', + SYSTEM_PROMPT, + TopicalAlignmentOutput, + TopicalAlignmentConfig as unknown as typeof LLMConfig + ) as CheckFn; diff --git a/src/checks/user-defined-llm.ts b/src/checks/user-defined-llm.ts index 54aed25..5ec0d8d 100644 --- a/src/checks/user-defined-llm.ts +++ b/src/checks/user-defined-llm.ts @@ -7,22 +7,17 @@ */ import { z } from 'zod'; -import { CheckFn, GuardrailResult } from '../types'; -import { defaultSpecRegistry } from '../registry'; -import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../utils/safety-identifier'; +import { CheckFn, GuardrailLLMContext } from '../types'; +import { LLMConfig, LLMOutput, createLLMCheckFn } from './llm-base'; /** * Configuration schema for user-defined LLM moderation checks. * * Extends the base LLMConfig with a required field for custom prompt details. */ -export const UserDefinedConfig = z.object({ - /** The LLM model to use for content checking */ - model: z.string(), - /** Minimum confidence score (0.0 to 1.0) required to trigger the guardrail. Defaults to 0.7. */ - confidence_threshold: z.number().min(0.0).max(1.0).default(0.7), +export const UserDefinedConfig = LLMConfig.omit({ system_prompt_details: true }).extend({ /** Free-form instructions describing content moderation requirements */ - system_prompt_details: z.string(), + system_prompt_details: z.string().describe('Free-form instructions describing content moderation requirements'), }); export type UserDefinedConfig = z.infer; @@ -30,21 +25,12 @@ export type UserDefinedConfig = z.infer; /** * Context requirements for the user-defined LLM guardrail. */ -export const UserDefinedContext = z.object({ - /** OpenAI client for LLM operations */ - guardrailLlm: z.any(), -}); - -export type UserDefinedContext = z.infer; +export type UserDefinedContext = GuardrailLLMContext; /** * Output schema for user-defined LLM analysis. */ -export const UserDefinedOutput = z.object({ - /** Whether the content was flagged according to the custom criteria */ - flagged: z.boolean(), - /** Confidence score (0.0 to 1.0) that the input violates the custom criteria */ - confidence: z.number().min(0.0).max(1.0), +export const UserDefinedOutput = LLMOutput.extend({ /** Optional reason for the flagging decision */ reason: z.string().optional(), }); @@ -58,168 +44,19 @@ const SYSTEM_PROMPT = `# Instructions {system_prompt_details} -Flag any content that matches the criteria defined above. - -Respond with a JSON object containing: -- "flagged": boolean indicating if content violates the criteria -- "confidence": number between 0.0 and 1.0 indicating confidence in the assessment`; +Flag any content that matches the criteria defined above.`; /** * User-defined LLM guardrail. * * Runs a user-defined guardrail based on a custom system prompt. * Allows for flexible content moderation based on specific requirements. - * - * @param ctx Guardrail context containing the LLM client. - * @param data Text to analyze according to custom criteria. - * @param config Configuration with custom system prompt details. - * @returns GuardrailResult containing analysis with flagged status - * and confidence score. */ -export const userDefinedLLMCheck: CheckFn = async ( - ctx, - data, - config -): Promise => { - try { - // Render the system prompt with custom details - const renderedSystemPrompt = SYSTEM_PROMPT.replace( - '{system_prompt_details}', - config.system_prompt_details - ); - - // Use the OpenAI API to analyze the text - // Try with JSON response format first, fall back to text if not supported - let response; - try { - // Build API call parameters - const params: Record = { - messages: [ - { role: 'system', content: renderedSystemPrompt }, - { role: 'user', content: data }, - ], - model: config.model, - temperature: 0.0, - response_format: { type: 'json_object' }, - }; - - // Only include safety_identifier for official OpenAI API (not Azure or local providers) - if (supportsSafetyIdentifier(ctx.guardrailLlm)) { - // @ts-ignore - safety_identifier is not defined in OpenAI types yet - params.safety_identifier = SAFETY_IDENTIFIER; - } - - response = await ctx.guardrailLlm.chat.completions.create(params); - } catch (error: unknown) { - // If JSON response format is not supported, try without it - if (error && typeof error === 'object' && 'error' in error && - (error as { error?: { param?: string } }).error?.param === 'response_format') { - // Build fallback parameters without response_format - const fallbackParams: Record = { - messages: [ - { role: 'system', content: renderedSystemPrompt }, - { role: 'user', content: data }, - ], - model: config.model, - temperature: 0.0, - }; - - // Only include safety_identifier for official OpenAI API (not Azure or local providers) - if (supportsSafetyIdentifier(ctx.guardrailLlm)) { - // @ts-ignore - safety_identifier is not defined in OpenAI types yet - fallbackParams.safety_identifier = SAFETY_IDENTIFIER; - } - - response = await ctx.guardrailLlm.chat.completions.create(fallbackParams); - } else { - // Return error information instead of re-throwing - return { - tripwireTriggered: false, - executionFailed: true, - originalException: error instanceof Error ? error : new Error(String(error)), - info: { - checked_text: data, - error_message: String(error), - flagged: false, - confidence: 0.0, - }, - }; - } - } - - const content = response.choices[0]?.message?.content; - if (!content) { - return { - tripwireTriggered: false, - executionFailed: true, - originalException: new Error('No response content from LLM'), - info: { - checked_text: data, - error_message: 'No response content from LLM', - flagged: false, - confidence: 0.0, - }, - }; - } - - // Parse the response - try JSON first, fall back to text parsing - let analysis: UserDefinedOutput; - try { - analysis = JSON.parse(content); - } catch { - // If JSON parsing fails, try to extract information from text response - // Look for patterns like "flagged: true/false" and "confidence: 0.8" - const flaggedMatch = content.match(/flagged:\s*(true|false)/i); - const confidenceMatch = content.match(/confidence:\s*([0-9.]+)/i); - const reasonMatch = content.match(/reason:\s*"([^"]+)"/i); - - analysis = { - flagged: flaggedMatch ? flaggedMatch[1].toLowerCase() === 'true' : false, - confidence: confidenceMatch ? parseFloat(confidenceMatch[1]) : 0.0, - reason: reasonMatch ? reasonMatch[1] : 'Could not parse response format', - }; - } - - // Determine if tripwire should be triggered - const isTrigger = analysis.flagged && analysis.confidence >= config.confidence_threshold; - - return { - tripwireTriggered: isTrigger, - info: { - checked_text: data, // Custom check doesn't modify the text - guardrail_name: 'Custom Prompt Check', - ...analysis, - threshold: config.confidence_threshold, - custom_prompt: config.system_prompt_details, - }, - }; - } catch (error) { - // Log unexpected errors and return safe default - console.error('Unexpected error in user-defined LLM check:', error); - return { - tripwireTriggered: false, - executionFailed: true, - originalException: error instanceof Error ? error : new Error(String(error)), - info: { - checked_text: data, // Return original text on error - guardrail_name: 'Custom Prompt Check', - flagged: false, - confidence: 0.0, - threshold: config.confidence_threshold, - custom_prompt: config.system_prompt_details, - error: String(error), - }, - }; - } -}; - -// Auto-register this guardrail with the default registry -defaultSpecRegistry.register( - 'Custom Prompt Check', - userDefinedLLMCheck, - 'User-defined LLM guardrail for custom content moderation', - 'text/plain', - UserDefinedConfig as z.ZodType, - UserDefinedContext, - { engine: 'llm' } -); +export const userDefinedLLM: CheckFn = + createLLMCheckFn( + 'Custom Prompt Check', + 'User-defined LLM guardrail for custom content moderation', + SYSTEM_PROMPT, + UserDefinedOutput, + UserDefinedConfig as unknown as typeof LLMConfig + ) as CheckFn;