From b11300c6f1e3b79d539754992e42f74bda7ab397 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Thu, 23 Apr 2026 22:28:52 -0700 Subject: [PATCH] FIX VLGuard review fixes: document subcategory mapping, move import to top - Add category-to-subcategory mapping in VLGuardSubcategory docstring - Move huggingface_hub import to top of file (transitive dep of datasets) - Update test mock path accordingly Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../seed_datasets/remote/vlguard_dataset.py | 13 ++++++++++--- tests/unit/datasets/test_vlguard_dataset.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py index db0f7aa76..8dc1f5303 100644 --- a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py @@ -9,6 +9,8 @@ from enum import Enum from pathlib import Path +from huggingface_hub import hf_hub_download + from pyrit.common.path import DB_DATA_PATH from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -37,7 +39,14 @@ class VLGuardCategory(Enum): class VLGuardSubcategory(Enum): - """Subcategories in the VLGuard dataset, nested under the main categories.""" + """ + Subcategories in the VLGuard dataset. Each subcategory belongs to a specific category. + + privacy: personal data + risky behavior: professional advice, political, sexually explicit, violence + deception: disinformation + discrimination: sex, race, other + """ PERSONAL_DATA = "personal data" PROFESSIONAL_ADVICE = "professional advice" @@ -263,8 +272,6 @@ async def _download_dataset_files_async(self, *, cache: bool = True) -> tuple[li Returns: tuple[list[dict], Path]: Tuple of (metadata list, image directory path). """ - from huggingface_hub import hf_hub_download - cache_dir = DB_DATA_PATH / "seed-prompt-entries" / "vlguard" cache_dir.mkdir(parents=True, exist_ok=True) diff --git a/tests/unit/datasets/test_vlguard_dataset.py b/tests/unit/datasets/test_vlguard_dataset.py index 29ad7eb85..0ef8bbad0 100644 --- a/tests/unit/datasets/test_vlguard_dataset.py +++ b/tests/unit/datasets/test_vlguard_dataset.py @@ -409,7 +409,7 @@ def mock_hf_download(*, repo_id, filename, repo_type, local_dir, token): with ( patch("pyrit.datasets.seed_datasets.remote.vlguard_dataset.DB_DATA_PATH", tmp_path), - patch("huggingface_hub.hf_hub_download", side_effect=mock_hf_download), + patch("pyrit.datasets.seed_datasets.remote.vlguard_dataset.hf_hub_download", side_effect=mock_hf_download), ): metadata, result_dir = await loader._download_dataset_files_async(cache=False)