From 2634dbe412c112c863cfaec8cb62a16c784d2498 Mon Sep 17 00:00:00 2001 From: rlundeen2 Date: Wed, 22 May 2024 18:43:42 -0700 Subject: [PATCH 1/3] MAINT: Updating GPT-V to use new guidance --- pyrit/exceptions/exception_classes.py | 23 +++++++++++ pyrit/models/literals.py | 3 +- .../azure_openai_gptv_chat_target.py | 35 +++++----------- .../test_azure_openai_gptv_chat_target.py | 40 +++++++++---------- 4 files changed, 53 insertions(+), 48 deletions(-) diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index 7fc2ae92f5..b2155c0b93 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -10,6 +10,8 @@ from typing import Callable from pyrit.common.constants import RETRY_WAIT_MIN_SECONDS, RETRY_WAIT_MAX_SECONDS, RETRY_MAX_NUM_ATTEMPTS +from pyrit.memory.memory_interface import MemoryInterface +from pyrit.models.prompt_request_piece import PromptRequestPiece logger = logging.getLogger(__name__) @@ -52,6 +54,27 @@ def __init__(self, status_code: int = 204, *, message: str = "No Content"): super().__init__(status_code=status_code, message=message) +def handle_bad_request_exception( + memory: MemoryInterface, + response_text: str, + request: PromptRequestPiece + ) -> PromptRequestPiece: + + if "content_filter" in response_text: + # Handle bad request error when content filter system detects harmful content + bad_request_exception = BadRequestException(400, message=response_text) + resp_text = bad_request_exception.process_exception() + response_entry = memory.add_response_entries_to_memory( + request=request, response_text_pieces=[resp_text], response_type="error", error="blocked" + ) + else: + memory.add_response_entries_to_memory( + request=request, response_text_pieces=[resp_text], response_type="error", error="processing" + ) + raise + + return response_entry + def pyrit_retry(func: Callable) -> Callable: """ A decorator to apply retry logic with exponential backoff to a function. diff --git a/pyrit/models/literals.py b/pyrit/models/literals.py index eb1b5d660e..ede9a13276 100644 --- a/pyrit/models/literals.py +++ b/pyrit/models/literals.py @@ -9,10 +9,9 @@ """ The type of the error in the prompt response blocked: blocked by an external filter e.g. Azure Filters -model: the model refused to answer or request e.g. "I'm sorry..." processing: there is an exception thrown unrelated to the query unknown: the type of error is unknown """ -PromptResponseError = Literal["none", "blocked", "error", "model", "processing", "unknown"] +PromptResponseError = Literal["blocked", "processing", "unknown"] ChatMessageRole = Literal["system", "user", "assistant"] diff --git a/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py b/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py index 3f9a4c859e..25b597639d 100644 --- a/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py @@ -13,6 +13,7 @@ from pyrit.auth.azure_auth import get_token_provider_from_default_azure_credential from pyrit.common import default_values +from pyrit.exceptions.exception_classes import PyritException, handle_bad_request_exception from pyrit.memory import MemoryInterface from pyrit.models import ChatMessageListContent from pyrit.models import PromptRequestResponse, PromptRequestPiece @@ -239,27 +240,14 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P request=request, response_text_pieces=[resp_text] ) except BadRequestError as bre: - # Handle bad request error when content filter system detects harmful content - bad_request_exception = BadRequestException(bre.status_code, message=bre.message) - resp_text = bad_request_exception.process_exception() - response_entry = self._memory.add_response_entries_to_memory( - request=request, response_text_pieces=[resp_text], response_type="error", error="blocked" - ) - except RateLimitError as rle: - # Handle the rate limit exception after exhausting the maximum number of retries. - rate_limit_exception = RateLimitException(rle.status_code, message=rle.message) - resp_text = rate_limit_exception.process_exception() - response_entry = self._memory.add_response_entries_to_memory( - request=request, response_text_pieces=[resp_text], response_type="error", error="error" - ) - except EmptyResponseException: - # Handle the empty response exception after exhausting the maximum number of retries. - message = f"Empty response from the target even after {RETRY_MAX_NUM_ATTEMPTS} retries." - empty_response_exception = EmptyResponseException(message=message) - resp_text = empty_response_exception.process_exception() - response_entry = self._memory.add_response_entries_to_memory( - request=request, response_text_pieces=[resp_text], response_type="error", error="error" + response_entry = handle_bad_request_exception( + memory=self._memory, response_text=bre.message, request=request ) + except Exception as ex: + self._memory.add_response_entries_to_memory( + request=request, response_text_pieces=[ex], response_type="error", error="processing" + ) + raise return response_entry @@ -328,10 +316,9 @@ async def _complete_chat_async( # Handle empty response if not extracted_response: raise EmptyResponseException(message="The chat returned an empty response.") - elif finish_reason == "content_filter": - message = response.choices[0] - content_filter_exception = BadRequestException(message=str(message)) - extracted_response = content_filter_exception.process_exception() + else: + raise PyritException(f"Unknown finish_reason {finish_reason}") + return extracted_response def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: diff --git a/tests/target/test_azure_openai_gptv_chat_target.py b/tests/target/test_azure_openai_gptv_chat_target.py index 2a742f7c02..710e5dae70 100644 --- a/tests/target/test_azure_openai_gptv_chat_target.py +++ b/tests/target/test_azure_openai_gptv_chat_target.py @@ -12,6 +12,9 @@ from openai.types.chat.chat_completion import Choice from openai import BadRequestError, RateLimitError +from pyrit.exceptions.exception_classes import EmptyResponseException +from pyrit.memory.duckdb_memory import DuckDBMemory +from pyrit.memory.memory_interface import MemoryInterface from pyrit.models.prompt_request_piece import PromptRequestPiece from pyrit.models.prompt_request_response import PromptRequestResponse from pyrit.prompt_target import AzureOpenAIGPTVChatTarget @@ -28,6 +31,7 @@ def azure_gptv_chat_engine() -> AzureOpenAIGPTVChatTarget: endpoint="https://mock.azure.com/", api_key="mock-api-key", api_version="some_version", + memory=DuckDBMemory(db_path=":memory:"), ) @@ -393,7 +397,7 @@ async def test_send_prompt_async( @pytest.mark.asyncio -async def test_send_prompt_async_empty_response( +async def test_send_prompt_async_empty_response_retries( azure_openai_mock_return: ChatCompletion, azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget ): with NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: @@ -433,23 +437,18 @@ async def test_send_prompt_async_empty_response( with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock_create: mock_create.return_value = azure_openai_mock_return constants.RETRY_MAX_NUM_ATTEMPTS = 5 - response: PromptRequestResponse = await azure_gptv_chat_engine.send_prompt_async( - prompt_request=prompt_req_resp - ) - assert len(response.request_pieces) == 1 - expected_error_message = ( - '{"status_code": 204, "message": "Empty response from the target even after 5 retries."}' - ) - assert response.request_pieces[0].converted_value == expected_error_message - assert response.request_pieces[0].converted_value_data_type == "error" - assert response.request_pieces[0].original_value == expected_error_message - assert response.request_pieces[0].original_value_data_type == "error" - assert str(constants.RETRY_MAX_NUM_ATTEMPTS) in response.request_pieces[0].converted_value - os.remove(tmp_file_name) + azure_gptv_chat_engine._memory = MagicMock(MemoryInterface) + + with pytest.raises(EmptyResponseException): + await azure_gptv_chat_engine.send_prompt_async( + prompt_request=prompt_req_resp + ) + + assert mock_create.call_count == constants.RETRY_MAX_NUM_ATTEMPTS @pytest.mark.asyncio -async def test_send_prompt_async_rate_limit_exception(azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget): +async def test_send_prompt_async_rate_limit_exception_retries(azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget): response = MagicMock() response.status_code = 429 @@ -461,13 +460,10 @@ async def test_send_prompt_async_rate_limit_exception(azure_gptv_chat_engine: Az request_pieces=[PromptRequestPiece(role="user", conversation_id="12345", original_value="Hello")] ) - result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) - assert "Rate Limit Reached" in result.request_pieces[0].converted_value - assert "Rate Limit Reached" in result.request_pieces[0].original_value - assert result.request_pieces[0].original_value_data_type == "error" - assert result.request_pieces[0].converted_value_data_type == "error" - expected_sha_256 = "7d0ed53fb1c888e3467776735ee117e328c24f1a588a5f8756ba213c9b0b84a9" - assert result.request_pieces[0].original_value_sha256 == expected_sha_256 + with pytest.raises(RateLimitError): + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) + + assert mock_complete_chat_async.call_count == constants.RETRY_MAX_NUM_ATTEMPTS @pytest.mark.asyncio From 326b5608554351fba714936b8abcf67a1fc1c0cc Mon Sep 17 00:00:00 2001 From: rdheekonda Date: Thu, 23 May 2024 10:24:12 -0700 Subject: [PATCH 2/3] Fix literals to include none type, fix unit tests, add unit tests for literals --- pyrit/exceptions/exception_classes.py | 17 +++--- pyrit/models/literals.py | 3 +- .../azure_openai_gptv_chat_target.py | 13 ++-- tests/exceptions/test_exceptions.py | 4 +- .../test_azure_openai_gptv_chat_target.py | 60 +++++++++---------- 5 files changed, 46 insertions(+), 51 deletions(-) diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index b2155c0b93..8e1ff49f8d 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -12,13 +12,15 @@ from pyrit.common.constants import RETRY_WAIT_MIN_SECONDS, RETRY_WAIT_MAX_SECONDS, RETRY_MAX_NUM_ATTEMPTS from pyrit.memory.memory_interface import MemoryInterface from pyrit.models.prompt_request_piece import PromptRequestPiece +from pyrit.models.prompt_request_response import PromptRequestResponse + logger = logging.getLogger(__name__) class PyritException(Exception, ABC): - def __init__(self, status_code, message): + def __init__(self, status_code=500, *, message: str="An error occured"): self.status_code = status_code self.message = message super().__init__(f"Status Code: {status_code}, Message: {message}") @@ -37,14 +39,14 @@ class BadRequestException(PyritException): """Exception class for bad client requests.""" def __init__(self, status_code: int = 400, *, message: str = "Bad Request"): - super().__init__(status_code, message) + super().__init__(status_code, message=message) class RateLimitException(PyritException): """Exception class for authentication errors.""" def __init__(self, status_code: int = 429, *, message: str = "Rate Limit Exception"): - super().__init__(status_code, message) + super().__init__(status_code, message=message) class EmptyResponseException(BadRequestException): @@ -55,10 +57,8 @@ def __init__(self, status_code: int = 204, *, message: str = "No Content"): def handle_bad_request_exception( - memory: MemoryInterface, - response_text: str, - request: PromptRequestPiece - ) -> PromptRequestPiece: + memory: MemoryInterface, response_text: str, request: PromptRequestPiece +) -> PromptRequestResponse: if "content_filter" in response_text: # Handle bad request error when content filter system detects harmful content @@ -69,12 +69,13 @@ def handle_bad_request_exception( ) else: memory.add_response_entries_to_memory( - request=request, response_text_pieces=[resp_text], response_type="error", error="processing" + request=request, response_text_pieces=[response_text], response_type="error", error="processing" ) raise return response_entry + def pyrit_retry(func: Callable) -> Callable: """ A decorator to apply retry logic with exponential backoff to a function. diff --git a/pyrit/models/literals.py b/pyrit/models/literals.py index ede9a13276..6d2506009d 100644 --- a/pyrit/models/literals.py +++ b/pyrit/models/literals.py @@ -9,9 +9,10 @@ """ The type of the error in the prompt response blocked: blocked by an external filter e.g. Azure Filters +none: no exception is raised processing: there is an exception thrown unrelated to the query unknown: the type of error is unknown """ -PromptResponseError = Literal["blocked", "processing", "unknown"] +PromptResponseError = Literal["blocked", "none", "processing", "unknown"] ChatMessageRole = Literal["system", "user", "assistant"] diff --git a/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py b/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py index 25b597639d..eae981ef6a 100644 --- a/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py +++ b/pyrit/prompt_target/prompt_chat_target/azure_openai_gptv_chat_target.py @@ -7,7 +7,7 @@ import json from openai import AsyncAzureOpenAI -from openai import BadRequestError, RateLimitError +from openai import BadRequestError from openai.types.chat import ChatCompletion @@ -19,8 +19,7 @@ from pyrit.models import PromptRequestResponse, PromptRequestPiece from pyrit.models.data_type_serializer import data_serializer_factory, DataTypeSerializer from pyrit.prompt_target import PromptChatTarget -from pyrit.exceptions import EmptyResponseException, BadRequestException, RateLimitException, pyrit_retry -from pyrit.common.constants import RETRY_MAX_NUM_ATTEMPTS +from pyrit.exceptions import EmptyResponseException, pyrit_retry logger = logging.getLogger(__name__) @@ -245,8 +244,8 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P ) except Exception as ex: self._memory.add_response_entries_to_memory( - request=request, response_text_pieces=[ex], response_type="error", error="processing" - ) + request=request, response_text_pieces=[str(ex)], response_type="error", error="processing" + ) raise return response_entry @@ -317,8 +316,8 @@ async def _complete_chat_async( if not extracted_response: raise EmptyResponseException(message="The chat returned an empty response.") else: - raise PyritException(f"Unknown finish_reason {finish_reason}") - + raise PyritException(message=f"Unknown finish_reason {finish_reason}") + return extracted_response def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: diff --git a/tests/exceptions/test_exceptions.py b/tests/exceptions/test_exceptions.py index d6c128e5c8..13c5d9a9ad 100644 --- a/tests/exceptions/test_exceptions.py +++ b/tests/exceptions/test_exceptions.py @@ -7,14 +7,14 @@ def test_pyrit_exception_initialization(): - ex = PyritException(500, "Internal Server Error") + ex = PyritException(500, message="Internal Server Error") assert ex.status_code == 500 assert ex.message == "Internal Server Error" assert str(ex) == "Status Code: 500, Message: Internal Server Error" def test_pyrit_exception_process_exception(caplog): - ex = PyritException(500, "Internal Server Error") + ex = PyritException(500, message="Internal Server Error") with caplog.at_level(logging.ERROR): result = ex.process_exception() assert json.loads(result) == {"status_code": 500, "message": "Internal Server Error"} diff --git a/tests/target/test_azure_openai_gptv_chat_target.py b/tests/target/test_azure_openai_gptv_chat_target.py index 710e5dae70..23207a92af 100644 --- a/tests/target/test_azure_openai_gptv_chat_target.py +++ b/tests/target/test_azure_openai_gptv_chat_target.py @@ -286,16 +286,14 @@ async def test_send_prompt_async_empty_response_adds_to_memory( ): with patch("openai.resources.chat.AsyncCompletions.create", new_callable=AsyncMock) as mock_create: mock_create.return_value = azure_openai_mock_return - response: PromptRequestResponse = await azure_gptv_chat_engine.send_prompt_async( - prompt_request=prompt_req_resp - ) - - azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="12345679") - azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with( - request=prompt_req_resp - ) - azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() - assert response is not None, "Expected a result but got None" + with pytest.raises(EmptyResponseException) as e: + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_req_resp) + azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="12345679") + azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with( + request=prompt_req_resp + ) + azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() + assert str(e.value) == "Status Code: 204, Message: The chat returned an empty response." @pytest.mark.asyncio @@ -317,11 +315,13 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory(azure_gptv_ request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")] ) - result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) - assert result is not None - azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123") - azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request) - azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() + with pytest.raises(RateLimitError) as rle: + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) + azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123") + azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request) + azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() + + assert str(rle.value) == "Rate Limit Reached" @pytest.mark.asyncio @@ -343,11 +343,13 @@ async def test_send_prompt_async_bad_request_error_adds_to_memory(azure_gptv_cha request_pieces=[PromptRequestPiece(role="user", conversation_id="123", original_value="Hello")] ) - result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) - assert result is not None - azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123") - azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request) - azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() + with pytest.raises(BadRequestError) as bre: + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) + azure_gptv_chat_engine._memory.get_conversation.assert_called_once_with(conversation_id="123") + azure_gptv_chat_engine._memory.add_request_response_to_memory.assert_called_once_with(request=prompt_request) + azure_gptv_chat_engine._memory.add_response_entries_to_memory.assert_called_once() + + assert str(bre.value) == "Bad Request" @pytest.mark.asyncio @@ -440,9 +442,7 @@ async def test_send_prompt_async_empty_response_retries( azure_gptv_chat_engine._memory = MagicMock(MemoryInterface) with pytest.raises(EmptyResponseException): - await azure_gptv_chat_engine.send_prompt_async( - prompt_request=prompt_req_resp - ) + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_req_resp) assert mock_create.call_count == constants.RETRY_MAX_NUM_ATTEMPTS @@ -462,8 +462,7 @@ async def test_send_prompt_async_rate_limit_exception_retries(azure_gptv_chat_en with pytest.raises(RateLimitError): await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) - - assert mock_complete_chat_async.call_count == constants.RETRY_MAX_NUM_ATTEMPTS + assert mock_complete_chat_async.call_count == constants.RETRY_MAX_NUM_ATTEMPTS @pytest.mark.asyncio @@ -479,14 +478,9 @@ async def test_send_prompt_async_bad_request_error(azure_gptv_chat_engine: Azure prompt_request = PromptRequestResponse( request_pieces=[PromptRequestPiece(role="user", conversation_id="1236748", original_value="Hello")] ) - - result = await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) - assert "Bad Request Error" in result.request_pieces[0].converted_value - assert "Bad Request Error" in result.request_pieces[0].original_value - assert result.request_pieces[0].original_value_data_type == "error" - assert result.request_pieces[0].converted_value_data_type == "error" - expected_sha256 = "4e98b0da48c028f090473fe5cc71461a921465f807ae66c5f7ae9d0e9f301f77" - assert result.request_pieces[0].original_value_sha256 == expected_sha256 + with pytest.raises(BadRequestError) as bre: + await azure_gptv_chat_engine.send_prompt_async(prompt_request=prompt_request) + assert str(bre.value) == "Bad Request Error" def test_parse_chat_completion_successful(azure_gptv_chat_engine: AzureOpenAIGPTVChatTarget): From fd387fc1f1562e57d507d3280ec1fd2d4f1894c9 Mon Sep 17 00:00:00 2001 From: rdheekonda Date: Thu, 23 May 2024 10:31:19 -0700 Subject: [PATCH 3/3] fix mypy --- pyrit/exceptions/exception_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index 8e1ff49f8d..f5b4509fc2 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -20,7 +20,7 @@ class PyritException(Exception, ABC): - def __init__(self, status_code=500, *, message: str="An error occured"): + def __init__(self, status_code=500, *, message: str = "An error occured"): self.status_code = status_code self.message = message super().__init__(f"Status Code: {status_code}, Message: {message}")