diff --git a/pyrit/common/download_hf_model.py b/pyrit/common/download_hf_model.py index 84d0105ee5..ad32222112 100644 --- a/pyrit/common/download_hf_model.py +++ b/pyrit/common/download_hf_model.py @@ -3,7 +3,6 @@ import asyncio import logging -import os from pathlib import Path import aiofiles @@ -47,7 +46,7 @@ async def download_specific_files_async( Download specific files from a Hugging Face model repository. If file_patterns is None, downloads all files. """ - os.makedirs(cache_dir, exist_ok=True) + cache_dir.mkdir(parents=True, exist_ok=True) available_files = get_available_files(model_id, token) # If no file patterns are provided, download all available files diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 8e2011d961..495aa44dd0 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -3,7 +3,7 @@ import base64 import json -import os +from pathlib import Path from typing import TYPE_CHECKING, Any, Union import aiofiles @@ -175,7 +175,7 @@ async def _convert_audio_to_input_audio_async(self, audio_path: str) -> dict[str audio_format = SUPPORTED_AUDIO_FORMATS[ext] # Read and encode the audio file - if not os.path.isfile(audio_path): + if not Path(audio_path).is_file(): raise FileNotFoundError(f"Audio file not found: {audio_path}") async with aiofiles.open(audio_path, "rb") as f: diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 07066bc267..bd83f44561 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -7,7 +7,6 @@ import asyncio import base64 import hashlib -import os import time import wave from mimetypes import guess_type @@ -226,7 +225,7 @@ async def save_formatted_audio( if self._memory.results_storage_io is None: raise RuntimeError("self._memory.results_storage_io is not initialized") await self._memory.results_storage_io.write_file(file_path, audio_data) - os.remove(local_temp_path) + local_temp_path.unlink() # If local, we can just save straight to disk and do not need to delete temp file after else: @@ -367,8 +366,8 @@ def get_extension(file_path: str) -> str | None: str | None: File extension (including dot) or None if unavailable. """ - _, ext = os.path.splitext(file_path) - return ext if ext else None + ext = Path(file_path).suffix + return ext or None @staticmethod def get_mime_type(file_path: str) -> str | None: diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index d2b867c105..fa6b9b59db 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -8,8 +8,8 @@ from __future__ import annotations import logging -import os from dataclasses import dataclass, field +from pathlib import Path from typing import TYPE_CHECKING, Optional, Union from tinytag import TinyTag @@ -21,7 +21,6 @@ if TYPE_CHECKING: import uuid from collections.abc import Sequence - from pathlib import Path from pyrit.models import Message from pyrit.models.literals import ChatMessageRole, PromptDataType @@ -65,8 +64,15 @@ def __post_init__(self) -> None: if not self.data_type: # If data_type is not provided, infer it from the value # Note: Does not assign 'error' or 'url' implicitly - if os.path.isfile(self.value): - _, ext = os.path.splitext(self.value) + # Guard against OSError / ValueError so values that aren't valid path + # strings (too long, null bytes, etc.) are treated as text, matching + # the prior os.path.isfile semantics. + try: + is_file = Path(self.value).is_file() + except (OSError, ValueError): + is_file = False + if is_file: + ext = Path(self.value).suffix ext = ext.lstrip(".").lower() if ext in ["mp4", "avi", "mov", "mkv", "ogv", "flv", "wmv", "webm"]: self.data_type = "video_path" diff --git a/pyrit/prompt_converter/add_image_to_video_converter.py b/pyrit/prompt_converter/add_image_to_video_converter.py index afb5051e64..d64973870e 100644 --- a/pyrit/prompt_converter/add_image_to_video_converter.py +++ b/pyrit/prompt_converter/add_image_to_video_converter.py @@ -4,7 +4,6 @@ import asyncio import contextlib import logging -import os from pathlib import Path from typing import Optional @@ -220,7 +219,7 @@ def _add_image_to_video_sync( with contextlib.suppress(cv2.error): cv2.destroyAllWindows() # Not available in headless OpenCV builds if azure_storage_flag and local_temp_path is not None: - os.remove(local_temp_path) + local_temp_path.unlink() async def convert_async(self, *, prompt: str, input_type: PromptDataType = "image_path") -> ConverterResult: """ diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index 55acbd227b..b5f42cb4cc 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -38,11 +38,11 @@ import asyncio import json import logging -import os import uuid from collections.abc import Awaitable, Callable, Iterable, Iterator from contextlib import contextmanager from dataclasses import replace +from pathlib import Path from pyrit.common.path import DATASETS_PATH from pyrit.models import Message, MessagePiece, PromptDataType @@ -835,7 +835,7 @@ def _create_test_message( asset_path = test_assets.get(modality) if asset_path is None: raise ValueError(f"No test asset configured for modality '{modality}'.") - if not os.path.isfile(asset_path): + if not Path(asset_path).is_file(): raise FileNotFoundError(f"Test asset for modality '{modality}' not found at: {asset_path}") pieces.append( diff --git a/pyrit/prompt_target/http_target/httpx_api_target.py b/pyrit/prompt_target/http_target/httpx_api_target.py index 9ea9ebbe21..bd32fd1fe2 100644 --- a/pyrit/prompt_target/http_target/httpx_api_target.py +++ b/pyrit/prompt_target/http_target/httpx_api_target.py @@ -3,8 +3,8 @@ import logging import mimetypes -import os from collections.abc import Callable +from pathlib import Path from typing import Any, Literal, Optional import aiofiles @@ -128,10 +128,10 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me # If user didn't set file_path, see if the PDF path is in converted_value if not self.file_path: possible_path = message_piece.converted_value - if isinstance(possible_path, str) and os.path.exists(possible_path): + if isinstance(possible_path, str) and Path(possible_path).exists(): logger.info(f"HTTPXApiTarget: auto-using file_path from {possible_path}") self.file_path = possible_path - elif not os.path.exists(self.file_path): + elif not Path(self.file_path).exists(): raise FileNotFoundError(f"File not found: {self.file_path}") if not self.http_url: @@ -141,9 +141,9 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me async with httpx.AsyncClient(http2=http2_version, **self.httpx_client_kwargs) as client: try: - if self.file_path and os.path.exists(self.file_path): + if self.file_path and Path(self.file_path).exists(): # Handle file upload (only for POST & PUT) - filename = os.path.basename(self.file_path) + filename = Path(self.file_path).name mime_type = mimetypes.guess_type(filename)[0] or "application/octet-stream" async with aiofiles.open(self.file_path, "rb") as fp: diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index 3c08fe3388..84aec25c29 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -4,7 +4,6 @@ import asyncio import json import logging -import os import warnings from pathlib import Path from typing import Any, cast @@ -264,12 +263,12 @@ async def load_model_and_tokenizer(self) -> None: self._load_from_path(self.model_path, **optional_model_kwargs) else: # Define the default Hugging Face cache directory - cache_dir = os.path.join( - os.path.expanduser("~"), - ".cache", - "huggingface", - "hub", - f"models--{(self.model_id or '').replace('/', '--')}", + cache_dir = ( + Path.home() + / ".cache" + / "huggingface" + / "hub" + / f"models--{(self.model_id or '').replace('/', '--')}" ) if self.necessary_files is None: @@ -279,7 +278,7 @@ async def load_model_and_tokenizer(self) -> None: self.model_id or "", None, self.huggingface_token, # type: ignore[ty:invalid-argument-type] - Path(cache_dir), + cache_dir, ) else: # Download only the necessary files diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 1b5304918e..9b10bd0c66 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -2,8 +2,8 @@ # Licensed under the MIT license. import logging -import os from mimetypes import guess_type +from pathlib import Path from typing import Any, Optional, Union, cast from openai.types import VideoSeconds, VideoSize @@ -327,7 +327,7 @@ async def _prepare_image_input_async(self, *, image_piece: MessagePiece) -> tupl f"Supported formats: {', '.join(self.SUPPORTED_IMAGE_FORMATS)}" ) - filename = os.path.basename(image_path) + filename = Path(image_path).name return (filename, image_bytes, mime_type) async def _remix_and_poll_async(self, *, video_id: str, prompt: str) -> Any: diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index d207193981..0f5ba30bbd 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -3,9 +3,9 @@ import asyncio import logging -import os import tempfile import uuid +from pathlib import Path from typing import Optional import av @@ -171,7 +171,7 @@ async def _score_audio_async(self, *, message_piece: MessagePiece, objective: Op """ audio_path = message_piece.converted_value - if not os.path.exists(audio_path): + if not Path(audio_path).exists(): raise FileNotFoundError(f"Audio file not found: {audio_path}") # Transcribe audio to text @@ -230,10 +230,10 @@ async def _transcribe_audio_async(self, audio_path: str) -> str: logger.info(f"Audio transcription: WAV file path = {wav_path}") # Check if WAV file exists and has content - if not os.path.exists(wav_path): + if not Path(wav_path).exists(): raise FileNotFoundError(f"WAV file does not exist at {wav_path}") - file_size = os.path.getsize(wav_path) + file_size = Path(wav_path).stat().st_size logger.info(f"Audio transcription: WAV file size = {file_size} bytes") try: @@ -247,8 +247,8 @@ async def _transcribe_audio_async(self, audio_path: str) -> str: raise finally: # Clean up temporary WAV file if it exists (ie for scoring audio from videos) - if wav_path != audio_path and os.path.exists(wav_path): - os.unlink(wav_path) + if wav_path != audio_path: + Path(wav_path).unlink(missing_ok=True) def _ensure_wav_format(self, audio_path: str) -> str: """ diff --git a/pyrit/score/video_scorer.py b/pyrit/score/video_scorer.py index 469018e367..400d2522a9 100644 --- a/pyrit/score/video_scorer.py +++ b/pyrit/score/video_scorer.py @@ -2,10 +2,10 @@ # Licensed under the MIT license. import logging -import os import random import tempfile import uuid +from pathlib import Path from typing import Optional from pyrit.memory import CentralMemory @@ -112,7 +112,7 @@ async def _score_frames_async(self, *, message_piece: MessagePiece, objective: O """ video_path = message_piece.converted_value - if not os.path.exists(video_path): + if not Path(video_path).exists(): raise FileNotFoundError(f"Video file not found: {video_path}") # Extract frames from video @@ -281,5 +281,5 @@ async def _score_video_audio_async( finally: # Clean up temporary audio file on success - if should_cleanup and audio_path and os.path.exists(audio_path): - os.unlink(audio_path) + if should_cleanup and audio_path: + Path(audio_path).unlink(missing_ok=True) diff --git a/tests/unit/models/test_data_type_serializer.py b/tests/unit/models/test_data_type_serializer.py index 75d7fef195..2fe1911abf 100644 --- a/tests/unit/models/test_data_type_serializer.py +++ b/tests/unit/models/test_data_type_serializer.py @@ -525,3 +525,27 @@ async def test_get_data_filename_uses_db_data_path_when_results_path_falsy(): result_str = str(result).replace("\\", "/") assert "/fallback/db_data" in result_str assert result_str.endswith(".png") + + +async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path): + """save_formatted_audio cleans up the local temp WAV after writing to Azure storage.""" + from pyrit.models import data_serializer_factory as factory + + serializer = factory(category="prompt-memory-entries", data_type="audio_path") + mock_memory = MagicMock() + mock_storage_io = AsyncMock() + mock_memory.results_storage_io = mock_storage_io + azure_url = "https://account.blob.core.windows.net/container/audio/test.wav" + + with ( + patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory), + patch.object(serializer, "get_data_filename", new_callable=AsyncMock, return_value=azure_url), + patch("pyrit.models.data_type_serializer.DB_DATA_PATH", tmp_path), + ): + await serializer.save_formatted_audio(data=b"\x00\x01\x02\x03") + + # The local temp file written via wave.open should have been unlinked after upload. + assert not (tmp_path / "temp_audio.wav").exists() + mock_storage_io.write_file.assert_awaited_once() + assert mock_storage_io.write_file.call_args[0][0] == azure_url + assert serializer.value == azure_url diff --git a/tests/unit/models/test_seed_prompt.py b/tests/unit/models/test_seed_prompt.py index 90f83ccd52..e5f584883d 100644 --- a/tests/unit/models/test_seed_prompt.py +++ b/tests/unit/models/test_seed_prompt.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import uuid -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -49,18 +49,35 @@ def test_seed_prompt_infers_text_data_type(): (".png", "image_path"), ], ) -@patch("os.path.isfile", return_value=True) -def test_seed_prompt_infers_data_type_from_extension(mock_isfile, extension, expected_type): - with patch("os.path.splitext", return_value=("/path/file", extension)): - sp = SeedPrompt(value=f"/path/file{extension}") - assert sp.data_type == expected_type +def test_seed_prompt_infers_data_type_from_extension(tmp_path, extension, expected_type): + file_path = tmp_path / f"file{extension}" + file_path.touch() + sp = SeedPrompt(value=str(file_path)) + assert sp.data_type == expected_type -@patch("os.path.isfile", return_value=True) -@patch("os.path.splitext", return_value=("/path/file", ".xyz")) -def test_seed_prompt_unknown_file_extension_raises(mock_splitext, mock_isfile): +def test_seed_prompt_unknown_file_extension_raises(tmp_path): + file_path = tmp_path / "file.xyz" + file_path.touch() with pytest.raises(ValueError, match="Unable to infer data_type"): - SeedPrompt(value="/path/file.xyz") + SeedPrompt(value=str(file_path)) + + +def test_seed_prompt_infers_text_for_value_exceeding_path_name_limit(): + # Values longer than the filesystem name limit must be treated as text. + # Path(value).is_file() can raise OSError (ENAMETOOLONG) on Linux/macOS, + # whereas os.path.isfile silently returned False. The inference logic + # must preserve the prior behavior so long-form text values (e.g. an + # academic paper used as a jailbreak template) don't crash construction. + long_value = "JOURNAL OF ARTIFICIAL INTELLIGENCE SAFETY RESEARCH " * 100 + sp = SeedPrompt(value=long_value) + assert sp.data_type == "text" + + +def test_seed_prompt_infers_text_for_value_with_null_byte(): + # Null bytes raise ValueError inside pathlib; treat as text rather than crashing. + sp = SeedPrompt(value="some text with \x00 embedded null") + assert sp.data_type == "text" def test_seed_prompt_explicit_data_type_not_overridden(): diff --git a/tests/unit/prompt_converter/test_add_image_video_converter.py b/tests/unit/prompt_converter/test_add_image_video_converter.py index a9e55a1045..acea7861cd 100644 --- a/tests/unit/prompt_converter/test_add_image_video_converter.py +++ b/tests/unit/prompt_converter/test_add_image_video_converter.py @@ -123,3 +123,45 @@ def factory_side_effect(*, category, data_type, value): ): with pytest.raises(ValueError, match="Failed to decode overlay image"): await converter._add_image_to_video(image_path="fake_image.png", output_path=output_path) + + +@pytest.mark.skipif(not is_opencv_installed(), reason="opencv is not installed") +async def test_add_image_to_video_azure_storage_unlinks_local_temp( + tmp_path, video_converter_sample_video, video_converter_sample_image +): + """When the video is in Azure storage, the downloaded local temp file is unlinked after processing.""" + from unittest.mock import AsyncMock, patch + + output_path = str(tmp_path / "output_video.mp4") + converter = AddImageVideoConverter(video_path=video_converter_sample_video, output_path=output_path) + + with open(video_converter_sample_video, "rb") as f: + video_bytes = f.read() + with open(video_converter_sample_image, "rb") as f: + image_bytes = f.read() + + mock_image_serializer = AsyncMock() + mock_image_serializer.read_data = AsyncMock(return_value=image_bytes) + mock_image_serializer._is_azure_storage_url = lambda x: False + + mock_video_serializer = AsyncMock() + mock_video_serializer.read_data = AsyncMock(return_value=video_bytes) + # Flag this video as living in Azure storage so the cleanup branch runs. + mock_video_serializer._is_azure_storage_url = lambda x: True + + def factory_side_effect(*, category, data_type, value): + if data_type == "image_path": + return mock_image_serializer + return mock_video_serializer + + with ( + patch( + "pyrit.prompt_converter.add_image_to_video_converter.data_serializer_factory", + side_effect=factory_side_effect, + ), + patch("pyrit.prompt_converter.add_image_to_video_converter.DB_DATA_PATH", tmp_path), + ): + await converter._add_image_to_video(image_path=video_converter_sample_image, output_path=output_path) + + # The local copy of the Azure-stored video should be removed by the cleanup branch. + assert not (tmp_path / "temp_video.mp4").exists() diff --git a/tests/unit/score/test_audio_scorer.py b/tests/unit/score/test_audio_scorer.py index 2de656bf43..3b0e51abb4 100644 --- a/tests/unit/score/test_audio_scorer.py +++ b/tests/unit/score/test_audio_scorer.py @@ -260,6 +260,35 @@ async def test_transcribe_audio_async_creates_converter(self, audio_message_piec mock_cls.assert_called_once() mock_converter.convert_async.assert_called_once() + async def test_transcribe_audio_async_unlinks_converted_wav(self, audio_message_piece, tmp_path): + """When _ensure_wav_format produces a different temp WAV, that file is cleaned up in finally.""" + from pyrit.score.audio_transcript_scorer import AudioTranscriptHelper + + text_scorer = MockTextTrueFalseScorer() + helper = AudioTranscriptHelper(text_capable_scorer=text_scorer) + + # Create a real temporary WAV distinct from audio_message_piece.converted_value + converted_wav = tmp_path / "converted.wav" + converted_wav.write_bytes(b"fake wav content") + + mock_converter = AsyncMock() + mock_result = AsyncMock() + mock_result.output_text = "transcribed text" + mock_converter.convert_async.return_value = mock_result + + with ( + patch.object(helper, "_ensure_wav_format", return_value=str(converted_wav)), + patch( + "pyrit.score.audio_transcript_scorer.AzureSpeechAudioToTextConverter", + return_value=mock_converter, + ), + ): + result = await helper._transcribe_audio_async(audio_message_piece.converted_value) + + assert result == "transcribed text" + # The converted temp WAV (different from the original audio path) should be deleted. + assert not converted_wav.exists() + class TestPyAVAudioConversion: """Tests for PyAV audio conversion functions""" diff --git a/tests/unit/score/test_video_scorer.py b/tests/unit/score/test_video_scorer.py index 506b10f7dd..4c5d4d3524 100644 --- a/tests/unit/score/test_video_scorer.py +++ b/tests/unit/score/test_video_scorer.py @@ -334,6 +334,28 @@ async def test_video_true_false_scorer_with_audio_scorer(video_converter_sample_ assert "visual" in scores[0].score_rationale.lower() or "audio" in scores[0].score_rationale.lower() +@pytest.mark.skipif(not is_opencv_installed(), reason="opencv is not installed") +async def test_video_audio_scorer_cleans_up_extracted_audio(tmp_path, video_converter_sample_video): + """_score_video_audio_async unlinks the temp audio file after successful scoring.""" + image_scorer = MockTrueFalseScorer(return_value=True) + audio_scorer = MockAudioTrueFalseScorer(return_value=True) + + # Create a real temp audio file that should be deleted by the cleanup branch. + extracted_audio = tmp_path / "extracted_audio.wav" + extracted_audio.write_bytes(b"fake audio bytes") + + with patch.object(AudioTranscriptHelper, "extract_audio_from_video", return_value=str(extracted_audio)): + scorer = VideoTrueFalseScorer( + image_capable_scorer=image_scorer, + audio_scorer=audio_scorer, + num_sampled_frames=3, + ) + + await scorer._score_piece_async(video_converter_sample_video) + + assert not extracted_audio.exists() + + @pytest.mark.skipif(not is_opencv_installed(), reason="opencv is not installed") async def test_video_scorer_and_aggregation_both_true(video_converter_sample_video): """Test AND aggregation when both visual and audio scores are true"""