diff --git a/pyrit/datasets/seed_datasets/remote/__init__.py b/pyrit/datasets/seed_datasets/remote/__init__.py index 284afac5a1..ba271096ce 100644 --- a/pyrit/datasets/seed_datasets/remote/__init__.py +++ b/pyrit/datasets/seed_datasets/remote/__init__.py @@ -39,6 +39,9 @@ from pyrit.datasets.seed_datasets.remote.darkbench_dataset import ( _DarkBenchDataset, ) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.decoding_trust_toxicity_dataset import ( + _DecodingTrustToxicityDataset, +) # noqa: F401 from pyrit.datasets.seed_datasets.remote.equitymedqa_dataset import ( _EquityMedQADataset, ) # noqa: F401 @@ -161,6 +164,7 @@ "ComicJailbreakTemplateConfig", "_DangerousQADataset", "_DarkBenchDataset", + "_DecodingTrustToxicityDataset", "_EquityMedQADataset", "_ForbiddenQuestionsDataset", "_HarmBenchDataset", diff --git a/pyrit/datasets/seed_datasets/remote/decoding_trust_toxicity_dataset.py b/pyrit/datasets/seed_datasets/remote/decoding_trust_toxicity_dataset.py new file mode 100644 index 0000000000..36806aa7ad --- /dev/null +++ b/pyrit/datasets/seed_datasets/remote/decoding_trust_toxicity_dataset.py @@ -0,0 +1,268 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +from typing import Any, Literal, cast + +from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( + _RemoteDatasetLoader, +) +from pyrit.models import SeedDataset, SeedPrompt + +logger = logging.getLogger(__name__) + + +# Pinned commit SHA of AI-secure/DecodingTrust `main` (2024-09-16). +# Pinning prevents silent upstream changes from altering the prompt set. +_DECODING_TRUST_COMMIT = "161ae8321ced62f45fcd9ceb412e05b47c603cd4" +_BASE_URL = ( + f"https://raw.githubusercontent.com/AI-secure/DecodingTrust/{_DECODING_TRUST_COMMIT}/data/toxicity/user_prompts/" +) +_TOXIC_URL = f"{_BASE_URL}toxic.jsonl" +_NONTOXIC_URL = f"{_BASE_URL}nontoxic.jsonl" + +# Mapping from Perspective API score keys (as published in the DecodingTrust JSONL) +# to harm_categories strings exposed on each SeedPrompt. +_PERSPECTIVE_SCORE_KEYS: tuple[str, ...] = ( + "toxicity", + "severe_toxicity", + "profanity", + "sexually_explicit", + "identity_attack", + "threat", + "insult", + "flirtation", +) + + +class _DecodingTrustToxicityDataset(_RemoteDatasetLoader): + """ + Loader for the Toxicity perspective of the DecodingTrust benchmark. + + DecodingTrust [1] evaluates LLM trustworthiness across eight perspectives; + the Toxicity perspective publishes ~2.4k prompts sampled from + RealToxicityPrompts [2] and re-scored with the Perspective API. Each record + contains a plain-text prompt fragment, eight Perspective-API scores + (toxicity, severe_toxicity, profanity, sexually_explicit, identity_attack, + threat, insult, flirtation) and a ``challenging`` flag for the + high-difficulty subset. + + The published data lives in two files at + ``AI-secure/DecodingTrust/data/toxicity/user_prompts/`` — ``toxic.jsonl`` + (1196 records, high toxicity) and ``nontoxic.jsonl`` (1200 records, low + toxicity). The loader fetches the source files at runtime from + ``raw.githubusercontent.com`` (no redistribution) at a pinned commit SHA. + + Per-prompt ``harm_categories`` are derived from the Perspective scores at + construction time using ``toxicity_threshold``: any score + ``>= toxicity_threshold`` adds the corresponding category. This avoids + guessing where the source provides no label. + + References: + - Wang et al., 2023. *DecodingTrust: A Comprehensive Assessment of + Trustworthiness in GPT Models.* https://arxiv.org/abs/2306.11698 + - Gehman et al., 2020. *RealToxicityPrompts: Evaluating Neural Toxic + Degeneration in Language Models.* https://arxiv.org/abs/2009.11462 + - https://github.com/AI-secure/DecodingTrust + + License: + DecodingTrust is distributed under CC BY-SA 4.0. PyRIT fetches the + prompts at runtime and does not redistribute them. Full attribution to + the DecodingTrust and RealToxicityPrompts authors is recorded on every + ``SeedPrompt`` produced. + + Warning: + The toxic subset contains profanity, sexual content, identity attacks + and other content that may be offensive. Use only for authorised + red-teaming and safety research. + """ + + # Class-level metadata picked up by _RemoteDatasetLoader._parse_metadata. + # See pyrit/datasets/seed_datasets/seed_metadata.py for the schema. + modalities: list[str] = ["text"] + size: str = "large" # 1196 (default toxic) – 2396 (subset="all") + tags: set[str] = {"default", "safety", "toxicity"} + + _AUTHORS: tuple[str, ...] = ( + # DecodingTrust authors + "Boxin Wang", + "Weixin Chen", + "Hengzhi Pei", + "Chulin Xie", + "Mintong Kang", + "Chenhui Zhang", + "Chejian Xu", + "Zidi Xiong", + "Ritik Dutta", + "Rylan Schaeffer", + "Sang T. Truong", + "Simran Arora", + "Mantas Mazeika", + "Dan Hendrycks", + "Zinan Lin", + "Yu Cheng", + "Sanmi Koyejo", + "Dawn Song", + "Bo Li", + # RealToxicityPrompts authors (DecodingTrust subsamples from RTP) + "Samuel Gehman", + "Suchin Gururangan", + "Maarten Sap", + "Yejin Choi", + "Noah A. Smith", + ) + + _GROUPS: tuple[str, ...] = ( + "University of Illinois Urbana-Champaign", + "Stanford University", + "University of California, Berkeley", + "Center for AI Safety", + "Microsoft Research", + "Allen Institute for Artificial Intelligence", + "University of Washington", + "Paul G. Allen School of Computer Science & Engineering", + ) + + _DESCRIPTION = ( + "Toxicity perspective of the DecodingTrust benchmark (Wang et al., 2023). " + "Prompts are sampled from RealToxicityPrompts (Gehman et al., 2020) and " + "re-scored with the Perspective API. Each record carries eight Perspective " + "scores (toxicity, severe_toxicity, profanity, sexually_explicit, " + "identity_attack, threat, insult, flirtation) and a `challenging` flag. " + "Per-prompt harm_categories are derived from these scores at load time." + ) + + def __init__( + self, + *, + subset: Literal["toxic", "nontoxic", "all"] = "toxic", + challenging_only: bool = False, + toxicity_threshold: float = 0.5, + ) -> None: + """ + Initialize the DecodingTrust Toxicity dataset loader. + + Args: + subset: Which subset of the toxicity prompts to load. ``"toxic"`` + (default) loads the ~1.2k high-toxicity prompts, ``"nontoxic"`` + loads the ~1.2k low-toxicity prompts, ``"all"`` loads both. + The default matches the red-teaming use case + (cf. microsoft/PyRIT#1798). + challenging_only: If True, retain only records flagged + ``challenging: true`` in the source — the adversarial subset + emphasised by the DecodingTrust authors. + toxicity_threshold: Perspective score at or above which a category + is added to a SeedPrompt's ``harm_categories``. Applied + independently to each of the eight Perspective scores. + """ + self.subset: Literal["toxic", "nontoxic", "all"] = subset + self.challenging_only = challenging_only + self.toxicity_threshold = toxicity_threshold + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "decoding_trust_toxicity" + + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: + """ + Fetch the DecodingTrust Toxicity prompts and return them as a SeedDataset. + + Args: + cache: Whether to cache the fetched JSONL files locally. Defaults to True. + + Returns: + SeedDataset: A SeedDataset whose seeds are the selected toxicity prompts. + + Raises: + ValueError: If the source JSONL contains items that are not dicts. + """ + logger.info(f"Loading DecodingTrust Toxicity subset={self.subset!r} from {_BASE_URL}") + + urls: list[str] = [] + if self.subset in ("toxic", "all"): + urls.append(_TOXIC_URL) + if self.subset in ("nontoxic", "all"): + urls.append(_NONTOXIC_URL) + + records: list[tuple[str, dict[str, Any]]] = [] + for url in urls: + raw = self._fetch_from_url(source=url, source_type="public_url", cache=cache) + records.extend((url, item) for item in cast("list[Any]", raw)) + + seed_prompts = self._records_to_seed_prompts(records=records) + logger.info(f"Loaded {len(seed_prompts)} prompts from DecodingTrust Toxicity") + return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) + + def _records_to_seed_prompts(self, *, records: list[tuple[str, dict[str, Any]]]) -> list[SeedPrompt]: + """ + Convert raw JSONL records into SeedPrompt instances. + + Args: + records: List of ``(source_url, record_dict)`` pairs straight from the + JSONL files. The ``source_url`` becomes the per-prompt ``source``. + + Returns: + List of SeedPrompt objects, one per record that passes filters and + yields a non-empty ``prompt.text``. + + Raises: + ValueError: If any record is not a dict (i.e. the source is malformed). + """ + seed_prompts: list[SeedPrompt] = [] + for source_url, item in records: + if not isinstance(item, dict): + raise ValueError( + f"Expected DecodingTrust toxicity records to be JSON objects, got {type(item).__name__}: {item!r}" + ) + + if self.challenging_only and not item.get("challenging"): + continue + + prompt_obj = item.get("prompt") or {} + if not isinstance(prompt_obj, dict): + logger.warning(f"Skipping record with non-dict 'prompt' field (type={type(prompt_obj).__name__})") + continue + + text = prompt_obj.get("text") + if not isinstance(text, str) or not text: + logger.warning("Skipping record with missing or empty 'prompt.text'") + continue + + seed_prompts.append( + SeedPrompt( + value=text, + data_type="text", + dataset_name=self.dataset_name, + harm_categories=self._derive_harm_categories(prompt_scores=prompt_obj), + description=self._DESCRIPTION, + source=source_url, + authors=list(self._AUTHORS), + groups=list(self._GROUPS), + ) + ) + return seed_prompts + + def _derive_harm_categories(self, *, prompt_scores: dict[str, Any]) -> list[str]: + """ + Map Perspective API scores on a single prompt to a list of harm categories. + + Each Perspective score whose value is ``>= self.toxicity_threshold`` + contributes its key as a harm category. Non-numeric or missing scores + are silently skipped — the source occasionally omits dimensions. + + Args: + prompt_scores: The ``prompt`` sub-dict from a DecodingTrust toxicity + record, containing the eight Perspective scores plus ``text``. + + Returns: + Sorted list of harm category strings. Sorted for deterministic + output (tests, snapshots). + """ + categories: list[str] = [] + for key in _PERSPECTIVE_SCORE_KEYS: + score = prompt_scores.get(key) + if isinstance(score, (int, float)) and score >= self.toxicity_threshold: + categories.append(key) + categories.sort() + return categories diff --git a/tests/unit/datasets/test_decoding_trust_toxicity_dataset.py b/tests/unit/datasets/test_decoding_trust_toxicity_dataset.py new file mode 100644 index 0000000000..fbf4df984b --- /dev/null +++ b/tests/unit/datasets/test_decoding_trust_toxicity_dataset.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import patch + +import pytest + +from pyrit.datasets.seed_datasets.remote.decoding_trust_toxicity_dataset import ( + _NONTOXIC_URL, + _TOXIC_URL, + _DecodingTrustToxicityDataset, +) +from pyrit.models import SeedDataset, SeedPrompt + + +def _record( + *, + text: str, + challenging: bool = True, + toxicity: float = 0.9, + severe_toxicity: float = 0.7, + profanity: float = 0.8, + sexually_explicit: float = 0.2, + identity_attack: float = 0.1, + threat: float = 0.1, + insult: float = 0.6, + flirtation: float = 0.1, +) -> dict: + """Build a DecodingTrust-shaped record for use in test fixtures.""" + return { + "filename": "fake.txt", + "begin": 0, + "end": len(text), + "challenging": challenging, + "prompt": { + "text": text, + "toxicity": toxicity, + "severe_toxicity": severe_toxicity, + "profanity": profanity, + "sexually_explicit": sexually_explicit, + "identity_attack": identity_attack, + "threat": threat, + "insult": insult, + "flirtation": flirtation, + }, + "continuation": {"text": " continuation"}, + } + + +@pytest.fixture +def toxic_records(): + return [ + _record(text="toxic one", toxicity=0.91, profanity=0.85, insult=0.7), + _record(text="toxic two", toxicity=0.95, profanity=0.4, insult=0.3), + ] + + +@pytest.fixture +def nontoxic_records(): + return [ + _record( + text="nontoxic one", + challenging=False, + toxicity=0.1, + severe_toxicity=0.05, + profanity=0.05, + sexually_explicit=0.05, + identity_attack=0.05, + threat=0.05, + insult=0.05, + flirtation=0.05, + ), + ] + + +class TestDecodingTrustToxicityDataset: + """Test the DecodingTrust Toxicity dataset loader.""" + + async def test_default_subset_is_toxic_only(self, toxic_records): + """Default constructor must fetch only toxic.jsonl (maintainer-requested default).""" + loader = _DecodingTrustToxicityDataset() + + with patch.object(loader, "_fetch_from_url", return_value=toxic_records) as mock_fetch: + dataset = await loader.fetch_dataset_async() + + assert mock_fetch.call_count == 1 + mock_fetch.assert_called_once_with(source=_TOXIC_URL, source_type="public_url", cache=True) + assert isinstance(dataset, SeedDataset) + assert len(dataset.seeds) == 2 + + async def test_subset_nontoxic_only(self, nontoxic_records): + """subset='nontoxic' fetches only the nontoxic URL.""" + loader = _DecodingTrustToxicityDataset(subset="nontoxic") + + with patch.object(loader, "_fetch_from_url", return_value=nontoxic_records) as mock_fetch: + dataset = await loader.fetch_dataset_async() + + mock_fetch.assert_called_once_with(source=_NONTOXIC_URL, source_type="public_url", cache=True) + assert len(dataset.seeds) == 1 + + async def test_subset_all_fetches_both_urls(self, toxic_records, nontoxic_records): + """subset='all' fetches both URLs and concatenates.""" + loader = _DecodingTrustToxicityDataset(subset="all") + + with patch.object( + loader, + "_fetch_from_url", + side_effect=[toxic_records, nontoxic_records], + ) as mock_fetch: + dataset = await loader.fetch_dataset_async() + + assert mock_fetch.call_count == 2 + assert {call.kwargs["source"] for call in mock_fetch.call_args_list} == {_TOXIC_URL, _NONTOXIC_URL} + assert len(dataset.seeds) == len(toxic_records) + len(nontoxic_records) + + async def test_harm_category_mapping_threshold(self, toxic_records): + """Each Perspective score >= threshold contributes its key to harm_categories.""" + loader = _DecodingTrustToxicityDataset(toxicity_threshold=0.5) + + with patch.object(loader, "_fetch_from_url", return_value=toxic_records): + dataset = await loader.fetch_dataset_async() + + first = dataset.seeds[0] + assert isinstance(first, SeedPrompt) + # toxicity=0.91, severe_toxicity=0.7, profanity=0.85, insult=0.7 → all in + # sexually_explicit=0.2, identity_attack=0.1, threat=0.1, flirtation=0.1 → out + assert first.harm_categories is not None + assert sorted(first.harm_categories) == ["insult", "profanity", "severe_toxicity", "toxicity"] + # Second record has lower profanity/insult — only toxicity & severe_toxicity stay + second = dataset.seeds[1] + assert second.harm_categories is not None + assert sorted(second.harm_categories) == ["severe_toxicity", "toxicity"] + + async def test_high_threshold_yields_empty_categories(self, nontoxic_records): + """At a very high threshold, nontoxic prompts get empty harm_categories.""" + loader = _DecodingTrustToxicityDataset(subset="nontoxic", toxicity_threshold=0.99) + + with patch.object(loader, "_fetch_from_url", return_value=nontoxic_records): + dataset = await loader.fetch_dataset_async() + + assert dataset.seeds[0].harm_categories == [] + + async def test_challenging_only_filter(self): + """challenging_only=True drops records with challenging=False.""" + records = [ + _record(text="keep me", challenging=True), + _record(text="drop me", challenging=False), + ] + loader = _DecodingTrustToxicityDataset(challenging_only=True) + + with patch.object(loader, "_fetch_from_url", return_value=records): + dataset = await loader.fetch_dataset_async() + + assert [seed.value for seed in dataset.seeds] == ["keep me"] + + async def test_skips_records_missing_prompt_text(self): + """Records with missing or empty prompt.text are skipped, not failed.""" + records = [ + _record(text="ok"), + {"challenging": True, "prompt": {}}, # missing text entirely + {"challenging": True, "prompt": {"text": ""}}, # empty string + {"challenging": True, "prompt": "not a dict"}, # malformed prompt field + ] + loader = _DecodingTrustToxicityDataset() + + with patch.object(loader, "_fetch_from_url", return_value=records): + dataset = await loader.fetch_dataset_async() + + assert [seed.value for seed in dataset.seeds] == ["ok"] + + async def test_raises_on_non_dict_record(self): + """Top-level non-dict records (string, number, list) are a hard error.""" + loader = _DecodingTrustToxicityDataset() + + with patch.object(loader, "_fetch_from_url", return_value=["this should be a dict"]): + with pytest.raises(ValueError, match="JSON objects"): + await loader.fetch_dataset_async() + + async def test_per_seed_metadata(self, toxic_records): + """Every SeedPrompt carries dataset_name, source, description, authors, groups.""" + loader = _DecodingTrustToxicityDataset() + + with patch.object(loader, "_fetch_from_url", return_value=toxic_records): + dataset = await loader.fetch_dataset_async() + + seed = dataset.seeds[0] + assert seed.dataset_name == "decoding_trust_toxicity" + assert seed.source == _TOXIC_URL + assert seed.data_type == "text" + assert seed.description is not None and "DecodingTrust" in seed.description + assert seed.authors is not None and "Boxin Wang" in seed.authors # DecodingTrust lead + assert "Samuel Gehman" in seed.authors # RealToxicityPrompts lead + assert seed.groups is not None and "University of Illinois Urbana-Champaign" in seed.groups + + def test_dataset_name(self): + """dataset_name property returns the canonical id.""" + assert _DecodingTrustToxicityDataset().dataset_name == "decoding_trust_toxicity" + + def test_default_source_is_pinned_commit(self): + """The fetcher URLs must reference a pinned commit SHA, not main.""" + from pyrit.datasets.seed_datasets.remote import decoding_trust_toxicity_dataset as mod + + assert mod._DECODING_TRUST_COMMIT == "161ae8321ced62f45fcd9ceb412e05b47c603cd4" + assert mod._DECODING_TRUST_COMMIT in _TOXIC_URL + assert mod._DECODING_TRUST_COMMIT in _NONTOXIC_URL + assert _TOXIC_URL.endswith("/data/toxicity/user_prompts/toxic.jsonl") + assert _NONTOXIC_URL.endswith("/data/toxicity/user_prompts/nontoxic.jsonl") + + def test_class_level_metadata(self): + """modalities/size/tags are set for the metadata parser.""" + assert _DecodingTrustToxicityDataset.modalities == ["text"] + assert _DecodingTrustToxicityDataset.size == "large" + assert _DecodingTrustToxicityDataset.tags == {"default", "safety", "toxicity"}