Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyrit/common/download_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import asyncio
import logging
import os
from pathlib import Path

import aiofiles
Expand Down Expand Up @@ -47,7 +46,7 @@ async def download_specific_files_async(
Download specific files from a Hugging Face model repository.
If file_patterns is None, downloads all files.
"""
os.makedirs(cache_dir, exist_ok=True)
cache_dir.mkdir(parents=True, exist_ok=True)

available_files = get_available_files(model_id, token)
# If no file patterns are provided, download all available files
Expand Down
4 changes: 2 additions & 2 deletions pyrit/message_normalizer/chat_message_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import base64
import json
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Union

import aiofiles
Expand Down Expand Up @@ -175,7 +175,7 @@ async def _convert_audio_to_input_audio_async(self, audio_path: str) -> dict[str
audio_format = SUPPORTED_AUDIO_FORMATS[ext]

# Read and encode the audio file
if not os.path.isfile(audio_path):
if not Path(audio_path).is_file():
raise FileNotFoundError(f"Audio file not found: {audio_path}")

async with aiofiles.open(audio_path, "rb") as f:
Expand Down
7 changes: 3 additions & 4 deletions pyrit/models/data_type_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import asyncio
import base64
import hashlib
import os
import time
import wave
from mimetypes import guess_type
Expand Down Expand Up @@ -226,7 +225,7 @@ async def save_formatted_audio(
if self._memory.results_storage_io is None:
raise RuntimeError("self._memory.results_storage_io is not initialized")
await self._memory.results_storage_io.write_file(file_path, audio_data)
os.remove(local_temp_path)
local_temp_path.unlink()
Comment thread
romanlutz marked this conversation as resolved.

# If local, we can just save straight to disk and do not need to delete temp file after
else:
Expand Down Expand Up @@ -367,8 +366,8 @@ def get_extension(file_path: str) -> str | None:
str | None: File extension (including dot) or None if unavailable.

"""
_, ext = os.path.splitext(file_path)
return ext if ext else None
ext = Path(file_path).suffix
return ext or None

@staticmethod
def get_mime_type(file_path: str) -> str | None:
Expand Down
14 changes: 10 additions & 4 deletions pyrit/models/seeds/seed_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from __future__ import annotations

import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union

from tinytag import TinyTag
Expand All @@ -21,7 +21,6 @@
if TYPE_CHECKING:
import uuid
from collections.abc import Sequence
from pathlib import Path

from pyrit.models import Message
from pyrit.models.literals import ChatMessageRole, PromptDataType
Expand Down Expand Up @@ -65,8 +64,15 @@ def __post_init__(self) -> None:
if not self.data_type:
# If data_type is not provided, infer it from the value
# Note: Does not assign 'error' or 'url' implicitly
if os.path.isfile(self.value):
_, ext = os.path.splitext(self.value)
# Guard against OSError / ValueError so values that aren't valid path
# strings (too long, null bytes, etc.) are treated as text, matching
# the prior os.path.isfile semantics.
try:
is_file = Path(self.value).is_file()
except (OSError, ValueError):
is_file = False
if is_file:
ext = Path(self.value).suffix
ext = ext.lstrip(".").lower()
if ext in ["mp4", "avi", "mov", "mkv", "ogv", "flv", "wmv", "webm"]:
self.data_type = "video_path"
Expand Down
3 changes: 1 addition & 2 deletions pyrit/prompt_converter/add_image_to_video_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncio
import contextlib
import logging
import os
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -220,7 +219,7 @@ def _add_image_to_video_sync(
with contextlib.suppress(cv2.error):
cv2.destroyAllWindows() # Not available in headless OpenCV builds
if azure_storage_flag and local_temp_path is not None:
os.remove(local_temp_path)
local_temp_path.unlink()

async def convert_async(self, *, prompt: str, input_type: PromptDataType = "image_path") -> ConverterResult:
"""
Expand Down
4 changes: 2 additions & 2 deletions pyrit/prompt_target/common/discover_target_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
import asyncio
import json
import logging
import os
import uuid
from collections.abc import Awaitable, Callable, Iterable, Iterator
from contextlib import contextmanager
from dataclasses import replace
from pathlib import Path

from pyrit.common.path import DATASETS_PATH
from pyrit.models import Message, MessagePiece, PromptDataType
Expand Down Expand Up @@ -835,7 +835,7 @@ def _create_test_message(
asset_path = test_assets.get(modality)
if asset_path is None:
raise ValueError(f"No test asset configured for modality '{modality}'.")
if not os.path.isfile(asset_path):
if not Path(asset_path).is_file():
raise FileNotFoundError(f"Test asset for modality '{modality}' not found at: {asset_path}")

pieces.append(
Expand Down
10 changes: 5 additions & 5 deletions pyrit/prompt_target/http_target/httpx_api_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import logging
import mimetypes
import os
from collections.abc import Callable
from pathlib import Path
from typing import Any, Literal, Optional

import aiofiles
Expand Down Expand Up @@ -128,10 +128,10 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me
# If user didn't set file_path, see if the PDF path is in converted_value
if not self.file_path:
possible_path = message_piece.converted_value
if isinstance(possible_path, str) and os.path.exists(possible_path):
if isinstance(possible_path, str) and Path(possible_path).exists():
logger.info(f"HTTPXApiTarget: auto-using file_path from {possible_path}")
self.file_path = possible_path
elif not os.path.exists(self.file_path):
elif not Path(self.file_path).exists():
raise FileNotFoundError(f"File not found: {self.file_path}")

if not self.http_url:
Expand All @@ -141,9 +141,9 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me

async with httpx.AsyncClient(http2=http2_version, **self.httpx_client_kwargs) as client:
try:
if self.file_path and os.path.exists(self.file_path):
if self.file_path and Path(self.file_path).exists():
# Handle file upload (only for POST & PUT)
filename = os.path.basename(self.file_path)
filename = Path(self.file_path).name
mime_type = mimetypes.guess_type(filename)[0] or "application/octet-stream"

async with aiofiles.open(self.file_path, "rb") as fp:
Expand Down
15 changes: 7 additions & 8 deletions pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncio
import json
import logging
import os
import warnings
from pathlib import Path
from typing import Any, cast
Expand Down Expand Up @@ -264,12 +263,12 @@ async def load_model_and_tokenizer(self) -> None:
self._load_from_path(self.model_path, **optional_model_kwargs)
else:
# Define the default Hugging Face cache directory
cache_dir = os.path.join(
os.path.expanduser("~"),
".cache",
"huggingface",
"hub",
f"models--{(self.model_id or '').replace('/', '--')}",
cache_dir = (
Path.home()
/ ".cache"
/ "huggingface"
/ "hub"
/ f"models--{(self.model_id or '').replace('/', '--')}"
)

if self.necessary_files is None:
Expand All @@ -279,7 +278,7 @@ async def load_model_and_tokenizer(self) -> None:
self.model_id or "",
None,
self.huggingface_token, # type: ignore[ty:invalid-argument-type]
Path(cache_dir),
cache_dir,
)
else:
# Download only the necessary files
Expand Down
4 changes: 2 additions & 2 deletions pyrit/prompt_target/openai/openai_video_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# Licensed under the MIT license.

import logging
import os
from mimetypes import guess_type
from pathlib import Path
from typing import Any, Optional, Union, cast

from openai.types import VideoSeconds, VideoSize
Expand Down Expand Up @@ -327,7 +327,7 @@ async def _prepare_image_input_async(self, *, image_piece: MessagePiece) -> tupl
f"Supported formats: {', '.join(self.SUPPORTED_IMAGE_FORMATS)}"
)

filename = os.path.basename(image_path)
filename = Path(image_path).name
return (filename, image_bytes, mime_type)

async def _remix_and_poll_async(self, *, video_id: str, prompt: str) -> Any:
Expand Down
12 changes: 6 additions & 6 deletions pyrit/score/audio_transcript_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import asyncio
import logging
import os
import tempfile
import uuid
from pathlib import Path
from typing import Optional

import av
Expand Down Expand Up @@ -171,7 +171,7 @@ async def _score_audio_async(self, *, message_piece: MessagePiece, objective: Op
"""
audio_path = message_piece.converted_value

if not os.path.exists(audio_path):
if not Path(audio_path).exists():
raise FileNotFoundError(f"Audio file not found: {audio_path}")

# Transcribe audio to text
Expand Down Expand Up @@ -230,10 +230,10 @@ async def _transcribe_audio_async(self, audio_path: str) -> str:
logger.info(f"Audio transcription: WAV file path = {wav_path}")

# Check if WAV file exists and has content
if not os.path.exists(wav_path):
if not Path(wav_path).exists():
raise FileNotFoundError(f"WAV file does not exist at {wav_path}")

file_size = os.path.getsize(wav_path)
file_size = Path(wav_path).stat().st_size
logger.info(f"Audio transcription: WAV file size = {file_size} bytes")

try:
Expand All @@ -247,8 +247,8 @@ async def _transcribe_audio_async(self, audio_path: str) -> str:
raise
finally:
# Clean up temporary WAV file if it exists (ie for scoring audio from videos)
if wav_path != audio_path and os.path.exists(wav_path):
os.unlink(wav_path)
if wav_path != audio_path:
Path(wav_path).unlink(missing_ok=True)

def _ensure_wav_format(self, audio_path: str) -> str:
"""
Expand Down
8 changes: 4 additions & 4 deletions pyrit/score/video_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# Licensed under the MIT license.

import logging
import os
import random
import tempfile
import uuid
from pathlib import Path
from typing import Optional

from pyrit.memory import CentralMemory
Expand Down Expand Up @@ -112,7 +112,7 @@ async def _score_frames_async(self, *, message_piece: MessagePiece, objective: O
"""
video_path = message_piece.converted_value

if not os.path.exists(video_path):
if not Path(video_path).exists():
raise FileNotFoundError(f"Video file not found: {video_path}")

# Extract frames from video
Expand Down Expand Up @@ -281,5 +281,5 @@ async def _score_video_audio_async(

finally:
# Clean up temporary audio file on success
if should_cleanup and audio_path and os.path.exists(audio_path):
os.unlink(audio_path)
if should_cleanup and audio_path:
Path(audio_path).unlink(missing_ok=True)
24 changes: 24 additions & 0 deletions tests/unit/models/test_data_type_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,27 @@ async def test_get_data_filename_uses_db_data_path_when_results_path_falsy():
result_str = str(result).replace("\\", "/")
assert "/fallback/db_data" in result_str
assert result_str.endswith(".png")


async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path):
"""save_formatted_audio cleans up the local temp WAV after writing to Azure storage."""
from pyrit.models import data_serializer_factory as factory

serializer = factory(category="prompt-memory-entries", data_type="audio_path")
mock_memory = MagicMock()
mock_storage_io = AsyncMock()
mock_memory.results_storage_io = mock_storage_io
azure_url = "https://account.blob.core.windows.net/container/audio/test.wav"

with (
patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory),
patch.object(serializer, "get_data_filename", new_callable=AsyncMock, return_value=azure_url),
patch("pyrit.models.data_type_serializer.DB_DATA_PATH", tmp_path),
):
await serializer.save_formatted_audio(data=b"\x00\x01\x02\x03")

# The local temp file written via wave.open should have been unlinked after upload.
assert not (tmp_path / "temp_audio.wav").exists()
mock_storage_io.write_file.assert_awaited_once()
assert mock_storage_io.write_file.call_args[0][0] == azure_url
assert serializer.value == azure_url
37 changes: 27 additions & 10 deletions tests/unit/models/test_seed_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import uuid
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock

import pytest

Expand Down Expand Up @@ -49,18 +49,35 @@ def test_seed_prompt_infers_text_data_type():
(".png", "image_path"),
],
)
@patch("os.path.isfile", return_value=True)
def test_seed_prompt_infers_data_type_from_extension(mock_isfile, extension, expected_type):
with patch("os.path.splitext", return_value=("/path/file", extension)):
sp = SeedPrompt(value=f"/path/file{extension}")
assert sp.data_type == expected_type
def test_seed_prompt_infers_data_type_from_extension(tmp_path, extension, expected_type):
file_path = tmp_path / f"file{extension}"
file_path.touch()
sp = SeedPrompt(value=str(file_path))
assert sp.data_type == expected_type


@patch("os.path.isfile", return_value=True)
@patch("os.path.splitext", return_value=("/path/file", ".xyz"))
def test_seed_prompt_unknown_file_extension_raises(mock_splitext, mock_isfile):
def test_seed_prompt_unknown_file_extension_raises(tmp_path):
file_path = tmp_path / "file.xyz"
file_path.touch()
with pytest.raises(ValueError, match="Unable to infer data_type"):
SeedPrompt(value="/path/file.xyz")
SeedPrompt(value=str(file_path))


def test_seed_prompt_infers_text_for_value_exceeding_path_name_limit():
# Values longer than the filesystem name limit must be treated as text.
# Path(value).is_file() can raise OSError (ENAMETOOLONG) on Linux/macOS,
# whereas os.path.isfile silently returned False. The inference logic
# must preserve the prior behavior so long-form text values (e.g. an
# academic paper used as a jailbreak template) don't crash construction.
long_value = "JOURNAL OF ARTIFICIAL INTELLIGENCE SAFETY RESEARCH " * 100
sp = SeedPrompt(value=long_value)
assert sp.data_type == "text"


def test_seed_prompt_infers_text_for_value_with_null_byte():
# Null bytes raise ValueError inside pathlib; treat as text rather than crashing.
sp = SeedPrompt(value="some text with \x00 embedded null")
assert sp.data_type == "text"


def test_seed_prompt_explicit_data_type_not_overridden():
Expand Down
Loading
Loading