diff --git a/frontend/src/components/History/AttackTable.tsx b/frontend/src/components/History/AttackTable.tsx index 30d52998b6..16d052d7c7 100644 --- a/frontend/src/components/History/AttackTable.tsx +++ b/frontend/src/components/History/AttackTable.tsx @@ -16,6 +16,7 @@ import { CheckmarkCircleRegular, DismissCircleRegular, QuestionCircleRegular, + ErrorCircleRegular, } from '@fluentui/react-icons' import type { AttackSummary } from '../../types' import { useAttackHistoryStyles } from './AttackHistory.styles' @@ -23,12 +24,14 @@ import { useAttackHistoryStyles } from './AttackHistory.styles' const OUTCOME_ICONS: Record = { success: , failure: , + error: , undetermined: , } -const OUTCOME_COLORS: Record = { +const OUTCOME_COLORS: Record = { success: 'success', failure: 'danger', + error: 'warning', undetermined: 'informative', } diff --git a/frontend/src/components/History/HistoryFiltersBar.tsx b/frontend/src/components/History/HistoryFiltersBar.tsx index 053fa8a7d3..7165cd3960 100644 --- a/frontend/src/components/History/HistoryFiltersBar.tsx +++ b/frontend/src/components/History/HistoryFiltersBar.tsx @@ -21,6 +21,7 @@ const NO_CONVERTERS_SENTINEL = '__no_converters__' const OUTCOME_LABELS: Record = { success: 'Success', failure: 'Failure', + error: 'Error', undetermined: 'Undetermined', } @@ -200,6 +201,7 @@ export default function HistoryFiltersBar({ + | null target?: TargetInfo | null converters: string[] - outcome?: 'undetermined' | 'success' | 'failure' | null + outcome?: 'undetermined' | 'success' | 'failure' | 'error' | null last_message_preview?: string | null message_count: number related_conversation_ids: string[] diff --git a/pyrit/analytics/result_analysis.py b/pyrit/analytics/result_analysis.py index a403d1aa37..d2e998ee94 100644 --- a/pyrit/analytics/result_analysis.py +++ b/pyrit/analytics/result_analysis.py @@ -17,9 +17,10 @@ class AttackStats: successes: int failures: int undetermined: int + errors: int -def _compute_stats(successes: int, failures: int, undetermined: int) -> AttackStats: +def _compute_stats(successes: int, failures: int, undetermined: int, errors: int) -> AttackStats: total_decided = successes + failures success_rate = successes / total_decided if total_decided > 0 else None return AttackStats( @@ -28,6 +29,7 @@ def _compute_stats(successes: int, failures: int, undetermined: int) -> AttackSt successes=successes, failures=failures, undetermined=undetermined, + errors=errors, ) @@ -71,6 +73,9 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats elif outcome == AttackOutcome.FAILURE: overall_counts["failures"] += 1 by_type_counts[attack_type]["failures"] += 1 + elif outcome == AttackOutcome.ERROR: + overall_counts["errors"] += 1 + by_type_counts[attack_type]["errors"] += 1 else: overall_counts["undetermined"] += 1 by_type_counts[attack_type]["undetermined"] += 1 @@ -79,6 +84,7 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats successes=overall_counts["successes"], failures=overall_counts["failures"], undetermined=overall_counts["undetermined"], + errors=overall_counts["errors"], ) by_type_stats = { @@ -86,6 +92,7 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats successes=counts["successes"], failures=counts["failures"], undetermined=counts["undetermined"], + errors=counts["errors"], ) for attack_type, counts in by_type_counts.items() } diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 9d6ffac1ae..9fe88ecf6c 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -30,6 +30,7 @@ Message, MessagePiece, MessagePieceRequest, + RetryEventResponse, Score, TargetInfo, ) @@ -43,6 +44,7 @@ if TYPE_CHECKING: from pyrit.models.conversation_stats import ConversationStats + from pyrit.models.retry_event import RetryEvent # ============================================================================ # Domain → DTO (for API responses) @@ -181,6 +183,34 @@ def _resolve_media_url(*, value: Optional[str], data_type: str) -> Optional[str] return value +def retry_events_to_response(retry_events: list[RetryEvent] | None) -> list[RetryEventResponse] | None: + """ + Convert a list of RetryEvent domain objects to RetryEventResponse DTOs. + + Args: + retry_events: Domain retry events, or None. + + Returns: + List of RetryEventResponse DTOs, or None if the input is None or empty. + """ + if not retry_events: + return None + return [ + RetryEventResponse( + timestamp=evt.timestamp, + attempt_number=evt.attempt_number, + function_name=evt.function_name, + exception_type=evt.exception_type, + exception_message=evt.exception_message, + component_role=evt.component_role, + component_name=evt.component_name, + endpoint=evt.endpoint, + elapsed_seconds=evt.elapsed_seconds, + ) + for evt in retry_events + ] + + def attack_result_to_summary( ar: AttackResult, *, @@ -232,6 +262,9 @@ def attack_result_to_summary( else None ) + # Build retry event responses if available + retry_event_responses = retry_events_to_response(ar.retry_events) + return AttackSummary( attack_result_id=ar.attack_result_id, conversation_id=ar.conversation_id, @@ -246,6 +279,11 @@ def attack_result_to_summary( labels=labels, created_at=created_at, updated_at=updated_at, + error_message=ar.error_message, + error_type=ar.error_type, + error_traceback=ar.error_traceback, + total_retries=ar.total_retries, + retry_events=retry_event_responses, ) diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 64d6269e44..2f98f78b7e 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -90,6 +90,20 @@ class TargetInfo(BaseModel): model_name: Optional[str] = Field(None, description="Model or deployment name") +class RetryEventResponse(BaseModel): + """A single retry attempt captured during execution.""" + + timestamp: datetime = Field(..., description="When the retry occurred") + attempt_number: int = Field(..., ge=1, description="Tenacity attempt number (1-based)") + function_name: str = Field(..., description="The retried function name") + exception_type: str = Field("", description="Exception class name") + exception_message: str = Field("", description="Exception message") + component_role: str = Field("", description="Component role from ExecutionContext") + component_name: str | None = Field(None, description="Component class name") + endpoint: str | None = Field(None, description="Target endpoint URL") + elapsed_seconds: float = Field(0.0, ge=0, description="Time since first attempt in seconds") + + class AttackSummary(BaseModel): """Summary view of an attack (for list views, omits full message content).""" @@ -102,7 +116,7 @@ class AttackSummary(BaseModel): default_factory=list, description="Request converter class names applied in this attack" ) objective: str = Field("", description="Natural-language description of the attacker's objective") - outcome: Optional[Literal["undetermined", "success", "failure"]] = Field( + outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = Field( None, description="Attack outcome (null if not yet determined)" ) outcome_reason: str | None = Field(None, description="Reason for the outcome") @@ -121,6 +135,17 @@ class AttackSummary(BaseModel): created_at: datetime = Field(..., description="Attack creation timestamp") updated_at: datetime = Field(..., description="Last update timestamp") + # Error information + error_message: str | None = Field(None, description="Error message if the attack failed with an exception") + error_type: str | None = Field(None, description="Exception class name (e.g., 'RateLimitError')") + error_traceback: str | None = Field(None, description="Formatted traceback string") + + # Retry information + total_retries: int = Field(0, ge=0, description="Total number of retries during this attack") + retry_events: list[RetryEventResponse] | None = Field( + None, description="Detailed retry events (omitted in list views unless requested)" + ) + # ============================================================================ # Conversation Messages Response @@ -236,7 +261,7 @@ class CreateAttackResponse(BaseModel): class UpdateAttackRequest(BaseModel): """Request to update an attack's outcome.""" - outcome: Literal["undetermined", "success", "failure"] = Field(..., description="Updated attack outcome") + outcome: Literal["undetermined", "success", "failure", "error"] = Field(..., description="Updated attack outcome") # ============================================================================ diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index e628020c2f..4dd8d31a84 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -98,6 +98,7 @@ class ScenarioRunSummary(BaseModel): created_at: datetime = Field(..., description="When the run was created") updated_at: datetime = Field(..., description="When the run status last changed") error: str | None = Field(None, description="Error message if status is FAILED") + error_type: str | None = Field(None, description="Exception class name if status is FAILED") strategies_used: list[str] = Field(default_factory=list, description="Strategy names that were executed") total_attacks: int = Field(0, ge=0, description="Total number of atomic attacks") completed_attacks: int = Field(0, ge=0, description="Number of attacks that completed") @@ -126,6 +127,8 @@ class AtomicAttackResults(BaseModel): success_count: int = Field(0, ge=0, description="Number of successful attacks") failure_count: int = Field(0, ge=0, description="Number of failed attacks") total_count: int = Field(0, ge=0, description="Total number of attack results") + total_retries: int = Field(0, ge=0, description="Sum of retries across all attacks in this group") + error_count: int = Field(0, ge=0, description="Number of attacks with errors") class ScenarioRunDetail(BaseModel): diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index 83f8ece989..38bb6991a0 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -93,7 +93,9 @@ async def list_attacks( description="Filter by converter presence. true = attacks with at least one converter; " "false = attacks with no converters. Omit for no filter.", ), - outcome: Optional[Literal["undetermined", "success", "failure"]] = Query(None, description="Filter by outcome"), + outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = Query( + None, description="Filter by outcome" + ), label: Optional[list[str]] = Query( None, description="Filter by labels (format: key:value). May be specified multiple times; " diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index f2da9ab7e9..5cfd83a7ae 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -86,7 +86,7 @@ async def list_attacks_async( converter_types: Optional[Sequence[str]] = None, converter_types_match: Literal["any", "all"] = "all", has_converters: Optional[bool] = None, - outcome: Optional[Literal["undetermined", "success", "failure"]] = None, + outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = None, labels: Optional[dict[str, str | Sequence[str]]] = None, min_turns: Optional[int] = None, max_turns: Optional[int] = None, @@ -370,6 +370,7 @@ async def update_attack_async( "undetermined": AttackOutcome.UNDETERMINED, "success": AttackOutcome.SUCCESS, "failure": AttackOutcome.FAILURE, + "error": AttackOutcome.ERROR, } new_outcome = outcome_map.get(request.outcome, AttackOutcome.UNDETERMINED) diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 26f9b21f60..fdcb285c6d 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -15,6 +15,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any +from pyrit.backend.mappers.attack_mappers import retry_events_to_response from pyrit.backend.models.scenarios import ( AtomicAttackResults, AttackSummary, @@ -183,7 +184,12 @@ async def cancel_run_async(self, *, scenario_result_id: str) -> ScenarioRunSumma await asyncio.wait_for(active.task, timeout=5.0) # Persist cancelled state to DB - self._memory.update_scenario_run_state(scenario_result_id=scenario_result_id, scenario_run_state="CANCELLED") + self._memory.update_scenario_run_state( + scenario_result_id=scenario_result_id, + scenario_run_state="CANCELLED", + error_message="Run was cancelled by user", + error_type="CancelledError", + ) return self._build_response(scenario_result_id=scenario_result_id) @@ -395,12 +401,24 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari scenario_result_id = str(scenario_result.id) active = self._active_tasks.get(scenario_result_id) - # Clean up finished active tasks after reading the error - error = None - if active is not None: + # Clean up finished active tasks + if active is not None and active.task is not None and active.task.done(): + del self._active_tasks[scenario_result_id] + + # Primary source: DB-persisted error fields + error = scenario_result.error_message + error_type = scenario_result.error_type + + # Fallback: look up error from persisted error AttackResults + if not error and scenario_result.error_attack_result_ids: + error_ars = self._memory.get_attack_results(attack_result_ids=scenario_result.error_attack_result_ids) + if error_ars: + error = error_ars[0].error_message + error_type = error_ars[0].error_type + + # Fallback: in-memory error for in-flight tasks where DB hasn't been updated yet + if not error and active is not None: error = active.error - if active.task is not None and active.task.done(): - del self._active_tasks[scenario_result_id] status = ScenarioRunStatus(scenario_result.scenario_run_state) @@ -426,6 +444,7 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari created_at=scenario_result.creation_time, updated_at=scenario_result.completion_time, error=error, + error_type=error_type, strategies_used=strategies_used, total_attacks=total_attacks, completed_attacks=completed_attacks, @@ -467,6 +486,8 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non details: list[AttackSummary] = [] success_count = 0 failure_count = 0 + group_total_retries = 0 + group_error_count = 0 for ar in attack_results: score_value = None @@ -478,6 +499,16 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non last_response_text = str(ar.last_response) timestamp = ar.timestamp or datetime.now(timezone.utc) + + # Build retry event responses using the shared mapper + retry_event_responses = retry_events_to_response(ar.retry_events) + + # Extract error/retry fields + ar_error_message = ar.error_message + ar_error_type = ar.error_type + ar_error_traceback = ar.error_traceback + ar_total_retries = ar.total_retries + details.append( AttackSummary( attack_result_id=ar.attack_result_id, @@ -491,6 +522,11 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non execution_time_ms=ar.execution_time_ms, created_at=timestamp, updated_at=timestamp, + error_message=ar_error_message, + error_type=ar_error_type, + error_traceback=ar_error_traceback, + total_retries=ar_total_retries, + retry_events=retry_event_responses, ) ) @@ -499,6 +535,10 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non elif ar.outcome == AttackOutcome.FAILURE: failure_count += 1 + group_total_retries += ar_total_retries + if ar_error_message: + group_error_count += 1 + attacks.append( AtomicAttackResults( atomic_attack_name=attack_name, @@ -507,6 +547,8 @@ def get_run_results(self, *, scenario_result_id: str) -> ScenarioRunDetail | Non success_count=success_count, failure_count=failure_count, total_count=len(details), + total_retries=group_total_retries, + error_count=group_error_count, ) ) diff --git a/pyrit/exceptions/__init__.py b/pyrit/exceptions/__init__.py index 9da9650c2d..abd42de031 100644 --- a/pyrit/exceptions/__init__.py +++ b/pyrit/exceptions/__init__.py @@ -27,15 +27,23 @@ set_execution_context, ) from pyrit.exceptions.exceptions_helpers import remove_markdown_json +from pyrit.exceptions.retry_collector import ( + RetryCollector, + clear_retry_collector, + get_retry_collector, + set_retry_collector, +) __all__ = [ "BadRequestException", "clear_execution_context", + "clear_retry_collector", "ComponentRole", "EmptyResponseException", "ExecutionContext", "ExecutionContextManager", "get_execution_context", + "get_retry_collector", "get_retry_max_num_attempts", "handle_bad_request_exception", "InvalidJsonException", @@ -47,6 +55,8 @@ "pyrit_placeholder_retry", "RateLimitException", "remove_markdown_json", + "RetryCollector", "set_execution_context", + "set_retry_collector", "execution_context", ] diff --git a/pyrit/exceptions/exceptions_helpers.py b/pyrit/exceptions/exceptions_helpers.py index 65a7267daf..1ee2f8e9cf 100644 --- a/pyrit/exceptions/exceptions_helpers.py +++ b/pyrit/exceptions/exceptions_helpers.py @@ -9,6 +9,7 @@ from tenacity import RetryCallState from pyrit.exceptions.exception_context import get_execution_context +from pyrit.exceptions.retry_collector import get_retry_collector logger = logging.getLogger(__name__) @@ -67,6 +68,9 @@ def log_exception(retry_state: RetryCallState) -> None: f"failed with exception: {exception}.{endpoint_clause} " f"Elapsed time: {elapsed_time} seconds. Total calls: {call_count}" ) + collector = get_retry_collector() + if collector: + collector.record(retry_state=retry_state) def remove_start_md_json(response_msg: str) -> str: diff --git a/pyrit/exceptions/retry_collector.py b/pyrit/exceptions/retry_collector.py new file mode 100644 index 0000000000..0710d40272 --- /dev/null +++ b/pyrit/exceptions/retry_collector.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Contextvar-based retry event collector for capturing Tenacity retry events.""" + +import time +from contextvars import ContextVar +from dataclasses import dataclass, field + +from tenacity import RetryCallState + +from pyrit.exceptions.exception_context import get_execution_context +from pyrit.models.retry_event import RetryEvent + + +@dataclass +class RetryCollector: + """ + Collects retry events during attack execution. + + Uses contextvar for thread/task-safe scoping. Each attack execution + creates its own collector so retry events are naturally scoped + per-objective. + """ + + events: list[RetryEvent] = field(default_factory=list) + + def record(self, *, retry_state: RetryCallState) -> None: + """ + Record a retry event from a Tenacity RetryCallState. + + Extracts information from the retry state and the current + ExecutionContext to build a structured RetryEvent. + + Args: + retry_state (RetryCallState): The Tenacity retry call state from the after callback. + """ + elapsed = time.monotonic() - retry_state.start_time + fn_name = retry_state.fn.__name__ if retry_state.fn is not None else "unknown" + + # Extract exception info + exception_type = "" + exception_message = "" + outcome = retry_state.outcome + if outcome is not None and outcome.failed: + exc = outcome.exception() + if exc: + exception_type = type(exc).__name__ + exception_message = str(exc) + + # Extract context info + component_role = "" + component_name: str | None = None + endpoint: str | None = None + try: + exec_context = get_execution_context() + if exec_context: + component_role = exec_context.component_role.value + component_name = exec_context.component_name + endpoint = exec_context.endpoint + except Exception: + pass + + event = RetryEvent( + attempt_number=retry_state.attempt_number, + function_name=fn_name, + exception_type=exception_type, + exception_message=exception_message, + component_role=component_role, + component_name=component_name, + endpoint=endpoint, + elapsed_seconds=round(elapsed, 3), + ) + self.events.append(event) + + +_retry_collector: ContextVar[RetryCollector | None] = ContextVar("retry_collector", default=None) + + +def get_retry_collector() -> RetryCollector | None: + """ + Get the current retry collector. + + Returns: + The active RetryCollector, or None if not set. + """ + return _retry_collector.get() + + +def set_retry_collector(collector: RetryCollector) -> None: + """ + Set the current retry collector. + + Args: + collector: The RetryCollector to activate. + """ + _retry_collector.set(collector) + + +def clear_retry_collector() -> None: + """Clear the current retry collector.""" + _retry_collector.set(None) diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index cbe6441c2e..a054ace813 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -6,11 +6,16 @@ import dataclasses import logging # noqa: TC003 import time +import traceback +import uuid from abc import ABC from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, overload from pyrit.common.logger import logger +from pyrit.exceptions.retry_collector import ( + get_retry_collector, +) from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT from pyrit.executor.core import ( Strategy, @@ -65,6 +70,9 @@ class AttackContext(StrategyContext, ABC, Generic[AttackParamsT]): _prepended_conversation_override: Optional[list[Message]] = None _memory_labels_override: Optional[dict[str, str]] = None + # Set by the ON_ERROR handler to link error AttackResults to ScenarioResults + _error_attack_result_id: str | None = None + # Convenience properties that delegate to params or overrides @property def objective(self) -> str: @@ -134,6 +142,7 @@ def __init__(self, logger: logging.Logger = logger) -> None: self._events = { StrategyEvent.ON_PRE_EXECUTE: self._on_pre_execute, StrategyEvent.ON_POST_EXECUTE: self._on_post_execute, + StrategyEvent.ON_ERROR: self._on_error_async, } self._memory = CentralMemory.get_memory_instance() @@ -167,6 +176,9 @@ async def _on_pre_execute( """ Handle pre-execution logic before the attack strategy runs. + Sets up execution timing and starts a RetryCollector to capture + retry events during execution. + Args: event_data (StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]): The event data containing context and result. @@ -189,6 +201,8 @@ async def _on_post_execute( """ Handle post-execution logic after the attack strategy has run. + Attaches retry events to the result and persists it to memory. + Args: event_data (StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]): The event data containing context and result. @@ -203,6 +217,12 @@ async def _on_post_execute( execution_time_ms = int((end_time - event_data.context.start_time) * 1000) event_data.result.execution_time_ms = execution_time_ms + # Attach collected retry events to the result + collector = get_retry_collector() + if collector and collector.events: + event_data.result.retry_events = collector.events + event_data.result.total_retries = len(collector.events) + self._logger.debug(f"Attack execution completed in {execution_time_ms}ms") self._log_attack_outcome(event_data.result) @@ -222,11 +242,66 @@ def _log_attack_outcome(self, result: AttackResult) -> None: message = f"{attack_name} achieved the objective. {reason}" elif result.outcome == AttackOutcome.UNDETERMINED: message = f"{attack_name} outcome is undetermined. {reason}" + elif result.outcome == AttackOutcome.ERROR: + message = f"{attack_name} failed with an error. {reason}" else: message = f"{attack_name} did not achieve the objective. {reason}" self._logger.info(message) + async def _on_error_async( + self, event_data: StrategyEventData[AttackStrategyContextT, AttackStrategyResultT] + ) -> None: + """ + Handle error during attack execution. + + Creates an error AttackResult with error details and any retry events + collected during execution, then persists it to memory. + + Args: + event_data (StrategyEventData[AttackStrategyContextT, AttackStrategyResultT]): The event data containing + context, result, and error. + """ + error = event_data.error + context = event_data.context + if not error or not context: + return + + # Clear any stale ID from a previous execution + context._error_attack_result_id = None + + # Collect retry events (visible via inherited ContextVar copy) + collector = get_retry_collector() + retry_events = collector.events if collector else [] + + # Build a conversation_id — use context's if available, otherwise generate one + conversation_id = getattr(context, "conversation_id", None) or str(uuid.uuid4()) + + error_result = AttackResult( + conversation_id=conversation_id, + objective=context.objective, + outcome=AttackOutcome.ERROR, + outcome_reason=f"Exception: {type(error).__name__}: {str(error)}", + labels=context.memory_labels, + related_conversations=context.related_conversations, + error_message=str(error), + error_type=type(error).__name__, + error_traceback="".join(traceback.format_exception(type(error), error, error.__traceback__)), + retry_events=retry_events, + total_retries=len(retry_events), + ) + + end_time = time.perf_counter() + if context.start_time: + error_result.execution_time_ms = int((end_time - context.start_time) * 1000) + + # Persist first, then set the ID on the context so scenario-level code + # only sees the reference if the write succeeded. + self._memory.add_attack_results_to_memory(attack_results=[error_result]) + context._error_attack_result_id = error_result.attack_result_id + + self._logger.error(f"Attack failed with {type(error).__name__}: {error}") + class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Identifiable, ABC): """ diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index 833ea1e072..3aa9ede03b 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -17,6 +17,11 @@ from pyrit.common import default_values from pyrit.common.logger import logger from pyrit.exceptions import clear_execution_context, get_execution_context +from pyrit.exceptions.retry_collector import ( + RetryCollector, + clear_retry_collector, + set_retry_collector, +) from pyrit.models import StrategyResultT if TYPE_CHECKING: @@ -25,6 +30,14 @@ StrategyContextT = TypeVar("StrategyContextT", bound="StrategyContext") +class _StrategyRuntimeError(RuntimeError): + """RuntimeError subclass that carries an optional error_attack_result_id.""" + + def __init__(self, message: str, *, error_attack_result_id: str | None = None) -> None: + super().__init__(message) + self.error_attack_result_id = error_attack_result_id + + @dataclass class StrategyContext(ABC): # noqa: B024 """Base class for all strategy contexts.""" @@ -328,12 +341,24 @@ async def execute_with_context_async(self, *, context: StrategyContextT) -> Stra try: async with self._execution_context(context): await self._handle_event(event=StrategyEvent.ON_PRE_EXECUTE, context=context) + + # Set up RetryCollector in the parent task so it is visible to + # Tenacity callbacks that fire during _perform_async. Event + # handlers run in child tasks (asyncio.create_task) which + # inherit a *copy* of the parent's ContextVar — setting the + # collector here ensures both the execution path and the event + # handlers can see it. + collector = RetryCollector() + set_retry_collector(collector) + result = await self._perform_async(context=context) await self._handle_event(event=StrategyEvent.ON_POST_EXECUTE, context=context, result=result) + clear_retry_collector() return result except Exception as e: # Notify error event await self._handle_event(event=StrategyEvent.ON_ERROR, context=context, error=e) + clear_retry_collector() # Build enhanced error message with execution context if available # Note: The context is preserved on exception by ExecutionContextManager @@ -361,7 +386,10 @@ async def execute_with_context_async(self, *, context: StrategyContextT) -> Stra else: error_message = f"Strategy execution failed for {self.__class__.__name__}: {str(e)}" - raise RuntimeError(error_message) from e + # Attach the error attack result ID if the ON_ERROR handler created one + error_attack_result_id = getattr(context, "_error_attack_result_id", None) + runtime_error = _StrategyRuntimeError(error_message, error_attack_result_id=error_attack_result_id) + raise runtime_error from e async def execute_async(self, **kwargs: Any) -> StrategyResultT: """ diff --git a/pyrit/memory/alembic/versions/108a72344872_add_labels_to_attack_results.py b/pyrit/memory/alembic/versions/108a72344872_add_labels_to_attack_results.py deleted file mode 100644 index 66b1574b93..0000000000 --- a/pyrit/memory/alembic/versions/108a72344872_add_labels_to_attack_results.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -add labels to Attack Results. - -Revision ID: 108a72344872 -Revises: e373726d391b -Create Date: 2026-04-27 13:47:20.711347 -""" - -from collections.abc import Sequence - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "108a72344872" -down_revision: str | None = "e373726d391b" -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - """Apply this schema upgrade.""" - # ### commands auto generated by Alembic and reviewed by author ### - op.add_column("AttackResultEntries", sa.Column("labels", sa.JSON(), nullable=True)) - # ### end Alembic commands ### - - -def downgrade() -> None: - """Revert this schema upgrade.""" - # ### commands auto generated by Alembic and reviewed by author ### - op.drop_column("AttackResultEntries", "labels") - # ### end Alembic commands ### diff --git a/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py index c4dd5a8d09..2ac9a666d6 100644 --- a/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py +++ b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py @@ -100,6 +100,9 @@ def upgrade() -> None: sa.Column("number_tries", sa.INTEGER(), nullable=False), sa.Column("completion_time", sa.DateTime(), nullable=False), sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("error_attack_result_ids_json", sa.Unicode(), nullable=True), + sa.Column("error_message", sa.Unicode(), nullable=True), + sa.Column("error_type", sa.String(), nullable=True), sa.PrimaryKeyConstraint("id"), ) op.create_table( @@ -180,6 +183,12 @@ def upgrade() -> None: sa.Column("adversarial_chat_conversation_ids", sa.JSON(), nullable=True), sa.Column("timestamp", sa.DateTime(), nullable=False), sa.Column("pyrit_version", sa.String(), nullable=True), + sa.Column("labels", sa.JSON(), nullable=True), + sa.Column("error_message", sa.Unicode(), nullable=True), + sa.Column("error_type", sa.String(), nullable=True), + sa.Column("error_traceback", sa.Unicode(), nullable=True), + sa.Column("retry_events_json", sa.Unicode(), nullable=True), + sa.Column("total_retries", sa.INTEGER(), nullable=True), sa.ForeignKeyConstraint( ["last_response_id"], ["PromptMemoryEntries.id"], diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index b28d05976e..65d4480aaf 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -2016,7 +2016,14 @@ def add_attack_results_to_scenario( logger.exception(f"Failed to add attack results to scenario {scenario_result_id}: {str(e)}") raise - def update_scenario_run_state(self, *, scenario_result_id: str, scenario_run_state: str) -> bool: + def update_scenario_run_state( + self, + *, + scenario_result_id: str, + scenario_run_state: str, + error_message: str | None = None, + error_type: str | None = None, + ) -> None: """ Update the run state of an existing scenario result. @@ -2024,41 +2031,67 @@ def update_scenario_run_state(self, *, scenario_result_id: str, scenario_run_sta scenario_result_id (str): The ID of the scenario result to update. scenario_run_state (str): The new state for the scenario (e.g., "CREATED", "IN_PROGRESS", "COMPLETED", "FAILED"). + error_message (str | None): Optional scenario-level error message. + error_type (str | None): Optional exception class name. - Returns: - bool: True if the update was successful, False otherwise. + Raises: + ValueError: If the scenario result is not found. + """ + scenario_results = self.get_scenario_results(scenario_result_ids=[scenario_result_id]) - Example: - >>> memory.update_scenario_run_state( - ... scenario_result_id="123e4567-e89b-12d3-a456-426614174000", - ... scenario_run_state="COMPLETED" - ... ) + if not scenario_results: + raise ValueError(f"Scenario result with ID {scenario_result_id} not found in memory") + + scenario_result = scenario_results[0] + + # Update the scenario run state + scenario_result.scenario_run_state = scenario_run_state # type: ignore[ty:invalid-assignment] + + if error_message is not None: + scenario_result.error_message = error_message + if error_type is not None: + scenario_result.error_type = error_type + + # Save updated result back to memory using update + entry = ScenarioResultEntry(entry=scenario_result) + self._update_entry(entry) + + logger.info(f"Updated scenario {scenario_result_id} state to '{scenario_run_state}'") + + def update_scenario_error_attacks(self, *, scenario_result_id: str, error_attack_result_ids: list[str]) -> None: """ - try: - # Retrieve current scenario result - scenario_results = self.get_scenario_results(scenario_result_ids=[scenario_result_id]) + Update the error attack result IDs on an existing scenario result. - if not scenario_results: - logger.error(f"Scenario result with ID {scenario_result_id} not found in memory") - return False + This links failed AttackResults to the ScenarioResult so the REST API + can quickly find error details without scanning all attacks. - scenario_result = scenario_results[0] + Performs the read-modify-write within a single DB session to avoid + inter-session consistency issues. - # Update the scenario run state - scenario_result.scenario_run_state = scenario_run_state # type: ignore[ty:invalid-assignment] + Args: + scenario_result_id: The ID of the scenario result to update. + error_attack_result_ids: IDs of AttackResults that contain error information. - # Save updated result back to memory using update - entry = ScenarioResultEntry(entry=scenario_result) - self._update_entry(entry) + Raises: + ValueError: If the scenario result is not found. + """ + import json - logger.info(f"Updated scenario {scenario_result_id} state to '{scenario_run_state}'") - return True + with closing(self.get_session()) as session: + entry = session.query(ScenarioResultEntry).filter_by(id=scenario_result_id).first() - except Exception as e: - logger.exception( - f"Failed to update scenario {scenario_result_id} state to '{scenario_run_state}': {str(e)}" + if not entry: + raise ValueError(f"Scenario result with ID {scenario_result_id} not found in memory") + + existing: list[str] = ( + json.loads(entry.error_attack_result_ids_json) if entry.error_attack_result_ids_json else [] ) - raise + merged = list(dict.fromkeys(existing + error_attack_result_ids)) + entry.error_attack_result_ids_json = json.dumps(merged) + + session.commit() + + logger.info(f"Updated scenario {scenario_result_id} with {len(error_attack_result_ids)} error attack result(s)") def get_scenario_results( self, diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 08bd0e466a..1e48c03cf5 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -723,7 +723,7 @@ class AttackResultEntry(Base): ) executed_turns = mapped_column(INTEGER, nullable=False, default=0) execution_time_ms = mapped_column(INTEGER, nullable=False, default=0) - outcome: Mapped[Literal["success", "failure", "undetermined"]] = mapped_column( + outcome: Mapped[Literal["success", "failure", "error", "undetermined"]] = mapped_column( String, nullable=False, default="undetermined" ) outcome_reason = mapped_column(String, nullable=True) @@ -736,6 +736,15 @@ class AttackResultEntry(Base): # Nullable for backwards compatibility with existing databases pyrit_version = mapped_column(String, nullable=True) + # Error information (populated when attack fails with exception) + error_message = mapped_column(Unicode, nullable=True) + error_type = mapped_column(String, nullable=True) + error_traceback = mapped_column(Unicode, nullable=True) + + # Retry events (JSON-serialized list of RetryEvent dicts) + retry_events_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) + total_retries = mapped_column(INTEGER, nullable=True, default=0) + last_response: Mapped[Optional["PromptMemoryEntry"]] = relationship( "PromptMemoryEntry", foreign_keys=[last_response_id], @@ -798,6 +807,18 @@ def __init__(self, *, entry: AttackResult) -> None: self.timestamp = entry.timestamp or datetime.now(tz=timezone.utc) self.pyrit_version = pyrit.__version__ + # Error information + self.error_message = entry.error_message + self.error_type = entry.error_type + # Truncate traceback to 10KB to avoid excessive DB storage + self.error_traceback = entry.error_traceback[:10240] if entry.error_traceback else None + + # Retry events + self.retry_events_json = ( + json.dumps([evt.to_dict() for evt in entry.retry_events]) if entry.retry_events else None + ) + self.total_retries = entry.total_retries + @staticmethod def _get_id_as_uuid(obj: Any) -> Optional[uuid.UUID]: """ @@ -883,6 +904,13 @@ def get_attack_result(self) -> AttackResult: attack_identifier=ComponentIdentifier.from_dict(self.attack_identifier), ) + # Deserialize retry events from JSON + retry_events = [] + if self.retry_events_json: + from pyrit.models.retry_event import RetryEvent + + retry_events = [RetryEvent.from_dict(evt_dict) for evt_dict in json.loads(self.retry_events_json)] + return AttackResult( conversation_id=self.conversation_id, attack_result_id=str(self.id), @@ -898,6 +926,11 @@ def get_attack_result(self) -> AttackResult: metadata=self.attack_metadata or {}, timestamp=_ensure_utc(self.timestamp) or datetime.now(tz=timezone.utc), labels=self.labels or {}, + error_message=self.error_message, + error_type=self.error_type, + error_traceback=self.error_traceback, + retry_events=retry_events, + total_retries=self.total_retries or 0, ) @@ -959,6 +992,13 @@ class ScenarioResultEntry(Base): completion_time = mapped_column(DateTime, nullable=False) timestamp = mapped_column(DateTime, nullable=False) + # Pointer to failed attack result(s) — avoids scanning all attacks for error info + error_attack_result_ids_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) + + # Scenario-level error info (persisted so it survives process restarts) + error_message: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) + error_type: Mapped[Optional[str]] = mapped_column(String, nullable=True) + def __init__(self, *, entry: ScenarioResult) -> None: """ Initialize a ScenarioResultEntry from a ScenarioResult object. @@ -1004,6 +1044,14 @@ def __init__(self, *, entry: ScenarioResult) -> None: # Serialize display_group_map if present self.display_group_map_json = json.dumps(entry._display_group_map) if entry._display_group_map else None + # Serialize error_attack_result_ids if present + self.error_attack_result_ids_json = ( + json.dumps(entry.error_attack_result_ids) if entry.error_attack_result_ids else None + ) + + self.error_message = entry.error_message + self.error_type = entry.error_type + self.timestamp = datetime.now(tz=timezone.utc) def get_scenario_result(self) -> ScenarioResult: @@ -1045,6 +1093,11 @@ def get_scenario_result(self) -> ScenarioResult: if self.display_group_map_json: display_group_map = json.loads(self.display_group_map_json) + # Deserialize error_attack_result_ids if stored + error_attack_result_ids: list[str] | None = None + if self.error_attack_result_ids_json: + error_attack_result_ids = json.loads(self.error_attack_result_ids_json) + return ScenarioResult( id=self.id, scenario_identifier=scenario_identifier, @@ -1057,6 +1110,9 @@ def get_scenario_result(self) -> ScenarioResult: number_tries=self.number_tries, completion_time=self.completion_time, display_group_map=display_group_map, + error_attack_result_ids=error_attack_result_ids, + error_message=self.error_message, + error_type=self.error_type, ) def get_conversation_ids_by_attack_name(self) -> dict[str, list[str]]: diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 5a85226322..1093e1da02 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -33,6 +33,7 @@ ) from pyrit.models.message_piece import MessagePiece, sort_message_pieces from pyrit.models.question_answering import QuestionAnsweringDataset, QuestionAnsweringEntry, QuestionChoice +from pyrit.models.retry_event import RetryEvent from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult from pyrit.models.score import Score, ScoreType, UnvalidatedScore @@ -115,4 +116,5 @@ "TextDataTypeSerializer", "UnvalidatedScore", "VideoPathDataTypeSerializer", + "RetryEvent", ] diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index be51668fa1..123c83a918 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -16,6 +16,7 @@ from pyrit.identifiers.component_identifier import ComponentIdentifier from pyrit.models.conversation_reference import ConversationReference from pyrit.models.message_piece import MessagePiece + from pyrit.models.retry_event import RetryEvent from pyrit.models.score import Score from pyrit.models.conversation_reference import ConversationType @@ -37,6 +38,9 @@ class AttackOutcome(str, Enum): # The attack failed to achieve its objective FAILURE = "failure" + # The attack failed due to an infrastructure error (exception), not a defensive refusal + ERROR = "error" + # The outcome of the attack is unknown or could not be determined UNDETERMINED = "undetermined" @@ -94,6 +98,15 @@ class AttackResult(StrategyResult): # labels associated with this attack result labels: dict[str, str] = field(default_factory=dict) + # Error information (populated when attack fails with exception) + error_message: str | None = None + error_type: str | None = None + error_traceback: str | None = None + + # Retry tracking + retry_events: list[RetryEvent] = field(default_factory=list) + total_retries: int = 0 + @property def attack_identifier(self) -> Optional[ComponentIdentifier]: """ diff --git a/pyrit/models/retry_event.py b/pyrit/models/retry_event.py new file mode 100644 index 0000000000..030bba3500 --- /dev/null +++ b/pyrit/models/retry_event.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Data model for capturing individual retry events during execution.""" + +from dataclasses import dataclass, field +from datetime import datetime, timezone + + +@dataclass +class RetryEvent: + """ + A single retry attempt captured during attack execution. + + Records structured information about a Tenacity retry event, including + which component was retrying, what exception triggered the retry, and + timing information. These events are collected by a RetryCollector and + attached to AttackResult objects for persistence and REST API exposure. + """ + + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + attempt_number: int = 0 + function_name: str = "" + exception_type: str = "" + exception_message: str = "" + component_role: str = "" + component_name: str | None = None + endpoint: str | None = None + elapsed_seconds: float = 0.0 + + def to_dict(self) -> dict: + """ + Serialize to a dictionary suitable for JSON storage. + + Returns: + dict: Dictionary representation of the retry event. + """ + return { + "timestamp": self.timestamp.isoformat(), + "attempt_number": self.attempt_number, + "function_name": self.function_name, + "exception_type": self.exception_type, + "exception_message": self.exception_message, + "component_role": self.component_role, + "component_name": self.component_name, + "endpoint": self.endpoint, + "elapsed_seconds": self.elapsed_seconds, + } + + @classmethod + def from_dict(cls, data: dict) -> "RetryEvent": + """ + Deserialize from a dictionary. + + Args: + data: Dictionary representation of a retry event. + + Returns: + RetryEvent: Deserialized retry event. + """ + return cls( + timestamp=datetime.fromisoformat(data["timestamp"]), + attempt_number=data.get("attempt_number", 0), + function_name=data.get("function_name", ""), + exception_type=data.get("exception_type", ""), + exception_message=data.get("exception_message", ""), + component_role=data.get("component_role", ""), + component_name=data.get("component_name"), + endpoint=data.get("endpoint"), + elapsed_seconds=data.get("elapsed_seconds", 0.0), + ) diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index 496b1d73a9..88a67f5991 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -69,6 +69,9 @@ def __init__( number_tries: int = 0, id: uuid.UUID | None = None, # noqa: A002 display_group_map: dict[str, str] | None = None, + error_attack_result_ids: list[str] | None = None, + error_message: str | None = None, + error_type: str | None = None, ) -> None: """ Initialize a scenario result. @@ -87,6 +90,11 @@ def __init__( display_group_map (Optional[dict[str, str]]): Optional mapping of atomic_attack_name → display group label. Used by the console printer to aggregate results for user-facing output. + error_attack_result_ids (Optional[list[str]]): IDs of AttackResults that + contain error information. Used for quick error lookup without scanning + all attack results. + error_message (Optional[str]): Scenario-level error message when the run fails. + error_type (Optional[str]): Exception class name when the run fails. """ self.id = id if id is not None else uuid.uuid4() @@ -103,6 +111,9 @@ def __init__( self.completion_time = completion_time if completion_time is not None else datetime.now(timezone.utc) self.number_tries = number_tries self._display_group_map = display_group_map or {} + self.error_attack_result_ids = error_attack_result_ids or [] + self.error_message = error_message + self.error_type = error_type @property def display_group_map(self) -> dict[str, str]: diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 089bed4267..7a30cb5a1c 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -1095,18 +1095,35 @@ async def _execute_scenario_async(self) -> ScenarioResult: for obj, exc in atomic_results.incomplete_objectives: logger.error(f" Incomplete objective '{obj[:50]}...': {str(exc)}") + # Collect error attack result IDs from the exceptions + error_ids = [] + for _, exc in atomic_results.incomplete_objectives: + error_id = getattr(exc, "error_attack_result_id", None) + if error_id: + error_ids.append(error_id) + + # Link error attack results to the scenario result + if error_ids: + self._memory.update_scenario_error_attacks( + scenario_result_id=scenario_result_id, + error_attack_result_ids=error_ids, + ) + # Mark scenario as failed + error_msg = ( + f"Atomic attack '{atomic_attack.atomic_attack_name}' partially failed: " + f"{incomplete_count} of {incomplete_count + completed_count} objectives incomplete. " + f"See attack results for details." + ) self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="FAILED", + error_message=error_msg, + error_type=type(atomic_results.incomplete_objectives[0][1]).__name__, ) # Raise exception with detailed information - raise ValueError( - f"Failed to execute atomic attack {i} ('{atomic_attack.atomic_attack_name}') " - f"in scenario '{self._name}': {incomplete_count} of {incomplete_count + completed_count} " - f"objectives incomplete. First failure: {atomic_results.incomplete_objectives[0][1]}" - ) from atomic_results.incomplete_objectives[0][1] + raise ValueError(error_msg) from atomic_results.incomplete_objectives[0][1] logger.info( f"Atomic attack {i}/{len(self._atomic_attacks)} completed successfully with " f"{len(atomic_results.completed_results)} results" @@ -1125,6 +1142,8 @@ async def _execute_scenario_async(self) -> ScenarioResult: self._memory.update_scenario_run_state( scenario_result_id=scenario_result_id, scenario_run_state="FAILED", + error_message=str(e), + error_type=type(e).__name__, ) raise diff --git a/tests/unit/analytics/test_result_analysis.py b/tests/unit/analytics/test_result_analysis.py index 05d1b94f0e..e2d96b5bd4 100644 --- a/tests/unit/analytics/test_result_analysis.py +++ b/tests/unit/analytics/test_result_analysis.py @@ -42,29 +42,44 @@ def test_analyze_results_raises_on_invalid_object(): @pytest.mark.parametrize( - "outcomes, expected_successes, expected_failures, expected_undetermined, expected_rate", + "outcomes, expected_successes, expected_failures, expected_undetermined, expected_errors, expected_rate", [ # all successes - ([AttackOutcome.SUCCESS, AttackOutcome.SUCCESS], 2, 0, 0, 1.0), + ([AttackOutcome.SUCCESS, AttackOutcome.SUCCESS], 2, 0, 0, 0, 1.0), # all failures - ([AttackOutcome.FAILURE, AttackOutcome.FAILURE], 0, 2, 0, 0.0), + ([AttackOutcome.FAILURE, AttackOutcome.FAILURE], 0, 2, 0, 0, 0.0), # mixed decided - ([AttackOutcome.SUCCESS, AttackOutcome.FAILURE], 1, 1, 0, 0.5), + ([AttackOutcome.SUCCESS, AttackOutcome.FAILURE], 1, 1, 0, 0, 0.5), # include undetermined (excluded from denominator) - ([AttackOutcome.SUCCESS, AttackOutcome.UNDETERMINED], 1, 0, 1, 1.0), - ([AttackOutcome.FAILURE, AttackOutcome.UNDETERMINED], 0, 1, 1, 0.0), + ([AttackOutcome.SUCCESS, AttackOutcome.UNDETERMINED], 1, 0, 1, 0, 1.0), + ([AttackOutcome.FAILURE, AttackOutcome.UNDETERMINED], 0, 1, 1, 0, 0.0), # multiple with undetermined ( [AttackOutcome.SUCCESS, AttackOutcome.FAILURE, AttackOutcome.UNDETERMINED], 1, 1, 1, + 0, + 0.5, + ), + # error excluded from denominator (like undetermined) + ([AttackOutcome.SUCCESS, AttackOutcome.ERROR], 1, 0, 0, 1, 1.0), + ([AttackOutcome.FAILURE, AttackOutcome.ERROR], 0, 1, 0, 1, 0.0), + # all errors + ([AttackOutcome.ERROR, AttackOutcome.ERROR], 0, 0, 0, 2, None), + # mixed with error and undetermined + ( + [AttackOutcome.SUCCESS, AttackOutcome.FAILURE, AttackOutcome.ERROR, AttackOutcome.UNDETERMINED], + 1, + 1, + 1, + 1, 0.5, ), ], ) def test_overall_success_rate_parametrized( - outcomes, expected_successes, expected_failures, expected_undetermined, expected_rate + outcomes, expected_successes, expected_failures, expected_undetermined, expected_errors, expected_rate ): attacks = [make_attack(o) for o in outcomes] result = analyze_results(attacks) @@ -74,6 +89,7 @@ def test_overall_success_rate_parametrized( assert overall.successes == expected_successes assert overall.failures == expected_failures assert overall.undetermined == expected_undetermined + assert overall.errors == expected_errors assert overall.total_decided == expected_successes + expected_failures assert overall.success_rate == expected_rate diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 084a10eefa..b44145c7d6 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -887,6 +887,19 @@ async def test_update_attack_updates_outcome_undetermined(self, attack_service, call_kwargs = mock_memory.update_attack_result_by_id.call_args[1] assert call_kwargs["update_fields"]["outcome"] == "undetermined" + async def test_update_attack_updates_outcome_error(self, attack_service, mock_memory) -> None: + """Test that update_attack maps 'error' to AttackOutcome.ERROR.""" + ar = make_attack_result(conversation_id="test-id") + mock_memory.get_attack_results.return_value = [ar] + mock_memory.get_conversation.return_value = [] + + await attack_service.update_attack_async( + attack_result_id="test-id", request=UpdateAttackRequest(outcome="error") + ) + + call_kwargs = mock_memory.update_attack_result_by_id.call_args[1] + assert call_kwargs["update_fields"]["outcome"] == "error" + async def test_update_attack_refreshes_updated_at(self, attack_service, mock_memory) -> None: """Test that update_attack refreshes the updated_at metadata.""" old_time = datetime(2024, 1, 1, tzinfo=timezone.utc) diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index ec1f403c9a..d3aef9a4ed 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -365,6 +365,40 @@ def test_created_at_falls_back_to_now_when_both_absent(self) -> None: assert before <= summary.created_at <= after + def test_retry_events_mapped_to_response(self) -> None: + """Test that retry events on an AttackResult are mapped to RetryEventResponse DTOs.""" + from pyrit.models.retry_event import RetryEvent + + now = datetime.now(timezone.utc) + ar = _make_attack_result() + ar.retry_events = [ + RetryEvent( + timestamp=now, + attempt_number=1, + function_name="send_prompt_async", + exception_type="RateLimitError", + exception_message="rate limit exceeded", + component_role="objective_target", + component_name="OpenAIChatTarget", + endpoint="https://api.openai.com", + elapsed_seconds=1.5, + ), + ] + ar.total_retries = 1 + + stats = ConversationStats(message_count=0) + summary = attack_result_to_summary(ar, stats=stats) + + assert summary.retry_events is not None + assert len(summary.retry_events) == 1 + evt = summary.retry_events[0] + assert evt.attempt_number == 1 + assert evt.function_name == "send_prompt_async" + assert evt.exception_type == "RateLimitError" + assert evt.component_role == "objective_target" + assert evt.elapsed_seconds == 1.5 + assert summary.total_retries == 1 + """Tests for pyrit_scores_to_dto function.""" def test_maps_scores(self) -> None: diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 26fa81a814..e29686599a 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -72,6 +72,9 @@ def _make_db_scenario_result( sr.objective_achieved_rate.return_value = 0 sr.get_display_groups.return_value = {} sr.display_group_map = {} + sr.error_message = None + sr.error_type = None + sr.error_attack_result_ids = [] return sr @@ -294,6 +297,26 @@ def test_get_run_returns_existing_run(self, mock_memory) -> None: assert fetched.scenario_name == "foundry.red_team_agent" assert fetched.status == ScenarioRunStatus.IN_PROGRESS + def test_get_run_falls_back_to_persisted_error(self, mock_memory) -> None: + """Test that get_run extracts error from persisted error AttackResult when no active task.""" + db_result = _make_db_scenario_result(result_id="sr-fail", run_state="FAILED") + db_result.error_attack_result_ids = ["err-ar-1"] + + # Mock the error AttackResult lookup + error_ar = MagicMock() + error_ar.error_message = "Connection refused" + error_ar.error_type = "ConnectionError" + mock_memory.get_scenario_results.return_value = [db_result] + mock_memory.get_attack_results.return_value = [error_ar] + + service = ScenarioRunService() + fetched = service.get_run(scenario_result_id="sr-fail") + + assert fetched is not None + assert fetched.error == "Connection refused" + assert fetched.error_type == "ConnectionError" + mock_memory.get_attack_results.assert_called_once_with(attack_result_ids=["err-ar-1"]) + class TestScenarioRunServiceListRuns: """Tests for ScenarioRunService.list_runs.""" @@ -351,7 +374,10 @@ async def test_cancel_run_sets_cancelled_status(self, mock_all_registries) -> No result = await service.cancel_run_async(scenario_result_id=response.scenario_result_id) mock_memory.update_scenario_run_state.assert_called_once_with( - scenario_result_id=response.scenario_result_id, scenario_run_state="CANCELLED" + scenario_result_id=response.scenario_result_id, + scenario_run_state="CANCELLED", + error_message="Run was cancelled by user", + error_type="CancelledError", ) assert result is not None assert result.status == ScenarioRunStatus.CANCELLED @@ -470,6 +496,11 @@ def test_get_results_returns_details_for_completed_run(self, mock_memory) -> Non mock_attack_result.executed_turns = 3 mock_attack_result.execution_time_ms = 1500 mock_attack_result.timestamp = None + mock_attack_result.error_message = None + mock_attack_result.error_type = None + mock_attack_result.error_traceback = None + mock_attack_result.total_retries = 0 + mock_attack_result.retry_events = [] db_result = _make_db_scenario_result( result_id="sr-123", diff --git a/tests/unit/exceptions/test_retry_collector.py b/tests/unit/exceptions/test_retry_collector.py new file mode 100644 index 0000000000..bc338fcc23 --- /dev/null +++ b/tests/unit/exceptions/test_retry_collector.py @@ -0,0 +1,149 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio + +from pyrit.exceptions.retry_collector import ( + RetryCollector, + clear_retry_collector, + get_retry_collector, + set_retry_collector, +) + + +class TestRetryCollector: + """Tests for the RetryCollector and its contextvar helpers.""" + + def test_collector_starts_empty(self) -> None: + """A new RetryCollector has no events.""" + c = RetryCollector() + assert c.events == [] + + def test_contextvar_default_is_none(self) -> None: + """get_retry_collector returns None when no collector is set.""" + clear_retry_collector() + assert get_retry_collector() is None + + def test_set_and_get(self) -> None: + """set_retry_collector makes it visible to get_retry_collector.""" + c = RetryCollector() + set_retry_collector(c) + assert get_retry_collector() is c + clear_retry_collector() + + def test_clear(self) -> None: + """clear_retry_collector resets the contextvar to None.""" + c = RetryCollector() + set_retry_collector(c) + clear_retry_collector() + assert get_retry_collector() is None + + def test_record_extracts_exception_info(self) -> None: + """record() extracts exception type and message from retry_state.""" + from unittest.mock import MagicMock + + c = RetryCollector() + + # Build a mock retry_state matching Tenacity's RetryCallState + retry_state = MagicMock() + retry_state.attempt_number = 2 + retry_state.start_time = 0.0 + retry_state.fn = MagicMock() + retry_state.fn.__name__ = "my_function" + + # Mock outcome with .failed=True and .exception() returning a ValueError + outcome = MagicMock() + outcome.failed = True + outcome.exception.return_value = ValueError("test error") + retry_state.outcome = outcome + + c.record(retry_state=retry_state) + + assert len(c.events) == 1 + evt = c.events[0] + assert evt.attempt_number == 2 + assert evt.function_name == "my_function" + assert evt.exception_type == "ValueError" + assert evt.exception_message == "test error" + + def test_record_multiple_events(self) -> None: + """record() accumulates events.""" + from unittest.mock import MagicMock + + c = RetryCollector() + + for i in range(3): + retry_state = MagicMock() + retry_state.attempt_number = i + 1 + retry_state.start_time = 0.0 + retry_state.fn = MagicMock() + retry_state.fn.__name__ = f"fn_{i}" + outcome = MagicMock() + outcome.failed = True + outcome.exception.return_value = RuntimeError(f"error {i}") + retry_state.outcome = outcome + c.record(retry_state=retry_state) + + assert len(c.events) == 3 + assert c.events[0].function_name == "fn_0" + assert c.events[2].function_name == "fn_2" + + def test_contextvar_isolation_across_tasks(self) -> None: + """Each asyncio task gets its own contextvar value.""" + results: dict[str, bool] = {} + + async def task_a() -> None: + c = RetryCollector() + set_retry_collector(c) + await asyncio.sleep(0.01) + results["a_has_collector"] = get_retry_collector() is c + clear_retry_collector() + + async def task_b() -> None: + await asyncio.sleep(0.005) + results["b_sees_none"] = get_retry_collector() is None + + async def run() -> None: + clear_retry_collector() + await asyncio.gather(task_a(), task_b()) + + asyncio.run(run()) + assert results.get("a_has_collector") is True + assert results.get("b_sees_none") is True + + def test_record_extracts_execution_context(self) -> None: + """record() extracts component_role, component_name, and endpoint from ExecutionContext.""" + from unittest.mock import MagicMock + + from pyrit.exceptions.exception_context import ( + ComponentRole, + ExecutionContext, + set_execution_context, + ) + + c = RetryCollector() + + ctx = ExecutionContext( + component_role=ComponentRole.OBJECTIVE_TARGET, + component_name="OpenAIChatTarget", + endpoint="https://api.openai.com", + ) + set_execution_context(ctx) + + retry_state = MagicMock() + retry_state.attempt_number = 1 + retry_state.start_time = 0.0 + retry_state.fn = MagicMock() + retry_state.fn.__name__ = "send_prompt_async" + outcome = MagicMock() + outcome.failed = True + outcome.exception.return_value = RuntimeError("timeout") + retry_state.outcome = outcome + + c.record(retry_state=retry_state) + + assert len(c.events) == 1 + evt = c.events[0] + assert evt.component_role == "objective_target" + assert evt.component_name == "OpenAIChatTarget" + assert evt.endpoint == "https://api.openai.com" diff --git a/tests/unit/executor/attack/core/test_attack_strategy.py b/tests/unit/executor/attack/core/test_attack_strategy.py index 7b6825775c..33370f0c9d 100644 --- a/tests/unit/executor/attack/core/test_attack_strategy.py +++ b/tests/unit/executor/attack/core/test_attack_strategy.py @@ -6,6 +6,7 @@ import pytest +from pyrit.exceptions.retry_collector import RetryCollector from pyrit.executor.attack.core.attack_parameters import AttackParameters from pyrit.executor.attack.core.attack_strategy import ( AttackContext, @@ -20,6 +21,7 @@ AttackResult, Message, ) +from pyrit.models.retry_event import RetryEvent from pyrit.prompt_target import PromptTarget @@ -397,6 +399,26 @@ async def test_on_post_execute_logs_undetermined( expected_message = f"{event_handler.__class__.__name__} outcome is undetermined. Reason: Not specified" mock_logger.info.assert_called_with(expected_message) + async def test_on_post_execute_logs_error_outcome( + self, event_handler, sample_attack_context, sample_attack_result, mock_logger + ): + """Test that post-execute handler logs error outcome""" + sample_attack_result.outcome = AttackOutcome.ERROR + sample_attack_result.outcome_reason = "Connection timeout" + + event_data = StrategyEventData( + event=StrategyEvent.ON_POST_EXECUTE, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + result=sample_attack_result, + ) + + await event_handler.on_event(event_data) + + expected_message = f"{event_handler.__class__.__name__} failed with an error. Reason: Connection timeout" + mock_logger.info.assert_called_with(expected_message) + async def test_on_post_execute_adds_results_to_memory(self, mock_memory): """Test that post-execute handler adds results to memory""" with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): @@ -437,6 +459,141 @@ async def test_on_post_execute_raises_on_none_result(self, event_handler, sample with pytest.raises(ValueError, match="Attack result is None"): await event_handler.on_event(event_data) + async def test_on_post_execute_attaches_retry_events( + self, sample_attack_context, sample_attack_result, mock_memory + ): + """Test that post-execute handler attaches retry events from collector to the result""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + sample_attack_context.start_time = 100.0 + retry_event = RetryEvent( + attempt_number=1, function_name="send_prompt_async", exception_type="RateLimitError" + ) + + collector = RetryCollector(events=[retry_event]) + with patch("pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector): + event_data = StrategyEventData( + event=StrategyEvent.ON_POST_EXECUTE, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + result=sample_attack_result, + ) + await handler.on_event(event_data) + + assert sample_attack_result.retry_events == [retry_event] + assert sample_attack_result.total_retries == 1 + + async def test_on_post_execute_no_retry_events_when_collector_empty( + self, sample_attack_context, sample_attack_result, mock_memory + ): + """Test that post-execute handler does not set retry_events when collector has no events""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + sample_attack_context.start_time = 100.0 + collector = RetryCollector(events=[]) + with patch("pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector): + event_data = StrategyEventData( + event=StrategyEvent.ON_POST_EXECUTE, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + result=sample_attack_result, + ) + await handler.on_event(event_data) + + # Empty collector means the guard `if collector and collector.events` is False + assert not sample_attack_result.retry_events + assert sample_attack_result.total_retries == 0 + + async def test_on_error_attaches_retry_events(self, sample_attack_context, mock_memory): + """Test that error handler attaches collected retry events to the error AttackResult""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + sample_attack_context.start_time = 100.0 + retry_event = RetryEvent(attempt_number=2, function_name="send_prompt_async", exception_type="TimeoutError") + collector = RetryCollector(events=[retry_event]) + + with patch("pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=collector): + event_data = StrategyEventData( + event=StrategyEvent.ON_ERROR, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + error=RuntimeError("test error"), + ) + await handler.on_event(event_data) + + stored_result = mock_memory.add_attack_results_to_memory.call_args.kwargs["attack_results"][0] + assert stored_result.outcome == AttackOutcome.ERROR + assert stored_result.retry_events == [retry_event] + assert stored_result.total_retries == 1 + + async def test_on_error_empty_retry_events_when_no_collector(self, sample_attack_context, mock_memory): + """Test that error handler sets empty retry_events when no collector exists""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + sample_attack_context.start_time = 100.0 + + with patch("pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=None): + event_data = StrategyEventData( + event=StrategyEvent.ON_ERROR, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + error=RuntimeError("test error"), + ) + await handler.on_event(event_data) + + stored_result = mock_memory.add_attack_results_to_memory.call_args.kwargs["attack_results"][0] + assert stored_result.retry_events == [] + assert stored_result.total_retries == 0 + + async def test_on_error_persists_result_to_memory(self, sample_attack_context, mock_memory): + """Test that error handler creates an error AttackResult and persists it""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + sample_attack_context.start_time = 100.0 + error = ValueError("something broke") + + with patch("pyrit.executor.attack.core.attack_strategy.get_retry_collector", return_value=None): + event_data = StrategyEventData( + event=StrategyEvent.ON_ERROR, + strategy_name="TestStrategy", + strategy_id="test-id", + context=sample_attack_context, + error=error, + ) + with patch("time.perf_counter", return_value=100.5): + await handler.on_event(event_data) + + mock_memory.add_attack_results_to_memory.assert_called_once() + stored_result = mock_memory.add_attack_results_to_memory.call_args.kwargs["attack_results"][0] + assert stored_result.outcome == AttackOutcome.ERROR + assert stored_result.error_message == "something broke" + assert stored_result.error_type == "ValueError" + assert stored_result.execution_time_ms == 500 + + async def test_on_error_skips_when_no_error_or_context(self, mock_memory): + """Test that error handler returns early when error or context is None""" + with patch("pyrit.memory.central_memory.CentralMemory.get_memory_instance", return_value=mock_memory): + handler = _DefaultAttackStrategyEventHandler() + + event_data = StrategyEventData( + event=StrategyEvent.ON_ERROR, + strategy_name="TestStrategy", + strategy_id="test-id", + context=None, + error=RuntimeError("test"), + ) + await handler.on_event(event_data) + mock_memory.add_attack_results_to_memory.assert_not_called() + async def test_on_event_handles_other_events(self, event_handler, sample_attack_context, mock_logger): """Test that on_event handles events not in the specific handlers""" event_data = StrategyEventData( diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 0fec3aceaa..97a240e37c 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -648,6 +648,60 @@ def test_combined_filters(sqlite_instance: MemoryInterface): assert "gpt-4" in results[0].objective_target_identifier.params["model_name"] +def test_update_scenario_error_attacks_success(sqlite_instance: MemoryInterface, sample_attack_results): + """Test successfully linking error attack result IDs to a scenario result.""" + scenario_result = create_scenario_result( + name="Error Scenario", + attack_results={"Attack1": [sample_attack_results[0]]}, + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) + + error_ids = ["error-ar-1", "error-ar-2"] + sqlite_instance.update_scenario_error_attacks( + scenario_result_id=str(scenario_result.id), + error_attack_result_ids=error_ids, + ) + + # Verify the error IDs were persisted + results = sqlite_instance.get_scenario_results(scenario_result_ids=[str(scenario_result.id)]) + assert len(results) == 1 + assert results[0].error_attack_result_ids == error_ids + + +def test_update_scenario_error_attacks_appends_to_existing(sqlite_instance: MemoryInterface, sample_attack_results): + """Test that updating error attacks appends to existing IDs without duplicates.""" + scenario_result = create_scenario_result( + name="Error Scenario", + attack_results={"Attack1": [sample_attack_results[0]]}, + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario_result]) + + # First update + sqlite_instance.update_scenario_error_attacks( + scenario_result_id=str(scenario_result.id), + error_attack_result_ids=["error-ar-1"], + ) + + # Second update with overlap and new ID + sqlite_instance.update_scenario_error_attacks( + scenario_result_id=str(scenario_result.id), + error_attack_result_ids=["error-ar-1", "error-ar-2"], + ) + + results = sqlite_instance.get_scenario_results(scenario_result_ids=[str(scenario_result.id)]) + # Should be deduplicated: ["error-ar-1", "error-ar-2"] + assert results[0].error_attack_result_ids == ["error-ar-1", "error-ar-2"] + + +def test_update_scenario_error_attacks_not_found(sqlite_instance: MemoryInterface): + """Test that updating a nonexistent scenario result raises ValueError.""" + with pytest.raises(ValueError, match="not found in memory"): + sqlite_instance.update_scenario_error_attacks( + scenario_result_id="nonexistent-id", + error_attack_result_ids=["error-ar-1"], + ) + + def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: MemoryInterface): """Test filtering scenario results by identifier filter.""" target_id_1 = ComponentIdentifier( diff --git a/tests/unit/memory/test_memory_models.py b/tests/unit/memory/test_memory_models.py index f143a69e98..b382bf8646 100644 --- a/tests/unit/memory/test_memory_models.py +++ b/tests/unit/memory/test_memory_models.py @@ -474,3 +474,17 @@ def test_init_with_empty_attack_results(self): entry = ScenarioResultEntry(entry=sr) conv_ids = entry.get_conversation_ids_by_attack_name() assert conv_ids == {} + + def test_roundtrip_error_attack_result_ids(self): + sr = self._make_scenario_result(error_attack_result_ids=["err-1", "err-2"]) + entry = ScenarioResultEntry(entry=sr) + assert entry.error_attack_result_ids_json is not None + recovered = entry.get_scenario_result() + assert recovered.error_attack_result_ids == ["err-1", "err-2"] + + def test_roundtrip_error_attack_result_ids_none(self): + sr = self._make_scenario_result() + entry = ScenarioResultEntry(entry=sr) + assert entry.error_attack_result_ids_json is None + recovered = entry.get_scenario_result() + assert recovered.error_attack_result_ids == [] diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index 12dd64e260..874d924846 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -8,6 +8,7 @@ from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory.memory_models import AttackResultEntry from pyrit.models.attack_result import AttackOutcome, AttackResult +from pyrit.models.retry_event import RetryEvent class TestAttackResultDeprecation: @@ -206,3 +207,147 @@ def test_naive_entry_timestamp_is_normalized_to_utc_on_hydration(self) -> None: assert hydrated.timestamp is not None assert hydrated.timestamp.tzinfo is timezone.utc assert hydrated.timestamp.replace(tzinfo=None) == datetime(2026, 4, 17, 12, 0, 0) # noqa: DTZ001 + + +class TestAttackResultErrorFields: + """Tests for the error and retry fields on AttackResult.""" + + def test_error_fields_default_to_none(self) -> None: + """AttackResult without error fields defaults to None/empty.""" + result = AttackResult(conversation_id="c1", objective="test") + assert result.error_message is None + assert result.error_type is None + assert result.error_traceback is None + assert result.retry_events == [] + assert result.total_retries == 0 + + def test_error_fields_set_correctly(self) -> None: + """AttackResult stores error fields when provided.""" + result = AttackResult( + conversation_id="c1", + objective="test", + error_message="Connection refused", + error_type="ConnectionError", + error_traceback="Traceback (most recent call last):\n ...", + total_retries=3, + ) + assert result.error_message == "Connection refused" + assert result.error_type == "ConnectionError" + assert "Traceback" in result.error_traceback + assert result.total_retries == 3 + + def test_retry_events_stored_on_result(self) -> None: + """AttackResult stores retry events.""" + events = [ + RetryEvent(attempt_number=1, function_name="fn1", exception_type="TimeoutError"), + RetryEvent(attempt_number=2, function_name="fn1", exception_type="TimeoutError"), + ] + result = AttackResult( + conversation_id="c1", + objective="test", + retry_events=events, + total_retries=2, + ) + assert len(result.retry_events) == 2 + assert result.retry_events[0].attempt_number == 1 + assert result.retry_events[1].attempt_number == 2 + + +class TestAttackResultErrorRoundTrip: + """Tests that error/retry fields survive the AttackResult -> AttackResultEntry -> AttackResult round-trip.""" + + def test_error_fields_roundtrip(self) -> None: + """Error fields are serialized to entry and deserialized back.""" + original = AttackResult( + conversation_id="c1", + objective="test", + outcome=AttackOutcome.FAILURE, + error_message="Rate limit hit", + error_type="RateLimitError", + error_traceback="Traceback...\n File ...", + total_retries=5, + ) + entry = AttackResultEntry(entry=original) + + # Verify serialized values on entry + assert entry.error_message == "Rate limit hit" + assert entry.error_type == "RateLimitError" + assert entry.error_traceback == "Traceback...\n File ..." + assert entry.total_retries == 5 + + # Deserialize back + hydrated = entry.get_attack_result() + assert hydrated.error_message == "Rate limit hit" + assert hydrated.error_type == "RateLimitError" + assert hydrated.error_traceback == "Traceback...\n File ..." + assert hydrated.total_retries == 5 + + def test_retry_events_roundtrip(self) -> None: + """Retry events are serialized to JSON and deserialized back.""" + events = [ + RetryEvent( + attempt_number=1, + function_name="send_async", + exception_type="TimeoutError", + exception_message="timed out", + component_role="target", + component_name="AzureTarget", + endpoint="https://api.azure.com", + elapsed_seconds=5.5, + ), + RetryEvent( + attempt_number=2, + function_name="send_async", + exception_type="RateLimitError", + exception_message="429", + elapsed_seconds=10.0, + ), + ] + original = AttackResult( + conversation_id="c1", + objective="test", + retry_events=events, + total_retries=2, + ) + entry = AttackResultEntry(entry=original) + assert entry.retry_events_json is not None + + hydrated = entry.get_attack_result() + assert len(hydrated.retry_events) == 2 + assert hydrated.retry_events[0].attempt_number == 1 + assert hydrated.retry_events[0].function_name == "send_async" + assert hydrated.retry_events[0].exception_type == "TimeoutError" + assert hydrated.retry_events[0].component_name == "AzureTarget" + assert hydrated.retry_events[1].attempt_number == 2 + assert hydrated.retry_events[1].exception_type == "RateLimitError" + assert hydrated.total_retries == 2 + + def test_no_error_fields_roundtrip(self) -> None: + """AttackResult without error fields round-trips cleanly.""" + original = AttackResult( + conversation_id="c1", + objective="test", + outcome=AttackOutcome.SUCCESS, + ) + entry = AttackResultEntry(entry=original) + assert entry.error_message is None + assert entry.error_type is None + assert entry.retry_events_json is None + assert entry.total_retries == 0 + + hydrated = entry.get_attack_result() + assert hydrated.error_message is None + assert hydrated.error_type is None + assert hydrated.retry_events == [] + assert hydrated.total_retries == 0 + + def test_traceback_truncation(self) -> None: + """Very long tracebacks are truncated to 10KB.""" + long_traceback = "x" * 20000 + original = AttackResult( + conversation_id="c1", + objective="test", + error_traceback=long_traceback, + ) + entry = AttackResultEntry(entry=original) + assert len(entry.error_traceback) == 10240 diff --git a/tests/unit/models/test_retry_event.py b/tests/unit/models/test_retry_event.py new file mode 100644 index 0000000000..09f20dfca7 --- /dev/null +++ b/tests/unit/models/test_retry_event.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from datetime import datetime, timezone + +from pyrit.models.retry_event import RetryEvent + + +class TestRetryEvent: + """Tests for the RetryEvent dataclass.""" + + def test_defaults(self) -> None: + """RetryEvent constructed with minimal args gets correct defaults.""" + evt = RetryEvent(attempt_number=1, function_name="test_fn") + assert evt.attempt_number == 1 + assert evt.function_name == "test_fn" + assert evt.exception_type == "" + assert evt.exception_message == "" + assert evt.component_role == "" + assert evt.component_name is None + assert evt.endpoint is None + assert evt.elapsed_seconds == 0.0 + assert evt.timestamp is not None + assert evt.timestamp.tzinfo is timezone.utc + + def test_full_construction(self) -> None: + """RetryEvent constructed with all args stores them correctly.""" + ts = datetime(2026, 5, 7, 12, 0, 0, tzinfo=timezone.utc) + evt = RetryEvent( + attempt_number=3, + function_name="send_prompt_async", + exception_type="RateLimitError", + exception_message="Rate limit exceeded", + component_role="objective_target", + component_name="OpenAIChatTarget", + endpoint="https://api.openai.com/v1/chat", + elapsed_seconds=5.123, + timestamp=ts, + ) + assert evt.attempt_number == 3 + assert evt.function_name == "send_prompt_async" + assert evt.exception_type == "RateLimitError" + assert evt.exception_message == "Rate limit exceeded" + assert evt.component_role == "objective_target" + assert evt.component_name == "OpenAIChatTarget" + assert evt.endpoint == "https://api.openai.com/v1/chat" + assert evt.elapsed_seconds == 5.123 + assert evt.timestamp == ts + + def test_to_dict(self) -> None: + """to_dict returns a JSON-serializable dictionary.""" + evt = RetryEvent( + attempt_number=2, + function_name="fn", + exception_type="ValueError", + exception_message="bad input", + component_role="scorer", + component_name="TFScorer", + endpoint="https://example.com", + elapsed_seconds=1.5, + ) + d = evt.to_dict() + assert d["attempt_number"] == 2 + assert d["function_name"] == "fn" + assert d["exception_type"] == "ValueError" + assert d["exception_message"] == "bad input" + assert d["component_role"] == "scorer" + assert d["component_name"] == "TFScorer" + assert d["endpoint"] == "https://example.com" + assert d["elapsed_seconds"] == 1.5 + assert "timestamp" in d + + def test_from_dict_roundtrip(self) -> None: + """from_dict correctly reconstructs a RetryEvent from to_dict output.""" + original = RetryEvent( + attempt_number=1, + function_name="call_target", + exception_type="TimeoutError", + exception_message="Request timed out", + component_role="objective_target", + component_name="AzureTarget", + endpoint="https://azure.openai.com", + elapsed_seconds=10.0, + ) + d = original.to_dict() + restored = RetryEvent.from_dict(d) + + assert restored.attempt_number == original.attempt_number + assert restored.function_name == original.function_name + assert restored.exception_type == original.exception_type + assert restored.exception_message == original.exception_message + assert restored.component_role == original.component_role + assert restored.component_name == original.component_name + assert restored.endpoint == original.endpoint + assert restored.elapsed_seconds == original.elapsed_seconds + + def test_from_dict_missing_optional_fields(self) -> None: + """from_dict handles missing optional fields gracefully.""" + d = { + "attempt_number": 1, + "function_name": "fn", + "timestamp": "2026-05-07T12:00:00+00:00", + } + evt = RetryEvent.from_dict(d) + assert evt.attempt_number == 1 + assert evt.function_name == "fn" + assert evt.exception_type == "" + assert evt.component_name is None + assert evt.endpoint is None + assert evt.elapsed_seconds == 0.0 + + def test_from_dict_timestamp_parsing(self) -> None: + """from_dict correctly parses ISO format timestamp.""" + d = { + "attempt_number": 1, + "function_name": "fn", + "timestamp": "2026-05-07T12:30:00+00:00", + } + evt = RetryEvent.from_dict(d) + assert evt.timestamp.year == 2026 + assert evt.timestamp.month == 5 + assert evt.timestamp.hour == 12 + assert evt.timestamp.minute == 30 diff --git a/tests/unit/models/test_scenario_result.py b/tests/unit/models/test_scenario_result.py index 68a1cef1a9..02af031429 100644 --- a/tests/unit/models/test_scenario_result.py +++ b/tests/unit/models/test_scenario_result.py @@ -165,3 +165,24 @@ def test_normalize_scenario_name_already_pascal(self): def test_normalize_scenario_name_mixed_case_with_underscore(self): assert ScenarioResult.normalize_scenario_name("Content_harms") == "Content_harms" + + def test_error_attack_result_ids_defaults_to_empty(self): + """error_attack_result_ids defaults to empty list.""" + sr = ScenarioResult( + scenario_identifier=_make_scenario_identifier(), + objective_target_identifier=ComponentIdentifier.from_dict({}), + attack_results={}, + objective_scorer_identifier=ComponentIdentifier.from_dict({}), + ) + assert sr.error_attack_result_ids == [] + + def test_error_attack_result_ids_stored(self): + """error_attack_result_ids are stored correctly.""" + sr = ScenarioResult( + scenario_identifier=_make_scenario_identifier(), + objective_target_identifier=ComponentIdentifier.from_dict({}), + attack_results={}, + objective_scorer_identifier=ComponentIdentifier.from_dict({}), + error_attack_result_ids=["id-1", "id-2"], + ) + assert sr.error_attack_result_ids == ["id-1", "id-2"]