diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 234056c32..c71dca408 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -20,6 +20,7 @@ UnsupportedCapabilityBehavior, ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget from pyrit.prompt_target.http_target.http_target import HTTPTarget @@ -77,6 +78,7 @@ "RealtimeTarget", "TargetCapabilities", "TargetConfiguration", + "TargetRequirements", "UnsupportedCapabilityBehavior", "TextTarget", "WebSocketCopilotTarget", diff --git a/pyrit/prompt_target/common/target_requirements.py b/pyrit/prompt_target/common/target_requirements.py new file mode 100644 index 000000000..95182b47b --- /dev/null +++ b/pyrit/prompt_target/common/target_requirements.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyrit.prompt_target.common.target_capabilities import CapabilityName + from pyrit.prompt_target.common.target_configuration import TargetConfiguration + + +@dataclass(frozen=True) +class TargetRequirements: + """ + Declarative description of what a consumer (attack, converter, scorer) + requires from a target. + + Consumers define their requirements once and validate them against a + ``TargetConfiguration`` at construction time. This replaces ad-hoc + ``isinstance`` checks and scattered capability branching. + """ + + # The set of capabilities the consumer requires. + required_capabilities: frozenset[CapabilityName] = field(default_factory=frozenset) + + def validate(self, *, configuration: TargetConfiguration) -> None: + """ + Validate that the target configuration can satisfy all requirements. + + Iterates over every required capability and delegates to + ``TargetConfiguration.ensure_can_handle``, which checks native support + first and then consults the handling policy. All violations are + collected and reported in a single ``ValueError``. + + Args: + configuration (TargetConfiguration): The target configuration to validate against. + + Raises: + ValueError: If any required capability is missing and the policy + does not allow adaptation. + """ + errors: list[str] = [] + for capability in sorted(self.required_capabilities, key=lambda c: c.value): + try: + configuration.ensure_can_handle(capability=capability) + except ValueError as exc: + errors.append(str(exc)) + if errors: + raise ValueError( + f"Target does not satisfy {len(errors)} required capability(ies):\n" + + "\n".join(f" - {e}" for e in errors) + ) diff --git a/tests/unit/target/test_target_requirements.py b/tests/unit/target/test_target_requirements.py new file mode 100644 index 000000000..002ccf086 --- /dev/null +++ b/tests/unit/target/test_target_requirements.py @@ -0,0 +1,131 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from pyrit.prompt_target.common.target_requirements import TargetRequirements + + +@pytest.fixture +def adapt_all_policy(): + return CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, + } + ) + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_init_default_has_empty_capabilities(): + reqs = TargetRequirements() + assert reqs.required_capabilities == frozenset() + + +def test_init_with_capabilities(): + reqs = TargetRequirements( + required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) + ) + assert CapabilityName.MULTI_TURN in reqs.required_capabilities + assert CapabilityName.SYSTEM_PROMPT in reqs.required_capabilities + + +# --------------------------------------------------------------------------- +# validate — all pass +# --------------------------------------------------------------------------- + + +def test_validate_passes_when_target_supports_all_natively(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps) + reqs = TargetRequirements( + required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) + ) + reqs.validate(configuration=config) + + +def test_validate_passes_when_policy_is_adapt(adapt_all_policy): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + reqs = TargetRequirements( + required_capabilities=frozenset({CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT}) + ) + reqs.validate(configuration=config) + + +def test_validate_passes_with_empty_requirements(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps) + reqs = TargetRequirements() + reqs.validate(configuration=config) + + +# --------------------------------------------------------------------------- +# validate — failures +# --------------------------------------------------------------------------- + + +def test_validate_raises_when_capability_missing_and_no_policy(): + # EDITABLE_HISTORY has no normalizer and no handling policy — validate raises. + caps = TargetCapabilities(supports_editable_history=False, supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps) + reqs = TargetRequirements(required_capabilities=frozenset({CapabilityName.EDITABLE_HISTORY})) + with pytest.raises(ValueError, match="supports_editable_history"): + reqs.validate(configuration=config) + + +def test_validate_raises_when_capability_missing_and_policy_raise(adapt_all_policy): + # json_output is missing and the policy is RAISE — validate raises. + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False, supports_json_output=False) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + reqs = TargetRequirements(required_capabilities=frozenset({CapabilityName.JSON_OUTPUT})) + with pytest.raises(ValueError, match="supports_json_output"): + reqs.validate(configuration=config) + + +def test_validate_collects_all_unsatisfied_capabilities(adapt_all_policy): + """When multiple capabilities are missing, validate reports all violations.""" + caps = TargetCapabilities( + supports_multi_turn=False, + supports_system_prompt=False, + supports_json_output=False, + supports_editable_history=False, + ) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + # json_output => RAISE, editable_history => no policy (raises) + reqs = TargetRequirements( + required_capabilities=frozenset({CapabilityName.JSON_OUTPUT, CapabilityName.EDITABLE_HISTORY}) + ) + with pytest.raises(ValueError, match="2 required capability") as exc_info: + reqs.validate(configuration=config) + assert "supports_json_output" in str(exc_info.value) + assert "supports_editable_history" in str(exc_info.value) + + +def test_validate_mixed_adapt_and_raise(adapt_all_policy): + """One capability adapts but another raises — validate should raise.""" + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False, supports_json_output=False) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + # multi_turn and system_prompt => ADAPT (OK), json_output => RAISE (fail) + reqs = TargetRequirements( + required_capabilities=frozenset( + {CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT, CapabilityName.JSON_OUTPUT} + ) + ) + with pytest.raises(ValueError, match="supports_json_output"): + reqs.validate(configuration=config)