From 8876b6f78e77f1fcf66e2a1fffae9b49210c8642 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 7 May 2026 16:34:12 -0700 Subject: [PATCH 1/9] feat: add error and retry tracking to AttackResult and ScenarioResult - Add RetryEvent dataclass for capturing Tenacity retry attempts - Add RetryCollector (contextvar-based) to accumulate retry events per-attack - Hook log_exception() to record retries to active collector - Add error_message, error_type, error_traceback, retry_events, total_retries fields to AttackResult and AttackResultEntry - Add error_attack_result_ids pointer to ScenarioResult/ScenarioResultEntry - Create ON_ERROR handler in attack strategy to persist failed AttackResults with error details, even when the attack crashes before returning - Wire RetryCollector lifecycle in pre/post execute event handlers - Link failed attack results to scenario result on incomplete objectives - Add RetryEventResponse, error/retry fields to REST API models (AttackSummary, ScenarioRunSummary, AtomicAttackResults) - Update service layer to populate error/retry info from persisted data - Add Alembic migration for new database columns - Add update_scenario_error_attacks() to MemoryInterface All new fields are nullable/optional for backward compatibility. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/models/attacks.py | 25 +++++ pyrit/backend/models/scenarios.py | 3 + .../backend/services/scenario_run_service.py | 55 +++++++++ pyrit/exceptions/__init__.py | 10 ++ pyrit/exceptions/exceptions_helpers.py | 7 ++ pyrit/exceptions/retry_collector.py | 104 ++++++++++++++++++ pyrit/executor/attack/core/attack_strategy.py | 73 ++++++++++++ pyrit/executor/core/strategy.py | 7 +- .../a1b2c3d4e5f6_add_error_retry_fields.py | 46 ++++++++ pyrit/memory/memory_interface.py | 40 +++++++ pyrit/memory/memory_models.py | 48 ++++++++ pyrit/models/__init__.py | 1 + pyrit/models/attack_result.py | 10 ++ pyrit/models/retry_event.py | 68 ++++++++++++ pyrit/models/scenario_result.py | 5 + pyrit/scenario/core/scenario.py | 14 +++ 16 files changed, 515 insertions(+), 1 deletion(-) create mode 100644 pyrit/exceptions/retry_collector.py create mode 100644 pyrit/memory/alembic/versions/a1b2c3d4e5f6_add_error_retry_fields.py create mode 100644 pyrit/models/retry_event.py diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 64d6269e44..e89b98358f 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -90,6 +90,20 @@ class TargetInfo(BaseModel): model_name: Optional[str] = Field(None, description="Model or deployment name") +class RetryEventResponse(BaseModel): + """A single retry attempt captured during execution.""" + + timestamp: datetime = Field(..., description="When the retry occurred") + attempt_number: int = Field(..., ge=1, description="Tenacity attempt number (1-based)") + function_name: str = Field(..., description="The retried function name") + exception_type: str = Field("", description="Exception class name") + exception_message: str = Field("", description="Exception message") + component_role: str = Field("", description="Component role from ExecutionContext") + component_name: str | None = Field(None, description="Component class name") + endpoint: str | None = Field(None, description="Target endpoint URL") + elapsed_seconds: float = Field(0.0, ge=0, description="Time since first attempt in seconds") + + class AttackSummary(BaseModel): """Summary view of an attack (for list views, omits full message content).""" @@ -121,6 +135,17 @@ class AttackSummary(BaseModel): created_at: datetime = Field(..., description="Attack creation timestamp") updated_at: datetime = Field(..., description="Last update timestamp") + # Error information + error_message: str | None = Field(None, description="Error message if the attack failed with an exception") + error_type: str | None = Field(None, description="Exception class name (e.g., 'RateLimitError')") + error_traceback: str | None = Field(None, description="Formatted traceback string") + + # Retry information + total_retries: int = Field(0, ge=0, description="Total number of retries during this attack") + retry_events: list[RetryEventResponse] | None = Field( + None, description="Detailed retry events (omitted in list views unless requested)" + ) + # ============================================================================ # Conversation Messages Response diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index e628020c2f..4dd8d31a84 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -98,6 +98,7 @@ class ScenarioRunSummary(BaseModel): created_at: datetime = Field(..., description="When the run was created") updated_at: datetime = Field(..., description="When the run status last changed") error: str | None = Field(None, description="Error message if status is FAILED") + error_type: str | None = Field(None, description="Exception class name if status is FAILED") strategies_used: list[str] = Field(default_factory=list, description="Strategy names that were executed") total_attacks: int = Field(0, ge=0, description="Total number of atomic attacks") completed_attacks: int = Field(0, ge=0, description="Number of attacks that completed") @@ -126,6 +127,8 @@ class AtomicAttackResults(BaseModel): success_count: int = Field(0, ge=0, description="Number of successful attacks") failure_count: int = Field(0, ge=0, description="Number of failed attacks") total_count: int = Field(0, ge=0, description="Total number of attack results") + total_retries: int = Field(0, ge=0, description="Sum of retries across all attacks in this group") + error_count: int = Field(0, ge=0, description="Number of attacks with errors") class ScenarioRunDetail(BaseModel): diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 26f9b21f60..373d76c2c4 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -397,11 +397,24 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari # Clean up finished active tasks after reading the error error = None + error_type = None if active is not None: error = active.error if active.task is not None and active.task.done(): del self._active_tasks[scenario_result_id] + # Fall back to persisted error from failed AttackResult + if not error and getattr(scenario_result, "error_attack_result_ids", None): + memory = CentralMemory.get_memory_instance() + error_ids = scenario_result.error_attack_result_ids + if isinstance(error_ids, list) and error_ids: + error_results = memory.get_attack_results( + attack_result_ids=error_ids[:1] + ) + if error_results: + error = error_results[0].error_message + error_type = error_results[0].error_type + status = ScenarioRunStatus(scenario_result.scenario_run_state) # Build result fields for completed runs @@ -426,6 +439,7 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari created_at=scenario_result.creation_time, updated_at=scenario_result.completion_time, error=error, + error_type=error_type, strategies_used=strategies_used, total_attacks=total_attacks, completed_attacks=completed_attacks, @@ -467,6 +481,8 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non details: list[AttackSummary] = [] success_count = 0 failure_count = 0 + group_total_retries = 0 + group_error_count = 0 for ar in attack_results: score_value = None @@ -478,6 +494,34 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non last_response_text = str(ar.last_response) timestamp = ar.timestamp or datetime.now(timezone.utc) + + # Build retry event responses if available + retry_event_responses = None + retry_events = getattr(ar, "retry_events", None) + if isinstance(retry_events, list) and retry_events: + from pyrit.backend.models.attacks import RetryEventResponse + + retry_event_responses = [ + RetryEventResponse( + timestamp=evt.timestamp, + attempt_number=evt.attempt_number, + function_name=evt.function_name, + exception_type=evt.exception_type, + exception_message=evt.exception_message, + component_role=evt.component_role, + component_name=evt.component_name, + endpoint=evt.endpoint, + elapsed_seconds=evt.elapsed_seconds, + ) + for evt in retry_events + ] + + # Extract error/retry fields with safe defaults + ar_error_message = ar.error_message if isinstance(ar.error_message, str) else None + ar_error_type = ar.error_type if isinstance(ar.error_type, str) else None + ar_error_traceback = ar.error_traceback if isinstance(ar.error_traceback, str) else None + ar_total_retries = ar.total_retries if isinstance(ar.total_retries, int) else 0 + details.append( AttackSummary( attack_result_id=ar.attack_result_id, @@ -491,6 +535,11 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non execution_time_ms=ar.execution_time_ms, created_at=timestamp, updated_at=timestamp, + error_message=ar_error_message, + error_type=ar_error_type, + error_traceback=ar_error_traceback, + total_retries=ar_total_retries, + retry_events=retry_event_responses, ) ) @@ -499,6 +548,10 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non elif ar.outcome == AttackOutcome.FAILURE: failure_count += 1 + group_total_retries += ar_total_retries + if ar_error_message: + group_error_count += 1 + attacks.append( AtomicAttackResults( atomic_attack_name=attack_name, @@ -507,6 +560,8 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non success_count=success_count, failure_count=failure_count, total_count=len(details), + total_retries=group_total_retries, + error_count=group_error_count, ) ) diff --git a/pyrit/exceptions/__init__.py b/pyrit/exceptions/__init__.py index 9da9650c2d..abd42de031 100644 --- a/pyrit/exceptions/__init__.py +++ b/pyrit/exceptions/__init__.py @@ -27,15 +27,23 @@ set_execution_context, ) from pyrit.exceptions.exceptions_helpers import remove_markdown_json +from pyrit.exceptions.retry_collector import ( + RetryCollector, + clear_retry_collector, + get_retry_collector, + set_retry_collector, +) __all__ = [ "BadRequestException", "clear_execution_context", + "clear_retry_collector", "ComponentRole", "EmptyResponseException", "ExecutionContext", "ExecutionContextManager", "get_execution_context", + "get_retry_collector", "get_retry_max_num_attempts", "handle_bad_request_exception", "InvalidJsonException", @@ -47,6 +55,8 @@ "pyrit_placeholder_retry", "RateLimitException", "remove_markdown_json", + "RetryCollector", "set_execution_context", + "set_retry_collector", "execution_context", ] diff --git a/pyrit/exceptions/exceptions_helpers.py b/pyrit/exceptions/exceptions_helpers.py index 65a7267daf..8c7a3075b6 100644 --- a/pyrit/exceptions/exceptions_helpers.py +++ b/pyrit/exceptions/exceptions_helpers.py @@ -68,6 +68,13 @@ def log_exception(retry_state: RetryCallState) -> None: f"Elapsed time: {elapsed_time} seconds. Total calls: {call_count}" ) + # Record to RetryCollector if one is active + from pyrit.exceptions.retry_collector import get_retry_collector + + collector = get_retry_collector() + if collector: + collector.record(retry_state=retry_state) + def remove_start_md_json(response_msg: str) -> str: """ diff --git a/pyrit/exceptions/retry_collector.py b/pyrit/exceptions/retry_collector.py new file mode 100644 index 0000000000..ad89d05f3b --- /dev/null +++ b/pyrit/exceptions/retry_collector.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Contextvar-based retry event collector for capturing Tenacity retry events.""" + +from contextvars import ContextVar +from dataclasses import dataclass, field +from typing import Any, Optional + +from pyrit.models.retry_event import RetryEvent + + +@dataclass +class RetryCollector: + """Collects retry events during attack execution. + + Uses contextvar for thread/task-safe scoping. Each attack execution + creates its own collector so retry events are naturally scoped + per-objective. + """ + + events: list[RetryEvent] = field(default_factory=list) + + def record(self, *, retry_state: Any) -> None: + """Record a retry event from a Tenacity RetryCallState. + + Extracts information from the retry state and the current + ExecutionContext to build a structured RetryEvent. + + Args: + retry_state: The Tenacity RetryCallState from the after callback. + """ + import time + + from pyrit.exceptions.exception_context import get_execution_context + + # Extract basic info from retry_state + call_count = getattr(retry_state, "attempt_number", None) or 0 + start_time = getattr(retry_state, "start_time", None) + elapsed = (time.monotonic() - start_time) if start_time is not None else 0.0 + + fn = getattr(retry_state, "fn", None) + fn_name = getattr(fn, "__name__", "unknown") if fn else "unknown" + + # Extract exception info + exception_type = "" + exception_message = "" + outcome = getattr(retry_state, "outcome", None) + if outcome and getattr(outcome, "failed", False): + exc = outcome.exception() if hasattr(outcome, "exception") else None + if exc: + exception_type = type(exc).__name__ + exception_message = str(exc) + + # Extract context info + component_role = "" + component_name: str | None = None + endpoint: str | None = None + try: + exec_context = get_execution_context() + if exec_context: + component_role = exec_context.component_role.value + component_name = exec_context.component_name + endpoint = exec_context.endpoint + except Exception: + pass + + event = RetryEvent( + attempt_number=call_count, + function_name=fn_name, + exception_type=exception_type, + exception_message=exception_message, + component_role=component_role, + component_name=component_name, + endpoint=endpoint, + elapsed_seconds=round(elapsed, 3), + ) + self.events.append(event) + + +_retry_collector: ContextVar[Optional[RetryCollector]] = ContextVar("retry_collector", default=None) + + +def get_retry_collector() -> Optional[RetryCollector]: + """Get the current retry collector. + + Returns: + The active RetryCollector, or None if not set. + """ + return _retry_collector.get() + + +def set_retry_collector(collector: RetryCollector) -> None: + """Set the current retry collector. + + Args: + collector: The RetryCollector to activate. + """ + _retry_collector.set(collector) + + +def clear_retry_collector() -> None: + """Clear the current retry collector.""" + _retry_collector.set(None) diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index cbe6441c2e..f8eb45b9d1 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -6,11 +6,13 @@ import dataclasses import logging # noqa: TC003 import time +import traceback from abc import ABC from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, overload from pyrit.common.logger import logger +from pyrit.exceptions.retry_collector import RetryCollector, clear_retry_collector, get_retry_collector, set_retry_collector from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT from pyrit.executor.core import ( Strategy, @@ -134,6 +136,7 @@ def __init__(self, logger: logging.Logger = logger) -> None: self._events = { StrategyEvent.ON_PRE_EXECUTE: self._on_pre_execute, StrategyEvent.ON_POST_EXECUTE: self._on_post_execute, + StrategyEvent.ON_ERROR: self._on_error, } self._memory = CentralMemory.get_memory_instance() @@ -167,6 +170,9 @@ async def _on_pre_execute( """ Handle pre-execution logic before the attack strategy runs. + Sets up execution timing and starts a RetryCollector to capture + retry events during execution. + Args: event_data (StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]): The event data containing context and result. @@ -180,6 +186,10 @@ async def _on_pre_execute( # Initialize start time for execution event_data.context.start_time = time.perf_counter() + # Start a RetryCollector to capture retry events during this attack + collector = RetryCollector() + set_retry_collector(collector) + # Log the start of the attack self._logger.info(f"Starting attack: {event_data.context.objective}") @@ -189,6 +199,8 @@ async def _on_post_execute( """ Handle post-execution logic after the attack strategy has run. + Attaches retry events to the result and persists it to memory. + Args: event_data (StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]): The event data containing context and result. @@ -203,6 +215,13 @@ async def _on_post_execute( execution_time_ms = int((end_time - event_data.context.start_time) * 1000) event_data.result.execution_time_ms = execution_time_ms + # Attach collected retry events to the result + collector = get_retry_collector() + if collector and collector.events: + event_data.result.retry_events = collector.events + event_data.result.total_retries = len(collector.events) + clear_retry_collector() + self._logger.debug(f"Attack execution completed in {execution_time_ms}ms") self._log_attack_outcome(event_data.result) @@ -227,6 +246,60 @@ def _log_attack_outcome(self, result: AttackResult) -> None: self._logger.info(message) + async def _on_error( + self, event_data: StrategyEventData[AttackStrategyContextT, AttackStrategyResultT] + ) -> None: + """ + Handle error during attack execution. + + Creates a failed AttackResult with error details and any retry events + collected during execution, then persists it to memory. + + Args: + event_data (StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]): The event data containing + context, result, and error. + """ + error = event_data.error + context = event_data.context + if not error or not context: + clear_retry_collector() + return + + # Collect retry events before clearing + collector = get_retry_collector() + retry_events = collector.events if collector else [] + clear_retry_collector() + + # Build a conversation_id — use context's if available, otherwise generate one + conversation_id = getattr(context, "conversation_id", None) or str(__import__("uuid").uuid4()) + + from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier + + error_result = AttackResult( + conversation_id=conversation_id, + objective=context.objective, + outcome=AttackOutcome.FAILURE, + outcome_reason=f"Exception: {type(error).__name__}: {str(error)}", + labels=context.memory_labels, + related_conversations=context.related_conversations, + error_message=str(error), + error_type=type(error).__name__, + error_traceback=traceback.format_exc(), + retry_events=retry_events, + total_retries=len(retry_events), + ) + + end_time = time.perf_counter() + if context.start_time: + error_result.execution_time_ms = int((end_time - context.start_time) * 1000) + + # Store the error attack result ID on the context so scenario-level + # code can link it to the ScenarioResult + context._error_attack_result_id = error_result.attack_result_id + + self._memory.add_attack_results_to_memory(attack_results=[error_result]) + self._logger.error(f"Attack failed with {type(error).__name__}: {error}") + class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Identifiable, ABC): """ diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index 833ea1e072..59c9046c7d 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -361,7 +361,12 @@ async def execute_with_context_async(self, *, context: StrategyContextT) -> Stra else: error_message = f"Strategy execution failed for {self.__class__.__name__}: {str(e)}" - raise RuntimeError(error_message) from e + runtime_error = RuntimeError(error_message) + # Attach the error attack result ID if the ON_ERROR handler created one + error_attack_result_id = getattr(context, "_error_attack_result_id", None) + if error_attack_result_id: + runtime_error.error_attack_result_id = error_attack_result_id # type: ignore[attr-defined] + raise runtime_error from e async def execute_async(self, **kwargs: Any) -> StrategyResultT: """ diff --git a/pyrit/memory/alembic/versions/a1b2c3d4e5f6_add_error_retry_fields.py b/pyrit/memory/alembic/versions/a1b2c3d4e5f6_add_error_retry_fields.py new file mode 100644 index 0000000000..7e244275c0 --- /dev/null +++ b/pyrit/memory/alembic/versions/a1b2c3d4e5f6_add_error_retry_fields.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +add error and retry fields to attack and scenario results. + +Revision ID: a1b2c3d4e5f6 +Revises: 108a72344872 +Create Date: 2026-05-08 00:00:00.000000 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a1b2c3d4e5f6" +down_revision: str | None = "108a72344872" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply this schema upgrade.""" + # Error fields on AttackResultEntries + op.add_column("AttackResultEntries", sa.Column("error_message", sa.Unicode(), nullable=True)) + op.add_column("AttackResultEntries", sa.Column("error_type", sa.String(), nullable=True)) + op.add_column("AttackResultEntries", sa.Column("error_traceback", sa.Unicode(), nullable=True)) + + # Retry fields on AttackResultEntries + op.add_column("AttackResultEntries", sa.Column("retry_events_json", sa.Unicode(), nullable=True)) + op.add_column("AttackResultEntries", sa.Column("total_retries", sa.INTEGER(), nullable=True)) + + # Error pointer on ScenarioResultEntries + op.add_column("ScenarioResultEntries", sa.Column("error_attack_result_ids_json", sa.Unicode(), nullable=True)) + + +def downgrade() -> None: + """Revert this schema upgrade.""" + op.drop_column("AttackResultEntries", "error_message") + op.drop_column("AttackResultEntries", "error_type") + op.drop_column("AttackResultEntries", "error_traceback") + op.drop_column("AttackResultEntries", "retry_events_json") + op.drop_column("AttackResultEntries", "total_retries") + op.drop_column("ScenarioResultEntries", "error_attack_result_ids_json") diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index b28d05976e..80938cb74e 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -2060,6 +2060,46 @@ def update_scenario_run_state(self, *, scenario_result_id: str, scenario_run_sta ) raise + def update_scenario_error_attacks( + self, *, scenario_result_id: str, error_attack_result_ids: list[str] + ) -> bool: + """ + Update the error attack result IDs on an existing scenario result. + + This links failed AttackResults to the ScenarioResult so the REST API + can quickly find error details without scanning all attacks. + + Args: + scenario_result_id: The ID of the scenario result to update. + error_attack_result_ids: IDs of AttackResults that contain error information. + + Returns: + True if the update was successful, False otherwise. + """ + try: + scenario_results = self.get_scenario_results(scenario_result_ids=[scenario_result_id]) + + if not scenario_results: + logger.error(f"Scenario result with ID {scenario_result_id} not found in memory") + return False + + scenario_result = scenario_results[0] + scenario_result.error_attack_result_ids = error_attack_result_ids + + entry = ScenarioResultEntry(entry=scenario_result) + self._update_entry(entry) + + logger.info( + f"Updated scenario {scenario_result_id} with {len(error_attack_result_ids)} error attack result(s)" + ) + return True + + except Exception as e: + logger.exception( + f"Failed to update scenario {scenario_result_id} error attacks: {str(e)}" + ) + raise + def get_scenario_results( self, *, diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 08bd0e466a..69ad955b4f 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -736,6 +736,15 @@ class AttackResultEntry(Base): # Nullable for backwards compatibility with existing databases pyrit_version = mapped_column(String, nullable=True) + # Error information (populated when attack fails with exception) + error_message = mapped_column(Unicode, nullable=True) + error_type = mapped_column(String, nullable=True) + error_traceback = mapped_column(Unicode, nullable=True) + + # Retry events (JSON-serialized list of RetryEvent dicts) + retry_events_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) + total_retries = mapped_column(INTEGER, nullable=True, default=0) + last_response: Mapped[Optional["PromptMemoryEntry"]] = relationship( "PromptMemoryEntry", foreign_keys=[last_response_id], @@ -798,6 +807,18 @@ def __init__(self, *, entry: AttackResult) -> None: self.timestamp = entry.timestamp or datetime.now(tz=timezone.utc) self.pyrit_version = pyrit.__version__ + # Error information + self.error_message = entry.error_message + self.error_type = entry.error_type + # Truncate traceback to 10KB to avoid excessive DB storage + self.error_traceback = entry.error_traceback[:10240] if entry.error_traceback else None + + # Retry events + self.retry_events_json = ( + json.dumps([evt.to_dict() for evt in entry.retry_events]) if entry.retry_events else None + ) + self.total_retries = entry.total_retries + @staticmethod def _get_id_as_uuid(obj: Any) -> Optional[uuid.UUID]: """ @@ -883,6 +904,14 @@ def get_attack_result(self) -> AttackResult: attack_identifier=ComponentIdentifier.from_dict(self.attack_identifier), ) + # Deserialize retry events from JSON + retry_events = [] + if self.retry_events_json: + from pyrit.models.retry_event import RetryEvent + + for evt_dict in json.loads(self.retry_events_json): + retry_events.append(RetryEvent.from_dict(evt_dict)) + return AttackResult( conversation_id=self.conversation_id, attack_result_id=str(self.id), @@ -898,6 +927,11 @@ def get_attack_result(self) -> AttackResult: metadata=self.attack_metadata or {}, timestamp=_ensure_utc(self.timestamp) or datetime.now(tz=timezone.utc), labels=self.labels or {}, + error_message=self.error_message, + error_type=self.error_type, + error_traceback=self.error_traceback, + retry_events=retry_events, + total_retries=self.total_retries or 0, ) @@ -959,6 +993,9 @@ class ScenarioResultEntry(Base): completion_time = mapped_column(DateTime, nullable=False) timestamp = mapped_column(DateTime, nullable=False) + # Pointer to failed attack result(s) — avoids scanning all attacks for error info + error_attack_result_ids_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) + def __init__(self, *, entry: ScenarioResult) -> None: """ Initialize a ScenarioResultEntry from a ScenarioResult object. @@ -1004,6 +1041,11 @@ def __init__(self, *, entry: ScenarioResult) -> None: # Serialize display_group_map if present self.display_group_map_json = json.dumps(entry._display_group_map) if entry._display_group_map else None + # Serialize error_attack_result_ids if present + self.error_attack_result_ids_json = ( + json.dumps(entry.error_attack_result_ids) if entry.error_attack_result_ids else None + ) + self.timestamp = datetime.now(tz=timezone.utc) def get_scenario_result(self) -> ScenarioResult: @@ -1045,6 +1087,11 @@ def get_scenario_result(self) -> ScenarioResult: if self.display_group_map_json: display_group_map = json.loads(self.display_group_map_json) + # Deserialize error_attack_result_ids if stored + error_attack_result_ids: list[str] | None = None + if self.error_attack_result_ids_json: + error_attack_result_ids = json.loads(self.error_attack_result_ids_json) + return ScenarioResult( id=self.id, scenario_identifier=scenario_identifier, @@ -1057,6 +1104,7 @@ def get_scenario_result(self) -> ScenarioResult: number_tries=self.number_tries, completion_time=self.completion_time, display_group_map=display_group_map, + error_attack_result_ids=error_attack_result_ids, ) def get_conversation_ids_by_attack_name(self) -> dict[str, list[str]]: diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 5a85226322..73f08007c5 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -33,6 +33,7 @@ ) from pyrit.models.message_piece import MessagePiece, sort_message_pieces from pyrit.models.question_answering import QuestionAnsweringDataset, QuestionAnsweringEntry, QuestionChoice +from pyrit.models.retry_event import RetryEvent from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult from pyrit.models.score import Score, ScoreType, UnvalidatedScore diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index be51668fa1..9e375f8300 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -10,6 +10,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Optional, TypeVar +from pyrit.models.retry_event import RetryEvent from pyrit.models.strategy_result import StrategyResult if TYPE_CHECKING: @@ -94,6 +95,15 @@ class AttackResult(StrategyResult): # labels associated with this attack result labels: dict[str, str] = field(default_factory=dict) + # Error information (populated when attack fails with exception) + error_message: str | None = None + error_type: str | None = None + error_traceback: str | None = None + + # Retry tracking + retry_events: list[RetryEvent] = field(default_factory=list) + total_retries: int = 0 + @property def attack_identifier(self) -> Optional[ComponentIdentifier]: """ diff --git a/pyrit/models/retry_event.py b/pyrit/models/retry_event.py new file mode 100644 index 0000000000..f2f918cbc5 --- /dev/null +++ b/pyrit/models/retry_event.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Data model for capturing individual retry events during execution.""" + +from dataclasses import dataclass, field +from datetime import datetime, timezone + + +@dataclass +class RetryEvent: + """A single retry attempt captured during attack execution. + + Records structured information about a Tenacity retry event, including + which component was retrying, what exception triggered the retry, and + timing information. These events are collected by a RetryCollector and + attached to AttackResult objects for persistence and REST API exposure. + """ + + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + attempt_number: int = 0 + function_name: str = "" + exception_type: str = "" + exception_message: str = "" + component_role: str = "" + component_name: str | None = None + endpoint: str | None = None + elapsed_seconds: float = 0.0 + + def to_dict(self) -> dict: + """Serialize to a dictionary suitable for JSON storage. + + Returns: + dict: Dictionary representation of the retry event. + """ + return { + "timestamp": self.timestamp.isoformat(), + "attempt_number": self.attempt_number, + "function_name": self.function_name, + "exception_type": self.exception_type, + "exception_message": self.exception_message, + "component_role": self.component_role, + "component_name": self.component_name, + "endpoint": self.endpoint, + "elapsed_seconds": self.elapsed_seconds, + } + + @classmethod + def from_dict(cls, data: dict) -> "RetryEvent": + """Deserialize from a dictionary. + + Args: + data: Dictionary representation of a retry event. + + Returns: + RetryEvent: Deserialized retry event. + """ + return cls( + timestamp=datetime.fromisoformat(data["timestamp"]), + attempt_number=data.get("attempt_number", 0), + function_name=data.get("function_name", ""), + exception_type=data.get("exception_type", ""), + exception_message=data.get("exception_message", ""), + component_role=data.get("component_role", ""), + component_name=data.get("component_name"), + endpoint=data.get("endpoint"), + elapsed_seconds=data.get("elapsed_seconds", 0.0), + ) diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index 496b1d73a9..6a657a1d36 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -69,6 +69,7 @@ def __init__( number_tries: int = 0, id: uuid.UUID | None = None, # noqa: A002 display_group_map: dict[str, str] | None = None, + error_attack_result_ids: list[str] | None = None, ) -> None: """ Initialize a scenario result. @@ -87,6 +88,9 @@ def __init__( display_group_map (Optional[dict[str, str]]): Optional mapping of atomic_attack_name → display group label. Used by the console printer to aggregate results for user-facing output. + error_attack_result_ids (Optional[list[str]]): IDs of AttackResults that + contain error information. Used for quick error lookup without scanning + all attack results. """ self.id = id if id is not None else uuid.uuid4() @@ -103,6 +107,7 @@ def __init__( self.completion_time = completion_time if completion_time is not None else datetime.now(timezone.utc) self.number_tries = number_tries self._display_group_map = display_group_map or {} + self.error_attack_result_ids = error_attack_result_ids or [] @property def display_group_map(self) -> dict[str, str]: diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 089bed4267..34618cba1c 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -1095,6 +1095,20 @@ async def _execute_scenario_async(self) -> ScenarioResult: for obj, exc in atomic_results.incomplete_objectives: logger.error(f" Incomplete objective '{obj[:50]}...': {str(exc)}") + # Collect error attack result IDs from the exceptions + error_ids = [] + for _, exc in atomic_results.incomplete_objectives: + error_id = getattr(exc, "error_attack_result_id", None) + if error_id: + error_ids.append(error_id) + + # Link error attack results to the scenario result + if error_ids: + self._memory.update_scenario_error_attacks( + scenario_result_id=scenario_result_id, + error_attack_result_ids=error_ids, + ) + # Mark scenario as failed self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, From 95be13eccb7c5202df35171b2e2b3d9a2806327c Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Thu, 7 May 2026 16:40:33 -0700 Subject: [PATCH 2/9] test: add tests for error/retry tracking - RetryEvent: serialization, round-trip, defaults, missing fields - RetryCollector: contextvar lifecycle, record(), asyncio task isolation - AttackResult: error field defaults, storage, DB round-trip, truncation - ScenarioResult: error_attack_result_ids defaults and storage Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/unit/exceptions/test_retry_collector.py | 112 ++++++++++++++ tests/unit/models/test_attack_result.py | 145 ++++++++++++++++++ tests/unit/models/test_retry_event.py | 123 +++++++++++++++ tests/unit/models/test_scenario_result.py | 21 +++ 4 files changed, 401 insertions(+) create mode 100644 tests/unit/exceptions/test_retry_collector.py create mode 100644 tests/unit/models/test_retry_event.py diff --git a/tests/unit/exceptions/test_retry_collector.py b/tests/unit/exceptions/test_retry_collector.py new file mode 100644 index 0000000000..14db749b65 --- /dev/null +++ b/tests/unit/exceptions/test_retry_collector.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio + +from pyrit.exceptions.retry_collector import ( + RetryCollector, + clear_retry_collector, + get_retry_collector, + set_retry_collector, +) + + +class TestRetryCollector: + """Tests for the RetryCollector and its contextvar helpers.""" + + def test_collector_starts_empty(self) -> None: + """A new RetryCollector has no events.""" + c = RetryCollector() + assert c.events == [] + + def test_contextvar_default_is_none(self) -> None: + """get_retry_collector returns None when no collector is set.""" + clear_retry_collector() + assert get_retry_collector() is None + + def test_set_and_get(self) -> None: + """set_retry_collector makes it visible to get_retry_collector.""" + c = RetryCollector() + set_retry_collector(c) + assert get_retry_collector() is c + clear_retry_collector() + + def test_clear(self) -> None: + """clear_retry_collector resets the contextvar to None.""" + c = RetryCollector() + set_retry_collector(c) + clear_retry_collector() + assert get_retry_collector() is None + + def test_record_extracts_exception_info(self) -> None: + """record() extracts exception type and message from retry_state.""" + from unittest.mock import MagicMock + + c = RetryCollector() + + # Build a mock retry_state matching Tenacity's RetryCallState + retry_state = MagicMock() + retry_state.attempt_number = 2 + retry_state.start_time = 0.0 + retry_state.fn = MagicMock() + retry_state.fn.__name__ = "my_function" + + # Mock outcome with .failed=True and .exception() returning a ValueError + outcome = MagicMock() + outcome.failed = True + outcome.exception.return_value = ValueError("test error") + retry_state.outcome = outcome + + c.record(retry_state=retry_state) + + assert len(c.events) == 1 + evt = c.events[0] + assert evt.attempt_number == 2 + assert evt.function_name == "my_function" + assert evt.exception_type == "ValueError" + assert evt.exception_message == "test error" + + def test_record_multiple_events(self) -> None: + """record() accumulates events.""" + from unittest.mock import MagicMock + + c = RetryCollector() + + for i in range(3): + retry_state = MagicMock() + retry_state.attempt_number = i + 1 + retry_state.start_time = 0.0 + retry_state.fn = MagicMock() + retry_state.fn.__name__ = f"fn_{i}" + outcome = MagicMock() + outcome.failed = True + outcome.exception.return_value = RuntimeError(f"error {i}") + retry_state.outcome = outcome + c.record(retry_state=retry_state) + + assert len(c.events) == 3 + assert c.events[0].function_name == "fn_0" + assert c.events[2].function_name == "fn_2" + + def test_contextvar_isolation_across_tasks(self) -> None: + """Each asyncio task gets its own contextvar value.""" + results: dict[str, bool] = {} + + async def task_a() -> None: + c = RetryCollector() + set_retry_collector(c) + await asyncio.sleep(0.01) + results["a_has_collector"] = get_retry_collector() is c + clear_retry_collector() + + async def task_b() -> None: + await asyncio.sleep(0.005) + results["b_sees_none"] = get_retry_collector() is None + + async def run() -> None: + clear_retry_collector() + await asyncio.gather(task_a(), task_b()) + + asyncio.run(run()) + assert results.get("a_has_collector") is True + assert results.get("b_sees_none") is True diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index 12dd64e260..874d924846 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -8,6 +8,7 @@ from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory.memory_models import AttackResultEntry from pyrit.models.attack_result import AttackOutcome, AttackResult +from pyrit.models.retry_event import RetryEvent class TestAttackResultDeprecation: @@ -206,3 +207,147 @@ def test_naive_entry_timestamp_is_normalized_to_utc_on_hydration(self) -> None: assert hydrated.timestamp is not None assert hydrated.timestamp.tzinfo is timezone.utc assert hydrated.timestamp.replace(tzinfo=None) == datetime(2026, 4, 17, 12, 0, 0) # noqa: DTZ001 + + +class TestAttackResultErrorFields: + """Tests for the error and retry fields on AttackResult.""" + + def test_error_fields_default_to_none(self) -> None: + """AttackResult without error fields defaults to None/empty.""" + result = AttackResult(conversation_id="c1", objective="test") + assert result.error_message is None + assert result.error_type is None + assert result.error_traceback is None + assert result.retry_events == [] + assert result.total_retries == 0 + + def test_error_fields_set_correctly(self) -> None: + """AttackResult stores error fields when provided.""" + result = AttackResult( + conversation_id="c1", + objective="test", + error_message="Connection refused", + error_type="ConnectionError", + error_traceback="Traceback (most recent call last):\n ...", + total_retries=3, + ) + assert result.error_message == "Connection refused" + assert result.error_type == "ConnectionError" + assert "Traceback" in result.error_traceback + assert result.total_retries == 3 + + def test_retry_events_stored_on_result(self) -> None: + """AttackResult stores retry events.""" + events = [ + RetryEvent(attempt_number=1, function_name="fn1", exception_type="TimeoutError"), + RetryEvent(attempt_number=2, function_name="fn1", exception_type="TimeoutError"), + ] + result = AttackResult( + conversation_id="c1", + objective="test", + retry_events=events, + total_retries=2, + ) + assert len(result.retry_events) == 2 + assert result.retry_events[0].attempt_number == 1 + assert result.retry_events[1].attempt_number == 2 + + +class TestAttackResultErrorRoundTrip: + """Tests that error/retry fields survive the AttackResult -> AttackResultEntry -> AttackResult round-trip.""" + + def test_error_fields_roundtrip(self) -> None: + """Error fields are serialized to entry and deserialized back.""" + original = AttackResult( + conversation_id="c1", + objective="test", + outcome=AttackOutcome.FAILURE, + error_message="Rate limit hit", + error_type="RateLimitError", + error_traceback="Traceback...\n File ...", + total_retries=5, + ) + entry = AttackResultEntry(entry=original) + + # Verify serialized values on entry + assert entry.error_message == "Rate limit hit" + assert entry.error_type == "RateLimitError" + assert entry.error_traceback == "Traceback...\n File ..." + assert entry.total_retries == 5 + + # Deserialize back + hydrated = entry.get_attack_result() + assert hydrated.error_message == "Rate limit hit" + assert hydrated.error_type == "RateLimitError" + assert hydrated.error_traceback == "Traceback...\n File ..." + assert hydrated.total_retries == 5 + + def test_retry_events_roundtrip(self) -> None: + """Retry events are serialized to JSON and deserialized back.""" + events = [ + RetryEvent( + attempt_number=1, + function_name="send_async", + exception_type="TimeoutError", + exception_message="timed out", + component_role="target", + component_name="AzureTarget", + endpoint="https://api.azure.com", + elapsed_seconds=5.5, + ), + RetryEvent( + attempt_number=2, + function_name="send_async", + exception_type="RateLimitError", + exception_message="429", + elapsed_seconds=10.0, + ), + ] + original = AttackResult( + conversation_id="c1", + objective="test", + retry_events=events, + total_retries=2, + ) + entry = AttackResultEntry(entry=original) + assert entry.retry_events_json is not None + + hydrated = entry.get_attack_result() + assert len(hydrated.retry_events) == 2 + assert hydrated.retry_events[0].attempt_number == 1 + assert hydrated.retry_events[0].function_name == "send_async" + assert hydrated.retry_events[0].exception_type == "TimeoutError" + assert hydrated.retry_events[0].component_name == "AzureTarget" + assert hydrated.retry_events[1].attempt_number == 2 + assert hydrated.retry_events[1].exception_type == "RateLimitError" + assert hydrated.total_retries == 2 + + def test_no_error_fields_roundtrip(self) -> None: + """AttackResult without error fields round-trips cleanly.""" + original = AttackResult( + conversation_id="c1", + objective="test", + outcome=AttackOutcome.SUCCESS, + ) + entry = AttackResultEntry(entry=original) + assert entry.error_message is None + assert entry.error_type is None + assert entry.retry_events_json is None + assert entry.total_retries == 0 + + hydrated = entry.get_attack_result() + assert hydrated.error_message is None + assert hydrated.error_type is None + assert hydrated.retry_events == [] + assert hydrated.total_retries == 0 + + def test_traceback_truncation(self) -> None: + """Very long tracebacks are truncated to 10KB.""" + long_traceback = "x" * 20000 + original = AttackResult( + conversation_id="c1", + objective="test", + error_traceback=long_traceback, + ) + entry = AttackResultEntry(entry=original) + assert len(entry.error_traceback) == 10240 diff --git a/tests/unit/models/test_retry_event.py b/tests/unit/models/test_retry_event.py new file mode 100644 index 0000000000..09f20dfca7 --- /dev/null +++ b/tests/unit/models/test_retry_event.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from datetime import datetime, timezone + +from pyrit.models.retry_event import RetryEvent + + +class TestRetryEvent: + """Tests for the RetryEvent dataclass.""" + + def test_defaults(self) -> None: + """RetryEvent constructed with minimal args gets correct defaults.""" + evt = RetryEvent(attempt_number=1, function_name="test_fn") + assert evt.attempt_number == 1 + assert evt.function_name == "test_fn" + assert evt.exception_type == "" + assert evt.exception_message == "" + assert evt.component_role == "" + assert evt.component_name is None + assert evt.endpoint is None + assert evt.elapsed_seconds == 0.0 + assert evt.timestamp is not None + assert evt.timestamp.tzinfo is timezone.utc + + def test_full_construction(self) -> None: + """RetryEvent constructed with all args stores them correctly.""" + ts = datetime(2026, 5, 7, 12, 0, 0, tzinfo=timezone.utc) + evt = RetryEvent( + attempt_number=3, + function_name="send_prompt_async", + exception_type="RateLimitError", + exception_message="Rate limit exceeded", + component_role="objective_target", + component_name="OpenAIChatTarget", + endpoint="https://api.openai.com/v1/chat", + elapsed_seconds=5.123, + timestamp=ts, + ) + assert evt.attempt_number == 3 + assert evt.function_name == "send_prompt_async" + assert evt.exception_type == "RateLimitError" + assert evt.exception_message == "Rate limit exceeded" + assert evt.component_role == "objective_target" + assert evt.component_name == "OpenAIChatTarget" + assert evt.endpoint == "https://api.openai.com/v1/chat" + assert evt.elapsed_seconds == 5.123 + assert evt.timestamp == ts + + def test_to_dict(self) -> None: + """to_dict returns a JSON-serializable dictionary.""" + evt = RetryEvent( + attempt_number=2, + function_name="fn", + exception_type="ValueError", + exception_message="bad input", + component_role="scorer", + component_name="TFScorer", + endpoint="https://example.com", + elapsed_seconds=1.5, + ) + d = evt.to_dict() + assert d["attempt_number"] == 2 + assert d["function_name"] == "fn" + assert d["exception_type"] == "ValueError" + assert d["exception_message"] == "bad input" + assert d["component_role"] == "scorer" + assert d["component_name"] == "TFScorer" + assert d["endpoint"] == "https://example.com" + assert d["elapsed_seconds"] == 1.5 + assert "timestamp" in d + + def test_from_dict_roundtrip(self) -> None: + """from_dict correctly reconstructs a RetryEvent from to_dict output.""" + original = RetryEvent( + attempt_number=1, + function_name="call_target", + exception_type="TimeoutError", + exception_message="Request timed out", + component_role="objective_target", + component_name="AzureTarget", + endpoint="https://azure.openai.com", + elapsed_seconds=10.0, + ) + d = original.to_dict() + restored = RetryEvent.from_dict(d) + + assert restored.attempt_number == original.attempt_number + assert restored.function_name == original.function_name + assert restored.exception_type == original.exception_type + assert restored.exception_message == original.exception_message + assert restored.component_role == original.component_role + assert restored.component_name == original.component_name + assert restored.endpoint == original.endpoint + assert restored.elapsed_seconds == original.elapsed_seconds + + def test_from_dict_missing_optional_fields(self) -> None: + """from_dict handles missing optional fields gracefully.""" + d = { + "attempt_number": 1, + "function_name": "fn", + "timestamp": "2026-05-07T12:00:00+00:00", + } + evt = RetryEvent.from_dict(d) + assert evt.attempt_number == 1 + assert evt.function_name == "fn" + assert evt.exception_type == "" + assert evt.component_name is None + assert evt.endpoint is None + assert evt.elapsed_seconds == 0.0 + + def test_from_dict_timestamp_parsing(self) -> None: + """from_dict correctly parses ISO format timestamp.""" + d = { + "attempt_number": 1, + "function_name": "fn", + "timestamp": "2026-05-07T12:30:00+00:00", + } + evt = RetryEvent.from_dict(d) + assert evt.timestamp.year == 2026 + assert evt.timestamp.month == 5 + assert evt.timestamp.hour == 12 + assert evt.timestamp.minute == 30 diff --git a/tests/unit/models/test_scenario_result.py b/tests/unit/models/test_scenario_result.py index 68a1cef1a9..02af031429 100644 --- a/tests/unit/models/test_scenario_result.py +++ b/tests/unit/models/test_scenario_result.py @@ -165,3 +165,24 @@ def test_normalize_scenario_name_already_pascal(self): def test_normalize_scenario_name_mixed_case_with_underscore(self): assert ScenarioResult.normalize_scenario_name("Content_harms") == "Content_harms" + + def test_error_attack_result_ids_defaults_to_empty(self): + """error_attack_result_ids defaults to empty list.""" + sr = ScenarioResult( + scenario_identifier=_make_scenario_identifier(), + objective_target_identifier=ComponentIdentifier.from_dict({}), + attack_results={}, + objective_scorer_identifier=ComponentIdentifier.from_dict({}), + ) + assert sr.error_attack_result_ids == [] + + def test_error_attack_result_ids_stored(self): + """error_attack_result_ids are stored correctly.""" + sr = ScenarioResult( + scenario_identifier=_make_scenario_identifier(), + objective_target_identifier=ComponentIdentifier.from_dict({}), + attack_results={}, + objective_scorer_identifier=ComponentIdentifier.from_dict({}), + error_attack_result_ids=["id-1", "id-2"], + ) + assert sr.error_attack_result_ids == ["id-1", "id-2"] From b49c1f0de720b4ff504f49a523937a91aed52112 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 11 May 2026 15:29:16 -0700 Subject: [PATCH 3/9] adding attack status --- .../src/components/History/AttackTable.tsx | 5 +++- .../components/History/HistoryFiltersBar.tsx | 2 ++ frontend/src/types/index.ts | 2 +- pyrit/analytics/result_analysis.py | 9 +++++- pyrit/backend/models/attacks.py | 4 +-- pyrit/backend/routes/attacks.py | 2 +- pyrit/backend/services/attack_service.py | 3 +- .../backend/services/scenario_run_service.py | 4 +-- pyrit/exceptions/retry_collector.py | 12 +++++--- pyrit/executor/attack/core/attack_strategy.py | 19 +++++++----- pyrit/memory/memory_interface.py | 8 ++--- pyrit/memory/memory_models.py | 2 +- pyrit/models/__init__.py | 2 +- pyrit/models/attack_result.py | 3 ++ pyrit/models/retry_event.py | 9 ++++-- tests/unit/analytics/test_result_analysis.py | 30 ++++++++++++++----- tests/unit/backend/test_attack_service.py | 13 ++++++++ 17 files changed, 89 insertions(+), 40 deletions(-) diff --git a/frontend/src/components/History/AttackTable.tsx b/frontend/src/components/History/AttackTable.tsx index 30d52998b6..16d052d7c7 100644 --- a/frontend/src/components/History/AttackTable.tsx +++ b/frontend/src/components/History/AttackTable.tsx @@ -16,6 +16,7 @@ import { CheckmarkCircleRegular, DismissCircleRegular, QuestionCircleRegular, + ErrorCircleRegular, } from '@fluentui/react-icons' import type { AttackSummary } from '../../types' import { useAttackHistoryStyles } from './AttackHistory.styles' @@ -23,12 +24,14 @@ import { useAttackHistoryStyles } from './AttackHistory.styles' const OUTCOME_ICONS: Record = { success: , failure: , + error: , undetermined: , } -const OUTCOME_COLORS: Record = { +const OUTCOME_COLORS: Record = { success: 'success', failure: 'danger', + error: 'warning', undetermined: 'informative', } diff --git a/frontend/src/components/History/HistoryFiltersBar.tsx b/frontend/src/components/History/HistoryFiltersBar.tsx index 053fa8a7d3..7165cd3960 100644 --- a/frontend/src/components/History/HistoryFiltersBar.tsx +++ b/frontend/src/components/History/HistoryFiltersBar.tsx @@ -21,6 +21,7 @@ const NO_CONVERTERS_SENTINEL = '__no_converters__' const OUTCOME_LABELS: Record = { success: 'Success', failure: 'Failure', + error: 'Error', undetermined: 'Undetermined', } @@ -200,6 +201,7 @@ export default function HistoryFiltersBar({ + | null target?: TargetInfo | null converters: string[] - outcome?: 'undetermined' | 'success' | 'failure' | null + outcome?: 'undetermined' | 'success' | 'failure' | 'error' | null last_message_preview?: string | null message_count: number related_conversation_ids: string[] diff --git a/pyrit/analytics/result_analysis.py b/pyrit/analytics/result_analysis.py index a403d1aa37..d2e998ee94 100644 --- a/pyrit/analytics/result_analysis.py +++ b/pyrit/analytics/result_analysis.py @@ -17,9 +17,10 @@ class AttackStats: successes: int failures: int undetermined: int + errors: int -def _compute_stats(successes: int, failures: int, undetermined: int) -> AttackStats: +def _compute_stats(successes: int, failures: int, undetermined: int, errors: int) -> AttackStats: total_decided = successes + failures success_rate = successes / total_decided if total_decided > 0 else None return AttackStats( @@ -28,6 +29,7 @@ def _compute_stats(successes: int, failures: int, undetermined: int) -> AttackSt successes=successes, failures=failures, undetermined=undetermined, + errors=errors, ) @@ -71,6 +73,9 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats elif outcome == AttackOutcome.FAILURE: overall_counts["failures"] += 1 by_type_counts[attack_type]["failures"] += 1 + elif outcome == AttackOutcome.ERROR: + overall_counts["errors"] += 1 + by_type_counts[attack_type]["errors"] += 1 else: overall_counts["undetermined"] += 1 by_type_counts[attack_type]["undetermined"] += 1 @@ -79,6 +84,7 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats successes=overall_counts["successes"], failures=overall_counts["failures"], undetermined=overall_counts["undetermined"], + errors=overall_counts["errors"], ) by_type_stats = { @@ -86,6 +92,7 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats successes=counts["successes"], failures=counts["failures"], undetermined=counts["undetermined"], + errors=counts["errors"], ) for attack_type, counts in by_type_counts.items() } diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index e89b98358f..2f98f78b7e 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -116,7 +116,7 @@ class AttackSummary(BaseModel): default_factory=list, description="Request converter class names applied in this attack" ) objective: str = Field("", description="Natural-language description of the attacker's objective") - outcome: Optional[Literal["undetermined", "success", "failure"]] = Field( + outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = Field( None, description="Attack outcome (null if not yet determined)" ) outcome_reason: str | None = Field(None, description="Reason for the outcome") @@ -261,7 +261,7 @@ class CreateAttackResponse(BaseModel): class UpdateAttackRequest(BaseModel): """Request to update an attack's outcome.""" - outcome: Literal["undetermined", "success", "failure"] = Field(..., description="Updated attack outcome") + outcome: Literal["undetermined", "success", "failure", "error"] = Field(..., description="Updated attack outcome") # ============================================================================ diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index 83f8ece989..2737163db6 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -93,7 +93,7 @@ async def list_attacks( description="Filter by converter presence. true = attacks with at least one converter; " "false = attacks with no converters. Omit for no filter.", ), - outcome: Optional[Literal["undetermined", "success", "failure"]] = Query(None, description="Filter by outcome"), + outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = Query(None, description="Filter by outcome"), label: Optional[list[str]] = Query( None, description="Filter by labels (format: key:value). May be specified multiple times; " diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index f2da9ab7e9..5cfd83a7ae 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -86,7 +86,7 @@ async def list_attacks_async( converter_types: Optional[Sequence[str]] = None, converter_types_match: Literal["any", "all"] = "all", has_converters: Optional[bool] = None, - outcome: Optional[Literal["undetermined", "success", "failure"]] = None, + outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = None, labels: Optional[dict[str, str | Sequence[str]]] = None, min_turns: Optional[int] = None, max_turns: Optional[int] = None, @@ -370,6 +370,7 @@ async def update_attack_async( "undetermined": AttackOutcome.UNDETERMINED, "success": AttackOutcome.SUCCESS, "failure": AttackOutcome.FAILURE, + "error": AttackOutcome.ERROR, } new_outcome = outcome_map.get(request.outcome, AttackOutcome.UNDETERMINED) diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 373d76c2c4..5eb2d97563 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -408,9 +408,7 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari memory = CentralMemory.get_memory_instance() error_ids = scenario_result.error_attack_result_ids if isinstance(error_ids, list) and error_ids: - error_results = memory.get_attack_results( - attack_result_ids=error_ids[:1] - ) + error_results = memory.get_attack_results(attack_result_ids=error_ids[:1]) if error_results: error = error_results[0].error_message error_type = error_results[0].error_type diff --git a/pyrit/exceptions/retry_collector.py b/pyrit/exceptions/retry_collector.py index ad89d05f3b..789863f307 100644 --- a/pyrit/exceptions/retry_collector.py +++ b/pyrit/exceptions/retry_collector.py @@ -12,7 +12,8 @@ @dataclass class RetryCollector: - """Collects retry events during attack execution. + """ + Collects retry events during attack execution. Uses contextvar for thread/task-safe scoping. Each attack execution creates its own collector so retry events are naturally scoped @@ -22,7 +23,8 @@ class RetryCollector: events: list[RetryEvent] = field(default_factory=list) def record(self, *, retry_state: Any) -> None: - """Record a retry event from a Tenacity RetryCallState. + """ + Record a retry event from a Tenacity RetryCallState. Extracts information from the retry state and the current ExecutionContext to build a structured RetryEvent. @@ -82,7 +84,8 @@ def record(self, *, retry_state: Any) -> None: def get_retry_collector() -> Optional[RetryCollector]: - """Get the current retry collector. + """ + Get the current retry collector. Returns: The active RetryCollector, or None if not set. @@ -91,7 +94,8 @@ def get_retry_collector() -> Optional[RetryCollector]: def set_retry_collector(collector: RetryCollector) -> None: - """Set the current retry collector. + """ + Set the current retry collector. Args: collector: The RetryCollector to activate. diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index f8eb45b9d1..99dc090097 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -12,7 +12,12 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, overload from pyrit.common.logger import logger -from pyrit.exceptions.retry_collector import RetryCollector, clear_retry_collector, get_retry_collector, set_retry_collector +from pyrit.exceptions.retry_collector import ( + RetryCollector, + clear_retry_collector, + get_retry_collector, + set_retry_collector, +) from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT from pyrit.executor.core import ( Strategy, @@ -241,18 +246,18 @@ def _log_attack_outcome(self, result: AttackResult) -> None: message = f"{attack_name} achieved the objective. {reason}" elif result.outcome == AttackOutcome.UNDETERMINED: message = f"{attack_name} outcome is undetermined. {reason}" + elif result.outcome == AttackOutcome.ERROR: + message = f"{attack_name} failed with an error. {reason}" else: message = f"{attack_name} did not achieve the objective. {reason}" self._logger.info(message) - async def _on_error( - self, event_data: StrategyEventData[AttackStrategyContextT, AttackStrategyResultT] - ) -> None: + async def _on_error(self, event_data: StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]) -> None: """ Handle error during attack execution. - Creates a failed AttackResult with error details and any retry events + Creates an error AttackResult with error details and any retry events collected during execution, then persists it to memory. Args: @@ -273,12 +278,10 @@ async def _on_error( # Build a conversation_id — use context's if available, otherwise generate one conversation_id = getattr(context, "conversation_id", None) or str(__import__("uuid").uuid4()) - from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier - error_result = AttackResult( conversation_id=conversation_id, objective=context.objective, - outcome=AttackOutcome.FAILURE, + outcome=AttackOutcome.ERROR, outcome_reason=f"Exception: {type(error).__name__}: {str(error)}", labels=context.memory_labels, related_conversations=context.related_conversations, diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 80938cb74e..ec2128cf38 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -2060,9 +2060,7 @@ def update_scenario_run_state(self, *, scenario_result_id: str, scenario_run_sta ) raise - def update_scenario_error_attacks( - self, *, scenario_result_id: str, error_attack_result_ids: list[str] - ) -> bool: + def update_scenario_error_attacks(self, *, scenario_result_id: str, error_attack_result_ids: list[str]) -> bool: """ Update the error attack result IDs on an existing scenario result. @@ -2095,9 +2093,7 @@ def update_scenario_error_attacks( return True except Exception as e: - logger.exception( - f"Failed to update scenario {scenario_result_id} error attacks: {str(e)}" - ) + logger.exception(f"Failed to update scenario {scenario_result_id} error attacks: {str(e)}") raise def get_scenario_results( diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 69ad955b4f..ce8633a975 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -723,7 +723,7 @@ class AttackResultEntry(Base): ) executed_turns = mapped_column(INTEGER, nullable=False, default=0) execution_time_ms = mapped_column(INTEGER, nullable=False, default=0) - outcome: Mapped[Literal["success", "failure", "undetermined"]] = mapped_column( + outcome: Mapped[Literal["success", "failure", "error", "undetermined"]] = mapped_column( String, nullable=False, default="undetermined" ) outcome_reason = mapped_column(String, nullable=True) diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 73f08007c5..adba529d84 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -115,5 +115,5 @@ "StrategyResultT", "TextDataTypeSerializer", "UnvalidatedScore", - "VideoPathDataTypeSerializer", + "VideoPathDataTypeSerializer", "RetryEvent", ] diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 9e375f8300..8e9d98d3e9 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -38,6 +38,9 @@ class AttackOutcome(str, Enum): # The attack failed to achieve its objective FAILURE = "failure" + # The attack failed due to an infrastructure error (exception), not a defensive refusal + ERROR = "error" + # The outcome of the attack is unknown or could not be determined UNDETERMINED = "undetermined" diff --git a/pyrit/models/retry_event.py b/pyrit/models/retry_event.py index f2f918cbc5..030bba3500 100644 --- a/pyrit/models/retry_event.py +++ b/pyrit/models/retry_event.py @@ -9,7 +9,8 @@ @dataclass class RetryEvent: - """A single retry attempt captured during attack execution. + """ + A single retry attempt captured during attack execution. Records structured information about a Tenacity retry event, including which component was retrying, what exception triggered the retry, and @@ -28,7 +29,8 @@ class RetryEvent: elapsed_seconds: float = 0.0 def to_dict(self) -> dict: - """Serialize to a dictionary suitable for JSON storage. + """ + Serialize to a dictionary suitable for JSON storage. Returns: dict: Dictionary representation of the retry event. @@ -47,7 +49,8 @@ def to_dict(self) -> dict: @classmethod def from_dict(cls, data: dict) -> "RetryEvent": - """Deserialize from a dictionary. + """ + Deserialize from a dictionary. Args: data: Dictionary representation of a retry event. diff --git a/tests/unit/analytics/test_result_analysis.py b/tests/unit/analytics/test_result_analysis.py index 05d1b94f0e..e2d96b5bd4 100644 --- a/tests/unit/analytics/test_result_analysis.py +++ b/tests/unit/analytics/test_result_analysis.py @@ -42,29 +42,44 @@ def test_analyze_results_raises_on_invalid_object(): @pytest.mark.parametrize( - "outcomes, expected_successes, expected_failures, expected_undetermined, expected_rate", + "outcomes, expected_successes, expected_failures, expected_undetermined, expected_errors, expected_rate", [ # all successes - ([AttackOutcome.SUCCESS, AttackOutcome.SUCCESS], 2, 0, 0, 1.0), + ([AttackOutcome.SUCCESS, AttackOutcome.SUCCESS], 2, 0, 0, 0, 1.0), # all failures - ([AttackOutcome.FAILURE, AttackOutcome.FAILURE], 0, 2, 0, 0.0), + ([AttackOutcome.FAILURE, AttackOutcome.FAILURE], 0, 2, 0, 0, 0.0), # mixed decided - ([AttackOutcome.SUCCESS, AttackOutcome.FAILURE], 1, 1, 0, 0.5), + ([AttackOutcome.SUCCESS, AttackOutcome.FAILURE], 1, 1, 0, 0, 0.5), # include undetermined (excluded from denominator) - ([AttackOutcome.SUCCESS, AttackOutcome.UNDETERMINED], 1, 0, 1, 1.0), - ([AttackOutcome.FAILURE, AttackOutcome.UNDETERMINED], 0, 1, 1, 0.0), + ([AttackOutcome.SUCCESS, AttackOutcome.UNDETERMINED], 1, 0, 1, 0, 1.0), + ([AttackOutcome.FAILURE, AttackOutcome.UNDETERMINED], 0, 1, 1, 0, 0.0), # multiple with undetermined ( [AttackOutcome.SUCCESS, AttackOutcome.FAILURE, AttackOutcome.UNDETERMINED], 1, 1, 1, + 0, + 0.5, + ), + # error excluded from denominator (like undetermined) + ([AttackOutcome.SUCCESS, AttackOutcome.ERROR], 1, 0, 0, 1, 1.0), + ([AttackOutcome.FAILURE, AttackOutcome.ERROR], 0, 1, 0, 1, 0.0), + # all errors + ([AttackOutcome.ERROR, AttackOutcome.ERROR], 0, 0, 0, 2, None), + # mixed with error and undetermined + ( + [AttackOutcome.SUCCESS, AttackOutcome.FAILURE, AttackOutcome.ERROR, AttackOutcome.UNDETERMINED], + 1, + 1, + 1, + 1, 0.5, ), ], ) def test_overall_success_rate_parametrized( - outcomes, expected_successes, expected_failures, expected_undetermined, expected_rate + outcomes, expected_successes, expected_failures, expected_undetermined, expected_errors, expected_rate ): attacks = [make_attack(o) for o in outcomes] result = analyze_results(attacks) @@ -74,6 +89,7 @@ def test_overall_success_rate_parametrized( assert overall.successes == expected_successes assert overall.failures == expected_failures assert overall.undetermined == expected_undetermined + assert overall.errors == expected_errors assert overall.total_decided == expected_successes + expected_failures assert overall.success_rate == expected_rate diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 084a10eefa..b44145c7d6 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -887,6 +887,19 @@ async def test_update_attack_updates_outcome_undetermined(self, attack_service, call_kwargs = mock_memory.update_attack_result_by_id.call_args[1] assert call_kwargs["update_fields"]["outcome"] == "undetermined" + async def test_update_attack_updates_outcome_error(self, attack_service, mock_memory) -> None: + """Test that update_attack maps 'error' to AttackOutcome.ERROR.""" + ar = make_attack_result(conversation_id="test-id") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_conversation.return_value = [] + + await attack_service.update_attack_async( + attack_result_id="test-id", request=UpdateAttackRequest(outcome="error") + ) + + call_kwargs = mock_memory.update_attack_result_by_id.call_args[1] + assert call_kwargs["update_fields"]["outcome"] == "error" + async def test_update_attack_refreshes_updated_at(self, attack_service, mock_memory) -> None: """Test that update_attack refreshes the updated_at metadata.""" old_time = datetime(2024, 1, 1, tzinfo=timezone.utc) From 76370304aa01c699d9c78da373791a625a544d9e Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 11 May 2026 15:57:20 -0700 Subject: [PATCH 4/9] self pr review --- pyrit/backend/mappers/attack_mappers.py | 24 +++++++++++++++++++ pyrit/executor/attack/core/attack_strategy.py | 14 ++--------- pyrit/executor/core/strategy.py | 17 +++++++++++++ pyrit/memory/memory_interface.py | 3 ++- pyrit/models/__init__.py | 3 ++- 5 files changed, 47 insertions(+), 14 deletions(-) diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 9d6ffac1ae..96e99cc933 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -30,6 +30,7 @@ Message, MessagePiece, MessagePieceRequest, + RetryEventResponse, Score, TargetInfo, ) @@ -232,6 +233,24 @@ def attack_result_to_summary( else None ) + # Build retry event responses if available + retry_event_responses = None + if ar.retry_events: + retry_event_responses = [ + RetryEventResponse( + timestamp=evt.timestamp, + attempt_number=evt.attempt_number, + function_name=evt.function_name, + exception_type=evt.exception_type, + exception_message=evt.exception_message, + component_role=evt.component_role, + component_name=evt.component_name, + endpoint=evt.endpoint, + elapsed_seconds=evt.elapsed_seconds, + ) + for evt in ar.retry_events + ] + return AttackSummary( attack_result_id=ar.attack_result_id, conversation_id=ar.conversation_id, @@ -246,6 +265,11 @@ def attack_result_to_summary( labels=labels, created_at=created_at, updated_at=updated_at, + error_message=ar.error_message, + error_type=ar.error_type, + error_traceback=ar.error_traceback, + total_retries=ar.total_retries, + retry_events=retry_event_responses, ) diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 99dc090097..1fda98d0f2 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -13,10 +13,7 @@ from pyrit.common.logger import logger from pyrit.exceptions.retry_collector import ( - RetryCollector, - clear_retry_collector, get_retry_collector, - set_retry_collector, ) from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT from pyrit.executor.core import ( @@ -191,10 +188,6 @@ async def _on_pre_execute( # Initialize start time for execution event_data.context.start_time = time.perf_counter() - # Start a RetryCollector to capture retry events during this attack - collector = RetryCollector() - set_retry_collector(collector) - # Log the start of the attack self._logger.info(f"Starting attack: {event_data.context.objective}") @@ -225,7 +218,6 @@ async def _on_post_execute( if collector and collector.events: event_data.result.retry_events = collector.events event_data.result.total_retries = len(collector.events) - clear_retry_collector() self._logger.debug(f"Attack execution completed in {execution_time_ms}ms") @@ -267,13 +259,11 @@ async def _on_error(self, event_data: StrategyEventData[AttackStrategyContextT, error = event_data.error context = event_data.context if not error or not context: - clear_retry_collector() return - # Collect retry events before clearing + # Collect retry events (visible via inherited ContextVar copy) collector = get_retry_collector() retry_events = collector.events if collector else [] - clear_retry_collector() # Build a conversation_id — use context's if available, otherwise generate one conversation_id = getattr(context, "conversation_id", None) or str(__import__("uuid").uuid4()) @@ -287,7 +277,7 @@ async def _on_error(self, event_data: StrategyEventData[AttackStrategyContextT, related_conversations=context.related_conversations, error_message=str(error), error_type=type(error).__name__, - error_traceback=traceback.format_exc(), + error_traceback="".join(traceback.format_exception(type(error), error, error.__traceback__)), retry_events=retry_events, total_retries=len(retry_events), ) diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index 59c9046c7d..a3e0d67b8c 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -17,6 +17,11 @@ from pyrit.common import default_values from pyrit.common.logger import logger from pyrit.exceptions import clear_execution_context, get_execution_context +from pyrit.exceptions.retry_collector import ( + RetryCollector, + clear_retry_collector, + set_retry_collector, +) from pyrit.models import StrategyResultT if TYPE_CHECKING: @@ -328,12 +333,24 @@ async def execute_with_context_async(self, *, context: StrategyContextT) -> Stra try: async with self._execution_context(context): await self._handle_event(event=StrategyEvent.ON_PRE_EXECUTE, context=context) + + # Set up RetryCollector in the parent task so it is visible to + # Tenacity callbacks that fire during _perform_async. Event + # handlers run in child tasks (asyncio.create_task) which + # inherit a *copy* of the parent's ContextVar — setting the + # collector here ensures both the execution path and the event + # handlers can see it. + collector = RetryCollector() + set_retry_collector(collector) + result = await self._perform_async(context=context) await self._handle_event(event=StrategyEvent.ON_POST_EXECUTE, context=context, result=result) + clear_retry_collector() return result except Exception as e: # Notify error event await self._handle_event(event=StrategyEvent.ON_ERROR, context=context, error=e) + clear_retry_collector() # Build enhanced error message with execution context if available # Note: The context is preserved on exception by ExecutionContextManager diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index ec2128cf38..0cfe03df06 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -2082,7 +2082,8 @@ def update_scenario_error_attacks(self, *, scenario_result_id: str, error_attack return False scenario_result = scenario_results[0] - scenario_result.error_attack_result_ids = error_attack_result_ids + existing = scenario_result.error_attack_result_ids or [] + scenario_result.error_attack_result_ids = list(dict.fromkeys(existing + error_attack_result_ids)) entry = ScenarioResultEntry(entry=scenario_result) self._update_entry(entry) diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index adba529d84..f6671a8e95 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -115,5 +115,6 @@ "StrategyResultT", "TextDataTypeSerializer", "UnvalidatedScore", - "VideoPathDataTypeSerializer", "RetryEvent", + "VideoPathDataTypeSerializer", + "RetryEvent" ] From e74b440277b9e17ee0c296050cda3c3601fc7771 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Mon, 11 May 2026 16:15:03 -0700 Subject: [PATCH 5/9] self pr review --- pyrit/backend/routes/attacks.py | 4 +- .../backend/services/scenario_run_service.py | 3 +- pyrit/exceptions/exceptions_helpers.py | 6 +- pyrit/exceptions/retry_collector.py | 34 ++-- pyrit/executor/attack/core/attack_strategy.py | 3 + pyrit/executor/core/strategy.py | 12 +- pyrit/memory/memory_models.py | 3 +- pyrit/models/__init__.py | 2 +- pyrit/models/attack_result.py | 2 +- .../attack/core/test_attack_strategy.py | 145 ++++++++++++++++++ 10 files changed, 180 insertions(+), 34 deletions(-) diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index 2737163db6..38bb6991a0 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -93,7 +93,9 @@ async def list_attacks( description="Filter by converter presence. true = attacks with at least one converter; " "false = attacks with no converters. Omit for no filter.", ), - outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = Query(None, description="Filter by outcome"), + outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = Query( + None, description="Filter by outcome" + ), label: Optional[list[str]] = Query( None, description="Filter by labels (format: key:value). May be specified multiple times; " diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 5eb2d97563..0ee8b96d2a 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -405,10 +405,9 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari # Fall back to persisted error from failed AttackResult if not error and getattr(scenario_result, "error_attack_result_ids", None): - memory = CentralMemory.get_memory_instance() error_ids = scenario_result.error_attack_result_ids if isinstance(error_ids, list) and error_ids: - error_results = memory.get_attack_results(attack_result_ids=error_ids[:1]) + error_results = self._memory.get_attack_results(attack_result_ids=error_ids[:1]) if error_results: error = error_results[0].error_message error_type = error_results[0].error_type diff --git a/pyrit/exceptions/exceptions_helpers.py b/pyrit/exceptions/exceptions_helpers.py index 8c7a3075b6..9cdef2fb8c 100644 --- a/pyrit/exceptions/exceptions_helpers.py +++ b/pyrit/exceptions/exceptions_helpers.py @@ -9,6 +9,8 @@ from tenacity import RetryCallState from pyrit.exceptions.exception_context import get_execution_context +from pyrit.exceptions.retry_collector import get_retry_collector + logger = logging.getLogger(__name__) @@ -67,10 +69,6 @@ def log_exception(retry_state: RetryCallState) -> None: f"failed with exception: {exception}.{endpoint_clause} " f"Elapsed time: {elapsed_time} seconds. Total calls: {call_count}" ) - - # Record to RetryCollector if one is active - from pyrit.exceptions.retry_collector import get_retry_collector - collector = get_retry_collector() if collector: collector.record(retry_state=retry_state) diff --git a/pyrit/exceptions/retry_collector.py b/pyrit/exceptions/retry_collector.py index 789863f307..0710d40272 100644 --- a/pyrit/exceptions/retry_collector.py +++ b/pyrit/exceptions/retry_collector.py @@ -3,10 +3,13 @@ """Contextvar-based retry event collector for capturing Tenacity retry events.""" +import time from contextvars import ContextVar from dataclasses import dataclass, field -from typing import Any, Optional +from tenacity import RetryCallState + +from pyrit.exceptions.exception_context import get_execution_context from pyrit.models.retry_event import RetryEvent @@ -22,7 +25,7 @@ class RetryCollector: events: list[RetryEvent] = field(default_factory=list) - def record(self, *, retry_state: Any) -> None: + def record(self, *, retry_state: RetryCallState) -> None: """ Record a retry event from a Tenacity RetryCallState. @@ -30,26 +33,17 @@ def record(self, *, retry_state: Any) -> None: ExecutionContext to build a structured RetryEvent. Args: - retry_state: The Tenacity RetryCallState from the after callback. + retry_state (RetryCallState): The Tenacity retry call state from the after callback. """ - import time - - from pyrit.exceptions.exception_context import get_execution_context - - # Extract basic info from retry_state - call_count = getattr(retry_state, "attempt_number", None) or 0 - start_time = getattr(retry_state, "start_time", None) - elapsed = (time.monotonic() - start_time) if start_time is not None else 0.0 - - fn = getattr(retry_state, "fn", None) - fn_name = getattr(fn, "__name__", "unknown") if fn else "unknown" + elapsed = time.monotonic() - retry_state.start_time + fn_name = retry_state.fn.__name__ if retry_state.fn is not None else "unknown" # Extract exception info exception_type = "" exception_message = "" - outcome = getattr(retry_state, "outcome", None) - if outcome and getattr(outcome, "failed", False): - exc = outcome.exception() if hasattr(outcome, "exception") else None + outcome = retry_state.outcome + if outcome is not None and outcome.failed: + exc = outcome.exception() if exc: exception_type = type(exc).__name__ exception_message = str(exc) @@ -68,7 +62,7 @@ def record(self, *, retry_state: Any) -> None: pass event = RetryEvent( - attempt_number=call_count, + attempt_number=retry_state.attempt_number, function_name=fn_name, exception_type=exception_type, exception_message=exception_message, @@ -80,10 +74,10 @@ def record(self, *, retry_state: Any) -> None: self.events.append(event) -_retry_collector: ContextVar[Optional[RetryCollector]] = ContextVar("retry_collector", default=None) +_retry_collector: ContextVar[RetryCollector | None] = ContextVar("retry_collector", default=None) -def get_retry_collector() -> Optional[RetryCollector]: +def get_retry_collector() -> RetryCollector | None: """ Get the current retry collector. diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 1fda98d0f2..f8e45bd9dc 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -69,6 +69,9 @@ class AttackContext(StrategyContext, ABC, Generic[AttackParamsT]): _prepended_conversation_override: Optional[list[Message]] = None _memory_labels_override: Optional[dict[str, str]] = None + # Set by the ON_ERROR handler to link error AttackResults to ScenarioResults + _error_attack_result_id: str | None = None + # Convenience properties that delegate to params or overrides @property def objective(self) -> str: diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index a3e0d67b8c..3aa9ede03b 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -30,6 +30,14 @@ StrategyContextT = TypeVar("StrategyContextT", bound="StrategyContext") +class _StrategyRuntimeError(RuntimeError): + """RuntimeError subclass that carries an optional error_attack_result_id.""" + + def __init__(self, message: str, *, error_attack_result_id: str | None = None) -> None: + super().__init__(message) + self.error_attack_result_id = error_attack_result_id + + @dataclass class StrategyContext(ABC): # noqa: B024 """Base class for all strategy contexts.""" @@ -378,11 +386,9 @@ async def execute_with_context_async(self, *, context: StrategyContextT) -> Stra else: error_message = f"Strategy execution failed for {self.__class__.__name__}: {str(e)}" - runtime_error = RuntimeError(error_message) # Attach the error attack result ID if the ON_ERROR handler created one error_attack_result_id = getattr(context, "_error_attack_result_id", None) - if error_attack_result_id: - runtime_error.error_attack_result_id = error_attack_result_id # type: ignore[attr-defined] + runtime_error = _StrategyRuntimeError(error_message, error_attack_result_id=error_attack_result_id) raise runtime_error from e async def execute_async(self, **kwargs: Any) -> StrategyResultT: diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index ce8633a975..f376375ac9 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -909,8 +909,7 @@ def get_attack_result(self) -> AttackResult: if self.retry_events_json: from pyrit.models.retry_event import RetryEvent - for evt_dict in json.loads(self.retry_events_json): - retry_events.append(RetryEvent.from_dict(evt_dict)) + retry_events = [RetryEvent.from_dict(evt_dict) for evt_dict in json.loads(self.retry_events_json)] return AttackResult( conversation_id=self.conversation_id, diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index f6671a8e95..1093e1da02 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -116,5 +116,5 @@ "TextDataTypeSerializer", "UnvalidatedScore", "VideoPathDataTypeSerializer", - "RetryEvent" + "RetryEvent", ] diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 8e9d98d3e9..123c83a918 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -10,13 +10,13 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Optional, TypeVar -from pyrit.models.retry_event import RetryEvent from pyrit.models.strategy_result import StrategyResult if TYPE_CHECKING: from pyrit.identifiers.component_identifier import ComponentIdentifier from pyrit.models.conversation_reference import ConversationReference from pyrit.models.message_piece import MessagePiece + from pyrit.models.retry_event import RetryEvent from pyrit.models.score import Score from pyrit.models.conversation_reference import ConversationType diff --git a/tests/unit/executor/attack/core/test_attack_strategy.py b/tests/unit/executor/attack/core/test_attack_strategy.py index 7b6825775c..2387294449 100644 --- a/tests/unit/executor/attack/core/test_attack_strategy.py +++ b/tests/unit/executor/attack/core/test_attack_strategy.py @@ -6,6 +6,7 @@ import pytest +from pyrit.exceptions.retry_collector import RetryCollector from pyrit.executor.attack.core.attack_parameters import AttackParameters from pyrit.executor.attack.core.attack_strategy import ( AttackContext, @@ -20,6 +21,7 @@ AttackResult, Message, ) +from pyrit.models.retry_event import RetryEvent from pyrit.prompt_target import PromptTarget @@ -437,6 +439,149 @@ async def test_on_post_execute_raises_on_none_result(self, event_handler, sample with pytest.raises(ValueError, match="Attack result is None"): await event_handler.on_event(event_data) + async def test_on_post_execute_attaches_retry_events( + self, sample_attack_context, sample_attack_result, mock_memory + ): + """Test that post-execute handler attaches retry events from collector to the result""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + sample_attack_context.start_time = 100.0 + retry_event = RetryEvent(attempt_number=1, function_name="send_prompt_async", exception_type="RateLimitError") + + collector = RetryCollector(events=[retry_event]) + with patch( + "pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector + ): + event_data = StrategyEventData( + event=StrategyEvent.ON_POST_EXECUTE, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + result=sample_attack_result, + ) + await handler.on_event(event_data) + + assert sample_attack_result.retry_events == [retry_event] + assert sample_attack_result.total_retries == 1 + + async def test_on_post_execute_no_retry_events_when_collector_empty( + self, sample_attack_context, sample_attack_result, mock_memory + ): + """Test that post-execute handler does not set retry_events when collector has no events""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + sample_attack_context.start_time = 100.0 + collector = RetryCollector(events=[]) + with patch( + "pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector + ): + event_data = StrategyEventData( + event=StrategyEvent.ON_POST_EXECUTE, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + result=sample_attack_result, + ) + await handler.on_event(event_data) + + # Empty collector means the guard `if collector and collector.events` is False + assert not sample_attack_result.retry_events + assert sample_attack_result.total_retries == 0 + + async def test_on_error_attaches_retry_events(self, sample_attack_context, mock_memory): + """Test that error handler attaches collected retry events to the error AttackResult""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + sample_attack_context.start_time = 100.0 + retry_event = RetryEvent(attempt_number=2, function_name="send_prompt_async", exception_type="TimeoutError") + collector = RetryCollector(events=[retry_event]) + + with patch( + "pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector + ): + event_data = StrategyEventData( + event=StrategyEvent.ON_ERROR, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + error=RuntimeError("test error"), + ) + await handler.on_event(event_data) + + stored_result = mock_memory.add_attack_results_to_memory.call_args.kwargs["attack_results"][0] + assert stored_result.outcome == AttackOutcome.ERROR + assert stored_result.retry_events == [retry_event] + assert stored_result.total_retries == 1 + + async def test_on_error_empty_retry_events_when_no_collector(self, sample_attack_context, mock_memory): + """Test that error handler sets empty retry_events when no collector exists""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + sample_attack_context.start_time = 100.0 + + with patch( + "pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=None + ): + event_data = StrategyEventData( + event=StrategyEvent.ON_ERROR, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + error=RuntimeError("test error"), + ) + await handler.on_event(event_data) + + stored_result = mock_memory.add_attack_results_to_memory.call_args.kwargs["attack_results"][0] + assert stored_result.retry_events == [] + assert stored_result.total_retries == 0 + + async def test_on_error_persists_result_to_memory(self, sample_attack_context, mock_memory): + """Test that error handler creates an error AttackResult and persists it""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + sample_attack_context.start_time = 100.0 + error = ValueError("something broke") + + with patch( + "pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=None + ): + event_data = StrategyEventData( + event=StrategyEvent.ON_ERROR, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + error=error, + ) + with patch("time.perf_counter", return_value=100.5): + await handler.on_event(event_data) + + mock_memory.add_attack_results_to_memory.assert_called_once() + stored_result = mock_memory.add_attack_results_to_memory.call_args.kwargs["attack_results"][0] + assert stored_result.outcome == AttackOutcome.ERROR + assert stored_result.error_message == "something broke" + assert stored_result.error_type == "ValueError" + assert stored_result.execution_time_ms == 500 + + async def test_on_error_skips_when_no_error_or_context(self, mock_memory): + """Test that error handler returns early when error or context is None""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + event_data = StrategyEventData( + event=StrategyEvent.ON_ERROR, + strategy_name="TestStrategy", + strategy_id="test-id", + context=None, + error=RuntimeError("test"), + ) + await handler.on_event(event_data) + mock_memory.add_attack_results_to_memory.assert_not_called() + async def test_on_event_handles_other_events(self, event_handler, sample_attack_context, mock_logger): """Test that on_event handles events not in the specific handlers""" event_data = StrategyEventData( From 2213d3af1934e5d1523f081b6cb0af5a4412b1b9 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 12 May 2026 11:17:27 -0700 Subject: [PATCH 6/9] fix: address review issues in error/retry tracking - Replace placeholder Alembic revision ID (a1b2c3d4e5f6) with proper generated hash (4f9db4b0a77f) - Swap persist-before-set order in _on_error_async so context only gets error_attack_result_id after DB write succeeds; clear stale ID at handler entry - Rewrite update_scenario_error_attacks to do read-modify-write in a single DB session - Extract retry_events_to_response() shared helper in attack_mappers, replacing duplicate code in scenario_run_service - Remove inline import from loop body in scenario_run_service - Replace __import__("uuid") hack with proper top-level import - Rename _on_error to _on_error_async per style guide Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/mappers/attack_mappers.py | 46 ++++++++++++------- .../backend/services/scenario_run_service.py | 35 ++++---------- pyrit/executor/attack/core/attack_strategy.py | 18 +++++--- ...=> 4f9db4b0a77f_add_error_retry_fields.py} | 4 +- pyrit/memory/memory_interface.py | 23 ++++++---- 5 files changed, 67 insertions(+), 59 deletions(-) rename pyrit/memory/alembic/versions/{a1b2c3d4e5f6_add_error_retry_fields.py => 4f9db4b0a77f_add_error_retry_fields.py} (96%) diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 96e99cc933..9fe88ecf6c 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from pyrit.models.conversation_stats import ConversationStats + from pyrit.models.retry_event import RetryEvent # ============================================================================ # Domain → DTO (for API responses) @@ -182,6 +183,34 @@ def _resolve_media_url(*, value: Optional[str], data_type: str) -> Optional[str] return value +def retry_events_to_response(retry_events: list[RetryEvent] | None) -> list[RetryEventResponse] | None: + """ + Convert a list of RetryEvent domain objects to RetryEventResponse DTOs. + + Args: + retry_events: Domain retry events, or None. + + Returns: + List of RetryEventResponse DTOs, or None if the input is None or empty. + """ + if not retry_events: + return None + return [ + RetryEventResponse( + timestamp=evt.timestamp, + attempt_number=evt.attempt_number, + function_name=evt.function_name, + exception_type=evt.exception_type, + exception_message=evt.exception_message, + component_role=evt.component_role, + component_name=evt.component_name, + endpoint=evt.endpoint, + elapsed_seconds=evt.elapsed_seconds, + ) + for evt in retry_events + ] + + def attack_result_to_summary( ar: AttackResult, *, @@ -234,22 +263,7 @@ def attack_result_to_summary( ) # Build retry event responses if available - retry_event_responses = None - if ar.retry_events: - retry_event_responses = [ - RetryEventResponse( - timestamp=evt.timestamp, - attempt_number=evt.attempt_number, - function_name=evt.function_name, - exception_type=evt.exception_type, - exception_message=evt.exception_message, - component_role=evt.component_role, - component_name=evt.component_name, - endpoint=evt.endpoint, - elapsed_seconds=evt.elapsed_seconds, - ) - for evt in ar.retry_events - ] + retry_event_responses = retry_events_to_response(ar.retry_events) return AttackSummary( attack_result_id=ar.attack_result_id, diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 0ee8b96d2a..034258c836 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -15,6 +15,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any +from pyrit.backend.mappers.attack_mappers import retry_events_to_response from pyrit.backend.models.scenarios import ( AtomicAttackResults, AttackSummary, @@ -492,32 +493,14 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non timestamp = ar.timestamp or datetime.now(timezone.utc) - # Build retry event responses if available - retry_event_responses = None - retry_events = getattr(ar, "retry_events", None) - if isinstance(retry_events, list) and retry_events: - from pyrit.backend.models.attacks import RetryEventResponse - - retry_event_responses = [ - RetryEventResponse( - timestamp=evt.timestamp, - attempt_number=evt.attempt_number, - function_name=evt.function_name, - exception_type=evt.exception_type, - exception_message=evt.exception_message, - component_role=evt.component_role, - component_name=evt.component_name, - endpoint=evt.endpoint, - elapsed_seconds=evt.elapsed_seconds, - ) - for evt in retry_events - ] - - # Extract error/retry fields with safe defaults - ar_error_message = ar.error_message if isinstance(ar.error_message, str) else None - ar_error_type = ar.error_type if isinstance(ar.error_type, str) else None - ar_error_traceback = ar.error_traceback if isinstance(ar.error_traceback, str) else None - ar_total_retries = ar.total_retries if isinstance(ar.total_retries, int) else 0 + # Build retry event responses using the shared mapper + retry_event_responses = retry_events_to_response(ar.retry_events) + + # Extract error/retry fields + ar_error_message = ar.error_message + ar_error_type = ar.error_type + ar_error_traceback = ar.error_traceback + ar_total_retries = ar.total_retries details.append( AttackSummary( diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index f8e45bd9dc..a054ace813 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -7,6 +7,7 @@ import logging # noqa: TC003 import time import traceback +import uuid from abc import ABC from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, overload @@ -141,7 +142,7 @@ def __init__(self, logger: logging.Logger = logger) -> None: self._events = { StrategyEvent.ON_PRE_EXECUTE: self._on_pre_execute, StrategyEvent.ON_POST_EXECUTE: self._on_post_execute, - StrategyEvent.ON_ERROR: self._on_error, + StrategyEvent.ON_ERROR: self._on_error_async, } self._memory = CentralMemory.get_memory_instance() @@ -248,7 +249,9 @@ def _log_attack_outcome(self, result: AttackResult) -> None: self._logger.info(message) - async def _on_error(self, event_data: StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]) -> None: + async def _on_error_async( + self, event_data: StrategyEventData[AttackStrategyContextT, AttackStrategyResultT] + ) -> None: """ Handle error during attack execution. @@ -264,12 +267,15 @@ async def _on_error(self, event_data: StrategyEventData[AttackStrategyContextT, if not error or not context: return + # Clear any stale ID from a previous execution + context._error_attack_result_id = None + # Collect retry events (visible via inherited ContextVar copy) collector = get_retry_collector() retry_events = collector.events if collector else [] # Build a conversation_id — use context's if available, otherwise generate one - conversation_id = getattr(context, "conversation_id", None) or str(__import__("uuid").uuid4()) + conversation_id = getattr(context, "conversation_id", None) or str(uuid.uuid4()) error_result = AttackResult( conversation_id=conversation_id, @@ -289,11 +295,11 @@ async def _on_error(self, event_data: StrategyEventData[AttackStrategyContextT, if context.start_time: error_result.execution_time_ms = int((end_time - context.start_time) * 1000) - # Store the error attack result ID on the context so scenario-level - # code can link it to the ScenarioResult + # Persist first, then set the ID on the context so scenario-level code + # only sees the reference if the write succeeded. + self._memory.add_attack_results_to_memory(attack_results=[error_result]) context._error_attack_result_id = error_result.attack_result_id - self._memory.add_attack_results_to_memory(attack_results=[error_result]) self._logger.error(f"Attack failed with {type(error).__name__}: {error}") diff --git a/pyrit/memory/alembic/versions/a1b2c3d4e5f6_add_error_retry_fields.py b/pyrit/memory/alembic/versions/4f9db4b0a77f_add_error_retry_fields.py similarity index 96% rename from pyrit/memory/alembic/versions/a1b2c3d4e5f6_add_error_retry_fields.py rename to pyrit/memory/alembic/versions/4f9db4b0a77f_add_error_retry_fields.py index 7e244275c0..b90361c991 100644 --- a/pyrit/memory/alembic/versions/a1b2c3d4e5f6_add_error_retry_fields.py +++ b/pyrit/memory/alembic/versions/4f9db4b0a77f_add_error_retry_fields.py @@ -4,7 +4,7 @@ """ add error and retry fields to attack and scenario results. -Revision ID: a1b2c3d4e5f6 +Revision ID: 4f9db4b0a77f Revises: 108a72344872 Create Date: 2026-05-08 00:00:00.000000 """ @@ -15,7 +15,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision: str = "a1b2c3d4e5f6" +revision: str = "4f9db4b0a77f" down_revision: str | None = "108a72344872" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 0cfe03df06..9ebada2543 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -2067,6 +2067,9 @@ def update_scenario_error_attacks(self, *, scenario_result_id: str, error_attack This links failed AttackResults to the ScenarioResult so the REST API can quickly find error details without scanning all attacks. + Performs the read-modify-write within a single DB session to avoid + inter-session consistency issues. + Args: scenario_result_id: The ID of the scenario result to update. error_attack_result_ids: IDs of AttackResults that contain error information. @@ -2074,19 +2077,21 @@ def update_scenario_error_attacks(self, *, scenario_result_id: str, error_attack Returns: True if the update was successful, False otherwise. """ + import json + try: - scenario_results = self.get_scenario_results(scenario_result_ids=[scenario_result_id]) + with closing(self.get_session()) as session: + entry = session.query(ScenarioResultEntry).filter_by(id=scenario_result_id).first() - if not scenario_results: - logger.error(f"Scenario result with ID {scenario_result_id} not found in memory") - return False + if not entry: + logger.error(f"Scenario result with ID {scenario_result_id} not found in memory") + return False - scenario_result = scenario_results[0] - existing = scenario_result.error_attack_result_ids or [] - scenario_result.error_attack_result_ids = list(dict.fromkeys(existing + error_attack_result_ids)) + existing: list[str] = json.loads(entry.error_attack_result_ids_json) if entry.error_attack_result_ids_json else [] + merged = list(dict.fromkeys(existing + error_attack_result_ids)) + entry.error_attack_result_ids_json = json.dumps(merged) - entry = ScenarioResultEntry(entry=scenario_result) - self._update_entry(entry) + session.commit() logger.info( f"Updated scenario {scenario_result_id} with {len(error_attack_result_ids)} error attack result(s)" From f46b82ee513e2a69051746f6d189b304c8a3ee1b Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 12 May 2026 13:31:43 -0700 Subject: [PATCH 7/9] fix: set error/retry fields on mock in scenario run service test Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../unit/backend/test_scenario_run_service.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 26fa81a814..29551b3cae 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -294,6 +294,26 @@ def test_get_run_returns_existing_run(self, mock_memory) -> None: assert fetched.scenario_name == "foundry.red_team_agent" assert fetched.status == ScenarioRunStatus.IN_PROGRESS + def test_get_run_falls_back_to_persisted_error(self, mock_memory) -> None: + """Test that get_run extracts error from persisted error AttackResult when no active task.""" + db_result = _make_db_scenario_result(result_id="sr-fail", run_state="FAILED") + db_result.error_attack_result_ids = ["err-ar-1"] + + # Mock the error AttackResult lookup + error_ar = MagicMock() + error_ar.error_message = "Connection refused" + error_ar.error_type = "ConnectionError" + mock_memory.get_scenario_results.return_value = [db_result] + mock_memory.get_attack_results.return_value = [error_ar] + + service = ScenarioRunService() + fetched = service.get_run(scenario_result_id="sr-fail") + + assert fetched is not None + assert fetched.error == "Connection refused" + assert fetched.error_type == "ConnectionError" + mock_memory.get_attack_results.assert_called_once_with(attack_result_ids=["err-ar-1"]) + class TestScenarioRunServiceListRuns: """Tests for ScenarioRunService.list_runs.""" @@ -470,6 +490,11 @@ def test_get_results_returns_details_for_completed_run(self, mock_memory) -> Non mock_attack_result.executed_turns = 3 mock_attack_result.execution_time_ms = 1500 mock_attack_result.timestamp = None + mock_attack_result.error_message = None + mock_attack_result.error_type = None + mock_attack_result.error_traceback = None + mock_attack_result.total_retries = 0 + mock_attack_result.retry_events = [] db_result = _make_db_scenario_result( result_id="sr-123", From a52be76dd2aa6fe4307799dbb0fe1dba64916477 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 12 May 2026 16:40:36 -0700 Subject: [PATCH 8/9] pr feedback --- .../backend/services/scenario_run_service.py | 32 +++---- pyrit/exceptions/exceptions_helpers.py | 1 - ...8a72344872_add_labels_to_attack_results.py | 35 -------- .../4f9db4b0a77f_add_error_retry_fields.py | 46 ---------- .../versions/e373726d391b_initial_schema.py | 9 ++ pyrit/memory/memory_interface.py | 89 +++++++++---------- pyrit/memory/memory_models.py | 9 ++ pyrit/models/scenario_result.py | 6 ++ pyrit/scenario/core/scenario.py | 15 ++-- tests/unit/backend/test_mappers.py | 34 +++++++ tests/unit/exceptions/test_retry_collector.py | 37 ++++++++ .../attack/core/test_attack_strategy.py | 44 +++++---- .../test_interface_scenario_results.py | 54 +++++++++++ tests/unit/memory/test_memory_models.py | 14 +++ 14 files changed, 257 insertions(+), 168 deletions(-) delete mode 100644 pyrit/memory/alembic/versions/108a72344872_add_labels_to_attack_results.py delete mode 100644 pyrit/memory/alembic/versions/4f9db4b0a77f_add_error_retry_fields.py diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 034258c836..41843c2639 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -184,7 +184,12 @@ async def cancel_run_async(self, *, scenario_result_id: str) -> ScenarioRunSumma await asyncio.wait_for(active.task, timeout=5.0) # Persist cancelled state to DB - self._memory.update_scenario_run_state(scenario_result_id=scenario_result_id, scenario_run_state="CANCELLED") + self._memory.update_scenario_run_state( + scenario_result_id=scenario_result_id, + scenario_run_state="CANCELLED", + error_message="Run was cancelled by user", + error_type="CancelledError", + ) return self._build_response(scenario_result_id=scenario_result_id) @@ -396,22 +401,17 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari scenario_result_id = str(scenario_result.id) active = self._active_tasks.get(scenario_result_id) - # Clean up finished active tasks after reading the error - error = None - error_type = None - if active is not None: + # Clean up finished active tasks + if active is not None and active.task is not None and active.task.done(): + del self._active_tasks[scenario_result_id] + + # Primary source: DB-persisted error fields + error = scenario_result.error_message + error_type = scenario_result.error_type + + # Fallback: in-memory error for in-flight tasks where DB hasn't been updated yet + if not error and active is not None: error = active.error - if active.task is not None and active.task.done(): - del self._active_tasks[scenario_result_id] - - # Fall back to persisted error from failed AttackResult - if not error and getattr(scenario_result, "error_attack_result_ids", None): - error_ids = scenario_result.error_attack_result_ids - if isinstance(error_ids, list) and error_ids: - error_results = self._memory.get_attack_results(attack_result_ids=error_ids[:1]) - if error_results: - error = error_results[0].error_message - error_type = error_results[0].error_type status = ScenarioRunStatus(scenario_result.scenario_run_state) diff --git a/pyrit/exceptions/exceptions_helpers.py b/pyrit/exceptions/exceptions_helpers.py index 9cdef2fb8c..1ee2f8e9cf 100644 --- a/pyrit/exceptions/exceptions_helpers.py +++ b/pyrit/exceptions/exceptions_helpers.py @@ -11,7 +11,6 @@ from pyrit.exceptions.exception_context import get_execution_context from pyrit.exceptions.retry_collector import get_retry_collector - logger = logging.getLogger(__name__) diff --git a/pyrit/memory/alembic/versions/108a72344872_add_labels_to_attack_results.py b/pyrit/memory/alembic/versions/108a72344872_add_labels_to_attack_results.py deleted file mode 100644 index 66b1574b93..0000000000 --- a/pyrit/memory/alembic/versions/108a72344872_add_labels_to_attack_results.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -add labels to Attack Results. - -Revision ID: 108a72344872 -Revises: e373726d391b -Create Date: 2026-04-27 13:47:20.711347 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "108a72344872" -down_revision: str | None = "e373726d391b" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - """Apply this schema upgrade.""" - # ### commands auto generated by Alembic and reviewed by author ### - op.add_column("AttackResultEntries", sa.Column("labels", sa.JSON(), nullable=True)) - # ### end Alembic commands ### - - -def downgrade() -> None: - """Revert this schema upgrade.""" - # ### commands auto generated by Alembic and reviewed by author ### - op.drop_column("AttackResultEntries", "labels") - # ### end Alembic commands ### diff --git a/pyrit/memory/alembic/versions/4f9db4b0a77f_add_error_retry_fields.py b/pyrit/memory/alembic/versions/4f9db4b0a77f_add_error_retry_fields.py deleted file mode 100644 index b90361c991..0000000000 --- a/pyrit/memory/alembic/versions/4f9db4b0a77f_add_error_retry_fields.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -add error and retry fields to attack and scenario results. - -Revision ID: 4f9db4b0a77f -Revises: 108a72344872 -Create Date: 2026-05-08 00:00:00.000000 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "4f9db4b0a77f" -down_revision: str | None = "108a72344872" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - """Apply this schema upgrade.""" - # Error fields on AttackResultEntries - op.add_column("AttackResultEntries", sa.Column("error_message", sa.Unicode(), nullable=True)) - op.add_column("AttackResultEntries", sa.Column("error_type", sa.String(), nullable=True)) - op.add_column("AttackResultEntries", sa.Column("error_traceback", sa.Unicode(), nullable=True)) - - # Retry fields on AttackResultEntries - op.add_column("AttackResultEntries", sa.Column("retry_events_json", sa.Unicode(), nullable=True)) - op.add_column("AttackResultEntries", sa.Column("total_retries", sa.INTEGER(), nullable=True)) - - # Error pointer on ScenarioResultEntries - op.add_column("ScenarioResultEntries", sa.Column("error_attack_result_ids_json", sa.Unicode(), nullable=True)) - - -def downgrade() -> None: - """Revert this schema upgrade.""" - op.drop_column("AttackResultEntries", "error_message") - op.drop_column("AttackResultEntries", "error_type") - op.drop_column("AttackResultEntries", "error_traceback") - op.drop_column("AttackResultEntries", "retry_events_json") - op.drop_column("AttackResultEntries", "total_retries") - op.drop_column("ScenarioResultEntries", "error_attack_result_ids_json") diff --git a/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py index c4dd5a8d09..2ac9a666d6 100644 --- a/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py +++ b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py @@ -100,6 +100,9 @@ def upgrade() -> None: sa.Column("number_tries", sa.INTEGER(), nullable=False), sa.Column("completion_time", sa.DateTime(), nullable=False), sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("error_attack_result_ids_json", sa.Unicode(), nullable=True), + sa.Column("error_message", sa.Unicode(), nullable=True), + sa.Column("error_type", sa.String(), nullable=True), sa.PrimaryKeyConstraint("id"), ) op.create_table( @@ -180,6 +183,12 @@ def upgrade() -> None: sa.Column("adversarial_chat_conversation_ids", sa.JSON(), nullable=True), sa.Column("timestamp", sa.DateTime(), nullable=False), sa.Column("pyrit_version", sa.String(), nullable=True), + sa.Column("labels", sa.JSON(), nullable=True), + sa.Column("error_message", sa.Unicode(), nullable=True), + sa.Column("error_type", sa.String(), nullable=True), + sa.Column("error_traceback", sa.Unicode(), nullable=True), + sa.Column("retry_events_json", sa.Unicode(), nullable=True), + sa.Column("total_retries", sa.INTEGER(), nullable=True), sa.ForeignKeyConstraint( ["last_response_id"], ["PromptMemoryEntries.id"], diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 9ebada2543..65d4480aaf 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -2016,7 +2016,14 @@ def add_attack_results_to_scenario( logger.exception(f"Failed to add attack results to scenario {scenario_result_id}: {str(e)}") raise - def update_scenario_run_state(self, *, scenario_result_id: str, scenario_run_state: str) -> bool: + def update_scenario_run_state( + self, + *, + scenario_result_id: str, + scenario_run_state: str, + error_message: str | None = None, + error_type: str | None = None, + ) -> None: """ Update the run state of an existing scenario result. @@ -2024,43 +2031,34 @@ def update_scenario_run_state(self, *, scenario_result_id: str, scenario_run_sta scenario_result_id (str): The ID of the scenario result to update. scenario_run_state (str): The new state for the scenario (e.g., "CREATED", "IN_PROGRESS", "COMPLETED", "FAILED"). + error_message (str | None): Optional scenario-level error message. + error_type (str | None): Optional exception class name. - Returns: - bool: True if the update was successful, False otherwise. - - Example: - >>> memory.update_scenario_run_state( - ... scenario_result_id="123e4567-e89b-12d3-a456-426614174000", - ... scenario_run_state="COMPLETED" - ... ) + Raises: + ValueError: If the scenario result is not found. """ - try: - # Retrieve current scenario result - scenario_results = self.get_scenario_results(scenario_result_ids=[scenario_result_id]) + scenario_results = self.get_scenario_results(scenario_result_ids=[scenario_result_id]) - if not scenario_results: - logger.error(f"Scenario result with ID {scenario_result_id} not found in memory") - return False + if not scenario_results: + raise ValueError(f"Scenario result with ID {scenario_result_id} not found in memory") - scenario_result = scenario_results[0] + scenario_result = scenario_results[0] - # Update the scenario run state - scenario_result.scenario_run_state = scenario_run_state # type: ignore[ty:invalid-assignment] + # Update the scenario run state + scenario_result.scenario_run_state = scenario_run_state # type: ignore[ty:invalid-assignment] - # Save updated result back to memory using update - entry = ScenarioResultEntry(entry=scenario_result) - self._update_entry(entry) + if error_message is not None: + scenario_result.error_message = error_message + if error_type is not None: + scenario_result.error_type = error_type - logger.info(f"Updated scenario {scenario_result_id} state to '{scenario_run_state}'") - return True + # Save updated result back to memory using update + entry = ScenarioResultEntry(entry=scenario_result) + self._update_entry(entry) - except Exception as e: - logger.exception( - f"Failed to update scenario {scenario_result_id} state to '{scenario_run_state}': {str(e)}" - ) - raise + logger.info(f"Updated scenario {scenario_result_id} state to '{scenario_run_state}'") - def update_scenario_error_attacks(self, *, scenario_result_id: str, error_attack_result_ids: list[str]) -> bool: + def update_scenario_error_attacks(self, *, scenario_result_id: str, error_attack_result_ids: list[str]) -> None: """ Update the error attack result IDs on an existing scenario result. @@ -2074,33 +2072,26 @@ def update_scenario_error_attacks(self, *, scenario_result_id: str, error_attack scenario_result_id: The ID of the scenario result to update. error_attack_result_ids: IDs of AttackResults that contain error information. - Returns: - True if the update was successful, False otherwise. + Raises: + ValueError: If the scenario result is not found. """ import json - try: - with closing(self.get_session()) as session: - entry = session.query(ScenarioResultEntry).filter_by(id=scenario_result_id).first() - - if not entry: - logger.error(f"Scenario result with ID {scenario_result_id} not found in memory") - return False + with closing(self.get_session()) as session: + entry = session.query(ScenarioResultEntry).filter_by(id=scenario_result_id).first() - existing: list[str] = json.loads(entry.error_attack_result_ids_json) if entry.error_attack_result_ids_json else [] - merged = list(dict.fromkeys(existing + error_attack_result_ids)) - entry.error_attack_result_ids_json = json.dumps(merged) + if not entry: + raise ValueError(f"Scenario result with ID {scenario_result_id} not found in memory") - session.commit() - - logger.info( - f"Updated scenario {scenario_result_id} with {len(error_attack_result_ids)} error attack result(s)" + existing: list[str] = ( + json.loads(entry.error_attack_result_ids_json) if entry.error_attack_result_ids_json else [] ) - return True + merged = list(dict.fromkeys(existing + error_attack_result_ids)) + entry.error_attack_result_ids_json = json.dumps(merged) - except Exception as e: - logger.exception(f"Failed to update scenario {scenario_result_id} error attacks: {str(e)}") - raise + session.commit() + + logger.info(f"Updated scenario {scenario_result_id} with {len(error_attack_result_ids)} error attack result(s)") def get_scenario_results( self, diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index f376375ac9..1e48c03cf5 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -995,6 +995,10 @@ class ScenarioResultEntry(Base): # Pointer to failed attack result(s) — avoids scanning all attacks for error info error_attack_result_ids_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) + # Scenario-level error info (persisted so it survives process restarts) + error_message: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) + error_type: Mapped[Optional[str]] = mapped_column(String, nullable=True) + def __init__(self, *, entry: ScenarioResult) -> None: """ Initialize a ScenarioResultEntry from a ScenarioResult object. @@ -1045,6 +1049,9 @@ def __init__(self, *, entry: ScenarioResult) -> None: json.dumps(entry.error_attack_result_ids) if entry.error_attack_result_ids else None ) + self.error_message = entry.error_message + self.error_type = entry.error_type + self.timestamp = datetime.now(tz=timezone.utc) def get_scenario_result(self) -> ScenarioResult: @@ -1104,6 +1111,8 @@ def get_scenario_result(self) -> ScenarioResult: completion_time=self.completion_time, display_group_map=display_group_map, error_attack_result_ids=error_attack_result_ids, + error_message=self.error_message, + error_type=self.error_type, ) def get_conversation_ids_by_attack_name(self) -> dict[str, list[str]]: diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index 6a657a1d36..88a67f5991 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -70,6 +70,8 @@ def __init__( id: uuid.UUID | None = None, # noqa: A002 display_group_map: dict[str, str] | None = None, error_attack_result_ids: list[str] | None = None, + error_message: str | None = None, + error_type: str | None = None, ) -> None: """ Initialize a scenario result. @@ -91,6 +93,8 @@ def __init__( error_attack_result_ids (Optional[list[str]]): IDs of AttackResults that contain error information. Used for quick error lookup without scanning all attack results. + error_message (Optional[str]): Scenario-level error message when the run fails. + error_type (Optional[str]): Exception class name when the run fails. """ self.id = id if id is not None else uuid.uuid4() @@ -108,6 +112,8 @@ def __init__( self.number_tries = number_tries self._display_group_map = display_group_map or {} self.error_attack_result_ids = error_attack_result_ids or [] + self.error_message = error_message + self.error_type = error_type @property def display_group_map(self) -> dict[str, str]: diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 34618cba1c..7a30cb5a1c 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -1110,17 +1110,20 @@ async def _execute_scenario_async(self) -> ScenarioResult: ) # Mark scenario as failed + error_msg = ( + f"Atomic attack '{atomic_attack.atomic_attack_name}' partially failed: " + f"{incomplete_count} of {incomplete_count + completed_count} objectives incomplete. " + f"See attack results for details." + ) self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="FAILED", + error_message=error_msg, + error_type=type(atomic_results.incomplete_objectives[0][1]).__name__, ) # Raise exception with detailed information - raise ValueError( - f"Failed to execute atomic attack {i} ('{atomic_attack.atomic_attack_name}') " - f"in scenario '{self._name}': {incomplete_count} of {incomplete_count + completed_count} " - f"objectives incomplete. First failure: {atomic_results.incomplete_objectives[0][1]}" - ) from atomic_results.incomplete_objectives[0][1] + raise ValueError(error_msg) from atomic_results.incomplete_objectives[0][1] logger.info( f"Atomic attack {i}/{len(self._atomic_attacks)} completed successfully with " f"{len(atomic_results.completed_results)} results" @@ -1139,6 +1142,8 @@ async def _execute_scenario_async(self) -> ScenarioResult: self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="FAILED", + error_message=str(e), + error_type=type(e).__name__, ) raise diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index ec1f403c9a..d3aef9a4ed 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -365,6 +365,40 @@ def test_created_at_falls_back_to_now_when_both_absent(self) -> None: assert before <= summary.created_at <= after + def test_retry_events_mapped_to_response(self) -> None: + """Test that retry events on an AttackResult are mapped to RetryEventResponse DTOs.""" + from pyrit.models.retry_event import RetryEvent + + now = datetime.now(timezone.utc) + ar = _make_attack_result() + ar.retry_events = [ + RetryEvent( + timestamp=now, + attempt_number=1, + function_name="send_prompt_async", + exception_type="RateLimitError", + exception_message="rate limit exceeded", + component_role="objective_target", + component_name="OpenAIChatTarget", + endpoint="https://api.openai.com", + elapsed_seconds=1.5, + ), + ] + ar.total_retries = 1 + + stats = ConversationStats(message_count=0) + summary = attack_result_to_summary(ar, stats=stats) + + assert summary.retry_events is not None + assert len(summary.retry_events) == 1 + evt = summary.retry_events[0] + assert evt.attempt_number == 1 + assert evt.function_name == "send_prompt_async" + assert evt.exception_type == "RateLimitError" + assert evt.component_role == "objective_target" + assert evt.elapsed_seconds == 1.5 + assert summary.total_retries == 1 + """Tests for pyrit_scores_to_dto function.""" def test_maps_scores(self) -> None: diff --git a/tests/unit/exceptions/test_retry_collector.py b/tests/unit/exceptions/test_retry_collector.py index 14db749b65..bc338fcc23 100644 --- a/tests/unit/exceptions/test_retry_collector.py +++ b/tests/unit/exceptions/test_retry_collector.py @@ -110,3 +110,40 @@ async def run() -> None: asyncio.run(run()) assert results.get("a_has_collector") is True assert results.get("b_sees_none") is True + + def test_record_extracts_execution_context(self) -> None: + """record() extracts component_role, component_name, and endpoint from ExecutionContext.""" + from unittest.mock import MagicMock + + from pyrit.exceptions.exception_context import ( + ComponentRole, + ExecutionContext, + set_execution_context, + ) + + c = RetryCollector() + + ctx = ExecutionContext( + component_role=ComponentRole.OBJECTIVE_TARGET, + component_name="OpenAIChatTarget", + endpoint="https://api.openai.com", + ) + set_execution_context(ctx) + + retry_state = MagicMock() + retry_state.attempt_number = 1 + retry_state.start_time = 0.0 + retry_state.fn = MagicMock() + retry_state.fn.__name__ = "send_prompt_async" + outcome = MagicMock() + outcome.failed = True + outcome.exception.return_value = RuntimeError("timeout") + retry_state.outcome = outcome + + c.record(retry_state=retry_state) + + assert len(c.events) == 1 + evt = c.events[0] + assert evt.component_role == "objective_target" + assert evt.component_name == "OpenAIChatTarget" + assert evt.endpoint == "https://api.openai.com" diff --git a/tests/unit/executor/attack/core/test_attack_strategy.py b/tests/unit/executor/attack/core/test_attack_strategy.py index 2387294449..33370f0c9d 100644 --- a/tests/unit/executor/attack/core/test_attack_strategy.py +++ b/tests/unit/executor/attack/core/test_attack_strategy.py @@ -399,6 +399,26 @@ async def test_on_post_execute_logs_undetermined( expected_message = f"{event_handler.__class__.__name__} outcome is undetermined. Reason: Not specified" mock_logger.info.assert_called_with(expected_message) + async def test_on_post_execute_logs_error_outcome( + self, event_handler, sample_attack_context, sample_attack_result, mock_logger + ): + """Test that post-execute handler logs error outcome""" + sample_attack_result.outcome = AttackOutcome.ERROR + sample_attack_result.outcome_reason = "Connection timeout" + + event_data = StrategyEventData( + event=StrategyEvent.ON_POST_EXECUTE, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + result=sample_attack_result, + ) + + await event_handler.on_event(event_data) + + expected_message = f"{event_handler.__class__.__name__} failed with an error. Reason: Connection timeout" + mock_logger.info.assert_called_with(expected_message) + async def test_on_post_execute_adds_results_to_memory(self, mock_memory): """Test that post-execute handler adds results to memory""" with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): @@ -447,12 +467,12 @@ async def test_on_post_execute_attaches_retry_events( handler = _DefaultAttackStrategyEventHandler() sample_attack_context.start_time = 100.0 - retry_event = RetryEvent(attempt_number=1, function_name="send_prompt_async", exception_type="RateLimitError") + retry_event = RetryEvent( + attempt_number=1, function_name="send_prompt_async", exception_type="RateLimitError" + ) collector = RetryCollector(events=[retry_event]) - with patch( - "pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector - ): + with patch("pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector): event_data = StrategyEventData( event=StrategyEvent.ON_POST_EXECUTE, strategy_name="TestStrategy", @@ -474,9 +494,7 @@ async def test_on_post_execute_no_retry_events_when_collector_empty( sample_attack_context.start_time = 100.0 collector = RetryCollector(events=[]) - with patch( - "pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector - ): + with patch("pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector): event_data = StrategyEventData( event=StrategyEvent.ON_POST_EXECUTE, strategy_name="TestStrategy", @@ -499,9 +517,7 @@ async def test_on_error_attaches_retry_events(self, sample_attack_context, mock_ retry_event = RetryEvent(attempt_number=2, function_name="send_prompt_async", exception_type="TimeoutError") collector = RetryCollector(events=[retry_event]) - with patch( - "pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector - ): + with patch("pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector): event_data = StrategyEventData( event=StrategyEvent.ON_ERROR, strategy_name="TestStrategy", @@ -523,9 +539,7 @@ async def test_on_error_empty_retry_events_when_no_collector(self, sample_attack sample_attack_context.start_time = 100.0 - with patch( - "pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=None - ): + with patch("pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=None): event_data = StrategyEventData( event=StrategyEvent.ON_ERROR, strategy_name="TestStrategy", @@ -547,9 +561,7 @@ async def test_on_error_persists_result_to_memory(self, sample_attack_context, m sample_attack_context.start_time = 100.0 error = ValueError("something broke") - with patch( - "pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=None - ): + with patch("pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=None): event_data = StrategyEventData( event=StrategyEvent.ON_ERROR, strategy_name="TestStrategy", diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 0fec3aceaa..97a240e37c 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -648,6 +648,60 @@ def test_combined_filters(sqlite_instance: MemoryInterface): assert "gpt-4" in results[0].objective_target_identifier.params["model_name"] +def test_update_scenario_error_attacks_success(sqlite_instance: MemoryInterface, sample_attack_results): + """Test successfully linking error attack result IDs to a scenario result.""" + scenario_result = create_scenario_result( + name="Error Scenario", + attack_results={"Attack1": [sample_attack_results[0]]}, + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) + + error_ids = ["error-ar-1", "error-ar-2"] + sqlite_instance.update_scenario_error_attacks( + scenario_result_id=str(scenario_result.id), + error_attack_result_ids=error_ids, + ) + + # Verify the error IDs were persisted + results = sqlite_instance.get_scenario_results(scenario_result_ids=[str(scenario_result.id)]) + assert len(results) == 1 + assert results[0].error_attack_result_ids == error_ids + + +def test_update_scenario_error_attacks_appends_to_existing(sqlite_instance: MemoryInterface, sample_attack_results): + """Test that updating error attacks appends to existing IDs without duplicates.""" + scenario_result = create_scenario_result( + name="Error Scenario", + attack_results={"Attack1": [sample_attack_results[0]]}, + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) + + # First update + sqlite_instance.update_scenario_error_attacks( + scenario_result_id=str(scenario_result.id), + error_attack_result_ids=["error-ar-1"], + ) + + # Second update with overlap and new ID + sqlite_instance.update_scenario_error_attacks( + scenario_result_id=str(scenario_result.id), + error_attack_result_ids=["error-ar-1", "error-ar-2"], + ) + + results = sqlite_instance.get_scenario_results(scenario_result_ids=[str(scenario_result.id)]) + # Should be deduplicated: ["error-ar-1", "error-ar-2"] + assert results[0].error_attack_result_ids == ["error-ar-1", "error-ar-2"] + + +def test_update_scenario_error_attacks_not_found(sqlite_instance: MemoryInterface): + """Test that updating a nonexistent scenario result raises ValueError.""" + with pytest.raises(ValueError, match="not found in memory"): + sqlite_instance.update_scenario_error_attacks( + scenario_result_id="nonexistent-id", + error_attack_result_ids=["error-ar-1"], + ) + + def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: MemoryInterface): """Test filtering scenario results by identifier filter.""" target_id_1 = ComponentIdentifier( diff --git a/tests/unit/memory/test_memory_models.py b/tests/unit/memory/test_memory_models.py index f143a69e98..b382bf8646 100644 --- a/tests/unit/memory/test_memory_models.py +++ b/tests/unit/memory/test_memory_models.py @@ -474,3 +474,17 @@ def test_init_with_empty_attack_results(self): entry = ScenarioResultEntry(entry=sr) conv_ids = entry.get_conversation_ids_by_attack_name() assert conv_ids == {} + + def test_roundtrip_error_attack_result_ids(self): + sr = self._make_scenario_result(error_attack_result_ids=["err-1", "err-2"]) + entry = ScenarioResultEntry(entry=sr) + assert entry.error_attack_result_ids_json is not None + recovered = entry.get_scenario_result() + assert recovered.error_attack_result_ids == ["err-1", "err-2"] + + def test_roundtrip_error_attack_result_ids_none(self): + sr = self._make_scenario_result() + entry = ScenarioResultEntry(entry=sr) + assert entry.error_attack_result_ids_json is None + recovered = entry.get_scenario_result() + assert recovered.error_attack_result_ids == [] From a31088cb863788b7138f9852ab2fd4ab1f1c6a5d Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 12 May 2026 16:56:18 -0700 Subject: [PATCH 9/9] test fix --- pyrit/backend/services/scenario_run_service.py | 7 +++++++ tests/unit/backend/test_scenario_run_service.py | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 41843c2639..fdcb285c6d 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -409,6 +409,13 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari error = scenario_result.error_message error_type = scenario_result.error_type + # Fallback: look up error from persisted error AttackResults + if not error and scenario_result.error_attack_result_ids: + error_ars = self._memory.get_attack_results(attack_result_ids=scenario_result.error_attack_result_ids) + if error_ars: + error = error_ars[0].error_message + error_type = error_ars[0].error_type + # Fallback: in-memory error for in-flight tasks where DB hasn't been updated yet if not error and active is not None: error = active.error diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 29551b3cae..e29686599a 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -72,6 +72,9 @@ def _make_db_scenario_result( sr.objective_achieved_rate.return_value = 0 sr.get_display_groups.return_value = {} sr.display_group_map = {} + sr.error_message = None + sr.error_type = None + sr.error_attack_result_ids = [] return sr @@ -371,7 +374,10 @@ async def test_cancel_run_sets_cancelled_status(self, mock_all_registries) -> No result = await service.cancel_run_async(scenario_result_id=response.scenario_result_id) mock_memory.update_scenario_run_state.assert_called_once_with( - scenario_result_id=response.scenario_result_id, scenario_run_state="CANCELLED" + scenario_result_id=response.scenario_result_id, + scenario_run_state="CANCELLED", + error_message="Run was cancelled by user", + error_type="CancelledError", ) assert result is not None assert result.status == ScenarioRunStatus.CANCELLED