Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion pyrit/common/data_url_converter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.common.deprecation import print_deprecation_message
from pyrit.models import DataTypeSerializer, data_serializer_factory

# Supported image formats for Azure OpenAI GPT-4o,
# https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/use-your-image-data
AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif"]


async def convert_local_image_to_data_url(image_path: str) -> str:
async def convert_local_image_to_data_url_async(image_path: str) -> str:
"""
Convert a local image file to a data URL encoded in base64.

Expand Down Expand Up @@ -42,3 +43,18 @@ async def convert_local_image_to_data_url(image_path: str) -> str:
# Construct the data URL, as per Azure OpenAI GPT-4 Turbo local image format
# https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/gpt-with-vision?tabs=rest%2Csystem-assigned%2Cresource#call-the-chat-completion-apis
return f"data:{mime_type};base64,{base64_encoded_data}"


async def convert_local_image_to_data_url(image_path: str) -> str:
"""
Delegate to :func:`convert_local_image_to_data_url_async` (deprecated alias).

Returns:
str: A string containing the MIME type and the base64-encoded data of the image, formatted as a data URL.
"""
print_deprecation_message(
old_item="pyrit.common.data_url_converter.convert_local_image_to_data_url",
new_item="pyrit.common.data_url_converter.convert_local_image_to_data_url_async",
removed_in="0.16.0",
)
return await convert_local_image_to_data_url_async(image_path)
13 changes: 12 additions & 1 deletion pyrit/common/display_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

from PIL import Image

from pyrit.common.deprecation import print_deprecation_message
from pyrit.common.notebook_utils import is_in_ipython_session
from pyrit.memory import CentralMemory
from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece

logger = logging.getLogger(__name__)


async def display_image_response(response_piece: MessagePiece) -> None:
async def display_image_response_async(response_piece: MessagePiece) -> None:
"""
Display response images if running in notebook environment.

Expand Down Expand Up @@ -54,3 +55,13 @@ async def display_image_response(response_piece: MessagePiece) -> None:
display(image) # type: ignore[ty:unresolved-reference] # noqa: F821
if response_piece.response_error == "blocked":
logger.info("---\nContent blocked, cannot show a response.\n---")


async def display_image_response(response_piece: MessagePiece) -> None:
"""Delegate to :func:`display_image_response_async` (deprecated alias)."""
print_deprecation_message(
old_item="pyrit.common.display_response.display_image_response",
new_item="pyrit.common.display_response.display_image_response_async",
removed_in="0.16.0",
)
await display_image_response_async(response_piece)
71 changes: 62 additions & 9 deletions pyrit/common/download_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import httpx
from huggingface_hub import HfApi

from pyrit.common.deprecation import print_deprecation_message

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -37,7 +39,9 @@ def get_available_files(model_id: str, token: str) -> list[str]:
return []


async def download_specific_files(model_id: str, file_patterns: list[str] | None, token: str, cache_dir: Path) -> None:
async def download_specific_files_async(
model_id: str, file_patterns: list[str] | None, token: str, cache_dir: Path
) -> None:
"""
Download specific files from a Hugging Face model repository.
If file_patterns is None, downloads all files.
Expand All @@ -61,10 +65,12 @@ async def download_specific_files(model_id: str, file_patterns: list[str] | None
urls = [base_url + file for file in files_to_download]

# Download the files
await download_files(urls, token, cache_dir)
await download_files_async(urls, token, cache_dir)


async def download_chunk(url: str, headers: dict[str, str], start: int, end: int, client: httpx.AsyncClient) -> bytes:
async def download_chunk_async(
url: str, headers: dict[str, str], start: int, end: int, client: httpx.AsyncClient
) -> bytes:
"""
Download a chunk of the file with a specified byte range.

Expand All @@ -77,7 +83,7 @@ async def download_chunk(url: str, headers: dict[str, str], start: int, end: int
return response.content


async def download_file(url: str, token: str, download_dir: Path, num_splits: int) -> None:
async def download_file_async(url: str, token: str, download_dir: Path, num_splits: int) -> None:
"""Download a file in multiple segments (splits) using byte-range requests."""
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(follow_redirects=True) as client:
Expand All @@ -95,7 +101,7 @@ async def download_file(url: str, token: str, download_dir: Path, num_splits: in
for i in range(num_splits):
start = i * chunk_size
end = start + chunk_size - 1 if i < num_splits - 1 else file_size - 1
tasks.append(download_chunk(url, headers, start, end, client))
tasks.append(download_chunk_async(url, headers, start, end, client))

# Download all chunks concurrently
chunks = await asyncio.gather(*tasks)
Expand All @@ -107,16 +113,63 @@ async def download_file(url: str, token: str, download_dir: Path, num_splits: in
logger.info(f"Downloaded {file_name} to {file_path}")


async def download_files(
async def download_files_async(
urls: list[str], token: str, download_dir: Path, num_splits: int = 3, parallel_downloads: int = 4
) -> None:
"""Download multiple files with parallel downloads and segmented downloading."""
# Limit the number of parallel downloads
semaphore = asyncio.Semaphore(parallel_downloads)

async def download_with_limit(url: str) -> None:
async def download_with_limit_async(url: str) -> None:
async with semaphore:
await download_file(url, token, download_dir, num_splits)
await download_file_async(url, token, download_dir, num_splits)

# Run downloads concurrently, but limit to parallel_downloads at a time
await asyncio.gather(*(download_with_limit(url) for url in urls))
await asyncio.gather(*(download_with_limit_async(url) for url in urls))


async def download_specific_files(model_id: str, file_patterns: list[str] | None, token: str, cache_dir: Path) -> None:
"""Delegate to :func:`download_specific_files_async` (deprecated alias)."""
print_deprecation_message(
old_item="pyrit.common.download_hf_model.download_specific_files",
new_item="pyrit.common.download_hf_model.download_specific_files_async",
removed_in="0.16.0",
)
await download_specific_files_async(model_id, file_patterns, token, cache_dir)


async def download_chunk(url: str, headers: dict[str, str], start: int, end: int, client: httpx.AsyncClient) -> bytes:
"""
Delegate to :func:`download_chunk_async` (deprecated alias).

Returns:
The content of the downloaded chunk.
"""
print_deprecation_message(
old_item="pyrit.common.download_hf_model.download_chunk",
new_item="pyrit.common.download_hf_model.download_chunk_async",
removed_in="0.16.0",
)
return await download_chunk_async(url, headers, start, end, client)


async def download_file(url: str, token: str, download_dir: Path, num_splits: int) -> None:
"""Delegate to :func:`download_file_async` (deprecated alias)."""
print_deprecation_message(
old_item="pyrit.common.download_hf_model.download_file",
new_item="pyrit.common.download_hf_model.download_file_async",
removed_in="0.16.0",
)
await download_file_async(url, token, download_dir, num_splits)


async def download_files(
urls: list[str], token: str, download_dir: Path, num_splits: int = 3, parallel_downloads: int = 4
) -> None:
"""Delegate to :func:`download_files_async` (deprecated alias)."""
print_deprecation_message(
old_item="pyrit.common.download_hf_model.download_files",
new_item="pyrit.common.download_hf_model.download_files_async",
removed_in="0.16.0",
)
await download_files_async(urls, token, download_dir, num_splits, parallel_downloads)
4 changes: 2 additions & 2 deletions pyrit/message_normalizer/chat_message_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from typing import TYPE_CHECKING, Any, Union

from pyrit.common.data_url_converter import convert_local_image_to_data_url
from pyrit.common.data_url_converter import convert_local_image_to_data_url_async
from pyrit.message_normalizer.message_normalizer import (
MessageListNormalizer,
MessageStringNormalizer,
Expand Down Expand Up @@ -140,7 +140,7 @@ async def _piece_to_content_dict_async(self, piece: MessagePiece) -> dict[str, A
return {"type": "text", "text": content}
if data_type == "image_path":
# Convert local image to base64 data URL
data_url = await convert_local_image_to_data_url(content)
data_url = await convert_local_image_to_data_url_async(content)
return {"type": "image_url", "image_url": {"url": data_url}}
if data_type == "audio_path":
# Convert local audio to base64 for input_audio format
Expand Down
11 changes: 8 additions & 3 deletions pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)

from pyrit.common import default_values
from pyrit.common.download_hf_model import download_specific_files
from pyrit.common.download_hf_model import download_specific_files_async
from pyrit.exceptions import EmptyResponseException, pyrit_target_retry
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import Message, construct_response_from_request
Expand Down Expand Up @@ -280,11 +280,16 @@ async def load_model_and_tokenizer(self) -> None:
if self.necessary_files is None:
# Download all files if no specific files are provided
logger.info(f"Downloading all files for {self.model_id}...")
await download_specific_files(self.model_id or "", None, self.huggingface_token, Path(cache_dir)) # type: ignore[ty:invalid-argument-type]
await download_specific_files_async(
self.model_id or "",
None,
self.huggingface_token, # type: ignore[ty:invalid-argument-type]
Path(cache_dir),
)
else:
# Download only the necessary files
logger.info(f"Downloading specific files for {self.model_id}...")
await download_specific_files(
await download_specific_files_async(
self.model_id or "",
self.necessary_files,
self.huggingface_token, # type: ignore[ty:invalid-argument-type]
Expand Down
4 changes: 2 additions & 2 deletions pyrit/prompt_target/openai/openai_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import replace
from typing import Any, Optional

from pyrit.common.data_url_converter import convert_local_image_to_data_url
from pyrit.common.data_url_converter import convert_local_image_to_data_url_async
from pyrit.exceptions import (
EmptyResponseException,
PyritException,
Expand Down Expand Up @@ -641,7 +641,7 @@ async def _build_chat_messages_for_multi_modal_async(
entry = {"type": "text", "text": message_piece.converted_value}
content.append(entry)
elif message_piece.converted_value_data_type == "image_path":
data_base64_encoded_url = await convert_local_image_to_data_url(message_piece.converted_value)
data_base64_encoded_url = await convert_local_image_to_data_url_async(message_piece.converted_value)
image_url_entry = {"url": data_base64_encoded_url}
entry = {"type": "image_url", "image_url": image_url_entry}
content.append(entry)
Expand Down
4 changes: 2 additions & 2 deletions pyrit/prompt_target/openai/openai_response_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from openai.types.shared import ReasoningEffort

from pyrit.common.data_url_converter import convert_local_image_to_data_url
from pyrit.common.data_url_converter import convert_local_image_to_data_url_async
from pyrit.exceptions import (
EmptyResponseException,
PyritException,
Expand Down Expand Up @@ -247,7 +247,7 @@ async def _construct_input_item_from_piece(self, piece: MessagePiece) -> dict[st
"text": piece.converted_value,
}
if piece.converted_value_data_type == "image_path":
data_url = await convert_local_image_to_data_url(piece.converted_value)
data_url = await convert_local_image_to_data_url_async(piece.converted_value)
return {"type": "input_image", "image_url": {"url": data_url}}
raise ValueError(f"Unsupported piece type for inline content: {piece.converted_value_data_type}")

Expand Down
4 changes: 2 additions & 2 deletions pyrit/prompt_target/websocket_copilot_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from websockets.exceptions import InvalidStatus

from pyrit.auth import CopilotAuthenticator, ManualCopilotAuthenticator
from pyrit.common.data_url_converter import convert_local_image_to_data_url
from pyrit.common.data_url_converter import convert_local_image_to_data_url_async
from pyrit.exceptions import (
EmptyResponseException,
pyrit_target_retry,
Expand Down Expand Up @@ -332,7 +332,7 @@ async def _process_image_piece_async(self, *, image_path: str, copilot_conversat
Returns:
dict: Message annotation structure for the uploaded image.
"""
data_url = await convert_local_image_to_data_url(image_path)
data_url = await convert_local_image_to_data_url_async(image_path)

normalized_image_path = image_path.replace("\\", "/")
file_name = pathlib.Path(normalized_image_path).name
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/common/test_convert_local_image_to_data_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

import pytest

from pyrit.common.data_url_converter import convert_local_image_to_data_url
from pyrit.common.data_url_converter import convert_local_image_to_data_url_async
from pyrit.memory.sqlite_memory import SQLiteMemory


async def test_convert_image_to_data_url_file_not_found():
with pytest.raises(FileNotFoundError):
await convert_local_image_to_data_url("nonexistent.jpg")
await convert_local_image_to_data_url_async("nonexistent.jpg")


async def test_convert_image_with_unsupported_extension():
Expand All @@ -23,7 +23,7 @@ async def test_convert_image_with_unsupported_extension():
assert os.path.exists(tmp_file_name)

with pytest.raises(ValueError) as exc_info:
await convert_local_image_to_data_url(tmp_file_name)
await convert_local_image_to_data_url_async(tmp_file_name)

assert "Unsupported image format" in str(exc_info.value)

Expand All @@ -36,7 +36,7 @@ async def test_convert_local_image_to_data_url_unsupported_format():
tmp_file_name = tmp_file.name
try:
with pytest.raises(ValueError) as excinfo:
await convert_local_image_to_data_url(tmp_file_name)
await convert_local_image_to_data_url_async(tmp_file_name)
assert "Unsupported image format" in str(excinfo.value)
finally:
os.remove(tmp_file_name)
Expand All @@ -45,7 +45,7 @@ async def test_convert_local_image_to_data_url_unsupported_format():
async def test_convert_local_image_to_data_url_missing_file():
# Should raise FileNotFoundError for missing file
with pytest.raises(FileNotFoundError):
await convert_local_image_to_data_url("not_a_real_file.jpg")
await convert_local_image_to_data_url_async("not_a_real_file.jpg")


@patch("os.path.exists", return_value=True)
Expand All @@ -63,7 +63,7 @@ async def test_convert_image_to_data_url_success(

assert os.path.exists(tmp_file_name)

result = await convert_local_image_to_data_url(tmp_file_name)
result = await convert_local_image_to_data_url_async(tmp_file_name)
assert "data:image/jpeg;base64,encoded_base64_string" in result

# Assertions for the mocks
Expand Down
24 changes: 21 additions & 3 deletions tests/unit/common/test_data_url_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pyrit.common.data_url_converter import (
AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS,
convert_local_image_to_data_url,
convert_local_image_to_data_url_async,
)


Expand All @@ -21,15 +22,15 @@ def test_supported_image_formats_contains_common_types():

async def test_convert_raises_file_not_found():
with pytest.raises(FileNotFoundError):
await convert_local_image_to_data_url("nonexistent_image.jpg")
await convert_local_image_to_data_url_async("nonexistent_image.jpg")


async def test_convert_raises_for_unsupported_format():
with NamedTemporaryFile(suffix=".svg", delete=False) as f:
tmp = f.name
try:
with pytest.raises(ValueError, match="Unsupported image format"):
await convert_local_image_to_data_url(tmp)
await convert_local_image_to_data_url_async(tmp)
finally:
os.remove(tmp)

Expand All @@ -42,7 +43,24 @@ async def test_convert_returns_data_url():
mock_serializer.read_data_base64 = AsyncMock(return_value="AAAA")

with patch("pyrit.common.data_url_converter.data_serializer_factory", return_value=mock_serializer):
result = await convert_local_image_to_data_url(tmp)
result = await convert_local_image_to_data_url_async(tmp)

assert result.startswith("data:image/png;base64,")
assert result.endswith("AAAA")
finally:
os.remove(tmp)


async def test_deprecated_alias_emits_warning_and_delegates():
with NamedTemporaryFile(suffix=".png", delete=False) as f:
tmp = f.name
try:
mock_serializer = AsyncMock()
mock_serializer.read_data_base64 = AsyncMock(return_value="AAAA")

with patch("pyrit.common.data_url_converter.data_serializer_factory", return_value=mock_serializer):
with pytest.warns(DeprecationWarning, match="convert_local_image_to_data_url"):
result = await convert_local_image_to_data_url(tmp)

assert result.startswith("data:image/png;base64,")
assert result.endswith("AAAA")
Expand Down
Loading
Loading