diff --git a/src/__tests__/unit/chat-resources.test.ts b/src/__tests__/unit/chat-resources.test.ts index 594d0a7..f92a7ba 100644 --- a/src/__tests__/unit/chat-resources.test.ts +++ b/src/__tests__/unit/chat-resources.test.ts @@ -95,7 +95,7 @@ describe('Chat resource', () => { messages, model: 'gpt-4', stream: false, - safety_identifier: 'oai-guardrails-ts', + safety_identifier: 'openai-guardrails-js', }); expect(client.handleLlmResponse).toHaveBeenCalledWith( { id: 'chat-response' }, @@ -156,7 +156,7 @@ describe('Responses resource', () => { model: 'gpt-4o', stream: false, tools: undefined, - safety_identifier: 'oai-guardrails-ts', + safety_identifier: 'openai-guardrails-js', }); expect(client.handleLlmResponse).toHaveBeenCalledWith( { id: 'responses-api' }, diff --git a/src/__tests__/unit/checks/moderation-secret-keys.test.ts b/src/__tests__/unit/checks/moderation-secret-keys.test.ts index 7808312..de4d5a4 100644 --- a/src/__tests__/unit/checks/moderation-secret-keys.test.ts +++ b/src/__tests__/unit/checks/moderation-secret-keys.test.ts @@ -48,6 +48,7 @@ describe('moderation guardrail', () => { expect(createMock).toHaveBeenCalledWith({ model: 'omni-moderation-latest', input: 'bad content', + safety_identifier: 'openai-guardrails-js', }); expect(result.tripwireTriggered).toBe(true); expect(result.info?.flagged_categories).toEqual([Category.HATE]); diff --git a/src/__tests__/unit/checks/user-defined-llm.test.ts b/src/__tests__/unit/checks/user-defined-llm.test.ts index fd9291f..0b367e1 100644 --- a/src/__tests__/unit/checks/user-defined-llm.test.ts +++ b/src/__tests__/unit/checks/user-defined-llm.test.ts @@ -52,6 +52,7 @@ describe('userDefinedLLMCheck', () => { model: 'gpt-test', temperature: 0.0, response_format: { type: 'json_object' }, + safety_identifier: 'openai-guardrails-js', }); expect(result.tripwireTriggered).toBe(true); expect(result.info?.flagged).toBe(true); diff --git a/src/__tests__/unit/utils/safety-identifier.test.ts b/src/__tests__/unit/utils/safety-identifier.test.ts new file mode 100644 index 0000000..c42e6b3 --- /dev/null +++ b/src/__tests__/unit/utils/safety-identifier.test.ts @@ -0,0 +1,128 @@ +/** + * Unit tests for safety identifier utilities. + * + * These tests verify the detection logic for determining whether a client + * supports the safety_identifier parameter in OpenAI API calls. + */ + +import { describe, it, expect } from 'vitest'; +import { supportsSafetyIdentifier, SAFETY_IDENTIFIER } from '../../../utils/safety-identifier'; + +describe('Safety Identifier utilities', () => { + describe('SAFETY_IDENTIFIER constant', () => { + it('should have the correct value', () => { + expect(SAFETY_IDENTIFIER).toBe('openai-guardrails-js'); + }); + }); + + describe('supportsSafetyIdentifier', () => { + it('should return true for official OpenAI client with default baseURL', () => { + // Mock an official OpenAI client (no custom baseURL) + const mockClient = { + constructor: { name: 'OpenAI' }, + baseURL: undefined, + }; + + expect(supportsSafetyIdentifier(mockClient)).toBe(true); + }); + + it('should return true for OpenAI client with explicit api.openai.com baseURL', () => { + const mockClient = { + constructor: { name: 'OpenAI' }, + baseURL: 'https://api.openai.com/v1', + }; + + expect(supportsSafetyIdentifier(mockClient)).toBe(true); + }); + + it('should return false for Azure OpenAI client', () => { + const mockClient = { + constructor: { name: 'AzureOpenAI' }, + baseURL: 'https://example.openai.azure.com/v1', + }; + + expect(supportsSafetyIdentifier(mockClient)).toBe(false); + }); + + it('should return false for AsyncAzureOpenAI client', () => { + const mockClient = { + constructor: { name: 'AsyncAzureOpenAI' }, + baseURL: 'https://example.openai.azure.com/v1', + }; + + expect(supportsSafetyIdentifier(mockClient)).toBe(false); + }); + + it('should return false for local model with custom baseURL (Ollama)', () => { + const mockClient = { + constructor: { name: 'OpenAI' }, + baseURL: 'http://localhost:11434/v1', + }; + + expect(supportsSafetyIdentifier(mockClient)).toBe(false); + }); + + it('should return false for alternative OpenAI-compatible provider', () => { + const mockClient = { + constructor: { name: 'OpenAI' }, + baseURL: 'https://api.together.xyz/v1', + }; + + expect(supportsSafetyIdentifier(mockClient)).toBe(false); + }); + + it('should return false for vLLM server', () => { + const mockClient = { + constructor: { name: 'OpenAI' }, + baseURL: 'http://localhost:8000/v1', + }; + + expect(supportsSafetyIdentifier(mockClient)).toBe(false); + }); + + it('should return false for null client', () => { + expect(supportsSafetyIdentifier(null)).toBe(false); + }); + + it('should return false for undefined client', () => { + expect(supportsSafetyIdentifier(undefined)).toBe(false); + }); + + it('should return false for non-object client', () => { + expect(supportsSafetyIdentifier('not an object')).toBe(false); + expect(supportsSafetyIdentifier(123)).toBe(false); + }); + + it('should check _client.baseURL if baseURL is not directly accessible', () => { + const mockClient = { + constructor: { name: 'OpenAI' }, + _client: { + baseURL: 'http://localhost:11434/v1', + }, + }; + + expect(supportsSafetyIdentifier(mockClient)).toBe(false); + }); + + it('should check _baseURL if baseURL and _client.baseURL are not accessible', () => { + const mockClient = { + constructor: { name: 'OpenAI' }, + _baseURL: 'http://localhost:11434/v1', + }; + + expect(supportsSafetyIdentifier(mockClient)).toBe(false); + }); + + it('should return true when api.openai.com is found via _client.baseURL', () => { + const mockClient = { + constructor: { name: 'OpenAI' }, + _client: { + baseURL: 'https://api.openai.com/v1', + }, + }; + + expect(supportsSafetyIdentifier(mockClient)).toBe(true); + }); + }); +}); + diff --git a/src/checks/llm-base.ts b/src/checks/llm-base.ts index 16fb8d3..9d1e2ef 100644 --- a/src/checks/llm-base.ts +++ b/src/checks/llm-base.ts @@ -11,6 +11,7 @@ import { z } from 'zod'; import { OpenAI } from 'openai'; import { CheckFn, GuardrailResult, GuardrailLLMContext } from '../types'; import { defaultSpecRegistry } from '../registry'; +import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../utils/safety-identifier'; /** * Configuration schema for LLM-based content checks. @@ -195,7 +196,8 @@ export async function runLLM( temperature = 1.0; } - const response = await client.chat.completions.create({ + // Build API call parameters + const params: Record = { messages: [ { role: 'system', content: fullPrompt }, { role: 'user', content: `# Text\n\n${text}` }, @@ -203,7 +205,16 @@ export async function runLLM( model: model, temperature: temperature, response_format: { type: 'json_object' }, - }); + }; + + // Only include safety_identifier for official OpenAI API (not Azure or local providers) + if (supportsSafetyIdentifier(client)) { + // @ts-ignore - safety_identifier is not defined in OpenAI types yet + params.safety_identifier = SAFETY_IDENTIFIER; + } + + // @ts-ignore - safety_identifier is not in the OpenAI types yet + const response = await client.chat.completions.create(params); const result = response.choices[0]?.message?.content; if (!result) { diff --git a/src/checks/moderation.ts b/src/checks/moderation.ts index 0357b80..47c224d 100644 --- a/src/checks/moderation.ts +++ b/src/checks/moderation.ts @@ -22,6 +22,7 @@ import { z } from 'zod'; import { CheckFn, GuardrailResult } from '../types'; import { defaultSpecRegistry } from '../registry'; import OpenAI from 'openai'; +import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../utils/safety-identifier'; /** * Enumeration of supported moderation categories. @@ -78,6 +79,42 @@ export const ModerationContext = z.object({ export type ModerationContext = z.infer; +/** + * Check if an error is a 404 Not Found error from the OpenAI API. + * + * @param error The error to check + * @returns True if the error is a 404 error + */ +function isNotFoundError(error: unknown): boolean { + return !!(error && typeof error === 'object' && 'status' in error && error.status === 404); +} + +/** + * Call the OpenAI moderation API. + * + * @param client The OpenAI client to use + * @param data The text to analyze + * @returns The moderation API response + */ +function callModerationAPI( + client: OpenAI, + data: string +): ReturnType { + const params: Record = { + model: 'omni-moderation-latest', + input: data, + }; + + // Only include safety_identifier for official OpenAI API (not Azure or local providers) + if (supportsSafetyIdentifier(client)) { + // @ts-ignore - safety_identifier is not defined in OpenAI types yet + params.safety_identifier = SAFETY_IDENTIFIER; + } + + // @ts-ignore - safety_identifier is not in the OpenAI types yet + return client.moderations.create(params); +} + /** * Guardrail check_fn to flag disallowed content categories using OpenAI moderation API. * @@ -102,39 +139,55 @@ export const moderationCheck: CheckFn; const categories = (configObj.categories as string[]) || Object.values(Category); - // Reuse provided client only if it targets the official OpenAI API. - const reuseClientIfOpenAI = (context: unknown): OpenAI | null => { - try { - const contextObj = context as Record; - const candidate = contextObj?.guardrailLlm; - if (!candidate || typeof candidate !== 'object') return null; - if (!(candidate instanceof OpenAI)) return null; - - const candidateObj = candidate as unknown as Record; - const baseURL: string | undefined = - (candidateObj.baseURL as string) ?? - ((candidateObj._client as Record)?.baseURL as string) ?? - (candidateObj._baseURL as string); - - if ( - baseURL === undefined || - (typeof baseURL === 'string' && baseURL.includes('api.openai.com')) - ) { - return candidate as OpenAI; - } - return null; - } catch { - return null; + // Get client from context if available + let client: OpenAI | null = null; + if (ctx) { + const contextObj = ctx as Record; + const candidate = contextObj.guardrailLlm; + if (candidate && candidate instanceof OpenAI) { + client = candidate; } - }; - - const client = reuseClientIfOpenAI(ctx) ?? new OpenAI(); + } try { - const resp = await client.moderations.create({ - model: 'omni-moderation-latest', - input: data, - }); + // Try the context client first, fall back if moderation endpoint doesn't exist + let resp: Awaited>; + if (client !== null) { + try { + resp = await callModerationAPI(client, data); + } catch (error) { + + // Moderation endpoint doesn't exist on this provider (e.g., third-party) + // Fall back to the OpenAI client + if (isNotFoundError(error)) { + try { + resp = await callModerationAPI(new OpenAI(), data); + } catch (fallbackError) { + // If fallback fails, provide a helpful error message + const errorMessage = fallbackError instanceof Error + ? fallbackError.message + : String(fallbackError); + + // Check if it's an API key error + if (errorMessage.includes('api_key') || errorMessage.includes('OPENAI_API_KEY')) { + return { + tripwireTriggered: false, + info: { + checked_text: data, + error: 'Moderation API requires OpenAI API key. Set OPENAI_API_KEY environment variable or pass a client with valid credentials.', + }, + }; + } + throw fallbackError; + } + } else { + throw error; + } + } + } else { + // No context client, use fallback + resp = await callModerationAPI(new OpenAI(), data); + } const results = resp.results || []; if (!results.length) { diff --git a/src/checks/user-defined-llm.ts b/src/checks/user-defined-llm.ts index c1f96b2..54aed25 100644 --- a/src/checks/user-defined-llm.ts +++ b/src/checks/user-defined-llm.ts @@ -9,6 +9,7 @@ import { z } from 'zod'; import { CheckFn, GuardrailResult } from '../types'; import { defaultSpecRegistry } from '../registry'; +import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../utils/safety-identifier'; /** * Configuration schema for user-defined LLM moderation checks. @@ -91,7 +92,8 @@ export const userDefinedLLMCheck: CheckFn = { messages: [ { role: 'system', content: renderedSystemPrompt }, { role: 'user', content: data }, @@ -99,19 +101,36 @@ export const userDefinedLLMCheck: CheckFn = { messages: [ { role: 'system', content: renderedSystemPrompt }, { role: 'user', content: data }, ], model: config.model, temperature: 0.0, - }); + }; + + // Only include safety_identifier for official OpenAI API (not Azure or local providers) + if (supportsSafetyIdentifier(ctx.guardrailLlm)) { + // @ts-ignore - safety_identifier is not defined in OpenAI types yet + fallbackParams.safety_identifier = SAFETY_IDENTIFIER; + } + + response = await ctx.guardrailLlm.chat.completions.create(fallbackParams); } else { // Return error information instead of re-throwing return { diff --git a/src/resources/chat/chat.ts b/src/resources/chat/chat.ts index ee0f6f2..8ddc5d7 100644 --- a/src/resources/chat/chat.ts +++ b/src/resources/chat/chat.ts @@ -6,6 +6,7 @@ import { OpenAI } from 'openai'; import { GuardrailsBaseClient, GuardrailsResponse } from '../../base-client'; import { Message } from '../../types'; +import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../../utils/safety-identifier'; // Note: We need to filter out non-text content since guardrails only work with text // The existing extractLatestUserTextMessage method expects TextOnlyMessageArray @@ -82,6 +83,24 @@ export class ChatCompletions { ); // Run input guardrails and LLM call concurrently + // Access protected _resourceClient - necessary for external resource classes + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const resourceClient = (this.client as any)._resourceClient; + + // Build API call parameters + const apiParams: Record = { + messages: modifiedMessages, + model, + stream, + ...kwargs, + }; + + // Only include safety_identifier for official OpenAI API (not Azure or local providers) + if (supportsSafetyIdentifier(resourceClient)) { + // @ts-ignore - safety_identifier is not defined in OpenAI types yet + apiParams.safety_identifier = SAFETY_IDENTIFIER; + } + const [inputResults, llmResponse] = await Promise.all([ this.client.runStageGuardrails( 'input', @@ -90,16 +109,7 @@ export class ChatCompletions { suppressTripwire, this.client.raiseGuardrailErrors ), - // Access protected _resourceClient - necessary for external resource classes - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (this.client as any)._resourceClient.chat.completions.create({ - messages: modifiedMessages, - model, - stream, - ...kwargs, - // @ts-ignore - safety_identifier is not defined in OpenAI types yet - safety_identifier: 'oai-guardrails-ts', - }), + resourceClient.chat.completions.create(apiParams), ]); // Handle streaming vs non-streaming diff --git a/src/resources/responses/responses.ts b/src/resources/responses/responses.ts index 1bb65af..1fcd321 100644 --- a/src/resources/responses/responses.ts +++ b/src/resources/responses/responses.ts @@ -6,6 +6,7 @@ import { OpenAI } from 'openai'; import { GuardrailsBaseClient, GuardrailsResponse } from '../../base-client'; import { Message } from '../../types'; import { mergeConversationWithItems } from '../../utils/conversation'; +import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../../utils/safety-identifier'; /** * Responses API with guardrails. @@ -85,6 +86,24 @@ export class Responses { ); // Input guardrails and LLM call concurrently + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const resourceClient = (this.client as any)._resourceClient; + + // Build API call parameters + const apiParams: Record = { + input: modifiedInput, + model, + stream, + tools, + ...kwargs, + }; + + // Only include safety_identifier for official OpenAI API (not Azure or local providers) + if (supportsSafetyIdentifier(resourceClient)) { + // @ts-ignore - safety_identifier is not defined in OpenAI types yet + apiParams.safety_identifier = SAFETY_IDENTIFIER; + } + const [inputResults, llmResponse] = await Promise.all([ this.client.runStageGuardrails( 'input', @@ -93,16 +112,7 @@ export class Responses { suppressTripwire, this.client.raiseGuardrailErrors ), - // eslint-disable-next-line @typescript-eslint/no-explicit-any - (this.client as any)._resourceClient.responses.create({ - input: modifiedInput, - model, - stream, - tools, - ...kwargs, - // @ts-ignore - safety_identifier is not defined in OpenAI types yet - safety_identifier: 'oai-guardrails-ts', - }), + resourceClient.responses.create(apiParams), ]); // Handle streaming vs non-streaming diff --git a/src/utils/index.ts b/src/utils/index.ts index ed7274e..80a1b25 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -50,3 +50,6 @@ export { // OpenAI vector store utilities export { createOpenAIVectorStoreFromPath, OpenAIVectorStoreConfig } from './openai-vector-store'; + +// Safety identifier utilities +export { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from './safety-identifier'; diff --git a/src/utils/safety-identifier.ts b/src/utils/safety-identifier.ts new file mode 100644 index 0000000..57604a3 --- /dev/null +++ b/src/utils/safety-identifier.ts @@ -0,0 +1,67 @@ +/** + * OpenAI safety identifier utilities. + * + * This module provides utilities for handling the OpenAI safety_identifier parameter, + * which is used to track guardrails library usage for monitoring and abuse detection. + * + * The safety identifier is only supported by the official OpenAI API and should not + * be sent to Azure OpenAI or other OpenAI-compatible providers. + */ + +import OpenAI from 'openai'; + +/** + * OpenAI safety identifier for tracking guardrails library usage. + */ +export const SAFETY_IDENTIFIER = 'openai-guardrails-js'; + +/** + * Check if the client supports the safety_identifier parameter. + * + * Only the official OpenAI API supports this parameter. + * Azure OpenAI and local/alternative providers (Ollama, vLLM, etc.) do not. + * + * @param client The OpenAI client instance to check + * @returns True if safety_identifier should be included in API calls, False otherwise + * + * @example + * ```typescript + * import OpenAI from 'openai'; + * import { supportsSafetyIdentifier } from './safety-identifier'; + * + * const client = new OpenAI(); + * console.log(supportsSafetyIdentifier(client)); // true + * + * const localClient = new OpenAI({ baseURL: 'http://localhost:11434' }); + * console.log(supportsSafetyIdentifier(localClient)); // false + * ``` + */ +export function supportsSafetyIdentifier(client: OpenAI | unknown): boolean { + if (!client || typeof client !== 'object') { + return false; + } + + // Check if this is an Azure OpenAI client by checking the constructor name + const constructorName = client.constructor?.name; + if (constructorName === 'AzureOpenAI' || constructorName === 'AsyncAzureOpenAI') { + return false; + } + + // Check if using a custom baseURL (local or alternative provider) + // Try multiple ways to access baseURL as the internal structure may vary + const clientObj = client as Record; + const baseURL: string | undefined = + (clientObj.baseURL as string) ?? + ((clientObj._client as Record)?.baseURL as string) ?? + (clientObj._baseURL as string); + + if (baseURL !== undefined && baseURL !== null) { + const baseURLStr = String(baseURL); + // Only official OpenAI API endpoints support safety_identifier + return baseURLStr.includes('api.openai.com'); + } + + // Default OpenAI client (no custom baseURL) supports it + return true; +} +