Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions pyrit/exceptions/exception_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
from typing import Callable

from pyrit.common.constants import RETRY_WAIT_MIN_SECONDS, RETRY_WAIT_MAX_SECONDS, RETRY_MAX_NUM_ATTEMPTS
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.models.prompt_request_response import PromptRequestResponse


logger = logging.getLogger(__name__)


class PyritException(Exception, ABC):

def __init__(self, status_code, message):
def __init__(self, status_code=500, *, message: str = "An error occured"):
Comment thread
rdheekonda marked this conversation as resolved.
self.status_code = status_code
self.message = message
super().__init__(f"Status Code: {status_code}, Message: {message}")
Expand All @@ -35,14 +39,14 @@ class BadRequestException(PyritException):
"""Exception class for bad client requests."""

def __init__(self, status_code: int = 400, *, message: str = "Bad Request"):
super().__init__(status_code, message)
super().__init__(status_code, message=message)


class RateLimitException(PyritException):
"""Exception class for authentication errors."""

def __init__(self, status_code: int = 429, *, message: str = "Rate Limit Exception"):
super().__init__(status_code, message)
super().__init__(status_code, message=message)


class EmptyResponseException(BadRequestException):
Expand All @@ -52,6 +56,26 @@ def __init__(self, status_code: int = 204, *, message: str = "No Content"):
super().__init__(status_code=status_code, message=message)


def handle_bad_request_exception(
memory: MemoryInterface, response_text: str, request: PromptRequestPiece
) -> PromptRequestResponse:

if "content_filter" in response_text:
# Handle bad request error when content filter system detects harmful content
bad_request_exception = BadRequestException(400, message=response_text)
resp_text = bad_request_exception.process_exception()
response_entry = memory.add_response_entries_to_memory(
request=request, response_text_pieces=[resp_text], response_type="error", error="blocked"
)
else:
memory.add_response_entries_to_memory(
request=request, response_text_pieces=[response_text], response_type="error", error="processing"
)
raise

return response_entry


def pyrit_retry(func: Callable) -> Callable:
"""
A decorator to apply retry logic with exponential backoff to a function.
Expand Down
4 changes: 2 additions & 2 deletions pyrit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
"""
The type of the error in the prompt response
blocked: blocked by an external filter e.g. Azure Filters
model: the model refused to answer or request e.g. "I'm sorry..."
none: no exception is raised
processing: there is an exception thrown unrelated to the query
unknown: the type of error is unknown
"""
PromptResponseError = Literal["none", "blocked", "error", "model", "processing", "unknown"]
PromptResponseError = Literal["blocked", "none", "processing", "unknown"]

ChatMessageRole = Literal["system", "user", "assistant"]
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@
import json

from openai import AsyncAzureOpenAI
from openai import BadRequestError, RateLimitError
from openai import BadRequestError
from openai.types.chat import ChatCompletion


from pyrit.auth.azure_auth import get_token_provider_from_default_azure_credential
from pyrit.common import default_values
from pyrit.exceptions.exception_classes import PyritException, handle_bad_request_exception
from pyrit.memory import MemoryInterface
from pyrit.models import ChatMessageListContent
from pyrit.models import PromptRequestResponse, PromptRequestPiece
from pyrit.models.data_type_serializer import data_serializer_factory, DataTypeSerializer
from pyrit.prompt_target import PromptChatTarget
from pyrit.exceptions import EmptyResponseException, BadRequestException, RateLimitException, pyrit_retry
from pyrit.common.constants import RETRY_MAX_NUM_ATTEMPTS
from pyrit.exceptions import EmptyResponseException, pyrit_retry

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -239,27 +239,14 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P
request=request, response_text_pieces=[resp_text]
)
except BadRequestError as bre:
# Handle bad request error when content filter system detects harmful content
bad_request_exception = BadRequestException(bre.status_code, message=bre.message)
resp_text = bad_request_exception.process_exception()
response_entry = self._memory.add_response_entries_to_memory(
request=request, response_text_pieces=[resp_text], response_type="error", error="blocked"
)
except RateLimitError as rle:
# Handle the rate limit exception after exhausting the maximum number of retries.
rate_limit_exception = RateLimitException(rle.status_code, message=rle.message)
resp_text = rate_limit_exception.process_exception()
response_entry = self._memory.add_response_entries_to_memory(
request=request, response_text_pieces=[resp_text], response_type="error", error="error"
response_entry = handle_bad_request_exception(
memory=self._memory, response_text=bre.message, request=request
)
except EmptyResponseException:
# Handle the empty response exception after exhausting the maximum number of retries.
message = f"Empty response from the target even after {RETRY_MAX_NUM_ATTEMPTS} retries."
empty_response_exception = EmptyResponseException(message=message)
resp_text = empty_response_exception.process_exception()
response_entry = self._memory.add_response_entries_to_memory(
request=request, response_text_pieces=[resp_text], response_type="error", error="error"
except Exception as ex:
self._memory.add_response_entries_to_memory(
request=request, response_text_pieces=[str(ex)], response_type="error", error="processing"
)
raise
Comment thread
rdheekonda marked this conversation as resolved.

return response_entry

Expand Down Expand Up @@ -328,10 +315,9 @@ async def _complete_chat_async(
# Handle empty response
if not extracted_response:
raise EmptyResponseException(message="The chat returned an empty response.")
elif finish_reason == "content_filter":
message = response.choices[0]
content_filter_exception = BadRequestException(message=str(message))
extracted_response = content_filter_exception.process_exception()
else:
raise PyritException(message=f"Unknown finish_reason {finish_reason}")

return extracted_response

def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/exceptions/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@


def test_pyrit_exception_initialization():
ex = PyritException(500, "Internal Server Error")
ex = PyritException(500, message="Internal Server Error")
assert ex.status_code == 500
assert ex.message == "Internal Server Error"
assert str(ex) == "Status Code: 500, Message: Internal Server Error"


def test_pyrit_exception_process_exception(caplog):
ex = PyritException(500, "Internal Server Error")
ex = PyritException(500, message="Internal Server Error")
with caplog.at_level(logging.ERROR):
result = ex.process_exception()
assert json.loads(result) == {"status_code": 500, "message": "Internal Server Error"}
Expand Down
90 changes: 40 additions & 50 deletions tests/target/test_azure_openai_gptv_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from openai.types.chat.chat_completion import Choice
from openai import BadRequestError, RateLimitError

from pyrit.exceptions.exception_classes import EmptyResponseException
from pyrit.memory.duckdb_memory import DuckDBMemory
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.models.prompt_request_response import PromptRequestResponse
from pyrit.prompt_target import AzureOpenAIGPTVChatTarget
Expand All @@ -28,6 +31,7 @@ def azure_gptv_chat_engine() -> AzureOpenAIGPTVChatTarget:
endpoint="https://mock.azure.com/",
api_key="mock-api-key",
api_version="some_version",
memory=DuckDBMemory(db_path=":memory:"),
)


Expand Down Expand Up @@ -282,16 +286,14 @@ async def test_send_prompt_async_empty_response_adds_to_memory(
):
with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock_create:
mock_create.return_value = azure_openai_mock_return
response: PromptRequestResponse = await azure_gptv_chat_engine.send_prompt_async(
prompt_request=prompt_req_resp
)

azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="12345679")
azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(
request=prompt_req_resp
)
azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once()
assert response is not None, "Expected a result but got None"
with pytest.raises(EmptyResponseException) as e:
await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_req_resp)
azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="12345679")
azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(
request=prompt_req_resp
)
azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once()
assert str(e.value) == "Status Code: 204, Message: The chat returned an empty response."


@pytest.mark.asyncio
Expand All @@ -313,11 +315,13 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory(azure_gptv_
request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")]
)

result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request)
assert result is not None
azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123")
azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request)
azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once()
with pytest.raises(RateLimitError) as rle:
await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request)
azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123")
azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request)
azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once()

assert str(rle.value) == "Rate Limit Reached"


@pytest.mark.asyncio
Expand All @@ -339,11 +343,13 @@ async def test_send_prompt_async_bad_request_error_adds_to_memory(azure_gptv_cha
request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")]
)

result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request)
assert result is not None
azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123")
azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request)
azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once()
with pytest.raises(BadRequestError) as bre:
await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request)
azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123")
azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request)
azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once()

assert str(bre.value) == "Bad Request"


@pytest.mark.asyncio
Expand Down Expand Up @@ -393,7 +399,7 @@ async def test_send_prompt_async(


@pytest.mark.asyncio
async def test_send_prompt_async_empty_response(
async def test_send_prompt_async_empty_response_retries(
azure_openai_mock_return: ChatCompletion, azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget
):
with NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
Expand Down Expand Up @@ -433,23 +439,16 @@ async def test_send_prompt_async_empty_response(
with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock_create:
mock_create.return_value = azure_openai_mock_return
constants.RETRY_MAX_NUM_ATTEMPTS = 5
response: PromptRequestResponse = await azure_gptv_chat_engine.send_prompt_async(
prompt_request=prompt_req_resp
)
assert len(response.request_pieces) == 1
expected_error_message = (
'{"status_code": 204, "message": "Empty response from the target even after 5 retries."}'
)
assert response.request_pieces[0].converted_value == expected_error_message
assert response.request_pieces[0].converted_value_data_type == "error"
assert response.request_pieces[0].original_value == expected_error_message
assert response.request_pieces[0].original_value_data_type == "error"
assert str(constants.RETRY_MAX_NUM_ATTEMPTS) in response.request_pieces[0].converted_value
os.remove(tmp_file_name)
azure_gptv_chat_engine._memory = MagicMock(MemoryInterface)

with pytest.raises(EmptyResponseException):
await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_req_resp)

assert mock_create.call_count == constants.RETRY_MAX_NUM_ATTEMPTS


@pytest.mark.asyncio
async def test_send_prompt_async_rate_limit_exception(azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget):
async def test_send_prompt_async_rate_limit_exception_retries(azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget):

response = MagicMock()
response.status_code = 429
Expand All @@ -461,13 +460,9 @@ async def test_send_prompt_async_rate_limit_exception(azure_gptv_chat_engine: Az
request_pieces=[PromptRequestPiece(role="user", conversation_id="12345", original_value="Hello")]
)

result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request)
assert "Rate Limit Reached" in result.request_pieces[0].converted_value
assert "Rate Limit Reached" in result.request_pieces[0].original_value
assert result.request_pieces[0].original_value_data_type == "error"
assert result.request_pieces[0].converted_value_data_type == "error"
expected_sha_256 = "7d0ed53fb1c888e3467776735ee117e328c24f1a588a5f8756ba213c9b0b84a9"
assert result.request_pieces[0].original_value_sha256 == expected_sha_256
with pytest.raises(RateLimitError):
await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request)
assert mock_complete_chat_async.call_count == constants.RETRY_MAX_NUM_ATTEMPTS


@pytest.mark.asyncio
Expand All @@ -483,14 +478,9 @@ async def test_send_prompt_async_bad_request_error(azure_gptv_chat_engine: Azure
prompt_request = PromptRequestResponse(
request_pieces=[PromptRequestPiece(role="user", conversation_id="1236748", original_value="Hello")]
)

result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request)
assert "Bad Request Error" in result.request_pieces[0].converted_value
assert "Bad Request Error" in result.request_pieces[0].original_value
assert result.request_pieces[0].original_value_data_type == "error"
assert result.request_pieces[0].converted_value_data_type == "error"
expected_sha256 = "4e98b0da48c028f090473fe5cc71461a921465f807ae66c5f7ae9d0e9f301f77"
assert result.request_pieces[0].original_value_sha256 == expected_sha256
with pytest.raises(BadRequestError) as bre:
await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request)
assert str(bre.value) == "Bad Request Error"


def test_parse_chat_completion_successful(azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget):
Expand Down