diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index 7fc2ae92f5..f5b4509fc2 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -10,13 +10,17 @@ from typing import Callable from pyrit.common.constants import RETRY_WAIT_MIN_SECONDS, RETRY_WAIT_MAX_SECONDS, RETRY_MAX_NUM_ATTEMPTS +from pyrit.memory.memory_interface import MemoryInterface +from pyrit.models.prompt_request_piece import PromptRequestPiece +from pyrit.models.prompt_request_response import PromptRequestResponse + logger = logging.getLogger(__name__) class PyritException(Exception, ABC): - def __init__(self, status_code, message): + def __init__(self, status_code=500, *, message: str = "An error occured"): self.status_code = status_code self.message = message super().__init__(f"Status Code: {status_code}, Message: {message}") @@ -35,14 +39,14 @@ class BadRequestException(PyritException): """Exception class for bad client requests.""" def __init__(self, status_code: int = 400, *, message: str = "Bad Request"): - super().__init__(status_code, message) + super().__init__(status_code, message=message) class RateLimitException(PyritException): """Exception class for authentication errors.""" def __init__(self, status_code: int = 429, *, message: str = "Rate Limit Exception"): - super().__init__(status_code, message) + super().__init__(status_code, message=message) class EmptyResponseException(BadRequestException): @@ -52,6 +56,26 @@ def __init__(self, status_code: int = 204, *, message: str = "No Content"): super().__init__(status_code=status_code, message=message) +def handle_bad_request_exception( + memory: MemoryInterface, response_text: str, request: PromptRequestPiece +) -> PromptRequestResponse: + + if "content_filter" in response_text: + # Handle bad request error when content filter system detects harmful content + bad_request_exception = BadRequestException(400, message=response_text) + resp_text = bad_request_exception.process_exception() + response_entry = memory.add_response_entries_to_memory( + request=request, response_text_pieces=[resp_text], response_type="error", error="blocked" + ) + else: + memory.add_response_entries_to_memory( + request=request, response_text_pieces=[response_text], response_type="error", error="processing" + ) + raise + + return response_entry + + def pyrit_retry(func: Callable) -> Callable: """ A decorator to apply retry logic with exponential backoff to a function. diff --git a/pyrit/models/literals.py b/pyrit/models/literals.py index eb1b5d660e..6d2506009d 100644 --- a/pyrit/models/literals.py +++ b/pyrit/models/literals.py @@ -9,10 +9,10 @@ """ The type of the error in the prompt response blocked: blocked by an external filter e.g. Azure Filters -model: the model refused to answer or request e.g. "I'm sorry..." +none: no exception is raised processing: there is an exception thrown unrelated to the query unknown: the type of error is unknown """ -PromptResponseError = Literal["none", "blocked", "error", "model", "processing", "unknown"] +PromptResponseError = Literal["blocked", "none", "processing", "unknown"] ChatMessageRole = Literal["system", "user", "assistant"] diff --git a/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py b/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py index 3f9a4c859e..eae981ef6a 100644 --- a/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py @@ -7,19 +7,19 @@ import json from openai import AsyncAzureOpenAI -from openai import BadRequestError, RateLimitError +from openai import BadRequestError from openai.types.chat import ChatCompletion from pyrit.auth.azure_auth import get_token_provider_from_default_azure_credential from pyrit.common import default_values +from pyrit.exceptions.exception_classes import PyritException, handle_bad_request_exception from pyrit.memory import MemoryInterface from pyrit.models import ChatMessageListContent from pyrit.models import PromptRequestResponse, PromptRequestPiece from pyrit.models.data_type_serializer import data_serializer_factory, DataTypeSerializer from pyrit.prompt_target import PromptChatTarget -from pyrit.exceptions import EmptyResponseException, BadRequestException, RateLimitException, pyrit_retry -from pyrit.common.constants import RETRY_MAX_NUM_ATTEMPTS +from pyrit.exceptions import EmptyResponseException, pyrit_retry logger = logging.getLogger(__name__) @@ -239,27 +239,14 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P request=request, response_text_pieces=[resp_text] ) except BadRequestError as bre: - # Handle bad request error when content filter system detects harmful content - bad_request_exception = BadRequestException(bre.status_code, message=bre.message) - resp_text = bad_request_exception.process_exception() - response_entry = self._memory.add_response_entries_to_memory( - request=request, response_text_pieces=[resp_text], response_type="error", error="blocked" - ) - except RateLimitError as rle: - # Handle the rate limit exception after exhausting the maximum number of retries. - rate_limit_exception = RateLimitException(rle.status_code, message=rle.message) - resp_text = rate_limit_exception.process_exception() - response_entry = self._memory.add_response_entries_to_memory( - request=request, response_text_pieces=[resp_text], response_type="error", error="error" + response_entry = handle_bad_request_exception( + memory=self._memory, response_text=bre.message, request=request ) - except EmptyResponseException: - # Handle the empty response exception after exhausting the maximum number of retries. - message = f"Empty response from the target even after {RETRY_MAX_NUM_ATTEMPTS} retries." - empty_response_exception = EmptyResponseException(message=message) - resp_text = empty_response_exception.process_exception() - response_entry = self._memory.add_response_entries_to_memory( - request=request, response_text_pieces=[resp_text], response_type="error", error="error" + except Exception as ex: + self._memory.add_response_entries_to_memory( + request=request, response_text_pieces=[str(ex)], response_type="error", error="processing" ) + raise return response_entry @@ -328,10 +315,9 @@ async def _complete_chat_async( # Handle empty response if not extracted_response: raise EmptyResponseException(message="The chat returned an empty response.") - elif finish_reason == "content_filter": - message = response.choices[0] - content_filter_exception = BadRequestException(message=str(message)) - extracted_response = content_filter_exception.process_exception() + else: + raise PyritException(message=f"Unknown finish_reason {finish_reason}") + return extracted_response def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: diff --git a/tests/exceptions/test_exceptions.py b/tests/exceptions/test_exceptions.py index d6c128e5c8..13c5d9a9ad 100644 --- a/tests/exceptions/test_exceptions.py +++ b/tests/exceptions/test_exceptions.py @@ -7,14 +7,14 @@ def test_pyrit_exception_initialization(): - ex = PyritException(500, "Internal Server Error") + ex = PyritException(500, message="Internal Server Error") assert ex.status_code == 500 assert ex.message == "Internal Server Error" assert str(ex) == "Status Code: 500, Message: Internal Server Error" def test_pyrit_exception_process_exception(caplog): - ex = PyritException(500, "Internal Server Error") + ex = PyritException(500, message="Internal Server Error") with caplog.at_level(logging.ERROR): result = ex.process_exception() assert json.loads(result) == {"status_code": 500, "message": "Internal Server Error"} diff --git a/tests/target/test_azure_openai_gptv_chat_target.py b/tests/target/test_azure_openai_gptv_chat_target.py index 2a742f7c02..23207a92af 100644 --- a/tests/target/test_azure_openai_gptv_chat_target.py +++ b/tests/target/test_azure_openai_gptv_chat_target.py @@ -12,6 +12,9 @@ from openai.types.chat.chat_completion import Choice from openai import BadRequestError, RateLimitError +from pyrit.exceptions.exception_classes import EmptyResponseException +from pyrit.memory.duckdb_memory import DuckDBMemory +from pyrit.memory.memory_interface import MemoryInterface from pyrit.models.prompt_request_piece import PromptRequestPiece from pyrit.models.prompt_request_response import PromptRequestResponse from pyrit.prompt_target import AzureOpenAIGPTVChatTarget @@ -28,6 +31,7 @@ def azure_gptv_chat_engine() -> AzureOpenAIGPTVChatTarget: endpoint="https://mock.azure.com/", api_key="mock-api-key", api_version="some_version", + memory=DuckDBMemory(db_path=":memory:"), ) @@ -282,16 +286,14 @@ async def test_send_prompt_async_empty_response_adds_to_memory( ): with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock_create: mock_create.return_value = azure_openai_mock_return - response: PromptRequestResponse = await azure_gptv_chat_engine.send_prompt_async( - prompt_request=prompt_req_resp - ) - - azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="12345679") - azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with( - request=prompt_req_resp - ) - azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() - assert response is not None, "Expected a result but got None" + with pytest.raises(EmptyResponseException) as e: + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_req_resp) + azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="12345679") + azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with( + request=prompt_req_resp + ) + azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() + assert str(e.value) == "Status Code: 204, Message: The chat returned an empty response." @pytest.mark.asyncio @@ -313,11 +315,13 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory(azure_gptv_ request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")] ) - result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) - assert result is not None - azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123") - azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request) - azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() + with pytest.raises(RateLimitError) as rle: + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) + azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123") + azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request) + azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() + + assert str(rle.value) == "Rate Limit Reached" @pytest.mark.asyncio @@ -339,11 +343,13 @@ async def test_send_prompt_async_bad_request_error_adds_to_memory(azure_gptv_cha request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")] ) - result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) - assert result is not None - azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123") - azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request) - azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() + with pytest.raises(BadRequestError) as bre: + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) + azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123") + azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request) + azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() + + assert str(bre.value) == "Bad Request" @pytest.mark.asyncio @@ -393,7 +399,7 @@ async def test_send_prompt_async( @pytest.mark.asyncio -async def test_send_prompt_async_empty_response( +async def test_send_prompt_async_empty_response_retries( azure_openai_mock_return: ChatCompletion, azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget ): with NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: @@ -433,23 +439,16 @@ async def test_send_prompt_async_empty_response( with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock_create: mock_create.return_value = azure_openai_mock_return constants.RETRY_MAX_NUM_ATTEMPTS = 5 - response: PromptRequestResponse = await azure_gptv_chat_engine.send_prompt_async( - prompt_request=prompt_req_resp - ) - assert len(response.request_pieces) == 1 - expected_error_message = ( - '{"status_code": 204, "message": "Empty response from the target even after 5 retries."}' - ) - assert response.request_pieces[0].converted_value == expected_error_message - assert response.request_pieces[0].converted_value_data_type == "error" - assert response.request_pieces[0].original_value == expected_error_message - assert response.request_pieces[0].original_value_data_type == "error" - assert str(constants.RETRY_MAX_NUM_ATTEMPTS) in response.request_pieces[0].converted_value - os.remove(tmp_file_name) + azure_gptv_chat_engine._memory = MagicMock(MemoryInterface) + + with pytest.raises(EmptyResponseException): + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_req_resp) + + assert mock_create.call_count == constants.RETRY_MAX_NUM_ATTEMPTS @pytest.mark.asyncio -async def test_send_prompt_async_rate_limit_exception(azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget): +async def test_send_prompt_async_rate_limit_exception_retries(azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget): response = MagicMock() response.status_code = 429 @@ -461,13 +460,9 @@ async def test_send_prompt_async_rate_limit_exception(azure_gptv_chat_engine: Az request_pieces=[PromptRequestPiece(role="user", conversation_id="12345", original_value="Hello")] ) - result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) - assert "Rate Limit Reached" in result.request_pieces[0].converted_value - assert "Rate Limit Reached" in result.request_pieces[0].original_value - assert result.request_pieces[0].original_value_data_type == "error" - assert result.request_pieces[0].converted_value_data_type == "error" - expected_sha_256 = "7d0ed53fb1c888e3467776735ee117e328c24f1a588a5f8756ba213c9b0b84a9" - assert result.request_pieces[0].original_value_sha256 == expected_sha_256 + with pytest.raises(RateLimitError): + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) + assert mock_complete_chat_async.call_count == constants.RETRY_MAX_NUM_ATTEMPTS @pytest.mark.asyncio @@ -483,14 +478,9 @@ async def test_send_prompt_async_bad_request_error(azure_gptv_chat_engine: Azure prompt_request = PromptRequestResponse( request_pieces=[PromptRequestPiece(role="user", conversation_id="1236748", original_value="Hello")] ) - - result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) - assert "Bad Request Error" in result.request_pieces[0].converted_value - assert "Bad Request Error" in result.request_pieces[0].original_value - assert result.request_pieces[0].original_value_data_type == "error" - assert result.request_pieces[0].converted_value_data_type == "error" - expected_sha256 = "4e98b0da48c028f090473fe5cc71461a921465f807ae66c5f7ae9d0e9f301f77" - assert result.request_pieces[0].original_value_sha256 == expected_sha256 + with pytest.raises(BadRequestError) as bre: + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) + assert str(bre.value) == "Bad Request Error" def test_parse_chat_completion_successful(azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget):