From 357e8e301a582374b92cd7157cf4b061d6f43454 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 1 Jun 2026 11:24:58 -0700 Subject: [PATCH 1/4] MAINT: Migrate os.path.* to pathlib.Path across pyrit/ Audit found ~22 os.path.* call sites across 11 pyrit/ source files, plus a few paired sibling APIs (os.makedirs/os.remove/os.unlink). This change does a mechanical pass to the pathlib equivalents: - os.path.exists -> Path(...).exists() - os.path.isfile -> Path(...).is_file() - os.path.splitext(p)[1] -> Path(p).suffix - os.path.basename -> Path(...).name - os.path.join(os.path.expanduser('~'), ...) -> Path.home() / ... - os.path.getsize -> Path(...).stat().st_size - os.makedirs(p, exist_ok=True) -> Path(p).mkdir(parents=True, exist_ok=True) - os.remove / os.unlink -> Path(...).unlink() Public parameter and attribute types remain str / Optional[str]; Path(...) conversion happens at the boundary so callers passing strings continue to work unchanged. In tests/unit/models/test_seed_prompt.py, two tests previously patched os.path.isfile / os.path.splitext on the source module. Those mocks would be inert after the migration, so they were replaced with real files created via the tmp_path fixture. Ruff PTH ruleset is intentionally left disabled in this PR; enabling it is a recommended follow-up now that pyrit/ is clean. Verification: - ruff check pyrit/ tests/unit/models/test_seed_prompt.py: passed - ruff format --check (same scope): passed - pytest tests/unit/{models/test_seed_prompt.py, models/test_data_type_serializer.py, score, message_normalizer, common, prompt_target, prompt_converter}: 3245 passed, 81 skipped Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/common/download_hf_model.py | 3 +-- .../chat_message_normalizer.py | 4 ++-- pyrit/models/data_type_serializer.py | 7 +++---- pyrit/models/seeds/seed_prompt.py | 7 +++---- .../add_image_to_video_converter.py | 3 +-- .../common/discover_target_capabilities.py | 4 ++-- .../http_target/httpx_api_target.py | 10 +++++----- .../hugging_face/hugging_face_chat_target.py | 15 +++++++------- .../openai/openai_video_target.py | 4 ++-- pyrit/score/audio_transcript_scorer.py | 12 +++++------ pyrit/score/video_scorer.py | 8 ++++---- tests/unit/models/test_seed_prompt.py | 20 +++++++++---------- 12 files changed, 46 insertions(+), 51 deletions(-) diff --git a/pyrit/common/download_hf_model.py b/pyrit/common/download_hf_model.py index c34ccb7aaf..a6c9467da5 100644 --- a/pyrit/common/download_hf_model.py +++ b/pyrit/common/download_hf_model.py @@ -3,7 +3,6 @@ import asyncio import logging -import os from pathlib import Path import httpx @@ -46,7 +45,7 @@ async def download_specific_files_async( Download specific files from a Hugging Face model repository. If file_patterns is None, downloads all files. """ - os.makedirs(cache_dir, exist_ok=True) + cache_dir.mkdir(parents=True, exist_ok=True) available_files = get_available_files(model_id, token) # If no file patterns are provided, download all available files diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index c5d3547e80..c9f0698165 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -3,7 +3,7 @@ import base64 import json -import os +from pathlib import Path from typing import TYPE_CHECKING, Any, Union from pyrit.common.data_url_converter import convert_local_image_to_data_url_async @@ -173,7 +173,7 @@ async def _convert_audio_to_input_audio(self, audio_path: str) -> dict[str, Any] audio_format = SUPPORTED_AUDIO_FORMATS[ext] # Read and encode the audio file - if not os.path.isfile(audio_path): + if not Path(audio_path).is_file(): raise FileNotFoundError(f"Audio file not found: {audio_path}") with open(audio_path, "rb") as f: diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 578efca5cc..0ace78f111 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -6,7 +6,6 @@ import abc import base64 import hashlib -import os import time import wave from mimetypes import guess_type @@ -206,7 +205,7 @@ async def save_formatted_audio( if self._memory.results_storage_io is None: raise RuntimeError("self._memory.results_storage_io is not initialized") await self._memory.results_storage_io.write_file(file_path, audio_data) - os.remove(local_temp_path) + local_temp_path.unlink() # If local, we can just save straight to disk and do not need to delete temp file after else: @@ -344,8 +343,8 @@ def get_extension(file_path: str) -> str | None: str | None: File extension (including dot) or None if unavailable. """ - _, ext = os.path.splitext(file_path) - return ext if ext else None + ext = Path(file_path).suffix + return ext or None @staticmethod def get_mime_type(file_path: str) -> str | None: diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index d2b867c105..6047c0e0c9 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -8,8 +8,8 @@ from __future__ import annotations import logging -import os from dataclasses import dataclass, field +from pathlib import Path from typing import TYPE_CHECKING, Optional, Union from tinytag import TinyTag @@ -21,7 +21,6 @@ if TYPE_CHECKING: import uuid from collections.abc import Sequence - from pathlib import Path from pyrit.models import Message from pyrit.models.literals import ChatMessageRole, PromptDataType @@ -65,8 +64,8 @@ def __post_init__(self) -> None: if not self.data_type: # If data_type is not provided, infer it from the value # Note: Does not assign 'error' or 'url' implicitly - if os.path.isfile(self.value): - _, ext = os.path.splitext(self.value) + if Path(self.value).is_file(): + ext = Path(self.value).suffix ext = ext.lstrip(".").lower() if ext in ["mp4", "avi", "mov", "mkv", "ogv", "flv", "wmv", "webm"]: self.data_type = "video_path" diff --git a/pyrit/prompt_converter/add_image_to_video_converter.py b/pyrit/prompt_converter/add_image_to_video_converter.py index d13b333184..14a237ea19 100644 --- a/pyrit/prompt_converter/add_image_to_video_converter.py +++ b/pyrit/prompt_converter/add_image_to_video_converter.py @@ -3,7 +3,6 @@ import contextlib import logging -import os from pathlib import Path from typing import Optional @@ -183,7 +182,7 @@ async def _add_image_to_video(self, image_path: str, output_path: str) -> str: with contextlib.suppress(cv2.error): cv2.destroyAllWindows() # Not available in headless OpenCV builds if azure_storage_flag: - os.remove(local_temp_path) # type: ignore[ty:possibly-unresolved-reference] + local_temp_path.unlink() # type: ignore[ty:possibly-unresolved-reference] logger.info(f"Video saved as {output_path}") diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index 45600e6009..cce6db7f67 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -38,11 +38,11 @@ import asyncio import json import logging -import os import uuid from collections.abc import Awaitable, Callable, Iterable, Iterator from contextlib import contextmanager from dataclasses import replace +from pathlib import Path from pyrit.common.path import DATASETS_PATH from pyrit.models import Message, MessagePiece, PromptDataType @@ -823,7 +823,7 @@ def _create_test_message( asset_path = test_assets.get(modality) if asset_path is None: raise ValueError(f"No test asset configured for modality '{modality}'.") - if not os.path.isfile(asset_path): + if not Path(asset_path).is_file(): raise FileNotFoundError(f"Test asset for modality '{modality}' not found at: {asset_path}") pieces.append( diff --git a/pyrit/prompt_target/http_target/httpx_api_target.py b/pyrit/prompt_target/http_target/httpx_api_target.py index f6af5b15d4..a5da118514 100644 --- a/pyrit/prompt_target/http_target/httpx_api_target.py +++ b/pyrit/prompt_target/http_target/httpx_api_target.py @@ -3,8 +3,8 @@ import logging import mimetypes -import os from collections.abc import Callable +from pathlib import Path from typing import Any, Literal, Optional import httpx @@ -127,10 +127,10 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me # If user didn't set file_path, see if the PDF path is in converted_value if not self.file_path: possible_path = message_piece.converted_value - if isinstance(possible_path, str) and os.path.exists(possible_path): + if isinstance(possible_path, str) and Path(possible_path).exists(): logger.info(f"HTTPXApiTarget: auto-using file_path from {possible_path}") self.file_path = possible_path - elif not os.path.exists(self.file_path): + elif not Path(self.file_path).exists(): raise FileNotFoundError(f"File not found: {self.file_path}") if not self.http_url: @@ -140,9 +140,9 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me async with httpx.AsyncClient(http2=http2_version, **self.httpx_client_kwargs) as client: try: - if self.file_path and os.path.exists(self.file_path): + if self.file_path and Path(self.file_path).exists(): # Handle file upload (only for POST & PUT) - filename = os.path.basename(self.file_path) + filename = Path(self.file_path).name mime_type = mimetypes.guess_type(filename)[0] or "application/octet-stream" with open(self.file_path, "rb") as fp: diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index f2d62be82a..f8e9292286 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -4,7 +4,6 @@ import asyncio import json import logging -import os import warnings from pathlib import Path from typing import Any, cast @@ -265,12 +264,12 @@ async def load_model_and_tokenizer(self) -> None: self._load_from_path(self.model_path, **optional_model_kwargs) else: # Define the default Hugging Face cache directory - cache_dir = os.path.join( - os.path.expanduser("~"), - ".cache", - "huggingface", - "hub", - f"models--{(self.model_id or '').replace('/', '--')}", + cache_dir = ( + Path.home() + / ".cache" + / "huggingface" + / "hub" + / f"models--{(self.model_id or '').replace('/', '--')}" ) if self.necessary_files is None: @@ -280,7 +279,7 @@ async def load_model_and_tokenizer(self) -> None: self.model_id or "", None, self.huggingface_token, # type: ignore[ty:invalid-argument-type] - Path(cache_dir), + cache_dir, ) else: # Download only the necessary files diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 544c1e8733..43ad24b879 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -2,8 +2,8 @@ # Licensed under the MIT license. import logging -import os from mimetypes import guess_type +from pathlib import Path from typing import Any, Optional, Union, cast from openai.types import VideoSeconds, VideoSize @@ -327,7 +327,7 @@ async def _prepare_image_input_async(self, *, image_piece: MessagePiece) -> tupl f"Supported formats: {', '.join(self.SUPPORTED_IMAGE_FORMATS)}" ) - filename = os.path.basename(image_path) + filename = Path(image_path).name return (filename, image_bytes, mime_type) async def _remix_and_poll_async(self, *, video_id: str, prompt: str) -> Any: diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index 25343c3a77..02eecb5c91 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -2,9 +2,9 @@ # Licensed under the MIT license. import logging -import os import tempfile import uuid +from pathlib import Path from typing import Optional import av @@ -170,7 +170,7 @@ async def _score_audio_async(self, *, message_piece: MessagePiece, objective: Op """ audio_path = message_piece.converted_value - if not os.path.exists(audio_path): + if not Path(audio_path).exists(): raise FileNotFoundError(f"Audio file not found: {audio_path}") # Transcribe audio to text @@ -229,10 +229,10 @@ async def _transcribe_audio_async(self, audio_path: str) -> str: logger.info(f"Audio transcription: WAV file path = {wav_path}") # Check if WAV file exists and has content - if not os.path.exists(wav_path): + if not Path(wav_path).exists(): raise FileNotFoundError(f"WAV file does not exist at {wav_path}") - file_size = os.path.getsize(wav_path) + file_size = Path(wav_path).stat().st_size logger.info(f"Audio transcription: WAV file size = {file_size} bytes") try: @@ -246,8 +246,8 @@ async def _transcribe_audio_async(self, audio_path: str) -> str: raise finally: # Clean up temporary WAV file if it exists (ie for scoring audio from videos) - if wav_path != audio_path and os.path.exists(wav_path): - os.unlink(wav_path) + if wav_path != audio_path and Path(wav_path).exists(): + Path(wav_path).unlink() def _ensure_wav_format(self, audio_path: str) -> str: """ diff --git a/pyrit/score/video_scorer.py b/pyrit/score/video_scorer.py index 2450105345..917e5b766a 100644 --- a/pyrit/score/video_scorer.py +++ b/pyrit/score/video_scorer.py @@ -2,10 +2,10 @@ # Licensed under the MIT license. import logging -import os import random import tempfile import uuid +from pathlib import Path from typing import Optional from pyrit.memory import CentralMemory @@ -112,7 +112,7 @@ async def _score_frames_async(self, *, message_piece: MessagePiece, objective: O """ video_path = message_piece.converted_value - if not os.path.exists(video_path): + if not Path(video_path).exists(): raise FileNotFoundError(f"Video file not found: {video_path}") # Extract frames from video @@ -281,5 +281,5 @@ async def _score_video_audio_async( finally: # Clean up temporary audio file on success - if should_cleanup and audio_path and os.path.exists(audio_path): - os.unlink(audio_path) + if should_cleanup and audio_path and Path(audio_path).exists(): + Path(audio_path).unlink() diff --git a/tests/unit/models/test_seed_prompt.py b/tests/unit/models/test_seed_prompt.py index 90f83ccd52..48ae2c353f 100644 --- a/tests/unit/models/test_seed_prompt.py +++ b/tests/unit/models/test_seed_prompt.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import uuid -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -49,18 +49,18 @@ def test_seed_prompt_infers_text_data_type(): (".png", "image_path"), ], ) -@patch("os.path.isfile", return_value=True) -def test_seed_prompt_infers_data_type_from_extension(mock_isfile, extension, expected_type): - with patch("os.path.splitext", return_value=("/path/file", extension)): - sp = SeedPrompt(value=f"/path/file{extension}") - assert sp.data_type == expected_type +def test_seed_prompt_infers_data_type_from_extension(tmp_path, extension, expected_type): + file_path = tmp_path / f"file{extension}" + file_path.touch() + sp = SeedPrompt(value=str(file_path)) + assert sp.data_type == expected_type -@patch("os.path.isfile", return_value=True) -@patch("os.path.splitext", return_value=("/path/file", ".xyz")) -def test_seed_prompt_unknown_file_extension_raises(mock_splitext, mock_isfile): +def test_seed_prompt_unknown_file_extension_raises(tmp_path): + file_path = tmp_path / "file.xyz" + file_path.touch() with pytest.raises(ValueError, match="Unable to infer data_type"): - SeedPrompt(value="/path/file.xyz") + SeedPrompt(value=str(file_path)) def test_seed_prompt_explicit_data_type_not_overridden(): From 11e25f480a4ca06af973c9529f9b44176120cfbe Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 1 Jun 2026 12:39:55 -0700 Subject: [PATCH 2/4] MAINT: Treat overlong/invalid SeedPrompt values as text Path(value).is_file() raises OSError (ENAMETOOLONG: Errno 36 on Linux, Errno 63 on macOS) for strings exceeding the filesystem's name limit, and raises ValueError for embedded null bytes. The previous os.path.isfile(value) silently returned False in both cases, so SeedPrompt construction worked for long text-only values such as the academic-paper jailbreak template at pyrit/datasets/jailbreak/templates/Arth_Singh/context_flood_academic.yaml. The pathlib migration regressed this, causing TestJailbreakInitialization / TestJailbreakAttackGeneration to fail on macos-3.11 and the Linux coverage job. Wrap the is_file() probe in try/except (OSError, ValueError) so the inference falls through to data_type='text', preserving the prior os.path.isfile semantics. Add two unit tests in tests/unit/models/test_seed_prompt.py covering the long-value and null-byte cases. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/seeds/seed_prompt.py | 9 ++++++++- tests/unit/models/test_seed_prompt.py | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index 6047c0e0c9..fa6b9b59db 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -64,7 +64,14 @@ def __post_init__(self) -> None: if not self.data_type: # If data_type is not provided, infer it from the value # Note: Does not assign 'error' or 'url' implicitly - if Path(self.value).is_file(): + # Guard against OSError / ValueError so values that aren't valid path + # strings (too long, null bytes, etc.) are treated as text, matching + # the prior os.path.isfile semantics. + try: + is_file = Path(self.value).is_file() + except (OSError, ValueError): + is_file = False + if is_file: ext = Path(self.value).suffix ext = ext.lstrip(".").lower() if ext in ["mp4", "avi", "mov", "mkv", "ogv", "flv", "wmv", "webm"]: diff --git a/tests/unit/models/test_seed_prompt.py b/tests/unit/models/test_seed_prompt.py index 48ae2c353f..e5f584883d 100644 --- a/tests/unit/models/test_seed_prompt.py +++ b/tests/unit/models/test_seed_prompt.py @@ -63,6 +63,23 @@ def test_seed_prompt_unknown_file_extension_raises(tmp_path): SeedPrompt(value=str(file_path)) +def test_seed_prompt_infers_text_for_value_exceeding_path_name_limit(): + # Values longer than the filesystem name limit must be treated as text. + # Path(value).is_file() can raise OSError (ENAMETOOLONG) on Linux/macOS, + # whereas os.path.isfile silently returned False. The inference logic + # must preserve the prior behavior so long-form text values (e.g. an + # academic paper used as a jailbreak template) don't crash construction. + long_value = "JOURNAL OF ARTIFICIAL INTELLIGENCE SAFETY RESEARCH " * 100 + sp = SeedPrompt(value=long_value) + assert sp.data_type == "text" + + +def test_seed_prompt_infers_text_for_value_with_null_byte(): + # Null bytes raise ValueError inside pathlib; treat as text rather than crashing. + sp = SeedPrompt(value="some text with \x00 embedded null") + assert sp.data_type == "text" + + def test_seed_prompt_explicit_data_type_not_overridden(): sp = SeedPrompt(value="some text", data_type="text") assert sp.data_type == "text" From 680a6586641f0556c3ec8c9af93f71e1905e3f25 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 1 Jun 2026 13:24:48 -0700 Subject: [PATCH 3/4] TEST: Add targeted coverage for pathlib migration cleanup branches The diff_cover gate (>=90% on changed lines) was failing because four lines in finally-style cleanup branches added during the os.path -> pathlib migration were not exercised by the existing test suite: - pyrit/models/data_type_serializer.py:208 (local_temp_path.unlink() after Azure storage upload in save_formatted_audio) - pyrit/prompt_converter/add_image_to_video_converter.py:185 (local_temp_path.unlink() in azure_storage_flag finally branch) - pyrit/score/audio_transcript_scorer.py:250 (Path(wav_path).unlink() cleanup when _ensure_wav_format produced a converted temp WAV) - pyrit/score/video_scorer.py:285 (Path(audio_path).unlink() cleanup after successful audio scoring in _score_video_audio_async) Added four tests that patch DB_DATA_PATH / extract_audio_from_video / _ensure_wav_format so the cleanup branches run against real temp files and verify the files are unlinked. Diff coverage now reports 98% on this PR (target: >=90%). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../unit/models/test_data_type_serializer.py | 24 +++++++++++ .../test_add_image_video_converter.py | 42 +++++++++++++++++++ tests/unit/score/test_audio_scorer.py | 29 +++++++++++++ tests/unit/score/test_video_scorer.py | 22 ++++++++++ 4 files changed, 117 insertions(+) diff --git a/tests/unit/models/test_data_type_serializer.py b/tests/unit/models/test_data_type_serializer.py index d710afd830..b45d0256b1 100644 --- a/tests/unit/models/test_data_type_serializer.py +++ b/tests/unit/models/test_data_type_serializer.py @@ -426,3 +426,27 @@ async def test_get_data_filename_uses_db_data_path_when_results_path_falsy(): result_str = str(result).replace("\\", "/") assert "/fallback/db_data" in result_str assert result_str.endswith(".png") + + +async def test_save_formatted_audio_azure_storage_unlinks_local_temp(tmp_path): + """save_formatted_audio cleans up the local temp WAV after writing to Azure storage.""" + from pyrit.models import data_serializer_factory as factory + + serializer = factory(category="prompt-memory-entries", data_type="audio_path") + mock_memory = MagicMock() + mock_storage_io = AsyncMock() + mock_memory.results_storage_io = mock_storage_io + azure_url = "https://account.blob.core.windows.net/container/audio/test.wav" + + with ( + patch.object(type(serializer), "_memory", new_callable=PropertyMock, return_value=mock_memory), + patch.object(serializer, "get_data_filename", new_callable=AsyncMock, return_value=azure_url), + patch("pyrit.models.data_type_serializer.DB_DATA_PATH", tmp_path), + ): + await serializer.save_formatted_audio(data=b"\x00\x01\x02\x03") + + # The local temp file written via wave.open should have been unlinked after upload. + assert not (tmp_path / "temp_audio.wav").exists() + mock_storage_io.write_file.assert_awaited_once() + assert mock_storage_io.write_file.call_args[0][0] == azure_url + assert serializer.value == azure_url diff --git a/tests/unit/prompt_converter/test_add_image_video_converter.py b/tests/unit/prompt_converter/test_add_image_video_converter.py index a9e55a1045..acea7861cd 100644 --- a/tests/unit/prompt_converter/test_add_image_video_converter.py +++ b/tests/unit/prompt_converter/test_add_image_video_converter.py @@ -123,3 +123,45 @@ def factory_side_effect(*, category, data_type, value): ): with pytest.raises(ValueError, match="Failed to decode overlay image"): await converter._add_image_to_video(image_path="fake_image.png", output_path=output_path) + + +@pytest.mark.skipif(not is_opencv_installed(), reason="opencv is not installed") +async def test_add_image_to_video_azure_storage_unlinks_local_temp( + tmp_path, video_converter_sample_video, video_converter_sample_image +): + """When the video is in Azure storage, the downloaded local temp file is unlinked after processing.""" + from unittest.mock import AsyncMock, patch + + output_path = str(tmp_path / "output_video.mp4") + converter = AddImageVideoConverter(video_path=video_converter_sample_video, output_path=output_path) + + with open(video_converter_sample_video, "rb") as f: + video_bytes = f.read() + with open(video_converter_sample_image, "rb") as f: + image_bytes = f.read() + + mock_image_serializer = AsyncMock() + mock_image_serializer.read_data = AsyncMock(return_value=image_bytes) + mock_image_serializer._is_azure_storage_url = lambda x: False + + mock_video_serializer = AsyncMock() + mock_video_serializer.read_data = AsyncMock(return_value=video_bytes) + # Flag this video as living in Azure storage so the cleanup branch runs. + mock_video_serializer._is_azure_storage_url = lambda x: True + + def factory_side_effect(*, category, data_type, value): + if data_type == "image_path": + return mock_image_serializer + return mock_video_serializer + + with ( + patch( + "pyrit.prompt_converter.add_image_to_video_converter.data_serializer_factory", + side_effect=factory_side_effect, + ), + patch("pyrit.prompt_converter.add_image_to_video_converter.DB_DATA_PATH", tmp_path), + ): + await converter._add_image_to_video(image_path=video_converter_sample_image, output_path=output_path) + + # The local copy of the Azure-stored video should be removed by the cleanup branch. + assert not (tmp_path / "temp_video.mp4").exists() diff --git a/tests/unit/score/test_audio_scorer.py b/tests/unit/score/test_audio_scorer.py index 4162377920..4289bec083 100644 --- a/tests/unit/score/test_audio_scorer.py +++ b/tests/unit/score/test_audio_scorer.py @@ -261,6 +261,35 @@ async def test_transcribe_audio_async_creates_converter(self, audio_message_piec mock_cls.assert_called_once() mock_converter.convert_async.assert_called_once() + async def test_transcribe_audio_async_unlinks_converted_wav(self, audio_message_piece, tmp_path): + """When _ensure_wav_format produces a different temp WAV, that file is cleaned up in finally.""" + from pyrit.score.audio_transcript_scorer import AudioTranscriptHelper + + text_scorer = MockTextTrueFalseScorer() + helper = AudioTranscriptHelper(text_capable_scorer=text_scorer) + + # Create a real temporary WAV distinct from audio_message_piece.converted_value + converted_wav = tmp_path / "converted.wav" + converted_wav.write_bytes(b"fake wav content") + + mock_converter = AsyncMock() + mock_result = AsyncMock() + mock_result.output_text = "transcribed text" + mock_converter.convert_async.return_value = mock_result + + with ( + patch.object(helper, "_ensure_wav_format", return_value=str(converted_wav)), + patch( + "pyrit.score.audio_transcript_scorer.AzureSpeechAudioToTextConverter", + return_value=mock_converter, + ), + ): + result = await helper._transcribe_audio_async(audio_message_piece.converted_value) + + assert result == "transcribed text" + # The converted temp WAV (different from the original audio path) should be deleted. + assert not converted_wav.exists() + class TestPyAVAudioConversion: """Tests for PyAV audio conversion functions""" diff --git a/tests/unit/score/test_video_scorer.py b/tests/unit/score/test_video_scorer.py index ddd5b383d7..7b8106fee5 100644 --- a/tests/unit/score/test_video_scorer.py +++ b/tests/unit/score/test_video_scorer.py @@ -335,6 +335,28 @@ async def test_video_true_false_scorer_with_audio_scorer(video_converter_sample_ assert "visual" in scores[0].score_rationale.lower() or "audio" in scores[0].score_rationale.lower() +@pytest.mark.skipif(not is_opencv_installed(), reason="opencv is not installed") +async def test_video_audio_scorer_cleans_up_extracted_audio(tmp_path, video_converter_sample_video): + """_score_video_audio_async unlinks the temp audio file after successful scoring.""" + image_scorer = MockTrueFalseScorer(return_value=True) + audio_scorer = MockAudioTrueFalseScorer(return_value=True) + + # Create a real temp audio file that should be deleted by the cleanup branch. + extracted_audio = tmp_path / "extracted_audio.wav" + extracted_audio.write_bytes(b"fake audio bytes") + + with patch.object(AudioTranscriptHelper, "extract_audio_from_video", return_value=str(extracted_audio)): + scorer = VideoTrueFalseScorer( + image_capable_scorer=image_scorer, + audio_scorer=audio_scorer, + num_sampled_frames=3, + ) + + await scorer._score_piece_async(video_converter_sample_video) + + assert not extracted_audio.exists() + + @pytest.mark.skipif(not is_opencv_installed(), reason="opencv is not installed") async def test_video_scorer_and_aggregation_both_true(video_converter_sample_video): """Test AND aggregation when both visual and audio scores are true""" From 314f2c2a1986f4c946f0891115b40baec1899c9f Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Tue, 2 Jun 2026 13:26:50 -0700 Subject: [PATCH 4/4] MAINT: Collapse exists()+unlink() into unlink(missing_ok=True) Applies @jsong468's review suggestions on PR #1877: the two temp-file cleanup branches in audio_transcript_scorer.py and video_scorer.py were probing Path.exists() before calling Path.unlink(). Collapse both to Path(...).unlink(missing_ok=True), which is race-safe (no TOCTOU gap between the existence check and the unlink) and a single syscall. Behavior is unchanged: if the file is absent the call is a no-op; if present it is removed. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/score/audio_transcript_scorer.py | 4 ++-- pyrit/score/video_scorer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index 02eecb5c91..8667e76c72 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -246,8 +246,8 @@ async def _transcribe_audio_async(self, audio_path: str) -> str: raise finally: # Clean up temporary WAV file if it exists (ie for scoring audio from videos) - if wav_path != audio_path and Path(wav_path).exists(): - Path(wav_path).unlink() + if wav_path != audio_path: + Path(wav_path).unlink(missing_ok=True) def _ensure_wav_format(self, audio_path: str) -> str: """ diff --git a/pyrit/score/video_scorer.py b/pyrit/score/video_scorer.py index 917e5b766a..e27a772f23 100644 --- a/pyrit/score/video_scorer.py +++ b/pyrit/score/video_scorer.py @@ -281,5 +281,5 @@ async def _score_video_audio_async( finally: # Clean up temporary audio file on success - if should_cleanup and audio_path and Path(audio_path).exists(): - Path(audio_path).unlink() + if should_cleanup and audio_path: + Path(audio_path).unlink(missing_ok=True)