Skip to content

Commit 293b1ae

Browse files
authored
Added safety header oaig. Ruff formatting (#18)
* Added safety header oaig. Ruff formatting * updated safety header name * additional test for agents
1 parent b2d7a81 commit 293b1ae

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+604
-881
lines changed

examples/basic/azure_implementation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@
5454
}
5555

5656

57-
async def process_input(
58-
guardrails_client: GuardrailsAsyncAzureOpenAI, user_input: str
59-
) -> None:
57+
async def process_input(guardrails_client: GuardrailsAsyncAzureOpenAI, user_input: str) -> None:
6058
"""Process user input with complete response validation using GuardrailsClient."""
6159
try:
6260
# Use GuardrailsClient to handle all guardrail checks and LLM calls

examples/basic/custom_context.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@
2626
"system_prompt_details": "Check if the text contains any math problems.",
2727
},
2828
},
29-
]
30-
}
29+
],
30+
},
3131
}
3232

3333

3434
async def main() -> None:
3535
# Use Ollama for guardrail LLM checks
3636
from openai import AsyncOpenAI
37+
3738
guardrail_llm = AsyncOpenAI(
3839
base_url="http://127.0.0.1:11434/v1/", # Ollama endpoint
3940
api_key="ollama",
@@ -50,10 +51,7 @@ async def main() -> None:
5051
while True:
5152
try:
5253
user_input = input("Enter a message: ")
53-
response = await client.chat.completions.create(
54-
model="gpt-4.1-nano",
55-
messages=[{"role": "user", "content": user_input}]
56-
)
54+
response = await client.chat.completions.create(model="gpt-4.1-nano", messages=[{"role": "user", "content": user_input}])
5755
print("Assistant:", response.llm_response.choices[0].message.content)
5856
except EOFError:
5957
break
@@ -65,5 +63,3 @@ async def main() -> None:
6563

6664
if __name__ == "__main__":
6765
asyncio.run(main())
68-
69-

examples/basic/hello_world.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,11 @@ async def process_input(
4949
previous_response_id=response_id,
5050
)
5151

52-
console.print(
53-
f"\nAssistant output: {response.llm_response.output_text}", end="\n\n"
54-
)
52+
console.print(f"\nAssistant output: {response.llm_response.output_text}", end="\n\n")
5553

5654
# Show guardrail results if any were run
5755
if response.guardrail_results.all_results:
58-
console.print(
59-
f"[dim]Guardrails checked: {len(response.guardrail_results.all_results)}[/dim]"
60-
)
56+
console.print(f"[dim]Guardrails checked: {len(response.guardrail_results.all_results)}[/dim]")
6157

6258
return response.llm_response.id
6359

@@ -76,16 +72,12 @@ async def main() -> None:
7672
while True:
7773
try:
7874
user_input = input("Enter a message: ")
79-
response_id = await process_input(
80-
guardrails_client, user_input, response_id
81-
)
75+
response_id = await process_input(guardrails_client, user_input, response_id)
8276
except EOFError:
8377
break
8478
except GuardrailTripwireTriggered as exc:
8579
stage_name = exc.guardrail_result.info.get("stage_name", "unknown")
86-
console.print(
87-
f"\n🛑 [bold red]Guardrail triggered in stage '{stage_name}'![/bold red]"
88-
)
80+
console.print(f"\n🛑 [bold red]Guardrail triggered in stage '{stage_name}'![/bold red]")
8981
console.print(
9082
Panel(
9183
str(exc.guardrail_result),

examples/basic/local_model.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,8 @@ async def main() -> None:
8080
break
8181
except GuardrailTripwireTriggered as exc:
8282
stage_name = exc.guardrail_result.info.get("stage_name", "unknown")
83-
guardrail_name = exc.guardrail_result.info.get(
84-
"guardrail_name", "unknown"
85-
)
86-
console.print(
87-
f"\n🛑 [bold red]Guardrail '{guardrail_name}' triggered in stage '{stage_name}'![/bold red]"
88-
)
83+
guardrail_name = exc.guardrail_result.info.get("guardrail_name", "unknown")
84+
console.print(f"\n🛑 [bold red]Guardrail '{guardrail_name}' triggered in stage '{stage_name}'![/bold red]")
8985
console.print(
9086
Panel(
9187
str(exc.guardrail_result),

examples/basic/multi_bundle.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ async def process_input(
7373

7474
# Get the response ID from the final chunk
7575
response_id_to_return = None
76-
if hasattr(chunk.llm_response, "response") and hasattr(
77-
chunk.llm_response.response, "id"
78-
):
76+
if hasattr(chunk.llm_response, "response") and hasattr(chunk.llm_response.response, "id"):
7977
response_id_to_return = chunk.llm_response.response.id
8078

8179
return response_id_to_return
@@ -98,16 +96,12 @@ async def main() -> None:
9896
while True:
9997
try:
10098
prompt = input("Enter a message: ")
101-
response_id = await process_input(
102-
guardrails_client, prompt, response_id
103-
)
99+
response_id = await process_input(guardrails_client, prompt, response_id)
104100
except (EOFError, KeyboardInterrupt):
105101
break
106102
except GuardrailTripwireTriggered as exc:
107103
stage_name = exc.guardrail_result.info.get("stage_name", "unknown")
108-
guardrail_name = exc.guardrail_result.info.get(
109-
"guardrail_name", "unknown"
110-
)
104+
guardrail_name = exc.guardrail_result.info.get("guardrail_name", "unknown")
111105
console.print(
112106
f"🛑 Guardrail '{guardrail_name}' triggered in stage '{stage_name}'!",
113107
style="bold red",

examples/basic/multiturn_chat_with_alignment.py

Lines changed: 22 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ def get_weather(location: str, unit: str = "celsius") -> dict[str, str | int]:
4343
}
4444

4545

46-
def get_flights(
47-
origin: str, destination: str, date: str
48-
) -> dict[str, list[dict[str, str]]]:
46+
def get_flights(origin: str, destination: str, date: str) -> dict[str, list[dict[str, str]]]:
4947
flights = [
5048
{"flight": "GA123", "depart": f"{date} 08:00", "arrive": f"{date} 12:30"},
5149
{"flight": "GA456", "depart": f"{date} 15:45", "arrive": f"{date} 20:10"},
@@ -160,9 +158,7 @@ def _stage_lines(stage_name: str, stage_results: Iterable) -> list[str]:
160158
# Header with status and confidence
161159
lines.append(f"[bold]{stage_name.upper()}[/bold] · {name} · {status}")
162160
if confidence != "N/A":
163-
lines.append(
164-
f" 📊 Confidence: {confidence} (threshold: {info.get('threshold', 'N/A')})"
165-
)
161+
lines.append(f" 📊 Confidence: {confidence} (threshold: {info.get('threshold', 'N/A')})")
166162

167163
# Prompt injection detection-specific details
168164
if name == "Prompt Injection Detection":
@@ -176,9 +172,7 @@ def _stage_lines(stage_name: str, stage_results: Iterable) -> list[str]:
176172

177173
# Add interpretation
178174
if r.tripwire_triggered:
179-
lines.append(
180-
" ⚠️ PROMPT INJECTION DETECTED: Action does not serve user's goal!"
181-
)
175+
lines.append(" ⚠️ PROMPT INJECTION DETECTED: Action does not serve user's goal!")
182176
else:
183177
lines.append(" ✨ ALIGNED: Action serves user's goal")
184178
else:
@@ -235,9 +229,7 @@ async def main(malicious: bool = False) -> None:
235229
messages.append({"role": "user", "content": user_input})
236230

237231
try:
238-
resp = await client.chat.completions.create(
239-
model="gpt-4.1-nano", messages=messages, tools=tools
240-
)
232+
resp = await client.chat.completions.create(model="gpt-4.1-nano", messages=messages, tools=tools)
241233
print_guardrail_results("initial", resp)
242234
choice = resp.llm_response.choices[0]
243235
message = choice.message
@@ -246,12 +238,12 @@ async def main(malicious: bool = False) -> None:
246238
info = getattr(e, "guardrail_result", None)
247239
info = info.info if info else {}
248240
lines = [
249-
f"Guardrail: {info.get('guardrail_name','Unknown')}",
250-
f"Stage: {info.get('stage_name','unknown')}",
251-
f"User goal: {info.get('user_goal','N/A')}",
252-
f"Action: {info.get('action','N/A')}",
253-
f"Observation: {info.get('observation','N/A')}",
254-
f"Confidence: {info.get('confidence','N/A')}",
241+
f"Guardrail: {info.get('guardrail_name', 'Unknown')}",
242+
f"Stage: {info.get('stage_name', 'unknown')}",
243+
f"User goal: {info.get('user_goal', 'N/A')}",
244+
f"Action: {info.get('action', 'N/A')}",
245+
f"Observation: {info.get('observation', 'N/A')}",
246+
f"Confidence: {info.get('confidence', 'N/A')}",
255247
]
256248
console.print(
257249
Panel(
@@ -292,12 +284,8 @@ async def main(malicious: bool = False) -> None:
292284

293285
# Malicious injection test mode
294286
if malicious:
295-
console.print(
296-
"[yellow]⚠️ MALICIOUS TEST: Injecting unrelated sensitive data into function output[/yellow]"
297-
)
298-
console.print(
299-
"[yellow] This should trigger the Prompt Injection Detection guardrail as misaligned![/yellow]"
300-
)
287+
console.print("[yellow]⚠️ MALICIOUS TEST: Injecting unrelated sensitive data into function output[/yellow]")
288+
console.print("[yellow] This should trigger the Prompt Injection Detection guardrail as misaligned![/yellow]")
301289
result = {
302290
**result,
303291
"bank_account": "1234567890",
@@ -319,17 +307,13 @@ async def main(malicious: bool = False) -> None:
319307
"role": "tool",
320308
"tool_call_id": call.id,
321309
"name": fname,
322-
"content": json.dumps(
323-
{"error": f"Unknown function: {fname}"}
324-
),
310+
"content": json.dumps({"error": f"Unknown function: {fname}"}),
325311
}
326312
)
327313

328314
# Final call
329315
try:
330-
resp = await client.chat.completions.create(
331-
model="gpt-4.1-nano", messages=messages, tools=tools
332-
)
316+
resp = await client.chat.completions.create(model="gpt-4.1-nano", messages=messages, tools=tools)
333317

334318
print_guardrail_results("final", resp)
335319
final_message = resp.llm_response.choices[0].message
@@ -342,19 +326,17 @@ async def main(malicious: bool = False) -> None:
342326
)
343327

344328
# Add final assistant response to conversation
345-
messages.append(
346-
{"role": "assistant", "content": final_message.content}
347-
)
329+
messages.append({"role": "assistant", "content": final_message.content})
348330
except GuardrailTripwireTriggered as e:
349331
info = getattr(e, "guardrail_result", None)
350332
info = info.info if info else {}
351333
lines = [
352-
f"Guardrail: {info.get('guardrail_name','Unknown')}",
353-
f"Stage: {info.get('stage_name','unknown')}",
354-
f"User goal: {info.get('user_goal','N/A')}",
355-
f"Action: {info.get('action','N/A')}",
356-
f"Observation: {info.get('observation','N/A')}",
357-
f"Confidence: {info.get('confidence','N/A')}",
334+
f"Guardrail: {info.get('guardrail_name', 'Unknown')}",
335+
f"Stage: {info.get('stage_name', 'unknown')}",
336+
f"User goal: {info.get('user_goal', 'N/A')}",
337+
f"Action: {info.get('action', 'N/A')}",
338+
f"Observation: {info.get('observation', 'N/A')}",
339+
f"Confidence: {info.get('confidence', 'N/A')}",
358340
]
359341
console.print(
360342
Panel(
@@ -380,9 +362,7 @@ async def main(malicious: bool = False) -> None:
380362

381363

382364
if __name__ == "__main__":
383-
parser = argparse.ArgumentParser(
384-
description="Chat Completions with Prompt Injection Detection guardrails"
385-
)
365+
parser = argparse.ArgumentParser(description="Chat Completions with Prompt Injection Detection guardrails")
386366
parser.add_argument(
387367
"--malicious",
388368
action="store_true",

examples/basic/pii_mask_example.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@
4242
},
4343
"input": {
4444
"version": 1,
45-
"guardrails": [
46-
{"name": "Moderation", "config": {"categories": ["hate", "violence"]}}
47-
],
45+
"guardrails": [{"name": "Moderation", "config": {"categories": ["hate", "violence"]}}],
4846
"config": {"concurrency": 5, "suppress_tripwire": False},
4947
},
5048
"output": {
@@ -98,9 +96,7 @@ async def process_input(
9896
# Show PII masking information if detected in pre-flight
9997
if response.guardrail_results.preflight:
10098
for result in response.guardrail_results.preflight:
101-
if result.info.get(
102-
"guardrail_name"
103-
) == "Contains PII" and result.info.get("pii_detected", False):
99+
if result.info.get("guardrail_name") == "Contains PII" and result.info.get("pii_detected", False):
104100
detected_entities = result.info.get("detected_entities", {})
105101
masked_text = result.info.get("checked_text", user_input)
106102

@@ -118,9 +114,7 @@ async def process_input(
118114
# Show if PII was detected in output
119115
if response.guardrail_results.output:
120116
for result in response.guardrail_results.output:
121-
if result.info.get(
122-
"guardrail_name"
123-
) == "Contains PII" and result.info.get("pii_detected", False):
117+
if result.info.get("guardrail_name") == "Contains PII" and result.info.get("pii_detected", False):
124118
detected_entities = result.info.get("detected_entities", {})
125119
console.print(
126120
Panel(
@@ -134,14 +128,8 @@ async def process_input(
134128
except GuardrailTripwireTriggered as exc:
135129
stage_name = exc.guardrail_result.info.get("stage_name", "unknown")
136130
guardrail_name = exc.guardrail_result.info.get("guardrail_name", "unknown")
137-
console.print(
138-
f"[bold red]Guardrail '{guardrail_name}' triggered in stage '{stage_name}'![/bold red]"
139-
)
140-
console.print(
141-
Panel(
142-
str(exc.guardrail_result), title="Guardrail Result", border_style="red"
143-
)
144-
)
131+
console.print(f"[bold red]Guardrail '{guardrail_name}' triggered in stage '{stage_name}'![/bold red]")
132+
console.print(Panel(str(exc.guardrail_result), title="Guardrail Result", border_style="red"))
145133
raise
146134

147135

examples/basic/structured_outputs_example.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# Define a simple Pydantic model for structured output
1111
class UserInfo(BaseModel):
1212
"""User information extracted from text."""
13+
1314
name: str = Field(description="Full name of the user")
1415
age: int = Field(description="Age of the user")
1516
email: str = Field(description="Email address of the user")
@@ -22,21 +23,18 @@ class UserInfo(BaseModel):
2223
"version": 1,
2324
"guardrails": [
2425
{"name": "Moderation", "config": {"categories": ["hate", "violence"]}},
25-
]
26-
}
26+
],
27+
},
2728
}
2829

2930

3031
async def extract_user_info(guardrails_client: GuardrailsAsyncOpenAI, text: str) -> UserInfo:
3132
"""Extract user information using responses_parse with structured output."""
3233
try:
3334
response = await guardrails_client.responses.parse(
34-
input=[
35-
{"role": "system", "content": "Extract user information from the provided text."},
36-
{"role": "user", "content": text}
37-
],
35+
input=[{"role": "system", "content": "Extract user information from the provided text."}, {"role": "user", "content": text}],
3836
model="gpt-4.1-nano",
39-
text_format=UserInfo
37+
text_format=UserInfo,
4038
)
4139

4240
# Access the parsed structured output

examples/basic/suppress_tripwire.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ async def process_input(
5555
for result in response.guardrail_results.all_results:
5656
guardrail_name = result.info.get("guardrail_name", "Unknown Guardrail")
5757
if result.tripwire_triggered:
58-
console.print(
59-
f"[bold yellow]Guardrail '{guardrail_name}' triggered![/bold yellow]"
60-
)
58+
console.print(f"[bold yellow]Guardrail '{guardrail_name}' triggered![/bold yellow]")
6159
console.print(
6260
Panel(
6361
str(result),
@@ -66,15 +64,11 @@ async def process_input(
6664
)
6765
)
6866
else:
69-
console.print(
70-
f"[bold green]Guardrail '{guardrail_name}' passed.[/bold green]"
71-
)
67+
console.print(f"[bold green]Guardrail '{guardrail_name}' passed.[/bold green]")
7268
else:
7369
console.print("[bold green]No guardrails triggered.[/bold green]")
7470

75-
console.print(
76-
f"\n[bold blue]Assistant output:[/bold blue] {response.llm_response.output_text}\n"
77-
)
71+
console.print(f"\n[bold blue]Assistant output:[/bold blue] {response.llm_response.output_text}\n")
7872
return response.llm_response.id
7973

8074
except Exception as e:
@@ -95,9 +89,7 @@ async def main() -> None:
9589
user_input = input("Enter a message: ")
9690
except EOFError:
9791
break
98-
response_id = await process_input(
99-
guardrails_client, user_input, response_id
100-
)
92+
response_id = await process_input(guardrails_client, user_input, response_id)
10193

10294

10395
if __name__ == "__main__":

0 commit comments

Comments
 (0)