From 0197ac972e79e8bf077d9351ce8f9124c7505f29 Mon Sep 17 00:00:00 2001 From: rdheekonda Date: Wed, 22 May 2024 18:49:03 -0700 Subject: [PATCH 1/5] Add exception handling to DALLE target --- pyrit/prompt_target/dall_e_target.py | 81 ++++++++----- tests/target/test_dall_e_target.py | 163 ++++++++++++++++++++++++++- 2 files changed, 215 insertions(+), 29 deletions(-) diff --git a/pyrit/prompt_target/dall_e_target.py b/pyrit/prompt_target/dall_e_target.py index cecfe9fa82..510d6fead2 100644 --- a/pyrit/prompt_target/dall_e_target.py +++ b/pyrit/prompt_target/dall_e_target.py @@ -7,15 +7,17 @@ import asyncio from typing import Literal, Optional, Dict, Any -from openai import BadRequestError +from openai import BadRequestError, RateLimitError from pyrit.common.path import RESULTS_PATH +from pyrit.exceptions import EmptyResponseException, BadRequestException, RateLimitException, pyrit_retry from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import PromptRequestResponse, data_serializer_factory from pyrit.models.literals import PromptDataType from pyrit.models.prompt_request_piece import PromptRequestPiece, PromptResponseError from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.prompt_chat_target.openai_chat_target import AzureOpenAIChatTarget +from pyrit.common.constants import RETRY_MAX_NUM_ATTEMPTS logger = logging.getLogger(__name__) @@ -136,43 +138,66 @@ async def _generate_images_async(self, prompt: str, request=PromptRequestPiece) image_generation_args["style"] = self.style try: - response = await self._image_target._async_client.images.generate(**image_generation_args) - json_response = json.loads(response.model_dump_json()) - + b64_data = await self._generate_image_response_async(image_generation_args) data = data_serializer_factory(data_type="image_path") - b64_data = json_response["data"][0]["b64_json"] data.save_b64_image(data=b64_data) - prompt_text = data.value + resp_text = data.value error: PromptResponseError = "none" response_type: PromptDataType = "image_path" - except BadRequestError as e: - json_response = {"exception type": "Blocked", "data": ""} - json_response["error"] = e.body - prompt_text = "content blocked" + except BadRequestError as bre: + # Handle bad request error when content filter system detects harmful content + bad_request_exception = BadRequestException(bre.status_code, message=bre.message) + resp_text = bad_request_exception.process_exception() error = "blocked" - response_type = "text" - - except json.JSONDecodeError as e: - json_response = {"error": e, "exception type": "JSON Error"} - prompt_text = "JSON Error" - error = "processing" - response_type = "text" - - except Exception as e: - json_response = {"error": e, "exception type": "exception"} - prompt_text = "target error" - error = "unknown" - response_type = "text" + response_type = "error" + + except RateLimitError as rle: + # Handle the rate limit exception after exhausting the maximum number of retries. + rate_limit_exception = RateLimitException(rle.status_code, message=rle.message) + resp_text = rate_limit_exception.process_exception() + error = "error" + response_type = "error" + + except EmptyResponseException: + # Handle the empty response exception after exhausting the maximum number of retries. + message = f"Empty response from the target even after {RETRY_MAX_NUM_ATTEMPTS} retries." + empty_response_exception = EmptyResponseException(message=message) + resp_text = empty_response_exception.process_exception() + error = "error" + response_type = "error" return self._memory.add_response_entries_to_memory( - request=request, - response_text_pieces=[prompt_text], - response_type=response_type, - prompt_metadata=json.dumps(json_response), - error=error, + request=request, response_text_pieces=[resp_text], response_type=response_type, error=error ) + @pyrit_retry + async def _generate_image_response_async(self, image_generation_args): + """ + Asynchronously generates an image using the provided generation arguments. + + Retries the function if it raises RateLimitError (HTTP 429) or EmptyResponseException, + with a wait time between retries that follows an exponential backoff strategy. + Logs retry attempts at the INFO level and stops after a maximum number of attempts. + + Args: + image_generation_args (dict): The arguments required for image generation. + + Returns: + The generated image in base64 format. + + Raises: + RateLimitError: If the rate limit is exceeded and the maximum number of retries is exhausted. + EmptyResponseException: If the response is empty after exhausting the maximum number of retries. + """ + result = await self._image_target._async_client.images.generate(**image_generation_args) + json_response = json.loads(result.model_dump_json()) + b64_data = json_response["data"][0]["b64_json"] + # Handle empty response using retry + if not b64_data: + raise EmptyResponseException(message="The chat returned an empty response.") + return b64_data + def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: if len(prompt_request.request_pieces) != 1: raise ValueError("This target only supports a single prompt request piece.") diff --git a/tests/target/test_dall_e_target.py b/tests/target/test_dall_e_target.py index 59f9d4c849..6e9f65b8fd 100644 --- a/tests/target/test_dall_e_target.py +++ b/tests/target/test_dall_e_target.py @@ -6,11 +6,13 @@ import os import pytest +from openai import BadRequestError, RateLimitError + from pyrit.models.prompt_request_piece import PromptRequestPiece from pyrit.models import PromptRequestResponse from pyrit.prompt_target import DALLETarget - from tests.mocks import get_sample_conversations +from pyrit.common import constants @pytest.fixture @@ -54,6 +56,80 @@ async def test_send_prompt_async(mock_image, dalle_target: DALLETarget, sample_c assert resp +@pytest.mark.asyncio +async def test_send_prompt_async_empty_response( + dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece] +): + request = sample_conversations[0] + request.conversation_id = str(uuid.uuid4()) + + mock_return = MagicMock() + # make b64_json value empty to test retries when empty response was returned + mock_return.model_dump_json.return_value = '{"data": [{"b64_json": ""}]}' + dalle_target._image_target._async_client.images.generate = AsyncMock(return_value=mock_return) + constants.RETRY_MAX_NUM_ATTEMPTS = 5 + response: PromptRequestResponse = await dalle_target.send_prompt_async( + prompt_request=PromptRequestResponse([request]) + ) + assert len(response.request_pieces) == 1 + expected_error_message = '{"status_code": 204, "message": "Empty response from the target even after 5 retries."}' + assert response.request_pieces[0].converted_value == expected_error_message + assert response.request_pieces[0].converted_value_data_type == "error" + assert response.request_pieces[0].original_value == expected_error_message + assert response.request_pieces[0].original_value_data_type == "error" + assert str(constants.RETRY_MAX_NUM_ATTEMPTS) in response.request_pieces[0].converted_value + + +@pytest.mark.asyncio +async def test_send_prompt_async_rate_limit_exception( + dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece] +): + request = sample_conversations[0] + request.conversation_id = str(uuid.uuid4()) + + response = MagicMock() + response.status_code = 429 + mock_image_resp_async = AsyncMock( + side_effect=RateLimitError("Rate Limit Reached", response=response, body="Rate limit reached") + ) + setattr(dalle_target, "_generate_image_response_async", mock_image_resp_async) + + result: PromptRequestResponse = await dalle_target.send_prompt_async( + prompt_request=PromptRequestResponse([request]) + ) + assert "Rate Limit Reached" in result.request_pieces[0].converted_value + assert "Rate Limit Reached" in result.request_pieces[0].original_value + assert result.request_pieces[0].original_value_data_type == "error" + assert result.request_pieces[0].converted_value_data_type == "error" + expected_sha_256 = "7d0ed53fb1c888e3467776735ee117e328c24f1a588a5f8756ba213c9b0b84a9" + assert result.request_pieces[0].original_value_sha256 == expected_sha_256 + + +@pytest.mark.asyncio +async def test_send_prompt_async_bad_request_error( + dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece] +): + request = sample_conversations[0] + request.conversation_id = str(uuid.uuid4()) + + response = MagicMock() + response.status_code = 400 + mock_image_resp_async = AsyncMock( + side_effect=RateLimitError("Bad Request Error", response=response, body="Bad Request") + ) + setattr(dalle_target, "_generate_image_response_async", mock_image_resp_async) + + result: PromptRequestResponse = await dalle_target.send_prompt_async( + prompt_request=PromptRequestResponse([request]) + ) + assert "Bad Request Error" in result.request_pieces[0].converted_value + assert "Bad Request Error" in result.request_pieces[0].original_value + assert result.request_pieces[0].original_value_data_type == "error" + assert result.request_pieces[0].converted_value_data_type == "error" + expected_sha256 = "4e98b0da48c028f090473fe5cc71461a921465f807ae66c5f7ae9d0e9f301f77" + assert result.request_pieces[0].original_value_sha256 == expected_sha256 + + @pytest.mark.asyncio async def test_dalle_validate_request_length(dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece]): request = PromptRequestResponse(request_pieces=sample_conversations) @@ -123,3 +199,88 @@ async def test_dalle_send_prompt_adds_memory_async() -> None: await mock_dalle_target.send_prompt_async(prompt_request=request) assert mock_memory.add_request_response_to_memory.called, "Request and Response need to be added to memory" assert mock_memory.add_response_entries_to_memory.called, "Request and Response need to be added to memory" + + +@pytest.mark.asyncio +async def test_send_prompt_async_empty_response_adds_memory() -> None: + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + mock_memory.add_request_response_to_memory = AsyncMock() + mock_memory.add_response_entries_to_memory = AsyncMock() + request = PromptRequestPiece( + role="user", + original_value="draw me a test picture", + ).to_prompt_request_response() + + mock_return = MagicMock() + + # b64_json with empty response + mock_return.model_dump_json.return_value = '{"data": [{"b64_json": ""}]}' + + mock_dalle_target = DALLETarget(deployment_name="test", endpoint="test", api_key="test", memory=mock_memory) + mock_dalle_target._image_target._async_client.images = MagicMock() + mock_dalle_target._image_target._async_client.images.generate = AsyncMock(return_value=mock_return) + mock_dalle_target._memory = mock_memory + response = await mock_dalle_target.send_prompt_async(prompt_request=request) + assert response is not None, "Expected a result but got None" + mock_memory.add_response_entries_to_memory.assert_called_once(), "Request and Response need to be added to memory" + + +@pytest.mark.asyncio +async def test_send_prompt_async_rate_limit_adds_memory() -> None: + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + mock_memory.add_request_response_to_memory = AsyncMock() + mock_memory.add_response_entries_to_memory = AsyncMock() + request = PromptRequestPiece( + role="user", + original_value="draw me a test picture", + ).to_prompt_request_response() + + mock_dalle_target = DALLETarget(deployment_name="test", endpoint="test", api_key="test", memory=mock_memory) + mock_dalle_target._memory = mock_memory + + # mocking openai.RateLimitError + response = MagicMock() + response.status_code = 429 + mock_generate_image_response_async = AsyncMock( + side_effect=RateLimitError("Rate Limit Reached", response=response, body="Rate limit reached") + ) + setattr(mock_dalle_target, "_generate_image_response_async", mock_generate_image_response_async) + + response = await mock_dalle_target.send_prompt_async(prompt_request=request) + assert response is not None + mock_dalle_target._memory.add_request_response_to_memory.assert_called_once() + mock_dalle_target._memory.add_response_entries_to_memory.assert_called_once() + + +@pytest.mark.asyncio +async def test_send_prompt_async_bad_request_adds_memory() -> None: + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + mock_memory.add_request_response_to_memory = AsyncMock() + mock_memory.add_response_entries_to_memory = AsyncMock() + request = PromptRequestPiece( + role="user", + original_value="draw me a test picture", + ).to_prompt_request_response() + + mock_dalle_target = DALLETarget(deployment_name="test", endpoint="test", api_key="test", memory=mock_memory) + mock_dalle_target._memory = mock_memory + + # mocking openai.BadRequestError + response = MagicMock() + response.status_code = 400 + mock_generate_image_response_async = AsyncMock( + side_effect=BadRequestError("Bad Request", response=response, body="Bad Request") + ) + + setattr(mock_dalle_target, "_generate_image_response_async", mock_generate_image_response_async) + + response = await mock_dalle_target.send_prompt_async(prompt_request=request) + assert response is not None + mock_dalle_target._memory.add_request_response_to_memory.assert_called_once() + mock_dalle_target._memory.add_response_entries_to_memory.assert_called_once() From 8579caa37b82fd265dc179805d0d14d79375608a Mon Sep 17 00:00:00 2001 From: rdheekonda Date: Wed, 22 May 2024 19:21:28 -0700 Subject: [PATCH 2/5] Fix mypy errors --- tests/target/test_dall_e_target.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/target/test_dall_e_target.py b/tests/target/test_dall_e_target.py index 6e9f65b8fd..ca404547da 100644 --- a/tests/target/test_dall_e_target.py +++ b/tests/target/test_dall_e_target.py @@ -66,7 +66,7 @@ async def test_send_prompt_async_empty_response( mock_return = MagicMock() # make b64_json value empty to test retries when empty response was returned mock_return.model_dump_json.return_value = '{"data": [{"b64_json": ""}]}' - dalle_target._image_target._async_client.images.generate = AsyncMock(return_value=mock_return) + setattr(dalle_target._image_target._async_client.images, "generate", AsyncMock(return_value=mock_return)) constants.RETRY_MAX_NUM_ATTEMPTS = 5 response: PromptRequestResponse = await dalle_target.send_prompt_async( prompt_request=PromptRequestResponse([request]) @@ -224,7 +224,7 @@ async def test_send_prompt_async_empty_response_adds_memory() -> None: mock_dalle_target._memory = mock_memory response = await mock_dalle_target.send_prompt_async(prompt_request=request) assert response is not None, "Expected a result but got None" - mock_memory.add_response_entries_to_memory.assert_called_once(), "Request and Response need to be added to memory" + mock_memory.add_response_entries_to_memory.assert_called_once() @pytest.mark.asyncio @@ -243,10 +243,10 @@ async def test_send_prompt_async_rate_limit_adds_memory() -> None: mock_dalle_target._memory = mock_memory # mocking openai.RateLimitError - response = MagicMock() - response.status_code = 429 + mock_resp = MagicMock() + mock_resp.status_code = 429 mock_generate_image_response_async = AsyncMock( - side_effect=RateLimitError("Rate Limit Reached", response=response, body="Rate limit reached") + side_effect=RateLimitError("Rate Limit Reached", response=mock_resp, body="Rate limit reached") ) setattr(mock_dalle_target, "_generate_image_response_async", mock_generate_image_response_async) @@ -272,10 +272,10 @@ async def test_send_prompt_async_bad_request_adds_memory() -> None: mock_dalle_target._memory = mock_memory # mocking openai.BadRequestError - response = MagicMock() - response.status_code = 400 + mock_resp = MagicMock() + mock_resp.status_code = 400 mock_generate_image_response_async = AsyncMock( - side_effect=BadRequestError("Bad Request", response=response, body="Bad Request") + side_effect=BadRequestError("Bad Request", response=mock_resp, body="Bad Request") ) setattr(mock_dalle_target, "_generate_image_response_async", mock_generate_image_response_async) From 577c5d49e752a615e09386a228767b16df0bf37f Mon Sep 17 00:00:00 2001 From: rdheekonda Date: Fri, 24 May 2024 13:08:59 -0700 Subject: [PATCH 3/5] Refactored exception handling with latest guidelines and updated tests --- .../orchestrator/red_teaming_orchestrator.py | 2 +- pyrit/prompt_target/dall_e_target.py | 53 +++++++------- tests/target/test_dall_e_target.py | 70 +++++++------------ 3 files changed, 52 insertions(+), 73 deletions(-) diff --git a/pyrit/orchestrator/red_teaming_orchestrator.py b/pyrit/orchestrator/red_teaming_orchestrator.py index 6e53102796..88e17934f4 100644 --- a/pyrit/orchestrator/red_teaming_orchestrator.py +++ b/pyrit/orchestrator/red_teaming_orchestrator.py @@ -296,7 +296,7 @@ def _get_prompt_for_red_teaming_chat(self, *, feedback: str | None) -> str: logger.info(f"Using the specified initial red teaming prompt: {self._initial_red_teaming_prompt}") return self._initial_red_teaming_prompt - if last_response_from_attack_target.converted_value_data_type == "text": + if last_response_from_attack_target.converted_value_data_type in ["text", "error"]: return self._handle_text_response(last_response_from_attack_target, feedback) return self._handle_file_response(last_response_from_attack_target, feedback) diff --git a/pyrit/prompt_target/dall_e_target.py b/pyrit/prompt_target/dall_e_target.py index 510d6fead2..7249673652 100644 --- a/pyrit/prompt_target/dall_e_target.py +++ b/pyrit/prompt_target/dall_e_target.py @@ -7,17 +7,17 @@ import asyncio from typing import Literal, Optional, Dict, Any -from openai import BadRequestError, RateLimitError +from openai import BadRequestError from pyrit.common.path import RESULTS_PATH -from pyrit.exceptions import EmptyResponseException, BadRequestException, RateLimitException, pyrit_retry +from pyrit.exceptions import EmptyResponseException, pyrit_retry +from pyrit.exceptions.exception_classes import handle_bad_request_exception from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import PromptRequestResponse, data_serializer_factory from pyrit.models.literals import PromptDataType -from pyrit.models.prompt_request_piece import PromptRequestPiece, PromptResponseError +from pyrit.models.prompt_request_piece import PromptRequestPiece from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.prompt_chat_target.openai_chat_target import AzureOpenAIChatTarget -from pyrit.common.constants import RETRY_MAX_NUM_ATTEMPTS logger = logging.getLogger(__name__) @@ -142,34 +142,29 @@ async def _generate_images_async(self, prompt: str, request=PromptRequestPiece) data = data_serializer_factory(data_type="image_path") data.save_b64_image(data=b64_data) resp_text = data.value - error: PromptResponseError = "none" response_type: PromptDataType = "image_path" + response_entry = self._memory.add_response_entries_to_memory( + request=request, response_text_pieces=[resp_text], response_type=response_type + ) + except BadRequestError as bre: - # Handle bad request error when content filter system detects harmful content - bad_request_exception = BadRequestException(bre.status_code, message=bre.message) - resp_text = bad_request_exception.process_exception() - error = "blocked" - response_type = "error" - - except RateLimitError as rle: - # Handle the rate limit exception after exhausting the maximum number of retries. - rate_limit_exception = RateLimitException(rle.status_code, message=rle.message) - resp_text = rate_limit_exception.process_exception() - error = "error" - response_type = "error" - - except EmptyResponseException: - # Handle the empty response exception after exhausting the maximum number of retries. - message = f"Empty response from the target even after {RETRY_MAX_NUM_ATTEMPTS} retries." - empty_response_exception = EmptyResponseException(message=message) - resp_text = empty_response_exception.process_exception() - error = "error" - response_type = "error" - - return self._memory.add_response_entries_to_memory( - request=request, response_text_pieces=[resp_text], response_type=response_type, error=error - ) + # # Handle bad request error when content filter system detects harmful content + # bad_request_exception = BadRequestException(bre.status_code, message=bre.message) + # resp_text = bad_request_exception.process_exception() + # error = "blocked" + # response_type = "error" + response_entry = handle_bad_request_exception( + memory=self._memory, response_text=bre.message, request=request + ) + + except Exception as ex: + self._memory.add_response_entries_to_memory( + request=request, response_text_pieces=[str(ex)], response_type="error", error="processing" + ) + raise + + return response_entry @pyrit_retry async def _generate_image_response_async(self, image_generation_args): diff --git a/tests/target/test_dall_e_target.py b/tests/target/test_dall_e_target.py index ca404547da..f7c90fd4f3 100644 --- a/tests/target/test_dall_e_target.py +++ b/tests/target/test_dall_e_target.py @@ -4,6 +4,7 @@ from unittest.mock import patch, MagicMock, AsyncMock import uuid import os +from pyrit.exceptions.exception_classes import EmptyResponseException import pytest from openai import BadRequestError, RateLimitError @@ -68,16 +69,10 @@ async def test_send_prompt_async_empty_response( mock_return.model_dump_json.return_value = '{"data": [{"b64_json": ""}]}' setattr(dalle_target._image_target._async_client.images, "generate", AsyncMock(return_value=mock_return)) constants.RETRY_MAX_NUM_ATTEMPTS = 5 - response: PromptRequestResponse = await dalle_target.send_prompt_async( - prompt_request=PromptRequestResponse([request]) - ) - assert len(response.request_pieces) == 1 - expected_error_message = '{"status_code": 204, "message": "Empty response from the target even after 5 retries."}' - assert response.request_pieces[0].converted_value == expected_error_message - assert response.request_pieces[0].converted_value_data_type == "error" - assert response.request_pieces[0].original_value == expected_error_message - assert response.request_pieces[0].original_value_data_type == "error" - assert str(constants.RETRY_MAX_NUM_ATTEMPTS) in response.request_pieces[0].converted_value + + with pytest.raises(EmptyResponseException) as e: + await dalle_target.send_prompt_async(prompt_request=PromptRequestResponse([request])) + assert str(e.value) == "Status Code: 204, Message: The chat returned an empty response." @pytest.mark.asyncio @@ -94,15 +89,9 @@ async def test_send_prompt_async_rate_limit_exception( ) setattr(dalle_target, "_generate_image_response_async", mock_image_resp_async) - result: PromptRequestResponse = await dalle_target.send_prompt_async( - prompt_request=PromptRequestResponse([request]) - ) - assert "Rate Limit Reached" in result.request_pieces[0].converted_value - assert "Rate Limit Reached" in result.request_pieces[0].original_value - assert result.request_pieces[0].original_value_data_type == "error" - assert result.request_pieces[0].converted_value_data_type == "error" - expected_sha_256 = "7d0ed53fb1c888e3467776735ee117e328c24f1a588a5f8756ba213c9b0b84a9" - assert result.request_pieces[0].original_value_sha256 == expected_sha_256 + with pytest.raises(RateLimitError): + await dalle_target.send_prompt_async(prompt_request=PromptRequestResponse([request])) + assert mock_image_resp_async.call_count == constants.RETRY_MAX_NUM_ATTEMPTS @pytest.mark.asyncio @@ -115,19 +104,12 @@ async def test_send_prompt_async_bad_request_error( response = MagicMock() response.status_code = 400 mock_image_resp_async = AsyncMock( - side_effect=RateLimitError("Bad Request Error", response=response, body="Bad Request") + side_effect=BadRequestError("Bad Request Error", response=response, body="Bad Request") ) setattr(dalle_target, "_generate_image_response_async", mock_image_resp_async) - - result: PromptRequestResponse = await dalle_target.send_prompt_async( - prompt_request=PromptRequestResponse([request]) - ) - assert "Bad Request Error" in result.request_pieces[0].converted_value - assert "Bad Request Error" in result.request_pieces[0].original_value - assert result.request_pieces[0].original_value_data_type == "error" - assert result.request_pieces[0].converted_value_data_type == "error" - expected_sha256 = "4e98b0da48c028f090473fe5cc71461a921465f807ae66c5f7ae9d0e9f301f77" - assert result.request_pieces[0].original_value_sha256 == expected_sha256 + with pytest.raises(BadRequestError) as bre: + await dalle_target.send_prompt_async(prompt_request=PromptRequestResponse([request])) + assert str(bre.value) == "Bad Request Error" @pytest.mark.asyncio @@ -222,9 +204,11 @@ async def test_send_prompt_async_empty_response_adds_memory() -> None: mock_dalle_target._image_target._async_client.images = MagicMock() mock_dalle_target._image_target._async_client.images.generate = AsyncMock(return_value=mock_return) mock_dalle_target._memory = mock_memory - response = await mock_dalle_target.send_prompt_async(prompt_request=request) - assert response is not None, "Expected a result but got None" - mock_memory.add_response_entries_to_memory.assert_called_once() + with pytest.raises(EmptyResponseException) as e: + await mock_dalle_target.send_prompt_async(prompt_request=request) + mock_memory.add_response_entries_to_memory.assert_called_once() + assert str(e.value) == "Status Code: 204, Message: The chat returned an empty response." + # mock_memory.add_response_entries_to_memory.assert_called_once() @pytest.mark.asyncio @@ -249,11 +233,11 @@ async def test_send_prompt_async_rate_limit_adds_memory() -> None: side_effect=RateLimitError("Rate Limit Reached", response=mock_resp, body="Rate limit reached") ) setattr(mock_dalle_target, "_generate_image_response_async", mock_generate_image_response_async) - - response = await mock_dalle_target.send_prompt_async(prompt_request=request) - assert response is not None - mock_dalle_target._memory.add_request_response_to_memory.assert_called_once() - mock_dalle_target._memory.add_response_entries_to_memory.assert_called_once() + with pytest.raises(RateLimitError) as rle: + await mock_dalle_target.send_prompt_async(prompt_request=request) + mock_dalle_target._memory.add_request_response_to_memory.assert_called_once() + mock_dalle_target._memory.add_response_entries_to_memory.assert_called_once() + assert str(rle.value) == "Rate Limit Reached" @pytest.mark.asyncio @@ -279,8 +263,8 @@ async def test_send_prompt_async_bad_request_adds_memory() -> None: ) setattr(mock_dalle_target, "_generate_image_response_async", mock_generate_image_response_async) - - response = await mock_dalle_target.send_prompt_async(prompt_request=request) - assert response is not None - mock_dalle_target._memory.add_request_response_to_memory.assert_called_once() - mock_dalle_target._memory.add_response_entries_to_memory.assert_called_once() + with pytest.raises(BadRequestError) as bre: + await mock_dalle_target.send_prompt_async(prompt_request=request) + mock_dalle_target._memory.add_request_response_to_memory.assert_called_once() + mock_dalle_target._memory.add_response_entries_to_memory.assert_called_once() + assert str(bre.value) == "Bad Request" From a98223111244e53469c02b7fa962411d63b48977 Mon Sep 17 00:00:00 2001 From: rdheekonda Date: Fri, 24 May 2024 13:10:28 -0700 Subject: [PATCH 4/5] Removed commented lines --- pyrit/prompt_target/dall_e_target.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pyrit/prompt_target/dall_e_target.py b/pyrit/prompt_target/dall_e_target.py index 7249673652..0678f61fcd 100644 --- a/pyrit/prompt_target/dall_e_target.py +++ b/pyrit/prompt_target/dall_e_target.py @@ -149,11 +149,6 @@ async def _generate_images_async(self, prompt: str, request=PromptRequestPiece) ) except BadRequestError as bre: - # # Handle bad request error when content filter system detects harmful content - # bad_request_exception = BadRequestException(bre.status_code, message=bre.message) - # resp_text = bad_request_exception.process_exception() - # error = "blocked" - # response_type = "error" response_entry = handle_bad_request_exception( memory=self._memory, response_text=bre.message, request=request ) From cb19ee6403e5bc5b2da1ab01bf04eb7397e06eb3 Mon Sep 17 00:00:00 2001 From: rdheekonda Date: Fri, 24 May 2024 13:13:53 -0700 Subject: [PATCH 5/5] Remove commented line --- tests/target/test_dall_e_target.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/target/test_dall_e_target.py b/tests/target/test_dall_e_target.py index f7c90fd4f3..9260f26fc5 100644 --- a/tests/target/test_dall_e_target.py +++ b/tests/target/test_dall_e_target.py @@ -208,7 +208,6 @@ async def test_send_prompt_async_empty_response_adds_memory() -> None: await mock_dalle_target.send_prompt_async(prompt_request=request) mock_memory.add_response_entries_to_memory.assert_called_once() assert str(e.value) == "Status Code: 204, Message: The chat returned an empty response." - # mock_memory.add_response_entries_to_memory.assert_called_once() @pytest.mark.asyncio