Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyrit/orchestrator/red_teaming_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _get_prompt_for_red_teaming_chat(self, *, feedback: str | None) -> str:
logger.info(f"Using the specified initial red teaming prompt: {self._initial_red_teaming_prompt}")
return self._initial_red_teaming_prompt

if last_response_from_attack_target.converted_value_data_type == "text":
if last_response_from_attack_target.converted_value_data_type in ["text", "error"]:
return self._handle_text_response(last_response_from_attack_target, feedback)

return self._handle_file_response(last_response_from_attack_target, feedback)
Expand Down
81 changes: 48 additions & 33 deletions pyrit/prompt_target/dall_e_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from openai import BadRequestError

from pyrit.common.path import RESULTS_PATH
from pyrit.exceptions import EmptyResponseException, pyrit_retry
from pyrit.exceptions.exception_classes import handle_bad_request_exception
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.models import PromptRequestResponse, data_serializer_factory
from pyrit.models.literals import PromptDataType
from pyrit.models.prompt_request_piece import PromptRequestPiece, PromptResponseError
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.prompt_target import PromptTarget
from pyrit.prompt_target.prompt_chat_target.openai_chat_target import AzureOpenAIChatTarget

Expand Down Expand Up @@ -136,42 +138,55 @@ async def _generate_images_async(self, prompt: str, request=PromptRequestPiece)
image_generation_args["style"] = self.style

try:
response = await self._image_target._async_client.images.generate(**image_generation_args)
json_response = json.loads(response.model_dump_json())

b64_data = await self._generate_image_response_async(image_generation_args)
data = data_serializer_factory(data_type="image_path")
b64_data = json_response["data"][0]["b64_json"]
data.save_b64_image(data=b64_data)
prompt_text = data.value
error: PromptResponseError = "none"
resp_text = data.value
response_type: PromptDataType = "image_path"

except BadRequestError as e:
json_response = {"exception type": "Blocked", "data": ""}
json_response["error"] = e.body
prompt_text = "content blocked"
error = "blocked"
response_type = "text"

except json.JSONDecodeError as e:
json_response = {"error": e, "exception type": "JSON Error"}
prompt_text = "JSON Error"
error = "processing"
response_type = "text"

except Exception as e:
json_response = {"error": e, "exception type": "exception"}
prompt_text = "target error"
error = "unknown"
response_type = "text"
Comment thread
rdheekonda marked this conversation as resolved.

return self._memory.add_response_entries_to_memory(
request=request,
response_text_pieces=[prompt_text],
response_type=response_type,
prompt_metadata=json.dumps(json_response),
error=error,
)
response_entry = self._memory.add_response_entries_to_memory(
request=request, response_text_pieces=[resp_text], response_type=response_type
)

except BadRequestError as bre:
response_entry = handle_bad_request_exception(
memory=self._memory, response_text=bre.message, request=request
)

except Exception as ex:
self._memory.add_response_entries_to_memory(
request=request, response_text_pieces=[str(ex)], response_type="error", error="processing"
)
raise

return response_entry

@pyrit_retry
async def _generate_image_response_async(self, image_generation_args):
"""
Asynchronously generates an image using the provided generation arguments.

Retries the function if it raises RateLimitError (HTTP 429) or EmptyResponseException,
with a wait time between retries that follows an exponential backoff strategy.
Logs retry attempts at the INFO level and stops after a maximum number of attempts.

Args:
image_generation_args (dict): The arguments required for image generation.

Returns:
The generated image in base64 format.

Raises:
RateLimitError: If the rate limit is exceeded and the maximum number of retries is exhausted.
EmptyResponseException: If the response is empty after exhausting the maximum number of retries.
"""
result = await self._image_target._async_client.images.generate(**image_generation_args)
json_response = json.loads(result.model_dump_json())
b64_data = json_response["data"][0]["b64_json"]
# Handle empty response using retry
if not b64_data:
raise EmptyResponseException(message="The chat returned an empty response.")
return b64_data

def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
if len(prompt_request.request_pieces) != 1:
Expand Down
146 changes: 145 additions & 1 deletion tests/target/test_dall_e_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
from unittest.mock import patch, MagicMock, AsyncMock
import uuid
import os
from pyrit.exceptions.exception_classes import EmptyResponseException
import pytest

from openai import BadRequestError, RateLimitError

from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.models import PromptRequestResponse
from pyrit.prompt_target import DALLETarget

from tests.mocks import get_sample_conversations
from pyrit.common import constants


@pytest.fixture
Expand Down Expand Up @@ -54,6 +57,61 @@ async def test_send_prompt_async(mock_image, dalle_target: DALLETarget, sample_c
assert resp


@pytest.mark.asyncio
async def test_send_prompt_async_empty_response(
dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece]
):
request = sample_conversations[0]
request.conversation_id = str(uuid.uuid4())

mock_return = MagicMock()
# make b64_json value empty to test retries when empty response was returned
mock_return.model_dump_json.return_value = '{"data": [{"b64_json": ""}]}'
setattr(dalle_target._image_target._async_client.images, "generate", AsyncMock(return_value=mock_return))
constants.RETRY_MAX_NUM_ATTEMPTS = 5

with pytest.raises(EmptyResponseException) as e:
await dalle_target.send_prompt_async(prompt_request=PromptRequestResponse([request]))
assert str(e.value) == "Status Code: 204, Message: The chat returned an empty response."


@pytest.mark.asyncio
async def test_send_prompt_async_rate_limit_exception(
dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece]
):
request = sample_conversations[0]
request.conversation_id = str(uuid.uuid4())

response = MagicMock()
response.status_code = 429
mock_image_resp_async = AsyncMock(
side_effect=RateLimitError("Rate Limit Reached", response=response, body="Rate limit reached")
)
setattr(dalle_target, "_generate_image_response_async", mock_image_resp_async)

with pytest.raises(RateLimitError):
await dalle_target.send_prompt_async(prompt_request=PromptRequestResponse([request]))
assert mock_image_resp_async.call_count == constants.RETRY_MAX_NUM_ATTEMPTS


@pytest.mark.asyncio
async def test_send_prompt_async_bad_request_error(
dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece]
):
request = sample_conversations[0]
request.conversation_id = str(uuid.uuid4())

response = MagicMock()
response.status_code = 400
mock_image_resp_async = AsyncMock(
side_effect=BadRequestError("Bad Request Error", response=response, body="Bad Request")
)
setattr(dalle_target, "_generate_image_response_async", mock_image_resp_async)
with pytest.raises(BadRequestError) as bre:
await dalle_target.send_prompt_async(prompt_request=PromptRequestResponse([request]))
assert str(bre.value) == "Bad Request Error"


@pytest.mark.asyncio
async def test_dalle_validate_request_length(dalle_target: DALLETarget, sample_conversations: list[PromptRequestPiece]):
request = PromptRequestResponse(request_pieces=sample_conversations)
Expand Down Expand Up @@ -123,3 +181,89 @@ async def test_dalle_send_prompt_adds_memory_async() -> None:
await mock_dalle_target.send_prompt_async(prompt_request=request)
assert mock_memory.add_request_response_to_memory.called, "Request and Response need to be added to memory"
assert mock_memory.add_response_entries_to_memory.called, "Request and Response need to be added to memory"


@pytest.mark.asyncio
async def test_send_prompt_async_empty_response_adds_memory() -> None:

mock_memory = MagicMock()
mock_memory.get_conversation.return_value = []
mock_memory.add_request_response_to_memory = AsyncMock()
mock_memory.add_response_entries_to_memory = AsyncMock()
request = PromptRequestPiece(
role="user",
original_value="draw me a test picture",
).to_prompt_request_response()

mock_return = MagicMock()

# b64_json with empty response
mock_return.model_dump_json.return_value = '{"data": [{"b64_json": ""}]}'

mock_dalle_target = DALLETarget(deployment_name="test", endpoint="test", api_key="test", memory=mock_memory)
mock_dalle_target._image_target._async_client.images = MagicMock()
mock_dalle_target._image_target._async_client.images.generate = AsyncMock(return_value=mock_return)
mock_dalle_target._memory = mock_memory
with pytest.raises(EmptyResponseException) as e:
await mock_dalle_target.send_prompt_async(prompt_request=request)
mock_memory.add_response_entries_to_memory.assert_called_once()
assert str(e.value) == "Status Code: 204, Message: The chat returned an empty response."


@pytest.mark.asyncio
async def test_send_prompt_async_rate_limit_adds_memory() -> None:

mock_memory = MagicMock()
mock_memory.get_conversation.return_value = []
mock_memory.add_request_response_to_memory = AsyncMock()
mock_memory.add_response_entries_to_memory = AsyncMock()
request = PromptRequestPiece(
role="user",
original_value="draw me a test picture",
).to_prompt_request_response()

mock_dalle_target = DALLETarget(deployment_name="test", endpoint="test", api_key="test", memory=mock_memory)
mock_dalle_target._memory = mock_memory

# mocking openai.RateLimitError
mock_resp = MagicMock()
mock_resp.status_code = 429
mock_generate_image_response_async = AsyncMock(
side_effect=RateLimitError("Rate Limit Reached", response=mock_resp, body="Rate limit reached")
)
setattr(mock_dalle_target, "_generate_image_response_async", mock_generate_image_response_async)
with pytest.raises(RateLimitError) as rle:
await mock_dalle_target.send_prompt_async(prompt_request=request)
mock_dalle_target._memory.add_request_response_to_memory.assert_called_once()
mock_dalle_target._memory.add_response_entries_to_memory.assert_called_once()
assert str(rle.value) == "Rate Limit Reached"


@pytest.mark.asyncio
async def test_send_prompt_async_bad_request_adds_memory() -> None:

mock_memory = MagicMock()
mock_memory.get_conversation.return_value = []
mock_memory.add_request_response_to_memory = AsyncMock()
mock_memory.add_response_entries_to_memory = AsyncMock()
request = PromptRequestPiece(
role="user",
original_value="draw me a test picture",
).to_prompt_request_response()

mock_dalle_target = DALLETarget(deployment_name="test", endpoint="test", api_key="test", memory=mock_memory)
mock_dalle_target._memory = mock_memory

# mocking openai.BadRequestError
mock_resp = MagicMock()
mock_resp.status_code = 400
mock_generate_image_response_async = AsyncMock(
side_effect=BadRequestError("Bad Request", response=mock_resp, body="Bad Request")
)

setattr(mock_dalle_target, "_generate_image_response_async", mock_generate_image_response_async)
with pytest.raises(BadRequestError) as bre:
await mock_dalle_target.send_prompt_async(prompt_request=request)
mock_dalle_target._memory.add_request_response_to_memory.assert_called_once()
mock_dalle_target._memory.add_response_entries_to_memory.assert_called_once()
assert str(bre.value) == "Bad Request"