Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/__tests__/unit/chat-resources.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ describe('Chat resource', () => {
messages,
model: 'gpt-4',
stream: false,
safety_identifier: 'oai-guardrails-ts',
safety_identifier: 'openai-guardrails-js',
});
expect(client.handleLlmResponse).toHaveBeenCalledWith(
{ id: 'chat-response' },
Expand Down Expand Up @@ -156,7 +156,7 @@ describe('Responses resource', () => {
model: 'gpt-4o',
stream: false,
tools: undefined,
safety_identifier: 'oai-guardrails-ts',
safety_identifier: 'openai-guardrails-js',
});
expect(client.handleLlmResponse).toHaveBeenCalledWith(
{ id: 'responses-api' },
Expand Down
1 change: 1 addition & 0 deletions src/__tests__/unit/checks/moderation-secret-keys.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ describe('moderation guardrail', () => {
expect(createMock).toHaveBeenCalledWith({
model: 'omni-moderation-latest',
input: 'bad content',
safety_identifier: 'openai-guardrails-js',
});
expect(result.tripwireTriggered).toBe(true);
expect(result.info?.flagged_categories).toEqual([Category.HATE]);
Expand Down
1 change: 1 addition & 0 deletions src/__tests__/unit/checks/user-defined-llm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ describe('userDefinedLLMCheck', () => {
model: 'gpt-test',
temperature: 0.0,
response_format: { type: 'json_object' },
safety_identifier: 'openai-guardrails-js',
});
expect(result.tripwireTriggered).toBe(true);
expect(result.info?.flagged).toBe(true);
Expand Down
128 changes: 128 additions & 0 deletions src/__tests__/unit/utils/safety-identifier.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/**
* Unit tests for safety identifier utilities.
*
* These tests verify the detection logic for determining whether a client
* supports the safety_identifier parameter in OpenAI API calls.
*/

import { describe, it, expect } from 'vitest';
import { supportsSafetyIdentifier, SAFETY_IDENTIFIER } from '../../../utils/safety-identifier';

describe('Safety Identifier utilities', () => {
describe('SAFETY_IDENTIFIER constant', () => {
it('should have the correct value', () => {
expect(SAFETY_IDENTIFIER).toBe('openai-guardrails-js');
});
});

describe('supportsSafetyIdentifier', () => {
it('should return true for official OpenAI client with default baseURL', () => {
// Mock an official OpenAI client (no custom baseURL)
const mockClient = {
constructor: { name: 'OpenAI' },
baseURL: undefined,
};

expect(supportsSafetyIdentifier(mockClient)).toBe(true);
});

it('should return true for OpenAI client with explicit api.openai.com baseURL', () => {
const mockClient = {
constructor: { name: 'OpenAI' },
baseURL: 'https://api.openai.com/v1',
};

expect(supportsSafetyIdentifier(mockClient)).toBe(true);
});

it('should return false for Azure OpenAI client', () => {
const mockClient = {
constructor: { name: 'AzureOpenAI' },
baseURL: 'https://example.openai.azure.com/v1',
};

expect(supportsSafetyIdentifier(mockClient)).toBe(false);
});

it('should return false for AsyncAzureOpenAI client', () => {
const mockClient = {
constructor: { name: 'AsyncAzureOpenAI' },
baseURL: 'https://example.openai.azure.com/v1',
};

expect(supportsSafetyIdentifier(mockClient)).toBe(false);
});

it('should return false for local model with custom baseURL (Ollama)', () => {
const mockClient = {
constructor: { name: 'OpenAI' },
baseURL: 'http://localhost:11434/v1',
};

expect(supportsSafetyIdentifier(mockClient)).toBe(false);
});

it('should return false for alternative OpenAI-compatible provider', () => {
const mockClient = {
constructor: { name: 'OpenAI' },
baseURL: 'https://api.together.xyz/v1',
};

expect(supportsSafetyIdentifier(mockClient)).toBe(false);
});

it('should return false for vLLM server', () => {
const mockClient = {
constructor: { name: 'OpenAI' },
baseURL: 'http://localhost:8000/v1',
};

expect(supportsSafetyIdentifier(mockClient)).toBe(false);
});

it('should return false for null client', () => {
expect(supportsSafetyIdentifier(null)).toBe(false);
});

it('should return false for undefined client', () => {
expect(supportsSafetyIdentifier(undefined)).toBe(false);
});

it('should return false for non-object client', () => {
expect(supportsSafetyIdentifier('not an object')).toBe(false);
expect(supportsSafetyIdentifier(123)).toBe(false);
});

it('should check _client.baseURL if baseURL is not directly accessible', () => {
const mockClient = {
constructor: { name: 'OpenAI' },
_client: {
baseURL: 'http://localhost:11434/v1',
},
};

expect(supportsSafetyIdentifier(mockClient)).toBe(false);
});

it('should check _baseURL if baseURL and _client.baseURL are not accessible', () => {
const mockClient = {
constructor: { name: 'OpenAI' },
_baseURL: 'http://localhost:11434/v1',
};

expect(supportsSafetyIdentifier(mockClient)).toBe(false);
});

it('should return true when api.openai.com is found via _client.baseURL', () => {
const mockClient = {
constructor: { name: 'OpenAI' },
_client: {
baseURL: 'https://api.openai.com/v1',
},
};

expect(supportsSafetyIdentifier(mockClient)).toBe(true);
});
});
});

15 changes: 13 additions & 2 deletions src/checks/llm-base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { z } from 'zod';
import { OpenAI } from 'openai';
import { CheckFn, GuardrailResult, GuardrailLLMContext } from '../types';
import { defaultSpecRegistry } from '../registry';
import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../utils/safety-identifier';

/**
* Configuration schema for LLM-based content checks.
Expand Down Expand Up @@ -195,15 +196,25 @@ export async function runLLM(
temperature = 1.0;
}

const response = await client.chat.completions.create({
// Build API call parameters
const params: Record<string, unknown> = {
messages: [
{ role: 'system', content: fullPrompt },
{ role: 'user', content: `# Text\n\n${text}` },
],
model: model,
temperature: temperature,
response_format: { type: 'json_object' },
});
};

// Only include safety_identifier for official OpenAI API (not Azure or local providers)
if (supportsSafetyIdentifier(client)) {
// @ts-ignore - safety_identifier is not defined in OpenAI types yet
params.safety_identifier = SAFETY_IDENTIFIER;
}

// @ts-ignore - safety_identifier is not in the OpenAI types yet
const response = await client.chat.completions.create(params);

const result = response.choices[0]?.message?.content;
if (!result) {
Expand Down
113 changes: 83 additions & 30 deletions src/checks/moderation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { z } from 'zod';
import { CheckFn, GuardrailResult } from '../types';
import { defaultSpecRegistry } from '../registry';
import OpenAI from 'openai';
import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../utils/safety-identifier';

/**
* Enumeration of supported moderation categories.
Expand Down Expand Up @@ -78,6 +79,42 @@ export const ModerationContext = z.object({

export type ModerationContext = z.infer<typeof ModerationContext>;

/**
* Check if an error is a 404 Not Found error from the OpenAI API.
*
* @param error The error to check
* @returns True if the error is a 404 error
*/
function isNotFoundError(error: unknown): boolean {
return !!(error && typeof error === 'object' && 'status' in error && error.status === 404);
}

/**
* Call the OpenAI moderation API.
*
* @param client The OpenAI client to use
* @param data The text to analyze
* @returns The moderation API response
*/
function callModerationAPI(
client: OpenAI,
data: string
): ReturnType<OpenAI['moderations']['create']> {
const params: Record<string, unknown> = {
model: 'omni-moderation-latest',
input: data,
};

// Only include safety_identifier for official OpenAI API (not Azure or local providers)
if (supportsSafetyIdentifier(client)) {
// @ts-ignore - safety_identifier is not defined in OpenAI types yet
params.safety_identifier = SAFETY_IDENTIFIER;
}

// @ts-ignore - safety_identifier is not in the OpenAI types yet
return client.moderations.create(params);
}

/**
* Guardrail check_fn to flag disallowed content categories using OpenAI moderation API.
*
Expand All @@ -102,39 +139,55 @@ export const moderationCheck: CheckFn<ModerationContext, string, ModerationConfi
const configObj = actualConfig as Record<string, unknown>;
const categories = (configObj.categories as string[]) || Object.values(Category);

// Reuse provided client only if it targets the official OpenAI API.
const reuseClientIfOpenAI = (context: unknown): OpenAI | null => {
try {
const contextObj = context as Record<string, unknown>;
const candidate = contextObj?.guardrailLlm;
if (!candidate || typeof candidate !== 'object') return null;
if (!(candidate instanceof OpenAI)) return null;

const candidateObj = candidate as unknown as Record<string, unknown>;
const baseURL: string | undefined =
(candidateObj.baseURL as string) ??
((candidateObj._client as Record<string, unknown>)?.baseURL as string) ??
(candidateObj._baseURL as string);

if (
baseURL === undefined ||
(typeof baseURL === 'string' && baseURL.includes('api.openai.com'))
) {
return candidate as OpenAI;
}
return null;
} catch {
return null;
// Get client from context if available
let client: OpenAI | null = null;
if (ctx) {
const contextObj = ctx as Record<string, unknown>;
const candidate = contextObj.guardrailLlm;
if (candidate && candidate instanceof OpenAI) {
client = candidate;
}
};

const client = reuseClientIfOpenAI(ctx) ?? new OpenAI();
}

try {
const resp = await client.moderations.create({
model: 'omni-moderation-latest',
input: data,
});
// Try the context client first, fall back if moderation endpoint doesn't exist
let resp: Awaited<ReturnType<typeof callModerationAPI>>;
if (client !== null) {
try {
resp = await callModerationAPI(client, data);
} catch (error) {

Copy link

Copilot AI Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Empty line inside catch block should be removed to improve code cleanliness.

Suggested change

Copilot uses AI. Check for mistakes.
// Moderation endpoint doesn't exist on this provider (e.g., third-party)
// Fall back to the OpenAI client
if (isNotFoundError(error)) {
try {
resp = await callModerationAPI(new OpenAI(), data);
} catch (fallbackError) {
// If fallback fails, provide a helpful error message
const errorMessage = fallbackError instanceof Error
? fallbackError.message
: String(fallbackError);

// Check if it's an API key error
if (errorMessage.includes('api_key') || errorMessage.includes('OPENAI_API_KEY')) {
return {
tripwireTriggered: false,
info: {
checked_text: data,
error: 'Moderation API requires OpenAI API key. Set OPENAI_API_KEY environment variable or pass a client with valid credentials.',
},
};
}
throw fallbackError;
}
} else {
throw error;
}
}
} else {
// No context client, use fallback
resp = await callModerationAPI(new OpenAI(), data);
}

const results = resp.results || [];
if (!results.length) {
Expand Down
27 changes: 23 additions & 4 deletions src/checks/user-defined-llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import { z } from 'zod';
import { CheckFn, GuardrailResult } from '../types';
import { defaultSpecRegistry } from '../registry';
import { SAFETY_IDENTIFIER, supportsSafetyIdentifier } from '../utils/safety-identifier';

/**
* Configuration schema for user-defined LLM moderation checks.
Expand Down Expand Up @@ -91,27 +92,45 @@ export const userDefinedLLMCheck: CheckFn<UserDefinedContext, string, UserDefine
// Try with JSON response format first, fall back to text if not supported
let response;
try {
response = await ctx.guardrailLlm.chat.completions.create({
// Build API call parameters
const params: Record<string, unknown> = {
messages: [
{ role: 'system', content: renderedSystemPrompt },
{ role: 'user', content: data },
],
model: config.model,
temperature: 0.0,
response_format: { type: 'json_object' },
});
};

// Only include safety_identifier for official OpenAI API (not Azure or local providers)
if (supportsSafetyIdentifier(ctx.guardrailLlm)) {
// @ts-ignore - safety_identifier is not defined in OpenAI types yet
params.safety_identifier = SAFETY_IDENTIFIER;
}

response = await ctx.guardrailLlm.chat.completions.create(params);
} catch (error: unknown) {
// If JSON response format is not supported, try without it
if (error && typeof error === 'object' && 'error' in error &&
(error as { error?: { param?: string } }).error?.param === 'response_format') {
response = await ctx.guardrailLlm.chat.completions.create({
// Build fallback parameters without response_format
const fallbackParams: Record<string, unknown> = {
messages: [
{ role: 'system', content: renderedSystemPrompt },
{ role: 'user', content: data },
],
model: config.model,
temperature: 0.0,
});
};

// Only include safety_identifier for official OpenAI API (not Azure or local providers)
if (supportsSafetyIdentifier(ctx.guardrailLlm)) {
// @ts-ignore - safety_identifier is not defined in OpenAI types yet
fallbackParams.safety_identifier = SAFETY_IDENTIFIER;
}

response = await ctx.guardrailLlm.chat.completions.create(fallbackParams);
} else {
// Return error information instead of re-throwing
return {
Expand Down
Loading