diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 9a36bde86e..b770ac9722 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -469,24 +469,30 @@ def _resolve_converter_modifiers(self, *, modifiers: list[str], token: str) -> l async def _initialize_scenario_async(self, *, request: RunScenarioRequest, init_kwargs: dict[str, Any]) -> Scenario: """ - Instantiate the scenario and call initialize_async. + Build and initialize the scenario via the registry. + + Delegates the full create + set-parameters + initialize lifecycle to + ``ScenarioRegistry.create_and_initialize_async`` so the registry owns + scenario creation and initialization. The run-specific common parameters + (target, strategies, dataset config, concurrency) are resolved by + ``_build_init_kwargs`` and forwarded as ``init_kwargs``. Args: request: The run request (for scenario_name, scenario_params, and scenario_result_id). - init_kwargs: The kwargs to pass to scenario.initialize_async. + init_kwargs: The resolved common parameters to pass to + scenario.initialize_async. Returns: The fully initialized Scenario instance ready for run_async. """ - constructor_kwargs: dict[str, Any] = {} - if request.scenario_result_id: - constructor_kwargs["scenario_result_id"] = request.scenario_result_id scenario_registry = ScenarioRegistry.get_registry_singleton() - scenario = scenario_registry.create_instance(request.scenario_name, **constructor_kwargs) - scenario.set_params_from_args(args=request.scenario_params or {}) - await scenario.initialize_async(**init_kwargs) - return scenario + return await scenario_registry.create_and_initialize_async( + request.scenario_name, + scenario_params=request.scenario_params or {}, + scenario_result_id=request.scenario_result_id or None, + **init_kwargs, + ) async def _execute_run_async(self, *, scenario_result_id: str) -> None: """ @@ -579,8 +585,8 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari return ScenarioRunSummary( scenario_result_id=scenario_result_id, - scenario_name=scenario_result.scenario_identifier.name, - scenario_version=scenario_result.scenario_identifier.version, + scenario_name=scenario_result.scenario_name, + scenario_version=scenario_result.scenario_version, status=status, created_at=scenario_result.creation_time, updated_at=scenario_result.completion_time or scenario_result.creation_time, diff --git a/pyrit/backend/services/scenario_service.py b/pyrit/backend/services/scenario_service.py index 3bc1fdce6f..66f0e4f0cd 100644 --- a/pyrit/backend/services/scenario_service.py +++ b/pyrit/backend/services/scenario_service.py @@ -36,7 +36,6 @@ def _metadata_to_registered_scenario(metadata: ScenarioMetadata) -> RegisteredSc aggregate_strategies=list(metadata.aggregate_strategies), all_strategies=list(metadata.all_strategies), default_datasets=list(metadata.default_datasets), - max_dataset_size=metadata.max_dataset_size, supported_parameters=list(metadata.supported_parameters), ) @@ -68,7 +67,7 @@ async def list_scenarios_async( Returns: ScenarioListResponse with paginated scenario summaries. """ - all_metadata = self._registry.list_metadata() + all_metadata = self._registry.get_all_registered_class_metadata() all_summaries = [_metadata_to_registered_scenario(m) for m in all_metadata] page, has_more = self._paginate(items=all_summaries, cursor=cursor, limit=limit) @@ -76,7 +75,12 @@ async def list_scenarios_async( return ListRegisteredScenariosResponse( items=page, - pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), + pagination=PaginationInfo( + limit=limit, + has_more=has_more, + next_cursor=next_cursor, + prev_cursor=cursor, + ), ) async def get_scenario_async(self, *, scenario_name: str) -> RegisteredScenario | None: @@ -89,10 +93,9 @@ async def get_scenario_async(self, *, scenario_name: str) -> RegisteredScenario Returns: ScenarioSummary if found, None otherwise. """ - all_metadata = self._registry.list_metadata() - for metadata in all_metadata: - if metadata.registry_name == scenario_name: - return _metadata_to_registered_scenario(metadata) + metadata = self._registry.get_registered_class_metadata(scenario_name) + if metadata is not None: + return _metadata_to_registered_scenario(metadata) return None @staticmethod diff --git a/pyrit/cli/_output.py b/pyrit/cli/_output.py index 70986c4ec1..3b80160b92 100644 --- a/pyrit/cli/_output.py +++ b/pyrit/cli/_output.py @@ -106,8 +106,7 @@ def print_scenario_list(*, items: list[RegisteredScenario]) -> None: if sc.default_strategy: print(f" Default Strategy: {sc.default_strategy}") if sc.default_datasets: - suffix = f", max {sc.max_dataset_size} per dataset" if sc.max_dataset_size else "" - print(f" Default Datasets ({len(sc.default_datasets)}{suffix}):") + print(f" Default Datasets ({len(sc.default_datasets)}):") print(_wrap(text=", ".join(sc.default_datasets), indent=" ")) if sc.supported_parameters: print(" Supported Parameters:") diff --git a/pyrit/memory/alembic/versions/d4e6f8a0b2c4_rename_scenario_init_data_to_scenario_identifier.py b/pyrit/memory/alembic/versions/d4e6f8a0b2c4_rename_scenario_init_data_to_scenario_identifier.py new file mode 100644 index 0000000000..71f39169fd --- /dev/null +++ b/pyrit/memory/alembic/versions/d4e6f8a0b2c4_rename_scenario_init_data_to_scenario_identifier.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Rename ScenarioResultEntries.scenario_init_data to scenario_identifier. + +The scenario result now stores a single canonical ``ScenarioIdentifier`` in +place of the loose ``scenario_init_data`` blob. Rename the column and make it +non-nullable to match the model. + +Revision ID: d4e6f8a0b2c4 +Revises: c3d5e7f9a1b2 +Create Date: 2026-07-02 10:35:00.000000 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d4e6f8a0b2c4" +down_revision: str | None = "c3d5e7f9a1b2" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply this schema upgrade.""" + # SQLite does not support ALTER COLUMN in place; batch_alter_table recreates + # the table so the rename and NOT NULL change are portable across SQLite and + # Azure SQL. + with op.batch_alter_table("ScenarioResultEntries") as batch_op: + batch_op.alter_column( + "scenario_init_data", + new_column_name="scenario_identifier", + existing_type=sa.JSON(), + existing_nullable=True, + nullable=False, + ) + + +def downgrade() -> None: + """Revert this schema upgrade.""" + with op.batch_alter_table("ScenarioResultEntries") as batch_op: + batch_op.alter_column( + "scenario_identifier", + new_column_name="scenario_init_data", + existing_type=sa.JSON(), + existing_nullable=False, + nullable=True, + ) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 0cf4e3782d..a7cfdbf552 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -45,6 +45,7 @@ EvaluationIdentifier, MessagePiece, PromptDataType, + ScenarioEvaluationIdentifier, ScenarioIdentifier, ScenarioResult, ScenarioRunState, @@ -1102,7 +1103,8 @@ class ScenarioResultEntry(Base): scenario_description (str): Optional detailed description of the scenario. scenario_version (int): Version number of the scenario definition (default: 1). pyrit_version (str): Version of PyRIT framework used during scenario execution. - scenario_init_data (dict): Optional initialization parameters used to configure the scenario. + scenario_identifier (dict): Canonical scenario identity (class name, version, + techniques, datasets, resolved params, objective target / scorer children). objective_target_identifier (dict): Identifier for the target being evaluated in the scenario. objective_scorer_identifier (dict): Optional identifier for the scorer used to evaluate results. scenario_run_state (str): Current execution state of the scenario @@ -1130,7 +1132,9 @@ class ScenarioResultEntry(Base): scenario_description = mapped_column(Unicode, nullable=True) scenario_version = mapped_column(INTEGER, nullable=False, default=1) pyrit_version = mapped_column(String, nullable=False) - scenario_init_data: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + #: Canonical scenario identity (class name, version, techniques, datasets, + #: resolved params, objective target / scorer children) with its eval hash. + scenario_identifier: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False) objective_target_identifier: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False) objective_scorer_identifier: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) scenario_run_state: Mapped[str] = mapped_column(String, nullable=False, default="CREATED") @@ -1161,24 +1165,32 @@ def __init__(self, *, entry: ScenarioResult) -> None: entry (ScenarioResult): The scenario result object to convert into a database entry. """ self.id = entry.id - self.scenario_name = entry.scenario_identifier.name - self.scenario_description = entry.scenario_identifier.description - self.scenario_version = entry.scenario_identifier.version - self.pyrit_version = entry.scenario_identifier.pyrit_version - self.scenario_init_data = entry.scenario_identifier.init_data + self.scenario_name = entry.scenario_name + self.scenario_description = entry.scenario_description + self.scenario_version = entry.scenario_version + self.pyrit_version = entry.pyrit_version + + # Stamp the canonical scenario identifier's eval_hash fresh and store it. + # The denormalized target / scorer columns are populated from the same + # identifier for DB-level filtering (never a value trusted from storage). + scenario_identifier = entry.scenario_identifier.with_eval_hash( + ScenarioEvaluationIdentifier(entry.scenario_identifier).eval_hash + ) + self.scenario_identifier = scenario_identifier.model_dump() + # Convert ComponentIdentifier to dict for JSON storage + target_identifier = entry.objective_target_identifier self.objective_target_identifier = ( # type: ignore[ty:invalid-assignment] - entry.objective_target_identifier.model_dump() if entry.objective_target_identifier else None + target_identifier.model_dump() if target_identifier else None ) # Always recompute eval_hash before dumping so the stored JSON carries the # freshly computed value for DB-level filtering (never a value from storage). - if entry.objective_scorer_identifier: - entry.objective_scorer_identifier = entry.objective_scorer_identifier.with_eval_hash( - ScorerEvaluationIdentifier(entry.objective_scorer_identifier).eval_hash + scorer_identifier = entry.objective_scorer_identifier + if scorer_identifier: + scorer_identifier = scorer_identifier.with_eval_hash( + ScorerEvaluationIdentifier(scorer_identifier).eval_hash ) - self.objective_scorer_identifier = ( - entry.objective_scorer_identifier.model_dump() if entry.objective_scorer_identifier else None - ) + self.objective_scorer_identifier = scorer_identifier.model_dump() if scorer_identifier else None self.scenario_run_state = entry.scenario_run_state.value self.labels = entry.labels self.number_tries = entry.number_tries @@ -1211,29 +1223,22 @@ def get_scenario_result(self) -> ScenarioResult: Returns: ScenarioResult object with scenario metadata but empty attack_results """ - # Recreate ScenarioIdentifier with the stored pyrit_version + # The canonical scenario identity (name / version / techniques / datasets / + # params / target / scorer children) is stored as one JSON column and + # reconstructed here as a typed ScenarioIdentifier. eval_hash is recomputed + # on reload (never trusted from storage). The denormalized target / scorer + # columns exist only for DB-level filtering, so they aren't read back here. stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION - scenario_identifier = ScenarioIdentifier( - name=self.scenario_name, - description=self.scenario_description or "", - scenario_version=self.scenario_version, - init_data=self.scenario_init_data, - pyrit_version=stored_version, - ) # Return empty attack_results - will be populated by memory_interface attack_results: dict[str, list[AttackResult]] = {} - # Convert dict back to ComponentIdentifier with the stored pyrit_version; - # eval_hash is recomputed on reload via ScorerEvaluationIdentifier. - scorer_identifier = _load_identifier( - self.objective_scorer_identifier, - pyrit_version=stored_version, - eval_identifier_cls=ScorerEvaluationIdentifier, + base_identifier = ComponentIdentifier.model_validate( + {**self.scenario_identifier, "pyrit_version": stored_version} + ) + scenario_identifier = ScenarioIdentifier.from_component_identifier( + base_identifier.with_eval_hash(ScenarioEvaluationIdentifier(base_identifier).eval_hash) ) - - # Convert dict back to ComponentIdentifier for reconstruction - target_identifier = _load_identifier(self.objective_target_identifier) # Deserialize display_group_map if stored display_group_map: dict[str, str] | None = None @@ -1243,9 +1248,8 @@ def get_scenario_result(self) -> ScenarioResult: return ScenarioResult( id=self.id, scenario_identifier=scenario_identifier, - objective_target_identifier=target_identifier, + scenario_description=self.scenario_description or "", attack_results=attack_results, - objective_scorer_identifier=scorer_identifier, scenario_run_state=ScenarioRunState(self.scenario_run_state), labels=self.labels or {}, creation_time=self.timestamp, diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index a0ab2296a9..a5ad9523bb 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -41,6 +41,8 @@ IdentifierFilter, IdentifierType, ObjectiveTargetEvaluationIdentifier, + ScenarioEvaluationIdentifier, + ScenarioIdentifier, ScorerEvaluationIdentifier, ScorerIdentifier, SeedIdentifier, @@ -95,7 +97,7 @@ ) from pyrit.models.question_answering import QuestionAnsweringDataset, QuestionAnsweringEntry, QuestionChoice from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT -from pyrit.models.results.scenario_result import ScenarioIdentifier, ScenarioResult, ScenarioRunState +from pyrit.models.results.scenario_result import ScenarioResult, ScenarioRunState from pyrit.models.results.strategy_result import StrategyResult, StrategyResultT from pyrit.models.retry_event import RetryEvent from pyrit.models.score import Score, ScoreType, UnvalidatedScore @@ -194,6 +196,7 @@ "ScaleDescription", "Score", "ScoreType", + "ScenarioEvaluationIdentifier", "ScorerEvaluationIdentifier", "ScorerIdentifier", "ScenarioIdentifier", diff --git a/pyrit/models/catalog/scenario.py b/pyrit/models/catalog/scenario.py index 6268a6bf82..5c510fab15 100644 --- a/pyrit/models/catalog/scenario.py +++ b/pyrit/models/catalog/scenario.py @@ -34,7 +34,6 @@ class RegisteredScenario(BaseModel): ) all_strategies: list[str] = Field(..., description="All available concrete strategy names") default_datasets: list[str] = Field(..., description="Default dataset names used by the scenario") - max_dataset_size: int | None = Field(None, description="Maximum items per dataset (None means unlimited)") supported_parameters: list[Parameter] = Field( default_factory=list, description="Scenario-declared custom parameters" ) diff --git a/pyrit/models/identifiers/__init__.py b/pyrit/models/identifiers/__init__.py index c6dd9278fc..1fa03b2319 100644 --- a/pyrit/models/identifiers/__init__.py +++ b/pyrit/models/identifiers/__init__.py @@ -23,6 +23,7 @@ ChildEvalRule, EvaluationIdentifier, ObjectiveTargetEvaluationIdentifier, + ScenarioEvaluationIdentifier, ScorerEvaluationIdentifier, compute_eval_hash, compute_inner_attack_eval_hash, @@ -31,6 +32,7 @@ from pyrit.models.identifiers.evaluation_markers import EvalMarker, Evaluate, Exclude, Include, Unwrap from pyrit.models.identifiers.identifier_filters import IdentifierFilter, IdentifierType from pyrit.models.identifiers.param_markers import Param, ParamMarker +from pyrit.models.identifiers.scenario_identifier import ScenarioIdentifier from pyrit.models.identifiers.scorer_identifier import ScorerIdentifier from pyrit.models.identifiers.seed_identifier import SeedIdentifier from pyrit.models.identifiers.target_identifier import TargetIdentifier @@ -57,8 +59,10 @@ "REGISTRY_NAME_PATTERN", "Param", "ParamMarker", + "ScenarioEvaluationIdentifier", "ScorerEvaluationIdentifier", "ScorerIdentifier", + "ScenarioIdentifier", "SeedIdentifier", "snake_case_to_class_name", "TARGET_EVAL_PARAM_FALLBACKS", diff --git a/pyrit/models/identifiers/evaluation_identifier.py b/pyrit/models/identifiers/evaluation_identifier.py index 73823fe6cd..eda0400c3d 100644 --- a/pyrit/models/identifiers/evaluation_identifier.py +++ b/pyrit/models/identifiers/evaluation_identifier.py @@ -18,6 +18,8 @@ * ``AtomicAttackEvaluationIdentifier`` — attack-domain concrete subclass. * ``ObjectiveTargetEvaluationIdentifier`` — leaf-target subclass used by the analytics layer to key cached results by behavioral target configuration. +* ``ScenarioEvaluationIdentifier`` — scenario-domain concrete subclass used to + key a scenario run's behavioral identity (for resume drift detection). """ from __future__ import annotations @@ -30,6 +32,7 @@ from pyrit.models.identifiers.attack_identifier import AttackIdentifier from pyrit.models.identifiers.component_identifier import ComponentIdentifier, config_hash from pyrit.models.identifiers.evaluation_markers import EvalMarker, Exclude, Include, Unwrap +from pyrit.models.identifiers.scenario_identifier import ScenarioIdentifier from pyrit.models.identifiers.scorer_identifier import ScorerIdentifier from pyrit.models.identifiers.target_identifier import TargetIdentifier @@ -531,6 +534,21 @@ class ObjectiveTargetEvaluationIdentifier(EvaluationIdentifier): EVAL_ROOT: ClassVar[type[ComponentIdentifier] | None] = TargetIdentifier +class ScenarioEvaluationIdentifier(EvaluationIdentifier): + """ + Evaluation identity for scenarios. + + Rules are derived from ``ScenarioIdentifier``'s field markers: the definition + ``version`` and resolved ``techniques`` / ``datasets`` feed the hash, the + resolved scenario ``params`` are included, and the ``objective_target`` / + ``objective_scorer`` children contribute their full behavioral projection. + Two runs of the same scenario definition with the same configuration produce + the same eval hash, which backs resume drift detection. + """ + + EVAL_ROOT: ClassVar[type[ComponentIdentifier] | None] = ScenarioIdentifier + + def compute_inner_attack_eval_hash(*, attack: AttackStrategy[Any, Any]) -> str: """ Predict the eval hash the executor will stamp on persisted child rows diff --git a/pyrit/models/identifiers/scenario_identifier.py b/pyrit/models/identifiers/scenario_identifier.py new file mode 100644 index 0000000000..62b9b3ca65 --- /dev/null +++ b/pyrit/models/identifiers/scenario_identifier.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Strongly-typed projection of a scenario's identifier.""" + +from __future__ import annotations + +from typing import Annotated, ClassVar + +from pydantic import Field + +from pyrit.models.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.identifiers.evaluation_markers import Evaluate +from pyrit.models.identifiers.param_markers import Param +from pyrit.models.identifiers.scorer_identifier import ( # noqa: TC001 + ScorerIdentifier, # runtime-required by Pydantic field annotations +) +from pyrit.models.identifiers.target_identifier import ( # noqa: TC001 + TargetIdentifier, # runtime-required by Pydantic field annotations +) +from pyrit.models.parameter import ComponentType + + +class ScenarioIdentifier(ComponentIdentifier): + """ + Strongly-typed projection of a ``Scenario``'s ``ComponentIdentifier``. + + Like the sibling projections (``TargetIdentifier`` / ``ScorerIdentifier``), + this is produced by the scenario registry when a scenario is built. It is also + the canonical per-run identity carried on the ``ScenarioResult`` aggregate and + persisted with it: the scenario class name (``class_name``), definition + ``version``, resolved ``techniques`` / ``datasets``, the resolved scenario + ``params``, and the ``objective_target`` / ``objective_scorer`` child + references all live here rather than as separate denormalized fields. Its eval + hash (via ``ScenarioEvaluationIdentifier``) backs resume drift detection. + + Promotes the scenario's behavioral identity to typed ``params`` fields that + feed both the content and eval hash: the definition ``version`` and the + resolved ``techniques`` / ``datasets`` the scenario runs (a v1 vs a v2, or a + different technique / dataset selection, is a different identity). The two + run-resolved reference slots — ``objective_target`` (a ``PromptTarget``) and + ``objective_scorer`` (a ``Scorer``) — are promoted children the registry + resolves by name from the target / scorer registries when building a scenario. + """ + + component_type: ClassVar[ComponentType] = ComponentType.SCENARIO + + #: Scenario definition version. Behavioral identity (a v1 and a v2 of the same + #: scenario are different identities); not a constructor input. + version: Annotated[int | None, Evaluate.Include(), Param.Exclude()] = None + #: Resolved technique names the scenario runs. Behavioral identity; not a + #: constructor input (the registry populates it from the selected strategies). + techniques: Annotated[list[str] | None, Evaluate.Include(), Param.Exclude()] = None + #: Resolved dataset names the scenario runs. Behavioral identity; not a + #: constructor input (the registry populates it from the dataset config). + datasets: Annotated[list[str] | None, Evaluate.Include(), Param.Exclude()] = None + #: Target the scenario attacks. Run-resolved reference resolved by name from + #: the target registry. + objective_target: Annotated[TargetIdentifier | None, Evaluate.Include(), Param.Include()] = Field(default=None) + #: Primary scorer the scenario evaluates with. Run-resolved reference resolved + #: by name from the scorer registry. + objective_scorer: Annotated[ScorerIdentifier | None, Evaluate.Include(), Param.Include()] = Field(default=None) diff --git a/pyrit/models/parameter.py b/pyrit/models/parameter.py index 6f016e4f1f..82c2759c91 100644 --- a/pyrit/models/parameter.py +++ b/pyrit/models/parameter.py @@ -25,12 +25,14 @@ class ComponentType(str, Enum): Each member maps one-to-one to a registry singleton that resolves references of that family by name (``TARGET`` → ``TargetRegistry``, ``CONVERTER`` → - ``ConverterRegistry``, ``SCORER`` → ``ScorerRegistry``). + ``ConverterRegistry``, ``SCORER`` → ``ScorerRegistry``, ``SCENARIO`` → + ``ScenarioRegistry``). """ TARGET = "target" CONVERTER = "converter" SCORER = "scorer" + SCENARIO = "scenario" class ParameterDestination(str, Enum): diff --git a/pyrit/models/results/__init__.py b/pyrit/models/results/__init__.py index b57cb1ef37..4332080b4f 100644 --- a/pyrit/models/results/__init__.py +++ b/pyrit/models/results/__init__.py @@ -12,9 +12,9 @@ - ScenarioRunState: Lifecycle state of a scenario run. """ +from pyrit.models.identifiers.scenario_identifier import ScenarioIdentifier from pyrit.models.results.attack_result import AttackOutcome, AttackResult, AttackResultT from pyrit.models.results.scenario_result import ( - ScenarioIdentifier, ScenarioResult, ScenarioRunState, ) diff --git a/pyrit/models/results/scenario_result.py b/pyrit/models/results/scenario_result.py index 7e3b632ce4..6719f1affd 100644 --- a/pyrit/models/results/scenario_result.py +++ b/pyrit/models/results/scenario_result.py @@ -9,71 +9,34 @@ from enum import Enum from typing import Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator -import pyrit from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.identifiers.component_identifier import ( # noqa: TC001 (runtime-required by Pydantic field annotations) - ComponentIdentifier, -) + +# Runtime-required by Pydantic field / computed-field annotations. +from pyrit.models.identifiers.scenario_identifier import ScenarioIdentifier # noqa: TC001 +from pyrit.models.identifiers.scorer_identifier import ScorerIdentifier # noqa: TC001 +from pyrit.models.identifiers.target_identifier import TargetIdentifier # noqa: TC001 from pyrit.models.results.attack_result import AttackOutcome, AttackResult logger = logging.getLogger(__name__) -class ScenarioIdentifier(BaseModel): - """ - Identifier describing the executed scenario. - """ - - model_config = ConfigDict(extra="forbid", populate_by_name=True) - - #: Name of the scenario. - name: str - #: Description of the scenario. - description: str = "" - #: Version of the scenario. Accepts the legacy ``scenario_version`` kwarg/wire key. - version: int = Field(default=1, alias="scenario_version") - #: PyRIT version string. Defaults to the current installed version. - pyrit_version: str = Field(default=pyrit.__version__) - #: Optional initialization data. - init_data: dict[str, Any] | None = None - - def to_dict(self) -> dict[str, Any]: - """ - Serialize to a JSON-compatible dictionary. - - Deprecated: use ``model_dump(by_alias=True)`` instead. - - Returns: - dict[str, Any]: Serialized payload. - """ - print_deprecation_message( - old_item="ScenarioIdentifier.to_dict()", - new_item="ScenarioIdentifier.model_dump(by_alias=True)", - removed_in="0.16.0", - ) - return self.model_dump(by_alias=True) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> ScenarioIdentifier: - """ - Reconstruct a ScenarioIdentifier from a dictionary. - - Deprecated: use ``model_validate(...)`` instead. +__all__ = ["ScenarioResult", "ScenarioRunState"] - Args: - data (dict[str, Any]): Dictionary as produced by ``model_dump(by_alias=True)``. - Returns: - ScenarioIdentifier: Reconstructed instance. - """ - print_deprecation_message( - old_item="ScenarioIdentifier.from_dict(...)", - new_item="ScenarioIdentifier.model_validate(...)", - removed_in="0.16.0", - ) - return cls.model_validate(data) +#: Denormalized identity fields exposed as ``@computed_field`` projections of +#: ``scenario_identifier``. They appear in ``model_dump`` output but are not +#: settable inputs, so they are dropped when reconstructing from a dump. +_COMPUTED_IDENTITY_FIELDS = frozenset( + { + "scenario_name", + "scenario_version", + "pyrit_version", + "objective_target_identifier", + "objective_scorer_identifier", + } +) class ScenarioRunState(str, Enum): @@ -104,12 +67,14 @@ class ScenarioResult(BaseModel): #: Scenario result ID. id: uuid.UUID = Field(default_factory=uuid.uuid4) - #: Identifier for the executed scenario. + #: Canonical scenario identity for this run. Carries the scenario class name, + #: definition version, resolved techniques / datasets, the resolved scenario + #: params, and the ``objective_target`` / ``objective_scorer`` child references. + #: Its eval hash backs resume drift detection. scenario_identifier: ScenarioIdentifier - #: Target identifier. - objective_target_identifier: ComponentIdentifier | None - #: Objective scorer identifier, or None if the scenario has no objective scorer. - objective_scorer_identifier: ComponentIdentifier | None + #: Human-readable scenario description (the scenario class docstring). Display / + #: catalog metadata snapshotted on the result — not part of scenario identity. + scenario_description: str = "" #: Results grouped by atomic attack name. attack_results: dict[str, list[AttackResult]] #: Current scenario run state. @@ -137,6 +102,59 @@ class ScenarioResult(BaseModel): #: scenario operates on. Keys are not part of any public contract and may evolve. metadata: dict[str, Any] = Field(default_factory=dict) + @model_validator(mode="before") + @classmethod + def _drop_computed_identity_fields(cls, data: Any) -> Any: + """ + Ignore denormalized computed identity fields when reconstructing from a dump. + + ``scenario_name`` / ``scenario_version`` / ``pyrit_version`` / + ``objective_target_identifier`` / ``objective_scorer_identifier`` are + ``@computed_field`` projections of ``scenario_identifier`` that show up in + ``model_dump`` output but are not settable inputs. Dropping them lets + ``model_validate(model_dump(...))`` round-trip under ``extra="forbid"``. + + Args: + data (Any): Raw input passed to validation (a dict when reconstructing from a dump). + + Returns: + Any: The input with computed identity keys removed when it is a dict; otherwise unchanged. + """ + if isinstance(data, dict): + return {key: value for key, value in data.items() if key not in _COMPUTED_IDENTITY_FIELDS} + return data + + @computed_field # type: ignore[prop-decorator] + @property + def scenario_name(self) -> str: + """Scenario class name (e.g. ``"ContentHarms"``), delegated to the identifier.""" + return self.scenario_identifier.class_name + + @computed_field # type: ignore[prop-decorator] + @property + def scenario_version(self) -> int: + """Scenario definition version, delegated to the identifier (defaults to 1).""" + version = self.scenario_identifier.version + return version if version is not None else 1 + + @computed_field # type: ignore[prop-decorator] + @property + def pyrit_version(self) -> str: + """PyRIT version the scenario ran under, delegated to the identifier.""" + return self.scenario_identifier.pyrit_version + + @computed_field # type: ignore[prop-decorator] + @property + def objective_target_identifier(self) -> TargetIdentifier | None: + """Target the scenario attacks, delegated to the identifier.""" + return self.scenario_identifier.objective_target + + @computed_field # type: ignore[prop-decorator] + @property + def objective_scorer_identifier(self) -> ScorerIdentifier | None: + """Primary scorer the scenario evaluates with, delegated to the identifier.""" + return self.scenario_identifier.objective_scorer + def get_strategies_used(self) -> list[str]: """ Get the list of strategies used in this scenario. @@ -275,7 +293,9 @@ def normalize_scenario_name(scenario_name: str) -> str: If the input is already in PascalCase or doesn't match the snake_case pattern, it is returned unchanged. - This is the inverse of ScenarioRegistry._class_name_to_scenario_name(). + This is the inverse of the snake_case registry-name conversion + (``class_name_to_snake_case``) applied to scenario class names during + discovery. Args: scenario_name (str): The scenario name to normalize. diff --git a/pyrit/output/scenario_result/pretty.py b/pyrit/output/scenario_result/pretty.py index d8654c0bfd..7f570fec20 100644 --- a/pyrit/output/scenario_result/pretty.py +++ b/pyrit/output/scenario_result/pretty.py @@ -97,7 +97,7 @@ def _render_header(self, result: ScenarioResult) -> str: lines: list[str] = [] lines.append("\n") lines.append(self._format_colored("=" * self._width, Fore.CYAN)) - header_text = f"📊 SCENARIO RESULTS: {result.scenario_identifier.name}" + header_text = f"📊 SCENARIO RESULTS: {result.scenario_name}" lines.append(self._format_colored(header_text.center(self._width), Style.BRIGHT, Fore.CYAN)) lines.append(self._format_colored("=" * self._width, Fore.CYAN)) return "".join(lines) @@ -154,25 +154,17 @@ async def render_async(self, result: ScenarioResult) -> str: lines.append(self._render_section_header("Scenario Information")) lines.append(self._format_colored(f"{self._indent}📋 Scenario Details", Style.BRIGHT)) - lines.append(self._format_colored(f"{self._indent * 2}• Name: {result.scenario_identifier.name}", Fore.CYAN)) + lines.append(self._format_colored(f"{self._indent * 2}• Name: {result.scenario_name}", Fore.CYAN)) lines.append( - self._format_colored( - f"{self._indent * 2}• Scenario Version: {result.scenario_identifier.version}", Fore.CYAN - ) - ) - lines.append( - self._format_colored( - f"{self._indent * 2}• PyRIT Version: {result.scenario_identifier.pyrit_version}", Fore.CYAN - ) + self._format_colored(f"{self._indent * 2}• Scenario Version: {result.scenario_version}", Fore.CYAN) ) + lines.append(self._format_colored(f"{self._indent * 2}• PyRIT Version: {result.pyrit_version}", Fore.CYAN)) - if result.scenario_identifier.description: + if result.scenario_description: lines.append(self._format_colored(f"{self._indent * 2}• Description:", Fore.CYAN)) desc_indent = self._indent * 4 available_width = 120 - len(desc_indent) - wrapped_lines = textwrap.wrap( - result.scenario_identifier.description, width=available_width, break_long_words=False - ) + wrapped_lines = textwrap.wrap(result.scenario_description, width=available_width, break_long_words=False) lines.extend(self._format_colored(f"{desc_indent}{line}", Fore.CYAN) for line in wrapped_lines) lines.append("\n") diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py index 8694c9c6c3..906b23ba7b 100644 --- a/pyrit/registry/__init__.py +++ b/pyrit/registry/__init__.py @@ -9,14 +9,14 @@ ClassEntry, InitializerMetadata, InitializerRegistry, - ScenarioMetadata, - ScenarioRegistry, ) from pyrit.registry.components import ( AttackTechniqueMetadata, AttackTechniqueRegistry, ConverterMetadata, ConverterRegistry, + ScenarioMetadata, + ScenarioRegistry, ScorerMetadata, ScorerRegistry, TargetMetadata, diff --git a/pyrit/registry/class_registries/__init__.py b/pyrit/registry/class_registries/__init__.py index fac5c0ebdf..1314305b96 100644 --- a/pyrit/registry/class_registries/__init__.py +++ b/pyrit/registry/class_registries/__init__.py @@ -18,16 +18,10 @@ InitializerMetadata, InitializerRegistry, ) -from pyrit.registry.class_registries.scenario_registry import ( - ScenarioMetadata, - ScenarioRegistry, -) __all__ = [ "BaseClassRegistry", "ClassEntry", - "ScenarioRegistry", - "ScenarioMetadata", "InitializerRegistry", "InitializerMetadata", ] diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py deleted file mode 100644 index ef5a839d5e..0000000000 --- a/pyrit/registry/class_registries/scenario_registry.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Scenario registry for discovering and managing PyRIT scenarios. - -This module provides a unified registry for discovering all available Scenario subclasses -from the pyrit.scenario.scenarios module and from user-defined initialization scripts. -""" - -from __future__ import annotations - -import logging -from dataclasses import dataclass, field -from pathlib import Path -from typing import TYPE_CHECKING - -from pyrit.models import class_name_to_snake_case -from pyrit.registry.base import ClassRegistryEntry -from pyrit.registry.class_registries.base_class_registry import ( - BaseClassRegistry, - ClassEntry, -) -from pyrit.registry.discovery import ( - discover_in_package, - discover_subclasses_in_loaded_modules, -) - -if TYPE_CHECKING: - from pyrit.models import Parameter - from pyrit.scenario.core import Scenario - -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class ScenarioMetadata(ClassRegistryEntry): - """ - Metadata describing a registered Scenario class. - - Use get_class() to get the actual class. - """ - - # The default strategy name (e.g., "single_turn") - default_strategy: str = field(kw_only=True) - - # All available strategy names for this scenario. - all_strategies: tuple[str, ...] = field(kw_only=True) - - # Aggregate strategies that combine multiple attack approaches. - aggregate_strategies: tuple[str, ...] = field(kw_only=True) - - # Default dataset names used by this scenario. - default_datasets: tuple[str, ...] = field(kw_only=True) - - # Maximum number of items per dataset. - max_dataset_size: int | None = field(kw_only=True) - - # Scenario-declared custom parameters. - supported_parameters: tuple[Parameter, ...] = field(kw_only=True, default=()) - - -class ScenarioRegistry(BaseClassRegistry["Scenario", ScenarioMetadata]): - """ - Registry for discovering and managing available scenario classes. - - This class discovers all Scenario subclasses from: - 1. Built-in scenarios in pyrit.scenario.scenarios module - 2. User-defined scenarios from initialization scripts (set via globals) - - Scenarios are identified by their dotted name (e.g., "garak.encoding", "foundry.red_team_agent"). - """ - - def __init__(self, *, lazy_discovery: bool = True) -> None: - """ - Initialize the scenario registry. - - Args: - lazy_discovery: If True, discovery is deferred until first access. - Defaults to True for performance. - """ - super().__init__(lazy_discovery=lazy_discovery) - - def _discover(self) -> None: - """Discover all built-in scenarios from pyrit.scenario.scenarios module.""" - self._discover_builtin_scenarios() - - def _discover_builtin_scenarios(self) -> None: - """ - Discover all built-in scenarios from pyrit.scenario.scenarios module. - - This method dynamically imports all modules in the scenarios package - and registers any Scenario subclasses found. - """ - from pyrit.scenario.core import Scenario - - try: - import pyrit.scenario.scenarios as scenarios_package - - # Get the path to the scenarios package - package_file = scenarios_package.__file__ - if package_file is None: - if hasattr(scenarios_package, "__path__"): - package_path = Path(scenarios_package.__path__[0]) - else: - logger.error("Cannot determine scenarios package location") - return - else: - package_path = Path(package_file).parent - - # Discover scenarios using the shared discovery utility - # Use ``package_name.module_name`` as the registry name - for registry_name, scenario_class in discover_in_package( - package_path=package_path, - package_name="pyrit.scenario.scenarios", - base_class=Scenario, - recursive=True, - ): - # Skip deprecated alias classes - doc = (scenario_class.__doc__ or "").strip() - if doc.startswith("Deprecated alias"): - logger.debug(f"Skipping deprecated alias: {scenario_class.__name__}") - continue - - # Skip re-exported aliases: if the class was defined in a different - # module than the one being discovered, it's an alias (e.g., - # ContentHarms in content_harms.py is really RapidResponse from - # rapid_response.py). - class_module = getattr(scenario_class, "__module__", "") - expected_module_suffix = registry_name.replace(".", "/") - if not class_module.endswith(registry_name.replace("/", ".")): - # Build the full expected module name for comparison - expected_module = f"pyrit.scenario.scenarios.{registry_name.replace('/', '.')}" - if class_module != expected_module: - logger.debug( - f"Skipping alias '{scenario_class.__name__}' in '{registry_name}' " - f"(defined in {class_module})" - ) - continue - - # Check for registry key collision - if registry_name in self._class_entries: - logger.warning( - f"Scenario registry name collision: '{registry_name}' " - f"conflicts with an already-registered scenario. Original " - f"scenario is kept: {self._class_entries[registry_name].registered_class.__name__}" - ) - continue - - entry = ClassEntry(registered_class=scenario_class) - self._class_entries[registry_name] = entry - logger.debug(f"Registered built-in scenario: {registry_name} ({scenario_class.__name__})") - - except Exception as e: - logger.error(f"Failed to discover built-in scenarios: {e}") - - def discover_user_scenarios(self) -> None: - """ - Discover user-defined scenarios from global variables. - - After initialization scripts are executed, they may define Scenario subclasses - and store them in globals. This method searches for such classes. - - User scenarios will override built-in scenarios with the same name. - """ - from pyrit.scenario.core import Scenario - - try: - for _, scenario_class in discover_subclasses_in_loaded_modules(base_class=Scenario): - # Check if this is a user-defined class (not from pyrit.scenario.scenarios) - if not scenario_class.__module__.startswith("pyrit.scenario.scenarios"): - # Convert class name to snake_case for scenario name - registry_name = class_name_to_snake_case(scenario_class.__name__, suffix="Scenario") - entry = ClassEntry(registered_class=scenario_class) - self._class_entries[registry_name] = entry - logger.info(f"Registered user-defined scenario: {registry_name} ({scenario_class.__name__})") - - except Exception as e: - logger.debug(f"Failed to discover user scenarios: {e}") - - def _build_metadata(self, name: str, entry: ClassEntry[Scenario]) -> ScenarioMetadata: - """ - Build metadata for a Scenario class. - - Instantiates the scenario with no arguments and reads the strategy/dataset - configuration off the instance. Every registered scenario MUST be no-arg - instantiable (defer required-input validation to ``initialize_async`` or - ``_get_atomic_attacks_async``); otherwise this raises ``TypeError``. - - Args: - name: The registry name of the scenario. - entry: The ClassEntry containing the scenario class. - - Returns: - ScenarioMetadata describing the scenario class. - - Raises: - TypeError: If ``scenario_class()`` cannot be called with no arguments. - """ - scenario_class = entry.registered_class - - description = entry.get_description(fallback="No description available") - - supported_parameters = tuple(scenario_class.supported_parameters()) - - try: - instance = scenario_class() # type: ignore[ty:missing-argument] - except TypeError as exc: - raise TypeError( - f"Scenario {scenario_class.__module__}.{scenario_class.__name__} (registered as " - f"{name!r}) must be instantiable with no arguments so the registry can introspect " - f"its strategies and default dataset config. Make all constructor parameters " - f"optional (defaulting to None) and defer required-input validation to " - f"initialize_async() or _get_atomic_attacks_async(). Original error: {exc}" - ) from exc - - strategy_class = instance._strategy_class - default_strategy_value = instance._default_strategy.value - all_strategies = tuple(s.value for s in strategy_class.get_all_strategies()) - aggregate_strategies = tuple(s.value for s in strategy_class.get_aggregate_strategies()) - default_datasets = tuple(instance._default_dataset_config.dataset_names) - max_dataset_size = instance._default_dataset_config.max_dataset_size - - return ScenarioMetadata( - class_name=scenario_class.__name__, - class_module=scenario_class.__module__, - class_description=description, - registry_name=name, - default_strategy=default_strategy_value, - all_strategies=all_strategies, - aggregate_strategies=aggregate_strategies, - default_datasets=default_datasets, - max_dataset_size=max_dataset_size, - supported_parameters=supported_parameters, - ) diff --git a/pyrit/registry/components/__init__.py b/pyrit/registry/components/__init__.py index 012bb96ce6..3f4e98f3f2 100644 --- a/pyrit/registry/components/__init__.py +++ b/pyrit/registry/components/__init__.py @@ -22,6 +22,10 @@ ConverterMetadata, ConverterRegistry, ) +from pyrit.registry.components.scenario_registry import ( + ScenarioMetadata, + ScenarioRegistry, +) from pyrit.registry.components.scorer_registry import ( ScorerMetadata, ScorerRegistry, @@ -38,6 +42,8 @@ "ConverterMetadata", "ScorerRegistry", "ScorerMetadata", + "ScenarioRegistry", + "ScenarioMetadata", "TargetRegistry", "TargetMetadata", ] diff --git a/pyrit/registry/components/attack_technique_registry.py b/pyrit/registry/components/attack_technique_registry.py index a5466c114f..97f9ce8a04 100644 --- a/pyrit/registry/components/attack_technique_registry.py +++ b/pyrit/registry/components/attack_technique_registry.py @@ -253,4 +253,7 @@ def register_from_factories( tags=tags, ) - logger.debug("Technique registration complete (%d total in registry)", len(self.instances)) + logger.debug( + "Technique registration complete (%d total in registry)", + len(self.instances), + ) diff --git a/pyrit/registry/components/scenario_registry.py b/pyrit/registry/components/scenario_registry.py new file mode 100644 index 0000000000..2f28b87fe0 --- /dev/null +++ b/pyrit/registry/components/scenario_registry.py @@ -0,0 +1,210 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario registry for discovering and managing PyRIT scenarios. + +A ``Registry`` for ``Scenario`` classes that discovers all available subclasses +from the ``pyrit.scenario.scenarios`` package and from user-defined initialization +scripts. Like the other component registries it is a unified ``Registry``: it owns +a validated class catalog and builds instances via ``create_instance``. Its +buildable classes are keyed by **dotted registry name** (e.g. ``garak.encoding``) +rather than by class name, so ``_discover``/``_get_registry_name`` are overridden. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from pyrit.models import class_name_to_snake_case +from pyrit.models.identifiers.scenario_identifier import ScenarioIdentifier +from pyrit.registry.base import ClassRegistryEntry +from pyrit.registry.registry import Registry + +if TYPE_CHECKING: + from types import ModuleType + + from pyrit.models import Parameter + from pyrit.models.identifiers.component_identifier import ComponentIdentifier + from pyrit.scenario.core import Scenario + + +@dataclass(frozen=True) +class ScenarioMetadata(ClassRegistryEntry): + """ + Metadata describing a registered Scenario class. + + Use get_class() to get the actual class. + """ + + # The default strategy name (e.g., "single_turn") + default_strategy: str = field(kw_only=True) + + # All available strategy names for this scenario. + all_strategies: tuple[str, ...] = field(kw_only=True) + + # Aggregate strategies that combine multiple attack approaches. + aggregate_strategies: tuple[str, ...] = field(kw_only=True) + + # Default dataset names used by this scenario. + default_datasets: tuple[str, ...] = field(kw_only=True) + + # Scenario-declared custom parameters. + supported_parameters: tuple[Parameter, ...] = field(kw_only=True, default=()) + + +class ScenarioRegistry(Registry["Scenario", ScenarioMetadata]): + """ + Registry for discovering and managing available scenario classes. + + Discovers every concrete ``Scenario`` subclass under ``pyrit.scenario.scenarios`` + via the unified ``Registry`` base (recursive subclass enumeration). Unlike the + component registries, scenarios are keyed by their **dotted module path** (e.g. + ``"garak.encoding"``, ``"foundry.red_team_agent"``) rather than class name, so + only ``_get_registry_name`` and ``_build_metadata`` are customized. + """ + + _DISCOVERY_PACKAGE = "pyrit.scenario.scenarios" + + def _identifier_type(self) -> type[ComponentIdentifier] | None: + """Return ``ScenarioIdentifier`` so ``Param.*`` markers drive derivation.""" + return ScenarioIdentifier + + def _metadata_class(self) -> type[ScenarioMetadata]: + """Return the concrete metadata dataclass this registry builds.""" + return ScenarioMetadata + + def _base_type(self) -> type[Scenario]: + """Return the ``Scenario`` base class, imported lazily.""" + from pyrit.scenario.core import Scenario + + return Scenario + + def _discovery_package(self) -> ModuleType: + """Return the ``pyrit.scenario.scenarios`` package scanned for scenario classes.""" + import pyrit.scenario.scenarios as scenarios_package + + return scenarios_package + + def _get_registry_name(self, cls: type[Scenario]) -> str: + """ + Key scenarios by their dotted module path (e.g. ``"airt.rapid_response"``). + + The path is the scenario module's location relative to + ``pyrit.scenario.scenarios``. Scenarios discovered outside that package + (e.g. user-defined ones) fall back to a suffix-stripped snake_case class name. + + Args: + cls (type[Scenario]): The scenario class. + + Returns: + str: The dotted registry name. + """ + module = cls.__module__ or "" + prefix = f"{self._DISCOVERY_PACKAGE}." + if module.startswith(prefix): + relative = module[len(prefix) :] + if relative: + return relative + return class_name_to_snake_case(cls.__name__, suffix="Scenario") + + def _build_metadata(self, name: str, cls: type[Scenario]) -> ScenarioMetadata: + """ + Build metadata for a Scenario class. + + Instantiates the scenario with no arguments and reads the strategy/dataset + configuration off the instance. Every registered scenario MUST be no-arg + instantiable (defer required-input validation to ``initialize_async`` or + ``_get_atomic_attacks_async``); otherwise this raises ``TypeError``. + + Args: + name: The registry name of the scenario. + cls: The scenario class to describe. + + Returns: + ScenarioMetadata describing the scenario class. + + Raises: + TypeError: If ``cls()`` cannot be called with no arguments. + """ + description = ClassRegistryEntry.description_from_docstring(cls, fallback="No description available") + + supported_parameters = tuple(cls.supported_parameters()) + + try: + instance = cls() # type: ignore[ty:missing-argument] + except TypeError as exc: + raise TypeError( + f"Scenario {cls.__module__}.{cls.__name__} (registered as " + f"{name!r}) must be instantiable with no arguments so the registry can introspect " + f"its strategies and default dataset config. Make all constructor parameters " + f"optional (defaulting to None) and defer required-input validation to " + f"initialize_async() or _get_atomic_attacks_async(). Original error: {exc}" + ) from exc + + strategy_class = instance._strategy_class + default_strategy_value = instance._default_strategy.value + all_strategies = tuple(s.value for s in strategy_class.get_all_strategies()) + aggregate_strategies = tuple(s.value for s in strategy_class.get_aggregate_strategies()) + default_datasets = tuple(instance._default_dataset_config.dataset_names) + + return ScenarioMetadata( + class_name=cls.__name__, + class_module=cls.__module__, + class_description=description, + registry_name=name, + default_strategy=default_strategy_value, + all_strategies=all_strategies, + aggregate_strategies=aggregate_strategies, + default_datasets=default_datasets, + supported_parameters=supported_parameters, + ) + + async def create_and_initialize_async( + self, + name: str, + *, + scenario_params: dict[str, Any] | None = None, + scenario_result_id: str | None = None, + **initialize_kwargs: Any, + ) -> Scenario: + """ + Build, parameterize, and initialize a scenario in one call. + + This is the canonical entry point for producing a run-ready ``Scenario``: + the registry — not the caller — owns the full lifecycle. + + 1. **create** the scenario via ``create_instance`` (seeding + ``scenario_result_id`` when resuming an existing run), + 2. **set parameters** — the scenario-specific declared parameters (from + ``supported_parameters()``) are coerced/validated/injected via + ``Scenario.set_params_from_args``, + 3. **initialize** — the run-resolved common parameters (``objective_target``, + ``scenario_strategies``, ``dataset_config``, ``max_concurrency``, + ``max_retries``, ``memory_labels``, ``include_baseline``) are forwarded + to ``Scenario.initialize_async``. + + Prefer this over manually chaining ``create_instance`` + + ``set_params_from_args`` + ``initialize_async``. + + Args: + name (str): The registry name of the scenario (e.g. ``"foundry.red_team_agent"``). + scenario_params (dict[str, Any] | None): Scenario-specific declared + parameters to set before initialization. Defaults to an empty mapping. + scenario_result_id (str | None): Existing scenario-result id to resume, + or ``None`` to start a fresh run. + **initialize_kwargs (Any): Run-resolved common parameters forwarded to + ``Scenario.initialize_async`` (notably ``objective_target``). + + Returns: + Scenario: The fully initialized scenario, ready for ``run_async``. + """ + constructor_kwargs: dict[str, Any] = {} + if scenario_result_id: + constructor_kwargs["scenario_result_id"] = scenario_result_id + + scenario = self.create_instance(name, **constructor_kwargs) + scenario.set_params_from_args(args=scenario_params or {}) + await scenario.initialize_async(**initialize_kwargs) + return scenario diff --git a/pyrit/registry/instance_registry.py b/pyrit/registry/instance_registry.py index c163570abf..b5c9ed8b98 100644 --- a/pyrit/registry/instance_registry.py +++ b/pyrit/registry/instance_registry.py @@ -196,7 +196,9 @@ def _resolve_instance_type(self) -> type | None: return resolved @staticmethod - def _normalize_tags(tags: dict[str, str] | list[str] | None = None) -> dict[str, str]: + def _normalize_tags( + tags: dict[str, str] | list[str] | None = None, + ) -> dict[str, str]: """ Normalize tags into a ``dict[str, str]``. diff --git a/pyrit/registry/object_registries/base_instance_registry.py b/pyrit/registry/object_registries/base_instance_registry.py index 1310cb2753..39b53ba1c5 100644 --- a/pyrit/registry/object_registries/base_instance_registry.py +++ b/pyrit/registry/object_registries/base_instance_registry.py @@ -105,7 +105,9 @@ def reset_instance(cls) -> None: del cls._instances[cls] @staticmethod - def _normalize_tags(tags: dict[str, str] | list[str] | None = None) -> dict[str, str]: + def _normalize_tags( + tags: dict[str, str] | list[str] | None = None, + ) -> dict[str, str]: """ Normalize tags into a ``dict[str, str]``. diff --git a/pyrit/registry/registry.py b/pyrit/registry/registry.py index b1ba20992f..c5d1154a80 100644 --- a/pyrit/registry/registry.py +++ b/pyrit/registry/registry.py @@ -213,8 +213,8 @@ def _base_type(self) -> type[T]: """ Return the domain base class to discover (e.g. ``PromptTarget``), imported lazily. - Used by the default ``_discover`` to filter the package's exports, and by - instance-holding registries to constrain their ``instances`` container. + Used by the default ``_discover`` to enumerate the base's concrete subclasses, + and by instance-holding registries to constrain their ``instances`` container. Importing lazily keeps the heavy domain package out of module load so the registry's lazy discovery is preserved. @@ -230,7 +230,10 @@ def _base_type(self) -> type[T]: def _discovery_package(self) -> ModuleType: """ - Return the package whose ``__all__`` the default ``_discover`` scans. + Return the package whose concrete subclasses the default ``_discover`` registers. + + Importing this package must load its component modules so the base type's + subclasses exist in memory for enumeration. Returns: ModuleType: The domain package (e.g. ``pyrit.prompt_target``). @@ -244,30 +247,77 @@ def _discovery_package(self) -> ModuleType: def _discover(self) -> None: """ - Populate the catalog from the domain package. - - Scans ``_discovery_package().__all__`` and registers every concrete subclass - of ``_base_type()`` (skipping the base itself and abstract classes), keyed by - class name via ``register_class``. Registries with bespoke discovery override - this method instead of supplying ``_base_type``/``_discovery_package``. + Populate the catalog with every concrete subclass of the domain base. + + Imports ``_discovery_package()`` and enumerates the recursive subclasses of + ``_base_type()`` in memory, registering every concrete (non-abstract) one + whose module lives under the discovery package. This finds classes by *type* + rather than by walking the filesystem, so it is unaffected by packages that + omit ``__init__.py``. Any names the package exports lazily (PEP 562 + ``__getattr__``) are materialized first so their classes are loaded before + enumeration; packages without ``__all__`` (e.g. scenarios) are already fully + imported, so that step is a no-op for them. Deprecated alias classes + (docstring starting with ``"Deprecated alias"``) are skipped. Names come from + ``_get_registry_name`` (class name by default; dotted paths for scenarios). + Registries with bespoke discovery override this method instead of supplying + ``_base_type``/``_discovery_package``. """ package = self._discovery_package() base = self._base_type() - for name in getattr(package, "__all__", []): - cls = getattr(package, name, None) - if cls is None or not isinstance(cls, type): - continue - # Guard against entries that aren't genuine classes. A test elsewhere in the - # suite may patch a package export with a mock (e.g. ``autospec``/``spec=type``) - # that reports ``isinstance(cls, type) is True`` yet makes ``issubclass`` raise - # ``TypeError``; skip anything that isn't a real subclass of the base. + package_name = package.__name__ + package_prefix = f"{package_name}." + # Materialize lazily-exported classes so they are loaded before enumeration. + # A lazy import backed by an optional dependency may fail; skip it rather + # than fail the whole discovery (the class cannot be built without the dep). + for exported_name in getattr(package, "__all__", ()): try: - if not issubclass(cls, base) or cls is base or inspect.isabstract(cls): - continue - except TypeError: + getattr(package, exported_name) + except Exception as exc: + logger.debug(f"Skipping lazily-exported '{exported_name}': {exc}") + for cls in self._iter_concrete_subclasses(base): + module = cls.__module__ or "" + if module != package_name and not module.startswith(package_prefix): + continue + if (cls.__doc__ or "").strip().startswith("Deprecated alias"): + logger.debug(f"Skipping deprecated alias: {cls.__name__}") + continue + name = self._get_registry_name(cls) + existing = self._classes.get(name) + if existing is not None and existing is not cls: + logger.warning( + f"{base.__name__} registry name collision: '{name}' conflicts with an " + f"already-registered class. Keeping {existing.__name__}, skipping {cls.__name__}." + ) + continue + self.register_class(cls, name=name) + logger.debug(f"Registered {base.__name__} class: {name} ({cls.__name__})") + + @staticmethod + def _iter_concrete_subclasses(base: type[T]) -> list[type[T]]: + """ + Return every non-abstract subclass of ``base`` currently loaded in memory. + + Walks ``__subclasses__()`` recursively (deduplicating) and drops abstract + classes. Results are sorted by ``(module, qualified name)`` for deterministic + registration order. + + Args: + base (type[T]): The domain base class to enumerate subclasses of. + + Returns: + list[type[T]]: The concrete subclasses, in a stable order. + """ + discovered: dict[int, type[T]] = {} + stack = list(base.__subclasses__()) + while stack: + cls = stack.pop() + if id(cls) in discovered: continue - self.register_class(cls) - logger.debug(f"Registered {base.__name__} class: {cls.__name__}") + discovered[id(cls)] = cls + stack.extend(cls.__subclasses__()) + concrete = [cls for cls in discovered.values() if not inspect.isabstract(cls)] + concrete.sort(key=lambda c: (c.__module__ or "", c.__qualname__)) + return concrete @abstractmethod def _metadata_class(self) -> type[MetadataT]: diff --git a/pyrit/registry/resolution.py b/pyrit/registry/resolution.py index 937bec1b5a..b786834e8e 100644 --- a/pyrit/registry/resolution.py +++ b/pyrit/registry/resolution.py @@ -2,11 +2,12 @@ # Licensed under the MIT license. """ -The constructor <-> ``Parameter`` contract bridge for PyRIT registries. +The ``Parameter`` contract bridge for PyRIT registries. -This module is the single place that translates between a component class's -``__init__`` and the declarative ``Parameter`` contract carried by its domain -identifier. It has three responsibilities: +This module is the single place that translates raw arguments into ready values +against the declarative ``Parameter`` contract, whether that contract is derived +from a class ``__init__`` or declared explicitly by a component. It has three +responsibilities: - **Derive** (``derive_parameters``): read the constructor signature, enriched by the identifier's ``Param.*`` build markers, into a ``list[Parameter]``. A @@ -14,12 +15,18 @@ included field typed as a child identifier, e.g. ``TargetIdentifier``) becomes a registry **reference**; every other parameter becomes a plain value parameter whose ``param_type`` is the annotation with ``Optional[X]`` reduced to ``X``. -- **Resolve** (``resolve_constructor_args``): derive the contract for a class - and turn a flat dict of raw arguments into constructor-ready keyword arguments — - coercing simple string values via ``Parameter.coerce_value`` and resolving - registry-reference parameters by name from the owning domain's registry. -- **Present** (``display_choices``): project a constrained-scalar ``param_type`` - into its allowed-value display tuple. +- **Resolve from a constructor** (``resolve_constructor_args``): derive the + contract for a class and turn a flat dict of raw arguments into + constructor-ready keyword arguments — coercing simple string values via + ``Parameter.coerce_value`` and resolving registry-reference parameters by name + from the owning domain's registry. Defaults are left to the constructor. +- **Resolve from a declared list** (``resolve_declared_params``): the sibling for + a component that declares an explicit ``list[Parameter]`` (e.g. a scenario's + ``supported_parameters()``). It has no references, coerces every supplied + value, and materializes every declared default so the result is a complete + param bag. Both resolve functions delegate the actual coercion/validation to + the ``Parameter`` model — that is the one shared kernel; they differ only in + where the contract comes from and how defaults are handled. The identifier is the declarative blueprint; this module is where the registry reads and applies it. It performs no eager heavy imports and never imports @@ -28,6 +35,7 @@ from __future__ import annotations +import copy import inspect import re import types @@ -147,7 +155,10 @@ def derive_parameters(*, cls: type, identifier_type: type[ComponentIdentifier] | for name, param in sig.parameters.items(): if name in _SKIPPED_PARAM_NAMES: continue - if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): continue annotation = param.annotation @@ -166,7 +177,14 @@ def derive_parameters(*, cls: type, identifier_type: type[ComponentIdentifier] | ) else: param_type = None if annotation is inspect.Parameter.empty else _unwrap_optional(annotation) - parameters.append(Parameter(name=name, description=description, default=default, param_type=param_type)) + parameters.append( + Parameter( + name=name, + description=description, + default=default, + param_type=param_type, + ) + ) return parameters @@ -188,7 +206,9 @@ def get_names(self) -> list[str]: ... -def _registry_getter_for_component_type(component_type: ComponentType) -> Callable[[], _NamedInstanceRegistry] | None: +def _registry_getter_for_component_type( + component_type: ComponentType, +) -> Callable[[], _NamedInstanceRegistry] | None: """ Return the getter for the instance registry that resolves a component family. @@ -205,7 +225,11 @@ def _registry_getter_for_component_type(component_type: ComponentType) -> Callab Callable[[], _NamedInstanceRegistry] | None: The registry getter, or None when no registry is wired for ``component_type``. """ - from pyrit.registry.components import ConverterRegistry, ScorerRegistry, TargetRegistry + from pyrit.registry.components import ( + ConverterRegistry, + ScorerRegistry, + TargetRegistry, + ) registry_classes = { ComponentType.TARGET: TargetRegistry, @@ -313,7 +337,10 @@ def _resolve_registry_reference( def resolve_constructor_args( - *, cls: type, raw_args: dict[str, Any], identifier_type: type[ComponentIdentifier] | None = None + *, + cls: type, + raw_args: dict[str, Any], + identifier_type: type[ComponentIdentifier] | None = None, ) -> dict[str, Any]: """ Resolve a flat argument dict into constructor-ready keyword arguments. @@ -371,3 +398,136 @@ def resolve_constructor_args( resolved[name] = value return resolved + + +# --------------------------------------------------------------------------- +# Resolve (declared list): raw args -> fully-materialized declared-parameter dict +# --------------------------------------------------------------------------- + + +def resolve_declared_params( + *, + declared: list[Parameter], + raw_args: dict[str, Any], + owner: str, +) -> dict[str, Any]: + """ + Resolve ``raw_args`` against an explicit declared-parameter contract. + + The declared-list sibling of ``resolve_constructor_args``. Both translate a + flat dict of raw arguments into ready values against the ``Parameter`` + contract, delegating the actual coercion/validation to the ``Parameter`` + model; they differ only in where the contract comes from and how it is + consumed: + + - ``resolve_constructor_args`` derives the contract from a class ``__init__``, + resolves registry references, coerces string values, and returns the kwargs + subset for ``cls(**resolved)`` (the constructor supplies defaults). + - ``resolve_declared_params`` takes an explicit ``list[Parameter]`` (e.g. a + scenario's ``supported_parameters()``), has no references, coerces every + supplied value, and **materializes every declared default** so the returned + dict is a complete param bag. Params declared without a default land as + ``None`` so callers can rely on ``params[name]`` never raising ``KeyError``. + + Args: + declared (list[Parameter]): The declaration snapshot to validate against. + raw_args (dict[str, Any]): Map of parameter name to raw value. Keys with + ``None`` values are treated as absent (YAML ``null``). + owner (str): Human-readable owner label used to prefix error messages, + e.g. ``"Scenario 'FoundryScenario'"``. + + Returns: + dict[str, Any]: Fully-materialized parameter dict. + + Raises: + ValueError: Invalid declaration, unknown parameter, coercion failure, or + value not in ``choices``. + """ + _validate_declarations(declared=declared, owner=owner) + + declared_by_name = {param.name: param for param in declared} + + # None values are treated as absent so YAML `key: null` falls through to defaults. + supplied = {name: value for name, value in raw_args.items() if value is not None} + + coerced: dict[str, Any] = {} + for name, raw_value in supplied.items(): + param = declared_by_name.get(name) + if param is None: + # Stash unknowns so _reject_undeclared_params can list them all at once. + coerced[name] = raw_value + continue + coerced[name] = param.coerce_value(raw_value) + + _reject_undeclared_params(params=coerced, declared=declared, owner=owner) + + for param in declared: + if param.name in coerced: + continue + # Materialize every declared param so callers can rely on + # ``params[name]`` never raising ``KeyError``. Params declared without an + # explicit default land as None, and the owner raises a domain-specific + # error at run time if it cannot proceed. + coerced[param.name] = copy.deepcopy(param.coerce_value(param.default)) if param.default is not None else None + + return coerced + + +def _validate_declarations(*, declared: list[Parameter], owner: str) -> None: + """ + Validate a declared-parameter snapshot for author mistakes. + + Args: + declared (list[Parameter]): The declaration snapshot. + owner (str): Owner label used to prefix error messages. + + Raises: + ValueError: If declarations contain duplicate names, an unsupported + ``param_type``, or a default that fails coercion (including + membership for a constrained scalar). + """ + seen: set[str] = set() + for param in declared: + if param.name in seen: + raise ValueError(f"{owner} declares duplicate parameter name '{param.name}'.") + seen.add(param.name) + + try: + param.validate() + except ValueError as exc: + raise ValueError(f"{owner} {exc}") from exc + + if param.default is not None: + try: + param.coerce_value(param.default) + except ValueError as exc: + raise ValueError(f"{owner} parameter '{param.name}' has an invalid default: {exc}") from exc + + +def _reject_undeclared_params(*, params: dict[str, Any], declared: list[Parameter], owner: str) -> None: + """ + Raise if ``params`` contains any key not in the ``declared`` snapshot. + + Specific to the declared-parameter path (``resolve_declared_params``): it + reports every undeclared key at once. The constructor path + (``resolve_constructor_args``) rejects unknown arguments inline instead. + + Args: + params (dict[str, Any]): Coerced (declared names) or raw (unknown) values. + declared (list[Parameter]): Declaration snapshot from the caller. + owner (str): Owner label used to prefix error messages. + + Raises: + ValueError: If any keys in ``params`` are not declared. + """ + declared_names = {param.name for param in declared} + unknown = sorted(set(params.keys()) - declared_names) + if unknown: + raise ValueError( + f"{owner} received unknown parameter(s): {', '.join(unknown)}. " + f"Supported parameters: " + f"{', '.join(sorted(declared_names)) if declared_names else 'none'}." + ) + + +# --------------------------------------------------------------------------- diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 752cabe2b6..de59d2c004 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -9,8 +9,6 @@ """ import asyncio -import copy -import json import logging import uuid from abc import ABC @@ -37,6 +35,7 @@ from pyrit.models import ( AttackOutcome, AttackResult, + ScenarioEvaluationIdentifier, ScenarioIdentifier, ScenarioResult, ScenarioRunState, @@ -46,6 +45,9 @@ from pyrit.prompt_target import PromptTarget from pyrit.prompt_target.common.target_requirements import TargetRequirements from pyrit.registry import ScorerRegistry +from pyrit.registry.resolution import ( + resolve_declared_params, +) from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.dataset_configuration import DatasetAttackConfiguration from pyrit.scenario.core.matrix_atomic_attack_builder import ( @@ -73,6 +75,13 @@ logger = logging.getLogger(__name__) +#: Param names a scenario must not declare via ``supported_parameters()``. These +#: collide with promoted identity fields on ``ScenarioIdentifier`` and would be +#: silently overwritten during identifier promotion. Only ``version`` is reserved +#: today; a scenario's definition version is owned by the identifier, not a param. +_RESERVED_SCENARIO_PARAM_NAMES: frozenset[str] = frozenset({"version"}) + + class BaselineAttackPolicy(Enum): """ Declares how a scenario type treats the default baseline atomic attack. @@ -94,56 +103,6 @@ class BaselineAttackPolicy(Enum): Forbidden = "forbidden" -def _assert_json_serializable(*, params: dict[str, Any]) -> None: - """ - Raise if any value in ``params`` cannot round-trip through JSON. - - Stage 5 stores ``params`` on ``ScenarioIdentifier.init_data`` for resume - validation; the underlying memory column is JSON. Catching unserializable - values here gives a clear error rather than a database failure. - - Args: - params (dict[str, Any]): Effective parameters to validate. - - Raises: - ValueError: If any value is not JSON-serializable. - """ - try: - json.dumps(params) - except (TypeError, ValueError) as exc: - raise ValueError( - f"Scenario params contain a non-JSON-serializable value (cannot persist for resume): {exc}. " - f"Use only JSON-safe types (str, int, float, bool, list, dict, None) for scenario parameters." - ) from exc - - -def _format_param_key_diff(*, stored: dict[str, Any], current: dict[str, Any]) -> str: - """ - Render the set-level difference between two param dicts as a short string. - - Lists only key names (no values) so secrets or large blobs in scenario - parameters do not leak into logs. - - Args: - stored (dict[str, Any]): Persisted params from the previous run. - current (dict[str, Any]): Effective params for the current run. - - Returns: - str: A short summary like ``"added: x, y; removed: z; changed: max_turns"``. - """ - parts: list[str] = [] - added = sorted(set(current) - set(stored)) - removed = sorted(set(stored) - set(current)) - changed = sorted(k for k in set(stored) & set(current) if stored[k] != current[k]) - if added: - parts.append(f"added: {', '.join(added)}") - if removed: - parts.append(f"removed: {', '.join(removed)}") - if changed: - parts.append(f"changed: {', '.join(changed)}") - return "; ".join(parts) if parts else "no diff details" - - class Scenario(ABC): # noqa: B024 - retained for subclass type-checking even without abstract methods """ Groups and executes multiple AtomicAttack instances sequentially. @@ -239,9 +198,13 @@ def __init__( description = ClassRegistryEntry.description_from_docstring(self.__class__) - self._identifier = ScenarioIdentifier( - name=type(self).__name__, scenario_version=version, description=description - ) + # The scenario identifier is the canonical per-run identity: the scenario + # registry produces it and it is persisted on the ScenarioResult (carrying + # class name / version / resolved techniques / datasets / params and the + # objective_target / objective_scorer references). The display description + # and pyrit_version ride alongside it on the ScenarioResult. + self._version = version + self._description = description # Store strategy configuration for use in initialize_async self._strategy_class = strategy_class @@ -277,9 +240,9 @@ def __init__( # Maps atomic_attack_name → display_group for user-facing aggregation self._display_group_map: dict[str, str] = {} - # Custom parameters: declared via supported_parameters(), populated via set_params_from_args(). + # Declared via supported_parameters(); resolved/populated by the registry + # helper (pyrit.registry.resolution). Subclasses read it in _get_atomic_attacks_async. self.params: dict[str, Any] = {} - self._declarations_validated: bool = False # Resolved effective baseline inclusion for the current run. Set in initialize_async # before _get_atomic_attacks_async is awaited so overrides can read it. @@ -439,9 +402,11 @@ def set_params_from_args(self, *, args: dict[str, Any]) -> None: """ Populate ``self.params`` from merged CLI / config arguments. - Coerces each value to its declared ``param_type``, validates, and - materializes declared defaults for params not in ``args``. Every - declared parameter is guaranteed a key in ``self.params`` after this + The scenario only **declares** its parameters via ``supported_parameters()``; + the coerce / validate / inject-defaults *mapping* is owned by the registry + layer (``pyrit.registry.resolution.resolve_declared_params``) so there is a + single implementation shared by the programmatic, CLI, and registry paths. + Every declared parameter is guaranteed a key in ``self.params`` after this call; params without a declared default land as ``None``. Args: @@ -451,94 +416,22 @@ def set_params_from_args(self, *, args: dict[str, Any]) -> None: Raises: ValueError: Invalid declaration, unknown parameter, coercion - failure, or value not in ``choices``. + failure, value not in ``choices``, or a declared parameter using + a reserved scenario identity name (e.g. ``version``). """ declared = list(self.supported_parameters()) - if not self._declarations_validated: - self._validate_declarations(declared=declared) - self._declarations_validated = True - - declared_by_name = {p.name: p for p in declared} - - # None values are treated as absent so YAML `key: null` falls through to defaults. - supplied = {name: value for name, value in args.items() if value is not None} - - coerced: dict[str, Any] = {} - for name, raw_value in supplied.items(): - param = declared_by_name.get(name) - if param is None: - # Stash unknowns so _validate_params can list them all at once. - coerced[name] = raw_value - continue - coerced[name] = param.coerce_value(raw_value) - - self._validate_params(params=coerced, declared=declared) - - for param in declared: - if param.name in coerced: - continue - # Materialize every declared param so scenarios can rely on - # ``self.params[name]`` never raising ``KeyError``. Params declared - # without an explicit default land as None, and the scenario raises - # a domain-specific error at run time if it cannot proceed. - coerced[param.name] = ( - copy.deepcopy(param.coerce_value(param.default)) if param.default is not None else None - ) - - self.params = coerced - - def _validate_declarations(self, *, declared: list[Parameter]) -> None: - """ - Validate the scenario's parameter declarations on first use. - - Args: - declared (list[Parameter]): The ``supported_parameters()`` snapshot. - - Raises: - ValueError: If declarations contain duplicate names, an - unsupported ``param_type``, or a default that fails coercion - (including membership for a constrained scalar). - """ - seen: set[str] = set() - for param in declared: - if param.name in seen: - raise ValueError(f"Scenario '{type(self).__name__}' declares duplicate parameter name '{param.name}'.") - seen.add(param.name) - - try: - param.validate() - except ValueError as exc: - raise ValueError(f"Scenario '{type(self).__name__}' {exc}") from exc - - if param.default is not None: - try: - param.coerce_value(param.default) - except ValueError as exc: - raise ValueError( - f"Scenario '{type(self).__name__}' parameter '{param.name}' has an invalid default: {exc}" - ) from exc - - def _validate_params(self, *, params: dict[str, Any], declared: list[Parameter]) -> None: - """ - Validate supplied params against the scenario's declarations. - - Args: - params (dict[str, Any]): Coerced (declared names) or raw (unknown) values. - declared (list[Parameter]): Declarations snapshot from the caller, so - the whole call sees one consistent view. - - Raises: - ValueError: If any keys in ``params`` are not declared. - """ - declared_names = {p.name for p in declared} - - unknown = sorted(set(params.keys()) - declared_names) - if unknown: + reserved = sorted({p.name for p in declared} & _RESERVED_SCENARIO_PARAM_NAMES) + if reserved: raise ValueError( - f"Scenario '{type(self).__name__}' received unknown parameter(s): {', '.join(unknown)}. " - f"Supported parameters: " - f"{', '.join(sorted(declared_names)) if declared_names else 'none'}." + f"Scenario '{type(self).__name__}' declares reserved parameter(s) {reserved}; " + "these names are owned by the scenario identity and cannot be scenario params. " + "Rename the parameter." ) + self.params = resolve_declared_params( + declared=declared, + raw_args=args, + owner=f"Scenario '{type(self).__name__}'", + ) def _prepare_strategies( self, @@ -668,12 +561,11 @@ async def initialize_async( self._scenario_strategies = self._prepare_strategies(scenario_strategies) self._strategy_converters = strategy_converters or {} - # Materialize declared defaults for programmatic callers that skip the - # explicit set_params_from_args step. Frontend-driven flows already - # call it (which sets _declarations_validated=True), so this is a no-op - # in that path. - if not self._declarations_validated: - self.set_params_from_args(args={}) + # Resolve declared parameters through the single registry-owned path, + # materializing defaults for programmatic callers that skipped an explicit + # set_params_from_args. Re-resolving an already-resolved bag is idempotent, + # so the registry- and CLI-driven flows converge here without divergence. + self.set_params_from_args(args=self.params) self._atomic_attacks = await self._get_atomic_attacks_async() @@ -695,12 +587,10 @@ async def initialize_async( seed_groups = await self._dataset_config.get_seed_attack_groups_async() self._atomic_attacks.insert(0, self._build_baseline_atomic_attack(seed_groups=seed_groups)) - # Snapshot params onto the identifier before the resume branch so the identifier - # is fully populated regardless of which branch we take. Deep-copy avoids sharing - # mutable state with self.params. - params_snapshot = copy.deepcopy(self.params) - _assert_json_serializable(params=params_snapshot) - self._identifier.init_data = params_snapshot + # Build the canonical scenario identifier once params/strategies/datasets + # are resolved, so both the resume check and the new-result branch share the + # same identity (and its eval hash). + scenario_identifier = self._build_scenario_identifier() # Check if we're resuming an existing scenario. Any divergence is a hard error # rather than a silent restart, so the original progress isn't orphaned without @@ -714,7 +604,7 @@ async def initialize_async( f"Drop scenario_result_id to start a new scenario." ) - self._validate_stored_scenario(stored_result=existing_results[0]) + self._validate_stored_scenario(stored_result=existing_results[0], current_identifier=scenario_identifier) self._apply_persisted_objectives(stored_result=existing_results[0]) return # Valid resume - skip creating new scenario result @@ -727,9 +617,8 @@ async def initialize_async( } result = ScenarioResult( - scenario_identifier=self._identifier, - objective_target_identifier=self._objective_target_identifier, - objective_scorer_identifier=self._objective_scorer_identifier, + scenario_identifier=scenario_identifier, + scenario_description=self._description, labels=self._memory_labels, attack_results=attack_results, scenario_run_state=ScenarioRunState.CREATED, @@ -842,48 +731,64 @@ def _build_baseline_atomic_attack(self, *, seed_groups: list[SeedAttackGroup]) - memory_labels=self._memory_labels, ) - def _validate_stored_scenario(self, *, stored_result: ScenarioResult) -> None: + def _build_scenario_identifier(self) -> ScenarioIdentifier: """ - Validate that a stored scenario result exactly matches the current scenario configuration. + Build the canonical ``ScenarioIdentifier`` for the current run. + + Combines the definition version, the resolved technique / dataset + selection, the resolved scenario params, and the objective target / scorer + references into one identity whose eval hash backs resume drift detection. + + Returns: + ScenarioIdentifier: The identifier describing this scenario run. + """ + techniques = sorted({s.value for s in self._scenario_strategies}) + datasets = list(self._dataset_config.dataset_names) + return ScenarioIdentifier.of( + self, + params=self.params, + version=self._version, + techniques=techniques, + datasets=datasets, + objective_target=self._objective_target_identifier, + objective_scorer=self._objective_scorer_identifier, + ) + + def _validate_stored_scenario( + self, *, stored_result: ScenarioResult, current_identifier: ScenarioIdentifier + ) -> None: + """ + Validate that a stored scenario result matches the current configuration. Resume is opt-in via ``scenario_result_id``; any divergence from the stored result is treated as user error rather than a silent restart, since the - original progress would otherwise be orphaned without warning. + original progress would otherwise be orphaned without warning. Divergence is + detected by comparing behavioral eval hashes: the scenario class name / + module, version, resolved techniques / datasets, params, and objective + target / scorer all feed the hash, so a mismatch means either a different + scenario or a changed configuration. Args: stored_result (ScenarioResult): The scenario result retrieved from memory. + current_identifier (ScenarioIdentifier): Identifier for the current run. Raises: - ValueError: If the stored scenario name, version, or parameters do not - match the current configuration. + ValueError: If the stored scenario identity does not match the current one. """ - stored_name = stored_result.scenario_identifier.name - stored_version = stored_result.scenario_identifier.version - - if stored_name != self._identifier.name: - raise ValueError( - f"Scenario result id '{self._scenario_result_id}' belongs to scenario '{stored_name}' " - f"but current scenario is '{self._identifier.name}'. " - f"Drop scenario_result_id to start a new scenario." - ) - - if stored_version != self._identifier.version: - raise ValueError( - f"Scenario result id '{self._scenario_result_id}' was created with " - f"{self._identifier.name} version {stored_version} but current version is " - f"{self._identifier.version}. Drop scenario_result_id to start a new scenario." - ) - - # Treat None (legacy result without persisted params) as empty. Compare both sides - # post-JSON-roundtrip so types that the memory column rewrites (tuple → list, non-str - # dict keys → str) don't surface as false mismatches under param_type=None. - stored_params = stored_result.scenario_identifier.init_data or {} - current_params_normalized = json.loads(json.dumps(self.params)) - if stored_params != current_params_normalized: - diff = _format_param_key_diff(stored=stored_params, current=current_params_normalized) + # Compare behavioral eval hashes. The stored eval_hash is never trusted; + # ScenarioEvaluationIdentifier recomputes it from the stored identifier's + # class / params / children, matching how the current identifier is hashed. + # class_name and class_module both feed the hash, so this also catches a + # scenario_result_id that belongs to an entirely different scenario. + stored_eval_hash = ScenarioEvaluationIdentifier(stored_result.scenario_identifier).eval_hash + current_eval_hash = ScenarioEvaluationIdentifier(current_identifier).eval_hash + + if stored_eval_hash != current_eval_hash: raise ValueError( - f"Scenario result id '{self._scenario_result_id}' has mismatched parameters ({diff}). " - f"Drop scenario_result_id to start a new scenario, or pass matching parameters to resume." + f"Scenario result id '{self._scenario_result_id}' does not match the current " + f"'{type(self).__name__}' configuration (a different scenario, or its version, " + f"techniques, datasets, parameters, or objective target / scorer changed). " + f"Drop scenario_result_id to start a new scenario, or pass matching configuration to resume." ) logger.info( @@ -1141,7 +1046,7 @@ async def run_async(self) -> ScenarioResult: Example: >>> result = await scenario.run_async() - >>> print(f"Scenario: {result.scenario_identifier.name}") + >>> print(f"Scenario: {result.scenario_name}") >>> print(f"Total results: {len(result.attack_results)}") """ if not self._atomic_attacks: diff --git a/pyrit/setup/initializers/scenarios/load_default_datasets.py b/pyrit/setup/initializers/scenarios/load_default_datasets.py index a5e8d383b5..b22f3909af 100644 --- a/pyrit/setup/initializers/scenarios/load_default_datasets.py +++ b/pyrit/setup/initializers/scenarios/load_default_datasets.py @@ -35,15 +35,13 @@ def execution_order(self) -> int: @property def description(self) -> str: """A description of this initializer.""" - return textwrap.dedent( - """ + return textwrap.dedent(""" This configuration uses the DatasetLoader to load default datasets into memory. This will enable all scenarios to run. Datasets can be customized in memory. Note: if you are using persistent memory, avoid calling this every time as datasets can take time to load. - """ - ).strip() + """).strip() @property def required_env_vars(self) -> list[str]: @@ -56,7 +54,7 @@ async def initialize_async(self) -> None: all_default_datasets: list[str] = [] - for metadata in registry.list_metadata(): + for metadata in registry.get_all_registered_class_metadata(): datasets = list(metadata.default_datasets) all_default_datasets.extend(datasets) logger.info(f"Scenario '{metadata.registry_name}' uses datasets: {datasets}") diff --git a/pyrit/setup/initializers/scenarios/preload_scenario_metadata.py b/pyrit/setup/initializers/scenarios/preload_scenario_metadata.py index 72a591c9d9..ba92ee4f1d 100644 --- a/pyrit/setup/initializers/scenarios/preload_scenario_metadata.py +++ b/pyrit/setup/initializers/scenarios/preload_scenario_metadata.py @@ -6,7 +6,7 @@ Each registered ``Scenario`` is instantiated once so the registry can read the strategy class, default strategy, and default dataset configuration off the -instance. The results are cached on ``BaseClassRegistry._metadata_cache``; the +instance. The results are cached on ``Registry._metadata_cache``; the first ``--list-scenarios`` / GUI call is then a cache hit. Per-scenario instantiation failures surface loudly here at startup rather than later. """ @@ -25,5 +25,5 @@ class PreloadScenarioMetadata(PyRITInitializer): async def initialize_async(self) -> None: """Warm the scenario metadata cache.""" registry = ScenarioRegistry.get_registry_singleton() - metadata = registry.list_metadata() + metadata = registry.get_all_registered_class_metadata() logger.info("Preloaded metadata for %d scenarios", len(metadata)) diff --git a/tests/end_to_end/test_scenarios.py b/tests/end_to_end/test_scenarios.py index dcafd66296..22b991a175 100644 --- a/tests/end_to_end/test_scenarios.py +++ b/tests/end_to_end/test_scenarios.py @@ -34,7 +34,7 @@ #: Per-scenario override map for initializers. A scenario absent here falls back #: to ``DEFAULT_INITIALIZERS``. Keys use the dotted registry name -#: (``.``) returned by ``ScenarioRegistry.get_names()``. +#: (``.``) returned by ``ScenarioRegistry.get_class_names()``. SCENARIO_INITIALIZERS: dict[str, list[str]] = {} #: Per-scenario extra CLI args appended after the standard flag block. Keys use @@ -56,7 +56,7 @@ def get_all_scenarios(): list[str]: Sorted list of scenario names. """ registry = ScenarioRegistry.get_registry_singleton() - return registry.get_names() + return registry.get_class_names() def _initializers_for(scenario_name: str) -> list[str]: diff --git a/tests/integration/memory/test_azure_sql_memory_integration.py b/tests/integration/memory/test_azure_sql_memory_integration.py index 716fa26c51..ff4a749cf8 100644 --- a/tests/integration/memory/test_azure_sql_memory_integration.py +++ b/tests/integration/memory/test_azure_sql_memory_integration.py @@ -8,6 +8,7 @@ import numpy as np from sqlalchemy.exc import SQLAlchemyError +from unit.mocks import make_scenario_result from pyrit.memory import AzureSQLMemory from pyrit.memory.memory_models import ( @@ -21,8 +22,6 @@ AttackResult, ComponentIdentifier, MessagePiece, - ScenarioIdentifier, - ScenarioResult, SeedPrompt, ) @@ -194,9 +193,17 @@ async def test_get_seeds_with_metadata_filter(azuresql_instance: AzureSQLMemory) value2 = np.random.randint(0, 10000) # Use unique seed values to avoid deduplication - sp1 = SeedPrompt(value=f"sp1-{test_id}", data_type="text", metadata={"key1": value1}, added_by=test_id) + sp1 = SeedPrompt( + value=f"sp1-{test_id}", + data_type="text", + metadata={"key1": value1}, + added_by=test_id, + ) sp2 = SeedPrompt( - value=f"sp2-{test_id}", data_type="text", metadata={"key1": value2, "key2": value1}, added_by=test_id + value=f"sp2-{test_id}", + data_type="text", + metadata={"key1": value2, "key2": value1}, + added_by=test_id, ) # Use public async API method @@ -256,7 +263,11 @@ async def test_get_attack_results_by_labels(azuresql_instance: AzureSQLMemory): role="user", original_value="Test 1", converted_value="Test 1", - labels={"op_id": f"op123_{test_id}", "category": "test", "priority": "high"}, + labels={ + "op_id": f"op123_{test_id}", + "category": "test", + "priority": "high", + }, ) piece2 = MessagePiece( conversation_id=conversation_ids[1], @@ -319,7 +330,9 @@ async def test_get_attack_results_by_labels(azuresql_instance: AzureSQLMemory): assert len(results) == 0 -async def test_scenario_result_scorer_identifier_roundtrip(azuresql_instance: AzureSQLMemory): +async def test_scenario_result_scorer_identifier_roundtrip( + azuresql_instance: AzureSQLMemory, +): """ Integration test for storing and retrieving objective_scorer_identifier in ScenarioResult. @@ -336,11 +349,9 @@ async def test_scenario_result_scorer_identifier_roundtrip(azuresql_instance: Az ) # Create scenario with scorer identifier - scenario = ScenarioResult( - scenario_identifier=ScenarioIdentifier( - name=f"Scorer Test Scenario {test_id}", - scenario_version=1, - ), + scenario = make_scenario_result( + scenario_name=f"Scorer Test Scenario {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict( {"endpoint": f"https://test-{test_id}.example.com"} ), @@ -380,22 +391,30 @@ async def test_get_scenario_results_by_labels(azuresql_instance: AzureSQLMemory) with cleanup_scenario_data(azuresql_instance, test_id): # Create scenario results with labels scorer_id = get_test_scorer_identifier() - scenario1 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name=f"Test Scenario 1 {test_id}", scenario_version=1), + scenario1 = make_scenario_result( + scenario_name=f"Test Scenario 1 {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict({"endpoint": "https://api.openai.com"}), attack_results={}, objective_scorer_identifier=scorer_id, - labels={"environment": "test", "priority": "high", "team": "red", "test_id": test_id}, + labels={ + "environment": "test", + "priority": "high", + "team": "red", + "test_id": test_id, + }, ) - scenario2 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name=f"Test Scenario 2 {test_id}", scenario_version=1), + scenario2 = make_scenario_result( + scenario_name=f"Test Scenario 2 {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict({"endpoint": "https://api.azure.com"}), attack_results={}, objective_scorer_identifier=scorer_id, labels={"environment": "test", "priority": "high", "test_id": test_id}, ) - scenario3 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name=f"Test Scenario 3 {test_id}", scenario_version=1), + scenario3 = make_scenario_result( + scenario_name=f"Test Scenario 3 {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict({"endpoint": "https://api.anthropic.com"}), attack_results={}, objective_scorer_identifier=scorer_id, @@ -407,7 +426,7 @@ async def test_get_scenario_results_by_labels(azuresql_instance: AzureSQLMemory) # Test filtering by single label results = azuresql_instance.get_scenario_results(labels={"environment": "test", "test_id": test_id}) assert len(results) == 2 - names = {r.scenario_identifier.name for r in results} + names = {r.scenario_name for r in results} assert f"Test Scenario 1 {test_id}" in names assert f"Test Scenario 2 {test_id}" in names @@ -421,14 +440,16 @@ async def test_get_scenario_results_by_labels(azuresql_instance: AzureSQLMemory) labels={"environment": "test", "team": "red", "test_id": test_id} ) assert len(results) == 1 - assert results[0].scenario_identifier.name == f"Test Scenario 1 {test_id}" + assert results[0].scenario_name == f"Test Scenario 1 {test_id}" # Test filtering with no matches results = azuresql_instance.get_scenario_results(labels={"environment": "staging", "test_id": test_id}) assert len(results) == 0 -async def test_get_scenario_results_by_target_endpoint(azuresql_instance: AzureSQLMemory): +async def test_get_scenario_results_by_target_endpoint( + azuresql_instance: AzureSQLMemory, +): """ Integration test for SQL Azure case-insensitive endpoint filtering. @@ -441,32 +462,36 @@ async def test_get_scenario_results_by_target_endpoint(azuresql_instance: AzureS with cleanup_scenario_data_by_field(azuresql_instance, test_id, "endpoint"): # Create scenario results with different endpoints scorer_id = get_test_scorer_identifier() - scenario1 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name=f"OpenAI Test {test_id}", scenario_version=1), + scenario1 = make_scenario_result( + scenario_name=f"OpenAI Test {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict( {"endpoint": f"https://api-{test_id}.openai.com/v1/chat"} ), attack_results={}, objective_scorer_identifier=scorer_id, ) - scenario2 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name=f"Azure OpenAI Test {test_id}", scenario_version=1), + scenario2 = make_scenario_result( + scenario_name=f"Azure OpenAI Test {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict( {"endpoint": f"https://myresource-{test_id}.openai.azure.com/openai"} ), attack_results={}, objective_scorer_identifier=scorer_id, ) - scenario3 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name=f"Anthropic Test {test_id}", scenario_version=1), + scenario3 = make_scenario_result( + scenario_name=f"Anthropic Test {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict( {"endpoint": f"https://api-{test_id}.anthropic.com/v1/messages"} ), attack_results={}, objective_scorer_identifier=scorer_id, ) - scenario4 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name=f"Azure Other {test_id}", scenario_version=1), + scenario4 = make_scenario_result( + scenario_name=f"Azure Other {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict( {"endpoint": f"https://myresource-{test_id}.cognitiveservices.azure.com"} ), @@ -485,27 +510,29 @@ async def test_get_scenario_results_by_target_endpoint(azuresql_instance: AzureS results = azuresql_instance.get_scenario_results(objective_target_endpoint=f"{test_id}.openai") assert len(results) == 2 - names = {r.scenario_identifier.name for r in results} + names = {r.scenario_name for r in results} assert f"OpenAI Test {test_id}" in names assert f"Azure OpenAI Test {test_id}" in names # Test case-insensitive with AZURE results = azuresql_instance.get_scenario_results(objective_target_endpoint=f"{test_id}.openai.AZURE") assert len(results) == 1 - assert results[0].scenario_identifier.name == f"Azure OpenAI Test {test_id}" + assert results[0].scenario_name == f"Azure OpenAI Test {test_id}" # Test anthropic results = azuresql_instance.get_scenario_results(objective_target_endpoint=f"{test_id}.AnThRoPiC") assert len(results) == 1 - assert results[0].scenario_identifier.name == f"Anthropic Test {test_id}" + assert results[0].scenario_name == f"Anthropic Test {test_id}" # Test cognitiveservices results = azuresql_instance.get_scenario_results(objective_target_endpoint=f"{test_id}.cognitiveservices") assert len(results) == 1 - assert results[0].scenario_identifier.name == f"Azure Other {test_id}" + assert results[0].scenario_name == f"Azure Other {test_id}" -async def test_get_scenario_results_by_target_model_name(azuresql_instance: AzureSQLMemory): +async def test_get_scenario_results_by_target_model_name( + azuresql_instance: AzureSQLMemory, +): """ Integration test for SQL Azure case-insensitive model name filtering. @@ -518,26 +545,30 @@ async def test_get_scenario_results_by_target_model_name(azuresql_instance: Azur with cleanup_scenario_data_by_field(azuresql_instance, test_id, "model_name"): # Create scenario results with different model names scorer_id = get_test_scorer_identifier() - scenario1 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name=f"GPT-4 Test {test_id}", scenario_version=1), + scenario1 = make_scenario_result( + scenario_name=f"GPT-4 Test {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict({"model_name": f"gpt-4-turbo-{test_id}"}), attack_results={}, objective_scorer_identifier=scorer_id, ) - scenario2 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name=f"GPT-4 Omni Test {test_id}", scenario_version=1), + scenario2 = make_scenario_result( + scenario_name=f"GPT-4 Omni Test {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict({"model_name": f"gpt-4o-{test_id}"}), attack_results={}, objective_scorer_identifier=scorer_id, ) - scenario3 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name=f"GPT-3.5 Test {test_id}", scenario_version=1), + scenario3 = make_scenario_result( + scenario_name=f"GPT-3.5 Test {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict({"model_name": f"gpt-3.5-turbo-{test_id}"}), attack_results={}, objective_scorer_identifier=scorer_id, ) - scenario4 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name=f"Claude Test {test_id}", scenario_version=1), + scenario4 = make_scenario_result( + scenario_name=f"Claude Test {test_id}", + scenario_version=1, objective_target_identifier=ComponentIdentifier.from_dict({"model_name": f"claude-3-opus-{test_id}"}), attack_results={}, objective_scorer_identifier=scorer_id, @@ -552,27 +583,27 @@ async def test_get_scenario_results_by_target_model_name(azuresql_instance: Azur # Test case-insensitive substring matching - gpt with test_id results = azuresql_instance.get_scenario_results(objective_target_model_name=f"gpt-4-turbo-{test_id}") assert len(results) == 1 - assert results[0].scenario_identifier.name == f"GPT-4 Test {test_id}" + assert results[0].scenario_name == f"GPT-4 Test {test_id}" # Test case-insensitive substring matching - GPT-4 (uppercase) results = azuresql_instance.get_scenario_results(objective_target_model_name=f"GPT-4o-{test_id}") assert len(results) == 1 - assert results[0].scenario_identifier.name == f"GPT-4 Omni Test {test_id}" + assert results[0].scenario_name == f"GPT-4 Omni Test {test_id}" # Test substring in the middle - version number results = azuresql_instance.get_scenario_results(objective_target_model_name=f"3.5-turbo-{test_id}") assert len(results) == 1 - assert results[0].scenario_identifier.name == f"GPT-3.5 Test {test_id}" + assert results[0].scenario_name == f"GPT-3.5 Test {test_id}" # Test case-insensitive with different model family results = azuresql_instance.get_scenario_results(objective_target_model_name=f"CLAUDE-3-opus-{test_id}") assert len(results) == 1 - assert results[0].scenario_identifier.name == f"Claude Test {test_id}" + assert results[0].scenario_name == f"Claude Test {test_id}" # Test turbo suffix with test_id results = azuresql_instance.get_scenario_results(objective_target_model_name=f"turbo-{test_id}") assert len(results) == 2 - names = {r.scenario_identifier.name for r in results} + names = {r.scenario_name for r in results} assert f"GPT-4 Test {test_id}" in names assert f"GPT-3.5 Test {test_id}" in names @@ -592,10 +623,10 @@ async def test_get_scenario_results_combined_filters(azuresql_instance: AzureSQL with cleanup_scenario_data(azuresql_instance, test_id): # Create scenario results with various attributes scorer_id = get_test_scorer_identifier() - scenario1 = ScenarioResult( - scenario_identifier=ScenarioIdentifier( - name=f"Production Test {test_id}", scenario_version=1, pyrit_version="0.4.0" - ), + scenario1 = make_scenario_result( + scenario_name=f"Production Test {test_id}", + scenario_version=1, + pyrit_version="0.4.0", objective_target_identifier=ComponentIdentifier.from_dict( { "endpoint": f"https://api-{test_id}.openai.com", @@ -607,10 +638,10 @@ async def test_get_scenario_results_combined_filters(azuresql_instance: AzureSQL labels={"environment": "prod", "priority": "high", "test_id": test_id}, completion_time=now, ) - scenario2 = ScenarioResult( - scenario_identifier=ScenarioIdentifier( - name=f"Test Environment {test_id}", scenario_version=1, pyrit_version="0.4.0" - ), + scenario2 = make_scenario_result( + scenario_name=f"Test Environment {test_id}", + scenario_version=1, + pyrit_version="0.4.0", objective_target_identifier=ComponentIdentifier.from_dict( { "endpoint": f"https://test-{test_id}.openai.com", @@ -622,10 +653,10 @@ async def test_get_scenario_results_combined_filters(azuresql_instance: AzureSQL labels={"environment": "test", "priority": "low", "test_id": test_id}, completion_time=yesterday, ) - scenario3 = ScenarioResult( - scenario_identifier=ScenarioIdentifier( - name=f"Old Version Test {test_id}", scenario_version=1, pyrit_version="0.3.0" - ), + scenario3 = make_scenario_result( + scenario_name=f"Old Version Test {test_id}", + scenario_version=1, + pyrit_version="0.3.0", objective_target_identifier=ComponentIdentifier.from_dict( { "endpoint": f"https://api-{test_id}.openai.com", @@ -642,10 +673,11 @@ async def test_get_scenario_results_combined_filters(azuresql_instance: AzureSQL # Test combining name filter with labels results = azuresql_instance.get_scenario_results( - scenario_name=f"Test Environment {test_id}", labels={"environment": "test", "test_id": test_id} + scenario_name=f"Test Environment {test_id}", + labels={"environment": "test", "test_id": test_id}, ) assert len(results) == 1 - assert results[0].scenario_identifier.name == f"Test Environment {test_id}" + assert results[0].scenario_name == f"Test Environment {test_id}" # Test combining endpoint, model name, and labels results = azuresql_instance.get_scenario_results( @@ -654,12 +686,14 @@ async def test_get_scenario_results_combined_filters(azuresql_instance: AzureSQL labels={"priority": "high", "test_id": test_id}, ) assert len(results) == 1 - assert results[0].scenario_identifier.name == f"Production Test {test_id}" + assert results[0].scenario_name == f"Production Test {test_id}" # Test combining version and time filters with test_id # Add 1 second buffer to account for SQL Server datetime precision differences results = azuresql_instance.get_scenario_results( - pyrit_version="0.4.0", added_before=now + timedelta(seconds=1), labels={"test_id": test_id} + pyrit_version="0.4.0", + added_before=now + timedelta(seconds=1), + labels={"test_id": test_id}, ) assert len(results) == 2 @@ -672,10 +706,11 @@ async def test_get_scenario_results_combined_filters(azuresql_instance: AzureSQL labels={"environment": "prod", "priority": "high", "test_id": test_id}, ) assert len(results) == 1 - assert results[0].scenario_identifier.name == f"Production Test {test_id}" + assert results[0].scenario_name == f"Production Test {test_id}" # Test combining filters with no matches results = azuresql_instance.get_scenario_results( - objective_target_endpoint=f"api-{test_id}", labels={"environment": "staging", "test_id": test_id} + objective_target_endpoint=f"api-{test_id}", + labels={"environment": "staging", "test_id": test_id}, ) assert len(results) == 0 diff --git a/tests/unit/backend/test_scenario_run_routes.py b/tests/unit/backend/test_scenario_run_routes.py index 1942efc04c..32897a8a66 100644 --- a/tests/unit/backend/test_scenario_run_routes.py +++ b/tests/unit/backend/test_scenario_run_routes.py @@ -17,6 +17,7 @@ from pyrit.backend.models.scenarios import ScenarioRunListResponse from pyrit.models import ScenarioRunState from pyrit.models.catalog.scenario import ScenarioRunSummary +from unit.mocks import make_scenario_result @pytest.fixture @@ -229,7 +230,7 @@ class TestGetScenarioRunResultsRoute: def test_get_results_returns_200(self, client: TestClient) -> None: """Test that getting results of a completed run returns 200.""" - from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, ScenarioIdentifier, ScenarioResult + from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier attack = AttackResult( conversation_id="conv-1", @@ -239,8 +240,9 @@ def test_get_results_returns_200(self, client: TestClient) -> None: execution_time_ms=100, timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), ) - scenario_result = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name="foundry.red_team_agent", description="Foundry red-team agent"), + scenario_result = make_scenario_result( + scenario_name="foundry.red_team_agent", + scenario_description="Foundry red-team agent", objective_target_identifier=ComponentIdentifier.from_dict( {"__type__": "FakeTarget", "__module__": "test.mod", "params": {}} ), @@ -258,7 +260,7 @@ def test_get_results_returns_200(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["scenario_identifier"]["name"] == "foundry.red_team_agent" + assert data["scenario_name"] == "foundry.red_team_agent" assert "base64_attack" in data["attack_results"] def test_get_results_not_found_returns_404(self, client: TestClient) -> None: diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index d00bdff28c..89180bfd80 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -88,8 +88,8 @@ def _make_db_scenario_result( """Create a mock ScenarioResult as returned by CentralMemory.""" sr = MagicMock() sr.id = result_id - sr.scenario_identifier.name = scenario_name - sr.scenario_identifier.version = 1 + sr.scenario_name = scenario_name + sr.scenario_version = 1 sr.scenario_run_state = run_state sr.get_strategies_used.return_value = [] sr.attack_results = attack_results or {} @@ -132,6 +132,7 @@ def mock_all_registries(mock_memory): mock_sr = MagicMock() mock_sr.get_class.return_value = mock_scenario_class mock_sr.create_instance.return_value = mock_scenario_instance + mock_sr.create_and_initialize_async = AsyncMock(return_value=mock_scenario_instance) mock_tr = MagicMock() mock_tr.instances.get.return_value = MagicMock() @@ -285,7 +286,7 @@ def _lookup(name): service = ScenarioRunService() await service.start_run_async(request=_make_request(strategies=["strat_a", "strat_b"])) - init_call = scenario_instance.initialize_async.await_args + init_call = mock_all_registries["scenario_registry"].create_and_initialize_async.await_args assert init_call.kwargs["scenario_strategies"] == [strategy_a, strategy_b] async def test_start_run_max_dataset_size_uses_default_config(self, mock_all_registries) -> None: @@ -300,7 +301,7 @@ async def test_start_run_max_dataset_size_uses_default_config(self, mock_all_reg # max_dataset_size on the default config was overridden assert default_config.max_dataset_size == 5 - init_call = scenario_instance.initialize_async.await_args + init_call = mock_all_registries["scenario_registry"].create_and_initialize_async.await_args assert init_call.kwargs["dataset_config"] is default_config async def test_start_run_dataset_names_preserves_subclass_config_type(self, mock_all_registries) -> None: @@ -323,7 +324,7 @@ class _MarkerDatasetConfiguration(DatasetConfiguration): service = ScenarioRunService() await service.start_run_async(request=_make_request(dataset_names=["custom_a", "custom_b"], max_dataset_size=3)) - init_call = scenario_instance.initialize_async.await_args + init_call = mock_all_registries["scenario_registry"].create_and_initialize_async.await_args built_config = init_call.kwargs["dataset_config"] # Type is preserved (this is the regression assertion) @@ -349,7 +350,7 @@ class _MarkerDatasetConfiguration(DatasetConfiguration): service = ScenarioRunService() await service.start_run_async(request=_make_request(dataset_names=["only_this"])) - init_call = scenario_instance.initialize_async.await_args + init_call = mock_all_registries["scenario_registry"].create_and_initialize_async.await_args built_config = init_call.kwargs["dataset_config"] assert type(built_config) is _MarkerDatasetConfiguration assert built_config.dataset_names == ["only_this"] @@ -375,7 +376,7 @@ def __init__(self, *, required_extra: str, **kwargs: Any) -> None: with caplog.at_level("WARNING", logger=_svc_mod.logger.name): await service.start_run_async(request=_make_request(dataset_names=["custom"])) - init_call = scenario_instance.initialize_async.await_args + init_call = mock_all_registries["scenario_registry"].create_and_initialize_async.await_args built_config = init_call.kwargs["dataset_config"] # Fallback is the generic base class, not the subclass @@ -431,7 +432,9 @@ class _MarkerDatasetConfiguration(DatasetConfiguration): service = ScenarioRunService() await service.start_run_async(request=_make_request(dataset_names=["a", "b"], max_dataset_size=7)) - built_config = scenario_instance.initialize_async.await_args.kwargs["dataset_config"] + built_config = mock_all_registries["scenario_registry"].create_and_initialize_async.await_args.kwargs[ + "dataset_config" + ] assert type(built_config) is _MarkerDatasetConfiguration assert built_config.dataset_names == ["a", "b"] assert built_config.max_dataset_size == 7 @@ -440,17 +443,18 @@ async def test_start_run_exceeds_concurrent_limit(self, mock_all_registries) -> """Test that exceeding concurrent run limit raises ValueError.""" service = ScenarioRunService() scenario_instance = mock_all_registries["scenario_instance"] + mock_sr = mock_all_registries["scenario_registry"] # Each call needs a unique scenario_result_id call_count = 0 - original_init = scenario_instance.initialize_async - async def _set_unique_id(**kwargs: object) -> None: + async def _set_unique_id(*args: object, **kwargs: object) -> object: nonlocal call_count call_count += 1 scenario_instance._scenario_result_id = f"sr-uuid-{call_count}" + return scenario_instance - scenario_instance.initialize_async = AsyncMock(side_effect=_set_unique_id) + mock_sr.create_and_initialize_async = AsyncMock(side_effect=_set_unique_id) # Fill up to the limit for _ in range(_DEFAULT_MAX_CONCURRENT_RUNS): @@ -474,25 +478,27 @@ async def test_start_run_runs_initializers(self, mock_all_registries) -> None: assert mock_init_instance.initialize_async.await_count == 2 async def test_start_run_passes_scenario_result_id_for_resume(self, mock_all_registries) -> None: - """Test that scenario_result_id is passed to the registry constructor for resumption.""" + """Test that scenario_result_id is passed to the registry for resumption.""" service = ScenarioRunService() mock_sr = mock_all_registries["scenario_registry"] response = await service.start_run_async(request=_make_request(scenario_result_id="existing-result-uuid")) assert response.status == ScenarioRunState.IN_PROGRESS - mock_sr.create_instance.assert_called_once_with( - "foundry.red_team_agent", scenario_result_id="existing-result-uuid" - ) + call = mock_sr.create_and_initialize_async.await_args + assert call.args[0] == "foundry.red_team_agent" + assert call.kwargs["scenario_result_id"] == "existing-result-uuid" async def test_start_run_omits_scenario_result_id_when_none(self, mock_all_registries) -> None: - """Test that scenario_result_id is not passed to the registry constructor when not provided.""" + """Test that scenario_result_id is None when not provided in the request.""" service = ScenarioRunService() mock_sr = mock_all_registries["scenario_registry"] await service.start_run_async(request=_make_request()) - mock_sr.create_instance.assert_called_once_with("foundry.red_team_agent") + call = mock_sr.create_and_initialize_async.await_args + assert call.args[0] == "foundry.red_team_agent" + assert call.kwargs["scenario_result_id"] is None class TestScenarioRunServiceGetRun: @@ -904,7 +910,7 @@ def test_unknown_base_strategy_raises(self, mock_memory) -> None: ) async def test_start_run_forwards_strategy_converters(self, mock_all_registries) -> None: - """A converter token is resolved and forwarded to ``initialize_async`` as ``strategy_converters``.""" + """A converter token is resolved and forwarded through the registry as ``strategy_converters``.""" conv = MagicMock(spec=PromptConverter) scenario_instance = mock_all_registries["scenario_instance"] scenario_instance._strategy_class = _StubStrategy @@ -913,6 +919,6 @@ async def test_start_run_forwards_strategy_converters(self, mock_all_registries) with _patch_converter_registry({"translation_spanish": conv}): await service.start_run_async(request=_make_request(strategies=["role_play:converter.translation_spanish"])) - init_call = scenario_instance.initialize_async.await_args + init_call = mock_all_registries["scenario_registry"].create_and_initialize_async.await_args assert init_call.kwargs["scenario_strategies"] == [_StubStrategy.ROLE_PLAY] assert init_call.kwargs["strategy_converters"] == {"role_play": [conv]} diff --git a/tests/unit/backend/test_scenario_service.py b/tests/unit/backend/test_scenario_service.py index cef5a45966..57489df81e 100644 --- a/tests/unit/backend/test_scenario_service.py +++ b/tests/unit/backend/test_scenario_service.py @@ -14,9 +14,13 @@ from pyrit.backend.main import app from pyrit.backend.models.common import PaginationInfo -from pyrit.backend.models.scenarios import ListRegisteredScenariosResponse, RegisteredScenario -from pyrit.backend.services.scenario_service import ScenarioService, get_scenario_service +from pyrit.backend.models.scenarios import ListRegisteredScenariosResponse +from pyrit.backend.services.scenario_service import ( + ScenarioService, + get_scenario_service, +) from pyrit.models import Parameter +from pyrit.models.catalog.scenario import RegisteredScenario from pyrit.registry import ScenarioMetadata @@ -43,7 +47,6 @@ def _make_scenario_metadata( all_strategies: tuple[str, ...] = ("role_play", "many_shot"), aggregate_strategies: tuple[str, ...] = ("all", "default"), default_datasets: tuple[str, ...] = ("test_dataset",), - max_dataset_size: int | None = None, ) -> ScenarioMetadata: """Create a ScenarioMetadata instance for testing.""" return ScenarioMetadata( @@ -55,7 +58,6 @@ def _make_scenario_metadata( all_strategies=all_strategies, aggregate_strategies=aggregate_strategies, default_datasets=default_datasets, - max_dataset_size=max_dataset_size, ) @@ -72,7 +74,7 @@ async def test_list_scenarios_returns_empty_when_no_scenarios(self) -> None: with patch.object(ScenarioService, "__init__", lambda self: None): service = ScenarioService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [] + service._registry.get_all_registered_class_metadata.return_value = [] result = await service.list_scenarios_async() @@ -86,7 +88,7 @@ async def test_list_scenarios_returns_scenarios_from_registry(self) -> None: with patch.object(ScenarioService, "__init__", lambda self: None): service = ScenarioService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [metadata] + service._registry.get_all_registered_class_metadata.return_value = [metadata] result = await service.list_scenarios_async() @@ -98,7 +100,6 @@ async def test_list_scenarios_returns_scenarios_from_registry(self) -> None: assert result.items[0].aggregate_strategies == ["all", "default"] assert result.items[0].all_strategies == ["role_play", "many_shot"] assert result.items[0].default_datasets == ["test_dataset"] - assert result.items[0].max_dataset_size is None async def test_list_scenarios_paginates_with_limit(self) -> None: """Test that list respects the limit parameter.""" @@ -109,7 +110,7 @@ async def test_list_scenarios_paginates_with_limit(self) -> None: with patch.object(ScenarioService, "__init__", lambda self: None): service = ScenarioService() service._registry = MagicMock() - service._registry.list_metadata.return_value = metadata_list + service._registry.get_all_registered_class_metadata.return_value = metadata_list result = await service.list_scenarios_async(limit=3) @@ -126,7 +127,7 @@ async def test_list_scenarios_paginates_with_cursor(self) -> None: with patch.object(ScenarioService, "__init__", lambda self: None): service = ScenarioService() service._registry = MagicMock() - service._registry.list_metadata.return_value = metadata_list + service._registry.get_all_registered_class_metadata.return_value = metadata_list result = await service.list_scenarios_async(limit=2, cursor="test.scenario_1") @@ -144,7 +145,7 @@ async def test_list_scenarios_last_page_has_more_false(self) -> None: with patch.object(ScenarioService, "__init__", lambda self: None): service = ScenarioService() service._registry = MagicMock() - service._registry.list_metadata.return_value = metadata_list + service._registry.get_all_registered_class_metadata.return_value = metadata_list result = await service.list_scenarios_async(limit=5) @@ -152,19 +153,6 @@ async def test_list_scenarios_last_page_has_more_false(self) -> None: assert result.pagination.has_more is False assert result.pagination.next_cursor is None - async def test_list_scenarios_includes_max_dataset_size(self) -> None: - """Test that max_dataset_size is included in response.""" - metadata = _make_scenario_metadata(max_dataset_size=10) - - with patch.object(ScenarioService, "__init__", lambda self: None): - service = ScenarioService() - service._registry = MagicMock() - service._registry.list_metadata.return_value = [metadata] - - result = await service.list_scenarios_async() - - assert result.items[0].max_dataset_size == 10 - class TestScenarioServiceGetScenario: """Tests for ScenarioService.get_scenario_async.""" @@ -176,7 +164,7 @@ async def test_get_scenario_returns_matching_scenario(self) -> None: with patch.object(ScenarioService, "__init__", lambda self: None): service = ScenarioService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [metadata] + service._registry.get_registered_class_metadata.return_value = metadata result = await service.get_scenario_async(scenario_name="foundry.red_team_agent") @@ -188,7 +176,7 @@ async def test_get_scenario_returns_none_for_missing(self) -> None: with patch.object(ScenarioService, "__init__", lambda self: None): service = ScenarioService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [] + service._registry.get_registered_class_metadata.return_value = None result = await service.get_scenario_async(scenario_name="nonexistent") @@ -232,7 +220,6 @@ def test_list_scenarios_with_items(self, client: TestClient) -> None: aggregate_strategies=["all", "default"], all_strategies=["role_play", "many_shot"], default_datasets=["airt_hate"], - max_dataset_size=10, ) with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: @@ -257,7 +244,6 @@ def test_list_scenarios_with_items(self, client: TestClient) -> None: assert item["aggregate_strategies"] == ["all", "default"] assert item["all_strategies"] == ["role_play", "many_shot"] assert item["default_datasets"] == ["airt_hate"] - assert item["max_dataset_size"] == 10 def test_list_scenarios_passes_pagination_params(self, client: TestClient) -> None: """Test that pagination params are forwarded to service.""" @@ -286,7 +272,6 @@ def test_get_scenario_returns_200(self, client: TestClient) -> None: aggregate_strategies=["all"], all_strategies=["role_play"], default_datasets=["airt_hate"], - max_dataset_size=None, ) with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: @@ -321,7 +306,6 @@ def test_get_scenario_with_dotted_name(self, client: TestClient) -> None: aggregate_strategies=["all"], all_strategies=["base64", "rot13"], default_datasets=[], - max_dataset_size=None, ) with patch("pyrit.backend.routes.scenarios.get_scenario_service") as mock_get_service: @@ -355,7 +339,6 @@ async def test_list_scenarios_includes_supported_parameters(self) -> None: all_strategies=("role_play",), aggregate_strategies=("all",), default_datasets=("test_dataset",), - max_dataset_size=None, supported_parameters=( Parameter( name="max_turns", @@ -375,7 +358,7 @@ async def test_list_scenarios_includes_supported_parameters(self) -> None: with patch.object(ScenarioService, "__init__", lambda self: None): service = ScenarioService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [metadata] + service._registry.get_all_registered_class_metadata.return_value = [metadata] result = await service.list_scenarios_async() @@ -404,7 +387,7 @@ async def test_scenario_with_no_parameters_has_empty_list(self) -> None: with patch.object(ScenarioService, "__init__", lambda self: None): service = ScenarioService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [metadata] + service._registry.get_all_registered_class_metadata.return_value = [metadata] result = await service.list_scenarios_async() @@ -421,7 +404,6 @@ async def test_supported_parameters_with_none_default(self) -> None: all_strategies=("all",), aggregate_strategies=("all",), default_datasets=(), - max_dataset_size=None, supported_parameters=( Parameter( name="optional_param", @@ -435,7 +417,7 @@ async def test_supported_parameters_with_none_default(self) -> None: with patch.object(ScenarioService, "__init__", lambda self: None): service = ScenarioService() service._registry = MagicMock() - service._registry.list_metadata.return_value = [metadata] + service._registry.get_all_registered_class_metadata.return_value = [metadata] result = await service.list_scenarios_async() diff --git a/tests/unit/cli/test_api_client.py b/tests/unit/cli/test_api_client.py index 92add6622f..63a1de07a3 100644 --- a/tests/unit/cli/test_api_client.py +++ b/tests/unit/cli/test_api_client.py @@ -21,6 +21,7 @@ TargetCapabilitiesInfo, TargetInstance, ) +from unit.mocks import make_scenario_result @pytest.fixture() @@ -60,7 +61,6 @@ def _scenario_payload(*, scenario_name: str = "s1") -> dict: "aggregate_strategies": [], "all_strategies": ["single_turn"], "default_datasets": [], - "max_dataset_size": None, "supported_parameters": [], } @@ -145,7 +145,9 @@ async def test_async_context_manager_passes_custom_request_timeout(mock_httpx_cl fake_async_client_cls.assert_called_once_with(base_url="http://localhost:8000", timeout=120.0) -async def test_async_context_manager_uses_default_when_request_timeout_is_none(mock_httpx_client): +async def test_async_context_manager_uses_default_when_request_timeout_is_none( + mock_httpx_client, +): c = PyRITApiClient(base_url="http://localhost:8000", request_timeout=None) fake_async_client_cls = MagicMock(return_value=mock_httpx_client) with patch("httpx.AsyncClient", fake_async_client_cls): @@ -358,10 +360,10 @@ async def test_get_scenario_run_async_wraps_connect_error(client, mock_httpx_cli async def test_get_scenario_run_results_async(client, mock_httpx_client): # Build a minimal ScenarioResult.to_dict() payload that from_dict can deserialize. - from pyrit.models import ScenarioIdentifier, ScenarioResult, ScenarioRunState + from pyrit.models import ScenarioResult, ScenarioRunState - scenario_result = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name="x"), + scenario_result = make_scenario_result( + scenario_name="x", objective_target_identifier=None, objective_scorer_identifier=None, attack_results={}, diff --git a/tests/unit/cli/test_output.py b/tests/unit/cli/test_output.py index 9cac48fc2b..8561e01f24 100644 --- a/tests/unit/cli/test_output.py +++ b/tests/unit/cli/test_output.py @@ -22,6 +22,7 @@ TargetCapabilitiesInfo, TargetInstance, ) +from unit.mocks import make_scenario_result # --------------------------------------------------------------------------- # Typed-object factory helpers @@ -37,7 +38,6 @@ def _make_scenario(**overrides) -> RegisteredScenario: "aggregate_strategies": [], "all_strategies": [], "default_datasets": [], - "max_dataset_size": None, "supported_parameters": [], } defaults.update(overrides) @@ -170,7 +170,6 @@ def test_print_scenario_list_full(capsys): all_strategies=["s1", "s2", "s3"], default_strategy="s1", default_datasets=["d1", "d2"], - max_dataset_size=50, supported_parameters=[ Parameter( name="max_turns", @@ -195,7 +194,7 @@ def test_print_scenario_list_full(capsys): assert "single_turn" in captured.out assert "Available Strategies (3)" in captured.out assert "Default Strategy: s1" in captured.out - assert "Default Datasets (2, max 50 per dataset)" in captured.out + assert "Default Datasets (2)" in captured.out assert "Supported Parameters" in captured.out assert "max_turns" in captured.out assert "mode" in captured.out @@ -477,11 +476,8 @@ async def test_print_scenario_result_async_accepts_real_scenario_result(): AttackOutcome, AttackResult, ComponentIdentifier, - ScenarioIdentifier, - ScenarioResult, ) - identifier = ScenarioIdentifier(name="test.scenario", description="A test") target_identifier = ComponentIdentifier.model_validate( {"__type__": "FakeTarget", "__module__": "test.mod", "params": {}} ) @@ -493,8 +489,9 @@ async def test_print_scenario_result_async_accepts_real_scenario_result(): execution_time_ms=150, timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), ) - scenario_result = ScenarioResult( - scenario_identifier=identifier, + scenario_result = make_scenario_result( + scenario_name="test.scenario", + scenario_description="A test", objective_target_identifier=target_identifier, objective_scorer_identifier=None, attack_results={"strat_a": [attack]}, diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index 0b2cfa58e5..7ec86a0814 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -14,6 +14,7 @@ from pyrit.cli import _config_reader as pyrit_scan_config_reader from pyrit.cli import pyrit_scan from pyrit.models import Parameter +from unit.mocks import make_scenario_result def _sp(*, name, description="", default=None, param_type="str", choices=None, is_list=False) -> Parameter: @@ -168,7 +169,12 @@ def test_no_scenario_keys_returns_empty(self): def test_scenario_keys_extracted_with_prefix_stripped(self): result = pyrit_scan._extract_scenario_args( - parsed=Namespace(scenario_name="x", config_file=None, scenario__max_turns=10, scenario__mode="fast") + parsed=Namespace( + scenario_name="x", + config_file=None, + scenario__max_turns=10, + scenario__mode="fast", + ) ) assert result == {"max_turns": 10, "mode": "fast"} @@ -181,8 +187,6 @@ def _make_scenario_result(): AttackOutcome, AttackResult, ComponentIdentifier, - ScenarioIdentifier, - ScenarioResult, ScenarioRunState, ) @@ -194,8 +198,9 @@ def _make_scenario_result(): execution_time_ms=10, timestamp=datetime(2025, 1, 1, tzinfo=timezone.utc), ) - return ScenarioResult( - scenario_identifier=ScenarioIdentifier(name="test_scenario", description="A test"), + return make_scenario_result( + scenario_name="test_scenario", + scenario_description="A test", objective_target_identifier=ComponentIdentifier.model_validate( {"__type__": "FakeTarget", "__module__": "test.mod", "params": {}} ), @@ -234,7 +239,6 @@ def _mock_api_client(): aggregate_strategies=[], all_strategies=[], default_datasets=[], - max_dataset_size=None, supported_parameters=[], ) client.start_scenario_run_async.return_value = ScenarioRunSummary( @@ -275,7 +279,11 @@ def _mock_api_client(): class TestMain: """Tests for main function (thin REST client).""" - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_list_scenarios(self, mock_client_class, mock_probe): """Test main with --list-scenarios flag.""" @@ -287,7 +295,11 @@ def test_main_list_scenarios(self, mock_client_class, mock_probe): assert result == 0 mock_client.list_scenarios_async.assert_awaited_once() - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_list_initializers(self, mock_client_class, mock_probe): """Test main with --list-initializers flag.""" @@ -299,7 +311,11 @@ def test_main_list_initializers(self, mock_client_class, mock_probe): assert result == 0 mock_client.list_initializers_async.assert_awaited_once() - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_list_targets(self, mock_client_class, mock_probe): """Test main with --list-targets flag.""" @@ -340,7 +356,11 @@ def test_main_no_args_shows_help(self): result = pyrit_scan.main([]) assert result == 0 # shows help and exits - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") @patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock) def test_main_run_scenario(self, _mock_print, mock_client_class, mock_probe): @@ -354,7 +374,11 @@ def test_main_run_scenario(self, _mock_print, mock_client_class, mock_probe): mock_client.get_scenario_async.assert_awaited_once() mock_client.start_scenario_run_async.assert_awaited_once() - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") @patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock) def test_main_run_scenario_with_initializers(self, _mock_print, mock_client_class, mock_probe): @@ -369,7 +393,11 @@ def test_main_run_scenario_with_initializers(self, _mock_print, mock_client_clas request = call_kwargs["request"] assert request.initializers == ["target", "datasets"] - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=False) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=False, + ) def test_main_server_not_available(self, mock_probe, capsys): """Test main when server is not available.""" result = pyrit_scan.main(["--list-scenarios"]) @@ -382,13 +410,21 @@ def test_main_malformed_config_is_hard_error(self, tmp_path, capsys): """A malformed --config-file should fail loudly, not silently use defaults.""" bad = tmp_path / "bad.yaml" bad.write_text(": :\nnot yaml: [unbalanced\n", encoding="utf-8") - with patch.object(pyrit_scan_config_reader, "_DEFAULT_CONFIG_FILE", tmp_path / "missing_default.yaml"): + with patch.object( + pyrit_scan_config_reader, + "_DEFAULT_CONFIG_FILE", + tmp_path / "missing_default.yaml", + ): result = pyrit_scan.main(["--list-scenarios", "--config-file", str(bad)]) assert result == 1 assert "not valid YAML" in capsys.readouterr().err - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=False) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=False, + ) def test_main_stop_server(self, mock_probe, capsys): """Test main with --stop-server.""" result = pyrit_scan.main(["--stop-server"]) @@ -397,7 +433,11 @@ def test_main_stop_server(self, mock_probe, capsys): captured = capsys.readouterr() assert "No server running" in captured.out - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_scenario_not_found(self, mock_client_class, mock_probe, capsys): """Test main when scenario is not found on server.""" @@ -411,7 +451,11 @@ def test_main_scenario_not_found(self, mock_client_class, mock_probe, capsys): captured = capsys.readouterr() assert "not found" in captured.out - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_failed_scenario(self, mock_client_class, mock_probe): """Test main when scenario run fails.""" @@ -553,7 +597,10 @@ class TestBuildRunRequest: def test_includes_initializer_args(self): parsed = Namespace( target="t", - initializers=[{"name": "openai_target", "args": {"model": "gpt-4"}}, "datasets"], + initializers=[ + {"name": "openai_target", "args": {"model": "gpt-4"}}, + "datasets", + ], scenario_strategies=None, max_concurrency=None, max_retries=None, @@ -839,7 +886,11 @@ def test_main_no_args_prints_help_and_exits_zero(self, capsys): captured = capsys.readouterr() assert "PyRIT Scanner" in captured.out or "usage" in captured.out.lower() - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_scenario_not_found_lists_available(self, mock_client_class, _mock_probe, capsys): from pyrit.models.catalog import RegisteredScenario @@ -855,7 +906,6 @@ def test_main_scenario_not_found_lists_available(self, mock_client_class, _mock_ aggregate_strategies=[], all_strategies=[], default_datasets=[], - max_dataset_size=None, ), RegisteredScenario( scenario_name="alt_b", @@ -865,7 +915,6 @@ def test_main_scenario_not_found_lists_available(self, mock_client_class, _mock_ aggregate_strategies=[], all_strategies=[], default_datasets=[], - max_dataset_size=None, ), ] mock_client_class.return_value = mock_client @@ -876,7 +925,11 @@ def test_main_scenario_not_found_lists_available(self, mock_client_class, _mock_ assert "alt_a" in captured.out assert "alt_b" in captured.out - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_start_scenario_failure(self, mock_client_class, _mock_probe, capsys): mock_client = _mock_api_client() @@ -888,7 +941,11 @@ def test_main_start_scenario_failure(self, mock_client_class, _mock_probe, capsy captured = capsys.readouterr() assert "server full" in captured.out - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_run_results_failure_is_hard_error(self, mock_client_class, _mock_probe, capsys): mock_client = _mock_api_client() @@ -905,7 +962,11 @@ def test_main_run_results_failure_is_hard_error(self, mock_client_class, _mock_p # The summary printer should still be used as a fallback for context. assert "test_scenario" in captured.out - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_start_server_only_prints_url_and_returns_zero(self, mock_client_class, _mock_probe, capsys): result = pyrit_scan.main(["--start-server"]) @@ -936,7 +997,11 @@ def test_main_stop_server_when_process_cannot_be_identified(self, _stop_mock, _m out = capsys.readouterr().out assert "could not identify" in out - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_add_initializer_missing_file(self, mock_client_class, _mock_probe, capsys, tmp_path): mock_client = _mock_api_client() @@ -947,7 +1012,11 @@ def test_main_add_initializer_missing_file(self, mock_client_class, _mock_probe, assert result == 1 assert "File not found" in capsys.readouterr().out - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_add_initializer_success(self, mock_client_class, _mock_probe, capsys, tmp_path): mock_client = _mock_api_client() @@ -962,7 +1031,11 @@ def test_main_add_initializer_success(self, mock_client_class, _mock_probe, caps assert "Registered initializer 'myinit'" in capsys.readouterr().out mock_client.register_initializer_async.assert_awaited_once() - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") def test_main_add_initializer_server_disabled(self, mock_client_class, _mock_probe, capsys, tmp_path): from pyrit.cli.api_client import ServerNotAvailableError @@ -1020,7 +1093,6 @@ def _build_mock_client(supported_params=None, status="COMPLETED"): aggregate_strategies=[], all_strategies=[], default_datasets=[], - max_dataset_size=None, ) ] client.get_scenario_async.return_value = RegisteredScenario( @@ -1031,7 +1103,6 @@ def _build_mock_client(supported_params=None, status="COMPLETED"): aggregate_strategies=[], all_strategies=[], default_datasets=[], - max_dataset_size=None, supported_parameters=typed_params, ) client.start_scenario_run_async.return_value = ScenarioRunSummary( @@ -1057,7 +1128,11 @@ def _build_mock_client(supported_params=None, status="COMPLETED"): client.__aexit__ = AsyncMock(return_value=None) return client - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") @patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock) @patch("pyrit.cli._output.print_scenario_run_progress") @@ -1071,7 +1146,11 @@ def test_scenario_declared_flag_is_forwarded(self, _mock_prog, _mock_print, mock sent_request = client.start_scenario_run_async.call_args.kwargs["request"] assert sent_request.scenario_params == {"max_turns": "7"} - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") @patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock) @patch("pyrit.cli._output.print_scenario_run_progress") @@ -1081,7 +1160,12 @@ def test_typed_scenario_flags_are_forwarded_as_typed_values( client = self._build_mock_client( supported_params=[ {"name": "dry_run", "description": "...", "param_type": "bool"}, - {"name": "sample_ids", "description": "...", "param_type": "list[int]", "is_list": True}, + { + "name": "sample_ids", + "description": "...", + "param_type": "list[int]", + "is_list": True, + }, ] ) mock_client_class.return_value = client @@ -1092,7 +1176,11 @@ def test_typed_scenario_flags_are_forwarded_as_typed_values( sent_request = client.start_scenario_run_async.call_args.kwargs["request"] assert sent_request.scenario_params == {"dry_run": True, "sample_ids": [1, 2]} - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") @patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock) @patch("pyrit.cli._output.print_scenario_run_progress") @@ -1105,7 +1193,11 @@ def test_unknown_flag_after_valid_scenario_errors(self, _mock_prog, _mock_print, assert result == 1 client.start_scenario_run_async.assert_not_called() - @patch("pyrit.cli._server_launcher.ServerLauncher.probe_health_async", new_callable=AsyncMock, return_value=True) + @patch( + "pyrit.cli._server_launcher.ServerLauncher.probe_health_async", + new_callable=AsyncMock, + return_value=True, + ) @patch("pyrit.cli.api_client.PyRITApiClient") @patch("pyrit.cli._output.print_scenario_result_async", new_callable=AsyncMock) @patch("pyrit.cli._output.print_scenario_run_progress") diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 7e6450a40c..f13c237ec3 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -12,6 +12,7 @@ from pyrit.cli import pyrit_shell from pyrit.models import Parameter +from unit.mocks import make_scenario_result def _sp(*, name, description="", default=None, param_type="str", choices=None, is_list=False) -> Parameter: @@ -51,7 +52,6 @@ def mock_api_client(): aggregate_strategies=[], all_strategies=[], default_datasets=[], - max_dataset_size=None, supported_parameters=[], ) client.close_async = AsyncMock() @@ -66,7 +66,6 @@ def mock_api_client(): aggregate_strategies=kw.get("aggregate_strategies", []), all_strategies=kw.get("all_strategies", []), default_datasets=kw.get("default_datasets", []), - max_dataset_size=kw.get("max_dataset_size", None), supported_parameters=kw.get("supported_parameters", []), ) # Suppress unused-import warning for datetime/timezone helpers used by tests. @@ -278,7 +277,10 @@ def test_main_parses_server_url(self): mock_shell = MagicMock() mock_shell_class.return_value = mock_shell - with patch("sys.argv", ["pyrit_shell", "--server-url", "http://remote:9000", "--no-animation"]): + with patch( + "sys.argv", + ["pyrit_shell", "--server-url", "http://remote:9000", "--no-animation"], + ): pyrit_shell.main() mock_shell_class.assert_called_once() @@ -344,7 +346,10 @@ def test_explicit_server_url_wins(self): def test_falls_back_to_config_reader(self, tmp_path): s = pyrit_shell.PyRITShell(no_animation=True) - with patch("pyrit.cli._config_reader.read_server_url", return_value="http://from-cfg:8000"): + with patch( + "pyrit.cli._config_reader.read_server_url", + return_value="http://from-cfg:8000", + ): assert s._resolve_base_url() == "http://from-cfg:8000" def test_default_when_config_returns_none(self): @@ -364,7 +369,10 @@ def test_start_server_launches_when_not_running(self): new_callable=AsyncMock, return_value=False, ), - patch("pyrit.cli._server_launcher.ServerLauncher.start_async", new_callable=AsyncMock) as mock_start, + patch( + "pyrit.cli._server_launcher.ServerLauncher.start_async", + new_callable=AsyncMock, + ) as mock_start, patch("pyrit.cli.api_client.PyRITApiClient") as mock_client_class, ): mock_start.return_value = "http://localhost:8000" @@ -493,10 +501,10 @@ def _run_payload(status="COMPLETED"): @staticmethod def _empty_scenario_result(): """Build a minimal ScenarioResult for use as get_scenario_run_results_async return.""" - from pyrit.models import ScenarioIdentifier, ScenarioResult, ScenarioRunState + from pyrit.models import ScenarioRunState - return ScenarioResult( - scenario_identifier=ScenarioIdentifier(name="foo"), + return make_scenario_result( + scenario_name="foo", objective_target_identifier=None, objective_scorer_identifier=None, attack_results={}, @@ -659,11 +667,11 @@ def test_scenario_history_error(self, shell, capsys): class TestPrintScenarioAndHelp: def test_print_scenario_success(self, shell): - from pyrit.models import ScenarioIdentifier, ScenarioResult, ScenarioRunState + from pyrit.models import ScenarioRunState s, client = shell - empty_result = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name="foo"), + empty_result = make_scenario_result( + scenario_name="foo", objective_target_identifier=None, objective_scorer_identifier=None, attack_results={}, @@ -719,7 +727,10 @@ def test_start_server_launch_success(self): new_callable=AsyncMock, return_value=False, ), - patch("pyrit.cli._server_launcher.ServerLauncher.start_async", new_callable=AsyncMock) as mock_start, + patch( + "pyrit.cli._server_launcher.ServerLauncher.start_async", + new_callable=AsyncMock, + ) as mock_start, patch("pyrit.cli.api_client.PyRITApiClient") as mock_client_class, ): mock_start.return_value = "http://localhost:8000" @@ -740,7 +751,10 @@ def test_start_server_launch_replaces_existing_client(self): new_callable=AsyncMock, return_value=False, ), - patch("pyrit.cli._server_launcher.ServerLauncher.start_async", new_callable=AsyncMock) as mock_start, + patch( + "pyrit.cli._server_launcher.ServerLauncher.start_async", + new_callable=AsyncMock, + ) as mock_start, patch("pyrit.cli.api_client.PyRITApiClient") as mock_client_class, ): mock_start.return_value = "http://localhost:8000" @@ -916,7 +930,10 @@ def test_shell_choices_rejected_before_request(self, shell, capsys): class TestSplitInitializerPaths: def test_posix_splits_on_whitespace(self): with patch.object(pyrit_shell.os, "name", "posix"): - assert pyrit_shell._split_initializer_paths("/a/one.py /b/two.py") == ["/a/one.py", "/b/two.py"] + assert pyrit_shell._split_initializer_paths("/a/one.py /b/two.py") == [ + "/a/one.py", + "/b/two.py", + ] def test_posix_respects_quotes_with_spaces(self): with patch.object(pyrit_shell.os, "name", "posix"): diff --git a/tests/unit/memory/memory_interface/test_batching_scale.py b/tests/unit/memory/memory_interface/test_batching_scale.py index 65c3805877..a1402d6bde 100644 --- a/tests/unit/memory/memory_interface/test_batching_scale.py +++ b/tests/unit/memory/memory_interface/test_batching_scale.py @@ -11,9 +11,11 @@ import uuid from unittest.mock import patch +from unit.mocks import make_scenario_result + from pyrit.memory import MemoryInterface from pyrit.memory.memory_models import PromptMemoryEntry -from pyrit.models import AttackResult, ComponentIdentifier, MessagePiece, ScenarioIdentifier, ScenarioResult, Score +from pyrit.models import AttackResult, ComponentIdentifier, MessagePiece, ScenarioResult, Score # Use the class attribute for the batch limit in tests _MAX_BIND_VARS = MemoryInterface._MAX_BIND_VARS @@ -58,13 +60,10 @@ def _create_scenario_result( attack_results: dict[str, list[AttackResult]] | None = None, ) -> ScenarioResult: """Create a sample scenario result for testing.""" - return ScenarioResult( - scenario_identifier=ScenarioIdentifier( - name=name, - description="test", - scenario_version=1, - init_data={}, - ), + return make_scenario_result( + scenario_name=name, + scenario_description="test", + scenario_version=1, objective_target_identifier=ComponentIdentifier(class_name="TestTarget", class_module="test"), attack_results=attack_results or {}, objective_scorer_identifier=ComponentIdentifier(class_name="TestScorer", class_module="test"), diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index f818e45ecd..d1dae1ac38 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta, timezone import pytest -from unit.mocks import get_mock_scorer_identifier +from unit.mocks import get_mock_scorer_identifier, make_scenario_result from pyrit.memory import MemoryInterface from pyrit.models import ( @@ -13,8 +13,6 @@ ComponentIdentifier, IdentifierFilter, IdentifierType, - ScenarioIdentifier, - ScenarioResult, ) @@ -44,13 +42,6 @@ def create_scenario_result( attack_results: dict[str, list[AttackResult]] | None = None, ): """Helper function to create ScenarioResult.""" - scenario_identifier = ScenarioIdentifier( - name=name, - description=description, - scenario_version=version, - init_data={"test_key": "test_value"}, - ) - if attack_results is None: attack_results = {} @@ -59,8 +50,10 @@ def create_scenario_result( class_module="tests.unit.memory", ) - return ScenarioResult( - scenario_identifier=scenario_identifier, + return make_scenario_result( + scenario_name=name, + scenario_version=version, + scenario_description=description, objective_target_identifier=ComponentIdentifier(class_name="test_target", class_module="test"), attack_results=attack_results, objective_scorer_identifier=scorer_identifier, @@ -92,7 +85,7 @@ def test_add_and_retrieve_scenario_results(sqlite_instance: MemoryInterface, sam assert len(all_scenarios) == 2 # Verify the data was stored correctly - scenario_names = {scenario.scenario_identifier.name for scenario in all_scenarios} + scenario_names = {scenario.scenario_name for scenario in all_scenarios} assert scenario_names == {"Scenario 1", "Scenario 2"} @@ -112,7 +105,7 @@ def test_filter_by_name(sqlite_instance: MemoryInterface, sample_attack_results) # Query by name substring results = sqlite_instance.get_scenario_results(scenario_name="Test") assert len(results) == 1 - assert results[0].scenario_identifier.name == "Test Scenario Alpha" + assert results[0].scenario_name == "Test Scenario Alpha" def test_filter_by_version(sqlite_instance: MemoryInterface, sample_attack_results): @@ -133,7 +126,7 @@ def test_filter_by_version(sqlite_instance: MemoryInterface, sample_attack_resul # Query by version results = sqlite_instance.get_scenario_results(scenario_version=2) assert len(results) == 1 - assert results[0].scenario_identifier.version == 2 + assert results[0].scenario_version == 2 def test_filter_by_ids(sqlite_instance: MemoryInterface, sample_attack_results): @@ -148,7 +141,7 @@ def test_filter_by_ids(sqlite_instance: MemoryInterface, sample_attack_results): # Query by ID using the scenario result's id results = sqlite_instance.get_scenario_results(scenario_result_ids=[str(scenario_result1.id)]) assert len(results) == 1 - assert results[0].scenario_identifier.name == "Scenario 1" + assert results[0].scenario_name == "Scenario 1" assert results[0].id == scenario_result1.id @@ -165,7 +158,10 @@ def test_attack_results_populated_correctly(sqlite_instance: MemoryInterface): sid = scenario_result.id attack_result1 = _make_attack_result_for_scenario( - scenario_result_id=sid, atomic_attack_name="PromptInjection", objective_index=0, conversation_id="conv_1" + scenario_result_id=sid, + atomic_attack_name="PromptInjection", + objective_index=0, + conversation_id="conv_1", ) attack_result2 = _make_attack_result_for_scenario( scenario_result_id=sid, @@ -175,7 +171,10 @@ def test_attack_results_populated_correctly(sqlite_instance: MemoryInterface): outcome=AttackOutcome.FAILURE, ) attack_result3 = _make_attack_result_for_scenario( - scenario_result_id=sid, atomic_attack_name="Crescendo", objective_index=0, conversation_id="conv_3" + scenario_result_id=sid, + atomic_attack_name="Crescendo", + objective_index=0, + conversation_id="conv_3", ) sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) @@ -210,7 +209,10 @@ def test_attack_order_preserved(sqlite_instance: MemoryInterface): # Insert in a specific order; hydration must surface them in the same order. attack_results = [ _make_attack_result_for_scenario( - scenario_result_id=sid, atomic_attack_name="Attack1", objective_index=i, conversation_id=f"conv_{i}" + scenario_result_id=sid, + atomic_attack_name="Attack1", + objective_index=i, + conversation_id=f"conv_{i}", ) for i in range(5) ] @@ -265,22 +267,20 @@ def test_preserves_metadata(sqlite_instance: MemoryInterface): """Test that scenario metadata is preserved correctly.""" # Create scenario result with metadata - scenario_identifier = ScenarioIdentifier( - name="Metadata Test Scenario", - description="A test scenario with metadata", - scenario_version=3, - init_data={"param1": "value1", "param2": 42}, - ) - scorer_identifier = ComponentIdentifier( class_name="TestScorer", class_module="test.module", ) - scenario_result = ScenarioResult( - scenario_identifier=scenario_identifier, + scenario_result = make_scenario_result( + scenario_name="Metadata Test Scenario", + scenario_version=3, + scenario_description="A test scenario with metadata", + params={"param1": "value1", "param2": 42}, objective_target_identifier=ComponentIdentifier( - class_name="test_target", class_module="test", params={"endpoint": "https://example.com"} + class_name="test_target", + class_module="test", + params={"endpoint": "https://example.com"}, ), attack_results={}, objective_scorer_identifier=scorer_identifier, @@ -292,10 +292,11 @@ def test_preserves_metadata(sqlite_instance: MemoryInterface): assert len(results) == 1 retrieved = results[0] - assert retrieved.scenario_identifier.name == "Metadata Test Scenario" - assert retrieved.scenario_identifier.description == "A test scenario with metadata" - assert retrieved.scenario_identifier.version == 3 - assert retrieved.scenario_identifier.init_data == {"param1": "value1", "param2": 42} + assert retrieved.scenario_name == "Metadata Test Scenario" + assert retrieved.scenario_description == "A test scenario with metadata" + assert retrieved.scenario_version == 3 + assert retrieved.scenario_identifier.params["param1"] == "value1" + assert retrieved.scenario_identifier.params["param2"] == 42 assert retrieved.objective_target_identifier.params["endpoint"] == "https://example.com" # objective_scorer_identifier is now a ComponentIdentifier, check its properties assert retrieved.objective_scorer_identifier.class_name == "TestScorer" @@ -333,9 +334,9 @@ def test_multiple_scenarios_with_attacks(sqlite_instance: MemoryInterface): # Verify each scenario has the correct attack results for result in results: - if result.scenario_identifier.name == "Scenario 1": + if result.scenario_name == "Scenario 1": assert len(result.attack_results["Attack1"]) == 5 - elif result.scenario_identifier.name == "Scenario 2": + elif result.scenario_name == "Scenario 2": assert len(result.attack_results["Attack2"]) == 3 @@ -358,16 +359,16 @@ def test_filter_by_name_and_version(sqlite_instance: MemoryInterface): # Query with both filters results = sqlite_instance.get_scenario_results(scenario_name="Test", scenario_version=2) assert len(results) == 1 - assert results[0].scenario_identifier.name == "Test Scenario" - assert results[0].scenario_identifier.version == 2 + assert results[0].scenario_name == "Test Scenario" + assert results[0].scenario_version == 2 def test_filter_by_labels(sqlite_instance: MemoryInterface, sample_attack_results): """Test scenario results with labels.""" # Create scenario with labels - scenario_identifier = ScenarioIdentifier(name="Labeled Scenario", scenario_version=1) - scenario_result = ScenarioResult( - scenario_identifier=scenario_identifier, + scenario_result = make_scenario_result( + scenario_name="Labeled Scenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [sample_attack_results[0]]}, labels={"environment": "testing", "team": "red-team"}, @@ -389,18 +390,18 @@ def test_filter_by_multiple_labels(sqlite_instance: MemoryInterface): sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) # Create scenarios with different labels - scenario1_identifier = ScenarioIdentifier(name="Scenario 1", scenario_version=1) - scenario1 = ScenarioResult( - scenario_identifier=scenario1_identifier, + scenario1 = make_scenario_result( + scenario_name="Scenario 1", + scenario_version=1, objective_target_identifier=ComponentIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [attack_result1]}, labels={"environment": "testing", "team": "red-team"}, objective_scorer_identifier=get_mock_scorer_identifier(), ) - scenario2_identifier = ScenarioIdentifier(name="Scenario 2", scenario_version=1) - scenario2 = ScenarioResult( - scenario_identifier=scenario2_identifier, + scenario2 = make_scenario_result( + scenario_name="Scenario 2", + scenario_version=1, objective_target_identifier=ComponentIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack2": [attack_result2]}, labels={"environment": "production", "team": "red-team"}, @@ -411,7 +412,7 @@ def test_filter_by_multiple_labels(sqlite_instance: MemoryInterface): # Query requiring both labels to match results = sqlite_instance.get_scenario_results(labels={"environment": "testing", "team": "red-team"}) assert len(results) == 1 - assert results[0].scenario_identifier.name == "Scenario 1" + assert results[0].scenario_name == "Scenario 1" def test_filter_by_completion_time(sqlite_instance: MemoryInterface): @@ -427,27 +428,27 @@ def test_filter_by_completion_time(sqlite_instance: MemoryInterface): yesterday = now - timedelta(days=1) last_week = now - timedelta(days=7) - scenario1_identifier = ScenarioIdentifier(name="Recent Scenario", scenario_version=1) - scenario1 = ScenarioResult( - scenario_identifier=scenario1_identifier, + scenario1 = make_scenario_result( + scenario_name="Recent Scenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [attack_result1]}, completion_time=now, objective_scorer_identifier=get_mock_scorer_identifier(), ) - scenario2_identifier = ScenarioIdentifier(name="Yesterday Scenario", scenario_version=1) - scenario2 = ScenarioResult( - scenario_identifier=scenario2_identifier, + scenario2 = make_scenario_result( + scenario_name="Yesterday Scenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack2": [attack_result2]}, completion_time=yesterday, objective_scorer_identifier=get_mock_scorer_identifier(), ) - scenario3_identifier = ScenarioIdentifier(name="Old Scenario", scenario_version=1) - scenario3 = ScenarioResult( - scenario_identifier=scenario3_identifier, + scenario3 = make_scenario_result( + scenario_name="Old Scenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack3": [attack_result3]}, completion_time=last_week, @@ -458,14 +459,14 @@ def test_filter_by_completion_time(sqlite_instance: MemoryInterface): # Query scenarios after yesterday results = sqlite_instance.get_scenario_results(added_after=yesterday) assert len(results) == 2 - result_names = {r.scenario_identifier.name for r in results} + result_names = {r.scenario_name for r in results} assert "Recent Scenario" in result_names assert "Yesterday Scenario" in result_names # Query scenarios before yesterday results = sqlite_instance.get_scenario_results(added_before=yesterday) assert len(results) == 2 - result_names = {r.scenario_identifier.name for r in results} + result_names = {r.scenario_name for r in results} assert "Yesterday Scenario" in result_names assert "Old Scenario" in result_names @@ -478,17 +479,19 @@ def test_filter_by_pyrit_version(sqlite_instance: MemoryInterface): sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) # Create scenarios with different PyRIT versions - scenario1_identifier = ScenarioIdentifier(name="Old Version Scenario", scenario_version=1, pyrit_version="0.4.0") - scenario1 = ScenarioResult( - scenario_identifier=scenario1_identifier, + scenario1 = make_scenario_result( + scenario_name="Old Version Scenario", + scenario_version=1, + pyrit_version="0.4.0", objective_target_identifier=ComponentIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [attack_result1]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) - scenario2_identifier = ScenarioIdentifier(name="New Version Scenario", scenario_version=1, pyrit_version="0.5.0") - scenario2 = ScenarioResult( - scenario_identifier=scenario2_identifier, + scenario2 = make_scenario_result( + scenario_name="New Version Scenario", + scenario_version=1, + pyrit_version="0.5.0", objective_target_identifier=ComponentIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack2": [attack_result2]}, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -498,8 +501,8 @@ def test_filter_by_pyrit_version(sqlite_instance: MemoryInterface): # Query by PyRIT version results = sqlite_instance.get_scenario_results(pyrit_version="0.5.0") assert len(results) == 1 - assert results[0].scenario_identifier.name == "New Version Scenario" - assert results[0].scenario_identifier.pyrit_version == "0.5.0" + assert results[0].scenario_name == "New Version Scenario" + assert results[0].pyrit_version == "0.5.0" def test_filter_by_target_endpoint(sqlite_instance: MemoryInterface): @@ -511,29 +514,33 @@ def test_filter_by_target_endpoint(sqlite_instance: MemoryInterface): sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) # Create scenarios with different target endpoints - scenario1_identifier = ScenarioIdentifier(name="Azure Scenario", scenario_version=1) - scenario1 = ScenarioResult( - scenario_identifier=scenario1_identifier, + scenario1 = make_scenario_result( + scenario_name="Azure Scenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier( - class_name="OpenAI", class_module="test", params={"endpoint": "https://myresource.openai.azure.com"} + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://myresource.openai.azure.com"}, ), attack_results={"Attack1": [attack_result1]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) - scenario2_identifier = ScenarioIdentifier(name="OpenAI Scenario", scenario_version=1) - scenario2 = ScenarioResult( - scenario_identifier=scenario2_identifier, + scenario2 = make_scenario_result( + scenario_name="OpenAI Scenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier( - class_name="OpenAI", class_module="test", params={"endpoint": "https://api.openai.com/v1"} + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com/v1"}, ), attack_results={"Attack2": [attack_result2]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) - scenario3_identifier = ScenarioIdentifier(name="No Endpoint Scenario", scenario_version=1) - scenario3 = ScenarioResult( - scenario_identifier=scenario3_identifier, + scenario3 = make_scenario_result( + scenario_name="No Endpoint Scenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier(class_name="Local", class_module="test"), attack_results={"Attack3": [attack_result3]}, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -543,12 +550,12 @@ def test_filter_by_target_endpoint(sqlite_instance: MemoryInterface): # Query by endpoint (case-insensitive substring match) results = sqlite_instance.get_scenario_results(objective_target_endpoint="azure") assert len(results) == 1 - assert results[0].scenario_identifier.name == "Azure Scenario" + assert results[0].scenario_name == "Azure Scenario" # Query for OpenAI endpoints results = sqlite_instance.get_scenario_results(objective_target_endpoint="openai") assert len(results) == 2 - result_names = {r.scenario_identifier.name for r in results} + result_names = {r.scenario_name for r in results} assert "Azure Scenario" in result_names assert "OpenAI Scenario" in result_names @@ -562,19 +569,21 @@ def test_filter_by_target_model_name(sqlite_instance: MemoryInterface): sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) # Create scenarios with different model names - scenario1_identifier = ScenarioIdentifier(name="GPT-4 Scenario", scenario_version=1) - scenario1 = ScenarioResult( - scenario_identifier=scenario1_identifier, + scenario1 = make_scenario_result( + scenario_name="GPT-4 Scenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier( - class_name="OpenAI", class_module="test", params={"model_name": "gpt-4-0613"} + class_name="OpenAI", + class_module="test", + params={"model_name": "gpt-4-0613"}, ), attack_results={"Attack1": [attack_result1]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) - scenario2_identifier = ScenarioIdentifier(name="GPT-4o Scenario", scenario_version=1) - scenario2 = ScenarioResult( - scenario_identifier=scenario2_identifier, + scenario2 = make_scenario_result( + scenario_name="GPT-4o Scenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier( class_name="OpenAI", class_module="test", params={"model_name": "gpt-4o"} ), @@ -582,11 +591,13 @@ def test_filter_by_target_model_name(sqlite_instance: MemoryInterface): objective_scorer_identifier=get_mock_scorer_identifier(), ) - scenario3_identifier = ScenarioIdentifier(name="GPT-3.5 Scenario", scenario_version=1) - scenario3 = ScenarioResult( - scenario_identifier=scenario3_identifier, + scenario3 = make_scenario_result( + scenario_name="GPT-3.5 Scenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier( - class_name="OpenAI", class_module="test", params={"model_name": "gpt-3.5-turbo"} + class_name="OpenAI", + class_module="test", + params={"model_name": "gpt-3.5-turbo"}, ), attack_results={"Attack3": [attack_result3]}, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -596,14 +607,14 @@ def test_filter_by_target_model_name(sqlite_instance: MemoryInterface): # Query by model name (case-insensitive substring match) results = sqlite_instance.get_scenario_results(objective_target_model_name="gpt-4") assert len(results) == 2 - result_names = {r.scenario_identifier.name for r in results} + result_names = {r.scenario_name for r in results} assert "GPT-4 Scenario" in result_names assert "GPT-4o Scenario" in result_names # Query for GPT-3.5 results = sqlite_instance.get_scenario_results(objective_target_model_name="3.5") assert len(results) == 1 - assert results[0].scenario_identifier.name == "GPT-3.5 Scenario" + assert results[0].scenario_name == "GPT-3.5 Scenario" def test_combined_filters(sqlite_instance: MemoryInterface): @@ -617,9 +628,10 @@ def test_combined_filters(sqlite_instance: MemoryInterface): now = datetime.now(timezone.utc) yesterday = now - timedelta(days=1) - scenario1_identifier = ScenarioIdentifier(name="Test Scenario", scenario_version=1, pyrit_version="0.5.0") - scenario1 = ScenarioResult( - scenario_identifier=scenario1_identifier, + scenario1 = make_scenario_result( + scenario_name="Test Scenario", + scenario_version=1, + pyrit_version="0.5.0", objective_target_identifier=ComponentIdentifier( class_name="OpenAI", class_module="test", @@ -631,9 +643,10 @@ def test_combined_filters(sqlite_instance: MemoryInterface): objective_scorer_identifier=get_mock_scorer_identifier(), ) - scenario2_identifier = ScenarioIdentifier(name="Test Scenario", scenario_version=1, pyrit_version="0.4.0") - scenario2 = ScenarioResult( - scenario_identifier=scenario2_identifier, + scenario2 = make_scenario_result( + scenario_name="Test Scenario", + scenario_version=1, + pyrit_version="0.4.0", objective_target_identifier=ComponentIdentifier( class_name="Azure", class_module="test", @@ -654,7 +667,7 @@ def test_combined_filters(sqlite_instance: MemoryInterface): labels={"environment": "testing"}, ) assert len(results) == 1 - assert results[0].scenario_identifier.pyrit_version == "0.5.0" + assert results[0].pyrit_version == "0.5.0" assert "gpt-4" in results[0].objective_target_identifier.params["model_name"] @@ -684,7 +697,9 @@ def _make_attack_result_for_scenario( ) -def test_get_scenario_results_loads_attack_results_via_foreign_key(sqlite_instance: MemoryInterface): +def test_get_scenario_results_loads_attack_results_via_foreign_key( + sqlite_instance: MemoryInterface, +): """When AttackResultEntry rows carry the attribution_parent_id foreign key, hydration picks them up directly — without needing the legacy attack_results_json manifest. This is the path that makes mid-AtomicAttack @@ -710,7 +725,9 @@ def test_get_scenario_results_loads_attack_results_via_foreign_key(sqlite_instan assert [r.conversation_id for r in result.attack_results["b"]] == ["conv-b-0"] -def test_get_attack_results_filters_by_scenario_result_id(sqlite_instance: MemoryInterface): +def test_get_attack_results_filters_by_scenario_result_id( + sqlite_instance: MemoryInterface, +): """get_attack_results gains a scenario_result_id filter — replaces the removed error_attack_result_ids_json lookup path.""" scenario_result = create_scenario_result(name="Filter Scenario") @@ -729,7 +746,10 @@ def test_get_attack_results_filters_by_scenario_result_id(sqlite_instance: Memor sqlite_instance.add_attack_results_to_memory(attack_results=[ok, err, unrelated]) all_for_scenario = sqlite_instance.get_attack_results(scenario_result_id=str(sid)) - assert {r.conversation_id for r in all_for_scenario} == {ok.conversation_id, err.conversation_id} + assert {r.conversation_id for r in all_for_scenario} == { + ok.conversation_id, + err.conversation_id, + } only_errors = sqlite_instance.get_attack_results( scenario_result_id=str(sid), @@ -738,7 +758,9 @@ def test_get_attack_results_filters_by_scenario_result_id(sqlite_instance: Memor assert [r.conversation_id for r in only_errors] == [err.conversation_id] -def test_delete_scenario_sets_attack_result_foreign_key_to_null(sqlite_instance: MemoryInterface): +def test_delete_scenario_sets_attack_result_foreign_key_to_null( + sqlite_instance: MemoryInterface, +): """ON DELETE SET NULL: deleting the parent ScenarioResultEntry nulls the attribution_parent_id foreign key on its linked AttackResultEntries but the AttackResultEntries survive (attribution_data is retained as @@ -776,7 +798,9 @@ def test_delete_scenario_sets_attack_result_foreign_key_to_null(sqlite_instance: assert entry.attribution_data == {"parent_collection": "a"} -def test_update_scenario_run_state_targeted_update_preserves_manifest(sqlite_instance: MemoryInterface): +def test_update_scenario_run_state_targeted_update_preserves_manifest( + sqlite_instance: MemoryInterface, +): """update_scenario_run_state must be a targeted UPDATE — it must not re-serialize the whole row and clobber the manifest column during the deprecation window.""" @@ -801,7 +825,9 @@ def test_update_scenario_run_state_targeted_update_preserves_manifest(sqlite_ins assert hydrated.error_type == "RuntimeError" -def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: MemoryInterface): +def test_get_scenario_results_by_target_identifier_filter_hash( + sqlite_instance: MemoryInterface, +): """Test filtering scenario results by identifier filter.""" target_id_1 = ComponentIdentifier( class_name="OpenAI", @@ -818,14 +844,16 @@ def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: attack_result2 = create_attack_result("conv_2", "Objective 2") sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) - scenario1 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name="Scenario OpenAI", scenario_version=1), + scenario1 = make_scenario_result( + scenario_name="Scenario OpenAI", + scenario_version=1, objective_target_identifier=target_id_1, attack_results={"Attack1": [attack_result1]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) - scenario2 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name="Scenario Azure", scenario_version=1), + scenario2 = make_scenario_result( + scenario_name="Scenario Azure", + scenario_version=1, objective_target_identifier=target_id_2, attack_results={"Attack2": [attack_result2]}, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -844,10 +872,12 @@ def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: ], ) assert len(results) == 1 - assert results[0].scenario_identifier.name == "Scenario OpenAI" + assert results[0].scenario_name == "Scenario OpenAI" -def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instance: MemoryInterface): +def test_get_scenario_results_by_target_identifier_filter_endpoint( + sqlite_instance: MemoryInterface, +): """Test filtering scenario results by identifier filter with endpoint.""" target_id_1 = ComponentIdentifier( class_name="OpenAI", @@ -864,14 +894,16 @@ def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instan attack_result2 = create_attack_result("conv_2", "Objective 2") sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) - scenario1 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name="Scenario OpenAI", scenario_version=1), + scenario1 = make_scenario_result( + scenario_name="Scenario OpenAI", + scenario_version=1, objective_target_identifier=target_id_1, attack_results={"Attack1": [attack_result1]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) - scenario2 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name="Scenario Azure", scenario_version=1), + scenario2 = make_scenario_result( + scenario_name="Scenario Azure", + scenario_version=1, objective_target_identifier=target_id_2, attack_results={"Attack2": [attack_result2]}, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -890,16 +922,19 @@ def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instan ], ) assert len(results) == 1 - assert results[0].scenario_identifier.name == "Scenario OpenAI" + assert results[0].scenario_name == "Scenario OpenAI" -def test_get_scenario_results_by_target_identifier_filter_no_match(sqlite_instance: MemoryInterface): +def test_get_scenario_results_by_target_identifier_filter_no_match( + sqlite_instance: MemoryInterface, +): """Test that TargetIdentifierFilter returns empty when nothing matches.""" attack_result1 = create_attack_result("conv_1", "Objective 1") sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1]) - scenario1 = ScenarioResult( - scenario_identifier=ScenarioIdentifier(name="Test Scenario", scenario_version=1), + scenario1 = make_scenario_result( + scenario_name="Test Scenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier( class_name="OpenAI", class_module="test", diff --git a/tests/unit/memory/test_memory_models.py b/tests/unit/memory/test_memory_models.py index a2b244fc1f..7cf45abbe5 100644 --- a/tests/unit/memory/test_memory_models.py +++ b/tests/unit/memory/test_memory_models.py @@ -28,13 +28,13 @@ ConversationReference, ConversationType, MessagePiece, - ScenarioIdentifier, ScenarioResult, Score, SeedObjective, SeedPrompt, SeedSimulatedConversation, ) +from unit.mocks import make_scenario_result # --------------------------------------------------------------------------- # Helpers @@ -397,7 +397,10 @@ def test_roundtrip_seed_objective_strips_reserved_key(self): name="obj1", dataset_name="ds", added_by="tester", - metadata={SEED_RESPONSE_JSON_SCHEMA_METADATA_KEY: "sneaky", "owned": "by-caller"}, + metadata={ + SEED_RESPONSE_JSON_SCHEMA_METADATA_KEY: "sneaky", + "owned": "by-caller", + }, ) entry = SeedEntry(entry=obj) assert SEED_RESPONSE_JSON_SCHEMA_METADATA_KEY not in entry.prompt_metadata @@ -414,7 +417,10 @@ def test_roundtrip_seed_simulated_conversation_strips_reserved_key(self): num_turns=3, adversarial_chat_system_prompt_path="/path/to/adversarial.yaml", simulated_target_system_prompt_path="/path/to/target.yaml", - metadata={SEED_RESPONSE_JSON_SCHEMA_METADATA_KEY: "sneaky", "owned": "by-caller"}, + metadata={ + SEED_RESPONSE_JSON_SCHEMA_METADATA_KEY: "sneaky", + "owned": "by-caller", + }, ) entry = SeedEntry(entry=config) assert SEED_RESPONSE_JSON_SCHEMA_METADATA_KEY not in entry.prompt_metadata @@ -546,7 +552,8 @@ def test_get_attack_result_prefers_atomic_over_stale_attack_identifier(self): class TestScenarioResultEntry: def _make_scenario_result(self, **overrides) -> ScenarioResult: defaults = { - "scenario_identifier": ScenarioIdentifier(name="test_scenario", description="desc"), + "scenario_name": "test_scenario", + "scenario_description": "desc", "objective_target_identifier": ComponentIdentifier(class_name="MockTarget", class_module="tests.mocks"), "attack_results": {}, "objective_scorer_identifier": ComponentIdentifier(class_name="MockScorer", class_module="pyrit.score"), @@ -556,7 +563,7 @@ def _make_scenario_result(self, **overrides) -> ScenarioResult: "completion_time": datetime.now(tz=timezone.utc), } defaults.update(overrides) - return ScenarioResult(**defaults) + return make_scenario_result(**defaults) def test_init_from_scenario_result(self): sr = self._make_scenario_result() @@ -570,7 +577,7 @@ def test_roundtrip_get_scenario_result(self): sr = self._make_scenario_result() entry = ScenarioResultEntry(entry=sr) recovered = entry.get_scenario_result() - assert recovered.scenario_identifier.name == "test_scenario" + assert recovered.scenario_name == "test_scenario" assert recovered.scenario_run_state == "COMPLETED" # attack_results should be empty after roundtrip (populated by memory_interface) assert recovered.attack_results == {} diff --git a/tests/unit/memory/test_migration.py b/tests/unit/memory/test_migration.py index 2054621438..f927ee6e32 100644 --- a/tests/unit/memory/test_migration.py +++ b/tests/unit/memory/test_migration.py @@ -253,9 +253,10 @@ def _seed_pre_migration_scenario(connection, *, scenario_id, manifest_json): text( 'INSERT INTO "ScenarioResultEntries" ' "(id, scenario_name, scenario_description, scenario_version, pyrit_version, " - "objective_target_identifier, scenario_run_state, attack_results_json, " + "objective_target_identifier, scenario_init_data, scenario_run_state, attack_results_json, " "number_tries, completion_time, timestamp) " - "VALUES (:id, :name, '', 1, '0.14.0.dev0', '{}', 'COMPLETED', :manifest, 0, '2026-05-18', '2026-05-18')" + "VALUES (:id, :name, '', 1, '0.14.0.dev0', '{}', '{}', 'COMPLETED', :manifest, 0, " + "'2026-05-18', '2026-05-18')" ), {"id": scenario_id, "name": "Backfill Test", "manifest": manifest_json}, ) @@ -419,9 +420,9 @@ def test_backfill_is_idempotent_and_does_not_clobber_existing_linkage(): text( 'INSERT INTO "ScenarioResultEntries" ' "(id, scenario_name, scenario_description, scenario_version, pyrit_version, " - "objective_target_identifier, scenario_run_state, attack_results_json, " + "objective_target_identifier, scenario_init_data, scenario_run_state, attack_results_json, " "number_tries, completion_time, timestamp) " - "VALUES (:id, 'Other', '', 1, '0.14.0.dev0', '{}', 'COMPLETED', :manifest, 0, " + "VALUES (:id, 'Other', '', 1, '0.14.0.dev0', '{}', '{}', 'COMPLETED', :manifest, 0, " "'2026-05-18', '2026-05-18')" ), {"id": str(uuid.uuid4()), "manifest": json.dumps({"x": ["conv-shared"]})}, diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 2616bd1fd5..8e64fd9a51 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -7,13 +7,88 @@ import uuid from collections.abc import Generator, MutableSequence, Sequence from contextlib import AbstractAsyncContextManager +from typing import Any from unittest.mock import MagicMock, patch from pyrit.memory import AzureSQLMemory, CentralMemory, PromptMemoryEntry -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import ( + ComponentIdentifier, + Message, + MessagePiece, + ScenarioIdentifier, + ScenarioResult, +) from pyrit.prompt_target import PromptTarget, TargetCapabilities, TargetConfiguration, limit_requests_per_minute +def make_scenario_identifier( + *, + scenario_name: str = "TestScenario", + scenario_module: str = "tests.unit.mocks", + version: int = 1, + objective_target: ComponentIdentifier | None = None, + objective_scorer: ComponentIdentifier | None = None, + techniques: list[str] | None = None, + datasets: list[str] | None = None, + params: dict[str, Any] | None = None, + pyrit_version: str | None = None, +) -> ScenarioIdentifier: + """ + Build a ``ScenarioIdentifier`` for tests. + + Mirrors what ``Scenario._build_scenario_identifier`` produces so tests can + construct a ``ScenarioResult`` without a live scenario. + """ + extra: dict[str, Any] = {} + if pyrit_version is not None: + extra["pyrit_version"] = pyrit_version + return ScenarioIdentifier( + class_name=scenario_name, + class_module=scenario_module, + version=version, + techniques=techniques, + datasets=datasets, + params=dict(params) if params else {}, + objective_target=objective_target, + objective_scorer=objective_scorer, + **extra, + ) + + +def make_scenario_result( + *, + scenario_name: str = "TestScenario", + scenario_version: int = 1, + objective_target_identifier: ComponentIdentifier | None = None, + objective_scorer_identifier: ComponentIdentifier | None = None, + techniques: list[str] | None = None, + datasets: list[str] | None = None, + params: dict[str, Any] | None = None, + pyrit_version: str | None = None, + **kwargs: Any, +) -> ScenarioResult: + """ + Build a ``ScenarioResult`` for tests from flat identity kwargs. + + The identity kwargs (``scenario_name`` / ``scenario_version`` / + ``objective_target_identifier`` / ``objective_scorer_identifier`` / + ``techniques`` / ``datasets`` / ``params`` / ``pyrit_version``) are folded + into a ``ScenarioIdentifier``; all other kwargs pass through to + ``ScenarioResult``. + """ + identifier = make_scenario_identifier( + scenario_name=scenario_name, + version=scenario_version, + objective_target=objective_target_identifier, + objective_scorer=objective_scorer_identifier, + techniques=techniques, + datasets=datasets, + params=params, + pyrit_version=pyrit_version, + ) + return ScenarioResult(scenario_identifier=identifier, **kwargs) + + def get_mock_scorer_identifier() -> ComponentIdentifier: """ Returns a mock ComponentIdentifier for use in tests where the specific diff --git a/tests/unit/models/test_scenario_result.py b/tests/unit/models/test_scenario_result.py index ab2da79e8c..d35424b16f 100644 --- a/tests/unit/models/test_scenario_result.py +++ b/tests/unit/models/test_scenario_result.py @@ -6,22 +6,15 @@ import pytest -import pyrit from pyrit.models import ( ComponentIdentifier, ConversationReference, ConversationType, - ScenarioIdentifier, ScenarioResult, ) from pyrit.models.results.attack_result import AttackOutcome, AttackResult from pyrit.models.retry_event import RetryEvent - - -def _make_scenario_identifier(**kwargs): - defaults = {"name": "TestScenario", "description": "A test", "scenario_version": 1} - defaults.update(kwargs) - return ScenarioIdentifier(**defaults) +from tests.unit.mocks import make_scenario_result def _make_component_identifier_dict(class_name="TestTarget"): @@ -36,53 +29,27 @@ def _make_attack_result(*, objective="test objective", outcome=AttackOutcome.SUC ) -class TestScenarioIdentifier: - def test_init_basic(self): - si = ScenarioIdentifier(name="MySc") - assert si.name == "MySc" - assert si.description == "" - assert si.version == 1 - assert si.init_data is None - - def test_init_with_all_params(self): - si = ScenarioIdentifier( - name="MySc", - description="desc", - scenario_version=2, - init_data={"key": "val"}, - pyrit_version="1.0.0", - ) - assert si.version == 2 - assert si.init_data == {"key": "val"} - assert si.pyrit_version == "1.0.0" - - def test_init_default_pyrit_version(self): - si = ScenarioIdentifier(name="X") - assert si.pyrit_version == pyrit.__version__ - - class TestScenarioResult: def test_init_basic(self): - si = _make_scenario_identifier() target_id = _make_component_identifier_dict() scorer_id = _make_component_identifier_dict("TestScorer") - result = ScenarioResult( - scenario_identifier=si, + result = make_scenario_result( + scenario_name="TestScenario", objective_target_identifier=target_id, attack_results={"strat1": []}, objective_scorer_identifier=scorer_id, ) - assert result.scenario_identifier is si + assert result.scenario_name == "TestScenario" + assert result.scenario_version == 1 assert result.scenario_run_state == "CREATED" assert result.labels == {} assert result.number_tries == 0 assert isinstance(result.id, uuid.UUID) def test_init_with_explicit_id(self): - si = _make_scenario_identifier() explicit_id = uuid.uuid4() - result = ScenarioResult( - scenario_identifier=si, + result = make_scenario_result( + scenario_name="TestScenario", objective_target_identifier=ComponentIdentifier.from_dict({}), attack_results={}, objective_scorer_identifier=ComponentIdentifier.from_dict({}), @@ -91,9 +58,8 @@ def test_init_with_explicit_id(self): assert result.id == explicit_id def test_get_strategies_used(self): - si = _make_scenario_identifier() - result = ScenarioResult( - scenario_identifier=si, + result = make_scenario_result( + scenario_name="TestScenario", objective_target_identifier=ComponentIdentifier.from_dict({}), attack_results={"crescendo": [], "flip": []}, objective_scorer_identifier=ComponentIdentifier.from_dict({}), @@ -105,8 +71,8 @@ def test_get_objectives_all(self): ar1 = _make_attack_result(objective="obj1") ar2 = _make_attack_result(objective="obj2") ar3 = _make_attack_result(objective="obj1") - result = ScenarioResult( - scenario_identifier=_make_scenario_identifier(), + result = make_scenario_result( + scenario_name="TestScenario", objective_target_identifier=ComponentIdentifier.from_dict({}), attack_results={"s1": [ar1, ar3], "s2": [ar2]}, objective_scorer_identifier=ComponentIdentifier.from_dict({}), @@ -117,8 +83,8 @@ def test_get_objectives_all(self): def test_get_objectives_by_attack_name(self): ar1 = _make_attack_result(objective="obj1") ar2 = _make_attack_result(objective="obj2") - result = ScenarioResult( - scenario_identifier=_make_scenario_identifier(), + result = make_scenario_result( + scenario_name="TestScenario", objective_target_identifier=ComponentIdentifier.from_dict({}), attack_results={"s1": [ar1], "s2": [ar2]}, objective_scorer_identifier=ComponentIdentifier.from_dict({}), @@ -133,8 +99,8 @@ def test_objective_achieved_rate_all(self): _make_attack_result(outcome=AttackOutcome.SUCCESS), _make_attack_result(outcome=AttackOutcome.UNDETERMINED), ] - sr = ScenarioResult( - scenario_identifier=_make_scenario_identifier(), + sr = make_scenario_result( + scenario_name="TestScenario", objective_target_identifier=ComponentIdentifier.from_dict({}), attack_results={"s1": results}, objective_scorer_identifier=ComponentIdentifier.from_dict({}), @@ -142,8 +108,8 @@ def test_objective_achieved_rate_all(self): assert sr.objective_achieved_rate() == 50 def test_objective_achieved_rate_empty(self): - sr = ScenarioResult( - scenario_identifier=_make_scenario_identifier(), + sr = make_scenario_result( + scenario_name="TestScenario", objective_target_identifier=ComponentIdentifier.from_dict({}), attack_results={"s1": []}, objective_scorer_identifier=ComponentIdentifier.from_dict({}), @@ -151,8 +117,8 @@ def test_objective_achieved_rate_empty(self): assert sr.objective_achieved_rate() == 0 def test_objective_achieved_rate_by_name(self): - sr = ScenarioResult( - scenario_identifier=_make_scenario_identifier(), + sr = make_scenario_result( + scenario_name="TestScenario", objective_target_identifier=ComponentIdentifier.from_dict({}), attack_results={ "s1": [_make_attack_result(outcome=AttackOutcome.SUCCESS)], @@ -176,8 +142,8 @@ def test_normalize_scenario_name_mixed_case_with_underscore(self): def test_error_attack_result_ids_defaults_to_empty(self): """error_attack_result_ids defaults to empty list.""" - sr = ScenarioResult( - scenario_identifier=_make_scenario_identifier(), + sr = make_scenario_result( + scenario_name="TestScenario", objective_target_identifier=ComponentIdentifier.from_dict({}), attack_results={}, objective_scorer_identifier=ComponentIdentifier.from_dict({}), @@ -186,8 +152,8 @@ def test_error_attack_result_ids_defaults_to_empty(self): def test_error_attack_result_ids_stored(self): """error_attack_result_ids are stored correctly.""" - sr = ScenarioResult( - scenario_identifier=_make_scenario_identifier(), + sr = make_scenario_result( + scenario_name="TestScenario", objective_target_identifier=ComponentIdentifier.from_dict({}), attack_results={}, objective_scorer_identifier=ComponentIdentifier.from_dict({}), @@ -196,25 +162,7 @@ def test_error_attack_result_ids_stored(self): assert sr.error_attack_result_ids == ["id-1", "id-2"] -def test_scenario_identifier_to_dict_from_dict_roundtrip(): - original = ScenarioIdentifier( - name="ContentHarms", - description="Tests content harm scenarios", - scenario_version=3, - init_data={"max_turns": 5, "strategy": "crescendo"}, - pyrit_version="0.14.0", - ) - roundtripped = ScenarioIdentifier.from_dict(original.to_dict()) - assert original.to_dict() == roundtripped.to_dict() - - def test_scenario_result_to_dict_from_dict_roundtrip(): - scenario_id = ScenarioIdentifier( - name="ContentHarms", - description="Tests content harm scenarios", - scenario_version=2, - pyrit_version="0.14.0", - ) target_id = ComponentIdentifier( class_name="OpenAIChatTarget", class_module="pyrit.prompt_target", @@ -253,9 +201,11 @@ def test_scenario_result_to_dict_from_dict_roundtrip(): ], total_retries=1, ) - original = ScenarioResult( + original = make_scenario_result( id=uuid.UUID("12345678-1234-1234-1234-123456789abc"), - scenario_identifier=scenario_id, + scenario_name="ContentHarms", + scenario_version=2, + pyrit_version="0.14.0", objective_target_identifier=target_id, objective_scorer_identifier=scorer_id, scenario_run_state="COMPLETED", @@ -271,31 +221,19 @@ def test_scenario_result_to_dict_from_dict_roundtrip(): ) roundtripped = ScenarioResult.from_dict(original.to_dict()) assert original.to_dict() == roundtripped.to_dict() - # The nested identifier must preserve the legacy ``scenario_version`` wire key. - assert "scenario_version" in original.to_dict()["scenario_identifier"] - assert "version" not in original.to_dict()["scenario_identifier"] - - -def test_scenario_identifier_from_dict_missing_pyrit_version_uses_current(): - """A payload missing pyrit_version now resolves to the current version via the Pydantic default.""" - data = { - "name": "Legacy", - "description": "loaded from older payload", - "scenario_version": 1, - "init_data": None, - # pyrit_version intentionally absent - } - identifier = ScenarioIdentifier.from_dict(data) - assert identifier.pyrit_version == pyrit.__version__ + # Identity facts round-trip as denormalized flat scalars on the result. + assert original.to_dict()["scenario_name"] == "ContentHarms" + assert original.to_dict()["scenario_version"] == 2 + assert original.to_dict()["pyrit_version"] == "0.14.0" def test_scenario_result_from_dict_preserves_missing_completion_time(): """An in-progress scenario serialized without completion_time should round-trip with completion_time=None.""" - scenario_id = ScenarioIdentifier(name="Test", scenario_version=1, pyrit_version="0.14.0") target_id = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit.prompt_target") - original = ScenarioResult( - scenario_identifier=scenario_id, + original = make_scenario_result( + scenario_name="Test", + pyrit_version="0.14.0", objective_target_identifier=target_id, objective_scorer_identifier=None, attack_results={}, @@ -308,18 +246,10 @@ def test_scenario_result_from_dict_preserves_missing_completion_time(): assert roundtripped.scenario_run_state == "IN_PROGRESS" -def test_scenario_identifier_to_dict_from_dict_emit_deprecation_warnings(): - identifier = ScenarioIdentifier(name="Test", scenario_version=1, pyrit_version="0.14.0") - with pytest.warns(DeprecationWarning): - payload = identifier.to_dict() - with pytest.warns(DeprecationWarning): - ScenarioIdentifier.from_dict(payload) - - def test_scenario_result_to_dict_from_dict_emit_deprecation_warnings(): - scenario_id = ScenarioIdentifier(name="Test", scenario_version=1, pyrit_version="0.14.0") - result = ScenarioResult( - scenario_identifier=scenario_id, + result = make_scenario_result( + scenario_name="Test", + pyrit_version="0.14.0", objective_target_identifier=ComponentIdentifier.from_dict({}), objective_scorer_identifier=None, attack_results={}, @@ -331,9 +261,9 @@ def test_scenario_result_to_dict_from_dict_emit_deprecation_warnings(): def test_scenario_result_display_group_map_is_public_field(): - scenario_id = ScenarioIdentifier(name="Test", scenario_version=1, pyrit_version="0.14.0") - result = ScenarioResult( - scenario_identifier=scenario_id, + result = make_scenario_result( + scenario_name="Test", + pyrit_version="0.14.0", objective_target_identifier=ComponentIdentifier.from_dict({}), objective_scorer_identifier=None, attack_results={"crescendo": []}, diff --git a/tests/unit/output/scenario_result/test_pretty.py b/tests/unit/output/scenario_result/test_pretty.py index f1ba89c431..45671d74ad 100644 --- a/tests/unit/output/scenario_result/test_pretty.py +++ b/tests/unit/output/scenario_result/test_pretty.py @@ -4,15 +4,17 @@ import uuid import pytest +from unit.mocks import make_scenario_result -from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier, ScenarioIdentifier, ScenarioResult +from pyrit.models import ( + AttackOutcome, + AttackResult, + ComponentIdentifier, + ScenarioResult, +) from pyrit.output.scenario_result.pretty import PrettyScenarioResultMemoryPrinter -def _scenario_identifier(*, name: str = "TestScenario", description: str = "") -> ScenarioIdentifier: - return ScenarioIdentifier(name=name, description=description, scenario_version=1, pyrit_version="1.0.0") - - def _target_identifier(**params) -> ComponentIdentifier: return ComponentIdentifier(class_name="MockTarget", class_module="tests", params=params) @@ -29,8 +31,11 @@ def _scenario_result( objective_scorer_identifier: ComponentIdentifier | None = None, display_group_map: dict[str, str] | None = None, ) -> ScenarioResult: - return ScenarioResult( - scenario_identifier=_scenario_identifier(description=description), + return make_scenario_result( + scenario_name="TestScenario", + scenario_version=1, + pyrit_version="1.0.0", + scenario_description=description, objective_target_identifier=_target_identifier(**(target_params or {})), attack_results=attack_results or {"strategy_a": [_attack_result()]}, objective_scorer_identifier=objective_scorer_identifier, @@ -77,8 +82,10 @@ async def test_write_async_renders_full_summary(printer, capsys): async def test_write_async_with_unknown_target_when_no_params(printer, capsys): - result = ScenarioResult( - scenario_identifier=_scenario_identifier(), + result = make_scenario_result( + scenario_name="TestScenario", + scenario_version=1, + pyrit_version="1.0.0", objective_target_identifier=ComponentIdentifier.from_dict({}), attack_results={"s": []}, objective_scorer_identifier=None, @@ -101,7 +108,9 @@ async def fake_render_async(*, scorer_identifier, harm_category=None): assert "[scorer-render-output]" in capsys.readouterr().out -async def test_write_async_raises_when_scorer_identifier_present_without_scorer_printer(patch_central_database): +async def test_write_async_raises_when_scorer_identifier_present_without_scorer_printer( + patch_central_database, +): printer = PrettyScenarioResultMemoryPrinter(enable_colors=False) printer._scorer_printer = None result = _scenario_result(objective_scorer_identifier=_target_identifier()) @@ -114,7 +123,10 @@ async def test_write_async_raises_when_scorer_identifier_present_without_scorer_ [ (100, [AttackOutcome.SUCCESS, AttackOutcome.SUCCESS]), # >=75 RED band (50, [AttackOutcome.SUCCESS, AttackOutcome.FAILURE]), # >=50 YELLOW band - (33, [AttackOutcome.SUCCESS, AttackOutcome.FAILURE, AttackOutcome.FAILURE]), # >=25 CYAN band + ( + 33, + [AttackOutcome.SUCCESS, AttackOutcome.FAILURE, AttackOutcome.FAILURE], + ), # >=25 CYAN band (0, [AttackOutcome.FAILURE]), # <25 GREEN band ], ) @@ -207,7 +219,11 @@ async def test_write_async_sort_is_stable_for_ties(patch_central_database, capsy ) await sorting_printer.write_async(result) # Tied 100% groups retain their original relative order; 0% group goes last. - assert _group_order(capsys.readouterr().out) == ["first_success", "second_success", "fail"] + assert _group_order(capsys.readouterr().out) == [ + "first_success", + "second_success", + "fail", + ] # --- deprecated alias --- diff --git a/tests/unit/registry/test_registry.py b/tests/unit/registry/test_registry.py index e5fa51105c..2a59870eb8 100644 --- a/tests/unit/registry/test_registry.py +++ b/tests/unit/registry/test_registry.py @@ -240,10 +240,11 @@ def _metadata_class(self) -> type[ClassRegistryEntry]: def test_discover_skips_spec_type_mock_exports(): - # A foreign test may patch a discovery-package export with a ``MagicMock(spec=type)`` - # that reports ``isinstance(obj, type) is True`` yet makes ``issubclass`` raise - # ``TypeError``. Default discovery must skip it rather than blow up the whole catalog. - package = ModuleType("_fake_widget_package") + # Type-based discovery enumerates concrete subclasses of the base that live under + # the discovery package. A foreign export leaked into ``__all__`` (e.g. a + # ``MagicMock(spec=type)``) is materialized by the lazy-export force-load but is + # not a real subclass, so it must be ignored rather than blow up the catalog. + package = ModuleType(_ConcreteWidget.__module__) package.__all__ = ["_ConcreteWidget", "_LeakedMock"] package._ConcreteWidget = _ConcreteWidget package._LeakedMock = MagicMock(spec=type) diff --git a/tests/unit/registry/test_resolution.py b/tests/unit/registry/test_resolution.py index a37e0a3694..6857a04fd0 100644 --- a/tests/unit/registry/test_resolution.py +++ b/tests/unit/registry/test_resolution.py @@ -15,8 +15,9 @@ from pyrit.models.identifiers import ConverterIdentifier, TargetIdentifier from pyrit.models.parameter import ComponentType from pyrit.prompt_target import PromptTarget -from pyrit.registry.components import TargetRegistry +from pyrit.registry.components import ConverterRegistry, ScorerRegistry, TargetRegistry from pyrit.registry.resolution import ( + _registry_getter_for_component_type, derive_parameters, display_choices, resolve_constructor_args, @@ -276,3 +277,37 @@ def test_module_has_no_backend_dependency() -> None: elif isinstance(node, ast.ImportFrom) and node.module: imported_modules.append(node.module) assert not any(name.startswith("pyrit.backend") for name in imported_modules) + + +# Component families whose references resolve by name, and the registry each maps to. +# Kept in the test (not imported from resolution.py) so it is an independent spec: +# the test fails if the production mapping drifts from this expectation. +_RESOLVABLE_COMPONENT_REGISTRIES = { + ComponentType.TARGET: TargetRegistry, + ComponentType.CONVERTER: ConverterRegistry, + ComponentType.SCORER: ScorerRegistry, +} +# Scenarios are created by name, never referenced by name inside another component, +# so they are deliberately not wired for reference resolution. +_NON_RESOLVABLE_COMPONENT_TYPES = {ComponentType.SCENARIO} + + +def test_every_component_type_is_classified() -> None: + # Guard against silently adding a ComponentType without deciding whether its + # references resolve by name. A new member forces an update here (and to the + # resolution map), rather than failing only at build time. + classified = set(_RESOLVABLE_COMPONENT_REGISTRIES) | _NON_RESOLVABLE_COMPONENT_TYPES + assert set(ComponentType) == classified + + +@pytest.mark.parametrize("component_type", list(_RESOLVABLE_COMPONENT_REGISTRIES)) +def test_resolvable_component_type_maps_to_its_registry(component_type: ComponentType) -> None: + getter = _registry_getter_for_component_type(component_type) + assert getter is not None + expected_registry = _RESOLVABLE_COMPONENT_REGISTRIES[component_type] + assert getter() is expected_registry.get_registry_singleton().instances + + +@pytest.mark.parametrize("component_type", sorted(_NON_RESOLVABLE_COMPONENT_TYPES)) +def test_non_resolvable_component_type_has_no_registry(component_type: ComponentType) -> None: + assert _registry_getter_for_component_type(component_type) is None diff --git a/tests/unit/registry/test_scenario_registry.py b/tests/unit/registry/test_scenario_registry.py index 263e6d60f2..7c2b519579 100644 --- a/tests/unit/registry/test_scenario_registry.py +++ b/tests/unit/registry/test_scenario_registry.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Tests for ScenarioRegistry._build_metadata.""" +"""Tests for ScenarioRegistry._build_metadata and create_and_initialize_async.""" + +from unittest.mock import AsyncMock, MagicMock import pytest -from pyrit.registry.class_registries.base_class_registry import ClassEntry -from pyrit.registry.class_registries.scenario_registry import ScenarioRegistry +from pyrit.registry.components.scenario_registry import ScenarioRegistry class _NotNoArgScenario: @@ -23,7 +24,44 @@ def __init__(self, *, required_arg) -> None: def test_build_metadata_raises_when_scenario_requires_constructor_args() -> None: """Scenarios that cannot be instantiated with no args must surface a clear error.""" registry = ScenarioRegistry() - entry = ClassEntry(registered_class=_NotNoArgScenario) with pytest.raises(TypeError, match="must be instantiable with no arguments"): - registry._build_metadata("not_no_arg", entry) + registry._build_metadata("not_no_arg", _NotNoArgScenario) + + +async def test_create_and_initialize_async_creates_sets_params_and_initializes() -> None: + """The registry owns build + set-params + initialize and returns the scenario.""" + registry = ScenarioRegistry() + + scenario = MagicMock() + scenario.initialize_async = AsyncMock() + target = MagicMock() + + registry.create_instance = MagicMock(return_value=scenario) # type: ignore[method-assign] + + result = await registry.create_and_initialize_async( + "my.scenario", + scenario_params={"foo": "bar"}, + scenario_result_id="sr-1", + objective_target=target, + max_concurrency=2, + ) + + assert result is scenario + registry.create_instance.assert_called_once_with("my.scenario", scenario_result_id="sr-1") + scenario.set_params_from_args.assert_called_once_with(args={"foo": "bar"}) + scenario.initialize_async.assert_awaited_once_with(objective_target=target, max_concurrency=2) + + +async def test_create_and_initialize_async_omits_result_id_when_none() -> None: + """When no scenario_result_id is supplied, it is not forwarded to the constructor.""" + registry = ScenarioRegistry() + + scenario = MagicMock() + scenario.initialize_async = AsyncMock() + registry.create_instance = MagicMock(return_value=scenario) # type: ignore[method-assign] + + await registry.create_and_initialize_async("my.scenario", objective_target=MagicMock()) + + registry.create_instance.assert_called_once_with("my.scenario") + scenario.set_params_from_args.assert_called_once_with(args={}) diff --git a/tests/unit/scenario/core/test_scenario.py b/tests/unit/scenario/core/test_scenario.py index d8366efe3c..276184c3e7 100644 --- a/tests/unit/scenario/core/test_scenario.py +++ b/tests/unit/scenario/core/test_scenario.py @@ -17,9 +17,20 @@ from pyrit.executor.attack.core import AttackExecutorResult from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier -from pyrit.scenario import DatasetAttackConfiguration, DatasetConfiguration, ScenarioIdentifier, ScenarioResult -from pyrit.scenario.core import AtomicAttack, BaselineAttackPolicy, Scenario, ScenarioStrategy +from pyrit.scenario import ( + DatasetAttackConfiguration, + DatasetConfiguration, + ScenarioIdentifier, + ScenarioResult, +) +from pyrit.scenario.core import ( + AtomicAttack, + BaselineAttackPolicy, + Scenario, + ScenarioStrategy, +) from pyrit.score import Scorer +from tests.unit.mocks import make_scenario_identifier, make_scenario_result # Reusable test scorer identifier _TEST_SCORER_ID = ComponentIdentifier( @@ -179,24 +190,22 @@ def test_init_with_valid_params(self, mock_objective_target): ) assert scenario.name == "Test Scenario" - assert scenario._identifier.name == "ConcreteScenario" - assert scenario._identifier.version == 1 + assert scenario._version == 1 + assert scenario._description == "Concrete implementation of Scenario for testing." assert scenario._memory_labels == {} assert scenario._max_concurrency is None assert scenario._max_retries == 0 # Default value assert scenario.atomic_attack_count == 0 # Not initialized yet - def test_init_creates_scenario_identifier(self, mock_objective_target): - """Test that initialization creates a proper ScenarioIdentifier.""" + def test_init_stores_scenario_version_and_description(self, mock_objective_target): + """Test that initialization stores run metadata used by ScenarioResult.""" scenario = ConcreteScenario( name="Test Scenario", version=3, ) - assert isinstance(scenario._identifier, ScenarioIdentifier) - assert scenario._identifier.name == "ConcreteScenario" - assert scenario._identifier.version == 3 - assert scenario._identifier.pyrit_version is not None + assert scenario._version == 3 + assert scenario._description == "Concrete implementation of Scenario for testing." def test_init_with_empty_attack_strategies(self, mock_objective_target): """Test that initialization works without attack_strategies.""" @@ -474,10 +483,9 @@ async def test_run_async_returns_scenario_result_with_identifier( result = await scenario.run_async() assert isinstance(result, ScenarioResult) - assert isinstance(result.scenario_identifier, ScenarioIdentifier) - assert result.scenario_identifier.name == "ConcreteScenario" - assert result.scenario_identifier.version == 5 - assert result.scenario_identifier.pyrit_version is not None + assert result.scenario_name == "ConcreteScenario" + assert result.scenario_version == 5 + assert result.pyrit_version is not None assert result.get_strategies_used() == [ "attack_run_1", "attack_run_2", @@ -567,15 +575,19 @@ class TestScenarioResult: def test_scenario_result_initialization(self, sample_attack_results): """Test ScenarioResult initialization.""" - identifier = ScenarioIdentifier(name="Test", scenario_version=1) - result = ScenarioResult( - scenario_identifier=identifier, + result = make_scenario_result( + scenario_name="Test", + scenario_version=1, objective_target_identifier=ComponentIdentifier(class_name="TestTarget", class_module="test"), - attack_results={"base64": sample_attack_results[:3], "rot13": sample_attack_results[3:]}, + attack_results={ + "base64": sample_attack_results[:3], + "rot13": sample_attack_results[3:], + }, objective_scorer_identifier=_TEST_SCORER_ID, ) - assert result.scenario_identifier == identifier + assert result.scenario_name == "Test" + assert result.scenario_version == 1 assert result.get_strategies_used() == ["base64", "rot13"] assert len(result.attack_results) == 2 assert len(result.attack_results["base64"]) == 3 @@ -583,9 +595,9 @@ def test_scenario_result_initialization(self, sample_attack_results): def test_scenario_result_with_empty_results(self): """Test ScenarioResult with empty attack results.""" - identifier = ScenarioIdentifier(name="TestScenario", scenario_version=1) - result = ScenarioResult( - scenario_identifier=identifier, + result = make_scenario_result( + scenario_name="TestScenario", + scenario_version=1, objective_target_identifier=ComponentIdentifier( class_name="TestTarget", class_module="test", @@ -599,11 +611,10 @@ def test_scenario_result_with_empty_results(self): def test_scenario_result_objective_achieved_rate(self, sample_attack_results): """Test objective_achieved_rate calculation.""" - identifier = ScenarioIdentifier(name="Test", scenario_version=1) - # All successful - result = ScenarioResult( - scenario_identifier=identifier, + result = make_scenario_result( + scenario_name="Test", + scenario_version=1, objective_target_identifier=ComponentIdentifier( class_name="TestTarget", class_module="test", @@ -628,8 +639,9 @@ def test_scenario_result_objective_achieved_rate(self, sample_attack_results): executed_turns=1, ), ] - result2 = ScenarioResult( - scenario_identifier=identifier, + result2 = make_scenario_result( + scenario_name="Test", + scenario_version=1, objective_target_identifier=ComponentIdentifier( class_name="TestTarget", class_module="test", @@ -642,29 +654,30 @@ def test_scenario_result_objective_achieved_rate(self, sample_attack_results): @pytest.mark.usefixtures("patch_central_database") class TestScenarioIdentifier: - """Tests for ScenarioIdentifier class.""" + """Tests for ScenarioIdentifier registry projection.""" def test_scenario_identifier_initialization(self): - """Test ScenarioIdentifier initialization.""" - identifier = ScenarioIdentifier(name="TestScenario", scenario_version=2) - - assert identifier.name == "TestScenario" - assert identifier.version == 2 - assert identifier.pyrit_version is not None - - def test_scenario_identifier_with_custom_pyrit_version(self): - """Test ScenarioIdentifier initialization sets pyrit version automatically.""" - identifier = ScenarioIdentifier(name="TestScenario", scenario_version=1) + """Test ScenarioIdentifier projection initialization.""" + identifier = ScenarioIdentifier( + class_name="TestScenario", + class_module="tests.unit.scenario.core.test_scenario", + version=2, + ) - assert identifier.pyrit_version is not None - assert identifier.name == "TestScenario" + assert identifier.class_name == "TestScenario" + assert identifier.class_module == "tests.unit.scenario.core.test_scenario" - def test_scenario_identifier_with_init_data(self): - """Test ScenarioIdentifier with init_data.""" - init_data = {"param1": "value1", "param2": 42} - identifier = ScenarioIdentifier(name="TestScenario", scenario_version=1, init_data=init_data) + def test_scenario_identifier_accepts_registry_projection_fields(self): + """Test ScenarioIdentifier stores registry projection metadata.""" + identifier = ScenarioIdentifier( + class_name="TestScenario", + class_module="tests.unit.scenario.core.test_scenario", + techniques=["baseline"], + datasets=["harmful_content"], + ) - assert identifier.init_data == init_data + assert identifier.techniques == ["baseline"] + assert identifier.datasets == ["harmful_content"] def create_mock_truefalse_scorer(): @@ -947,7 +960,10 @@ async def _get_atomic_attacks_async(self): ) ] if self._include_baseline: - atomic_attacks.insert(0, self._build_baseline_atomic_attack(seed_groups=all_seed_groups)) + atomic_attacks.insert( + 0, + self._build_baseline_atomic_attack(seed_groups=all_seed_groups), + ) return atomic_attacks # Two distinct samples wired up. A buggy implementation with a second @@ -1015,7 +1031,8 @@ def test_raises_when_scorer_is_none(self, mock_objective_target): @pytest.mark.usefixtures("patch_central_database") class TestBaselineEmissionDeprecationRescue: """Deprecation rescue (removed in 0.16.0): overrides that don't emit baseline get a - DeprecationWarning + auto-injected baseline so they keep working during the migration.""" + DeprecationWarning + auto-injected baseline so they keep working during the migration. + """ @staticmethod def _dataset_config(): @@ -1076,40 +1093,40 @@ class TestValidateStoredScenario: def _make_scenario(self, *, name: str = "TestScenario", version: int = 1) -> ConcreteScenario: scenario = ConcreteScenario(name=name, version=version) scenario._scenario_result_id = "test-result-id" - # _validate_stored_scenario now also checks params scenario.params = {} return scenario def test_passes_when_name_and_version_match(self): - """Valid match does not raise.""" + """Valid match (identical eval hash) does not raise.""" scenario = self._make_scenario(name="TestScenario", version=2) - stored_result = MagicMock(spec=ScenarioResult) - stored_result.scenario_identifier = ScenarioIdentifier(name="ConcreteScenario", scenario_version=2) - stored_result.scenario_run_state = "CREATED" + current = make_scenario_identifier(scenario_name="ConcreteScenario", version=2) + stored_result = make_scenario_result( + scenario_name="ConcreteScenario", scenario_version=2, scenario_run_state="CREATED", attack_results={} + ) # Should not raise - scenario._validate_stored_scenario(stored_result=stored_result) + scenario._validate_stored_scenario(stored_result=stored_result, current_identifier=current) def test_raises_when_name_mismatches(self): """Mismatched name raises ValueError.""" scenario = self._make_scenario(name="TestScenario", version=1) - stored_result = MagicMock(spec=ScenarioResult) - stored_result.scenario_identifier = ScenarioIdentifier(name="DifferentScenario", scenario_version=1) + current = make_scenario_identifier(scenario_name="ConcreteScenario", version=1) + stored_result = make_scenario_result(scenario_name="DifferentScenario", scenario_version=1, attack_results={}) - with pytest.raises(ValueError, match="belongs to scenario 'DifferentScenario'"): - scenario._validate_stored_scenario(stored_result=stored_result) + with pytest.raises(ValueError, match="does not match the current"): + scenario._validate_stored_scenario(stored_result=stored_result, current_identifier=current) def test_raises_when_version_mismatches(self): - """Mismatched version raises ValueError.""" + """Mismatched version changes the eval hash and raises ValueError.""" scenario = self._make_scenario(name="TestScenario", version=2) - stored_result = MagicMock(spec=ScenarioResult) - stored_result.scenario_identifier = ScenarioIdentifier(name="ConcreteScenario", scenario_version=99) + current = make_scenario_identifier(scenario_name="ConcreteScenario", version=2) + stored_result = make_scenario_result(scenario_name="ConcreteScenario", scenario_version=99, attack_results={}) - with pytest.raises(ValueError, match="version 99 but current version is 2"): - scenario._validate_stored_scenario(stored_result=stored_result) + with pytest.raises(ValueError, match="does not match the current"): + scenario._validate_stored_scenario(stored_result=stored_result, current_identifier=current) @pytest.mark.usefixtures("patch_central_database") @@ -1227,7 +1244,10 @@ async def run_async(*, executor, **kwargs): atomic_attack=mock_atomic_attacks[idx], ) save_attack_results_to_memory([sample_attack_results[idx]]) - return AttackExecutorResult(completed_results=[sample_attack_results[idx]], incomplete_objectives=[]) + return AttackExecutorResult( + completed_results=[sample_attack_results[idx]], + incomplete_objectives=[], + ) return AsyncMock(side_effect=run_async) @@ -1276,7 +1296,10 @@ async def run_async(*args, **kwargs): atomic_attack=mock_atomic_attacks[idx], ) save_attack_results_to_memory([sample_attack_results[idx]]) - return AttackExecutorResult(completed_results=[sample_attack_results[idx]], incomplete_objectives=[]) + return AttackExecutorResult( + completed_results=[sample_attack_results[idx]], + incomplete_objectives=[], + ) return AsyncMock(side_effect=run_async) diff --git a/tests/unit/scenario/core/test_scenario_parameters.py b/tests/unit/scenario/core/test_scenario_parameters.py index 03256fc806..cb4bb0b953 100644 --- a/tests/unit/scenario/core/test_scenario_parameters.py +++ b/tests/unit/scenario/core/test_scenario_parameters.py @@ -19,8 +19,8 @@ def _make_scenario(*, declared_params: list[Parameter]) -> Scenario: """Build a minimal Scenario subclass that declares the given parameters. - Each test gets its own subclass so ``_declarations_validated`` state never - leaks across tests. + Each test gets its own subclass so declared-parameter state never leaks + across tests. """ params_to_declare = declared_params @@ -290,6 +290,12 @@ def test_unknown_params_listed_together(self) -> None: with pytest.raises(ValueError, match="bogus1, bogus2"): scenario.set_params_from_args(args={"bogus1": "a", "bogus2": "b"}) + def test_reserved_version_param_raises(self) -> None: + """A scenario cannot declare a param named ``version`` (owned by the identity).""" + scenario = _make_scenario(declared_params=[Parameter(name="version", description="d", param_type=int)]) + with pytest.raises(ValueError, match="reserved parameter"): + scenario.set_params_from_args(args={}) + @pytest.mark.usefixtures("patch_central_database") class TestDeclarationValidation: @@ -398,49 +404,61 @@ def test_none_value_with_no_default_materializes_as_none(self) -> None: @pytest.mark.usefixtures("patch_central_database") class TestResumeParameterValidation: - """Tests for Stage 5 resume validation against persisted scenario params.""" + """Tests for resume validation against a persisted scenario identifier (eval-hash based).""" - @staticmethod - def _make_stored_result(*, scenario_name: str, version: int, init_data): - """Build a minimal ScenarioResult with a controlled identifier for resume tests.""" - from pyrit.models import ScenarioIdentifier, ScenarioResult + _TARGET_ID = ComponentIdentifier(class_name="MockTarget", class_module="tests.unit.scenarios") - identifier = ScenarioIdentifier( - name=scenario_name, - description="", + @classmethod + def _make_stored_result(cls, *, scenario_name: str, version: int, params): + """Build a minimal ScenarioResult with a controlled scenario identifier for resume tests.""" + from tests.unit.mocks import make_scenario_result + + return make_scenario_result( + scenario_name=scenario_name, scenario_version=version, - init_data=init_data, - ) - target_id = ComponentIdentifier(class_name="MockTarget", class_module="tests.unit.scenarios") - return ScenarioResult( - scenario_identifier=identifier, - objective_target_identifier=target_id, + params=params, + objective_target_identifier=cls._TARGET_ID, objective_scorer_identifier=_TEST_SCORER_ID, labels={}, attack_results={}, scenario_run_state="CREATED", ) + @classmethod + def _current_identifier(cls, *, scenario, version: int = 1, params): + """Build the identifier that mirrors the current run for the given scenario.""" + from tests.unit.mocks import make_scenario_identifier + + return make_scenario_identifier( + scenario_name=type(scenario).__name__, + version=version, + params=params, + objective_target=cls._TARGET_ID, + objective_scorer=_TEST_SCORER_ID, + ) + def test_matching_params_returns_none(self) -> None: scenario = _make_scenario( declared_params=[Parameter(name="max_turns", description="d", param_type=int, default=5)] ) scenario.set_params_from_args(args={"max_turns": 10}) - stored = self._make_stored_result(scenario_name=type(scenario).__name__, version=1, init_data={"max_turns": 10}) + stored = self._make_stored_result(scenario_name=type(scenario).__name__, version=1, params={"max_turns": 10}) + current = self._current_identifier(scenario=scenario, params={"max_turns": 10}) # Match path: returns None and does not raise. - assert scenario._validate_stored_scenario(stored_result=stored) is None + assert scenario._validate_stored_scenario(stored_result=stored, current_identifier=current) is None - def test_changed_param_raises_with_diff(self) -> None: + def test_changed_param_raises_without_leaking_values(self) -> None: scenario = _make_scenario( declared_params=[Parameter(name="max_turns", description="d", param_type=int, default=5)] ) scenario.set_params_from_args(args={"max_turns": 10}) - stored = self._make_stored_result(scenario_name=type(scenario).__name__, version=1, init_data={"max_turns": 5}) - with pytest.raises(ValueError, match="mismatched parameters .*changed: max_turns") as exc_info: - scenario._validate_stored_scenario(stored_result=stored) - # Diff names the key but never the values (no leak). + stored = self._make_stored_result(scenario_name=type(scenario).__name__, version=1, params={"max_turns": 5}) + current = self._current_identifier(scenario=scenario, params={"max_turns": 10}) + with pytest.raises(ValueError, match="does not match the current") as exc_info: + scenario._validate_stored_scenario(stored_result=stored, current_identifier=current) + # Generic drift message never leaks the differing param values. assert "10" not in str(exc_info.value) assert "stored=5" not in str(exc_info.value) @@ -453,71 +471,73 @@ def test_added_param_raises(self) -> None: ) scenario.set_params_from_args(args={}) - stored = self._make_stored_result(scenario_name=type(scenario).__name__, version=1, init_data={"max_turns": 5}) - with pytest.raises(ValueError, match="added: mode"): - scenario._validate_stored_scenario(stored_result=stored) - - def test_legacy_init_data_none_matches_empty_params(self) -> None: - """A pre-Stage-5 stored result has init_data=None; treat as empty for back-compat.""" - scenario = _make_scenario(declared_params=[]) - scenario.set_params_from_args(args={}) - - stored = self._make_stored_result(scenario_name=type(scenario).__name__, version=1, init_data=None) - assert scenario._validate_stored_scenario(stored_result=stored) is None - - def test_legacy_init_data_none_mismatches_populated_params(self) -> None: - scenario = _make_scenario( - declared_params=[Parameter(name="max_turns", description="d", param_type=int, default=5)] - ) - scenario.set_params_from_args(args={"max_turns": 7}) - - stored = self._make_stored_result(scenario_name=type(scenario).__name__, version=1, init_data=None) - with pytest.raises(ValueError, match="added: max_turns"): - scenario._validate_stored_scenario(stored_result=stored) + stored = self._make_stored_result(scenario_name=type(scenario).__name__, version=1, params={"max_turns": 5}) + current = self._current_identifier(scenario=scenario, params={"max_turns": 5, "mode": "fast"}) + with pytest.raises(ValueError, match="does not match the current"): + scenario._validate_stored_scenario(stored_result=stored, current_identifier=current) def test_resume_normalizes_json_drift_for_passthrough_tuples(self) -> None: """A tuple value under param_type=None matches a stored list (post-JSON round-trip).""" scenario = _make_scenario(declared_params=[Parameter(name="weights", description="d")]) scenario.set_params_from_args(args={"weights": (0.5, 0.5)}) - # init_data after a real DB round-trip would be a list, not a tuple. The fix - # normalizes both sides through json.loads(json.dumps(...)) before comparing. + # A stored value after a real DB round-trip would be a list, not a tuple. The + # eval hash normalizes both sides through JSON before comparing. stored = self._make_stored_result( - scenario_name=type(scenario).__name__, version=1, init_data={"weights": [0.5, 0.5]} + scenario_name=type(scenario).__name__, version=1, params={"weights": [0.5, 0.5]} ) - assert scenario._validate_stored_scenario(stored_result=stored) is None + current = self._current_identifier(scenario=scenario, params={"weights": (0.5, 0.5)}) + assert scenario._validate_stored_scenario(stored_result=stored, current_identifier=current) is None def test_name_mismatch_raises(self) -> None: scenario = _make_scenario(declared_params=[]) scenario.set_params_from_args(args={}) - stored = self._make_stored_result(scenario_name="OtherScenario", version=1, init_data={}) - with pytest.raises(ValueError, match="belongs to scenario 'OtherScenario'"): - scenario._validate_stored_scenario(stored_result=stored) + stored = self._make_stored_result(scenario_name="OtherScenario", version=1, params={}) + current = self._current_identifier(scenario=scenario, params={}) + with pytest.raises(ValueError, match="does not match the current"): + scenario._validate_stored_scenario(stored_result=stored, current_identifier=current) def test_version_mismatch_raises(self) -> None: scenario = _make_scenario(declared_params=[]) scenario.set_params_from_args(args={}) - stored = self._make_stored_result(scenario_name=type(scenario).__name__, version=999, init_data={}) - with pytest.raises(ValueError, match="version 999 but current version is 1"): - scenario._validate_stored_scenario(stored_result=stored) + stored = self._make_stored_result(scenario_name=type(scenario).__name__, version=999, params={}) + current = self._current_identifier(scenario=scenario, version=1, params={}) + with pytest.raises(ValueError, match="does not match the current"): + scenario._validate_stored_scenario(stored_result=stored, current_identifier=current) @pytest.mark.usefixtures("patch_central_database") class TestParamPersistenceJsonSafety: - """Tests for the JSON-serializability check before persisting params.""" + """Params flow into the scenario identifier, which enforces JSON-serializable values.""" + + @staticmethod + def _mock_target() -> MagicMock: + target = MagicMock() + target.get_identifier.return_value = ComponentIdentifier(class_name="MockTarget", class_module="test") + return target - def test_json_safe_scalar_passes(self) -> None: - from pyrit.scenario.core.scenario import _assert_json_serializable + async def test_json_safe_params_persist_on_init(self) -> None: + scenario = _make_scenario( + declared_params=[Parameter(name="max_turns", description="d", param_type=int, default=5)] + ) + scenario.set_params_from_args(args={"max_turns": 10}) + + await scenario.initialize_async(objective_target=self._mock_target()) - _assert_json_serializable(params={"max_turns": 5, "mode": "fast", "datasets": ["a", "b"]}) + stored = scenario._memory.get_scenario_results(scenario_result_ids=[scenario._scenario_result_id])[0] + assert stored.scenario_identifier.params["max_turns"] == 10 - def test_non_json_safe_value_raises(self) -> None: - from pyrit.scenario.core.scenario import _assert_json_serializable + async def test_non_json_safe_value_raises(self) -> None: + from pydantic import ValidationError class _NotJsonable: pass - with pytest.raises(ValueError, match="non-JSON-serializable"): - _assert_json_serializable(params={"x": _NotJsonable()}) + # param_type=None passes the raw value straight through set_params_from_args. + scenario = _make_scenario(declared_params=[Parameter(name="blob", description="d")]) + scenario.set_params_from_args(args={"blob": _NotJsonable()}) + + with pytest.raises(ValidationError): + await scenario.initialize_async(objective_target=self._mock_target()) diff --git a/tests/unit/setup/test_load_default_datasets.py b/tests/unit/setup/test_load_default_datasets.py index b453525026..9a2d58acf7 100644 --- a/tests/unit/setup/test_load_default_datasets.py +++ b/tests/unit/setup/test_load_default_datasets.py @@ -64,7 +64,7 @@ async def test_initialize_async_no_scenarios(self) -> None: """Test initialization when no scenarios are registered.""" initializer = LoadDefaultDatasets() - with patch.object(ScenarioRegistry, "list_metadata", return_value=[]): + with patch.object(ScenarioRegistry, "get_all_registered_class_metadata", return_value=[]): with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: with patch.object(CentralMemory, "get_memory_instance") as mock_memory: mock_memory_instance = MagicMock() @@ -82,7 +82,7 @@ async def test_initialize_async_with_scenarios(self) -> None: metadata = [_FakeMetadata(registry_name="mock_scenario", default_datasets=("dataset1", "dataset2"))] - with patch.object(ScenarioRegistry, "list_metadata", return_value=metadata): + with patch.object(ScenarioRegistry, "get_all_registered_class_metadata", return_value=metadata): with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: mock_dataset1 = MagicMock(spec=SeedDataset) mock_dataset2 = MagicMock(spec=SeedDataset) @@ -112,7 +112,7 @@ async def test_initialize_async_deduplicates_datasets(self) -> None: _FakeMetadata(registry_name="scenario2", default_datasets=("dataset2", "dataset3")), ] - with patch.object(ScenarioRegistry, "list_metadata", return_value=metadata): + with patch.object(ScenarioRegistry, "get_all_registered_class_metadata", return_value=metadata): with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: mock_fetch.return_value = [] @@ -157,7 +157,7 @@ async def test_all_required_datasets_available_in_seed_provider(self, populated_ ): registry = ScenarioRegistry.get_registry_singleton() registry._metadata_cache = None # force rebuild under the patch - metadata_list = list(registry.list_metadata()) + metadata_list = list(registry.get_all_registered_class_metadata()) missing_datasets: list[str] = [] for metadata in metadata_list: @@ -178,7 +178,7 @@ async def test_initialize_async_empty_dataset_list(self) -> None: metadata = [_FakeMetadata(registry_name="empty_scenario", default_datasets=())] - with patch.object(ScenarioRegistry, "list_metadata", return_value=metadata): + with patch.object(ScenarioRegistry, "get_all_registered_class_metadata", return_value=metadata): with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: with patch.object(CentralMemory, "get_memory_instance") as mock_memory: mock_memory_instance = MagicMock() diff --git a/tests/unit/setup/test_preload_scenario_metadata.py b/tests/unit/setup/test_preload_scenario_metadata.py index d736d628ad..52ae274c6d 100644 --- a/tests/unit/setup/test_preload_scenario_metadata.py +++ b/tests/unit/setup/test_preload_scenario_metadata.py @@ -7,19 +7,25 @@ import pytest -from pyrit.setup.initializers.scenarios.preload_scenario_metadata import PreloadScenarioMetadata +from pyrit.setup.initializers.scenarios.preload_scenario_metadata import ( + PreloadScenarioMetadata, +) class TestPreloadScenarioMetadata: """Tests for PreloadScenarioMetadata.initialize_async.""" @pytest.mark.asyncio - async def test_initialize_async_calls_list_metadata(self) -> None: - """``initialize_async`` should fetch the registry and call ``list_metadata`` to warm the cache.""" + async def test_initialize_async_warms_metadata_cache(self) -> None: + """``initialize_async`` should fetch the registry and warm the metadata cache.""" initializer = PreloadScenarioMetadata() mock_registry = MagicMock() - mock_registry.list_metadata.return_value = [MagicMock(), MagicMock(), MagicMock()] + mock_registry.get_all_registered_class_metadata.return_value = [ + MagicMock(), + MagicMock(), + MagicMock(), + ] with patch( "pyrit.setup.initializers.scenarios.preload_scenario_metadata.ScenarioRegistry.get_registry_singleton", @@ -27,15 +33,15 @@ async def test_initialize_async_calls_list_metadata(self) -> None: ): await initializer.initialize_async() - mock_registry.list_metadata.assert_called_once_with() + mock_registry.get_all_registered_class_metadata.assert_called_once_with() @pytest.mark.asyncio async def test_initialize_async_propagates_registry_errors(self) -> None: - """If a scenario fails to instantiate, ``list_metadata`` raises and the initializer surfaces it.""" + """If a scenario fails to instantiate, metadata building raises and the initializer surfaces it.""" initializer = PreloadScenarioMetadata() mock_registry = MagicMock() - mock_registry.list_metadata.side_effect = TypeError("scenario X is not no-arg instantiable") + mock_registry.get_all_registered_class_metadata.side_effect = TypeError("scenario X is not no-arg instantiable") with patch( "pyrit.setup.initializers.scenarios.preload_scenario_metadata.ScenarioRegistry.get_registry_singleton",