diff --git a/pyrit/orchestrator/red_teaming_orchestrator.py b/pyrit/orchestrator/red_teaming_orchestrator.py index 6e53102796..88e17934f4 100644 --- a/pyrit/orchestrator/red_teaming_orchestrator.py +++ b/pyrit/orchestrator/red_teaming_orchestrator.py @@ -296,7 +296,7 @@ def _get_prompt_for_red_teaming_chat(self, *, feedback: str | None) -> str: logger.info(f"Using the specified initial red teaming prompt: {self._initial_red_teaming_prompt}") return self._initial_red_teaming_prompt - if last_response_from_attack_target.converted_value_data_type == "text": + if last_response_from_attack_target.converted_value_data_type in ["text", "error"]: return self._handle_text_response(last_response_from_attack_target, feedback) return self._handle_file_response(last_response_from_attack_target, feedback) diff --git a/pyrit/prompt_target/dall_e_target.py b/pyrit/prompt_target/dall_e_target.py index cecfe9fa82..0678f61fcd 100644 --- a/pyrit/prompt_target/dall_e_target.py +++ b/pyrit/prompt_target/dall_e_target.py @@ -10,10 +10,12 @@ from openai import BadRequestError from pyrit.common.path import RESULTS_PATH +from pyrit.exceptions import EmptyResponseException, pyrit_retry +from pyrit.exceptions.exception_classes import handle_bad_request_exception from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import PromptRequestResponse, data_serializer_factory from pyrit.models.literals import PromptDataType -from pyrit.models.prompt_request_piece import PromptRequestPiece, PromptResponseError +from pyrit.models.prompt_request_piece import PromptRequestPiece from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.prompt_chat_target.openai_chat_target import AzureOpenAIChatTarget @@ -136,42 +138,55 @@ async def _generate_images_async(self, prompt: str, request=PromptRequestPiece) image_generation_args["style"] = self.style try: - response = await self._image_target._async_client.images.generate(**image_generation_args) - json_response = json.loads(response.model_dump_json()) - + b64_data = await self._generate_image_response_async(image_generation_args) data = data_serializer_factory(data_type="image_path") - b64_data = json_response["data"][0]["b64_json"] data.save_b64_image(data=b64_data) - prompt_text = data.value - error: PromptResponseError = "none" + resp_text = data.value response_type: PromptDataType = "image_path" - except BadRequestError as e: - json_response = {"exception type": "Blocked", "data": ""} - json_response["error"] = e.body - prompt_text = "content blocked" - error = "blocked" - response_type = "text" - - except json.JSONDecodeError as e: - json_response = {"error": e, "exception type": "JSON Error"} - prompt_text = "JSON Error" - error = "processing" - response_type = "text" - - except Exception as e: - json_response = {"error": e, "exception type": "exception"} - prompt_text = "target error" - error = "unknown" - response_type = "text" - - return self._memory.add_response_entries_to_memory( - request=request, - response_text_pieces=[prompt_text], - response_type=response_type, - prompt_metadata=json.dumps(json_response), - error=error, - ) + response_entry = self._memory.add_response_entries_to_memory( + request=request, response_text_pieces=[resp_text], response_type=response_type + ) + + except BadRequestError as bre: + response_entry = handle_bad_request_exception( + memory=self._memory, response_text=bre.message, request=request + ) + + except Exception as ex: + self._memory.add_response_entries_to_memory( + request=request, response_text_pieces=[str(ex)], response_type="error", error="processing" + ) + raise + + return response_entry + + @pyrit_retry + async def _generate_image_response_async(self, image_generation_args): + """ + Asynchronously generates an image using the provided generation arguments. + + Retries the function if it raises RateLimitError (HTTP 429) or EmptyResponseException, + with a wait time between retries that follows an exponential backoff strategy. + Logs retry attempts at the INFO level and stops after a maximum number of attempts. + + Args: + image_generation_args (dict): The arguments required for image generation. + + Returns: + The generated image in base64 format. + + Raises: + RateLimitError: If the rate limit is exceeded and the maximum number of retries is exhausted. + EmptyResponseException: If the response is empty after exhausting the maximum number of retries. + """ + result = await self._image_target._async_client.images.generate(**image_generation_args) + json_response = json.loads(result.model_dump_json()) + b64_data = json_response["data"][0]["b64_json"] + # Handle empty response using retry + if not b64_data: + raise EmptyResponseException(message="The chat returned an empty response.") + return b64_data def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None: if len(prompt_request.request_pieces) != 1: diff --git a/tests/target/test_dall_e_target.py b/tests/target/test_dall_e_target.py index 59f9d4c849..9260f26fc5 100644 --- a/tests/target/test_dall_e_target.py +++ b/tests/target/test_dall_e_target.py @@ -4,13 +4,16 @@ from unittest.mock import patch, MagicMock, AsyncMock import uuid import os +from pyrit.exceptions.exception_classes import EmptyResponseException import pytest +from openai import BadRequestError, RateLimitError + from pyrit.models.prompt_request_piece import PromptRequestPiece from pyrit.models import PromptRequestResponse from pyrit.prompt_target import DALLETarget - from tests.mocks import get_sample_conversations +from pyrit.common import constants @pytest.fixture @@ -54,6 +57,61 @@ async def test_send_prompt_async(mock_image, dalle_target: DALLETarget, sample_c assert resp +@pytest.mark.asyncio +async def test_send_prompt_async_empty_response( + dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece] +): + request = sample_conversations[0] + request.conversation_id = str(uuid.uuid4()) + + mock_return = MagicMock() + # make b64_json value empty to test retries when empty response was returned + mock_return.model_dump_json.return_value = '{"data": [{"b64_json": ""}]}' + setattr(dalle_target._image_target._async_client.images, "generate", AsyncMock(return_value=mock_return)) + constants.RETRY_MAX_NUM_ATTEMPTS = 5 + + with pytest.raises(EmptyResponseException) as e: + await dalle_target.send_prompt_async(prompt_request=PromptRequestResponse([request])) + assert str(e.value) == "Status Code: 204, Message: The chat returned an empty response." + + +@pytest.mark.asyncio +async def test_send_prompt_async_rate_limit_exception( + dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece] +): + request = sample_conversations[0] + request.conversation_id = str(uuid.uuid4()) + + response = MagicMock() + response.status_code = 429 + mock_image_resp_async = AsyncMock( + side_effect=RateLimitError("Rate Limit Reached", response=response, body="Rate limit reached") + ) + setattr(dalle_target, "_generate_image_response_async", mock_image_resp_async) + + with pytest.raises(RateLimitError): + await dalle_target.send_prompt_async(prompt_request=PromptRequestResponse([request])) + assert mock_image_resp_async.call_count == constants.RETRY_MAX_NUM_ATTEMPTS + + +@pytest.mark.asyncio +async def test_send_prompt_async_bad_request_error( + dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece] +): + request = sample_conversations[0] + request.conversation_id = str(uuid.uuid4()) + + response = MagicMock() + response.status_code = 400 + mock_image_resp_async = AsyncMock( + side_effect=BadRequestError("Bad Request Error", response=response, body="Bad Request") + ) + setattr(dalle_target, "_generate_image_response_async", mock_image_resp_async) + with pytest.raises(BadRequestError) as bre: + await dalle_target.send_prompt_async(prompt_request=PromptRequestResponse([request])) + assert str(bre.value) == "Bad Request Error" + + @pytest.mark.asyncio async def test_dalle_validate_request_length(dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece]): request = PromptRequestResponse(request_pieces=sample_conversations) @@ -123,3 +181,89 @@ async def test_dalle_send_prompt_adds_memory_async() -> None: await mock_dalle_target.send_prompt_async(prompt_request=request) assert mock_memory.add_request_response_to_memory.called, "Request and Response need to be added to memory" assert mock_memory.add_response_entries_to_memory.called, "Request and Response need to be added to memory" + + +@pytest.mark.asyncio +async def test_send_prompt_async_empty_response_adds_memory() -> None: + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + mock_memory.add_request_response_to_memory = AsyncMock() + mock_memory.add_response_entries_to_memory = AsyncMock() + request = PromptRequestPiece( + role="user", + original_value="draw me a test picture", + ).to_prompt_request_response() + + mock_return = MagicMock() + + # b64_json with empty response + mock_return.model_dump_json.return_value = '{"data": [{"b64_json": ""}]}' + + mock_dalle_target = DALLETarget(deployment_name="test", endpoint="test", api_key="test", memory=mock_memory) + mock_dalle_target._image_target._async_client.images = MagicMock() + mock_dalle_target._image_target._async_client.images.generate = AsyncMock(return_value=mock_return) + mock_dalle_target._memory = mock_memory + with pytest.raises(EmptyResponseException) as e: + await mock_dalle_target.send_prompt_async(prompt_request=request) + mock_memory.add_response_entries_to_memory.assert_called_once() + assert str(e.value) == "Status Code: 204, Message: The chat returned an empty response." + + +@pytest.mark.asyncio +async def test_send_prompt_async_rate_limit_adds_memory() -> None: + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + mock_memory.add_request_response_to_memory = AsyncMock() + mock_memory.add_response_entries_to_memory = AsyncMock() + request = PromptRequestPiece( + role="user", + original_value="draw me a test picture", + ).to_prompt_request_response() + + mock_dalle_target = DALLETarget(deployment_name="test", endpoint="test", api_key="test", memory=mock_memory) + mock_dalle_target._memory = mock_memory + + # mocking openai.RateLimitError + mock_resp = MagicMock() + mock_resp.status_code = 429 + mock_generate_image_response_async = AsyncMock( + side_effect=RateLimitError("Rate Limit Reached", response=mock_resp, body="Rate limit reached") + ) + setattr(mock_dalle_target, "_generate_image_response_async", mock_generate_image_response_async) + with pytest.raises(RateLimitError) as rle: + await mock_dalle_target.send_prompt_async(prompt_request=request) + mock_dalle_target._memory.add_request_response_to_memory.assert_called_once() + mock_dalle_target._memory.add_response_entries_to_memory.assert_called_once() + assert str(rle.value) == "Rate Limit Reached" + + +@pytest.mark.asyncio +async def test_send_prompt_async_bad_request_adds_memory() -> None: + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = [] + mock_memory.add_request_response_to_memory = AsyncMock() + mock_memory.add_response_entries_to_memory = AsyncMock() + request = PromptRequestPiece( + role="user", + original_value="draw me a test picture", + ).to_prompt_request_response() + + mock_dalle_target = DALLETarget(deployment_name="test", endpoint="test", api_key="test", memory=mock_memory) + mock_dalle_target._memory = mock_memory + + # mocking openai.BadRequestError + mock_resp = MagicMock() + mock_resp.status_code = 400 + mock_generate_image_response_async = AsyncMock( + side_effect=BadRequestError("Bad Request", response=mock_resp, body="Bad Request") + ) + + setattr(mock_dalle_target, "_generate_image_response_async", mock_generate_image_response_async) + with pytest.raises(BadRequestError) as bre: + await mock_dalle_target.send_prompt_async(prompt_request=request) + mock_dalle_target._memory.add_request_response_to_memory.assert_called_once() + mock_dalle_target._memory.add_response_entries_to_memory.assert_called_once() + assert str(bre.value) == "Bad Request"