-
Notifications
You must be signed in to change notification settings - Fork 16
Handle sync guardrail calls to avoid awaitable error #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,20 @@ def __init__(self, content: str | None) -> None: | |
| self.chat = SimpleNamespace(completions=_FakeCompletions(content)) | ||
|
|
||
|
|
||
| class _FakeSyncCompletions: | ||
| def __init__(self, content: str | None) -> None: | ||
| self._content = content | ||
|
|
||
| def create(self, **kwargs: Any) -> Any: | ||
| _ = kwargs | ||
|
Comment on lines
+41
to
+42
|
||
| return SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self._content))]) | ||
|
|
||
|
|
||
| class _FakeSyncClient: | ||
| def __init__(self, content: str | None) -> None: | ||
| self.chat = SimpleNamespace(completions=_FakeSyncCompletions(content)) | ||
|
|
||
|
|
||
| def test_strip_json_code_fence_removes_wrapping() -> None: | ||
| """Valid JSON code fences should be removed.""" | ||
| fenced = """```json | ||
|
|
@@ -64,6 +78,23 @@ async def test_run_llm_returns_valid_output() -> None: | |
| assert result.flagged is True and result.confidence == 0.9 # noqa: S101 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_run_llm_supports_sync_clients() -> None: | ||
| """run_llm should invoke synchronous clients without awaiting them.""" | ||
| client = _FakeSyncClient('{"flagged": false, "confidence": 0.25}') | ||
|
|
||
| result = await run_llm( | ||
| text="General text", | ||
| system_prompt="Assess text.", | ||
| client=client, # type: ignore[arg-type] | ||
| model="gpt-test", | ||
| output_model=LLMOutput, | ||
| ) | ||
|
|
||
| assert isinstance(result, LLMOutput) # noqa: S101 | ||
| assert result.flagged is False and result.confidence == 0.25 # noqa: S101 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_run_llm_handles_content_filter_error(monkeypatch: pytest.MonkeyPatch) -> None: | ||
| """Content filter errors should return LLMErrorOutput with flagged=True.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -120,3 +120,21 @@ async def failing_llm(*_args: Any, **_kwargs: Any) -> PromptInjectionDetectionOu | |||
|
|
||||
| assert result.tripwire_triggered is False # noqa: S101 | ||||
| assert "Error during prompt injection detection check" in result.info["observation"] # noqa: S101 | ||||
|
|
||||
|
|
||||
| @pytest.mark.asyncio | ||||
| async def test_prompt_injection_detection_llm_supports_sync_responses() -> None: | ||||
| """Underlying responses.parse may be synchronous for some clients.""" | ||||
| analysis = PromptInjectionDetectionOutput(flagged=True, confidence=0.4, observation="Action summary") | ||||
|
|
||||
| class _SyncResponses: | ||||
| def parse(self, **kwargs: Any) -> Any: | ||||
| _ = kwargs | ||||
|
||||
| _ = kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just in a test, I think it's fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bare
except Exceptionis too broad. Consider catchingImportErrorspecifically since this is handling an optional dependency import.