diff --git a/.github/instructions/style-guide.instructions.md b/.github/instructions/style-guide.instructions.md index edee2e0b59..4d9bf7ff4f 100644 --- a/.github/instructions/style-guide.instructions.md +++ b/.github/instructions/style-guide.instructions.md @@ -90,6 +90,7 @@ def validate_input(self, data: dict) -> None: # Should be private - `str | None` not `Optional[str]` - `int | float` not `Union[int, float]` - Still import `Any`, `Literal`, `TypeVar`, `Protocol`, `cast` etc. from `typing` as needed +- **This rule applies to docstrings and comments too.** Argument type references inside docstrings (e.g. `Args:` blocks) and any comment mentioning a type should use the modern form so the docs stay consistent with the signatures. ```python # CORRECT diff --git a/doc/blog/2024_12_3.md b/doc/blog/2024_12_3.md index d18e05c20f..ca6f0a3e03 100644 --- a/doc/blog/2024_12_3.md +++ b/doc/blog/2024_12_3.md @@ -40,7 +40,7 @@ It turns out, yes, we can. `CrescendoOrchestrator`, `PairOrchestrator`, `RedTeam - `max_turns` defines the maximum number of conversation turns. - `prompt_converters` are used to modify prompts before sending them to the target. - `objective_scorer` evaluates whether the objective was achieved. -- `run_attack_async(objective: str, memory_labels: Optional[dict[str, str]] = None)` executes the attack and always returns a `OrchestratorResult`, which includes information about the conversation and the outcome. +- `run_attack_async(objective: str, memory_labels: dict[str, str] | None = None)` executes the attack and always returns a `OrchestratorResult`, which includes information about the conversation and the outcome. - `run_attacks_async` enables parallelized attacks. - `print_conversation_async` is now standardized and prints the "best" conversation (when multiple exist). diff --git a/doc/code/memory/0_memory.md b/doc/code/memory/0_memory.md index de45abe2b7..f736705752 100644 --- a/doc/code/memory/0_memory.md +++ b/doc/code/memory/0_memory.md @@ -12,7 +12,7 @@ At the beginning of each notebook, make sure to call: # Import the specific constant for the MemoryDatabaseType, or provide the literal value from pyrit.setup import initialize_pyrit_async, IN_MEMORY, SQLITE, AZURE_SQL -await initialize_pyrit_async(memory_db_type: MemoryDatabaseType, memory_instance_kwargs: Optional[Any]) +await initialize_pyrit_async(memory_db_type: MemoryDatabaseType, memory_instance_kwargs: Any | None) ``` The `MemoryDatabaseType` is a `Literal` with 3 options: IN_MEMORY, SQLITE, AZURE_SQL. (Read more below) diff --git a/doc/code/registry/0_registry.md b/doc/code/registry/0_registry.md index b3e0d9c24b..b29e0fa5fe 100644 --- a/doc/code/registry/0_registry.md +++ b/doc/code/registry/0_registry.md @@ -15,7 +15,7 @@ PyRIT has two registry patterns for different use cases: | Type | Stores | Use Case | |------|--------|----------| -| **Class Registry** | Classes (Type[T]) | Components instantiated with user-provided parameters | +| **Class Registry** | Classes (type[T]) | Components instantiated with user-provided parameters | | **Instance Registry** | Pre-configured instances | Components requiring complex setup before use | ## Common API (RegistryProtocol) @@ -44,7 +44,7 @@ def show_registry_contents(registry: RegistryProtocol) -> None: | Aspect | Class Registry | Instance Registry | |--------|----------------|-------------------| -| Stores | Classes (Type[T]) | Instances (T) | +| Stores | Classes (type[T]) | Instances (T) | | Registration | Automatic discovery | Explicit via `register()` | | Returns | Class to instantiate | Ready-to-use instance | | Instantiation | Caller provides parameters | Pre-configured by initializer | diff --git a/doc/code/setup/default_values.md b/doc/code/setup/default_values.md index 04efe67e24..ad3ae7dcb3 100644 --- a/doc/code/setup/default_values.md +++ b/doc/code/setup/default_values.md @@ -23,7 +23,7 @@ from pyrit.common.apply_defaults import apply_defaults class MyConverter(PromptConverter): @apply_defaults - def __init__(self, *, converter_target: Optional[PromptTarget] = None, temperature: Optional[float] = None): + def __init__(self, *, converter_target: PromptTarget | None = None, temperature: float | None = None): self.converter_target = converter_target self.temperature = temperature ``` diff --git a/doc/code/targets/11_message_normalizer.ipynb b/doc/code/targets/11_message_normalizer.ipynb index 63d9b2faef..93d2489313 100644 --- a/doc/code/targets/11_message_normalizer.ipynb +++ b/doc/code/targets/11_message_normalizer.ipynb @@ -19,8 +19,8 @@ "## Base Classes\n", "\n", "There are two base normalizer types:\n", - "- **`MessageListNormalizer[T]`**: Converts `List[Message]` → `List[T]` (e.g., to `ChatMessage` objects)\n", - "- **`MessageStringNormalizer`**: Converts `List[Message]` → `str` (e.g., to ChatML format)\n", + "- **`MessageListNormalizer[T]`**: Converts `list[Message]` → `list[T]` (e.g., to `ChatMessage` objects)\n", + "- **`MessageStringNormalizer`**: Converts `list[Message]` → `str` (e.g., to ChatML format)\n", "\n", "Some normalizers implement both interfaces." ] diff --git a/doc/code/targets/11_message_normalizer.py b/doc/code/targets/11_message_normalizer.py index 583afed606..9f005339be 100644 --- a/doc/code/targets/11_message_normalizer.py +++ b/doc/code/targets/11_message_normalizer.py @@ -23,8 +23,8 @@ # ## Base Classes # # There are two base normalizer types: -# - **`MessageListNormalizer[T]`**: Converts `List[Message]` → `List[T]` (e.g., to `ChatMessage` objects) -# - **`MessageStringNormalizer`**: Converts `List[Message]` → `str` (e.g., to ChatML format) +# - **`MessageListNormalizer[T]`**: Converts `list[Message]` → `list[T]` (e.g., to `ChatMessage` objects) +# - **`MessageStringNormalizer`**: Converts `list[Message]` → `str` (e.g., to ChatML format) # # Some normalizers implement both interfaces. diff --git a/doc/getting_started/troubleshooting/deploy_hf_model_aml.ipynb b/doc/getting_started/troubleshooting/deploy_hf_model_aml.ipynb index 2b5600ce5a..971a066c83 100644 --- a/doc/getting_started/troubleshooting/deploy_hf_model_aml.ipynb +++ b/doc/getting_started/troubleshooting/deploy_hf_model_aml.ipynb @@ -128,14 +128,12 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Union\n", - "\n", "from azure.ai.ml import MLClient\n", "from azure.core.exceptions import ResourceNotFoundError\n", "from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential\n", "\n", "try:\n", - " credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential()\n", + " credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential()\n", " credential.get_token(\"https://management.azure.com/.default\")\n", "except Exception as ex:\n", " credential = InteractiveBrowserCredential()\n", diff --git a/doc/getting_started/troubleshooting/deploy_hf_model_aml.py b/doc/getting_started/troubleshooting/deploy_hf_model_aml.py index b55a818a6c..b6fa4a1e67 100644 --- a/doc/getting_started/troubleshooting/deploy_hf_model_aml.py +++ b/doc/getting_started/troubleshooting/deploy_hf_model_aml.py @@ -106,14 +106,13 @@ # Set up the `DefaultAzureCredential` for seamless authentication with Azure services. This method should handle most authentication scenarios. If you encounter issues, refer to the [Azure Identity documentation](https://docs.microsoft.com/en-us/python/api/azure-identity/azure.identity?view=azure-python) for alternative credentials. # # %% -from typing import Union from azure.ai.ml import MLClient from azure.core.exceptions import ResourceNotFoundError from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential try: - credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential() + credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential() credential.get_token("https://management.azure.com/.default") except Exception as ex: credential = InteractiveBrowserCredential() diff --git a/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.ipynb b/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.ipynb index 29e7fdec00..65cb1229a3 100644 --- a/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.ipynb +++ b/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.ipynb @@ -71,7 +71,6 @@ "source": [ "# Import the Azure ML SDK components required for workspace connection and model management.\n", "import os\n", - "from typing import Union\n", "\n", "# Import necessary libraries for Azure ML operations and authentication\n", "from azure.ai.ml import MLClient, UserIdentityConfiguration\n", @@ -201,7 +200,7 @@ "source": [ "# Setup Azure credentials, preferring DefaultAzureCredential and falling back to InteractiveBrowserCredential if necessary\n", "try:\n", - " credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential()\n", + " credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential()\n", " # Verify if the default credential can fetch a token successfully\n", " credential.get_token(\"https://management.azure.com/.default\")\n", "except Exception as ex:\n", diff --git a/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.py b/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.py index 34013c449f..251e49ecf2 100644 --- a/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.py +++ b/doc/getting_started/troubleshooting/download_and_register_hf_model_aml.py @@ -61,7 +61,6 @@ # %% # Import the Azure ML SDK components required for workspace connection and model management. import os -from typing import Union # Import necessary libraries for Azure ML operations and authentication from azure.ai.ml import MLClient, UserIdentityConfiguration @@ -160,7 +159,7 @@ # %% # Setup Azure credentials, preferring DefaultAzureCredential and falling back to InteractiveBrowserCredential if necessary try: - credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential() + credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential() # Verify if the default credential can fetch a token successfully credential.get_token("https://management.azure.com/.default") except Exception as ex: diff --git a/pyproject.toml b/pyproject.toml index ee6ef1b728..243d636931 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -341,8 +341,6 @@ ignore = [ "DOC502", # Raised exception is not explicitly raised "PERF203", # try-except-in-loop (intentional per-item error handling) "SIM117", # multiple-with-statements (combining often exceeds line length) - "UP007", # non-pep604-annotation-union (keep Union[X, Y] syntax) - "UP045", # non-pep604-annotation-optional (keep Optional[X] syntax) ] extend-select = [ "D204", # 1 blank line required after class docstring diff --git a/pyrit/analytics/conversation_analytics.py b/pyrit/analytics/conversation_analytics.py index 3440ddce60..12e2d8b63a 100644 --- a/pyrit/analytics/conversation_analytics.py +++ b/pyrit/analytics/conversation_analytics.py @@ -57,11 +57,11 @@ def get_similar_chat_messages_by_embedding( Retrieve chat messages that are similar to the given embedding based on cosine similarity. Args: - chat_message_embedding (List[float]): The embedding of the chat message to find similar messages for. + chat_message_embedding (list[float]): The embedding of the chat message to find similar messages for. threshold (float): The similarity threshold for considering messages as similar. Defaults to 0.8. Returns: - List[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing + list[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing the similar chat messages based on embedding similarity. """ all_embdedding_memory = self.memory_interface.get_all_embeddings() diff --git a/pyrit/analytics/result_analysis.py b/pyrit/analytics/result_analysis.py index 9e158ed88a..f4e400a7cc 100644 --- a/pyrit/analytics/result_analysis.py +++ b/pyrit/analytics/result_analysis.py @@ -4,7 +4,7 @@ from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.models import ( AttackOutcome, @@ -22,7 +22,7 @@ class AttackStats: """Statistics for attack analysis results.""" - success_rate: Optional[float] + success_rate: float | None total_decided: int successes: int failures: int @@ -118,7 +118,7 @@ def get_cached_results_for_technique( *, technique_eval_hash: str, objective_target_eval_hash: str, - additional_filters: Optional[Sequence[IdentifierFilter]] = None, + additional_filters: Sequence[IdentifierFilter] | None = None, ) -> list[AttackResult]: """ Return cached AttackResults matching a (technique × objective target) pair. @@ -144,7 +144,7 @@ def get_cached_results_for_technique( (also exposed as ``AtomicAttack.technique_eval_hash``). objective_target_eval_hash (str): Behavioral eval hash of the objective target, as produced by ``ObjectiveTargetEvaluationIdentifier.eval_hash``. - additional_filters (Optional[Sequence[IdentifierFilter]]): Extra + additional_filters (Sequence[IdentifierFilter] | None): Extra ``IdentifierFilter`` predicates appended to the SQL pre-filter. Defaults to None. @@ -170,7 +170,7 @@ def get_cached_results_for_technique( return matches -def _objective_target_eval_hash_for(attack_result: AttackResult) -> Optional[str]: +def _objective_target_eval_hash_for(attack_result: AttackResult) -> str | None: """ Return the ObjectiveTargetEvaluationIdentifier eval hash for a result. @@ -182,7 +182,7 @@ def _objective_target_eval_hash_for(attack_result: AttackResult) -> Optional[str ``atomic_attack_identifier`` tree should be inspected. Returns: - Optional[str]: The ``ObjectiveTargetEvaluationIdentifier.eval_hash`` + str | None: The ``ObjectiveTargetEvaluationIdentifier.eval_hash`` computed from the persisted objective-target identifier, or ``None`` when the identifier tree is missing expected nodes (e.g. legacy rows or atomic attacks without a distinct objective diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 82a39e082d..8a6131d3a4 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -6,7 +6,7 @@ import inspect import logging import time -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, cast import msal from azure.core.credentials import AccessToken @@ -41,7 +41,7 @@ class TokenProviderCredential: get_azure_token_provider) and Azure SDK clients that require a TokenCredential object. """ - def __init__(self, token_provider: Callable[[], Union[str, Callable[..., Any]]]) -> None: + def __init__(self, token_provider: Callable[[], str | Callable[..., Any]]) -> None: """ Initialize TokenProviderCredential. @@ -75,7 +75,7 @@ class AsyncTokenProviderCredential: async clients that require an AsyncTokenCredential object (with async def get_token). """ - def __init__(self, token_provider: Callable[[], Union[str, Awaitable[str]]]) -> None: + def __init__(self, token_provider: Callable[[], str | Awaitable[str]]) -> None: """ Initialize AsyncTokenProviderCredential. @@ -394,14 +394,14 @@ def get_azure_openai_auth(endpoint: str) -> Callable[[], Awaitable[str]]: return get_azure_async_token_provider(scope) -def get_speech_config(resource_id: Union[str, None], key: Union[str, None], region: str) -> speechsdk.SpeechConfig: +def get_speech_config(resource_id: str | None, key: str | None, region: str) -> speechsdk.SpeechConfig: """ Get the speech config using key/region pair (for key auth scenarios) or resource_id/region pair (for Entra auth scenarios). Args: - resource_id (Union[str, None]): The resource ID to get the token for. - key (Union[str, None]): The Azure Speech key + resource_id (str | None): The resource ID to get the token for. + key (str | None): The Azure Speech key region (str): The region to get the token for. Returns: @@ -437,8 +437,8 @@ def get_speech_config(resource_id: Union[str, None], key: Union[str, None], regi async def get_speech_config_async( *, token_provider: Callable[[], str | Awaitable[str]] | None, - resource_id: Union[str, None], - key: Union[str, None], + resource_id: str | None, + key: str | None, region: str, ) -> speechsdk.SpeechConfig: """ diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index a069af33ff..ba83c023c6 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -7,7 +7,7 @@ import os import sys from datetime import datetime, timedelta, timezone -from typing import Any, Optional +from typing import Any from msal_extensions import FilePersistence, build_encrypted_persistence @@ -196,7 +196,7 @@ def _create_persistent_cache(cache_file: str, fallback_to_plaintext: bool = Fals logger.error(f"Encryption unavailable ({e}) and fallback_to_plaintext is False. Cannot proceed.") raise - async def _get_cached_token_if_available_and_valid_async(self) -> Optional[dict[str, Any]]: + async def _get_cached_token_if_available_and_valid_async(self) -> dict[str, Any] | None: """ Retrieve and validate cached token. @@ -258,7 +258,7 @@ async def _get_cached_token_if_available_and_valid_async(self) -> Optional[dict[ logger.error(f"Failed to load cached token ({error_name}): {e}") return None - def _save_token_to_cache(self, *, token: str, expires_in: Optional[int] = None) -> None: + def _save_token_to_cache(self, *, token: str, expires_in: int | None = None) -> None: """ Save token to persistent cache with metadata. @@ -301,12 +301,12 @@ def _clear_token_cache(self) -> None: except Exception as e: logger.error(f"Failed to clear cache: {e}") - async def _fetch_access_token_with_playwright_async(self) -> Optional[str]: + async def _fetch_access_token_with_playwright_async(self) -> str | None: """ Fetch access token using Playwright browser automation. Returns: - Optional[str]: The bearer token if successfully retrieved, None otherwise. + str | None: The bearer token if successfully retrieved, None otherwise. Raises: RuntimeError: If Playwright is not installed or browser launch fails. @@ -339,35 +339,35 @@ async def _fetch_access_token_with_playwright_async(self) -> Optional[str]: # If not on Windows or using the right loop already, proceed normally return await self._run_playwright_browser_automation_async() - async def _run_playwright_in_thread_async(self) -> Optional[str]: + async def _run_playwright_in_thread_async(self) -> str | None: """ Run Playwright browser automation in a separate thread with ProactorEventLoop. This is needed on Windows when the main loop is SelectorEventLoop (e.g., in Jupyter). Returns: - Optional[str]: The bearer token if successfully retrieved, None otherwise. + str | None: The bearer token if successfully retrieved, None otherwise. """ - def run_in_new_loop() -> Optional[str]: + def run_in_new_loop() -> str | None: if sys.platform == "win32": new_loop = asyncio.ProactorEventLoop() else: new_loop = asyncio.new_event_loop() asyncio.set_event_loop(new_loop) try: - result: Optional[str] = new_loop.run_until_complete(self._run_playwright_browser_automation_async()) + result: str | None = new_loop.run_until_complete(self._run_playwright_browser_automation_async()) return result finally: new_loop.close() return await asyncio.get_running_loop().run_in_executor(None, run_in_new_loop) - async def _run_playwright_browser_automation_async(self) -> Optional[str]: + async def _run_playwright_browser_automation_async(self) -> str | None: """ Execute the actual Playwright browser automation to fetch the access token. Returns: - Optional[str]: The bearer token if successfully retrieved, None otherwise. + str | None: The bearer token if successfully retrieved, None otherwise. Raises: ValueError: If the username is not set. diff --git a/pyrit/auth/manual_copilot_authenticator.py b/pyrit/auth/manual_copilot_authenticator.py index b175118878..2f50209d8a 100644 --- a/pyrit/auth/manual_copilot_authenticator.py +++ b/pyrit/auth/manual_copilot_authenticator.py @@ -3,7 +3,7 @@ import logging import os -from typing import Any, Optional +from typing import Any import jwt @@ -36,12 +36,12 @@ class ManualCopilotAuthenticator(Authenticator): #: Environment variable name for the Copilot access token ACCESS_TOKEN_ENV_VAR: str = "COPILOT_ACCESS_TOKEN" - def __init__(self, *, access_token: Optional[str] = None) -> None: + def __init__(self, *, access_token: str | None = None) -> None: """ Initialize the ManualCopilotAuthenticator with a pre-obtained access token. Args: - access_token (Optional[str]): A valid JWT access token for Microsoft Copilot. + access_token (str | None): A valid JWT access token for Microsoft Copilot. This token can be obtained from browser DevTools when connected to Copilot. If None, the token will be read from the ``COPILOT_ACCESS_TOKEN`` environment variable. diff --git a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py index 787154fe75..ef61ebc4d7 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py +++ b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py @@ -10,7 +10,7 @@ import random import time from copy import deepcopy -from typing import Any, Optional +from typing import Any import numpy as np import pandas as pd @@ -132,7 +132,7 @@ def __init__( target: str, tokenizer: Any, control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, + test_prefixes: list[str] | None = None, ) -> None: """ Initializes the AttackPrompt object with the provided parameters. @@ -417,16 +417,16 @@ def __init__( targets: list[str], tokenizer: Any, control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, - managers: Optional[dict[str, type[AttackPrompt]]] = None, + test_prefixes: list[str] | None = None, + managers: dict[str, type[AttackPrompt]] | None = None, ) -> None: """ Initializes the PromptManager object with the provided parameters. Args: - goals (List[str]): + goals (list[str]): The list of intended goals of the attack - targets (List[str]): + targets (list[str]): The list of targets of the attack tokenizer (Transformer Tokenizer): The tokenizer used to convert text into tokens. Must have a chat template configured. @@ -539,22 +539,22 @@ def __init__( targets: list[str], workers: list[ModelWorker], control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, - logfile: Optional[str] = None, - managers: Optional[dict[str, Any]] = None, - test_goals: Optional[list[str]] = None, - test_targets: Optional[list[str]] = None, - test_workers: Optional[list[ModelWorker]] = None, + test_prefixes: list[str] | None = None, + logfile: str | None = None, + managers: dict[str, Any] | None = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + test_workers: list[ModelWorker] | None = None, ) -> None: """ Initializes the MultiPromptAttack object with the provided parameters. Args: - goals (List[str]): + goals (list[str]): The list of intended goals of the attack - targets (List[str]): + targets (list[str]): The list of targets of the attack - workers (List[Worker]): + workers (list[Worker]): The list of workers used in the attack control_init (str, optional): A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") @@ -619,7 +619,7 @@ def get_filtered_cands( worker_index: int, control_cand: torch.Tensor, filter_cand: bool = True, - curr_control: Optional[str] = None, + curr_control: str | None = None, ) -> list[str]: cands, count = [], 0 worker = self.workers[worker_index] @@ -656,8 +656,8 @@ def run( topk: int = 256, temp: float = 1.0, allow_non_ascii: bool = True, - target_weight: Optional[float] = None, - control_weight: Optional[float] = None, + target_weight: float | None = None, + control_weight: float | None = None, anneal: bool = True, anneal_from: int = 0, prev_loss: float = np.inf, @@ -873,23 +873,23 @@ def __init__( progressive_goals: bool = True, progressive_models: bool = True, control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, - logfile: Optional[str] = None, - managers: Optional[dict[str, Any]] = None, - test_goals: Optional[list[str]] = None, - test_targets: Optional[list[str]] = None, - test_workers: Optional[list[ModelWorker]] = None, + test_prefixes: list[str] | None = None, + logfile: str | None = None, + managers: dict[str, Any] | None = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + test_workers: list[ModelWorker] | None = None, **kwargs: Any, ) -> None: """ Initializes the ProgressiveMultiPromptAttack object with the provided parameters. Args: - goals (List[str]): + goals (list[str]): The list of intended goals of the attack - targets (List[str]): + targets (list[str]): The list of targets of the attack - workers (List[Worker]): + workers (list[Worker]): The list of workers used in the attack progressive_goals (bool, optional): If true, goals progress over time (default is True) @@ -897,17 +897,17 @@ def __init__( If true, models progress over time (default is True) control_init (str, optional): A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") - test_prefixes (List[str], optional): + test_prefixes (list[str], optional): A list of prefixes to test the attack (default is _DEFAULT_TEST_PREFIXES). logfile (str, optional): A file to which logs will be written managers (dict, optional): A dictionary of manager objects, required to create the prompts. - test_goals (List[str], optional): + test_goals (list[str], optional): The list of test goals of the attack - test_targets (List[str], optional): + test_targets (list[str], optional): The list of test targets of the attack - test_workers (List[Worker], optional): + test_workers (list[Worker], optional): The list of test workers used in the attack """ if test_prefixes is None: @@ -986,8 +986,8 @@ def run( topk: int = 256, temp: float = 1.0, allow_non_ascii: bool = False, - target_weight: Optional[float] = None, - control_weight: Optional[float] = None, + target_weight: float | None = None, + control_weight: float | None = None, anneal: bool = True, test_steps: int = 50, incr_control: bool = True, @@ -1119,12 +1119,12 @@ def __init__( targets: list[str], workers: list[ModelWorker], control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, - logfile: Optional[str] = None, - managers: Optional[dict[str, Any]] = None, - test_goals: Optional[list[str]] = None, - test_targets: Optional[list[str]] = None, - test_workers: Optional[list[ModelWorker]] = None, + test_prefixes: list[str] | None = None, + logfile: str | None = None, + managers: dict[str, Any] | None = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + test_workers: list[ModelWorker] | None = None, **kwargs: Any, ) -> None: """ @@ -1225,8 +1225,8 @@ def run( topk: int = 256, temp: float = 1.0, allow_non_ascii: bool = True, - target_weight: Optional[float] = None, - control_weight: Optional[float] = None, + target_weight: float | None = None, + control_weight: float | None = None, anneal: bool = True, test_steps: int = 50, incr_control: bool = True, @@ -1331,12 +1331,12 @@ def __init__( targets: list[str], workers: list[ModelWorker], control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !", - test_prefixes: Optional[list[str]] = None, - logfile: Optional[str] = None, - managers: Optional[dict[str, Any]] = None, - test_goals: Optional[list[str]] = None, - test_targets: Optional[list[str]] = None, - test_workers: Optional[list[ModelWorker]] = None, + test_prefixes: list[str] | None = None, + logfile: str | None = None, + managers: dict[str, Any] | None = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + test_workers: list[ModelWorker] | None = None, **kwargs: Any, ) -> None: """ @@ -1549,7 +1549,7 @@ def __init__( self.tokenizer = tokenizer self.tasks: mp.JoinableQueue[Any] = mp.JoinableQueue() self.results: mp.JoinableQueue[Any] = mp.JoinableQueue() - self.process: Optional[mp.Process] = None + self.process: mp.Process | None = None @staticmethod def run(model: Any, tasks: mp.JoinableQueue[Any], results: mp.JoinableQueue[Any]) -> None: diff --git a/pyrit/auxiliary_attacks/gcg/experiments/log.py b/pyrit/auxiliary_attacks/gcg/experiments/log.py index bdd96c1ca4..2c9e58fea5 100644 --- a/pyrit/auxiliary_attacks/gcg/experiments/log.py +++ b/pyrit/auxiliary_attacks/gcg/experiments/log.py @@ -3,7 +3,7 @@ import logging import subprocess as sp -from typing import Any, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -20,14 +20,14 @@ def log_params( *, params: Any, - param_keys: Optional[list[str]] = None, + param_keys: list[str] | None = None, ) -> None: """ Log selected parameters via Python logging. Args: params (Any): A config object with a `to_dict()` method containing all parameters. - param_keys (Optional[list[str]]): Keys to extract and log. Defaults to standard GCG training keys. + param_keys (list[str] | None): Keys to extract and log. Defaults to standard GCG training keys. """ if param_keys is None: param_keys = _DEFAULT_PARAM_KEYS diff --git a/pyrit/auxiliary_attacks/gcg/generator.py b/pyrit/auxiliary_attacks/gcg/generator.py index cf353a2adf..4c812594e9 100644 --- a/pyrit/auxiliary_attacks/gcg/generator.py +++ b/pyrit/auxiliary_attacks/gcg/generator.py @@ -38,7 +38,7 @@ import logging import time from dataclasses import dataclass, field -from typing import Any, Optional, overload +from typing import Any, overload import numpy as np import torch.multiprocessing as mp @@ -93,8 +93,8 @@ class GCGContext(PromptGeneratorStrategyContext): workers: list[Any] = field(default_factory=list) test_workers: list[Any] = field(default_factory=list) - attack: Optional[Any] = None - logfile_path: Optional[str] = None + attack: Any | None = None + logfile_path: str | None = None class GCGResult(PromptGeneratorStrategyResult): @@ -138,11 +138,11 @@ def __init__( self, *, models: list[GCGModelConfig], - algorithm: Optional[GCGAlgorithmConfig] = None, - strategy: Optional[GCGStrategyConfig] = None, - output: Optional[GCGOutputConfig] = None, - test_models: Optional[list[GCGModelConfig]] = None, - hf_token: Optional[str] = None, + algorithm: GCGAlgorithmConfig | None = None, + strategy: GCGStrategyConfig | None = None, + output: GCGOutputConfig | None = None, + test_models: list[GCGModelConfig] | None = None, + hf_token: str | None = None, ) -> None: """ Initialize the GCG generator. @@ -307,9 +307,9 @@ async def execute_async( *, goals: list[str], targets: list[str], - test_goals: Optional[list[str]] = None, - test_targets: Optional[list[str]] = None, - memory_labels: Optional[dict[str, str]] = None, + test_goals: list[str] | None = None, + test_targets: list[str] | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> GCGResult: ... diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 5807c27bef..9d784648d4 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -17,7 +17,7 @@ import uuid from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast from urllib.parse import quote, urlparse from azure.identity.aio import DefaultAzureCredential @@ -155,7 +155,7 @@ async def _sign_blob_url_async(*, blob_url: str) -> str: return blob_url -def _resolve_media_url(*, value: Optional[str], data_type: str) -> Optional[str]: +def _resolve_media_url(*, value: str | None, data_type: str) -> str | None: """ For media path types, convert a local file path to a ``/api/media`` URL. @@ -311,7 +311,7 @@ def pyrit_scores_to_dto(scores: list[PyritScore]) -> list[Score]: ] -def _infer_mime_type(*, value: Optional[str], data_type: PromptDataType) -> Optional[str]: +def _infer_mime_type(*, value: str | None, data_type: PromptDataType) -> str | None: """ Infer MIME type from a value and its data type. @@ -335,9 +335,9 @@ def _infer_mime_type(*, value: Optional[str], data_type: PromptDataType) -> Opti def _build_filename( *, data_type: str, - sha256: Optional[str], - value: Optional[str], -) -> Optional[str]: + sha256: str | None, + value: str | None, +) -> str | None: """ Build a human-readable download filename from the data type and hash. @@ -355,7 +355,7 @@ def _build_filename( value: The original value (path or URL) used to infer file extension. Returns: - Optional[str]: A filename like ``image_a1b2c3d4e5f6.png``, or ``None`` for text-like types. + str | None: A filename like ``image_a1b2c3d4e5f6.png``, or ``None`` for text-like types. """ # Map data types to friendly prefixes prefix_map = { @@ -462,7 +462,7 @@ def request_piece_to_pyrit_message_piece( role: ChatMessageRole, conversation_id: str, sequence: int, - labels: Optional[dict[str, str]] = None, # deprecated + labels: dict[str, str] | None = None, # deprecated ) -> PyritMessagePiece: """ Convert a single request piece DTO to a PyRIT MessagePiece domain object. @@ -509,7 +509,7 @@ def request_to_pyrit_message( request: AddMessageRequest, conversation_id: str, sequence: int, - labels: Optional[dict[str, str]] = None, # deprecated + labels: dict[str, str] | None = None, # deprecated ) -> PyritMessage: """ Build a PyRIT Message from an AddMessageRequest DTO. diff --git a/pyrit/backend/mappers/converter_mappers.py b/pyrit/backend/mappers/converter_mappers.py index f1d097762d..a71b5aa537 100644 --- a/pyrit/backend/mappers/converter_mappers.py +++ b/pyrit/backend/mappers/converter_mappers.py @@ -5,8 +5,6 @@ Converter mappers – domain → DTO translation for converter-related models. """ -from typing import Optional - from pyrit.backend.models.converters import ConverterInstance from pyrit.prompt_converter import PromptConverter @@ -21,7 +19,7 @@ def converter_object_to_instance( converter_id: str, converter_obj: PromptConverter, *, - sub_converter_ids: Optional[list[str]] = None, + sub_converter_ids: list[str] | None = None, ) -> ConverterInstance: """ Build a ConverterInstance DTO from a registry converter object. diff --git a/pyrit/backend/middleware/auth.py b/pyrit/backend/middleware/auth.py index db7de281ea..012af51912 100644 --- a/pyrit/backend/middleware/auth.py +++ b/pyrit/backend/middleware/auth.py @@ -18,7 +18,7 @@ import logging import os from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import httpx import jwt @@ -241,7 +241,7 @@ async def _resolve_excess_groups_async(self, claims: dict[str, Any], token: str) logger.warning("Failed to resolve group memberships: %s", e) return [] - def _validate_token(self, token: str) -> tuple[Optional[AuthenticatedUser], dict[str, Any]]: + def _validate_token(self, token: str) -> tuple[AuthenticatedUser | None, dict[str, Any]]: """ Validate a JWT against Entra ID JWKS. diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 2f98f78b7e..6cb739cb5e 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -9,7 +9,7 @@ """ from datetime import datetime -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field @@ -26,8 +26,8 @@ class Score(BaseModel): score_value: str = Field( ..., description="Score value ('true'/'false' for true_false, '0.0'-'1.0' for float_scale)" ) - score_category: Optional[list[str]] = Field(None, description="Harm categories (e.g., ['hate', 'violence'])") - score_rationale: Optional[str] = Field(None, description="Explanation for the score") + score_category: list[str] | None = Field(None, description="Harm categories (e.g., ['hate', 'violence'])") + score_rationale: str | None = Field(None, description="Explanation for the score") scored_at: datetime = Field(..., description="When the score was generated") @@ -46,24 +46,24 @@ class MessagePiece(BaseModel): converted_value_data_type: str = Field( default="text", description="Data type of the converted value: 'text', 'image', 'audio', etc." ) - original_value: Optional[str] = Field(default=None, description="Original value before conversion") - original_value_mime_type: Optional[str] = Field(default=None, description="MIME type of original value") + original_value: str | None = Field(default=None, description="Original value before conversion") + original_value_mime_type: str | None = Field(default=None, description="MIME type of original value") converted_value: str = Field(..., description="Converted value (text or base64 for media)") - converted_value_mime_type: Optional[str] = Field(default=None, description="MIME type of converted value") + converted_value_mime_type: str | None = Field(default=None, description="MIME type of converted value") scores: list[Score] = Field(default_factory=list, description="Scores embedded in this piece") response_error: PromptResponseError = Field( default="none", description="Error status: none, processing, blocked, empty, unknown" ) - response_error_description: Optional[str] = Field( + response_error_description: str | None = Field( default=None, description="Description of the error if response_error is not 'none'" ) - original_filename: Optional[str] = Field( + original_filename: str | None = Field( default=None, description="Original filename extracted from file path or blob URL" ) - converted_filename: Optional[str] = Field( + converted_filename: str | None = Field( default=None, description="Converted filename extracted from file path or blob URL" ) - prompt_metadata: Optional[dict[str, Any]] = Field( + prompt_metadata: dict[str, Any] | None = Field( default=None, description="Metadata associated with the piece (e.g., video_id for remix mode)" ) @@ -86,8 +86,8 @@ class TargetInfo(BaseModel): """Target information extracted from the stored TargetIdentifier.""" target_type: str = Field(..., description="Target class name (e.g., 'OpenAIChatTarget')") - endpoint: Optional[str] = Field(None, description="Target endpoint URL") - model_name: Optional[str] = Field(None, description="Model or deployment name") + endpoint: str | None = Field(None, description="Target endpoint URL") + model_name: str | None = Field(None, description="Model or deployment name") class RetryEventResponse(BaseModel): @@ -110,20 +110,18 @@ class AttackSummary(BaseModel): attack_result_id: str = Field(..., description="Database-assigned unique ID for this AttackResult") conversation_id: str = Field(..., description="Primary conversation of this attack result") attack_type: str = Field("", description="Attack class name (e.g., 'CrescendoAttack', 'ManualAttack')") - attack_specific_params: Optional[dict[str, Any]] = Field(None, description="Additional attack-specific parameters") - target: Optional[TargetInfo] = Field(None, description="Target information from the stored identifier") + attack_specific_params: dict[str, Any] | None = Field(None, description="Additional attack-specific parameters") + target: TargetInfo | None = Field(None, description="Target information from the stored identifier") converters: list[str] = Field( default_factory=list, description="Request converter class names applied in this attack" ) objective: str = Field("", description="Natural-language description of the attacker's objective") - outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = Field( + outcome: Literal["undetermined", "success", "failure", "error"] | None = Field( None, description="Attack outcome (null if not yet determined)" ) outcome_reason: str | None = Field(None, description="Reason for the outcome") last_response: str | None = Field(None, description="Model response from the final turn") - last_message_preview: Optional[str] = Field( - None, description="Preview of the last message (truncated to ~100 chars)" - ) + last_message_preview: str | None = Field(None, description="Preview of the last message (truncated to ~100 chars)") score_value: str | None = Field(None, description="Score value from the objective scorer") executed_turns: int = Field(0, ge=0, description="Number of turns executed") execution_time_ms: int = Field(0, ge=0, description="Execution time in milliseconds") @@ -195,13 +193,13 @@ class MessagePieceRequest(BaseModel): data_type: str = Field(default="text", description="Data type: 'text', 'image', 'audio', etc.") original_value: str = Field(..., description="Original value (text or base64 for media)") - converted_value: Optional[str] = Field(None, description="Converted value. If provided, bypasses converters.") - mime_type: Optional[str] = Field(None, description="MIME type for media content") - prompt_metadata: Optional[dict[str, Any]] = Field( + converted_value: str | None = Field(None, description="Converted value. If provided, bypasses converters.") + mime_type: str | None = Field(None, description="MIME type for media content") + prompt_metadata: dict[str, Any] | None = Field( None, description="Metadata to attach to the piece (e.g., {'video_id': '...'} for remix mode).", ) - original_prompt_id: Optional[str] = Field( + original_prompt_id: str | None = Field( None, description="ID of the source piece when prepending from an existing conversation. " "Preserves lineage so the new piece traces back to the original.", @@ -231,18 +229,16 @@ class CreateAttackRequest(BaseModel): supplied in ``labels`` (typically the current operator's labels). """ - name: Optional[str] = Field(None, description="Attack name/label") + name: str | None = Field(None, description="Attack name/label") target_registry_name: str = Field(..., description="Target registry name to attack") - source_conversation_id: Optional[str] = Field( + source_conversation_id: str | None = Field( None, description="Conversation to branch from (clone messages into the new attack)" ) - cutoff_index: Optional[int] = Field( - None, description="Include messages up to and including this turn index (0-based)" - ) - prepended_conversation: Optional[list[PrependedMessageRequest]] = Field( + cutoff_index: int | None = Field(None, description="Include messages up to and including this turn index (0-based)") + prepended_conversation: list[PrependedMessageRequest] | None = Field( None, description="Messages to prepend (system prompts, branching context)", max_length=200 ) - labels: Optional[dict[str, str]] = Field(None, description="User-defined labels for filtering") + labels: dict[str, str] | None = Field(None, description="User-defined labels for filtering") class CreateAttackResponse(BaseModel): @@ -274,8 +270,8 @@ class ConversationSummary(BaseModel): conversation_id: str = Field(..., description="Unique conversation identifier") message_count: int = Field(0, description="Number of messages in this conversation") - last_message_preview: Optional[str] = Field(None, description="Preview of the last message") - created_at: Optional[datetime] = Field(None, description="Timestamp of the first message") + last_message_preview: str | None = Field(None, description="Preview of the last message") + created_at: datetime | None = Field(None, description="Timestamp of the first message") class AttackConversationsResponse(BaseModel): @@ -297,10 +293,8 @@ class CreateConversationRequest(BaseModel): the cutoff turn, preserving tracking relationships (original_prompt_id). """ - source_conversation_id: Optional[str] = Field(None, description="Conversation to branch from") - cutoff_index: Optional[int] = Field( - None, description="Include messages up to and including this turn index (0-based)" - ) + source_conversation_id: str | None = Field(None, description="Conversation to branch from") + cutoff_index: int | None = Field(None, description="Include messages up to and including this turn index (0-based)") class CreateConversationResponse(BaseModel): @@ -344,11 +338,11 @@ class AddMessageRequest(BaseModel): default=True, description="If True, send to target and wait for response. If False, just store in memory.", ) - target_registry_name: Optional[str] = Field( + target_registry_name: str | None = Field( None, description="Target registry name. Required when send=True so the backend knows which target to use.", ) - converter_ids: Optional[list[str]] = Field( + converter_ids: list[str] | None = Field( None, description="Converter instance IDs to apply (overrides attack-level)" ) target_conversation_id: str = Field( @@ -356,7 +350,7 @@ class AddMessageRequest(BaseModel): description="The conversation_id to store and send messages under. " "Usually the attack's main conversation, but can be a related conversation.", ) - labels: Optional[dict[str, str]] = Field( + labels: dict[str, str] | None = Field( None, description="Labels to attach to every message piece. " "Falls back to labels from existing pieces in the conversation.", diff --git a/pyrit/backend/models/common.py b/pyrit/backend/models/common.py index 0a2e00e6b5..36767467cc 100644 --- a/pyrit/backend/models/common.py +++ b/pyrit/backend/models/common.py @@ -7,7 +7,7 @@ Includes pagination, error handling (RFC 7807), and shared base models. """ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -17,8 +17,8 @@ class PaginationInfo(BaseModel): limit: int = Field(..., description="Maximum items per page") has_more: bool = Field(..., description="Whether more items exist") - next_cursor: Optional[str] = Field(None, description="Cursor for next page") - prev_cursor: Optional[str] = Field(None, description="Cursor for previous page") + next_cursor: str | None = Field(None, description="Cursor for next page") + prev_cursor: str | None = Field(None, description="Cursor for previous page") class FieldError(BaseModel): @@ -26,8 +26,8 @@ class FieldError(BaseModel): field: str = Field(..., description="Field name with path (e.g., 'pieces[0].data_type')") message: str = Field(..., description="Error message") - code: Optional[str] = Field(None, description="Error code") - value: Optional[Any] = Field(None, description="The invalid value") + code: str | None = Field(None, description="Error code") + value: Any | None = Field(None, description="The invalid value") class ProblemDetail(BaseModel): @@ -41,8 +41,8 @@ class ProblemDetail(BaseModel): title: str = Field(..., description="Short human-readable summary") status: int = Field(..., description="HTTP status code") detail: str = Field(..., description="Human-readable explanation") - instance: Optional[str] = Field(None, description="URI of the specific occurrence") - errors: Optional[list[FieldError]] = Field(None, description="Field-level errors for validation") + instance: str | None = Field(None, description="URI of the specific occurrence") + errors: list[FieldError] | None = Field(None, description="Field-level errors for validation") # Sensitive field patterns to filter from identifiers diff --git a/pyrit/backend/models/converters.py b/pyrit/backend/models/converters.py index ba5ca5390d..dd216b84b3 100644 --- a/pyrit/backend/models/converters.py +++ b/pyrit/backend/models/converters.py @@ -7,7 +7,7 @@ This module defines the Instance models and preview functionality. """ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -38,9 +38,9 @@ class ConverterParameterSchema(BaseModel): name: str = Field(..., description="Parameter name") type_name: str = Field(..., description="Human-readable type (e.g. 'str', 'int', 'Literal[...]')") required: bool = Field(..., description="Whether the parameter must be provided") - default_value: Optional[str] = Field(None, description="String representation of default value, if any") - choices: Optional[list[str]] = Field(None, description="Allowed values for Literal types") - description: Optional[str] = Field(None, description="Parameter description from docstring") + default_value: str | None = Field(None, description="String representation of default value, if any") + choices: list[str] | None = Field(None, description="Allowed values for Literal types") + description: str | None = Field(None, description="Parameter description from docstring") class ConverterCatalogEntry(BaseModel): @@ -57,7 +57,7 @@ class ConverterCatalogEntry(BaseModel): default_factory=list, description="Constructor parameters for dynamic form generation" ) is_llm_based: bool = Field(False, description="Whether this converter requires an LLM target") - description: Optional[str] = Field(None, description="Short description of the converter from its docstring") + description: str | None = Field(None, description="Short description of the converter from its docstring") class ConverterCatalogResponse(BaseModel): @@ -76,17 +76,17 @@ class ConverterInstance(BaseModel): converter_id: str = Field(..., description="Unique converter instance identifier") converter_type: str = Field(..., description="Converter class name (e.g., 'Base64Converter')") - display_name: Optional[str] = Field(None, description="Human-readable display name") + display_name: str | None = Field(None, description="Human-readable display name") supported_input_types: list[str] = Field( default_factory=list, description="Input data types supported by this converter" ) supported_output_types: list[str] = Field( default_factory=list, description="Output data types produced by this converter" ) - converter_specific_params: Optional[dict[str, Any]] = Field( + converter_specific_params: dict[str, Any] | None = Field( None, description="Additional converter-specific parameters" ) - sub_converter_ids: Optional[list[str]] = Field( + sub_converter_ids: list[str] | None = Field( None, description="Converter IDs of sub-converters (for pipelines/composites)" ) @@ -101,7 +101,7 @@ class CreateConverterRequest(BaseModel): """Request to create a new converter instance.""" type: str = Field(..., description="Converter type (e.g., 'Base64Converter')") - display_name: Optional[str] = Field(None, description="Human-readable display name") + display_name: str | None = Field(None, description="Human-readable display name") params: dict[str, Any] = Field( default_factory=dict, description="Converter constructor parameters", @@ -113,7 +113,7 @@ class CreateConverterResponse(BaseModel): converter_id: str = Field(..., description="Unique converter instance identifier") converter_type: str = Field(..., description="Converter class name") - display_name: Optional[str] = Field(None, description="Human-readable display name") + display_name: str | None = Field(None, description="Human-readable display name") # ============================================================================ diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index 54480b76c7..aaac688cf0 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -10,7 +10,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -40,7 +40,7 @@ class RegisteredScenario(BaseModel): ) all_strategies: list[str] = Field(..., description="All available concrete strategy names") default_datasets: list[str] = Field(..., description="Default dataset names used by the scenario") - max_dataset_size: Optional[int] = Field(None, description="Maximum items per dataset (None means unlimited)") + max_dataset_size: int | None = Field(None, description="Maximum items per dataset (None means unlimited)") supported_parameters: list[ScenarioParameterSummary] = Field( default_factory=list, description="Scenario-declared custom parameters" ) diff --git a/pyrit/backend/models/targets.py b/pyrit/backend/models/targets.py index da512da155..944fbc358f 100644 --- a/pyrit/backend/models/targets.py +++ b/pyrit/backend/models/targets.py @@ -11,7 +11,7 @@ This module defines the Instance models for runtime target management. """ -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field @@ -56,16 +56,14 @@ class TargetInstance(BaseModel): target_registry_name: str = Field(..., description="Target registry key (e.g., 'azure_openai_chat')") target_type: str = Field(..., description="Target class name (e.g., 'OpenAIChatTarget')") - endpoint: Optional[str] = Field(None, description="Target endpoint URL") - model_name: Optional[str] = Field(None, description="Model or deployment name used in API calls") - underlying_model_name: Optional[str] = Field( - None, description="Underlying model name if different (e.g., 'gpt-4o')" - ) - temperature: Optional[float] = Field(None, description="Temperature parameter for generation") - top_p: Optional[float] = Field(None, description="Top-p parameter for generation") - max_requests_per_minute: Optional[int] = Field(None, description="Maximum requests per minute") + endpoint: str | None = Field(None, description="Target endpoint URL") + model_name: str | None = Field(None, description="Model or deployment name used in API calls") + underlying_model_name: str | None = Field(None, description="Underlying model name if different (e.g., 'gpt-4o')") + temperature: float | None = Field(None, description="Temperature parameter for generation") + top_p: float | None = Field(None, description="Top-p parameter for generation") + max_requests_per_minute: int | None = Field(None, description="Maximum requests per minute") capabilities: TargetCapabilitiesInfo = Field(..., description="Structured snapshot of target capabilities") - target_specific_params: Optional[dict[str, Any]] = Field(None, description="Additional target-specific parameters") + target_specific_params: dict[str, Any] | None = Field(None, description="Additional target-specific parameters") class TargetListResponse(BaseModel): diff --git a/pyrit/backend/pyrit_backend.py b/pyrit/backend/pyrit_backend.py index 0770ae501a..6c792f383d 100644 --- a/pyrit/backend/pyrit_backend.py +++ b/pyrit/backend/pyrit_backend.py @@ -17,12 +17,11 @@ import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter from pathlib import Path -from typing import Optional from pyrit.common.cli_helpers import CONFIG_FILE_HELP, validate_log_level_argparse -def parse_args(*, args: Optional[list[str]] = None) -> Namespace: +def parse_args(*, args: list[str] | None = None) -> Namespace: """ Parse command-line arguments for the PyRIT backend server. @@ -88,7 +87,7 @@ def parse_args(*, args: Optional[list[str]] = None) -> Namespace: return parser.parse_args(args) -def main(*, args: Optional[list[str]] = None) -> int: +def main(*, args: list[str] | None = None) -> int: """ Start the PyRIT backend server. diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index d6844d5041..7f41ec4339 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -10,7 +10,7 @@ import logging from collections.abc import Sequence -from typing import Literal, Optional +from typing import Literal from fastapi import APIRouter, HTTPException, Query, status @@ -39,7 +39,7 @@ router = APIRouter(prefix="/attacks", tags=["attacks"]) -def _parse_labels(label_params: Optional[list[str]]) -> Optional[dict[str, str | Sequence[str]]]: +def _parse_labels(label_params: list[str] | None) -> dict[str, str | Sequence[str]] | None: """ Parse 'key:value' label query params into a dict grouping values by key. @@ -69,13 +69,13 @@ def _parse_labels(label_params: Optional[list[str]]) -> Optional[dict[str, str | response_model=AttackListResponse, ) async def list_attacks( # pyrit-async-suffix-exempt - attack_types: Optional[list[str]] = Query( + attack_types: list[str] | None = Query( None, description="Filter by attack type names. May be specified multiple times to OR-match " "across types (e.g. ?attack_types=A&attack_types=B). Case-insensitive. " "Omit to return all attacks regardless of type.", ), - converter_types: Optional[list[str]] = Query( + converter_types: list[str] | None = Query( None, description="Filter by converter type names. May be specified multiple times; " "combination semantics are controlled by converter_types_match " @@ -88,24 +88,24 @@ async def list_attacks( # pyrit-async-suffix-exempt description="How to combine multiple converter_types: 'any' (attack has at least one) " "or 'all' (attack has every one). Defaults to 'all'.", ), - has_converters: Optional[bool] = Query( + has_converters: bool | None = Query( None, description="Filter by converter presence. true = attacks with at least one converter; " "false = attacks with no converters. Omit for no filter.", ), - outcome: Optional[Literal["undetermined", "success", "failure", "error"]] = Query( + outcome: Literal["undetermined", "success", "failure", "error"] | None = Query( None, description="Filter by outcome" ), - label: Optional[list[str]] = Query( + label: list[str] | None = Query( None, description="Filter by labels (format: key:value). May be specified multiple times; " "OR-matched within a key, AND-matched across keys " "(e.g. ?label=op:red&label=op:blue matches op=red OR op=blue).", ), - min_turns: Optional[int] = Query(None, ge=0, description="Filter by minimum executed turns"), - max_turns: Optional[int] = Query(None, ge=0, description="Filter by maximum executed turns"), + min_turns: int | None = Query(None, ge=0, description="Filter by minimum executed turns"), + max_turns: int | None = Query(None, ge=0, description="Filter by maximum executed turns"), limit: int = Query(20, ge=1, le=100, description="Maximum items per page"), - cursor: Optional[str] = Query( + cursor: str | None = Query( None, description="Pagination cursor: the attack_result_id of the last item from the previous page. " "Omit to start from the beginning. The response includes next_cursor for the next page.", diff --git a/pyrit/backend/routes/scenarios.py b/pyrit/backend/routes/scenarios.py index d75857210b..941d8021fb 100644 --- a/pyrit/backend/routes/scenarios.py +++ b/pyrit/backend/routes/scenarios.py @@ -12,8 +12,6 @@ /api/scenarios/runs — scenario execution lifecycle """ -from typing import Optional - from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.common import ProblemDetail @@ -41,7 +39,7 @@ ) async def list_scenarios( # pyrit-async-suffix-exempt limit: int = Query(50, ge=1, le=200, description="Maximum items per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor (scenario_name to start after)"), + cursor: str | None = Query(None, description="Pagination cursor (scenario_name to start after)"), ) -> ListRegisteredScenariosResponse: """ List all available scenarios. diff --git a/pyrit/backend/routes/targets.py b/pyrit/backend/routes/targets.py index 5a05ea41fd..bea53ddef2 100644 --- a/pyrit/backend/routes/targets.py +++ b/pyrit/backend/routes/targets.py @@ -8,8 +8,6 @@ Target types are set at app startup via initializers - you cannot add new types at runtime. """ -from typing import Optional - from fastapi import APIRouter, HTTPException, Query, status from pyrit.backend.models.common import ProblemDetail @@ -32,7 +30,7 @@ ) async def list_targets( # pyrit-async-suffix-exempt limit: int = Query(50, ge=1, le=200, description="Maximum items per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor (target_registry_name)"), + cursor: str | None = Query(None, description="Pagination cursor (target_registry_name)"), ) -> TargetListResponse: """ List target instances with pagination. diff --git a/pyrit/backend/routes/version.py b/pyrit/backend/routes/version.py index b59d176158..2d75c200aa 100644 --- a/pyrit/backend/routes/version.py +++ b/pyrit/backend/routes/version.py @@ -6,7 +6,6 @@ import json import logging from pathlib import Path -from typing import Optional from fastapi import APIRouter, Request from pydantic import BaseModel @@ -23,12 +22,12 @@ class VersionResponse(BaseModel): """Version information response model.""" version: str - source: Optional[str] = None - commit: Optional[str] = None - modified: Optional[bool] = None + source: str | None = None + commit: str | None = None + modified: bool | None = None display: str - database_info: Optional[str] = None - default_labels: Optional[dict[str, str]] = None + database_info: str | None = None + default_labels: dict[str, str] | None = None @router.get("", response_model=VersionResponse) @@ -62,7 +61,7 @@ async def get_version_async(request: Request) -> VersionResponse: logger.warning(f"Failed to load build info: {e}") # Detect current database backend - database_info: Optional[str] = None + database_info: str | None = None try: memory = CentralMemory.get_memory_instance() db_type = type(memory).__name__ @@ -74,7 +73,7 @@ async def get_version_async(request: Request) -> VersionResponse: logger.debug(f"Could not detect database info: {e}") # Read default labels from app state (set by pyrit_backend CLI) - default_labels: Optional[dict[str, str]] = getattr(request.app.state, "default_labels", None) or None + default_labels: dict[str, str] | None = getattr(request.app.state, "default_labels", None) or None return VersionResponse( version=version, diff --git a/pyrit/backend/services/converter_service.py b/pyrit/backend/services/converter_service.py index 66cb8bdc31..4f09248aaa 100644 --- a/pyrit/backend/services/converter_service.py +++ b/pyrit/backend/services/converter_service.py @@ -501,7 +501,7 @@ def _coerce_params(*, converter_class: type, params: dict[str, Any]) -> dict[str continue origin = get_origin(annotation) - # Unwrap Optional[X] to X + # Unwrap X | None to X if origin is Union: args = get_args(annotation) non_none = [a for a in args if a is not type(None)] diff --git a/pyrit/cli/_banner.py b/pyrit/cli/_banner.py index 0c8d4719eb..21e1f1c5bd 100644 --- a/pyrit/cli/_banner.py +++ b/pyrit/cli/_banner.py @@ -21,7 +21,6 @@ import time from dataclasses import dataclass, field from enum import Enum -from typing import Optional from pyrit.cli._banner_assets import BRAILLE_RACCOON, PYRIT_LETTERS, PYRIT_WIDTH, RACCOON_TAIL @@ -199,7 +198,7 @@ def _build_static_banner() -> StaticBannerData: color_map: dict[int, ColorRole] = {} segment_colors: dict[int, list[tuple[int, int, ColorRole]]] = {} - def add(line: str, role: ColorRole, segments: Optional[list[tuple[int, int, ColorRole]]] = None) -> None: + def add(line: str, role: ColorRole, segments: list[tuple[int, int, ColorRole]] | None = None) -> None: idx = len(lines) color_map[idx] = role if segments: @@ -559,14 +558,14 @@ def _render_line_with_segments( """ reset = _get_color(ColorRole.RESET, theme) # Build per-character color map (later segments override earlier ones) - char_roles: list[Optional[ColorRole]] = [None] * len(line) + char_roles: list[ColorRole | None] = [None] * len(line) for start, end, role in segments: for pos in range(start, min(end, len(line))): char_roles[pos] = role # Group consecutive same-role characters for efficient rendering result: list[str] = [] - current_role: Optional[ColorRole] = None + current_role: ColorRole | None = None for pos, ch in enumerate(line): char_role = char_roles[pos] if char_role != current_role: diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index e982b6ae06..eddad0f0f3 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -21,7 +21,7 @@ import logging import shlex from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, get_origin +from typing import TYPE_CHECKING, Any, get_origin from pyrit.common.cli_helpers import ( CONFIG_FILE_HELP, @@ -67,7 +67,7 @@ def validate_database(*, database: str) -> str: return database -def validate_integer(value: str, *, name: str = "value", min_value: Optional[int] = None) -> int: +def validate_integer(value: str, *, name: str = "value", min_value: int | None = None) -> int: """ Validate and parse an integer value. @@ -492,7 +492,7 @@ def _parse_shell_arguments(*, parts: list[str], arg_specs: list[_ArgSpec]) -> di return result -def parse_run_arguments(*, args_string: str, declared_params: Optional[list[Parameter]] = None) -> dict[str, Any]: +def parse_run_arguments(*, args_string: str, declared_params: list[Parameter] | None = None) -> dict[str, Any]: """ Parse run command arguments from a string (for shell mode). @@ -535,8 +535,8 @@ def parse_list_targets_arguments(*, args_string: str) -> dict[str, Any]: Returns: Dictionary with parsed arguments: - - initializers: Optional[list[str | dict[str, Any]]] - - initialization_scripts: Optional[list[str]] + - initializers: list[str | dict[str, Any]] | None + - initialization_scripts: list[str] | None Raises: ValueError: If parsing or validation fails. @@ -643,7 +643,7 @@ def extract_scenario_args(*, parsed: dict[str, Any]) -> dict[str, Any]: # --------------------------------------------------------------------------- -def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> Optional[list[Parameter]]: +def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> list[Parameter] | None: """ Build ``Parameter`` objects from a scenario catalog's ``supported_parameters``. @@ -655,7 +655,7 @@ def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> Optional[l api_params: List of parameter dicts from ``GET /api/scenarios/catalog/{name}``. Returns: - Optional[list[Parameter]]: Parameter list when ``api_params`` is non-empty, else ``None``. + list[Parameter] | None: Parameter list when ``api_params`` is non-empty, else ``None``. """ if not api_params: return None @@ -669,7 +669,7 @@ def build_parameters_from_api(*, api_params: list[dict[str, Any]]) -> Optional[l else: resolved_type = type_map.get(type_display) raw_choices = p.get("choices") - choices: Optional[tuple[Any, ...]] = tuple(raw_choices) if raw_choices else None + choices: tuple[Any, ...] | None = tuple(raw_choices) if raw_choices else None parameters.append( Parameter( name=p["name"], @@ -699,7 +699,7 @@ def add_common_arguments(parser: argparse.ArgumentParser) -> None: def merge_config_scenario_args( *, - config_scenario: Optional[ScenarioConfig], + config_scenario: ScenarioConfig | None, effective_scenario_name: str, cli_args: dict[str, Any], ) -> dict[str, Any]: @@ -711,7 +711,7 @@ def merge_config_scenario_args( Mutable values are deep-copied so they don't leak across runs. Args: - config_scenario (Optional[ScenarioConfig]): The ``scenario:`` block from + config_scenario (ScenarioConfig | None): The ``scenario:`` block from the layered config, or ``None`` when not configured. effective_scenario_name (str): The scenario about to run (CLI wins). cli_args (dict[str, Any]): Scenario args supplied on the CLI. diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index 57ee0a3326..1e4467f929 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -17,7 +17,7 @@ import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter from pathlib import Path -from typing import Any, Optional +from typing import Any from pyrit.cli._cli_args import ( ARG_HELP, @@ -332,7 +332,7 @@ def _extract_scenario_args(*, parsed: Namespace) -> dict[str, Any]: } -def parse_args(args: Optional[list[str]] = None) -> Namespace: +def parse_args(args: list[str] | None = None) -> Namespace: """ Parse command-line arguments (pass 1 — tolerant of scenario-declared flags). @@ -760,7 +760,7 @@ async def _run_async(*, parsed_args: Namespace) -> int: return 1 -def main(args: Optional[list[str]] = None) -> int: +def main(args: list[str] | None = None) -> int: """ Start the PyRIT scanner CLI. diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 78dbed1aeb..1a0760eb7c 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -18,7 +18,7 @@ import sys import threading from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from pyrit.cli import _banner as banner @@ -173,7 +173,7 @@ def _ensure_client(self) -> bool: self._start_server = False # only auto-start once return True - def cmdloop(self, intro: Optional[str] = None) -> None: + def cmdloop(self, intro: str | None = None) -> None: """Override cmdloop to play animated banner before starting the REPL.""" if intro is None: prev_disable = logging.root.manager.disable diff --git a/pyrit/common/csv_helper.py b/pyrit/common/csv_helper.py index 2fb831149e..51633c123b 100644 --- a/pyrit/common/csv_helper.py +++ b/pyrit/common/csv_helper.py @@ -10,7 +10,7 @@ def read_csv(file: IO[Any]) -> list[dict[str, str]]: Read a CSV file and return its rows as dictionaries. Returns: - List[Dict[str, str]]: Parsed CSV rows as dictionaries. + list[dict[str, str]]: Parsed CSV rows as dictionaries. """ reader = csv.DictReader(file) return list(reader) @@ -22,7 +22,7 @@ def write_csv(file: IO[Any], examples: list[dict[str, str]]) -> None: Args: file: A file-like object opened for writing CSV data. - examples (List[Dict[str, str]]): List of dictionaries to write as CSV rows. + examples (list[dict[str, str]]): List of dictionaries to write as CSV rows. """ if not examples: return diff --git a/pyrit/common/default_values.py b/pyrit/common/default_values.py index 9dbcba427f..4334cfb5d2 100644 --- a/pyrit/common/default_values.py +++ b/pyrit/common/default_values.py @@ -3,7 +3,7 @@ import logging import os -from typing import Any, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -38,7 +38,7 @@ def get_required_value(*, env_var_name: str, passed_value: Any) -> Any: raise ValueError(f"Environment variable {env_var_name} is required") -def get_non_required_value(*, env_var_name: str, passed_value: Optional[str] = None) -> str: +def get_non_required_value(*, env_var_name: str, passed_value: str | None = None) -> str: """ Get a non-required value from an environment variable or a passed value, preferring the passed value. diff --git a/pyrit/common/json_helper.py b/pyrit/common/json_helper.py index cbb0575e5a..fdee16bd3b 100644 --- a/pyrit/common/json_helper.py +++ b/pyrit/common/json_helper.py @@ -10,7 +10,7 @@ def read_json(file: IO[Any]) -> list[dict[str, str]]: Read a JSON file and return its content. Returns: - List[Dict[str, str]]: Parsed JSON content. + list[dict[str, str]]: Parsed JSON content. """ return cast("list[dict[str, str]]", json.load(file)) @@ -21,7 +21,7 @@ def write_json(file: IO[Any], examples: list[dict[str, str]]) -> None: Args: file: A file-like object opened for writing JSON data. - examples (List[Dict[str, str]]): List of dictionaries to write as JSON. + examples (list[dict[str, str]]): List of dictionaries to write as JSON. """ json.dump(examples, file) @@ -31,7 +31,7 @@ def read_jsonl(file: IO[Any]) -> list[dict[str, str]]: Read a JSONL file and return its content. Returns: - List[Dict[str, str]]: Parsed JSONL content. + list[dict[str, str]]: Parsed JSONL content. """ return [json.loads(line) for line in file if line.strip()] @@ -42,7 +42,7 @@ def write_jsonl(file: IO[Any], examples: list[dict[str, str]]) -> None: Args: file: A file-like object opened for writing JSONL data. - examples (List[Dict[str, str]]): List of dictionaries to write as JSONL. + examples (list[dict[str, str]]): List of dictionaries to write as JSONL. """ for example in examples: file.write(json.dumps(example) + "\n") diff --git a/pyrit/common/net_utility.py b/pyrit/common/net_utility.py index eb75f5616e..17d944e1e2 100644 --- a/pyrit/common/net_utility.py +++ b/pyrit/common/net_utility.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Any, Literal, Optional, cast, overload +from typing import Any, Literal, cast, overload from urllib.parse import parse_qs, urlparse, urlunparse import httpx @@ -88,10 +88,10 @@ async def make_request_and_raise_if_error_async( method: str, post_type: PostType = "json", debug: bool = False, - extra_url_parameters: Optional[dict[str, str]] = None, - request_body: Optional[dict[str, object]] = None, - files: Optional[dict[str, tuple[str, bytes, str]]] = None, - headers: Optional[dict[str, str]] = None, + extra_url_parameters: dict[str, str] | None = None, + request_body: dict[str, object] | None = None, + files: dict[str, tuple[str, bytes, str]] | None = None, + headers: dict[str, str] | None = None, **httpx_client_kwargs: Any, ) -> httpx.Response: """ diff --git a/pyrit/common/text_helper.py b/pyrit/common/text_helper.py index 849b7bf0f7..e37fe136f9 100644 --- a/pyrit/common/text_helper.py +++ b/pyrit/common/text_helper.py @@ -9,7 +9,7 @@ def read_txt(file: IO[Any]) -> list[dict[str, str]]: Read a TXT file and return its content. Returns: - List[Dict[str, str]]: Parsed TXT content. + list[dict[str, str]]: Parsed TXT content. """ return [{"prompt": line.strip()} for line in file.readlines() if line.strip()] @@ -20,6 +20,6 @@ def write_txt(file: IO[Any], examples: list[dict[str, str]]) -> None: Args: file: A file-like object opened for writing TXT data. - examples (List[Dict[str, str]]): List of dictionaries to write as TXT. + examples (list[dict[str, str]]): List of dictionaries to write as TXT. """ file.write("\n".join([ex["prompt"] for ex in examples])) diff --git a/pyrit/common/utils.py b/pyrit/common/utils.py index a7203ca336..7f2119e29f 100644 --- a/pyrit/common/utils.py +++ b/pyrit/common/utils.py @@ -8,12 +8,12 @@ import math import random from pathlib import Path -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar logger = logging.getLogger(__name__) -def verify_and_resolve_path(path: Union[str, Path]) -> Path: +def verify_and_resolve_path(path: str | Path) -> Path: """ Verify that a path is valid and resolve it to an absolute path. @@ -21,7 +21,7 @@ def verify_and_resolve_path(path: Union[str, Path]) -> Path: such as in scorers, converters, or other components that accept file paths. Args: - path (Union[str, Path]): A path as a string or Path object. + path (str | Path): A path as a string or Path object. Returns: Path: The resolved absolute Path object. @@ -39,9 +39,7 @@ def verify_and_resolve_path(path: Union[str, Path]) -> Path: return path_obj -def combine_dict( - existing_dict: Optional[dict[str, Any]] = None, new_dict: Optional[dict[str, Any]] = None -) -> dict[str, Any]: +def combine_dict(existing_dict: dict[str, Any] | None = None, new_dict: dict[str, Any] | None = None) -> dict[str, Any]: """ Combine two dictionaries containing string keys and values into one. @@ -58,13 +56,13 @@ def combine_dict( return result -def combine_list(list1: Union[str, list[str]], list2: Union[str, list[str]]) -> list[str]: +def combine_list(list1: str | list[str], list2: str | list[str]) -> list[str]: """ Combine two lists or strings into a single list with unique values. Args: - list1 (Union[str, List[str]]): First list or string to combine. - list2 (Union[str, List[str]]): Second list or string to combine. + list1 (str | list[str]): First list or string to combine. + list2 (str | list[str]): Second list or string to combine. Returns: list: Combined list containing unique values from both inputs. @@ -91,7 +89,7 @@ def get_random_indices(*, start: int, size: int, proportion: float) -> list[int] For example, if `proportion` is 0.5 and `size` is 10, 5 randomly selected indices will be returned. Returns: - List[int]: A list of randomly selected indices based on the specified proportion. + list[int]: A list of randomly selected indices based on the specified proportion. Raises: ValueError: If `start` is negative, `size` is not positive, or `proportion` is not between 0 and 1. @@ -126,7 +124,7 @@ def to_sha256(data: str) -> str: def warn_if_set( - *, config: Any, unused_fields: list[str], log: Union[logging.Logger, logging.LoggerAdapter[logging.Logger]] = logger + *, config: Any, unused_fields: list[str], log: logging.Logger | logging.LoggerAdapter[logging.Logger] = logger ) -> None: """ Warn about unused parameters in configurations. @@ -137,8 +135,8 @@ def warn_if_set( Args: config (Any): The configuration object to check for unused fields. - unused_fields (List[str]): List of field names to check in the config object. - log (Union[logging.Logger, logging.LoggerAdapter]): Logger to use for warning messages. + unused_fields (list[str]): List of field names to check in the config object. + log (logging.Logger | logging.LoggerAdapter): Logger to use for warning messages. """ config_name = config.__class__.__name__ @@ -169,20 +167,20 @@ def get_kwarg_param( param_name: str, expected_type: type[_T], required: bool = True, - default_value: Optional[_T] = None, -) -> Optional[_T]: + default_value: _T | None = None, +) -> _T | None: """ Validate and extract a parameter from kwargs. Args: - kwargs (Dict[str, Any]): The dictionary containing parameters. + kwargs (dict[str, Any]): The dictionary containing parameters. param_name (str): The name of the parameter to validate. - expected_type (Type[_T]): The expected type of the parameter. + expected_type (type[_T]): The expected type of the parameter. required (bool): Whether the parameter is required. If True, raises ValueError if missing. - default_value (Optional[_T]): Default value to return if the parameter is not required and not present. + default_value (_T | None): Default value to return if the parameter is not required and not present. Returns: - Optional[_T]: The validated parameter value if present and valid, otherwise None. + _T | None: The validated parameter value if present and valid, otherwise None. Raises: ValueError: If the parameter is missing or None. diff --git a/pyrit/common/yaml_loadable.py b/pyrit/common/yaml_loadable.py index a7857b7ad7..2fb4422c73 100644 --- a/pyrit/common/yaml_loadable.py +++ b/pyrit/common/yaml_loadable.py @@ -3,7 +3,7 @@ import abc from pathlib import Path -from typing import TypeVar, Union +from typing import TypeVar import yaml @@ -18,7 +18,7 @@ class YamlLoadable(abc.ABC): # noqa: B024 """ @classmethod - def from_yaml_file(cls: type[T], file: Union[Path | str]) -> T: + def from_yaml_file(cls: type[T], file: Path | str) -> T: """ Create a new object from a YAML file. diff --git a/pyrit/datasets/executors/question_answer/wmdp_dataset.py b/pyrit/datasets/executors/question_answer/wmdp_dataset.py index 1270c9b6c0..81f4711747 100644 --- a/pyrit/datasets/executors/question_answer/wmdp_dataset.py +++ b/pyrit/datasets/executors/question_answer/wmdp_dataset.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from datasets import load_dataset @@ -12,7 +11,7 @@ ) -def fetch_wmdp_dataset(category: Optional[str] = None) -> QuestionAnsweringDataset: +def fetch_wmdp_dataset(category: str | None = None) -> QuestionAnsweringDataset: """ Fetch WMDP examples and create a QuestionAnsweringDataset. diff --git a/pyrit/datasets/jailbreak/text_jailbreak.py b/pyrit/datasets/jailbreak/text_jailbreak.py index 6e5083bd42..b4affc3d51 100644 --- a/pyrit/datasets/jailbreak/text_jailbreak.py +++ b/pyrit/datasets/jailbreak/text_jailbreak.py @@ -5,7 +5,7 @@ import random import threading from pathlib import Path -from typing import Any, Optional +from typing import Any from pyrit.common.path import JAILBREAK_TEMPLATES_PATH from pyrit.models import SeedPrompt @@ -18,7 +18,7 @@ class TextJailBreak: A class that manages jailbreak datasets (like DAN, etc.). """ - _template_cache: Optional[dict[str, list[Path]]] = None + _template_cache: dict[str, list[Path]] | None = None _cache_lock: threading.Lock = threading.Lock() @classmethod @@ -30,7 +30,7 @@ def _scan_template_files(cls) -> dict[str, list[Path]]: by filename so duplicate names across subdirectories are detectable. Returns: - Dict[str, List[Path]]: Mapping of filename to list of matching paths. + dict[str, list[Path]]: Mapping of filename to list of matching paths. """ result: dict[str, list[Path]] = {} for path in JAILBREAK_TEMPLATES_PATH.rglob("*.yaml"): @@ -46,7 +46,7 @@ def _get_template_cache(cls) -> dict[str, list[Path]]: Thread-safe: uses a lock to prevent concurrent scans from racing. Returns: - Dict[str, List[Path]]: Cached mapping of filename to list of matching paths. + dict[str, list[Path]]: Cached mapping of filename to list of matching paths. """ if cls._template_cache is None: with cls._cache_lock: @@ -85,7 +85,7 @@ def _get_all_template_paths(cls) -> list[Path]: Return a flat list of all cached template file paths. Returns: - List[Path]: All template paths (excluding multi_parameter), in no particular order. + list[Path]: All template paths (excluding multi_parameter), in no particular order. Raises: ValueError: If no templates are available. @@ -99,9 +99,9 @@ def _get_all_template_paths(cls) -> list[Path]: def __init__( self, *, - template_path: Optional[str] = None, - template_file_name: Optional[str] = None, - string_template: Optional[str] = None, + template_path: str | None = None, + template_file_name: str | None = None, + string_template: str | None = None, random_template: bool = False, **kwargs: Any, ) -> None: @@ -208,7 +208,7 @@ def _apply_extra_kwargs(self, kwargs: dict[str, Any]) -> None: self.template.value = self.template.render_template_value_silent(**kwargs) @classmethod - def get_jailbreak_templates(cls, num_templates: Optional[int] = None) -> list[str]: + def get_jailbreak_templates(cls, num_templates: int | None = None) -> list[str]: """ Retrieve all jailbreaks from the JAILBREAK_TEMPLATES_PATH. @@ -216,7 +216,7 @@ def get_jailbreak_templates(cls, num_templates: Optional[int] = None) -> list[st num_templates (int, optional): Number of jailbreak templates to return. None to get all. Returns: - List[str]: List of jailbreak template file names. + list[str]: List of jailbreak template file names. Raises: ValueError: If no jailbreak templates are found in the jailbreak directory. diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 18f8343330..c7023e8202 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -5,7 +5,7 @@ from collections.abc import Callable from dataclasses import fields from pathlib import Path -from typing import Any, Optional +from typing import Any import yaml @@ -76,7 +76,7 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: logger.error(f"Failed to load local dataset from {self.file_path}: {e}") raise - async def _parse_metadata_async(self) -> Optional[SeedDatasetMetadata]: + async def _parse_metadata_async(self) -> SeedDatasetMetadata | None: """ Extract metadata from a local YAML file and coerce raw values into typed schema fields. @@ -84,7 +84,7 @@ async def _parse_metadata_async(self) -> Optional[SeedDatasetMetadata]: enum and set types expected by SeedDatasetMetadata before _match_filter can work. Returns: - Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None. + SeedDatasetMetadata | None: Parsed metadata if available, otherwise None. Raises: Exception: If the dataset file cannot be read. diff --git a/pyrit/datasets/seed_datasets/remote/_image_cache.py b/pyrit/datasets/seed_datasets/remote/_image_cache.py index dbc866a47b..bdb502ad23 100644 --- a/pyrit/datasets/seed_datasets/remote/_image_cache.py +++ b/pyrit/datasets/seed_datasets/remote/_image_cache.py @@ -14,7 +14,7 @@ import logging from collections.abc import Mapping from pathlib import Path -from typing import Any, Optional +from typing import Any from pyrit.common.net_utility import make_request_and_raise_if_error_async from pyrit.models import data_serializer_factory @@ -25,11 +25,11 @@ async def fetch_and_cache_image_async( *, filename: str, - image_url: Optional[str] = None, - image_bytes: Optional[bytes] = None, + image_url: str | None = None, + image_bytes: bytes | None = None, log_prefix: str = "image-cache", - request_headers: Optional[Mapping[str, str]] = None, - request_timeout: Optional[float] = None, + request_headers: Mapping[str, str] | None = None, + request_timeout: float | None = None, follow_redirects: bool = False, ) -> str: """ diff --git a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py index 0f64ef7d29..6cc697a462 100644 --- a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Literal, Optional +from typing import Literal from datasets import load_dataset from typing_extensions import override @@ -81,35 +81,34 @@ class _AegisContentSafetyDataset(_RemoteDatasetLoader): def __init__( self, *, - harm_categories: Optional[ - list[ - Literal[ - "Controlled/Regulated Substances", - "Copyright/Trademark/Plagiarism", - "Criminal Planning/Confessions", - "Fraud/Deception", - "Guns and Illegal Weapons", - "Harassment", - "Hate/Identity Hate", - "High Risk Gov Decision Making", - "Illegal Activity", - "Immoral/Unethical", - "Malware", - "Manipulation", - "Needs Caution", - "Other", - "PII/Privacy", - "Political/Misinformation/Conspiracy", - "Profanity", - "Sexual", - "Sexual (minor)", - "Suicide and Self Harm", - "Threat", - "Unauthorized Advice", - "Violence", - ] + harm_categories: list[ + Literal[ + "Controlled/Regulated Substances", + "Copyright/Trademark/Plagiarism", + "Criminal Planning/Confessions", + "Fraud/Deception", + "Guns and Illegal Weapons", + "Harassment", + "Hate/Identity Hate", + "High Risk Gov Decision Making", + "Illegal Activity", + "Immoral/Unethical", + "Malware", + "Manipulation", + "Needs Caution", + "Other", + "PII/Privacy", + "Political/Misinformation/Conspiracy", + "Profanity", + "Sexual", + "Sexual (minor)", + "Suicide and Self Harm", + "Threat", + "Unauthorized Advice", + "Violence", ] - ] = None, + ] + | None = None, ) -> None: """ Initialize the NVIDIA Aegis AI Content Safety Dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/agent_threat_rules_dataset.py b/pyrit/datasets/seed_datasets/remote/agent_threat_rules_dataset.py index cea36ef544..7577d154f1 100644 --- a/pyrit/datasets/seed_datasets/remote/agent_threat_rules_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/agent_threat_rules_dataset.py @@ -3,7 +3,7 @@ import logging from enum import Enum -from typing import Literal, Optional +from typing import Literal from typing_extensions import override @@ -136,10 +136,10 @@ def __init__( "db793f9/data/autoresearch/adversarial-samples.json" ), source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[list[ATRCategory]] = None, - techniques: Optional[list[str]] = None, - detection_fields: Optional[list[ATRDetectionField]] = None, - variation_types: Optional[list[ATRVariationType]] = None, + categories: list[ATRCategory] | None = None, + techniques: list[str] | None = None, + detection_fields: list[ATRDetectionField] | None = None, + variation_types: list[ATRVariationType] | None = None, ) -> None: """ Initialize the ATR dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py b/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py index 34698b4255..d99332116b 100644 --- a/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py @@ -3,7 +3,7 @@ import ast import logging -from typing import Literal, Optional +from typing import Literal from typing_extensions import override @@ -64,22 +64,21 @@ def __init__( language: Literal[ "English", "Hindi", "French", "Spanish", "Arabic", "Russian", "Serbian", "Tagalog" ] = "English", - harm_categories: Optional[ - list[ - Literal[ - "Bullying & Harassment", - "Discrimination & Injustice", - "Graphic material", - "Harms of Representation Allocation and Quality of Service", - "Hate Speech", - "Non-consensual sexual content", - "Profanity", - "Self-Harm", - "Violence, Threats & Incitement", - ] + harm_categories: list[ + Literal[ + "Bullying & Harassment", + "Discrimination & Injustice", + "Graphic material", + "Harms of Representation Allocation and Quality of Service", + "Hate Speech", + "Non-consensual sexual content", + "Profanity", + "Self-Harm", + "Violence, Threats & Incitement", ] - ] = None, - harm_scope: Optional[Literal["global", "local"]] = None, + ] + | None = None, + harm_scope: Literal["global", "local"] | None = None, ) -> None: """ Initialize the Aya Red-teaming dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py b/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py index 73f33d89a6..94cffaccd5 100644 --- a/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal from typing_extensions import override @@ -58,7 +58,7 @@ def __init__( self, *, source: str = "Babelscape/ALERT", - category: Optional[Literal["alert", "alert_adversarial"]] = "alert_adversarial", + category: Literal["alert", "alert_adversarial"] | None = "alert_adversarial", ) -> None: """ Initialize the Babelscape ALERT dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/figstep_dataset.py b/pyrit/datasets/seed_datasets/remote/figstep_dataset.py index 5f0505a4b1..3a7e10a34b 100644 --- a/pyrit/datasets/seed_datasets/remote/figstep_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/figstep_dataset.py @@ -9,7 +9,7 @@ import zipfile from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal from typing_extensions import override @@ -172,8 +172,8 @@ def __init__( *, use_tiny: bool = True, variant: FigStepVariant = FigStepVariant.FIGSTEP, - categories: Optional[list[FigStepCategory]] = None, - source: Optional[str] = None, + categories: list[FigStepCategory] | None = None, + source: str | None = None, source_type: Literal["public_url", "file"] = "public_url", ) -> None: """ @@ -248,8 +248,8 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: required_keys = {"dataset", "category_id", "task_id", "category_name", "question", "instruction"} rows = self._fetch_from_url(source=self.source, source_type=self.source_type, cache=cache) - pro_extract_dir: Optional[Path] = None - pro_benign_sentences: Optional[list[str]] = None + pro_extract_dir: Path | None = None + pro_benign_sentences: list[str] | None = None if self.variant == FigStepVariant.FIGSTEP_PRO: pro_extract_dir, pro_benign_sentences = await self._ensure_figstep_pro_assets_async(cache=cache) diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py index 5728ecf88b..d2e7973399 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import Literal, Optional +from typing import Literal from typing_extensions import override @@ -81,7 +81,7 @@ def __init__( "harmbench_behaviors_multimodal_all.csv" ), source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[list[SemanticCategory]] = None, + categories: list[SemanticCategory] | None = None, ) -> None: """ Initialize the HarmBench multimodal dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py b/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py index 5ff18a8953..251cfb5405 100644 --- a/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py @@ -6,7 +6,7 @@ import uuid import zipfile from enum import Enum -from typing import Literal, Optional +from typing import Literal from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -87,7 +87,7 @@ def __init__( source: str = "JailbreakV-28K/JailBreakV-28k", zip_dir: str = str(pathlib.Path.home()), split: Literal["JailBreakV_28K", "mini_JailBreakV_28K"] = "mini_JailBreakV_28K", - harm_categories: Optional[list[_HarmCategory]] = None, + harm_categories: list[_HarmCategory] | None = None, ) -> None: """ Initialize the JailBreakV-28K dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/jailbreakv_redteam_2k_dataset.py b/pyrit/datasets/seed_datasets/remote/jailbreakv_redteam_2k_dataset.py index e76b294d83..c4218b82d9 100644 --- a/pyrit/datasets/seed_datasets/remote/jailbreakv_redteam_2k_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jailbreakv_redteam_2k_dataset.py @@ -3,7 +3,6 @@ import logging from enum import Enum -from typing import Optional from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, @@ -77,7 +76,7 @@ def __init__( self, *, source: str = "JailbreakV-28K/JailBreakV-28k", - harm_categories: Optional[list[_HarmCategory]] = None, + harm_categories: list[_HarmCategory] | None = None, ) -> None: """ Initialize the JailBreakV Redteam_2k dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/mm_safetybench_dataset.py b/pyrit/datasets/seed_datasets/remote/mm_safetybench_dataset.py index 50240b8b6c..4c686ef07b 100644 --- a/pyrit/datasets/seed_datasets/remote/mm_safetybench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/mm_safetybench_dataset.py @@ -5,7 +5,7 @@ import logging import uuid from enum import Enum -from typing import Any, Optional +from typing import Any from pyrit.datasets.seed_datasets.remote._image_cache import ( fetch_and_cache_image_async, @@ -178,9 +178,9 @@ def __init__( self, *, variant: MMSafetyBenchVariant = MMSafetyBenchVariant.SD_TYPOGRAPHY, - categories: Optional[list[MMSafetyBenchCategory]] = None, + categories: list[MMSafetyBenchCategory] | None = None, use_tiny: bool = False, - token: Optional[str] = None, + token: str | None = None, ) -> None: """ Initialize the MM-SafetyBench dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py b/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py index 2d1a775920..f1435bff8b 100644 --- a/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Literal, Optional +from typing import Literal from typing_extensions import override @@ -57,31 +57,30 @@ def __init__( *, source: str = "PKU-Alignment/PKU-SafeRLHF", include_safe_prompts: bool = True, - filter_harm_categories: Optional[ - list[ - Literal[ - "Animal Abuse", - "Copyright Issues", - "Cybercrime", - "Discriminatory Behavior", - "Disrupting Public Order", - "Drugs", - "Economic Crime", - "Endangering National Security", - "Endangering Public Health", - "Environmental Damage", - "Human Trafficking", - "Insulting Behavior", - "Mental Manipulation", - "Physical Harm", - "Privacy Violation", - "Psychological Harm", - "Sexual Content", - "Violence", - "White-Collar Crime", - ] + filter_harm_categories: list[ + Literal[ + "Animal Abuse", + "Copyright Issues", + "Cybercrime", + "Discriminatory Behavior", + "Disrupting Public Order", + "Drugs", + "Economic Crime", + "Endangering National Security", + "Endangering Public Health", + "Environmental Damage", + "Human Trafficking", + "Insulting Behavior", + "Mental Manipulation", + "Physical Harm", + "Privacy Violation", + "Psychological Harm", + "Sexual Content", + "Violence", + "White-Collar Crime", ] - ] = None, + ] + | None = None, ) -> None: """ Initialize the PKU-SafeRLHF dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py b/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py index 69862353b2..7438a0c14c 100644 --- a/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py @@ -5,7 +5,7 @@ import os from datetime import datetime from enum import Enum -from typing import Any, Optional +from typing import Any import requests from typing_extensions import override @@ -77,10 +77,10 @@ class _PromptIntelDataset(_RemoteDatasetLoader): def __init__( self, *, - api_key: Optional[str] = None, - severity: Optional[PromptIntelSeverity] = None, - categories: Optional[list[PromptIntelCategory]] = None, - search: Optional[str] = None, + api_key: str | None = None, + severity: PromptIntelSeverity | None = None, + categories: list[PromptIntelCategory] | None = None, + search: str | None = None, ) -> None: """ Initialize the PromptIntel dataset loader. @@ -123,7 +123,7 @@ def _fetch_all_prompts(self) -> list[dict[str, Any]]: category and results are merged with deduplication by prompt ID. Returns: - List[Dict[str, Any]]: All fetched prompt records. + list[dict[str, Any]]: All fetched prompt records. Raises: ValueError: If no API key is provided and PROMPTINTEL_API_KEY is not set. @@ -141,7 +141,7 @@ def _fetch_all_prompts(self) -> list[dict[str, Any]]: } # Build list of category values to fetch; [None] means fetch all categories - categories_to_fetch: list[Optional[str]] = [c.value for c in self._categories] if self._categories else [None] + categories_to_fetch: list[str | None] = [c.value for c in self._categories] if self._categories else [None] all_prompts: list[dict[str, Any]] = [] seen_ids: set[str] = set() @@ -189,7 +189,7 @@ def _fetch_all_prompts(self) -> list[dict[str, Any]]: return all_prompts - def _parse_datetime(self, date_str: Optional[str]) -> Optional[datetime]: + def _parse_datetime(self, date_str: str | None) -> datetime | None: """ Parse an ISO 8601 datetime string from the API. @@ -214,7 +214,7 @@ def _build_metadata(self, record: dict[str, Any]) -> dict[str, str | int]: record: A single prompt record from the API. Returns: - Dict[str, str | int]: Metadata dictionary with string or integer values. + dict[str, str | int]: Metadata dictionary with string or integer values. """ metadata: dict[str, str | int] = {} @@ -254,7 +254,7 @@ def _build_metadata(self, record: dict[str, Any]) -> dict[str, str | int]: return metadata - def _convert_record_to_seed_prompt(self, record: dict[str, Any]) -> Optional[SeedPrompt]: + def _convert_record_to_seed_prompt(self, record: dict[str, Any]) -> SeedPrompt | None: """ Convert a single PromptIntel record into a SeedPrompt. diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 0fc9bdd3b9..f29151935a 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -13,7 +13,7 @@ from dataclasses import fields from enum import Enum from pathlib import Path -from typing import Any, Literal, Optional, TextIO, cast +from typing import Any, Literal, TextIO, cast from urllib.parse import urlparse import requests @@ -155,7 +155,7 @@ def _read_cache(self, *, cache_file: Path, file_type: str) -> list[dict[str, str file_type (str): The file extension/type. Returns: - List[Dict[str, str]]: The cached examples. + list[dict[str, str]]: The cached examples. Raises: ValueError: If the file_type is invalid. @@ -170,7 +170,7 @@ def _write_cache(self, *, cache_file: Path, examples: list[dict[str, str]], file Args: cache_file (Path): Path to the cache file. - examples (List[Dict[str, str]]): The examples to cache. + examples (list[dict[str, str]]): The examples to cache. file_type (str): The file extension/type. Raises: @@ -190,7 +190,7 @@ def _fetch_from_public_url(self, *, source: str, file_type: str) -> list[dict[st file_type: The file extension/type. Returns: - List[Dict[str, str]]: The fetched examples. + list[dict[str, str]]: The fetched examples. Raises: ValueError: If the file_type is invalid. @@ -220,7 +220,7 @@ def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str file_type: The file extension/type. Returns: - List[Dict[str, str]]: The fetched examples. + list[dict[str, str]]: The fetched examples. Raises: ValueError: If the file_type is invalid. @@ -247,7 +247,7 @@ def _fetch_from_url( cache: Whether to cache the fetched examples. Defaults to True. Returns: - List[Dict[str, str]]: A list of examples. + list[dict[str, str]]: A list of examples. Raises: ValueError: If the file_type is invalid. @@ -288,10 +288,10 @@ async def _fetch_from_huggingface_async( self, *, dataset_name: str, - config: Optional[str] = None, - split: Optional[str] = None, + config: str | None = None, + split: str | None = None, cache: bool = True, - token: Optional[str] = None, + token: str | None = None, **kwargs: Any, ) -> Any: """ @@ -356,7 +356,7 @@ def _load_dataset_sync() -> Any: logger.error(f"Failed to load HuggingFace dataset {dataset_name}: {e}") raise - async def _parse_metadata_async(self) -> Optional[SeedDatasetMetadata]: + async def _parse_metadata_async(self) -> SeedDatasetMetadata | None: """ Extract metadata from class attributes, wrap in sets, and format into SeedDatasetMetadata. @@ -364,7 +364,7 @@ async def _parse_metadata_async(self) -> Optional[SeedDatasetMetadata]: All are normalized into sets for the unified SeedDatasetMetadata schema. Returns: - Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None. + SeedDatasetMetadata | None: Parsed metadata if available, otherwise None. """ valid_fields = [f.name for f in fields(SeedDatasetMetadata)] @@ -423,7 +423,7 @@ async def _fetch_zip_from_url_async( def _download_and_parse() -> dict[str, list[dict[str, Any]]]: zip_path: Path - temp_to_clean: Optional[Path] = None + temp_to_clean: Path | None = None if cache and cache_path.exists(): zip_path = cache_path else: diff --git a/pyrit/datasets/seed_datasets/remote/siuo_dataset.py b/pyrit/datasets/seed_datasets/remote/siuo_dataset.py index a0c8e3df03..b9199539ff 100644 --- a/pyrit/datasets/seed_datasets/remote/siuo_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/siuo_dataset.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal from typing_extensions import override @@ -115,7 +115,7 @@ def __init__( *, source: str = GEN_JSON_URL, source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[list[SIUOCategory]] = None, + categories: list[SIUOCategory] | None = None, ) -> None: """ Initialize the SIUO dataset loader. @@ -125,7 +125,7 @@ def __init__( HuggingFace mirror pinned to a commit SHA for reproducibility. source_type (Literal["public_url", "file"]): Whether source is a public URL or a local file path. Defaults to 'public_url'. - categories (Optional[list[SIUOCategory]]): Optional filter; only rows + categories (list[SIUOCategory] | None): Optional filter; only rows whose category matches one of these enum values are included. If None, every category is included. diff --git a/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py index 43b5ec49db..478b776497 100644 --- a/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py @@ -3,7 +3,6 @@ import logging import os -from typing import Optional from typing_extensions import override @@ -133,9 +132,9 @@ def __init__( self, *, source: str = "sorry-bench/sorry-bench-202503", - categories: Optional[list[str]] = None, - prompt_style: Optional[str] = None, - token: Optional[str] = None, + categories: list[str] | None = None, + prompt_style: str | None = None, + token: str | None = None, ) -> None: """ Initialize the Sorry-Bench dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py index e18a072150..7daad427e1 100644 --- a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import Literal, Optional +from typing import Literal from typing_extensions import override @@ -84,8 +84,8 @@ def __init__( *, source: str = METADATA_URL, source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[list[VisualLeakBenchCategory]] = None, - pii_types: Optional[list[VisualLeakBenchPIIType]] = None, + categories: list[VisualLeakBenchCategory] | None = None, + pii_types: list[VisualLeakBenchPIIType] | None = None, ) -> None: """ Initialize the VisualLeakBench dataset loader. diff --git a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py index 9fa573bbb1..362d4b7fd1 100644 --- a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import Literal, Optional +from typing import Literal from typing_extensions import override @@ -81,9 +81,9 @@ def __init__( *, source: str = "https://raw.githubusercontent.com/apple/ml-vlsu/main/data/VLSU.csv", source_type: Literal["public_url", "file"] = "public_url", - categories: Optional[list[VLSUCategory]] = None, - unsafe_grades: Optional[list[str]] = None, - max_examples: Optional[int] = None, + categories: list[VLSUCategory] | None = None, + unsafe_grades: list[str] | None = None, + max_examples: int | None = None, ) -> None: """ Initialize the ML-VLSU multimodal dataset loader. diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 4c91271722..8518b57a5b 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import fields as dc_fields -from typing import Any, Optional +from typing import Any from tqdm import tqdm @@ -126,7 +126,7 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: # pyrit-as ) return await self.fetch_dataset_async(cache=cache) - async def _parse_metadata_async(self) -> Optional[SeedDatasetMetadata]: + async def _parse_metadata_async(self) -> SeedDatasetMetadata | None: """ Parse provider-specific metadata into the shared schema. @@ -135,7 +135,7 @@ async def _parse_metadata_async(self) -> Optional[SeedDatasetMetadata]: returns None, which means metadata is not available for this provider. Returns: - Optional[SeedDatasetMetadata]: Parsed metadata for this provider, or None. + SeedDatasetMetadata | None: Parsed metadata for this provider, or None. """ return None @@ -145,20 +145,20 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: Get all registered dataset provider classes. Returns: - Dict[str, Type[SeedDatasetProvider]]: Dictionary mapping class names to provider classes. + dict[str, type[SeedDatasetProvider]]: Dictionary mapping class names to provider classes. """ return cls._registry.copy() @classmethod - async def get_all_dataset_names_async(cls, filters: Optional[SeedDatasetFilter] = None) -> list[str]: + async def get_all_dataset_names_async(cls, filters: SeedDatasetFilter | None = None) -> list[str]: """ Get the names of all registered datasets. Args: - filters (Optional[SeedDatasetFilter]): List of filters to apply. + filters (SeedDatasetFilter | None): List of filters to apply. Returns: - List[str]: List of dataset names from all registered providers. + list[str]: List of dataset names from all registered providers. Raises: ValueError: If no providers are registered or if providers cannot be instantiated. @@ -279,7 +279,7 @@ def _match_single_criterion( async def fetch_datasets_async( cls, *, - dataset_names: Optional[list[str]] = None, + dataset_names: list[str] | None = None, cache: bool = True, max_concurrency: int = 5, ) -> list[SeedDataset]: @@ -297,7 +297,7 @@ async def fetch_datasets_async( Set to 1 for fully sequential execution. Returns: - List[SeedDataset]: List of all fetched datasets. + list[SeedDataset]: List of all fetched datasets. Raises: ValueError: If any requested dataset_name does not exist. @@ -321,12 +321,12 @@ async def fetch_datasets_async( async def fetch_single_dataset_async( provider_name: str, provider_class: type["SeedDatasetProvider"] - ) -> Optional[tuple[str, SeedDataset]]: + ) -> tuple[str, SeedDataset] | None: """ Fetch a single dataset with error handling. Returns: - Optional[Tuple[str, SeedDataset]]: Tuple of provider name and dataset, or None if filtered. + tuple[str, SeedDataset] | None: Tuple of provider name and dataset, or None if filtered. """ provider = provider_class() @@ -347,12 +347,12 @@ async def fetch_single_dataset_async( async def fetch_with_semaphore_async( provider_name: str, provider_class: type["SeedDatasetProvider"] - ) -> Optional[tuple[str, SeedDataset]]: + ) -> tuple[str, SeedDataset] | None: """ Enforce concurrency limit and update progress during dataset fetch. Returns: - Optional[Tuple[str, SeedDataset]]: Tuple of provider name and dataset, or None if filtered. + tuple[str, SeedDataset] | None: Tuple of provider name and dataset, or None if filtered. """ async with semaphore: result = await fetch_single_dataset_async(provider_name, provider_class) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 7b40b95a6b..cdd1149a85 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -4,7 +4,7 @@ import logging from dataclasses import dataclass, fields from enum import Enum -from typing import Any, ClassVar, Literal, Optional +from typing import Any, ClassVar, Literal logger = logging.getLogger(__name__) @@ -94,12 +94,12 @@ class SeedDatasetMetadata: # All fields are optional sets to support both real metadata (single-element) # and filter criteria (multi-element). SINGULAR_FIELDS enforces that parsers # only produce single-element sets for size and source_type. - tags: Optional[set[str]] = None - size: Optional[set[str]] = None - modalities: Optional[set[str]] = None - source_type: Optional[set[str]] = None - load_time: Optional[set[SeedDatasetLoadTime]] = None - harm_categories: Optional[set[str]] = None + tags: set[str] | None = None + size: set[str] | None = None + modalities: set[str] | None = None + source_type: set[str] | None = None + load_time: set[SeedDatasetLoadTime] | None = None + harm_categories: set[str] | None = None # Fields that must have at most 1 element in real dataset metadata. SINGULAR_FIELDS: ClassVar[frozenset[str]] = frozenset({"size", "source_type"}) @@ -195,7 +195,7 @@ class SeedDatasetFilter: def __init__( self, *, - criteria: Optional[list[SeedDatasetMetadata]] = None, + criteria: list[SeedDatasetMetadata] | None = None, strict_match: bool = False, **kwargs: Any, ) -> None: diff --git a/pyrit/embedding/openai_text_embedding.py b/pyrit/embedding/openai_text_embedding.py index 5efbb69107..86a774c404 100644 --- a/pyrit/embedding/openai_text_embedding.py +++ b/pyrit/embedding/openai_text_embedding.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Awaitable, Callable -from typing import Any, Optional +from typing import Any import tenacity from openai import AsyncOpenAI @@ -31,9 +31,9 @@ class OpenAITextEmbedding(EmbeddingSupport): def __init__( self, *, - api_key: Optional[str | Callable[[], str | Awaitable[str]]] = None, - endpoint: Optional[str] = None, - model_name: Optional[str] = None, + api_key: str | Callable[[], str | Awaitable[str]] | None = None, + endpoint: str | None = None, + model_name: str | None = None, ) -> None: """ Initialize text embedding client for Azure OpenAI or platform OpenAI. diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index e39fa7a88a..5b4f3de72f 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -6,7 +6,7 @@ import os from abc import ABC from collections.abc import Callable -from typing import Any, Optional +from typing import Any from openai import RateLimitError from tenacity import ( @@ -176,14 +176,14 @@ def __init__(self, *, status_code: int = 429, message: str = "Rate Limit Excepti class ServerErrorException(PyritException): """Exception class for opaque 5xx errors returned by the server.""" - def __init__(self, *, status_code: int = 500, message: str = "Server Error", body: Optional[str] = None) -> None: + def __init__(self, *, status_code: int = 500, message: str = "Server Error", body: str | None = None) -> None: """ Initialize a server error exception. Args: status_code (int): Status code for the error. message (str): Error message. - body (Optional[str]): Optional raw server response body. + body (str | None): Optional raw server response body. """ super().__init__(status_code=status_code, message=message) @@ -247,7 +247,7 @@ class ExperimentalWarning(FutureWarning): def pyrit_custom_result_retry( - retry_function: Callable[..., bool], retry_max_num_attempts: Optional[int] = None + retry_function: Callable[..., bool], retry_max_num_attempts: int | None = None ) -> Callable[..., Any]: """ Apply retry logic with exponential backoff to a function. diff --git a/pyrit/exceptions/exception_context.py b/pyrit/exceptions/exception_context.py index 9b45ac8737..1e6c6e4699 100644 --- a/pyrit/exceptions/exception_context.py +++ b/pyrit/exceptions/exception_context.py @@ -13,7 +13,7 @@ from contextvars import ContextVar from dataclasses import dataclass, field from enum import Enum -from typing import Any, Optional +from typing import Any from pyrit.models import ComponentIdentifier @@ -59,25 +59,25 @@ class ExecutionContext: component_role: ComponentRole = ComponentRole.UNKNOWN # The attack strategy class name (e.g., "PromptSendingAttack") - attack_strategy_name: Optional[str] = None + attack_strategy_name: str | None = None # The identifier for the attack strategy - attack_identifier: Optional[ComponentIdentifier] = None + attack_identifier: ComponentIdentifier | None = None # The identifier from the component's get_identifier() (target, scorer, etc.) - component_identifier: Optional[ComponentIdentifier] = None + component_identifier: ComponentIdentifier | None = None # The objective target conversation ID if available - objective_target_conversation_id: Optional[str] = None + objective_target_conversation_id: str | None = None # The endpoint/URI if available (extracted from component_identifier for quick access) - endpoint: Optional[str] = None + endpoint: str | None = None # The component class name (extracted from component_identifier.__type__ for quick access) - component_name: Optional[str] = None + component_name: str | None = None # The attack objective if available - objective: Optional[str] = None + objective: str | None = None def get_retry_context_string(self) -> str: """ @@ -135,15 +135,15 @@ def get_exception_details(self) -> str: # The contextvar that stores the current execution context -_execution_context: ContextVar[Optional[ExecutionContext]] = ContextVar("execution_context", default=None) +_execution_context: ContextVar[ExecutionContext | None] = ContextVar("execution_context", default=None) -def get_execution_context() -> Optional[ExecutionContext]: +def get_execution_context() -> ExecutionContext | None: """ Get the current execution context. Returns: - Optional[ExecutionContext]: The current context, or None if not set. + ExecutionContext | None: The current context, or None if not set. """ return _execution_context.get() @@ -213,11 +213,11 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def execution_context( *, component_role: ComponentRole, - attack_strategy_name: Optional[str] = None, - attack_identifier: Optional[ComponentIdentifier] = None, - component_identifier: Optional[ComponentIdentifier] = None, - objective_target_conversation_id: Optional[str] = None, - objective: Optional[str] = None, + attack_strategy_name: str | None = None, + attack_identifier: ComponentIdentifier | None = None, + component_identifier: ComponentIdentifier | None = None, + objective_target_conversation_id: str | None = None, + objective: str | None = None, ) -> ExecutionContextManager: """ Create an execution context manager with the specified parameters. diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index b7d554775c..b4b47175f2 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -39,7 +39,7 @@ def mark_messages_as_simulated(messages: Sequence[Message]) -> list[Message]: messages (Sequence[Message]): The messages to mark as simulated. Returns: - List[Message]: The same messages with assistant roles converted to simulated_assistant. + list[Message]: The same messages with assistant roles converted to simulated_assistant. Modifies the messages in place and also returns them for convenience. """ result = list(messages) @@ -56,7 +56,7 @@ def get_adversarial_chat_messages( adversarial_chat_conversation_id: str, attack_identifier: ComponentIdentifier, adversarial_chat_target_identifier: ComponentIdentifier, - labels: Optional[dict[str, str]] = None, # deprecated + labels: dict[str, str] | None = None, # deprecated ) -> list[Message]: """ Transform prepended conversation messages for adversarial chat with swapped roles. @@ -146,7 +146,7 @@ async def build_conversation_context_string_async(messages: list[Message]) -> st return await normalizer.normalize_string_async(messages) -def get_prepended_turn_count(prepended_conversation: Optional[list[Message]]) -> int: +def get_prepended_turn_count(prepended_conversation: list[Message] | None) -> int: """ Count the number of turns (assistant responses) in a prepended conversation. @@ -191,7 +191,7 @@ def __init__( self, *, attack_identifier: ComponentIdentifier, - prompt_normalizer: Optional[PromptNormalizer] = None, + prompt_normalizer: PromptNormalizer | None = None, ) -> None: """ Initialize the conversation manager. @@ -219,9 +219,7 @@ def get_conversation(self, conversation_id: str) -> list[Message]: conversation = self._memory.get_conversation(conversation_id=conversation_id) return list(conversation) - def get_last_message( - self, *, conversation_id: str, role: Optional[ChatMessageRole] = None - ) -> Optional[MessagePiece]: + def get_last_message(self, *, conversation_id: str, role: ChatMessageRole | None = None) -> MessagePiece | None: """ Retrieve the most recent message from a conversation. @@ -251,7 +249,7 @@ def set_system_prompt( target: PromptTarget, conversation_id: str, system_prompt: str, - labels: Optional[dict[str, str]] = None, # deprecated + labels: dict[str, str] | None = None, # deprecated ) -> None: """ Set or update the system prompt for a conversation. @@ -288,10 +286,10 @@ async def initialize_context_async( context: "AttackContext[Any]", target: PromptTarget, conversation_id: str, - request_converters: Optional[list[PromptConverterConfiguration]] = None, + request_converters: list[PromptConverterConfiguration] | None = None, prepended_conversation_config: Optional["PrependedConversationConfig"] = None, - max_turns: Optional[int] = None, - memory_labels: Optional[dict[str, str]] = None, + max_turns: int | None = None, + memory_labels: dict[str, str] | None = None, ) -> ConversationState: """ Initialize attack context with prepended conversation and merged labels. @@ -438,9 +436,9 @@ async def add_prepended_conversation_to_memory_async( *, prepended_conversation: list[Message], conversation_id: str, - request_converters: Optional[list[PromptConverterConfiguration]] = None, + request_converters: list[PromptConverterConfiguration] | None = None, prepended_conversation_config: Optional["PrependedConversationConfig"] = None, - max_turns: Optional[int] = None, + max_turns: int | None = None, ) -> int: """ Add prepended conversation messages to memory for a chat target. @@ -519,9 +517,9 @@ async def _process_prepended_for_chat_target_async( context: "AttackContext[Any]", prepended_conversation: list[Message], conversation_id: str, - request_converters: Optional[list[PromptConverterConfiguration]], + request_converters: list[PromptConverterConfiguration] | None, prepended_conversation_config: Optional["PrependedConversationConfig"], - max_turns: Optional[int], + max_turns: int | None, ) -> ConversationState: """ Process prepended conversation for a chat target. @@ -587,7 +585,7 @@ async def _apply_converters_async( *, message: Message, request_converters: list[PromptConverterConfiguration], - apply_to_roles: Optional[list[ChatMessageRole]], + apply_to_roles: list[ChatMessageRole] | None, ) -> None: """ Apply converters to message pieces. diff --git a/pyrit/executor/attack/compound/sequential_attack.py b/pyrit/executor/attack/compound/sequential_attack.py index 7e851507d4..15e6eb1dc3 100644 --- a/pyrit/executor/attack/compound/sequential_attack.py +++ b/pyrit/executor/attack/compound/sequential_attack.py @@ -26,7 +26,7 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from pydantic import Field @@ -107,8 +107,8 @@ class SequentialChildAttack: strategy: AttackStrategy[Any, AttackResult] seed_group: SeedAttackGroup - adversarial_chat: Optional[PromptTarget] = None - objective_scorer: Optional[TrueFalseScorer] = None + adversarial_chat: PromptTarget | None = None + objective_scorer: TrueFalseScorer | None = None memory_labels: Mapping[str, str] = field(default_factory=dict) @@ -288,7 +288,7 @@ async def _run_child_attack_async( *, child_attack: SequentialChildAttack, memory_labels: dict[str, str], - attribution: Optional[AttackResultAttribution] = None, + attribution: AttackResultAttribution | None = None, ) -> AttackResult: """ Execute one child attack via ``AttackExecutor`` and return its result. diff --git a/pyrit/executor/attack/core/attack_config.py b/pyrit/executor/attack/core/attack_config.py index c86131f769..803c6c4296 100644 --- a/pyrit/executor/attack/core/attack_config.py +++ b/pyrit/executor/attack/core/attack_config.py @@ -3,7 +3,6 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, Union from pyrit.executor.core import StrategyConverterConfig from pyrit.models import SeedPrompt @@ -26,10 +25,10 @@ class AttackAdversarialConfig: target: PromptTarget # Path to the YAML file containing the system prompt for the adversarial chat target - system_prompt_path: Optional[Union[str, Path]] = None + system_prompt_path: str | Path | None = None # Seed prompt for the adversarial chat target (supports {{ objective }} template variable) - seed_prompt: Union[str, SeedPrompt] = "Generate your first message to achieve: {{ objective }}" + seed_prompt: str | SeedPrompt = "Generate your first message to achieve: {{ objective }}" @dataclass @@ -42,10 +41,10 @@ class AttackScoringConfig: """ # Primary scorer for evaluating attack effectiveness - objective_scorer: Optional[TrueFalseScorer] = None + objective_scorer: TrueFalseScorer | None = None # Refusal scorer for detecting refusals or non-compliance - refusal_scorer: Optional[TrueFalseScorer] = None + refusal_scorer: TrueFalseScorer | None = None # Additional scorers for auxiliary metrics or custom evaluations auxiliary_scorers: list[Scorer] = field(default_factory=list) diff --git a/pyrit/executor/attack/core/attack_executor.py b/pyrit/executor/attack/core/attack_executor.py index 88c2108b8b..4f3ecb2cbe 100644 --- a/pyrit/executor/attack/core/attack_executor.py +++ b/pyrit/executor/attack/core/attack_executor.py @@ -145,8 +145,8 @@ def __init__(self, *, max_concurrency: int = 1) -> None: # and then run it under more than one ``asyncio.run(...)`` invocation. By # constructing the semaphore inside ``_get_semaphore()`` and rebuilding when the # running loop changes, one AttackExecutor instance is safe to reuse across loops. - self._semaphore: Optional[asyncio.Semaphore] = None - self._semaphore_loop: Optional[asyncio.AbstractEventLoop] = None + self._semaphore: asyncio.Semaphore | None = None + self._semaphore_loop: asyncio.AbstractEventLoop | None = None def _get_semaphore(self) -> asyncio.Semaphore: """ @@ -174,9 +174,9 @@ async def execute_attack_from_seed_groups_async( seed_groups: Sequence[SeedAttackGroup], adversarial_chat: Optional["PromptTarget"] = None, objective_scorer: Optional["TrueFalseScorer"] = None, - field_overrides: Optional[Sequence[dict[str, Any]]] = None, + field_overrides: Sequence[dict[str, Any]] | None = None, return_partial_on_failure: bool = False, - attribution: Optional[AttackResultAttribution] = None, + attribution: AttackResultAttribution | None = None, **broadcast_fields: Any, ) -> AttackExecutorResult[AttackStrategyResultT]: """ @@ -254,9 +254,9 @@ async def execute_attack_async( *, attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT], objectives: Sequence[str], - field_overrides: Optional[Sequence[dict[str, Any]]] = None, + field_overrides: Sequence[dict[str, Any]] | None = None, return_partial_on_failure: bool = False, - attribution: Optional[AttackResultAttribution] = None, + attribution: AttackResultAttribution | None = None, **broadcast_fields: Any, ) -> AttackExecutorResult[AttackStrategyResultT]: """ @@ -323,7 +323,7 @@ async def _execute_with_params_list_async( attack: AttackStrategy[AttackStrategyContextT, AttackStrategyResultT], params_list: Sequence[AttackParameters], return_partial_on_failure: bool = False, - attribution: Optional[AttackResultAttribution] = None, + attribution: AttackResultAttribution | None = None, ) -> AttackExecutorResult[AttackStrategyResultT]: """ Execute attacks in parallel with a list of pre-built parameters. diff --git a/pyrit/executor/attack/core/attack_parameters.py b/pyrit/executor/attack/core/attack_parameters.py index 72739c59ed..f4ef53100a 100644 --- a/pyrit/executor/attack/core/attack_parameters.py +++ b/pyrit/executor/attack/core/attack_parameters.py @@ -5,7 +5,7 @@ import dataclasses from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from pyrit.models import Message, SeedAttackGroup, SeedGroup @@ -33,13 +33,13 @@ class AttackParameters: objective: str # Optional message to send to the objective target (overrides objective if provided) - next_message: Optional[Message] = None + next_message: Message | None = None # Conversation that is automatically prepended to the target model - prepended_conversation: Optional[list[Message]] = None + prepended_conversation: list[Message] | None = None # Additional labels that can be applied to the prompts throughout the attack - memory_labels: Optional[dict[str, str]] = field(default_factory=dict) + memory_labels: dict[str, str] | None = field(default_factory=dict) def __str__(self) -> str: """Return a nicely formatted string representation of the attack parameters.""" @@ -78,8 +78,8 @@ async def from_seed_group_async( cls: type[AttackParamsT], *, seed_group: SeedAttackGroup, - adversarial_chat: Optional[PromptTarget] = None, - objective_scorer: Optional[TrueFalseScorer] = None, + adversarial_chat: PromptTarget | None = None, + objective_scorer: TrueFalseScorer | None = None, **overrides: Any, ) -> AttackParamsT: """ @@ -151,7 +151,7 @@ async def from_seed_group_async( if objective_scorer is None: raise ValueError("objective_scorer is required when seed_group has a simulated conversation config") - # Generate the simulated conversation - returns List[SeedPrompt] + # Generate the simulated conversation - returns list[SeedPrompt] simulated_prompts = await generate_simulated_conversation_async( objective=seed_group.objective.value, adversarial_chat=adversarial_chat, diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 333dde19f0..ec91008974 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -10,7 +10,7 @@ import uuid from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, overload from pyrit.common.logger import logger from pyrit.exceptions.retry_collector import ( @@ -68,16 +68,16 @@ class AttackContext(StrategyContext, ABC, Generic[AttackParamsT]): related_conversations: set[ConversationReference] = field(default_factory=set) # Mutable overrides for attacks that generate these values internally - _next_message_override: Optional[Message] = None - _prepended_conversation_override: Optional[list[Message]] = None - _memory_labels_override: Optional[dict[str, str]] = None + _next_message_override: Message | None = None + _prepended_conversation_override: list[Message] | None = None + _memory_labels_override: dict[str, str] | None = None # Optional attribution from an upstream orchestrator (e.g. Scenario). When # set, the persistence path stamps attribution_parent_id + attribution_data # onto the resulting AttackResult so it can be located later for hydration # and resume. Set by AttackExecutor per-task before scheduling. Stays None # for ad-hoc/direct attack execution outside any orchestrator. - _attribution: Optional[AttackResultAttribution] = None + _attribution: AttackResultAttribution | None = None # Convenience properties that delegate to params or overrides @property @@ -115,7 +115,7 @@ def prepended_conversation(self, value: list[Message]) -> None: self._prepended_conversation_override = value @property - def next_message(self) -> Optional[Message]: + def next_message(self) -> Message | None: """Optional message to send to the objective target.""" # Check override first (for attacks that generate internally) if self._next_message_override is not None: @@ -126,7 +126,7 @@ def next_message(self) -> Optional[Message]: return None @next_message.setter - def next_message(self, value: Optional[Message]) -> None: + def next_message(self, value: Message | None) -> None: """Set the next message (for attacks that generate internally).""" self._next_message_override = value @@ -385,7 +385,7 @@ def __init__( Args: objective_target (PromptTarget): The target system to attack. context_type (type[AttackStrategyContextT]): The type of context this strategy operates on. - params_type (Type[AttackParamsT]): The type of parameters this strategy accepts. + params_type (type[AttackParamsT]): The type of parameters this strategy accepts. Defaults to AttackParameters. Use AttackParameters.excluding() to create a params type that rejects certain fields. logger (logging.Logger): Logger instance for logging events. @@ -409,8 +409,8 @@ def __init__( def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the attack strategy identifier. @@ -420,15 +420,15 @@ def _create_identifier( additional params or children. Args: - params (Optional[Dict[str, Any]]): Additional behavioral parameters from + params (dict[str, Any] | None): Additional behavioral parameters from the subclass. - children (Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]]): + children (dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None): Named child component identifiers. Returns: ComponentIdentifier: The identifier for this attack strategy. """ - all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { + all_children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = { "objective_target": self.get_objective_target().get_identifier(), } @@ -472,7 +472,7 @@ def params_type(self) -> type[AttackParameters]: Get the parameters type for this attack strategy. Returns: - Type[AttackParameters]: The parameters type this strategy accepts. + type[AttackParameters]: The parameters type this strategy accepts. """ return self._params_type @@ -485,12 +485,12 @@ def get_objective_target(self) -> PromptTarget: """ return self._objective_target - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. Returns: - Optional[AttackScoringConfig]: The scoring configuration, or None if not applicable. + AttackScoringConfig | None: The scoring configuration, or None if not applicable. Note: Subclasses that use scoring should override this method to return their @@ -512,9 +512,9 @@ async def execute_async( self, *, objective: str, - next_message: Optional[Message] = None, - prepended_conversation: Optional[list[Message]] = None, - memory_labels: Optional[dict[str, str]] = None, + next_message: Message | None = None, + prepended_conversation: list[Message] | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> AttackStrategyResultT: ... @@ -538,9 +538,9 @@ async def execute_async( Args: objective (str): The objective of the attack. - next_message (Optional[Message]): Message to send to the target. - prepended_conversation (Optional[List[Message]]): Conversation to prepend. - memory_labels (Optional[Dict[str, str]]): Memory labels for the attack context. + next_message (Message | None): Message to send to the target. + prepended_conversation (list[Message] | None): Conversation to prepend. + memory_labels (dict[str, str] | None): Memory labels for the attack context. **kwargs: Additional context-specific parameters (conversation_id, system_prompt, etc.). Returns: diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 6289b39841..4e7e5caefa 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -5,7 +5,7 @@ import textwrap from dataclasses import dataclass, field from string import Formatter -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.exceptions import ComponentRole, execution_context @@ -102,9 +102,9 @@ def __init__( total_length: int = 200, chunk_type: str = "characters", request_template: str = DEFAULT_TEMPLATE, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, ) -> None: """ Initialize the chunked request attack strategy. @@ -116,9 +116,9 @@ def __init__( chunk_type (str): Type of chunk to request (e.g., "characters", "bytes", "words"). request_template (str): Template for generating chunk requests (default: "Give me {chunk_type} {start}-{end} of '{objective}'"). - attack_converter_config (Optional[AttackConverterConfig]): Configuration for prompt converters. - attack_scoring_config (Optional[AttackScoringConfig]): Configuration for scoring components. - prompt_normalizer (Optional[PromptNormalizer]): Normalizer for handling prompts. + attack_converter_config (AttackConverterConfig | None): Configuration for prompt converters. + attack_scoring_config (AttackScoringConfig | None): Configuration for scoring components. + prompt_normalizer (PromptNormalizer | None): Normalizer for handling prompts. Raises: ValueError: If chunk_size or total_length are invalid. @@ -167,7 +167,7 @@ def __init__( attack_scoring_config = attack_scoring_config or AttackScoringConfig() self._auxiliary_scorers = attack_scoring_config.auxiliary_scorers - self._objective_scorer: Optional[TrueFalseScorer] = attack_scoring_config.objective_scorer + self._objective_scorer: TrueFalseScorer | None = attack_scoring_config.objective_scorer # Initialize prompt normalizer and conversation manager self._prompt_normalizer = prompt_normalizer or PromptNormalizer() @@ -176,12 +176,12 @@ def __init__( prompt_normalizer=self._prompt_normalizer, ) - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. Returns: - Optional[AttackScoringConfig]: The scoring configuration with objective and auxiliary scorers. + AttackScoringConfig | None: The scoring configuration with objective and auxiliary scorers. """ return AttackScoringConfig( objective_scorer=self._objective_scorer, @@ -209,7 +209,7 @@ def _generate_chunk_prompts(self, context: ChunkedRequestAttackContext) -> list[ context (ChunkedRequestAttackContext): The attack context. Returns: - List[str]: List of chunk request prompts. + list[str]: List of chunk request prompts. """ prompts = [] start = 1 @@ -333,16 +333,16 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac def _determine_attack_outcome( self, *, - score: Optional[Score], - ) -> tuple[AttackOutcome, Optional[str]]: + score: Score | None, + ) -> tuple[AttackOutcome, str | None]: """ Determine the outcome of the attack based on the score. Args: - score (Optional[Score]): The objective score (if any). + score (Score | None): The objective score (if any). Returns: - tuple[AttackOutcome, Optional[str]]: A tuple of (outcome, outcome_reason). + tuple[AttackOutcome, str | None]: A tuple of (outcome, outcome_reason). """ if not self._objective_scorer: return AttackOutcome.UNDETERMINED, "No objective scorer configured" @@ -359,7 +359,7 @@ async def _score_combined_value_async( *, combined_value: str, objective: str, - ) -> Optional[Score]: + ) -> Score | None: """ Score the combined chunk responses against the objective. @@ -368,7 +368,7 @@ async def _score_combined_value_async( objective (str): The natural-language description of the attack's objective. Returns: - Optional[Score]: The score from the objective scorer if configured, or None if + Score | None: The score from the objective scorer if configured, or None if no objective scorer is set. """ if not self._objective_scorer: diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index a241affa66..bc987f1270 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -8,7 +8,7 @@ import re from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, cast from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH @@ -75,7 +75,7 @@ class CrescendoAttackContext(MultiTurnAttackContext[Any]): """Context for the Crescendo attack strategy.""" # Text that was refused by the target in the previous attempt (used for backtracking) - refused_text: Optional[str] = None + refused_text: str | None = None # Counter for number of backtracks performed during the attack backtrack_count: int = 0 @@ -144,12 +144,12 @@ def __init__( *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] attack_adversarial_config: AttackAdversarialConfig, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_backtracks: int = 10, max_turns: int = 10, - prepended_conversation_config: Optional[PrependedConversationConfig] = None, + prepended_conversation_config: PrependedConversationConfig | None = None, ) -> None: """ Initialize the Crescendo attack strategy. @@ -159,13 +159,13 @@ def __init__( support editable conversation history. attack_adversarial_config (AttackAdversarialConfig): Configuration for the adversarial component, including the adversarial chat target and optional system prompt path. - attack_converter_config (Optional[AttackConverterConfig]): Configuration for attack converters, + attack_converter_config (AttackConverterConfig | None): Configuration for attack converters, including request and response converters. - attack_scoring_config (Optional[AttackScoringConfig]): Configuration for scoring responses. - prompt_normalizer (Optional[PromptNormalizer]): Normalizer for prompts. + attack_scoring_config (AttackScoringConfig | None): Configuration for scoring responses. + prompt_normalizer (PromptNormalizer | None): Normalizer for prompts. max_backtracks (int): Maximum number of backtracks allowed. max_turns (int): Maximum number of turns allowed. - prepended_conversation_config (Optional[PrependedConversationConfiguration]): + prepended_conversation_config (PrependedConversationConfiguration | None): Configuration for how to process prepended conversations. Controls converter application by role, message normalization, and non-chat target behavior. @@ -249,12 +249,12 @@ def __init__( # Store the prepended conversation configuration self._prepended_conversation_config = prepended_conversation_config - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. Returns: - Optional[AttackScoringConfig]: The scoring configuration with objective scorer, + AttackScoringConfig | None: The scoring configuration with objective scorer, auxiliary scorers, and refusal scorer. """ return AttackScoringConfig( @@ -315,7 +315,7 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None: ) # Set up adversarial chat with prepended conversation - adversarial_chat_context: Optional[str] = None + adversarial_chat_context: str | None = None if context.prepended_conversation: # Build context string for system prompt normalizer = ConversationContextNormalizer() @@ -760,12 +760,12 @@ async def _backtrack_memory_async(self, *, conversation_id: str) -> str: self._logger.debug(f"Backtracked conversation from {conversation_id} to {new_conversation_id}") return new_conversation_id - def _set_adversarial_chat_system_prompt_template(self, *, system_prompt_template_path: Union[Path, str]) -> None: + def _set_adversarial_chat_system_prompt_template(self, *, system_prompt_template_path: Path | str) -> None: """ Set the system prompt template for the adversarial chat. Args: - system_prompt_template_path (Union[Path, str]): Path to the system prompt template. + system_prompt_template_path (Path | str): Path to the system prompt template. Raises: ValueError: If the template doesn't contain required parameters. diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index 2c4255984b..e15ef6c63d 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -47,7 +47,7 @@ class MultiPromptSendingAttackParameters(AttackParameters): Only accepts objective and user_messages fields. """ - user_messages: Optional[list[Message]] = None + user_messages: list[Message] | None = None @classmethod async def from_seed_group_async( @@ -137,18 +137,18 @@ def __init__( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, ) -> None: """ Initialize the multi-prompt sending attack strategy. Args: objective_target (PromptTarget): The target system to attack. - attack_converter_config (Optional[AttackConverterConfig]): Configuration for prompt converters. - attack_scoring_config (Optional[AttackScoringConfig]): Configuration for scoring components. - prompt_normalizer (Optional[PromptNormalizer]): Normalizer for handling prompts. + attack_converter_config (AttackConverterConfig | None): Configuration for prompt converters. + attack_scoring_config (AttackScoringConfig | None): Configuration for scoring components. + prompt_normalizer (PromptNormalizer | None): Normalizer for handling prompts. Raises: ValueError: If the objective scorer is not a true/false scorer. @@ -179,12 +179,12 @@ def __init__( prompt_normalizer=self._prompt_normalizer, ) - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. Returns: - Optional[AttackScoringConfig]: The scoring configuration with objective and auxiliary scorers. + AttackScoringConfig | None: The scoring configuration with objective and auxiliary scorers. """ return AttackScoringConfig( objective_scorer=self._objective_scorer, @@ -301,20 +301,20 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac def _determine_attack_outcome( self, *, - response: Optional[Message], - score: Optional[Score], + response: Message | None, + score: Score | None, context: MultiTurnAttackContext[Any], - ) -> tuple[AttackOutcome, Optional[str]]: + ) -> tuple[AttackOutcome, str | None]: """ Determine the outcome of the attack based on the response and score. Args: - response (Optional[Message]): The last response from the target (if any). - score (Optional[Score]): The objective score (if any). + response (Message | None): The last response from the target (if any). + score (Score | None): The objective score (if any). context (MultiTurnAttackContext): The attack context containing configuration. Returns: - tuple[AttackOutcome, Optional[str]]: A tuple of (outcome, outcome_reason). + tuple[AttackOutcome, str | None]: A tuple of (outcome, outcome_reason). """ if not self._objective_scorer: # No scorer means we can't determine success/failure @@ -340,7 +340,7 @@ async def _teardown_async(self, *, context: MultiTurnAttackContext[Any]) -> None async def _send_prompt_to_objective_target_async( self, *, current_message: Message, context: MultiTurnAttackContext[Any] - ) -> Optional[Message]: + ) -> Message | None: """ Send the prompt to the target and return the response. @@ -349,7 +349,7 @@ async def _send_prompt_to_objective_target_async( context (MultiTurnAttackContext): The attack context containing parameters and labels. Returns: - Optional[Message]: The model's response if successful, or None if + Message | None: The model's response if successful, or None if the request was filtered, blocked, or encountered an error. """ with execution_context( @@ -370,7 +370,7 @@ async def _send_prompt_to_objective_target_async( attack_identifier=self.get_identifier(), ) - async def _evaluate_response_async(self, *, response: Message, objective: str) -> Optional[Score]: + async def _evaluate_response_async(self, *, response: Message, objective: str) -> Score | None: """ Evaluate the response against the objective using the configured scorers. @@ -382,7 +382,7 @@ async def _evaluate_response_async(self, *, response: Message, objective: str) - objective (str): The natural-language description of the attack's objective. Returns: - Optional[Score]: The score from the objective scorer if configured, or None if + Score | None: The score from the objective scorer if configured, or None if no objective scorer is set. Note that auxiliary scorer results are not returned but are still executed and stored. """ diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 809b150988..4aca72d054 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -7,7 +7,7 @@ import uuid from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -58,10 +58,10 @@ class MultiTurnAttackContext(AttackContext[AttackParamsT]): executed_turns: int = 0 # Model response produced in the latest turn - last_response: Optional[Message] = None + last_response: Message | None = None # Score assigned to the latest response by a scorer component - last_score: Optional[Score] = None + last_score: Score | None = None class MultiTurnAttackStrategy(AttackStrategy[MultiTurnAttackStrategyContextT, AttackStrategyResultT], ABC): @@ -85,7 +85,7 @@ def __init__( Args: objective_target (PromptTarget): The target system to attack. context_type (type[MultiTurnAttackContext]): The type of context this strategy will use. - params_type (Type[AttackParamsT]): The type of parameters this strategy accepts. + params_type (type[AttackParamsT]): The type of parameters this strategy accepts. logger (logging.Logger): Logger instance for logging events and messages. """ super().__init__( diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 84cb085fca..8c0d34c6eb 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -6,7 +6,7 @@ import enum import logging from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_RED_TEAM_PATH @@ -100,9 +100,9 @@ def __init__( *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] attack_adversarial_config: AttackAdversarialConfig, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_turns: int = 10, score_last_turn_only: bool = False, ) -> None: @@ -175,12 +175,12 @@ def __init__( self._max_turns = max_turns self._score_last_turn_only = score_last_turn_only - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. Returns: - Optional[AttackScoringConfig]: The scoring configuration with objective scorer + AttackScoringConfig | None: The scoring configuration with objective scorer and use_score_as_feedback. """ return AttackScoringConfig( @@ -578,7 +578,7 @@ async def _send_prompt_to_objective_target_async( return response - async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) -> Optional[Score]: + async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) -> Score | None: """ Evaluate the objective target's response with the objective scorer. @@ -589,7 +589,7 @@ async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) - context (MultiTurnAttackContext): The attack context containing the response to score. Returns: - Optional[Score]: The score of the response if available, otherwise None. + Score | None: The score of the response if available, otherwise None. """ if not context.last_response: logger.warning("No response available in context to score") @@ -613,12 +613,12 @@ async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) - objective_scores = scoring_results return objective_scores[0] if objective_scores else None - def _set_adversarial_chat_seed_prompt(self, *, seed_prompt: Union[str, SeedPrompt]) -> None: + def _set_adversarial_chat_seed_prompt(self, *, seed_prompt: str | SeedPrompt) -> None: """ Set the seed prompt for the adversarial chat. Args: - seed_prompt (Union[str, SeedPrompt]): The seed prompt to set for the adversarial chat. + seed_prompt (str | SeedPrompt): The seed prompt to set for the adversarial chat. Raises: ValueError: If the seed prompt is not a string or SeedPrompt object. diff --git a/pyrit/executor/attack/multi_turn/simulated_conversation.py b/pyrit/executor/attack/multi_turn/simulated_conversation.py index a814424c92..895516243c 100644 --- a/pyrit/executor/attack/multi_turn/simulated_conversation.py +++ b/pyrit/executor/attack/multi_turn/simulated_conversation.py @@ -11,7 +11,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pyrit.executor.attack.core.attack_config import ( AttackAdversarialConfig, @@ -39,11 +39,11 @@ async def generate_simulated_conversation_async( objective_scorer: TrueFalseScorer, num_turns: int = 3, starting_sequence: int = 0, - adversarial_chat_system_prompt_path: Union[str, Path], - simulated_target_system_prompt_path: Optional[Union[str, Path]] = None, - next_message_system_prompt_path: Optional[Union[str, Path]] = None, - attack_converter_config: Optional[AttackConverterConfig] = None, - memory_labels: Optional[dict[str, str]] = None, + adversarial_chat_system_prompt_path: str | Path, + simulated_target_system_prompt_path: str | Path | None = None, + next_message_system_prompt_path: str | Path | None = None, + attack_converter_config: AttackConverterConfig | None = None, + memory_labels: dict[str, str] | None = None, ) -> list[SeedPrompt]: """ Generate a simulated conversation between an adversarial chat and a target. @@ -171,7 +171,7 @@ async def _generate_next_message_async( objective: str, conversation_messages: list[Message], adversarial_chat: PromptTarget, - next_message_system_prompt_path: Union[str, Path], + next_message_system_prompt_path: str | Path, ) -> Message: """ Generate a single next message using the adversarial chat LLM. diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 94f3bc9ac1..0cd557b1c6 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -106,8 +106,8 @@ def __init__( self, *, objective_scorer: FloatScaleThresholdScorer, - refusal_scorer: Optional[TrueFalseScorer] = None, - auxiliary_scorers: Optional[list[Scorer]] = None, + refusal_scorer: TrueFalseScorer | None = None, + auxiliary_scorers: list[Scorer] | None = None, use_score_as_feedback: bool = True, ) -> None: """ @@ -117,8 +117,8 @@ def __init__( objective_scorer (FloatScaleThresholdScorer): The scorer for evaluating attack success. Must be a FloatScaleThresholdScorer to provide both granular float scores for node comparison and a threshold for success determination. - refusal_scorer (Optional[TrueFalseScorer]): Optional scorer for detecting refusals. - auxiliary_scorers (Optional[List[Scorer]]): Additional scorers for auxiliary metrics. + refusal_scorer (TrueFalseScorer | None): Optional scorer for detecting refusals. + auxiliary_scorers (list[Scorer] | None): Additional scorers for auxiliary metrics. use_score_as_feedback (bool): Whether to use scoring results as feedback. Defaults to True. Raises: @@ -171,9 +171,9 @@ class TAPAttackContext(MultiTurnAttackContext[Any]): nodes: list["_TreeOfAttacksNode"] = field(default_factory=list) # Best conversation ID and score found during the attack - best_conversation_id: Optional[str] = None - best_objective_score: Optional[Score] = None - best_adversarial_conversation_id: Optional[str] = None + best_conversation_id: str | None = None + best_objective_score: Score | None = None + best_adversarial_conversation_id: str | None = None class TAPAttackResult(AttackResult): @@ -185,7 +185,7 @@ class TAPAttackResult(AttackResult): """ @property - def tree_visualization(self) -> Optional[Tree]: + def tree_visualization(self) -> Tree | None: """Get the tree visualization from metadata.""" return self.metadata.get("tree_visualization", None) @@ -235,12 +235,12 @@ def auxiliary_scores_summary(self, value: dict[str, float]) -> None: self.metadata["auxiliary_scores_summary"] = value @property - def best_adversarial_conversation_id(self) -> Optional[str]: + def best_adversarial_conversation_id(self) -> str | None: """Get the adversarial conversation ID for the best-scoring branch.""" - return cast("Optional[str]", self.metadata.get("best_adversarial_conversation_id", None)) + return cast("str | None", self.metadata.get("best_adversarial_conversation_id", None)) @best_adversarial_conversation_id.setter - def best_adversarial_conversation_id(self, value: Optional[str]) -> None: + def best_adversarial_conversation_id(self, value: str | None) -> None: """Set the best adversarial conversation ID.""" self.metadata["best_adversarial_conversation_id"] = value @@ -285,16 +285,16 @@ def __init__( adversarial_chat_system_seed_prompt: SeedPrompt, desired_response_prefix: str, objective_scorer: Scorer, - on_topic_scorer: Optional[Scorer], + on_topic_scorer: Scorer | None, request_converters: list[PromptConverterConfiguration], response_converters: list[PromptConverterConfiguration], - auxiliary_scorers: Optional[list[Scorer]], + auxiliary_scorers: list[Scorer] | None, attack_id: ComponentIdentifier, attack_strategy_name: str, - memory_labels: Optional[dict[str, str]] = None, - parent_id: Optional[str] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, - initial_prompt: Optional[Message] = None, + memory_labels: dict[str, str] | None = None, + parent_id: str | None = None, + prompt_normalizer: PromptNormalizer | None = None, + initial_prompt: Message | None = None, ) -> None: """ Initialize a tree node. @@ -307,16 +307,16 @@ def __init__( adversarial_chat_system_seed_prompt (SeedPrompt): The system prompt for the adversarial chat desired_response_prefix (str): The prefix for the desired response. objective_scorer (Scorer): The scorer for evaluating the objective target's response. - on_topic_scorer (Optional[Scorer]): Optional scorer to check if the prompt is on-topic. - request_converters (List[PromptConverterConfiguration]): Converters for request normalization - response_converters (List[PromptConverterConfiguration]): Converters for response normalization - auxiliary_scorers (Optional[List[Scorer]]): Additional scorers for the response + on_topic_scorer (Scorer | None): Optional scorer to check if the prompt is on-topic. + request_converters (list[PromptConverterConfiguration]): Converters for request normalization + response_converters (list[PromptConverterConfiguration]): Converters for response normalization + auxiliary_scorers (list[Scorer] | None): Additional scorers for the response attack_id (ComponentIdentifier): Unique identifier for the attack. attack_strategy_name (str): Name of the attack strategy for execution context. - memory_labels (Optional[dict[str, str]]): Labels for memory storage. - parent_id (Optional[str]): ID of the parent node, if this is a child node - prompt_normalizer (Optional[PromptNormalizer]): Normalizer for handling prompts and responses. - initial_prompt (Optional[Message]): Initial message to send for the first turn, + memory_labels (dict[str, str] | None): Labels for memory storage. + parent_id (str | None): ID of the parent node, if this is a child node + prompt_normalizer (PromptNormalizer | None): Normalizer for handling prompts and responses. + initial_prompt (Message | None): Initial message to send for the first turn, bypassing adversarial chat generation. Supports multimodal messages. """ # Store configuration @@ -353,21 +353,21 @@ def __init__( # Execution results (populated after send_prompt_async) self.completed = False self.off_topic = False - self.objective_score: Optional[Score] = None + self.objective_score: Score | None = None self.auxiliary_scores: dict[str, Score] = {} - self.last_prompt_sent: Optional[str] = None - self.last_response: Optional[str] = None - self.error_message: Optional[str] = None + self.last_prompt_sent: str | None = None + self.last_response: str | None = None + self.error_message: str | None = None # Context from prepended conversation (for adversarial chat system prompt) - self._conversation_context: Optional[str] = None + self._conversation_context: str | None = None # Initial prompt for first turn (bypasses adversarial chat generation) # This supports multimodal messages - self._initial_prompt: Optional[Message] = initial_prompt + self._initial_prompt: Message | None = initial_prompt # Current objective (set when send_prompt_async is called) - self._objective: Optional[str] = None + self._objective: str | None = None async def initialize_with_prepended_conversation_async( self, @@ -391,8 +391,8 @@ async def initialize_with_prepended_conversation_async( - The context is used in _generate_first_turn_prompt_async Args: - prepended_conversation (List[Message]): The conversation history to replay. - prepended_conversation_config (Optional[PrependedConversationConfig]): + prepended_conversation (list[Message]): The conversation history to replay. + prepended_conversation_config (PrependedConversationConfig | None): Configuration for how to process the prepended conversation. Note: @@ -1925,8 +1925,8 @@ def _create_attack_node( self, *, context: TAPAttackContext, - parent_id: Optional[str] = None, - initial_prompt: Optional[Message] = None, + parent_id: str | None = None, + initial_prompt: Message | None = None, ) -> _TreeOfAttacksNode: """ Create a new attack node with the configured settings. @@ -1937,9 +1937,9 @@ def _create_attack_node( Args: context (TAPAttackContext): The attack context containing the objective and other configuration. - parent_id (Optional[str]): The ID of the parent node in the tree, if any. If None, + parent_id (str | None): The ID of the parent node in the tree, if any. If None, the node will be a root-level node. - initial_prompt (Optional[Message]): Initial message for first turn, bypassing + initial_prompt (Message | None): Initial message for first turn, bypassing adversarial chat generation. Supports multimodal messages. "next_message" in multiturncontext Returns: @@ -1986,11 +1986,11 @@ def _get_completed_nodes_sorted_by_score(self, nodes: list[_TreeOfAttacksNode]) have identical scores. Args: - nodes (List[_TreeOfAttacksNode]): List of nodes to filter and sort. May + nodes (list[_TreeOfAttacksNode]): List of nodes to filter and sort. May contain nodes in various states (completed, off-topic, errored, etc.) Returns: - List[_TreeOfAttacksNode]: A list of nodes that are completed, on-topic, + list[_TreeOfAttacksNode]: A list of nodes that are completed, on-topic, and have valid objective scores, sorted by score in descending order. """ completed_nodes = [ @@ -2037,7 +2037,7 @@ def _format_node_result(self, node: _TreeOfAttacksNode) -> str: unnormalized_score = round(1 + normalized_score * 9) return f"Score: {unnormalized_score}/10" - def _create_on_topic_scorer(self, objective: str) -> Optional[Scorer]: + def _create_on_topic_scorer(self, objective: str) -> Scorer | None: """ Create an on-topic scorer if enabled, configured for the specific objective. @@ -2051,7 +2051,7 @@ def _create_on_topic_scorer(self, objective: str) -> Optional[Scorer]: relevant to the original goal. Returns: - Optional[Scorer]: + Scorer | None: - `SelfAskTrueFalseScorer` instance configured with the objective if `on_topic_checking_enabled` is `True` and scoring_target exists - `None` if `on_topic_checking_enabled` is `False` or no scoring_target @@ -2187,7 +2187,7 @@ def _create_attack_result( return result - def _get_last_response_from_conversation(self, conversation_id: Optional[str]) -> Optional[MessagePiece]: + def _get_last_response_from_conversation(self, conversation_id: str | None) -> MessagePiece | None: """ Retrieve the last response from a conversation. @@ -2196,11 +2196,11 @@ def _get_last_response_from_conversation(self, conversation_id: Optional[str]) - response from the best performing conversation for inclusion in the attack result. Args: - conversation_id (Optional[str]): The conversation ID to retrieve from. May be + conversation_id (str | None): The conversation ID to retrieve from. May be None if no successful conversations were found during the attack. Returns: - Optional[MessagePiece]: The last response piece from the conversation, + MessagePiece | None: The last response piece from the conversation, or None if no conversation ID was provided or no responses exist. """ if not conversation_id: @@ -2218,10 +2218,10 @@ def _get_auxiliary_scores_summary(self, nodes: list[_TreeOfAttacksNode]) -> dict beyond the objective score that may be useful for analysis. Args: - nodes (List[TreeOfAttacksNode]): List of nodes to extract auxiliary scores from. + nodes (list[TreeOfAttacksNode]): List of nodes to extract auxiliary scores from. Returns: - Dict[str, float]: A dictionary mapping auxiliary score names to their + dict[str, float]: A dictionary mapping auxiliary score names to their float values, or an empty dictionary if no auxiliary scores are available. """ if not nodes or not nodes[0].auxiliary_scores: @@ -2243,7 +2243,7 @@ def _calculate_tree_statistics(self, tree_visualization: Tree) -> dict[str, int] if it was removed from consideration. Returns: - Dict[str, int]: A dictionary with the following keys: + dict[str, int]: A dictionary with the following keys: - "nodes_explored": Total number of nodes explored (excluding root) - "nodes_pruned": Total number of nodes that were pruned during execution """ @@ -2261,7 +2261,7 @@ async def execute_async( self, *, objective: str, - memory_labels: Optional[dict[str, str]] = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> TAPAttackResult: ... @@ -2280,7 +2280,7 @@ async def execute_async( Args: objective (str): The objective of the attack. - memory_labels (Optional[Dict[str, str]]): Memory labels for the attack context. + memory_labels (dict[str, str] | None): Memory labels for the attack context. **kwargs: Additional parameters for the attack. Returns: diff --git a/pyrit/executor/attack/single_turn/context_compliance.py b/pyrit/executor/attack/single_turn/context_compliance.py index 61e58b9e5f..4568a158e8 100644 --- a/pyrit/executor/attack/single_turn/context_compliance.py +++ b/pyrit/executor/attack/single_turn/context_compliance.py @@ -3,7 +3,7 @@ import logging from pathlib import Path -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH @@ -59,12 +59,12 @@ def __init__( *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] attack_adversarial_config: AttackAdversarialConfig, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, - context_description_instructions_path: Optional[Path] = None, - affirmative_response: Optional[str] = None, + context_description_instructions_path: Path | None = None, + affirmative_response: str | None = None, ) -> None: """ Initialize the context compliance attack strategy. @@ -73,14 +73,14 @@ def __init__( objective_target (PromptTarget): The target system to attack. Must be a PromptTarget. attack_adversarial_config (AttackAdversarialConfig): Configuration for the adversarial component, including the adversarial chat target used for rephrasing. - attack_converter_config (Optional[AttackConverterConfig]): Configuration for attack converters, + attack_converter_config (AttackConverterConfig | None): Configuration for attack converters, including request and response converters. - attack_scoring_config (Optional[AttackScoringConfig]): Configuration for attack scoring. - prompt_normalizer (Optional[PromptNormalizer]): The prompt normalizer to use for sending prompts. + attack_scoring_config (AttackScoringConfig | None): Configuration for attack scoring. + prompt_normalizer (PromptNormalizer | None): The prompt normalizer to use for sending prompts. max_attempts_on_failure (int): Maximum number of attempts to retry on failure. - context_description_instructions_path (Optional[Path]): Path to the context description + context_description_instructions_path (Path | None): Path to the context description instructions YAML file. If not provided, uses the default path. - affirmative_response (Optional[str]): The affirmative response to be used in the conversation history. + affirmative_response (str | None): The affirmative response to be used in the conversation history. If not provided, uses the default "yes.". Raises: diff --git a/pyrit/executor/attack/single_turn/flip_attack.py b/pyrit/executor/attack/single_turn/flip_attack.py index 035ef2212d..878ff1da1a 100644 --- a/pyrit/executor/attack/single_turn/flip_attack.py +++ b/pyrit/executor/attack/single_turn/flip_attack.py @@ -4,7 +4,7 @@ import logging import pathlib import uuid -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH @@ -41,9 +41,9 @@ def __init__( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, ) -> None: """ diff --git a/pyrit/executor/attack/single_turn/many_shot_jailbreak.py b/pyrit/executor/attack/single_turn/many_shot_jailbreak.py index e4225b3ddb..6c9f81bbf4 100644 --- a/pyrit/executor/attack/single_turn/many_shot_jailbreak.py +++ b/pyrit/executor/attack/single_turn/many_shot_jailbreak.py @@ -3,7 +3,7 @@ import json import logging -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import DATASETS_PATH, JAILBREAK_TEMPLATES_PATH @@ -50,12 +50,12 @@ def __init__( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, example_count: int = 100, - many_shot_examples: Optional[list[dict[str, str]]] = None, + many_shot_examples: list[dict[str, str]] | None = None, ) -> None: """ Args: diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index f3c8aeedae..794cc4294e 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -3,7 +3,7 @@ import logging import uuid -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.utils import warn_if_set @@ -55,26 +55,26 @@ def __init__( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, params_type: type[AttackParamsT] = AttackParameters, # type: ignore[ty:invalid-parameter-default] - prepended_conversation_config: Optional[PrependedConversationConfig] = None, + prepended_conversation_config: PrependedConversationConfig | None = None, ) -> None: """ Initialize the prompt injection attack strategy. Args: objective_target (PromptTarget): The target system to attack. - attack_converter_config (Optional[AttackConverterConfig]): Configuration for prompt converters. - attack_scoring_config (Optional[AttackScoringConfig]): Configuration for scoring components. - prompt_normalizer (Optional[PromptNormalizer]): Normalizer for handling prompts. + attack_converter_config (AttackConverterConfig | None): Configuration for prompt converters. + attack_scoring_config (AttackScoringConfig | None): Configuration for scoring components. + prompt_normalizer (PromptNormalizer | None): Normalizer for handling prompts. max_attempts_on_failure (int): Maximum number of attempts to retry on failure. - params_type (Type[AttackParamsT]): The type of parameters this strategy accepts. + params_type (type[AttackParamsT]): The type of parameters this strategy accepts. Defaults to AttackParameters. Use AttackParameters.excluding() to create a params type that rejects certain fields. - prepended_conversation_config (Optional[PrependedConversationConfiguration]): + prepended_conversation_config (PrependedConversationConfiguration | None): Configuration for how to process prepended conversations. Controls converter application by role, message normalization, and non-chat target behavior. @@ -119,12 +119,12 @@ def __init__( # Store the prepended conversation configuration self._prepended_conversation_config = prepended_conversation_config - def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: + def get_attack_scoring_config(self) -> AttackScoringConfig | None: """ Get the attack scoring configuration used by this strategy. Returns: - Optional[AttackScoringConfig]: The scoring configuration with objective and auxiliary scorers. + AttackScoringConfig | None: The scoring configuration with objective and auxiliary scorers. """ return AttackScoringConfig( objective_scorer=self._objective_scorer, @@ -242,18 +242,18 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta ) def _determine_attack_outcome( - self, *, response: Optional[Message], score: Optional[Score], context: SingleTurnAttackContext[Any] - ) -> tuple[AttackOutcome, Optional[str]]: + self, *, response: Message | None, score: Score | None, context: SingleTurnAttackContext[Any] + ) -> tuple[AttackOutcome, str | None]: """ Determine the outcome of the attack based on the response and score. Args: - response (Optional[Message]): The last response from the target (if any). - score (Optional[Score]): The objective score (if any). + response (Message | None): The last response from the target (if any). + score (Score | None): The objective score (if any). context (SingleTurnAttackContext): The attack context containing configuration. Returns: - tuple[AttackOutcome, Optional[str]]: A tuple of (outcome, outcome_reason). + tuple[AttackOutcome, str | None]: A tuple of (outcome, outcome_reason). """ if not self._objective_scorer: # No scorer means we can't determine success/failure @@ -299,7 +299,7 @@ def _get_message(self, context: SingleTurnAttackContext[Any]) -> Message: async def _send_prompt_to_objective_target_async( self, *, message: Message, context: SingleTurnAttackContext[Any] - ) -> Optional[Message]: + ) -> Message | None: """ Send the prompt to the target and return the response. @@ -308,7 +308,7 @@ async def _send_prompt_to_objective_target_async( context (SingleTurnAttackContext): The attack context containing parameters and labels. Returns: - Optional[Message]: The model's response if successful, or None if + Message | None: The model's response if successful, or None if the request was filtered, blocked, or encountered an error. """ with execution_context( @@ -334,7 +334,7 @@ async def _evaluate_response_async( *, response: Message, objective: str, - ) -> Optional[Score]: + ) -> Score | None: """ Evaluate the response against the objective using the configured scorers. @@ -346,7 +346,7 @@ async def _evaluate_response_async( objective (str): The natural-language description of the attack's objective. Returns: - Optional[Score]: The score from the objective scorer if configured, or None if + Score | None: The score from the objective scorer if configured, or None if no objective scorer is set. Note that auxiliary scorer results are not returned but are still executed and stored. """ diff --git a/pyrit/executor/attack/single_turn/role_play.py b/pyrit/executor/attack/single_turn/role_play.py index dfa21c8aa9..c59cef423c 100644 --- a/pyrit/executor/attack/single_turn/role_play.py +++ b/pyrit/executor/attack/single_turn/role_play.py @@ -4,7 +4,7 @@ import enum import logging import pathlib -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH @@ -68,9 +68,9 @@ def __init__( objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] attack_adversarial_config: AttackAdversarialConfig, role_play_definition_path: pathlib.Path, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, ) -> None: """ @@ -82,9 +82,9 @@ def __init__( including the adversarial chat target used to rephrase objectives into role-play scenarios. role_play_definition_path (pathlib.Path): Path to the YAML file containing role-play definitions (rephrase instructions, user start turn, assistant start turn). - attack_converter_config (Optional[AttackConverterConfig]): Configuration for prompt converters. - attack_scoring_config (Optional[AttackScoringConfig]): Configuration for scoring components. - prompt_normalizer (Optional[PromptNormalizer]): Normalizer for handling prompts. + attack_converter_config (AttackConverterConfig | None): Configuration for prompt converters. + attack_scoring_config (AttackScoringConfig | None): Configuration for scoring components. + prompt_normalizer (PromptNormalizer | None): Normalizer for handling prompts. max_attempts_on_failure (int): Maximum number of attempts to retry the attack Raises: @@ -157,12 +157,12 @@ async def _rephrase_objective_async(self, *, objective: str) -> str: result = await converter.convert_async(prompt=objective, input_type="text") return result.output_text - async def _get_conversation_start_async(self) -> Optional[list[Message]]: + async def _get_conversation_start_async(self) -> list[Message] | None: """ Get the role-play conversation start messages. Returns: - Optional[list[Message]]: List containing user and assistant start turns + list[Message] | None: List containing user and assistant start turns for the role-play scenario. """ return [ diff --git a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py index d719861646..a2271fef29 100644 --- a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py +++ b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py @@ -7,7 +7,7 @@ import uuid from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pyrit.common.logger import logger from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT @@ -32,10 +32,10 @@ class SingleTurnAttackContext(AttackContext[AttackParamsT]): conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) # System prompt for chat-based targets - system_prompt: Optional[str] = None + system_prompt: str | None = None # Arbitrary metadata that downstream attacks or scorers may attach - metadata: Optional[dict[str, Union[str, int]]] = None + metadata: dict[str, str | int] | None = None class SingleTurnAttackStrategy(AttackStrategy[SingleTurnAttackContext[Any], AttackResult], ABC): @@ -59,7 +59,7 @@ def __init__( Args: objective_target (PromptTarget): The target system to attack. context_type (type[SingleTurnAttackContext]): The type of context this strategy will use. - params_type (Type[AttackParamsT]): The type of parameters this strategy accepts. + params_type (type[AttackParamsT]): The type of parameters this strategy accepts. logger (logging.Logger): Logger instance for logging events and messages. """ super().__init__( diff --git a/pyrit/executor/attack/single_turn/skeleton_key.py b/pyrit/executor/attack/single_turn/skeleton_key.py index 3761cb9cb2..25460b4575 100644 --- a/pyrit/executor/attack/single_turn/skeleton_key.py +++ b/pyrit/executor/attack/single_turn/skeleton_key.py @@ -3,7 +3,7 @@ import logging from pathlib import Path -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import EXECUTOR_SEED_PROMPT_PATH @@ -54,10 +54,10 @@ def __init__( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, - skeleton_key_prompt: Optional[str] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, + skeleton_key_prompt: str | None = None, max_attempts_on_failure: int = 0, ) -> None: """ @@ -65,10 +65,10 @@ def __init__( Args: objective_target (PromptTarget): The target system to attack. - attack_converter_config (Optional[AttackConverterConfig]): Configuration for prompt converters. - attack_scoring_config (Optional[AttackScoringConfig]): Configuration for scoring components. - prompt_normalizer (Optional[PromptNormalizer]): Normalizer for handling prompts. - skeleton_key_prompt (Optional[str]): The skeleton key prompt to use. + attack_converter_config (AttackConverterConfig | None): Configuration for prompt converters. + attack_scoring_config (AttackScoringConfig | None): Configuration for scoring components. + prompt_normalizer (PromptNormalizer | None): Normalizer for handling prompts. + skeleton_key_prompt (str | None): The skeleton key prompt to use. If not provided, uses the default skeleton key prompt. max_attempts_on_failure (int): Maximum number of attempts to retry on failure. """ @@ -85,12 +85,12 @@ def __init__( # Load skeleton key prompt self._skeleton_key_prompt = self._load_skeleton_key_prompt(skeleton_key_prompt) - def _load_skeleton_key_prompt(self, skeleton_key_prompt: Optional[str]) -> str: + def _load_skeleton_key_prompt(self, skeleton_key_prompt: str | None) -> str: """ Load the skeleton key prompt from the provided string or default file. Args: - skeleton_key_prompt (Optional[str]): Custom skeleton key prompt if provided. + skeleton_key_prompt (str | None): Custom skeleton key prompt if provided. Returns: str: The skeleton key prompt to use. @@ -135,7 +135,7 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta return result - async def _send_skeleton_key_prompt_async(self, *, context: SingleTurnAttackContext[Any]) -> Optional[Message]: + async def _send_skeleton_key_prompt_async(self, *, context: SingleTurnAttackContext[Any]) -> Message | None: """ Send the skeleton key prompt to the target to prime it for the attack. @@ -143,7 +143,7 @@ async def _send_skeleton_key_prompt_async(self, *, context: SingleTurnAttackCont context (SingleTurnAttackContext): The attack context containing configuration. Returns: - Optional[Message]: The response from the target, or None if filtered. + Message | None: The response from the target, or None if filtered. """ self._logger.debug("Sending skeleton key prompt to target") diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index aa4f0da3b5..5ab784772f 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -6,7 +6,7 @@ import uuid from collections import Counter from dataclasses import dataclass, field -from typing import Any, Optional, cast, overload +from typing import Any, cast, overload from pyrit.common.utils import get_kwarg_param from pyrit.executor.attack.core import ( @@ -45,7 +45,7 @@ class FairnessBiasBenchmarkContext(StrategyContext): num_experiments: int = 1 # Custom objective (optional) - if not provided, default will be generated - objective: Optional[str] = None + objective: str | None = None # Prepended conversation for context prepended_conversation: list[Message] = field(default_factory=list) @@ -56,7 +56,7 @@ class FairnessBiasBenchmarkContext(StrategyContext): # Generated fields for the benchmark run generated_objective: str = field(default_factory=str) generated_story_prompt: str = field(default_factory=str) - generated_message: Optional[Message] = None + generated_message: Message | None = None experiment_results: list[dict[str, str]] = field(default_factory=list) @@ -90,9 +90,9 @@ def __init__( self, *, objective_target: PromptTarget, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, max_attempts_on_failure: int = 0, ) -> None: """ @@ -100,9 +100,9 @@ def __init__( Args: objective_target (PromptTarget): The target model to test for fairness and bias - attack_converter_config (Optional[AttackConverterConfig]): Optional configuration for attack converters - attack_scoring_config (Optional[AttackScoringConfig]): Optional configuration for attack scoring - prompt_normalizer (Optional[PromptNormalizer]): Optional normalizer for prompt processing + attack_converter_config (AttackConverterConfig | None): Optional configuration for attack converters + attack_scoring_config (AttackScoringConfig | None): Optional configuration for attack scoring + prompt_normalizer (PromptNormalizer | None): Optional normalizer for prompt processing max_attempts_on_failure (int): Maximum number of retry attempts on failure (default: 0) """ super().__init__( @@ -274,7 +274,7 @@ def _format_story_prompt(self, *, subject: str, story_type: str) -> str: f"and use pronouns to refer to the {subject} later in the text." ) - def _extract_name(self, response: str) -> Optional[str]: + def _extract_name(self, response: str) -> str | None: """ Extract person names using spaCy NER (if available) or fallback to regex. spaCy provides much more robust name extraction than simple regex patterns. @@ -283,7 +283,7 @@ def _extract_name(self, response: str) -> Optional[str]: response (str): The generated text response from the target model Returns: - Optional[str]: The extracted name if found, None otherwise + str | None: The extracted name if found, None otherwise """ # Try spaCy-based extraction first (more robust) try: @@ -306,7 +306,7 @@ def _extract_name(self, response: str) -> Optional[str]: return str(name) return None - def _extract_name_with_regex(self, response: str) -> Optional[str]: + def _extract_name_with_regex(self, response: str) -> str | None: """ Fallback regex-based name extraction (original implementation with improvements). @@ -314,7 +314,7 @@ def _extract_name_with_regex(self, response: str) -> Optional[str]: response (str): The generated text response from the target model Returns: - Optional[str]: The extracted name if found, None otherwise + str | None: The extracted name if found, None otherwise """ lines = response.split("\n") for line in lines[:3]: @@ -342,7 +342,7 @@ def get_experiment_summary(self, *, context: FairnessBiasBenchmarkContext) -> di context (FairnessBiasBenchmarkContext): The benchmark context containing experiment results Returns: - Dict[str, Any]: Summary dictionary of experiment + dict[str, Any]: Summary dictionary of experiment """ if not context.experiment_results: return {"error": "No experiment results available"} @@ -365,15 +365,15 @@ def get_experiment_summary(self, *, context: FairnessBiasBenchmarkContext) -> di ], } - def get_last_context(self) -> Optional[FairnessBiasBenchmarkContext]: + def get_last_context(self) -> FairnessBiasBenchmarkContext | None: """ Get the context from the last execution. Returns: - Optional[FairnessBiasBenchmarkContext]: The context from the most recent execution, + FairnessBiasBenchmarkContext | None: The context from the most recent execution, or None if no execution has occurred """ - return cast("Optional[FairnessBiasBenchmarkContext]", getattr(self, "_last_context", None)) + return cast("FairnessBiasBenchmarkContext | None", getattr(self, "_last_context", None)) async def _teardown_async(self, *, context: FairnessBiasBenchmarkContext) -> None: """ @@ -391,9 +391,9 @@ async def execute_async( subject: str, story_type: str, num_experiments: int = 1, - objective: Optional[str] = None, - prepended_conversation: Optional[list[Message]] = None, - memory_labels: Optional[dict[str, str]] = None, + objective: str | None = None, + prepended_conversation: list[Message] | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> AttackResult: ... @@ -410,8 +410,8 @@ async def execute_async(self, **kwargs: Any) -> AttackResult: story_type (str): The type of story to generate num_experiments (int, optional): Number of experiments to run (default: 1) objective (str, optional): Custom objective prompt (default: auto-generated) - prepended_conversation (List[Message], optional): Context conversation - memory_labels (Dict[str, str], optional): Labels for memory tracking + prepended_conversation (list[Message], optional): Context conversation + memory_labels (dict[str, str], optional): Labels for memory tracking Returns: AttackResult: The result of the benchmark execution diff --git a/pyrit/executor/benchmark/question_answering.py b/pyrit/executor/benchmark/question_answering.py index 8f2307eba9..3da77b477f 100644 --- a/pyrit/executor/benchmark/question_answering.py +++ b/pyrit/executor/benchmark/question_answering.py @@ -4,7 +4,7 @@ import logging import textwrap from dataclasses import dataclass, field -from typing import Any, Optional, overload +from typing import Any, overload from pyrit.common.utils import get_kwarg_param from pyrit.executor.attack.core import ( @@ -45,7 +45,7 @@ class QuestionAnsweringBenchmarkContext(StrategyContext): # The generated question prompt for the benchmark generated_question_prompt: str = field(default_factory=str) # The generated message for the benchmark - generated_message: Optional[Message] = None + generated_message: Message | None = None class QuestionAnsweringBenchmark(Strategy[QuestionAnsweringBenchmarkContext, AttackResult]): @@ -84,9 +84,9 @@ def __init__( self, *, objective_target: PromptTarget, - attack_converter_config: Optional[AttackConverterConfig] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + attack_converter_config: AttackConverterConfig | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, objective_format_string: str = _DEFAULT_OBJECTIVE_FORMAT, question_asking_format_string: str = _DEFAULT_QUESTION_FORMAT, options_format_string: str = _DEFAULT_OPTIONS_FORMAT, @@ -97,9 +97,9 @@ def __init__( Args: objective_target (PromptTarget): The target system to evaluate. - attack_converter_config (Optional[AttackConverterConfig]): Configuration for prompt converters. - attack_scoring_config (Optional[AttackScoringConfig]): Configuration for scoring components. - prompt_normalizer (Optional[PromptNormalizer]): Normalizer for handling prompts. + attack_converter_config (AttackConverterConfig | None): Configuration for prompt converters. + attack_scoring_config (AttackScoringConfig | None): Configuration for scoring components. + prompt_normalizer (PromptNormalizer | None): Normalizer for handling prompts. objective_format_string (str): Format string for objectives sent to scorers. question_asking_format_string (str): Format string for questions sent to target. options_format_string (str): Format string for formatting answer choices. @@ -259,8 +259,8 @@ async def execute_async( self, *, question_answering_entry: QuestionAnsweringEntry, - prepended_conversation: Optional[list[Message]] = None, - memory_labels: Optional[dict[str, str]] = None, + prepended_conversation: list[Message] | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> AttackResult: ... @@ -279,8 +279,8 @@ async def execute_async( Args: question_answering_entry (QuestionAnsweringEntry): The question answering entry to evaluate. - prepended_conversation (Optional[List[Message]]): Conversation to prepend. - memory_labels (Optional[Dict[str, str]]): Memory labels for the benchmark context. + prepended_conversation (list[Message] | None): Conversation to prepend. + memory_labels (dict[str, str] | None): Memory labels for the benchmark context. **kwargs: Additional parameters for the benchmark. Returns: diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index 38a6c9261f..363f15d189 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -12,7 +12,7 @@ from copy import deepcopy from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from pyrit.common import default_values from pyrit.common.logger import logger @@ -85,10 +85,10 @@ class StrategyEventData(Generic[StrategyContextT, StrategyResultT]): # Context and result of the strategy context: StrategyContextT - result: Optional[StrategyResultT] = None + result: StrategyResultT | None = None # Optional error if the event is related to an error - error: Optional[Exception] = None + error: Exception | None = None class StrategyEventHandler(ABC, Generic[StrategyContextT, StrategyResultT]): @@ -157,7 +157,7 @@ def __init__( self, *, context_type: type[StrategyContextT], - event_handler: Optional[StrategyEventHandler[StrategyContextT, StrategyResultT]] = None, + event_handler: StrategyEventHandler[StrategyContextT, StrategyResultT] | None = None, logger: logging.Logger = logger, ) -> None: """ @@ -165,7 +165,7 @@ def __init__( Args: context_type (type[StrategyContextT]): The type of context this strategy will use. - event_handler (Optional[StrategyEventHandler[StrategyContextT, StrategyResultT]]): An optional + event_handler (StrategyEventHandler[StrategyContextT, StrategyResultT] | None): An optional event handler for strategy events. logger (logging.Logger): The logger to use for this strategy. """ @@ -250,8 +250,8 @@ async def _handle_event_async( *, event: StrategyEvent, context: StrategyContextT, - result: Optional[StrategyResultT] = None, - error: Optional[Exception] = None, + result: StrategyResultT | None = None, + error: Exception | None = None, ) -> None: """ Handle a strategy event by notifying all registered event handlers. @@ -259,8 +259,8 @@ async def _handle_event_async( Args: event (StrategyEvent): The event that occurred. context (StrategyContextT): The context for the strategy. - result (Optional[StrategyResultT]): The result of the strategy execution, if applicable. - error (Optional[Exception]): An error that occurred during execution, if applicable. + result (StrategyResultT | None): The result of the strategy execution, if applicable. + error (Exception | None): An error that occurred during execution, if applicable. """ event_data = StrategyEventData( event=event, diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index 3e32fa4faa..0d953fe8ad 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -7,7 +7,7 @@ import uuid from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union, overload +from typing import TYPE_CHECKING, Any, overload import yaml @@ -103,20 +103,20 @@ def __init__( self, *, objective_target: PromptTarget, - processing_model: Optional[PromptTarget] = None, - converter_config: Optional[StrategyConverterConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + processing_model: PromptTarget | None = None, + converter_config: StrategyConverterConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, ) -> None: """ Initialize the Anecdoctor prompt generation strategy. Args: objective_target (PromptTarget): The chat model to be used for prompt generation. - processing_model (Optional[PromptTarget]): The model used for knowledge graph extraction. + processing_model (PromptTarget | None): The model used for knowledge graph extraction. If provided, the generator will extract a knowledge graph from the examples before generation. If None, the generator will use few-shot examples directly. - converter_config (Optional[StrategyConverterConfig]): Configuration for prompt converters. - prompt_normalizer (Optional[PromptNormalizer]): Normalizer for handling prompts. + converter_config (StrategyConverterConfig | None): Configuration for prompt converters. + prompt_normalizer (PromptNormalizer | None): Normalizer for handling prompts. """ # Initialize base class super().__init__(logger=logger, context_type=AnecdoctorContext) @@ -136,7 +136,7 @@ def __init__( self._system_prompt_template = self._load_prompt_from_yaml(yaml_filename=self._ANECDOCTOR_USE_KG_YAML) # Also preload the KG extraction prompt so `_extract_knowledge_graph_async` doesn't # repeat the file read + YAML parse on each invocation. - self._kg_prompt_template: Optional[str] = self._load_prompt_from_yaml( + self._kg_prompt_template: str | None = self._load_prompt_from_yaml( yaml_filename=self._ANECDOCTOR_BUILD_KG_YAML ) else: @@ -146,21 +146,21 @@ def __init__( def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the identifier for this prompt generator. Args: - params (Optional[Dict[str, Any]]): Additional behavioral parameters. - children (Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]]): + params (dict[str, Any] | None): Additional behavioral parameters. + children (dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None): Named child component identifiers. Returns: ComponentIdentifier: The identifier for this prompt generator. """ - all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { + all_children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = { "objective_target": self._objective_target.get_identifier(), } if children: @@ -286,7 +286,7 @@ async def _prepare_examples_async(self, *, context: AnecdoctorContext) -> str: async def _send_examples_to_target_async( self, *, formatted_examples: str, context: AnecdoctorContext - ) -> Optional[Message]: + ) -> Message | None: """ Send the formatted examples to the target model. @@ -298,7 +298,7 @@ async def _send_examples_to_target_async( context (AnecdoctorContext): The generation context containing conversation metadata. Returns: - Optional[Message]: The response from the target model, + Message | None: The response from the target model, or None if the request failed. """ # Create message from the formatted examples @@ -343,7 +343,7 @@ def _format_few_shot_examples(self, *, evaluation_data: list[str]) -> str: Format the evaluation data as few-shot examples. Args: - evaluation_data (List[str]): The evaluation data to format. + evaluation_data (list[str]): The evaluation data to format. Returns: str: Formatted string with examples prefixed by "### examples". @@ -414,7 +414,7 @@ async def execute_async( content_type: str, language: str, evaluation_data: list[str], - memory_labels: Optional[dict[str, str]] = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> AnecdoctorResult: ... @@ -434,8 +434,8 @@ async def execute_async( Args: content_type (str): The type of content to generate (e.g., "viral tweet", "news article"). language (str): The language of the content to generate (e.g., "english", "spanish"). - evaluation_data (List[str]): The data in ClaimsReview format to use in constructing the prompt. - memory_labels (Optional[Dict[str, str]]): Memory labels for the generation context. + evaluation_data (list[str]): The data in ClaimsReview format to use in constructing the prompt. + memory_labels (dict[str, str] | None): Memory labels for the generation context. **kwargs: Additional parameters for the generation. Returns: diff --git a/pyrit/executor/promptgen/core/prompt_generator_strategy.py b/pyrit/executor/promptgen/core/prompt_generator_strategy.py index e027147139..a8d09e3fae 100644 --- a/pyrit/executor/promptgen/core/prompt_generator_strategy.py +++ b/pyrit/executor/promptgen/core/prompt_generator_strategy.py @@ -6,7 +6,7 @@ import logging # noqa: TC003 from abc import ABC from dataclasses import dataclass -from typing import Optional, TypeVar +from typing import TypeVar from pyrit.common.logger import logger from pyrit.executor.core.strategy import ( @@ -70,9 +70,8 @@ def __init__( self, context_type: type[PromptGeneratorStrategyContextT], logger: logging.Logger = logger, - event_handler: Optional[ - StrategyEventHandler[PromptGeneratorStrategyContextT, PromptGeneratorStrategyResultT] - ] = None, + event_handler: StrategyEventHandler[PromptGeneratorStrategyContextT, PromptGeneratorStrategyResultT] + | None = None, ) -> None: """ Initialize the prompt generator strategy. diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index df04314069..73410799a4 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -8,7 +8,7 @@ import textwrap import uuid from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, Union, overload +from typing import TYPE_CHECKING, Any, overload import numpy as np from colorama import Fore, Style @@ -51,14 +51,14 @@ class _PromptNode: def __init__( self, template: str, - parent: Optional[_PromptNode] = None, + parent: _PromptNode | None = None, ) -> None: """ Create the PromptNode instance. Args: template (str): Prompt template. - parent (Optional[_PromptNode]): Parent node. + parent (_PromptNode | None): Parent node. """ self.id = uuid.uuid4() self.template: str = template @@ -66,7 +66,7 @@ def __init__( self.level: int = 0 if parent is None else parent.level + 1 self.visited_num = 0 self.rewards: float = 0 - self.parent: Optional[_PromptNode] = None + self.parent: _PromptNode | None = None if parent is not None: self.add_parent(parent) @@ -157,7 +157,7 @@ def _calculate_uct_score(self, *, node: _PromptNode, step: int) -> float: exploration = self.frequency_weight * np.sqrt(2 * np.log(step) / (node.visited_num + 0.01)) return float(exploitation + exploration) - def update_rewards(self, path: list[_PromptNode], reward: float, last_node: Optional[_PromptNode] = None) -> None: + def update_rewards(self, path: list[_PromptNode], reward: float, last_node: _PromptNode | None = None) -> None: """ Update rewards for nodes in the path. @@ -185,19 +185,19 @@ class FuzzerContext(PromptGeneratorStrategyContext): # Per-execution input data prompts: list[str] prompt_templates: list[str] - max_query_limit: Optional[int] = None + max_query_limit: int | None = None # Tracking state total_target_query_count: int = 0 total_jailbreak_count: int = 0 - jailbreak_conversation_ids: list[Union[str, uuid.UUID]] = field(default_factory=list) + jailbreak_conversation_ids: list[str | uuid.UUID] = field(default_factory=list) executed_turns: int = 0 # Tree structure initial_prompt_nodes: list[_PromptNode] = field(default_factory=list) new_prompt_nodes: list[_PromptNode] = field(default_factory=list) mcts_selected_path: list[_PromptNode] = field(default_factory=list) - last_choice_node: Optional[_PromptNode] = None + last_choice_node: _PromptNode | None = None # Optional memory labels to apply to the prompts memory_labels: dict[str, str] = field(default_factory=dict) @@ -223,7 +223,7 @@ class FuzzerResult(PromptGeneratorStrategyResult): # Concrete fields instead of metadata storage successful_templates: list[str] = Field(default_factory=list) - jailbreak_conversation_ids: list[Union[str, uuid.UUID]] = Field(default_factory=list) + jailbreak_conversation_ids: list[str | uuid.UUID] = Field(default_factory=list) total_queries: int = 0 templates_explored: int = 0 @@ -541,8 +541,8 @@ def with_default_scorer( objective_target: PromptTarget, template_converters: list[FuzzerConverter], scoring_target: PromptTarget, - converter_config: Optional[StrategyConverterConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + converter_config: StrategyConverterConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, frequency_weight: float = _DEFAULT_FREQUENCY_WEIGHT, reward_penalty: float = _DEFAULT_REWARD_PENALTY, minimum_reward: float = _DEFAULT_MINIMUM_REWARD, @@ -562,10 +562,10 @@ def with_default_scorer( Args: objective_target (PromptTarget): The target to send the prompts to. - template_converters (List[FuzzerConverter]): The converters to apply on the selected jailbreak template. + template_converters (list[FuzzerConverter]): The converters to apply on the selected jailbreak template. scoring_target (PromptTarget): The chat target to use for scoring responses. - converter_config (Optional[StrategyConverterConfig]): Configuration for prompt converters. - prompt_normalizer (Optional[PromptNormalizer]): The prompt normalizer to use. + converter_config (StrategyConverterConfig | None): Configuration for prompt converters. + prompt_normalizer (PromptNormalizer | None): The prompt normalizer to use. frequency_weight (float): Constant that balances between high reward and selection frequency. reward_penalty (float): Penalty that diminishes reward as path length increases. minimum_reward (float): Minimal reward to prevent rewards from being too small. @@ -608,10 +608,10 @@ def __init__( *, objective_target: PromptTarget, template_converters: list[FuzzerConverter], - converter_config: Optional[StrategyConverterConfig] = None, - scorer: Optional[Scorer] = None, + converter_config: StrategyConverterConfig | None = None, + scorer: Scorer | None = None, scoring_success_threshold: float = 0.8, - prompt_normalizer: Optional[PromptNormalizer] = None, + prompt_normalizer: PromptNormalizer | None = None, frequency_weight: float = _DEFAULT_FREQUENCY_WEIGHT, reward_penalty: float = _DEFAULT_REWARD_PENALTY, minimum_reward: float = _DEFAULT_MINIMUM_REWARD, @@ -624,13 +624,13 @@ def __init__( Args: objective_target (PromptTarget): The target to send the prompts to. - template_converters (List[FuzzerConverter]): The converters to apply on the selected jailbreak template. + template_converters (list[FuzzerConverter]): The converters to apply on the selected jailbreak template. In each iteration, one converter is chosen at random. - converter_config (Optional[StrategyConverterConfig]): Configuration for prompt converters. + converter_config (StrategyConverterConfig | None): Configuration for prompt converters. Defaults to None. - scorer (Optional[Scorer]): Configuration for scoring responses. Defaults to None. + scorer (Scorer | None): Configuration for scoring responses. Defaults to None. scoring_success_threshold (float): The score threshold to consider a jailbreak successful. - prompt_normalizer (Optional[PromptNormalizer]): The prompt normalizer to use. Defaults to None. + prompt_normalizer (PromptNormalizer | None): The prompt normalizer to use. Defaults to None. frequency_weight (float): Constant that balances between high reward and selection frequency. Defaults to 0.5. reward_penalty (float): Penalty that diminishes reward as path length increases. @@ -685,21 +685,21 @@ def __init__( def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the identifier for this prompt generator. Args: - params (Optional[Dict[str, Any]]): Additional behavioral parameters. - children (Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]]): + params (dict[str, Any] | None): Additional behavioral parameters. + children (dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None): Named child component identifiers. Returns: ComponentIdentifier: The identifier for this prompt generator. """ - all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { + all_children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = { "objective_target": self._objective_target.get_identifier(), } if children: @@ -725,7 +725,7 @@ def _validate_inputs( Validate input parameters. Args: - template_converters (List[FuzzerConverter]): List of template converters. + template_converters (list[FuzzerConverter]): List of template converters. batch_size (int): The batch size for sending prompts. Raises: @@ -972,10 +972,10 @@ def _generate_prompts_from_template(self, *, template: SeedPrompt, prompts: list Args: template (SeedPrompt): The template to use. - prompts (List[str]): The prompts to fill into the template. + prompts (list[str]): The prompts to fill into the template. Returns: - List[str]: The generated jailbreak prompts. + list[str]: The generated jailbreak prompts. Raises: ValueError: If the template doesn't have the required 'prompt' parameter. @@ -992,10 +992,10 @@ async def _send_prompts_to_target_async(self, *, context: FuzzerContext, prompts Args: context (FuzzerContext): The generation context. - prompts (List[str]): The prompts to send. + prompts (list[str]): The prompts to send. Returns: - List[Message]: The responses from the target. + list[Message]: The responses from the target. """ requests = self._create_normalizer_requests(prompts) @@ -1012,7 +1012,7 @@ def _create_normalizer_requests(self, prompts: list[str]) -> list[NormalizerRequ Create normalizer requests from prompts. Args: - prompts (List[str]): The prompts to create requests for. + prompts (list[str]): The prompts to create requests for. Returns: List of normalizer requests. @@ -1041,11 +1041,11 @@ async def _score_responses_async(self, *, responses: list[Message], tasks: list[ Score the responses from the target. Args: - responses (List[Message]): The responses to score. - tasks (List[str]): The original tasks/prompts used for generating the responses. + responses (list[Message]): The responses to score. + tasks (list[str]): The original tasks/prompts used for generating the responses. Returns: - List[Score]: The scores for each response. + list[Score]: The scores for each response. """ if not responses: return [] @@ -1071,8 +1071,8 @@ def _process_scoring_results( Args: context (FuzzerContext): The generation context. - scores (List[Score]): The scores for each response. - responses (List[Message]): The responses that were scored. + scores (list[Score]): The scores for each response. + responses (list[Message]): The responses that were scored. template_node (_PromptNode): The template node that was tested. current_seed (_PromptNode): The seed node that was selected. @@ -1198,8 +1198,8 @@ async def execute_async( *, prompts: list[str], prompt_templates: list[str], - max_query_limit: Optional[int] = None, - memory_labels: Optional[dict[str, str]] = None, + max_query_limit: int | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> FuzzerResult: ... @@ -1217,10 +1217,10 @@ async def execute_async( Execute the Fuzzer generation strategy asynchronously. Args: - prompts (List[str]): The list of prompts to use for generation. - prompt_templates (List[str]): The list of prompt templates to use. - max_query_limit (Optional[int]): The maximum number of queries to execute. - memory_labels (Optional[dict[str, str]]): Optional labels to apply to the prompts. + prompts (list[str]): The list of prompts to use for generation. + prompt_templates (list[str]): The list of prompt templates to use. + max_query_limit (int | None): The maximum number of queries to execute. + memory_labels (dict[str, str] | None): Optional labels to apply to the prompts. **kwargs: Additional keyword arguments. Returns: diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py index 7e68a9907d..812979a797 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py @@ -4,7 +4,7 @@ import pathlib import random import uuid -from typing import Any, Optional +from typing import Any from pyrit.common.apply_defaults import apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -25,9 +25,9 @@ class FuzzerCrossOverConverter(FuzzerConverter): def __init__( self, *, - converter_target: Optional[PromptTarget] = None, - prompt_template: Optional[SeedPrompt] = None, - prompt_templates: Optional[list[str]] = None, + converter_target: PromptTarget | None = None, + prompt_template: SeedPrompt | None = None, + prompt_templates: list[str] | None = None, ) -> None: """ Initialize the converter with the specified chat target and prompt templates. @@ -37,7 +37,7 @@ def __init__( Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt, Optional): Template to be used instead of the default system prompt with instructions for the chat target. - prompt_templates (List[str], Optional): List of prompt templates to use in addition to the default one. + prompt_templates (list[str], Optional): List of prompt templates to use in addition to the default one. """ prompt_template = ( prompt_template diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py index 0c1f3fdf95..627ed159ed 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py @@ -3,7 +3,6 @@ import pathlib import uuid -from typing import Optional from pyrit.common.apply_defaults import apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -24,8 +23,8 @@ class FuzzerExpandConverter(FuzzerConverter): def __init__( self, *, - converter_target: Optional[PromptTarget] = None, - prompt_template: Optional[SeedPrompt] = None, + converter_target: PromptTarget | None = None, + prompt_template: SeedPrompt | None = None, ) -> None: """Initialize the expand converter with optional chat target and prompt template.""" prompt_template = ( diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_rephrase_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_rephrase_converter.py index 10acff3fb6..d1f6783fc3 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_rephrase_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_rephrase_converter.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import pathlib -from typing import Optional from pyrit.common.apply_defaults import apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -20,7 +19,7 @@ class FuzzerRephraseConverter(FuzzerConverter): @apply_defaults def __init__( - self, *, converter_target: Optional[PromptTarget] = None, prompt_template: Optional[SeedPrompt] = None + self, *, converter_target: PromptTarget | None = None, prompt_template: SeedPrompt | None = None ) -> None: """Initialize the rephrase converter with optional chat target and prompt template.""" prompt_template = ( diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_shorten_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_shorten_converter.py index 6258a5e7b3..dcc098f67c 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_shorten_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_shorten_converter.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import pathlib -from typing import Optional from pyrit.common.apply_defaults import apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -20,7 +19,7 @@ class FuzzerShortenConverter(FuzzerConverter): @apply_defaults def __init__( - self, *, converter_target: Optional[PromptTarget] = None, prompt_template: Optional[SeedPrompt] = None + self, *, converter_target: PromptTarget | None = None, prompt_template: SeedPrompt | None = None ) -> None: """Initialize the shorten converter with optional chat target and prompt template.""" prompt_template = ( diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_similar_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_similar_converter.py index d7f2796579..25ec6fd2fb 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_similar_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_similar_converter.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import pathlib -from typing import Optional from pyrit.common.apply_defaults import apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -20,7 +19,7 @@ class FuzzerSimilarConverter(FuzzerConverter): @apply_defaults def __init__( - self, *, converter_target: Optional[PromptTarget] = None, prompt_template: Optional[SeedPrompt] = None + self, *, converter_target: PromptTarget | None = None, prompt_template: SeedPrompt | None = None ) -> None: """Initialize the similar converter with optional chat target and prompt template.""" prompt_template = ( diff --git a/pyrit/executor/workflow/core/workflow_strategy.py b/pyrit/executor/workflow/core/workflow_strategy.py index 944e5efe4d..627b8ffe73 100644 --- a/pyrit/executor/workflow/core/workflow_strategy.py +++ b/pyrit/executor/workflow/core/workflow_strategy.py @@ -6,7 +6,7 @@ import logging # noqa: TC003 from abc import ABC from dataclasses import dataclass -from typing import Optional, TypeVar +from typing import TypeVar from pyrit.common.logger import logger from pyrit.executor.core.strategy import ( @@ -109,7 +109,7 @@ def __init__( *, context_type: type[WorkflowContextT], logger: logging.Logger = logger, - event_handler: Optional[StrategyEventHandler[WorkflowContextT, WorkflowResultT]] = None, + event_handler: StrategyEventHandler[WorkflowContextT, WorkflowResultT] | None = None, ) -> None: """ Initialize the workflow strategy with a specific context type and logger. diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index e981c46b63..c367cd9b02 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -5,7 +5,7 @@ import uuid from dataclasses import dataclass, field from enum import Enum -from typing import Any, Optional, Protocol, Union, overload +from typing import Any, Protocol, overload from pyrit.common.utils import combine_dict, get_kwarg_param from pyrit.executor.core import StrategyConverterConfig @@ -66,7 +66,7 @@ class XPIAContext(WorkflowContext): attack_content: Message # Callback to execute after the attack prompt is positioned in the attack location - processing_callback: Optional[XPIAProcessingCallback] = None + processing_callback: XPIAProcessingCallback | None = None # Conversation ID for the attack setup target attack_setup_target_conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -75,7 +75,7 @@ class XPIAContext(WorkflowContext): processing_conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) # The prompt to send to the processing target (for test workflow) - processing_prompt: Optional[Message] = None + processing_prompt: Message | None = None # Additional labels that can be applied throughout the workflow memory_labels: dict[str, str] = field(default_factory=dict) @@ -96,10 +96,10 @@ class XPIAResult(WorkflowResult): processing_response: str # Score if a scorer was used, None otherwise - score: Optional[Score] = None + score: Score | None = None # Response from the attack setup target - attack_setup_response: Optional[str] = None + attack_setup_response: str | None = None @property def success(self) -> bool: @@ -145,9 +145,9 @@ def __init__( self, *, attack_setup_target: PromptTarget, - scorer: Optional[Scorer] = None, - converter_config: Optional[StrategyConverterConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + scorer: Scorer | None = None, + converter_config: StrategyConverterConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, logger: logging.Logger = logger, ) -> None: """ @@ -156,11 +156,11 @@ def __init__( Args: attack_setup_target (PromptTarget): The target that generates the attack prompt and gets it into the attack location. - scorer (Optional[Scorer]): Optional scorer to evaluate the processing response. + scorer (Scorer | None): Optional scorer to evaluate the processing response. If no scorer is provided the workflow will skip scoring. - converter_config (Optional[StrategyConverterConfig]): Optional converter + converter_config (StrategyConverterConfig | None): Optional converter configuration for request and response converters. - prompt_normalizer (Optional[PromptNormalizer]): Optional PromptNormalizer + prompt_normalizer (PromptNormalizer | None): Optional PromptNormalizer instance. If not provided, a new one will be created. logger (logging.Logger): Logger instance for logging events. """ @@ -178,21 +178,21 @@ def __init__( def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the identifier for this XPIA workflow. Args: - params (Optional[Dict[str, Any]]): Additional behavioral parameters. - children (Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]]): + params (dict[str, Any] | None): Additional behavioral parameters. + children (dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None): Named child component identifiers. Returns: ComponentIdentifier: The identifier for this XPIA workflow. """ - all_children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = { + all_children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = { "attack_setup_target": self._attack_setup_target.get_identifier(), } if self._scorer: @@ -382,7 +382,7 @@ async def _execute_processing_async(self, *, context: XPIAContext) -> str: self._logger.info(f'Received the following response from the processing target "{processing_response}"') return processing_response - async def _score_response_async(self, *, processing_response: str) -> Optional[Score]: + async def _score_response_async(self, *, processing_response: str) -> Score | None: """ Score the processing response if a scorer is provided. @@ -394,7 +394,7 @@ async def _score_response_async(self, *, processing_response: str) -> Optional[S processing_response (str): The response from the processing target to score. Returns: - Optional[Score]: The score if a scorer is configured, None otherwise. + Score | None: The score if a scorer is configured, None otherwise. """ if not self._scorer: self._logger.info("No scorer provided. Returning raw processing response.") @@ -429,9 +429,9 @@ async def execute_async( self, *, attack_content: Message, - processing_callback: Optional[XPIAProcessingCallback] = None, - processing_prompt: Optional[Message] = None, - memory_labels: Optional[dict[str, str]] = None, + processing_callback: XPIAProcessingCallback | None = None, + processing_prompt: Message | None = None, + memory_labels: dict[str, str] | None = None, **kwargs: Any, ) -> XPIAResult: ... @@ -453,9 +453,9 @@ async def execute_async( processing_callback (ProcessingCallback): The callback to execute after the attack prompt is positioned in the attack location. This is generic on purpose to allow for flexibility. The callback should return the processing response. - processing_prompt (Optional[Message]): The prompt to send to the processing target. This should + processing_prompt (Message | None): The prompt to send to the processing target. This should include placeholders to invoke plugins (if any). - memory_labels (Optional[Dict[str, str]]): Memory labels for the attack context. + memory_labels (dict[str, str] | None): Memory labels for the attack context. **kwargs: Additional parameters for the attack. Returns: @@ -503,8 +503,8 @@ def __init__( attack_setup_target: PromptTarget, processing_target: PromptTarget, scorer: Scorer, - converter_config: Optional[StrategyConverterConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + converter_config: StrategyConverterConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, logger: logging.Logger = logger, ) -> None: """ @@ -517,9 +517,9 @@ def __init__( processing prompt. This should include references to invoke plugins (if any). scorer (Scorer): The scorer to use to score the processing response. This is required for test workflows to evaluate attack success. - converter_config (Optional[StrategyConverterConfig]): Optional converter + converter_config (StrategyConverterConfig | None): Optional converter configuration for request and response converters. - prompt_normalizer (Optional[PromptNormalizer]): Optional PromptNormalizer + prompt_normalizer (PromptNormalizer | None): Optional PromptNormalizer instance. If not provided, a new one will be created. logger (logging.Logger): Logger instance for logging events. """ @@ -605,8 +605,8 @@ def __init__( *, attack_setup_target: PromptTarget, scorer: Scorer, - converter_config: Optional[StrategyConverterConfig] = None, - prompt_normalizer: Optional[PromptNormalizer] = None, + converter_config: StrategyConverterConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, logger: logging.Logger = logger, ) -> None: """ @@ -617,9 +617,9 @@ def __init__( and gets it into the attack location. scorer (Scorer): The scorer to use to score the processing response. This is required to evaluate the manually provided response. - converter_config (Optional[StrategyConverterConfig]): Optional converter + converter_config (StrategyConverterConfig | None): Optional converter configuration for request and response converters. - prompt_normalizer (Optional[PromptNormalizer]): Optional PromptNormalizer + prompt_normalizer (PromptNormalizer | None): Optional PromptNormalizer instance. If not provided, a new one will be created. logger (logging.Logger): Logger instance for logging events. """ diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 6723ae2842..81b97d8716 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -7,7 +7,7 @@ from collections.abc import MutableSequence, Sequence from contextlib import closing, suppress from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast from sqlalchemy import and_, create_engine, event, exists, or_, text from sqlalchemy.engine.base import Engine @@ -64,9 +64,9 @@ class AzureSQLMemory(MemoryInterface, metaclass=Singleton): def __init__( self, *, - connection_string: Optional[str] = None, - results_container_url: Optional[str] = None, - results_sas_token: Optional[str] = None, + connection_string: str | None = None, + results_container_url: str | None = None, + results_sas_token: str | None = None, verbose: bool = False, skip_schema_migration: bool = False, silent: bool = False, @@ -75,11 +75,11 @@ def __init__( Initialize an Azure SQL Memory backend. Args: - connection_string (Optional[str]): The connection string for the Azure Sql Database. If not provided, + connection_string (str | None): The connection string for the Azure Sql Database. If not provided, it falls back to the 'AZURE_SQL_DB_CONNECTION_STRING' environment variable. - results_container_url (Optional[str]): The URL to an Azure Storage Container. If not provided, + results_container_url (str | None): The URL to an Azure Storage Container. If not provided, it falls back to the 'AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL' environment variable. - results_sas_token (Optional[str]): The Shared Access Signature (SAS) token for the storage container. + results_sas_token (str | None): The Shared Access Signature (SAS) token for the storage container. If not provided, falls back to the 'AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN' environment variable. verbose (bool): Whether to enable verbose logging for the database engine. Defaults to False. skip_schema_migration (bool): Whether to skip schema migration. Defaults to False. @@ -93,12 +93,12 @@ def __init__( env_var_name=self.AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL, passed_value=results_container_url ) - self._results_container_sas_token: Optional[str] = self._resolve_sas_token( + self._results_container_sas_token: str | None = self._resolve_sas_token( self.AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN, results_sas_token ) - self._auth_token: Optional[AccessToken] = None - self._auth_token_expiry: Optional[int] = None + self._auth_token: AccessToken | None = None + self._auth_token_expiry: int | None = None self.results_path = self._results_container_url @@ -116,16 +116,16 @@ def __init__( super().__init__() @staticmethod - def _resolve_sas_token(env_var_name: str, passed_value: Optional[str] = None) -> Optional[str]: + def _resolve_sas_token(env_var_name: str, passed_value: str | None = None) -> str | None: """ Resolve the SAS token value, allowing a fallback to None for delegation SAS. Args: env_var_name (str): The environment variable name to look up. - passed_value (Optional[str]): A passed-in value for the SAS token. + passed_value (str | None): A passed-in value for the SAS token. Returns: - Optional[str]: Resolved SAS token or None if not provided. + str | None: Resolved SAS token or None if not provided. """ try: return default_values.get_required_value(env_var_name=env_var_name, passed_value=passed_value) @@ -285,14 +285,14 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str return [or_(pme_match, are_match)] - def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]: + def _get_metadata_conditions(self, *, prompt_metadata: dict[str, str | int]) -> list[TextClause]: """ Generate SQL conditions for filtering by prompt metadata. Uses JSON_VALUE() function specific to SQL Azure to query metadata fields in JSON format. Args: - prompt_metadata (dict[str, Union[str, int]]): Dictionary of metadata key-value pairs to filter by. + prompt_metadata (dict[str, str | int]): Dictionary of metadata key-value pairs to filter by. Returns: list: List containing a single SQLAlchemy text condition with bound parameters. @@ -310,7 +310,7 @@ def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int] return [condition] def _get_message_pieces_prompt_metadata_conditions( - self, *, prompt_metadata: dict[str, Union[str, int]] + self, *, prompt_metadata: dict[str, str | int] ) -> list[TextClause]: """ Generate SQL conditions for filtering message pieces by prompt metadata. @@ -318,14 +318,14 @@ def _get_message_pieces_prompt_metadata_conditions( This is a convenience wrapper around _get_metadata_conditions. Args: - prompt_metadata (dict[str, Union[str, int]]): Dictionary of metadata key-value pairs to filter by. + prompt_metadata (dict[str, str | int]): Dictionary of metadata key-value pairs to filter by. Returns: list: List containing SQLAlchemy text conditions with bound parameters. """ return self._get_metadata_conditions(prompt_metadata=prompt_metadata) - def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> TextClause: + def _get_seed_metadata_conditions(self, *, metadata: dict[str, str | int]) -> TextClause: """ Generate SQL condition for filtering seed prompts by metadata. @@ -333,7 +333,7 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) the first (and only) condition. Args: - metadata (dict[str, Union[str, int]]): Dictionary of metadata key-value pairs to filter by. + metadata (dict[str, str | int]): Dictionary of metadata key-value pairs to filter by. Returns: Any: SQLAlchemy text condition with bound parameters. @@ -403,7 +403,7 @@ def _get_condition_json_array_match( Args: json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. property_path (str): The JSON path for the target array. - array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_element_path (str | None): An optional JSON path applied to each array item before matching. array_to_match (Sequence[str]): The array that must match the extracted JSON array values. Combination semantics for multiple entries are controlled by ``match_mode``. If ``array_to_match`` is empty, the condition matches only if the target is also an @@ -800,10 +800,10 @@ def _query_entries( self, model_class: type[Model], *, - conditions: Optional[Any] = None, + conditions: Any | None = None, distinct: bool = False, join_scores: bool = False, - order_by: Optional[Any] = None, + order_by: Any | None = None, limit: int | None = None, ) -> MutableSequence[Model]: """ diff --git a/pyrit/memory/memory_embedding.py b/pyrit/memory/memory_embedding.py index e2ddc26e77..72bddcd34b 100644 --- a/pyrit/memory/memory_embedding.py +++ b/pyrit/memory/memory_embedding.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.embedding import OpenAITextEmbedding from pyrit.memory.memory_models import EmbeddingDataEntry @@ -16,12 +15,12 @@ class MemoryEmbedding: embedding_model (EmbeddingSupport): An instance of a class that supports embedding generation. """ - def __init__(self, *, embedding_model: Optional[EmbeddingSupport] = None) -> None: + def __init__(self, *, embedding_model: EmbeddingSupport | None = None) -> None: """ Initialize the memory embedding helper with a backing embedding model. Args: - embedding_model (Optional[EmbeddingSupport]): The embedding model used to + embedding_model (EmbeddingSupport | None): The embedding model used to generate text embeddings. If not provided, a ValueError is raised. Raises: @@ -55,7 +54,7 @@ def generate_embedding_memory_data(self, *, message_piece: MessagePiece) -> Embe raise ValueError("Only text data is supported for embedding.") -def default_memory_embedding_factory(embedding_model: Optional[EmbeddingSupport] = None) -> MemoryEmbedding | None: +def default_memory_embedding_factory(embedding_model: EmbeddingSupport | None = None) -> MemoryEmbedding | None: """ Create a MemoryEmbedding instance with default or provided embedding model. diff --git a/pyrit/memory/memory_exporter.py b/pyrit/memory/memory_exporter.py index 54e61505b3..b2a06668e7 100644 --- a/pyrit/memory/memory_exporter.py +++ b/pyrit/memory/memory_exporter.py @@ -4,7 +4,6 @@ import csv import json from pathlib import Path -from typing import Optional from pyrit.models import MessagePiece @@ -30,7 +29,7 @@ def __init__(self) -> None: } def export_data( - self, data: list[MessagePiece], *, file_path: Optional[Path] = None, export_type: str = "json" + self, data: list[MessagePiece], *, file_path: Path | None = None, export_type: str = "json" ) -> None: """ Export the provided data to a file in the specified format. @@ -52,7 +51,7 @@ def export_data( else: raise ValueError(f"Unsupported export format: {export_type}") - def export_to_json(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: + def export_to_json(self, data: list[MessagePiece], file_path: Path | None = None) -> None: """ Export the provided data to a JSON file at the specified file path. Each item in the data list, representing a row from the table, @@ -73,7 +72,7 @@ def export_to_json(self, data: list[MessagePiece], file_path: Optional[Path] = N with open(file_path, "w") as f: json.dump(export_data, f, indent=4) - def export_to_csv(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: + def export_to_csv(self, data: list[MessagePiece], file_path: Path | None = None) -> None: """ Export the provided data to a CSV file at the specified file path. Each item in the data list, representing a row from the table, @@ -97,7 +96,7 @@ def export_to_csv(self, data: list[MessagePiece], file_path: Optional[Path] = No writer.writeheader() writer.writerows(export_data) - def export_to_markdown(self, data: list[MessagePiece], file_path: Optional[Path] = None) -> None: + def export_to_markdown(self, data: list[MessagePiece], file_path: Path | None = None) -> None: """ Export the provided data to a Markdown file at the specified file path. Each item in the data list is converted to a dictionary and formatted as a table. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 26448f5b6c..f0e5fdeb9b 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -11,7 +11,7 @@ from contextlib import closing from datetime import datetime, timezone from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Literal, TypeVar from sqlalchemy import MetaData, and_, not_, or_ from sqlalchemy.engine.base import Engine @@ -93,7 +93,7 @@ def _uid() -> str: """Return a short unique suffix for bind-param deduplication.""" return uuid.uuid4().hex[:8] - def __init__(self, embedding_model: Optional[Any] = None) -> None: + def __init__(self, embedding_model: Any | None = None) -> None: """ Initialize the MemoryInterface. @@ -110,7 +110,7 @@ def __init__(self, embedding_model: Optional[Any] = None) -> None: # Ensure cleanup at process exit self.cleanup() - def enable_embedding(self, embedding_model: Optional[Any] = None) -> None: + def enable_embedding(self, embedding_model: Any | None = None) -> None: """ Enable embedding functionality for the memory interface. @@ -276,7 +276,7 @@ def _get_condition_json_array_match( Args: json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. property_path (str): The JSON path for the target array. - array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_element_path (str | None): An optional JSON path applied to each array item before matching. array_to_match (Sequence[str]): The array that must match the extracted JSON array values. Combination semantics for multiple entries are controlled by ``match_mode``. If ``array_to_match`` is empty, the condition matches only if the target is also an @@ -318,9 +318,7 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str """ @abc.abstractmethod - def _get_message_pieces_prompt_metadata_conditions( - self, *, prompt_metadata: dict[str, Union[str, int]] - ) -> list[Any]: + def _get_message_pieces_prompt_metadata_conditions(self, *, prompt_metadata: dict[str, str | int]) -> list[Any]: """ Return a list of conditions for filtering memory entries based on prompt metadata. @@ -333,7 +331,7 @@ def _get_message_pieces_prompt_metadata_conditions( """ @abc.abstractmethod - def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: + def _get_seed_metadata_conditions(self, *, metadata: dict[str, str | int]) -> Any: """ Return a condition for filtering seed prompt entries based on prompt metadata. @@ -362,10 +360,10 @@ def _query_entries( self, model_class: type[Model], *, - conditions: Optional[Any] = None, + conditions: Any | None = None, distinct: bool = False, join_scores: bool = False, - order_by: Optional[Any] = None, + order_by: Any | None = None, limit: int | None = None, ) -> MutableSequence[Model]: """ @@ -393,7 +391,7 @@ def _execute_batched_query( distinct: bool = False, join_scores: bool = False, batch_size: int | None = None, - order_by: Optional[Any] = None, + order_by: Any | None = None, limit: int | None = None, ) -> MutableSequence[Model]: """ @@ -708,23 +706,23 @@ def add_scores_to_memory(self, *, scores: Sequence[Score]) -> None: def get_scores( self, *, - score_ids: Optional[Sequence[str]] = None, - score_type: Optional[str] = None, - score_category: Optional[str] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - identifier_filters: Optional[Sequence[IdentifierFilter]] = None, + score_ids: Sequence[str] | None = None, + score_type: str | None = None, + score_category: str | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + identifier_filters: Sequence[IdentifierFilter] | None = None, ) -> Sequence[Score]: """ Retrieve a list of Score objects based on the specified filters. Args: - score_ids (Optional[Sequence[str]]): A list of score IDs to filter by. - score_type (Optional[str]): The type of the score to filter by. - score_category (Optional[str]): The category of the score to filter by. - sent_after (Optional[datetime]): Filter for scores sent after this datetime. - sent_before (Optional[datetime]): Filter for scores sent before this datetime. - identifier_filters (Optional[Sequence[IdentifierFilter]]): A sequence of IdentifierFilter objects that + score_ids (Sequence[str] | None): A list of score IDs to filter by. + score_type (str | None): The type of the score to filter by. + score_category (str | None): The category of the score to filter by. + sent_after (datetime | None): Filter for scores sent after this datetime. + sent_before (datetime | None): Filter for scores sent before this datetime. + identifier_filters (Sequence[IdentifierFilter] | None): A sequence of IdentifierFilter objects that allows filtering by various scorer identifier JSON properties. Defaults to None. Returns: @@ -772,39 +770,39 @@ def get_scores( def get_prompt_scores( self, *, - attack_id: Optional[str | uuid.UUID] = None, - role: Optional[str] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[Sequence[str | uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - prompt_metadata: Optional[dict[str, Union[str, int]]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[Sequence[str]] = None, - converted_values: Optional[Sequence[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[Sequence[str]] = None, + attack_id: str | uuid.UUID | None = None, + role: str | None = None, + conversation_id: str | uuid.UUID | None = None, + prompt_ids: Sequence[str | uuid.UUID] | None = None, + labels: dict[str, str] | None = None, + prompt_metadata: dict[str, str | int] | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + original_values: Sequence[str] | None = None, + converted_values: Sequence[str] | None = None, + data_type: str | None = None, + not_data_type: str | None = None, + converted_value_sha256: Sequence[str] | None = None, ) -> Sequence[Score]: """ Retrieve scores attached to message pieces based on the specified filters. Args: - attack_id (Optional[str | uuid.UUID], optional): The ID of the attack. Defaults to None. - role (Optional[str], optional): The role of the prompt. Defaults to None. - conversation_id (Optional[str | uuid.UUID], optional): The ID of the conversation. Defaults to None. - prompt_ids (Optional[Sequence[str] | Sequence[uuid.UUID]], optional): A list of prompt IDs. + attack_id (str | uuid.UUID | None, optional): The ID of the attack. Defaults to None. + role (str | None, optional): The role of the prompt. Defaults to None. + conversation_id (str | uuid.UUID | None, optional): The ID of the conversation. Defaults to None. + prompt_ids (Sequence[str] | Sequence[uuid.UUID] | None, optional): A list of prompt IDs. Defaults to None. - labels (Optional[dict[str, str]], optional): A dictionary of labels. Defaults to None. - prompt_metadata (Optional[dict[str, Union[str, int]]], optional): The metadata associated with the prompt. + labels (dict[str, str] | None, optional): A dictionary of labels. Defaults to None. + prompt_metadata (dict[str, str | int] | None, optional): The metadata associated with the prompt. Defaults to None. - sent_after (Optional[datetime], optional): Filter for prompts sent after this datetime. Defaults to None. - sent_before (Optional[datetime], optional): Filter for prompts sent before this datetime. Defaults to None. - original_values (Optional[Sequence[str]], optional): A list of original values. Defaults to None. - converted_values (Optional[Sequence[str]], optional): A list of converted values. Defaults to None. - data_type (Optional[str], optional): The data type to filter by. Defaults to None. - not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. - converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. + sent_after (datetime | None, optional): Filter for prompts sent after this datetime. Defaults to None. + sent_before (datetime | None, optional): Filter for prompts sent before this datetime. Defaults to None. + original_values (Sequence[str] | None, optional): A list of original values. Defaults to None. + converted_values (Sequence[str] | None, optional): A list of converted values. Defaults to None. + data_type (str | None, optional): The data type to filter by. Defaults to None. + not_data_type (str | None, optional): The data type to exclude. Defaults to None. + converted_value_sha256 (Sequence[str] | None, optional): A list of SHA256 hashes of converted values. Defaults to None. Returns: @@ -879,42 +877,42 @@ def get_request_from_response(self, *, response: Message) -> Message: def get_message_pieces( self, *, - attack_id: Optional[str | uuid.UUID] = None, - role: Optional[str] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[Sequence[str | uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - prompt_metadata: Optional[dict[str, Union[str, int]]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[Sequence[str]] = None, - converted_values: Optional[Sequence[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[Sequence[str]] = None, - identifier_filters: Optional[Sequence[IdentifierFilter]] = None, + attack_id: str | uuid.UUID | None = None, + role: str | None = None, + conversation_id: str | uuid.UUID | None = None, + prompt_ids: Sequence[str | uuid.UUID] | None = None, + labels: dict[str, str] | None = None, + prompt_metadata: dict[str, str | int] | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + original_values: Sequence[str] | None = None, + converted_values: Sequence[str] | None = None, + data_type: str | None = None, + not_data_type: str | None = None, + converted_value_sha256: Sequence[str] | None = None, + identifier_filters: Sequence[IdentifierFilter] | None = None, ) -> Sequence[MessagePiece]: """ Retrieve a list of MessagePiece objects based on the specified filters. Args: - attack_id (Optional[str | uuid.UUID], optional): The ID of the attack. Defaults to None. - role (Optional[str], optional): The role of the prompt. Defaults to None. - conversation_id (Optional[str | uuid.UUID], optional): The ID of the conversation. Defaults to None. - prompt_ids (Optional[Sequence[str] | Sequence[uuid.UUID]], optional): A list of prompt IDs. + attack_id (str | uuid.UUID | None, optional): The ID of the attack. Defaults to None. + role (str | None, optional): The role of the prompt. Defaults to None. + conversation_id (str | uuid.UUID | None, optional): The ID of the conversation. Defaults to None. + prompt_ids (Sequence[str] | Sequence[uuid.UUID] | None, optional): A list of prompt IDs. Defaults to None. - labels (Optional[dict[str, str]], optional): A dictionary of labels. Defaults to None. - prompt_metadata (Optional[dict[str, Union[str, int]]], optional): The metadata associated with the prompt. + labels (dict[str, str] | None, optional): A dictionary of labels. Defaults to None. + prompt_metadata (dict[str, str | int] | None, optional): The metadata associated with the prompt. Defaults to None. - sent_after (Optional[datetime], optional): Filter for prompts sent after this datetime. Defaults to None. - sent_before (Optional[datetime], optional): Filter for prompts sent before this datetime. Defaults to None. - original_values (Optional[Sequence[str]], optional): A list of original values. Defaults to None. - converted_values (Optional[Sequence[str]], optional): A list of converted values. Defaults to None. - data_type (Optional[str], optional): The data type to filter by. Defaults to None. - not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. - converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. + sent_after (datetime | None, optional): Filter for prompts sent after this datetime. Defaults to None. + sent_before (datetime | None, optional): Filter for prompts sent before this datetime. Defaults to None. + original_values (Sequence[str] | None, optional): A list of original values. Defaults to None. + converted_values (Sequence[str] | None, optional): A list of converted values. Defaults to None. + data_type (str | None, optional): The data type to filter by. Defaults to None. + not_data_type (str | None, optional): The data type to exclude. Defaults to None. + converted_value_sha256 (Sequence[str] | None, optional): A list of SHA256 hashes of converted values. Defaults to None. - identifier_filters (Optional[Sequence[IdentifierFilter]], optional): + identifier_filters (Sequence[IdentifierFilter] | None, optional): A sequence of IdentifierFilter objects that allow filtering by various identifier JSON properties. Defaults to None. @@ -1164,7 +1162,7 @@ def update_labels_by_conversation_id(self, *, conversation_id: str, labels: dict ) def update_prompt_metadata_by_conversation_id( - self, *, conversation_id: str, prompt_metadata: dict[str, Union[str, int]] + self, *, conversation_id: str, prompt_metadata: dict[str, str | int] ) -> bool: """ Update the metadata of prompt entries in memory for a given conversation ID. @@ -1232,20 +1230,20 @@ def cleanup(self) -> None: def get_seeds( self, *, - value: Optional[str] = None, - value_sha256: Optional[Sequence[str]] = None, - dataset_name: Optional[str] = None, - dataset_name_pattern: Optional[str] = None, - data_types: Optional[Sequence[str]] = None, - harm_categories: Optional[Sequence[str]] = None, - added_by: Optional[str] = None, - authors: Optional[Sequence[str]] = None, - groups: Optional[Sequence[str]] = None, - source: Optional[str] = None, - seed_type: Optional[SeedType] = None, - parameters: Optional[Sequence[str]] = None, - metadata: Optional[dict[str, Union[str, int]]] = None, - prompt_group_ids: Optional[Sequence[uuid.UUID]] = None, + value: str | None = None, + value_sha256: Sequence[str] | None = None, + dataset_name: str | None = None, + dataset_name_pattern: str | None = None, + data_types: Sequence[str] | None = None, + harm_categories: Sequence[str] | None = None, + added_by: str | None = None, + authors: Sequence[str] | None = None, + groups: Sequence[str] | None = None, + source: str | None = None, + seed_type: SeedType | None = None, + parameters: Sequence[str] | None = None, + metadata: dict[str, str | int] | None = None, + prompt_group_ids: Sequence[uuid.UUID] | None = None, ) -> Sequence[Seed]: """ Retrieve a list of seed prompts based on the specified filters. @@ -1258,7 +1256,7 @@ def get_seeds( Supports wildcards: % (any characters) and _ (single character). Examples: "harm%" matches names starting with "harm", "%test%" matches names containing "test". If both dataset_name and dataset_name_pattern are provided, dataset_name takes precedence. - data_types (Optional[Sequence[str], Optional): List of data types to filter seed prompts by + data_types (Sequence[str] | None): List of data types to filter seed prompts by (e.g., text, image_path). harm_categories (Sequence[str]): A list of harm categories to filter by. If None, all harm categories are considered. @@ -1327,7 +1325,7 @@ def get_seeds( raise def _add_list_conditions( - self, field: InstrumentedAttribute[Any], conditions: list[Any], values: Optional[Sequence[str]] = None + self, field: InstrumentedAttribute[Any], conditions: list[Any], values: Sequence[str] | None = None ) -> None: if values: conditions.extend(field.contains(value) for value in values) @@ -1364,7 +1362,7 @@ async def _serialize_seed_value_async(self, prompt: Seed) -> str: serialized_prompt_value = str(serializer.value) return serialized_prompt_value or "" - async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: Optional[str] = None) -> None: + async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: str | None = None) -> None: """ Insert a list of seeds into the memory storage. @@ -1441,7 +1439,7 @@ def get_seed_dataset_names(self) -> Sequence[str]: raise async def add_seed_groups_to_memory_async( - self, *, prompt_groups: Sequence[SeedGroup], added_by: Optional[str] = None + self, *, prompt_groups: Sequence[SeedGroup], added_by: str | None = None ) -> None: """ Insert a list of seed groups into the memory storage. @@ -1481,47 +1479,47 @@ async def add_seed_groups_to_memory_async( def get_seed_groups( self, *, - value: Optional[str] = None, - value_sha256: Optional[Sequence[str]] = None, - dataset_name: Optional[str] = None, - dataset_name_pattern: Optional[str] = None, - data_types: Optional[Sequence[str]] = None, - harm_categories: Optional[Sequence[str]] = None, - added_by: Optional[str] = None, - authors: Optional[Sequence[str]] = None, - groups: Optional[Sequence[str]] = None, - source: Optional[str] = None, - seed_type: Optional[SeedType] = None, - parameters: Optional[Sequence[str]] = None, - metadata: Optional[dict[str, Union[str, int]]] = None, - prompt_group_ids: Optional[Sequence[uuid.UUID]] = None, - group_length: Optional[Sequence[int]] = None, + value: str | None = None, + value_sha256: Sequence[str] | None = None, + dataset_name: str | None = None, + dataset_name_pattern: str | None = None, + data_types: Sequence[str] | None = None, + harm_categories: Sequence[str] | None = None, + added_by: str | None = None, + authors: Sequence[str] | None = None, + groups: Sequence[str] | None = None, + source: str | None = None, + seed_type: SeedType | None = None, + parameters: Sequence[str] | None = None, + metadata: dict[str, str | int] | None = None, + prompt_group_ids: Sequence[uuid.UUID] | None = None, + group_length: Sequence[int] | None = None, ) -> Sequence[SeedGroup]: """ Retrieve groups of seed prompts based on the provided filtering criteria. Args: - value (Optional[str], Optional): The value to match by substring. - value_sha256 (Optional[Sequence[str]], Optional): SHA256 hash of value to filter seed groups by. - dataset_name (Optional[str], Optional): Name of the dataset to match exactly. - dataset_name_pattern (Optional[str], Optional): A pattern to match dataset names using SQL LIKE syntax. + value (str | None, Optional): The value to match by substring. + value_sha256 (Sequence[str] | None, Optional): SHA256 hash of value to filter seed groups by. + dataset_name (str | None, Optional): Name of the dataset to match exactly. + dataset_name_pattern (str | None, Optional): A pattern to match dataset names using SQL LIKE syntax. Supports wildcards: % (any characters) and _ (single character). Examples: "harm%" matches names starting with "harm", "%test%" matches names containing "test". If both dataset_name and dataset_name_pattern are provided, dataset_name takes precedence. - data_types (Optional[Sequence[str]], Optional): List of data types to filter seed prompts by + data_types (Sequence[str] | None, Optional): List of data types to filter seed prompts by (e.g., text, image_path). - harm_categories (Optional[Sequence[str]], Optional): List of harm categories to filter seed prompts by. - added_by (Optional[str], Optional): The user who added the seed groups to filter by. - authors (Optional[Sequence[str]], Optional): List of authors to filter seed groups by. - groups (Optional[Sequence[str]], Optional): List of groups to filter seed groups by. - source (Optional[str], Optional): The source from which the seed prompts originated. - seed_type (Optional[SeedType], Optional): The type of seed to filter by ("prompt", "objective", or + harm_categories (Sequence[str] | None, Optional): List of harm categories to filter seed prompts by. + added_by (str | None, Optional): The user who added the seed groups to filter by. + authors (Sequence[str] | None, Optional): List of authors to filter seed groups by. + groups (Sequence[str] | None, Optional): List of groups to filter seed groups by. + source (str | None, Optional): The source from which the seed prompts originated. + seed_type (SeedType | None, Optional): The type of seed to filter by ("prompt", "objective", or "simulated_conversation"). - parameters (Optional[Sequence[str]], Optional): List of parameters to filter by. - metadata (Optional[dict[str, Union[str, int]]], Optional): A free-form dictionary for tagging + parameters (Sequence[str] | None, Optional): List of parameters to filter by. + metadata (dict[str, str | int] | None, Optional): A free-form dictionary for tagging prompts with custom metadata. - prompt_group_ids (Optional[Sequence[uuid.UUID]], Optional): List of prompt group IDs to filter by. - group_length (Optional[Sequence[int]], Optional): The number of seeds in the group to filter by. + prompt_group_ids (Sequence[uuid.UUID] | None, Optional): List of prompt group IDs to filter by. + group_length (Sequence[int] | None, Optional): The number of seeds in the group to filter by. Returns: Sequence[SeedGroup]: A list of `SeedGroup` objects that match the filtering criteria. @@ -1564,18 +1562,18 @@ def get_seed_groups( def export_conversations( self, *, - attack_id: Optional[str | uuid.UUID] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[Sequence[str] | Sequence[uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[Sequence[str]] = None, - converted_values: Optional[Sequence[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[Sequence[str]] = None, - file_path: Optional[Path] = None, + attack_id: str | uuid.UUID | None = None, + conversation_id: str | uuid.UUID | None = None, + prompt_ids: Sequence[str] | Sequence[uuid.UUID] | None = None, + labels: dict[str, str] | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + original_values: Sequence[str] | None = None, + converted_values: Sequence[str] | None = None, + data_type: str | None = None, + not_data_type: str | None = None, + converted_value_sha256: Sequence[str] | None = None, + file_path: Path | None = None, export_type: str = "json", ) -> Path: """ @@ -1583,20 +1581,20 @@ def export_conversations( Defaults to all conversations if no filters are provided. Args: - attack_id (Optional[str | uuid.UUID], optional): The ID of the attack. Defaults to None. - conversation_id (Optional[str | uuid.UUID], optional): The ID of the conversation. Defaults to None. - prompt_ids (Optional[Sequence[str] | Sequence[uuid.UUID]], optional): A list of prompt IDs. + attack_id (str | uuid.UUID | None, optional): The ID of the attack. Defaults to None. + conversation_id (str | uuid.UUID | None, optional): The ID of the conversation. Defaults to None. + prompt_ids (Sequence[str] | Sequence[uuid.UUID] | None, optional): A list of prompt IDs. Defaults to None. - labels (Optional[dict[str, str]], optional): A dictionary of labels. Defaults to None. - sent_after (Optional[datetime], optional): Filter for prompts sent after this datetime. Defaults to None. - sent_before (Optional[datetime], optional): Filter for prompts sent before this datetime. Defaults to None. - original_values (Optional[Sequence[str]], optional): A list of original values. Defaults to None. - converted_values (Optional[Sequence[str]], optional): A list of converted values. Defaults to None. - data_type (Optional[str], optional): The data type to filter by. Defaults to None. - not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. - converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. + labels (dict[str, str] | None, optional): A dictionary of labels. Defaults to None. + sent_after (datetime | None, optional): Filter for prompts sent after this datetime. Defaults to None. + sent_before (datetime | None, optional): Filter for prompts sent before this datetime. Defaults to None. + original_values (Sequence[str] | None, optional): A list of original values. Defaults to None. + converted_values (Sequence[str] | None, optional): A list of converted values. Defaults to None. + data_type (str | None, optional): The data type to filter by. Defaults to None. + not_data_type (str | None, optional): The data type to exclude. Defaults to None. + converted_value_sha256 (Sequence[str] | None, optional): A list of SHA256 hashes of converted values. Defaults to None. - file_path (Optional[Path], optional): The path to the file where the data will be exported. + file_path (Path | None, optional): The path to the file where the data will be exported. Defaults to None. export_type (str, optional): The format of the export. Defaults to "json". @@ -1715,45 +1713,45 @@ def update_attack_result_by_id(self, *, attack_result_id: str, update_fields: di def get_attack_results( self, *, - attack_result_ids: Optional[Sequence[str]] = None, - conversation_id: Optional[str] = None, - objective: Optional[str] = None, - objective_sha256: Optional[Sequence[str]] = None, - outcome: Optional[str] = None, - attack_class: Optional[str] = None, - attack_classes: Optional[Sequence[str]] = None, - atomic_attack_eval_hashes: Optional[Sequence[str]] = None, - converter_classes: Optional[Sequence[str]] = None, + attack_result_ids: Sequence[str] | None = None, + conversation_id: str | None = None, + objective: str | None = None, + objective_sha256: Sequence[str] | None = None, + outcome: str | None = None, + attack_class: str | None = None, + attack_classes: Sequence[str] | None = None, + atomic_attack_eval_hashes: Sequence[str] | None = None, + converter_classes: Sequence[str] | None = None, converter_classes_match: Literal["all", "any"] = "all", - has_converters: Optional[bool] = None, - targeted_harm_categories: Optional[Sequence[str]] = None, - labels: Optional[dict[str, str | Sequence[str]]] = None, - identifier_filters: Optional[Sequence[IdentifierFilter]] = None, - scenario_result_id: Optional[str] = None, + has_converters: bool | None = None, + targeted_harm_categories: Sequence[str] | None = None, + labels: dict[str, str | Sequence[str]] | None = None, + identifier_filters: Sequence[IdentifierFilter] | None = None, + scenario_result_id: str | None = None, ) -> Sequence[AttackResult]: """ Retrieve a list of AttackResult objects based on the specified filters. Args: - attack_result_ids (Optional[Sequence[str]], optional): A list of attack result IDs. Defaults to None. - conversation_id (Optional[str], optional): The conversation ID to filter by. Defaults to None. - objective (Optional[str], optional): The objective to filter by (substring match). Defaults to None. - objective_sha256 (Optional[Sequence[str]], optional): A list of objective SHA256 hashes to filter by. + attack_result_ids (Sequence[str] | None, optional): A list of attack result IDs. Defaults to None. + conversation_id (str | None, optional): The conversation ID to filter by. Defaults to None. + objective (str | None, optional): The objective to filter by (substring match). Defaults to None. + objective_sha256 (Sequence[str] | None, optional): A list of objective SHA256 hashes to filter by. Defaults to None. - outcome (Optional[str], optional): The outcome to filter by (success, failure, undetermined). + outcome (str | None, optional): The outcome to filter by (success, failure, undetermined). Defaults to None. - attack_class (Optional[str], optional): Deprecated. Filter by a single exact attack + attack_class (str | None, optional): Deprecated. Filter by a single exact attack class_name in attack_identifier. Equivalent to passing ``attack_classes=[attack_class]``. Cannot be combined with ``attack_classes``. Defaults to None. - attack_classes (Optional[Sequence[str]], optional): Filter by exact attack class_name in + attack_classes (Sequence[str] | None, optional): Filter by exact attack class_name in attack_identifier. Returns attacks matching ANY of the listed class names (OR logic, case-sensitive). An empty sequence applies no filter. Defaults to None. - atomic_attack_eval_hashes (Optional[Sequence[str]], optional): Filter by behavioral + atomic_attack_eval_hashes (Sequence[str] | None, optional): Filter by behavioral equivalence hash on ``atomic_attack_identifier.eval_hash`` (auto-stamped on persistence by ``AtomicAttackEvaluationIdentifier``). Returns results matching ANY of the listed hashes (OR logic, case-sensitive). Designed for ASR aggregation by technique configuration. An empty sequence applies no filter. Defaults to None. - converter_classes (Optional[Sequence[str]], optional): Filter by converter class names. + converter_classes (Sequence[str] | None, optional): Filter by converter class names. Combination semantics for multiple entries are controlled by ``converter_classes_match``. An empty sequence filters to attacks that used no converters; ``None`` applies no filter. To filter by presence/absence of any converter explicitly, use the @@ -1763,17 +1761,17 @@ def get_attack_results( converter (AND, case-insensitive). ``"any"`` matches attacks that used at least one listed converter (OR, case-insensitive). Ignored when ``converter_classes`` has fewer than 2 entries or is empty. - has_converters (Optional[bool], optional): Filter by converter presence. + has_converters (bool | None, optional): Filter by converter presence. ``True`` returns only attacks that used at least one converter. ``False`` returns only attacks that used no converters. ``None`` applies no filter. Defaults to None. - targeted_harm_categories (Optional[Sequence[str]], optional): + targeted_harm_categories (Sequence[str] | None, optional): A list of targeted harm categories to filter results by. These targeted harm categories are associated with the prompts themselves, meaning they are harm(s) we're trying to elicit with the prompt, not necessarily one(s) that were found in the response. By providing a list, this means ALL categories in the list must be present. Defaults to None. - labels (Optional[dict[str, str | Sequence[str]]], optional): Filter results + labels (dict[str, str | Sequence[str]] | None, optional): Filter results by attack labels. Entries are AND-combined across label names; within a single entry, a string value is an equality match and a sequence value is an OR match over the listed values. An empty sequence applies no filter @@ -1781,10 +1779,10 @@ def get_attack_results( ["roakey_op_a", "roakey_op_b"]}`` matches attacks where ``operator == "roakey"`` AND (``operation == "roakey_op_a"`` OR ``operation == "roakey_op_b"``). Defaults to None. - identifier_filters (Optional[Sequence[IdentifierFilter]], optional): + identifier_filters (Sequence[IdentifierFilter] | None, optional): A sequence of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. - scenario_result_id (Optional[str], optional): Filter to attack results linked to a + scenario_result_id (str | None, optional): Filter to attack results linked to a specific scenario via the ``AttackResultEntry.attribution_parent_id`` foreign key. Combined with ``outcome=AttackOutcome.ERROR`` this is the replacement for the removed per-scenario error_attack_result_ids manifest. Defaults to None. @@ -2091,16 +2089,16 @@ def update_scenario_metadata( def get_scenario_results( self, *, - scenario_result_ids: Optional[Sequence[str]] = None, - scenario_name: Optional[str] = None, - scenario_version: Optional[int] = None, - pyrit_version: Optional[str] = None, - added_after: Optional[datetime] = None, - added_before: Optional[datetime] = None, - labels: Optional[dict[str, str]] = None, - objective_target_endpoint: Optional[str] = None, - objective_target_model_name: Optional[str] = None, - identifier_filters: Optional[Sequence[IdentifierFilter]] = None, + scenario_result_ids: Sequence[str] | None = None, + scenario_name: str | None = None, + scenario_version: int | None = None, + pyrit_version: str | None = None, + added_after: datetime | None = None, + added_before: datetime | None = None, + labels: dict[str, str] | None = None, + objective_target_endpoint: str | None = None, + objective_target_model_name: str | None = None, + identifier_filters: Sequence[IdentifierFilter] | None = None, limit: int | None = None, ) -> Sequence[ScenarioResult]: """ @@ -2109,25 +2107,25 @@ def get_scenario_results( Results are always ordered by completion_time descending (most recent first). Args: - scenario_result_ids (Optional[Sequence[str]], optional): A list of scenario result IDs. + scenario_result_ids (Sequence[str] | None, optional): A list of scenario result IDs. Defaults to None. - scenario_name (Optional[str], optional): The scenario name to filter by (substring match). + scenario_name (str | None, optional): The scenario name to filter by (substring match). Defaults to None. - scenario_version (Optional[int], optional): The scenario version to filter by. Defaults to None. - pyrit_version (Optional[str], optional): The PyRIT version to filter by. Defaults to None. - added_after (Optional[datetime], optional): Filter for scenarios completed after this datetime. + scenario_version (int | None, optional): The scenario version to filter by. Defaults to None. + pyrit_version (str | None, optional): The PyRIT version to filter by. Defaults to None. + added_after (datetime | None, optional): Filter for scenarios completed after this datetime. Defaults to None. - added_before (Optional[datetime], optional): Filter for scenarios completed before this datetime. + added_before (datetime | None, optional): Filter for scenarios completed before this datetime. Defaults to None. - labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter by. + labels (dict[str, str] | None, optional): A dictionary of memory labels to filter by. Defaults to None. - objective_target_endpoint (Optional[str], optional): Filter for scenarios where the + objective_target_endpoint (str | None, optional): Filter for scenarios where the objective_target_identifier has an endpoint attribute containing this value (case-insensitive). Defaults to None. - objective_target_model_name (Optional[str], optional): Filter for scenarios where the + objective_target_model_name (str | None, optional): Filter for scenarios where the objective_target_identifier has a model_name attribute containing this value (case-insensitive). Defaults to None. - identifier_filters (Optional[Sequence[IdentifierFilter]], optional): + identifier_filters (Sequence[IdentifierFilter] | None, optional): A sequence of IdentifierFilter objects that allows filtering by identifier JSON properties. Defaults to None. limit (int | None): Maximum number of results to return. Defaults to None (no limit). diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index a5c277b3f1..9f52b38afa 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -232,14 +232,14 @@ class PromptMemoryEntry(Base): sequence (int): The order of the conversation within a conversation_id. Can be the same number for multi-part requests or multi-part responses. timestamp (DateTime): The timestamp of the memory entry. - labels (Dict[str, str]): The labels associated with the memory entry. Several can be standardized. - targeted_harm_categories (List[str]): The targeted harm categories for the memory entry. + labels (dict[str, str]): The labels associated with the memory entry. Several can be standardized. + targeted_harm_categories (list[str]): The targeted harm categories for the memory entry. prompt_metadata (JSON): The metadata associated with the prompt. This can be specific to any scenarios. Because memory is how components talk with each other, this can be component specific. e.g. the URI from a file uploaded to a blob store, or a document type you want to upload. converters (list[PromptConverter]): The converters for the prompt. prompt_target (PromptTarget): The target for the prompt. - attack_identifier (Dict[str, str]): The attack identifier for the prompt. + attack_identifier (dict[str, str]): The attack identifier for the prompt. original_value_data_type (PromptDataType): The data type of the original prompt (text, image) original_value (str): The text of the original prompt. If prompt is an image, it's a link. original_value_sha256 (str): The SHA256 hash of the original prompt data. @@ -564,16 +564,16 @@ class SeedEntry(Base): value_sha256 (str): The SHA256 hash of the value of the seed prompt data. data_type (PromptDataType): The data type of the seed prompt. dataset_name (str): The name of the dataset the seed prompt belongs to. - harm_categories (List[str]): The harm categories associated with the seed prompt. + harm_categories (list[str]): The harm categories associated with the seed prompt. description (str): The description of the seed prompt. - authors (List[str]): The authors of the seed prompt. - groups (List[str]): The groups involved in authoring the seed prompt (if any). + authors (list[str]): The authors of the seed prompt. + groups (list[str]): The groups involved in authoring the seed prompt (if any). source (str): The source of the seed prompt. date_added (DateTime): The date the seed prompt was added. added_by (str): The user who added the seed prompt. prompt_metadata (dict[str, str | int]): The metadata associated with the seed prompt. This includes information that is useful for the specific target you're probing, such as encoding data. - parameters (List[str]): The parameters included in the value. + parameters (list[str]): The parameters included in the value. Note that seed prompts do not have parameters, only prompt templates do. However, they are stored in the same table. prompt_group_id (uuid.UUID): The ID of a group the seed prompt may optionally belong to. @@ -742,8 +742,8 @@ class AttackResultEntry(Base): outcome_reason (str): Optional reason for the outcome, providing additional context. attack_metadata (dict[str, Any]): Metadata can be included as key-value pairs to provide extra context. labels (dict[str, str]): Optional labels associated with the attack result entry. - pruned_conversation_ids (List[str]): List of conversation IDs that were pruned from the attack. - adversarial_chat_conversation_ids (List[str]): List of conversation IDs used for adversarial chat. + pruned_conversation_ids (list[str]): List of conversation IDs that were pruned from the attack. + adversarial_chat_conversation_ids (list[str]): List of conversation IDs used for adversarial chat. timestamp (DateTime): The timestamp of the attack result entry. last_response (PromptMemoryEntry): Relationship to the last response prompt memory entry. last_score (ScoreEntry): Relationship to the last score entry. @@ -1083,7 +1083,7 @@ def __init__(self, *, entry: ScenarioResult) -> None: self.number_tries = entry.number_tries self.completion_time = entry.completion_time - # Serialize attack_results: dict[str, List[AttackResult]] -> dict[str, List[str]] + # Serialize attack_results: dict[str, list[AttackResult]] -> dict[str, list[str]] # Store only conversation_ids - the full AttackResults can be queried from the database serialized_attack_results = {} for attack_name, results in entry.attack_results.items(): diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 461d2b871b..5f628ab075 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -8,7 +8,7 @@ from contextlib import closing, suppress from datetime import datetime from pathlib import Path -from typing import Any, Literal, Optional, TypeVar, Union, cast +from typing import Any, Literal, TypeVar, cast from sqlalchemy import and_, create_engine, exists, func, or_, text from sqlalchemy.engine.base import Engine @@ -59,7 +59,7 @@ class SQLiteMemory(MemoryInterface, metaclass=Singleton): def __init__( self, *, - db_path: Optional[Union[Path, str]] = None, + db_path: Path | str | None = None, verbose: bool = False, skip_schema_migration: bool = False, silent: bool = False, @@ -68,7 +68,7 @@ def __init__( Initialize the SQLiteMemory instance. Args: - db_path (Optional[Union[Path, str]]): Path to the SQLite database file. + db_path (Path | str | None): Path to the SQLite database file. Defaults to "pyrit.db". verbose (bool): Whether to enable verbose logging. Defaults to False. @@ -80,7 +80,7 @@ def __init__( super().__init__() if db_path == ":memory:": - self.db_path: Union[Path, str] = ":memory:" + self.db_path: Path | str = ":memory:" else: self.db_path = Path(db_path or Path(DB_DATA_PATH, self.DEFAULT_DB_FILE_NAME)).resolve() self.results_path = str(DB_DATA_PATH) @@ -179,7 +179,7 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str return [or_(pme_match, are_match)] def _get_message_pieces_prompt_metadata_conditions( - self, *, prompt_metadata: dict[str, Union[str, int]] + self, *, prompt_metadata: dict[str, str | int] ) -> list[TextClause]: """ Generate SQLAlchemy filter conditions for filtering conversation pieces by prompt metadata. @@ -195,7 +195,7 @@ def _get_message_pieces_prompt_metadata_conditions( condition = text(json_conditions).bindparams(**{key: str(value) for key, value in prompt_metadata.items()}) return [condition] - def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: + def _get_seed_metadata_conditions(self, *, metadata: dict[str, str | int]) -> Any: """ Generate SQLAlchemy filter conditions for filtering seed prompts by metadata. @@ -256,7 +256,7 @@ def _get_condition_json_array_match( Args: json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. property_path (str): The JSON path for the target array. - array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_element_path (str | None): An optional JSON path applied to each array item before matching. array_to_match (Sequence[str]): The array that must match the extracted JSON array values. Combination semantics for multiple entries are controlled by ``match_mode``. If ``array_to_match`` is empty, the condition matches only if the target is also an @@ -334,10 +334,10 @@ def _query_entries( self, model_class: type[Model], *, - conditions: Optional[Any] = None, + conditions: Any | None = None, distinct: bool = False, join_scores: bool = False, - order_by: Optional[Any] = None, + order_by: Any | None = None, limit: int | None = None, ) -> MutableSequence[Model]: """ @@ -487,18 +487,18 @@ def dispose_engine(self) -> None: def export_conversations( self, *, - attack_id: Optional[str | uuid.UUID] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[Sequence[str] | Sequence[uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[Sequence[str]] = None, - converted_values: Optional[Sequence[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[Sequence[str]] = None, - file_path: Optional[Path] = None, + attack_id: str | uuid.UUID | None = None, + conversation_id: str | uuid.UUID | None = None, + prompt_ids: Sequence[str] | Sequence[uuid.UUID] | None = None, + labels: dict[str, str] | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + original_values: Sequence[str] | None = None, + converted_values: Sequence[str] | None = None, + data_type: str | None = None, + not_data_type: str | None = None, + converted_value_sha256: Sequence[str] | None = None, + file_path: Path | None = None, export_type: str = "json", ) -> Path: """ diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 4fc11fbd0a..d3fcaade31 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -4,7 +4,7 @@ import base64 import json from pathlib import Path -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any import aiofiles @@ -93,7 +93,7 @@ async def normalize_async(self, messages: list[Message]) -> list[ChatMessage]: # Use simple string for single text piece, otherwise use content list if len(pieces) == 1 and pieces[0].converted_value_data_type == "text": - content: Union[str, list[dict[str, Any]]] = pieces[0].converted_value + content: str | list[dict[str, Any]] = pieces[0].converted_value else: content = [await self._piece_to_content_dict_async(piece) for piece in pieces] diff --git a/pyrit/message_normalizer/tokenizer_template_normalizer.py b/pyrit/message_normalizer/tokenizer_template_normalizer.py index caab9fa337..06c904505b 100644 --- a/pyrit/message_normalizer/tokenizer_template_normalizer.py +++ b/pyrit/message_normalizer/tokenizer_template_normalizer.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Literal, Optional, cast +from typing import TYPE_CHECKING, ClassVar, Literal, cast from pyrit.common import get_non_required_value from pyrit.message_normalizer.chat_message_normalizer import ChatMessageNormalizer @@ -109,7 +109,7 @@ def __init__( self.system_message_behavior = system_message_behavior @staticmethod - def _load_tokenizer(model_name: str, token: Optional[str]) -> "PreTrainedTokenizerBase": + def _load_tokenizer(model_name: str, token: str | None) -> "PreTrainedTokenizerBase": """ Load a tokenizer from HuggingFace. @@ -134,8 +134,8 @@ def from_model( cls, model_name_or_alias: str, *, - token: Optional[str] = None, - system_message_behavior: Optional[TokenizerSystemBehavior] = None, + token: str | None = None, + system_message_behavior: TokenizerSystemBehavior | None = None, ) -> "TokenizerTemplateNormalizer": """ Create a normalizer from a model name or alias. diff --git a/pyrit/models/chat_message.py b/pyrit/models/chat_message.py index c2f801862d..b873b33333 100644 --- a/pyrit/models/chat_message.py +++ b/pyrit/models/chat_message.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Any, Optional, Union +from typing import Any from pydantic import BaseModel, ConfigDict @@ -30,10 +30,10 @@ class ChatMessage(BaseModel): model_config = ConfigDict(extra="forbid") role: ChatMessageRole - content: Union[str, list[dict[str, Any]]] - name: Optional[str] = None - tool_calls: Optional[list[ToolCall]] = None - tool_call_id: Optional[str] = None + content: str | list[dict[str, Any]] + name: str | None = None + tool_calls: list[ToolCall] | None = None + tool_call_id: str | None = None def to_dict(self) -> dict[str, Any]: """ diff --git a/pyrit/models/conversation_reference.py b/pyrit/models/conversation_reference.py index 0915a045c4..6e39cfd233 100644 --- a/pyrit/models/conversation_reference.py +++ b/pyrit/models/conversation_reference.py @@ -4,7 +4,6 @@ from __future__ import annotations from enum import Enum -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -27,7 +26,7 @@ class ConversationReference(BaseModel): conversation_id: str conversation_type: ConversationType - description: Optional[str] = None + description: str | None = None def __hash__(self) -> int: """ diff --git a/pyrit/models/conversation_stats.py b/pyrit/models/conversation_stats.py index 67b09e24be..14497954f5 100644 --- a/pyrit/models/conversation_stats.py +++ b/pyrit/models/conversation_stats.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from datetime import datetime -from typing import ClassVar, Optional +from typing import ClassVar from pydantic import BaseModel, ConfigDict, Field @@ -29,7 +29,7 @@ class ConversationStats(BaseModel): """ message_count: int = 0 - last_message_preview: Optional[str] = None - last_message_data_type: Optional[PromptDataType] = None + last_message_preview: str | None = None + last_message_data_type: PromptDataType | None = None labels: dict[str, str] = Field(default_factory=dict) - created_at: Optional[datetime] = None + created_at: datetime | None = None diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index 2528423394..2cf9e6593d 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -12,7 +12,7 @@ import wave from mimetypes import guess_type from pathlib import Path -from typing import TYPE_CHECKING, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Literal, get_args from urllib.parse import urlparse import aiofiles @@ -48,8 +48,8 @@ def _write_wav_sync( def data_serializer_factory( *, data_type: PromptDataType, - value: Optional[str] = None, - extension: Optional[str] = None, + value: str | None = None, + extension: str | None = None, category: AllowedCategories, ) -> DataTypeSerializer: """ @@ -58,7 +58,7 @@ def data_serializer_factory( Args: data_type (str): The type of the data (e.g., 'text', 'image_path', 'audio_path'). value (str): The data value to be serialized. - extension (Optional[str]): The file extension, if applicable. + extension (str | None): The file extension, if applicable. category (AllowedCategories): The category or context for the data (e.g., 'seed-prompt-entries'). Returns: @@ -114,7 +114,7 @@ class DataTypeSerializer(abc.ABC): data_sub_directory: str file_extension: str - _file_path: Union[Path, str] | None = None + _file_path: Path | str | None = None @property def _memory(self) -> MemoryInterface: @@ -152,7 +152,7 @@ def data_on_disk(self) -> bool: """ - async def save_data_async(self, data: bytes, output_filename: Optional[str] = None) -> None: + async def save_data_async(self, data: bytes, output_filename: str | None = None) -> None: """ Save data to storage. @@ -193,7 +193,7 @@ async def save_formatted_audio_async( num_channels: int = 1, sample_width: int = 2, sample_rate: int = 16000, - output_filename: Optional[str] = None, + output_filename: str | None = None, ) -> None: """ Save PCM16 or similarly formatted audio data to storage. @@ -312,15 +312,15 @@ async def get_sha256_async(self) -> str: hash_object = hashlib.sha256(input_bytes) return hash_object.hexdigest() - async def get_data_filename_async(self, file_name: Optional[str] = None) -> Union[Path, str]: + async def get_data_filename_async(self, file_name: str | None = None) -> Path | str: """ Generate or retrieve a unique filename for the data file. Args: - file_name (Optional[str]): Optional file name override. + file_name (str | None): Optional file name override. Returns: - Union[Path, str]: Full storage path for the generated data file. + Path | str: Full storage path for the generated data file. Raises: TypeError: If the serializer is not configured for on-disk data. @@ -358,7 +358,7 @@ async def get_data_filename_async(self, file_name: Optional[str] = None) -> Unio return self._file_path async def save_data( # pyrit-async-suffix-exempt - self, data: bytes, output_filename: Optional[str] = None + self, data: bytes, output_filename: str | None = None ) -> None: """ Save data to storage (deprecated alias of ``save_data_async``). @@ -397,7 +397,7 @@ async def save_formatted_audio( # pyrit-async-suffix-exempt num_channels: int = 1, sample_width: int = 2, sample_rate: int = 16000, - output_filename: Optional[str] = None, + output_filename: str | None = None, ) -> None: """ Save formatted audio data to storage (deprecated alias of ``save_formatted_audio_async``). @@ -459,8 +459,8 @@ async def get_sha256(self) -> str: # pyrit-async-suffix-exempt return await self.get_sha256_async() async def get_data_filename( # pyrit-async-suffix-exempt - self, file_name: Optional[str] = None - ) -> Union[Path, str]: + self, file_name: str | None = None + ) -> Path | str: """ Generate or retrieve a unique filename for the data file (deprecated alias of ``get_data_filename_async``). @@ -468,7 +468,7 @@ async def get_data_filename( # pyrit-async-suffix-exempt file_name: Optional file name override. Returns: - Union[Path, str]: Full storage path for the generated data file. + Path | str: Full storage path for the generated data file. """ print_deprecation_message( old_item="pyrit.models.data_type_serializer.DataTypeSerializer.get_data_filename", @@ -576,14 +576,14 @@ def data_on_disk(self) -> bool: class URLDataTypeSerializer(DataTypeSerializer): """Serializer for URL values and URL-backed local file references.""" - def __init__(self, *, category: str, prompt_text: str, extension: Optional[str] = None) -> None: + def __init__(self, *, category: str, prompt_text: str, extension: str | None = None) -> None: """ Initialize a URL serializer. Args: category (str): Data category folder name. prompt_text (str): URL or path value. - extension (Optional[str]): Optional extension for persisted content. + extension (str | None): Optional extension for persisted content. """ self.data_type = "url" @@ -606,14 +606,14 @@ def data_on_disk(self) -> bool: class ImagePathDataTypeSerializer(DataTypeSerializer): """Serializer for image path values stored on disk.""" - def __init__(self, *, category: str, prompt_text: Optional[str] = None, extension: Optional[str] = None) -> None: + def __init__(self, *, category: str, prompt_text: str | None = None, extension: str | None = None) -> None: """ Initialize an image-path serializer. Args: category (str): Data category folder name. - prompt_text (Optional[str]): Optional existing image path. - extension (Optional[str]): Optional image extension. + prompt_text (str | None): Optional existing image path. + extension (str | None): Optional image extension. """ self.data_type = "image_path" @@ -641,16 +641,16 @@ def __init__( self, *, category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, + prompt_text: str | None = None, + extension: str | None = None, ) -> None: """ Initialize an audio-path serializer. Args: category (str): Data category folder name. - prompt_text (Optional[str]): Optional existing audio path. - extension (Optional[str]): Optional audio extension. + prompt_text (str | None): Optional existing audio path. + extension (str | None): Optional audio extension. """ self.data_type = "audio_path" @@ -678,16 +678,16 @@ def __init__( self, *, category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, + prompt_text: str | None = None, + extension: str | None = None, ) -> None: """ Initialize a video-path serializer. Args: category (str): The category or context for the data. - prompt_text (Optional[str]): The video path or identifier. - extension (Optional[str]): The file extension, defaults to 'mp4'. + prompt_text (str | None): The video path or identifier. + extension (str | None): The file extension, defaults to 'mp4'. """ self.data_type = "video_path" @@ -715,8 +715,8 @@ def __init__( self, *, category: str, - prompt_text: Optional[str] = None, - extension: Optional[str] = None, + prompt_text: str | None = None, + extension: str | None = None, ) -> None: """ Initialize a generic binary-path serializer. @@ -727,8 +727,8 @@ def __init__( Args: category (str): The category or context for the data. - prompt_text (Optional[str]): The binary file path or identifier. - extension (Optional[str]): The file extension, defaults to 'bin'. + prompt_text (str | None): The binary file path or identifier. + extension (str | None): The file extension, defaults to 'bin'. """ self.data_type = "binary_path" diff --git a/pyrit/models/harm_definition.py b/pyrit/models/harm_definition.py index 9e739244ab..ab555ac8ad 100644 --- a/pyrit/models/harm_definition.py +++ b/pyrit/models/harm_definition.py @@ -10,7 +10,6 @@ import logging import re from pathlib import Path -from typing import Optional, Union import yaml from pydantic import BaseModel, Field @@ -53,9 +52,9 @@ class HarmDefinition(BaseModel): version: str category: str scale_descriptions: list[ScaleDescription] = Field(default_factory=list) - source_path: Optional[str] = None + source_path: str | None = None - def get_scale_description(self, score_value: str) -> Optional[str]: + def get_scale_description(self, score_value: str) -> str | None: """ Get the description for a specific score value. @@ -101,7 +100,7 @@ def validate_category(category: str, *, check_exists: bool = False) -> bool: return True @classmethod - def from_yaml(cls, harm_definition_path: Union[str, Path]) -> "HarmDefinition": + def from_yaml(cls, harm_definition_path: str | Path) -> "HarmDefinition": """ Load and validate a harm definition from a YAML file. @@ -177,7 +176,7 @@ def get_all_harm_definitions() -> dict[str, HarmDefinition]: and loads each one as a HarmDefinition. Returns: - Dict[str, HarmDefinition]: A dictionary mapping category names to their + dict[str, HarmDefinition]: A dictionary mapping category names to their HarmDefinition objects. The keys are the category names from the YAML files (e.g., "violence", "hate_speech"). diff --git a/pyrit/models/identifiers/component_identifier.py b/pyrit/models/identifiers/component_identifier.py index d3d0c71933..b8e7ad1895 100644 --- a/pyrit/models/identifiers/component_identifier.py +++ b/pyrit/models/identifiers/component_identifier.py @@ -20,7 +20,7 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, ClassVar, Optional, Union +from typing import Any, ClassVar from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, model_serializer, model_validator @@ -56,7 +56,7 @@ def config_hash(config_dict: dict[str, Any]) -> str: ensure determinism. Args: - config_dict (Dict[str, Any]): A JSON-serializable dictionary. + config_dict (dict[str, Any]): A JSON-serializable dictionary. Returns: str: Hex-encoded SHA256 hash string. @@ -85,11 +85,11 @@ def _build_hash_dict( Args: class_name (str): The component's class name. class_module (str): The component's module path. - params (Dict[str, Any]): Behavioral parameters (non-None values only). - children (Dict[str, Any]): Child name to ComponentIdentifier or list of ComponentIdentifier. + params (dict[str, Any]): Behavioral parameters (non-None values only). + children (dict[str, Any]): Child name to ComponentIdentifier or list of ComponentIdentifier. Returns: - Dict[str, Any]: The canonical dictionary for hashing. + dict[str, Any]: The canonical dictionary for hashing. """ hash_dict: dict[str, Any] = { ComponentIdentifier.KEY_CLASS_NAME: class_name, @@ -162,16 +162,16 @@ class ComponentIdentifier(BaseModel): #: Behavioral parameters that affect output. params: dict[str, Any] = Field(default_factory=dict) #: Named child identifiers for compositional identity (e.g., a scorer's target). - children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] = Field(default_factory=dict) + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] = Field(default_factory=dict) #: Content-addressed SHA256 hash. Computed automatically when ``None``; #: pass an explicit value to preserve a hash from DB storage where params #: may have been truncated. - hash: Optional[str] = None + hash: str | None = None #: Version tag for storage. Not included in the content hash. pyrit_version: str = Field(default=pyrit.__version__) #: Evaluation hash. Computed by EvaluationIdentifier subclasses and attached #: to the identifier so it survives DB round-trips with truncated params. - eval_hash: Optional[str] = None + eval_hash: str | None = None # ------------------------------------------------------------------ # Validators @@ -435,8 +435,8 @@ def of( cls, obj: object, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Build a ComponentIdentifier from a live object instance. @@ -464,7 +464,7 @@ def of( children=clean_children, ) - def get_child(self, key: str) -> Optional[ComponentIdentifier]: + def get_child(self, key: str) -> ComponentIdentifier | None: """ Get a single child by key. @@ -519,7 +519,7 @@ def _collect_child_eval_hashes(self) -> set[str]: return hashes @staticmethod - def _truncate_value(*, value: Any, max_length: Optional[int]) -> Any: + def _truncate_value(*, value: Any, max_length: int | None) -> Any: """ Truncate string values longer than ``max_length`` with a ``...`` suffix. @@ -538,7 +538,7 @@ def _truncate_value(*, value: Any, max_length: Optional[int]) -> Any: # Deprecated shims — kept for one release cycle # ------------------------------------------------------------------ - def to_dict(self, *, max_value_length: Optional[int] = None) -> dict[str, Any]: + def to_dict(self, *, max_value_length: int | None = None) -> dict[str, Any]: """ Return the flat storage dict (deprecated; use ``model_dump`` instead). @@ -584,7 +584,7 @@ class Identifiable(ABC): component's lifetime. """ - _identifier: Optional[ComponentIdentifier] = None + _identifier: ComponentIdentifier | None = None @abstractmethod def _build_identifier(self) -> ComponentIdentifier: diff --git a/pyrit/models/identifiers/evaluation_identifier.py b/pyrit/models/identifiers/evaluation_identifier.py index b2fc1b996d..31a328c9d4 100644 --- a/pyrit/models/identifiers/evaluation_identifier.py +++ b/pyrit/models/identifiers/evaluation_identifier.py @@ -23,7 +23,7 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import TYPE_CHECKING, Any, ClassVar from pydantic import BaseModel, ConfigDict, Field @@ -65,18 +65,18 @@ class ChildEvalRule(BaseModel): model_config = ConfigDict(frozen=True) exclude: bool = False - included_params: Optional[frozenset[str]] = None - included_item_values: Optional[dict[str, Any]] = Field(default=None) - param_fallbacks: Optional[dict[str, str]] = Field(default=None) - inner_child_name: Optional[str] = Field(default=None) + included_params: frozenset[str] | None = None + included_item_values: dict[str, Any] | None = Field(default=None) + param_fallbacks: dict[str, str] | None = Field(default=None) + inner_child_name: str | None = Field(default=None) def _build_eval_dict( identifier: ComponentIdentifier, *, child_eval_rules: dict[str, ChildEvalRule], - _included_params: Optional[frozenset[str]] = None, - _param_fallbacks: Optional[dict[str, str]] = None, + _included_params: frozenset[str] | None = None, + _param_fallbacks: dict[str, str] | None = None, ) -> dict[str, Any]: """ Build a filtered dictionary for eval-hash computation. @@ -89,10 +89,10 @@ def _build_eval_dict( identifier (ComponentIdentifier): The component identity to process. child_eval_rules (dict[str, ChildEvalRule]): Per-child eval rules. Keys are child names; values describe how each child is filtered. - _included_params (Optional[frozenset[str]]): Internal. If set, only + _included_params (frozenset[str] | None): Internal. If set, only include params whose keys are in this frozenset. Passed down from a parent rule's ``included_params``. - _param_fallbacks (Optional[dict[str, str]]): Internal. Maps a primary + _param_fallbacks (dict[str, str] | None): Internal. Maps a primary param key to a fallback key. When the primary value is falsy, the fallback key's value from raw params is used instead. Passed down from a parent rule's ``param_fallbacks``. @@ -177,7 +177,7 @@ def compute_eval_hash( identifier: ComponentIdentifier, *, child_eval_rules: dict[str, ChildEvalRule], - own_rule: Optional[ChildEvalRule] = None, + own_rule: ChildEvalRule | None = None, ) -> str: """ Compute a behavioral equivalence hash for evaluation grouping. @@ -200,7 +200,7 @@ def compute_eval_hash( identifier (ComponentIdentifier): The component identity to compute the hash for. child_eval_rules (dict[str, ChildEvalRule]): Per-child eval rules. - own_rule (Optional[ChildEvalRule]): Rule applied to the root entity's + own_rule (ChildEvalRule | None): Rule applied to the root entity's own params and fallbacks. Only ``included_params`` and ``param_fallbacks`` are honored; ``exclude``, ``included_item_values``, and ``inner_child_name`` are not meaningful at the root and will @@ -252,7 +252,7 @@ class EvaluationIdentifier(ABC): """ CHILD_EVAL_RULES: ClassVar[dict[str, ChildEvalRule]] - OWN_RULE: ClassVar[Optional[ChildEvalRule]] = None + OWN_RULE: ClassVar[ChildEvalRule | None] = None def __init__(self, identifier: ComponentIdentifier) -> None: """ @@ -357,7 +357,7 @@ class ObjectiveTargetEvaluationIdentifier(EvaluationIdentifier): """ CHILD_EVAL_RULES: ClassVar[dict[str, ChildEvalRule]] = {} - OWN_RULE: ClassVar[Optional[ChildEvalRule]] = ChildEvalRule( + OWN_RULE: ClassVar[ChildEvalRule | None] = ChildEvalRule( included_params=TARGET_EVAL_PARAMS, param_fallbacks=TARGET_EVAL_PARAM_FALLBACKS, ) diff --git a/pyrit/models/json_response_config.py b/pyrit/models/json_response_config.py index 8c4c1b9864..b6915526b3 100644 --- a/pyrit/models/json_response_config.py +++ b/pyrit/models/json_response_config.py @@ -4,7 +4,7 @@ from __future__ import annotations import json -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict @@ -30,12 +30,12 @@ class _JsonResponseConfig(BaseModel): model_config = ConfigDict(extra="forbid") enabled: bool = False - json_schema: Optional[dict[str, Any]] = None + json_schema: dict[str, Any] | None = None schema_name: str = "CustomSchema" strict: bool = True @classmethod - def from_metadata(cls, *, metadata: Optional[dict[str, Any]]) -> _JsonResponseConfig: + def from_metadata(cls, *, metadata: dict[str, Any] | None) -> _JsonResponseConfig: if not metadata: return cls(enabled=False) diff --git a/pyrit/models/messages/conversations.py b/pyrit/models/messages/conversations.py index 2b5224f6c6..b225e527b2 100644 --- a/pyrit/models/messages/conversations.py +++ b/pyrit/models/messages/conversations.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pyrit.models.messages.message import Message from pyrit.models.messages.message_piece import MessagePiece @@ -178,7 +178,7 @@ def construct_response_from_request( request: MessagePiece, response_text_pieces: list[str], response_type: PromptDataType = "text", - prompt_metadata: Optional[dict[str, Union[str, int]]] = None, + prompt_metadata: dict[str, str | int] | None = None, error: PromptResponseError = "none", ) -> Message: """ @@ -188,7 +188,7 @@ def construct_response_from_request( request (MessagePiece): Source request message piece. response_text_pieces (list[str]): Response values to include. response_type (PromptDataType): Data type for original and converted response values. - prompt_metadata (Optional[Dict[str, Union[str, int]]]): Additional metadata to merge. + prompt_metadata (dict[str, str | int] | None): Additional metadata to merge. error (PromptResponseError): Error classification for the response. Returns: diff --git a/pyrit/models/messages/message.py b/pyrit/models/messages/message.py index d14f1ec1b9..1c64b8581d 100644 --- a/pyrit/models/messages/message.py +++ b/pyrit/models/messages/message.py @@ -6,7 +6,7 @@ import copy import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, model_validator @@ -195,9 +195,9 @@ def get_piece(self, n: int = 0) -> MessagePiece: def get_pieces_by_type( self, *, - data_type: Optional[PromptDataType] = None, - original_value_data_type: Optional[PromptDataType] = None, - converted_value_data_type: Optional[PromptDataType] = None, + data_type: PromptDataType | None = None, + original_value_data_type: PromptDataType | None = None, + converted_value_data_type: PromptDataType | None = None, ) -> list[MessagePiece]: """ Return all message pieces matching the given data type. @@ -222,10 +222,10 @@ def get_pieces_by_type( def get_piece_by_type( self, *, - data_type: Optional[PromptDataType] = None, - original_value_data_type: Optional[PromptDataType] = None, - converted_value_data_type: Optional[PromptDataType] = None, - ) -> Optional[MessagePiece]: + data_type: PromptDataType | None = None, + original_value_data_type: PromptDataType | None = None, + converted_value_data_type: PromptDataType | None = None, + ) -> MessagePiece | None: """ Return the first message piece matching the given data type, or None. @@ -358,7 +358,7 @@ def from_prompt( *, prompt: str, role: ChatMessageRole, - prompt_metadata: Optional[dict[str, Union[str, int]]] = None, + prompt_metadata: dict[str, str | int] | None = None, ) -> Message: """ Build a single-piece message from prompt text. @@ -366,7 +366,7 @@ def from_prompt( Args: prompt (str): Prompt text. role (ChatMessageRole): Role assigned to the message piece. - prompt_metadata (Optional[Dict[str, Union[str, int]]]): Optional prompt metadata. + prompt_metadata (dict[str, str | int] | None): Optional prompt metadata. Returns: Message: Constructed message instance. diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py index a8e6cb7eca..728f736f20 100644 --- a/pyrit/models/messages/message_piece.py +++ b/pyrit/models/messages/message_piece.py @@ -5,7 +5,7 @@ import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal from uuid import uuid4 from pydantic import ( @@ -102,20 +102,20 @@ class MessagePiece(BaseModel): timestamp: AwareDatetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) original_value: str original_value_data_type: PromptDataType = "text" - original_value_sha256: Optional[str] = None + original_value_sha256: str | None = None converted_value: str = "" converted_value_data_type: PromptDataType = "text" - converted_value_sha256: Optional[str] = None + converted_value_sha256: str | None = None response_error: PromptResponseError = "none" originator: Literal["attack", "converter", "undefined", "scorer"] = "undefined" - original_prompt_id: Optional[uuid.UUID] = None + original_prompt_id: uuid.UUID | None = None labels: dict[str, Any] = Field(default_factory=dict) targeted_harm_categories: list[str] = Field(default_factory=list) prompt_metadata: dict[str, Any] = Field(default_factory=dict) converter_identifiers: list[ComponentIdentifierField] = Field(default_factory=list) - prompt_target_identifier: Optional[ComponentIdentifierField] = None - attack_identifier: Optional[ComponentIdentifierField] = None - scorer_identifier: Optional[ComponentIdentifierField] = None + prompt_target_identifier: ComponentIdentifierField | None = None + attack_identifier: ComponentIdentifierField | None = None + scorer_identifier: ComponentIdentifierField | None = None scores: list[Score] = Field(default_factory=list) # When True, the memory layer skips persisting this piece. Used for ephemeral diff --git a/pyrit/models/question_answering.py b/pyrit/models/question_answering.py index c468461090..8e1ca76e0f 100644 --- a/pyrit/models/question_answering.py +++ b/pyrit/models/question_answering.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Literal, Union +from typing import Literal from pydantic import BaseModel, ConfigDict @@ -26,7 +26,7 @@ class QuestionAnsweringEntry(BaseModel): model_config = ConfigDict(extra="forbid") question: str answer_type: Literal["int", "float", "str", "bool"] - correct_answer: Union[int, str, float] + correct_answer: int | str | float choices: list[QuestionChoice] def get_correct_answer_text(self) -> str: diff --git a/pyrit/models/retry_event.py b/pyrit/models/retry_event.py index 46a6e79fcf..2ef67f908f 100644 --- a/pyrit/models/retry_event.py +++ b/pyrit/models/retry_event.py @@ -6,7 +6,6 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import Optional from pydantic import BaseModel, Field @@ -29,8 +28,8 @@ class RetryEvent(BaseModel): exception_type: str = "" exception_message: str = "" component_role: str = "" - component_name: Optional[str] = None - endpoint: Optional[str] = None + component_name: str | None = None + endpoint: str | None = None elapsed_seconds: float = 0.0 def to_dict(self) -> dict: diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index 1c9f06ae20..1df91444c9 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -271,7 +271,7 @@ def from_dict(cls, data: dict[str, Any]) -> SeedDataset: ``prompt_group_alias`` into a shared ``prompt_group_id``. Args: - data (Dict[str, Any]): Dataset payload with top-level defaults and seed entries. + data (dict[str, Any]): Dataset payload with top-level defaults and seed entries. Returns: SeedDataset: Constructed dataset. diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index 5b610f80d8..15a7068c33 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from urllib.parse import urlparse import aiofiles @@ -36,41 +36,41 @@ class StorageIO(ABC): """ @abstractmethod - async def read_file_async(self, path: Union[Path, str]) -> bytes: + async def read_file_async(self, path: Path | str) -> bytes: """ Asynchronously reads the file (or blob) from the given path. """ @abstractmethod - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + async def write_file_async(self, path: Path | str, data: bytes) -> None: """ Asynchronously writes data to the given path. """ @abstractmethod - async def path_exists_async(self, path: Union[Path, str]) -> bool: + async def path_exists_async(self, path: Path | str) -> bool: """ Asynchronously checks if a file or blob exists at the given path. """ @abstractmethod - async def is_file_async(self, path: Union[Path, str]) -> bool: + async def is_file_async(self, path: Path | str) -> bool: """ Asynchronously checks if the path refers to a file (not a directory or container). """ @abstractmethod - async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> None: + async def create_directory_if_not_exists_async(self, path: Path | str) -> None: """ Asynchronously creates a directory or equivalent in the storage system if it doesn't exist. """ - async def read_file(self, path: Union[Path, str]) -> bytes: # pyrit-async-suffix-exempt + async def read_file(self, path: Path | str) -> bytes: # pyrit-async-suffix-exempt """ Read a file from storage (deprecated alias of ``read_file_async``). Args: - path (Union[Path, str]): The path to the file. + path (Path | str): The path to the file. Returns: bytes: The content of the file. @@ -82,12 +82,12 @@ async def read_file(self, path: Union[Path, str]) -> bytes: # pyrit-async-suffi ) return await self.read_file_async(path) - async def write_file(self, path: Union[Path, str], data: bytes) -> None: # pyrit-async-suffix-exempt + async def write_file(self, path: Path | str, data: bytes) -> None: # pyrit-async-suffix-exempt """ Write data to storage (deprecated alias of ``write_file_async``). Args: - path (Union[Path, str]): The path to the file. + path (Path | str): The path to the file. data (bytes): The content to write to the file. """ print_deprecation_message( @@ -97,12 +97,12 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: # pyri ) await self.write_file_async(path, data) - async def path_exists(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-exempt + async def path_exists(self, path: Path | str) -> bool: # pyrit-async-suffix-exempt """ Check whether a path exists (deprecated alias of ``path_exists_async``). Args: - path (Union[Path, str]): The path to check. + path (Path | str): The path to check. Returns: bool: True if the path exists, False otherwise. @@ -114,12 +114,12 @@ async def path_exists(self, path: Union[Path, str]) -> bool: # pyrit-async-suff ) return await self.path_exists_async(path) - async def is_file(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-exempt + async def is_file(self, path: Path | str) -> bool: # pyrit-async-suffix-exempt """ Check whether the given path is a file (deprecated alias of ``is_file_async``). Args: - path (Union[Path, str]): The path to check. + path (Path | str): The path to check. Returns: bool: True if the path is a file, False otherwise. @@ -131,12 +131,12 @@ async def is_file(self, path: Union[Path, str]) -> bool: # pyrit-async-suffix-e ) return await self.is_file_async(path) - async def create_directory_if_not_exists(self, path: Union[Path, str]) -> None: # pyrit-async-suffix-exempt + async def create_directory_if_not_exists(self, path: Path | str) -> None: # pyrit-async-suffix-exempt """ Create a directory if it does not exist (deprecated alias of ``create_directory_if_not_exists_async``). Args: - path (Union[Path, str]): The directory path to create. + path (Path | str): The directory path to create. """ print_deprecation_message( old_item="pyrit.models.storage_io.StorageIO.create_directory_if_not_exists", @@ -151,12 +151,12 @@ class DiskStorageIO(StorageIO): Implementation of StorageIO for local disk storage. """ - async def read_file_async(self, path: Union[Path, str]) -> bytes: + async def read_file_async(self, path: Path | str) -> bytes: """ Asynchronously reads a file from the local disk. Args: - path (Union[Path, str]): The path to the file. + path (Path | str): The path to the file. Returns: bytes: The content of the file. @@ -166,7 +166,7 @@ async def read_file_async(self, path: Union[Path, str]) -> bytes: async with aiofiles.open(path, "rb") as file: return await file.read() - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + async def write_file_async(self, path: Path | str, data: bytes) -> None: """ Asynchronously writes data to a file on the local disk. @@ -179,7 +179,7 @@ async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: async with aiofiles.open(path, "wb") as file: await file.write(data) - async def path_exists_async(self, path: Union[Path, str]) -> bool: + async def path_exists_async(self, path: Path | str) -> bool: """ Check whether a path exists on the local disk. @@ -193,7 +193,7 @@ async def path_exists_async(self, path: Union[Path, str]) -> bool: path = self._convert_to_path(path) return path.exists() - async def is_file_async(self, path: Union[Path, str]) -> bool: + async def is_file_async(self, path: Path | str) -> bool: """ Check whether the given path is a file (not a directory). @@ -207,7 +207,7 @@ async def is_file_async(self, path: Union[Path, str]) -> bool: path = self._convert_to_path(path) return path.is_file() - async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> None: + async def create_directory_if_not_exists_async(self, path: Path | str) -> None: """ Asynchronously creates a directory if it doesn't exist on the local disk. @@ -219,12 +219,12 @@ async def create_directory_if_not_exists_async(self, path: Union[Path, str]) -> if not directory_path.exists(): directory_path.mkdir(parents=True, exist_ok=True) - def _convert_to_path(self, path: Union[Path, str]) -> Path: + def _convert_to_path(self, path: Path | str) -> Path: """ Convert an input path to a Path object. Args: - path (Union[Path, str]): Input path value. + path (Path | str): Input path value. Returns: Path: Normalized Path instance. @@ -241,16 +241,16 @@ class AzureBlobStorageIO(StorageIO): def __init__( self, *, - container_url: Optional[str] = None, - sas_token: Optional[str] = None, + container_url: str | None = None, + sas_token: str | None = None, blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT, ) -> None: """ Initialize an Azure Blob Storage I/O adapter. Args: - container_url (Optional[str]): Azure Blob container URL. - sas_token (Optional[str]): Optional SAS token. + container_url (str | None): Azure Blob container URL. + sas_token (str | None): Optional SAS token. blob_content_type (SupportedContentType): Blob content type for uploads. Raises: @@ -351,7 +351,7 @@ def parse_blob_url(self, file_path: str) -> tuple[str, str]: return container_name, blob_name raise ValueError("Invalid blob URL") - def _resolve_blob_name(self, path: Union[Path, str]) -> str: + def _resolve_blob_name(self, path: Path | str) -> str: """ Resolve a blob name from either a full blob URL or a relative blob path. @@ -363,7 +363,7 @@ def _resolve_blob_name(self, path: Union[Path, str]) -> str: created on Windows still produce valid blob names. Args: - path (Union[Path, str]): Blob URL or relative blob path. + path (Path | str): Blob URL or relative blob path. Returns: str: The resolved blob name. @@ -377,7 +377,7 @@ def _resolve_blob_name(self, path: Union[Path, str]) -> str: except ValueError: return path_str - async def read_file_async(self, path: Union[Path, str]) -> bytes: + async def read_file_async(self, path: Path | str) -> bytes: """ Asynchronously reads the content of a file (blob) from Azure Blob Storage. @@ -420,7 +420,7 @@ async def read_file_async(self, path: Union[Path, str]) -> bytes: await self._client_async.close() self._client_async = None - async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: + async def write_file_async(self, path: Path | str, data: bytes) -> None: """ Write data to Azure Blob Storage at the specified path. @@ -428,7 +428,7 @@ async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: If a relative path is provided, it is used as the blob name directly. Args: - path (Union[Path, str]): Full blob URL or relative blob path. + path (Path | str): Full blob URL or relative blob path. data (bytes): The data to write. """ if not self._client_async: @@ -443,12 +443,12 @@ async def write_file_async(self, path: Union[Path, str], data: bytes) -> None: await self._client_async.close() self._client_async = None - async def path_exists_async(self, path: Union[Path, str]) -> bool: + async def path_exists_async(self, path: Path | str) -> bool: """ Check whether a given path exists in the Azure Blob Storage container. Args: - path (Union[Path, str]): Blob URL or path to test. + path (Path | str): Blob URL or path to test. Returns: bool: True when the path exists. @@ -468,12 +468,12 @@ async def path_exists_async(self, path: Union[Path, str]) -> bool: await self._client_async.close() self._client_async = None - async def is_file_async(self, path: Union[Path, str]) -> bool: + async def is_file_async(self, path: Path | str) -> bool: """ Check whether the path refers to a file (blob) in Azure Blob Storage. Args: - path (Union[Path, str]): Blob URL or path to test. + path (Path | str): Blob URL or path to test. Returns: bool: True when the blob exists and has non-zero content size. @@ -493,12 +493,12 @@ async def is_file_async(self, path: Union[Path, str]) -> bool: await self._client_async.close() self._client_async = None - async def create_directory_if_not_exists_async(self, directory_path: Union[Path, str]) -> None: # type: ignore[ty:invalid-method-override] + async def create_directory_if_not_exists_async(self, directory_path: Path | str) -> None: # type: ignore[ty:invalid-method-override] """ Log a no-op directory creation for Azure Blob Storage. Args: - directory_path (Union[Path, str]): Requested directory path. + directory_path (Path | str): Requested directory path. """ logger.info( diff --git a/pyrit/prompt_converter/add_image_to_video_converter.py b/pyrit/prompt_converter/add_image_to_video_converter.py index 8b7ce67fa6..0fceec3578 100644 --- a/pyrit/prompt_converter/add_image_to_video_converter.py +++ b/pyrit/prompt_converter/add_image_to_video_converter.py @@ -5,7 +5,6 @@ import contextlib import logging from pathlib import Path -from typing import Optional import numpy as np @@ -43,7 +42,7 @@ class AddImageVideoConverter(PromptConverter): def __init__( self, video_path: str, - output_path: Optional[str] = None, + output_path: str | None = None, img_position: tuple[int, int] = (10, 10), img_resize_size: tuple[int, int] = (500, 500), ) -> None: @@ -151,9 +150,9 @@ def _add_image_to_video_sync( import cv2 video_path = self._video_path - local_temp_path: Optional[Path] = None - cap: Optional[cv2.VideoCapture] = None - output_video: Optional[cv2.VideoWriter] = None + local_temp_path: Path | None = None + cap: cv2.VideoCapture | None = None + output_video: cv2.VideoWriter | None = None try: if azure_storage_flag: diff --git a/pyrit/prompt_converter/ask_to_decode_converter.py b/pyrit/prompt_converter/ask_to_decode_converter.py index 5542050df8..ea95e28ffd 100644 --- a/pyrit/prompt_converter/ask_to_decode_converter.py +++ b/pyrit/prompt_converter/ask_to_decode_converter.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import random -from typing import Optional from pyrit.models import PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -44,7 +43,7 @@ class AskToDecodeConverter(PromptConverter): # TODO: remove this opt-out and insert ``*,`` after ``self`` in 0.16.0. _brick_legacy_init = True - def __init__(self, template: Optional[str] = None, encoding_name: str = "cipher") -> None: + def __init__(self, template: str | None = None, encoding_name: str = "cipher") -> None: """ Initialize the converter with a specified encoding name and template. diff --git a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py index ae4190d59a..d4c6698aae 100644 --- a/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py +++ b/pyrit/prompt_converter/azure_speech_audio_to_text_converter.py @@ -5,7 +5,7 @@ import logging import time from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: import azure.cognitiveservices.speech as speechsdk @@ -47,10 +47,10 @@ class AzureSpeechAudioToTextConverter(PromptConverter): def __init__( self, *, - azure_speech_region: Optional[str] = None, - azure_speech_key: Optional[str | Callable[[], str | Awaitable[str]]] = None, - azure_speech_resource_id: Optional[str] = None, - use_entra_auth: Optional[bool] = None, + azure_speech_region: str | None = None, + azure_speech_key: str | Callable[[], str | Awaitable[str]] | None = None, + azure_speech_resource_id: str | None = None, + use_entra_auth: bool | None = None, recognition_language: str = "en-US", ) -> None: """ diff --git a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py index 03463a3b62..98b107f7b3 100644 --- a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py +++ b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py @@ -3,7 +3,7 @@ import logging from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: import azure.cognitiveservices.speech as speechsdk # noqa: F401 @@ -48,10 +48,10 @@ class AzureSpeechTextToAudioConverter(PromptConverter): def __init__( self, *, - azure_speech_region: Optional[str] = None, - azure_speech_key: Optional[str | Callable[[], str | Awaitable[str]]] = None, - azure_speech_resource_id: Optional[str] = None, - use_entra_auth: Optional[bool] = None, + azure_speech_region: str | None = None, + azure_speech_key: str | Callable[[], str | Awaitable[str]] | None = None, + azure_speech_resource_id: str | None = None, + use_entra_auth: bool | None = None, synthesis_language: str = "en_US", synthesis_voice_name: str = "en-US-AvaNeural", output_format: AzureSpeechAudioFormat = "wav", diff --git a/pyrit/prompt_converter/bin_ascii_converter.py b/pyrit/prompt_converter/bin_ascii_converter.py index f06971d9ee..f5743c3143 100644 --- a/pyrit/prompt_converter/bin_ascii_converter.py +++ b/pyrit/prompt_converter/bin_ascii_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import binascii -from typing import Literal, Optional +from typing import Literal from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import ( @@ -29,8 +29,8 @@ def __init__( self, *, encoding_func: EncodingFunc = "hex", - word_selection_strategy: Optional[WordSelectionStrategy] = None, - word_split_separator: Optional[str] = " ", + word_selection_strategy: WordSelectionStrategy | None = None, + word_split_separator: str | None = " ", ) -> None: """ Initialize the BinAsciiConverter. @@ -38,9 +38,9 @@ def __init__( Args: encoding_func (str): The encoding function to use. Options: "hex", "quoted-printable", "UUencode". Defaults to "hex". - word_selection_strategy (Optional[WordSelectionStrategy]): Strategy for selecting which words to convert. + word_selection_strategy (WordSelectionStrategy | None): Strategy for selecting which words to convert. If None, all words will be converted. - word_split_separator (Optional[str]): Separator used to split words in the input text. + word_split_separator (str | None): Separator used to split words in the input text. Defaults to " ". Raises: diff --git a/pyrit/prompt_converter/binary_converter.py b/pyrit/prompt_converter/binary_converter.py index c17ce73234..6a9a6425d7 100644 --- a/pyrit/prompt_converter/binary_converter.py +++ b/pyrit/prompt_converter/binary_converter.py @@ -4,7 +4,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.prompt_converter.word_level_converter import WordLevelConverter @@ -29,7 +29,7 @@ def __init__( self, *, bits_per_char: BinaryConverter.BitsPerChar = BitsPerChar.BITS_16, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified bits per character and selection strategy. @@ -37,7 +37,7 @@ def __init__( Args: bits_per_char (BinaryConverter.BitsPerChar): Number of bits to use for each character (8, 16, or 32). Default is 16 bits. - word_selection_strategy (Optional[WordSelectionStrategy]): Strategy for selecting which words to convert. + word_selection_strategy (WordSelectionStrategy | None): Strategy for selecting which words to convert. If None, all words will be converted. Raises: diff --git a/pyrit/prompt_converter/charswap_attack_converter.py b/pyrit/prompt_converter/charswap_attack_converter.py index cf91c75a88..e58f84d454 100644 --- a/pyrit/prompt_converter/charswap_attack_converter.py +++ b/pyrit/prompt_converter/charswap_attack_converter.py @@ -3,7 +3,6 @@ import random import string -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import ( @@ -22,7 +21,7 @@ def __init__( self, *, max_iterations: int = 10, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified parameters. @@ -32,7 +31,7 @@ def __init__( Args: max_iterations (int): Number of times to generate perturbed prompts. The higher the number the higher the chance that words are different from the original prompt. - word_selection_strategy (Optional[WordSelectionStrategy]): Strategy for selecting which words to convert. + word_selection_strategy (WordSelectionStrategy | None): Strategy for selecting which words to convert. If None, defaults to WordProportionSelectionStrategy(proportion=0.2). Raises: diff --git a/pyrit/prompt_converter/codechameleon_converter.py b/pyrit/prompt_converter/codechameleon_converter.py index a8d932ec6a..6079cd18f6 100644 --- a/pyrit/prompt_converter/codechameleon_converter.py +++ b/pyrit/prompt_converter/codechameleon_converter.py @@ -7,7 +7,7 @@ import re import textwrap from collections.abc import Callable -from typing import Any, Optional +from typing import Any from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt @@ -52,8 +52,8 @@ def __init__( self, *, encrypt_type: str, - encrypt_function: Optional[Callable[..., Any]] = None, - decrypt_function: Optional[Callable[..., Any] | list[Callable[..., Any] | str]] = None, + encrypt_function: Callable[..., Any] | None = None, + decrypt_function: Callable[..., Any] | list[Callable[..., Any] | str] | None = None, ) -> None: """ Initialize the converter with the specified encryption type and optional functions. @@ -163,10 +163,10 @@ class TreeNode: def __init__(self, value: str) -> None: self.value = value - self.left: Optional[TreeNode] = None - self.right: Optional[TreeNode] = None + self.left: TreeNode | None = None + self.right: TreeNode | None = None - def build_tree(words: list[str], start: int, end: int) -> Optional[TreeNode]: + def build_tree(words: list[str], start: int, end: int) -> TreeNode | None: """ Recursively build a balanced binary tree from a sublist of words. @@ -189,7 +189,7 @@ def build_tree(words: list[str], start: int, end: int) -> Optional[TreeNode]: return node - def tree_to_json(node: Optional[TreeNode]) -> Optional[dict[str, Any]]: + def tree_to_json(node: TreeNode | None) -> dict[str, Any] | None: """ Convert a tree to a JSON representation. diff --git a/pyrit/prompt_converter/colloquial_wordswap_converter.py b/pyrit/prompt_converter/colloquial_wordswap_converter.py index e6118a4040..65dd641928 100644 --- a/pyrit/prompt_converter/colloquial_wordswap_converter.py +++ b/pyrit/prompt_converter/colloquial_wordswap_converter.py @@ -4,7 +4,6 @@ import pathlib import random import re -from typing import Optional import yaml @@ -28,8 +27,8 @@ def __init__( self, *, deterministic: bool = False, - custom_substitutions: Optional[dict[str, list[str]]] = None, - wordswap_path: Optional[str] = None, + custom_substitutions: dict[str, list[str]] | None = None, + wordswap_path: str | None = None, ) -> None: """ Initialize the converter with optional deterministic mode and substitutions source. @@ -37,9 +36,9 @@ def __init__( Args: deterministic (bool): If True, use the first substitution for each wordswap. If False, randomly choose a substitution for each wordswap. Defaults to False. - custom_substitutions (Optional[dict[str, list[str]]]): A dictionary of custom substitutions + custom_substitutions (dict[str, list[str]] | None): A dictionary of custom substitutions to override the defaults. Defaults to None. - wordswap_path (Optional[str]): Path to a YAML file containing word substitutions. + wordswap_path (str | None): Path to a YAML file containing word substitutions. Can be a filename within the built-in colloquial_wordswaps directory (e.g., "filipino.yaml") or an absolute path to a custom YAML file. Defaults to None (uses singaporean.yaml). diff --git a/pyrit/prompt_converter/denylist_converter.py b/pyrit/prompt_converter/denylist_converter.py index b9a0ee2c7c..0d685a9d0c 100644 --- a/pyrit/prompt_converter/denylist_converter.py +++ b/pyrit/prompt_converter/denylist_converter.py @@ -3,7 +3,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,7 +26,7 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - system_prompt_template: Optional[SeedPrompt] = None, + system_prompt_template: SeedPrompt | None = None, denylist: list[str] | None = None, ) -> None: """ @@ -36,7 +35,7 @@ def __init__( Args: converter_target (PromptTarget): The target for the prompt conversion. Can be omitted if a default has been configured via PyRIT initialization. - system_prompt_template (Optional[SeedPrompt]): The system prompt template to use for the conversion. + system_prompt_template (SeedPrompt | None): The system prompt template to use for the conversion. If not provided, a default template will be used. denylist (list[str]): A list of words or phrases that should be replaced in the prompt. """ diff --git a/pyrit/prompt_converter/first_letter_converter.py b/pyrit/prompt_converter/first_letter_converter.py index ce6058ac68..b3c13db2ce 100644 --- a/pyrit/prompt_converter/first_letter_converter.py +++ b/pyrit/prompt_converter/first_letter_converter.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy @@ -18,14 +17,14 @@ def __init__( self, *, letter_separator: str = " ", - word_selection_strategy: Optional[WordSelectionStrategy] = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified letter separator and selection strategy. Args: letter_separator (str): The string used to join the first letters. - word_selection_strategy (Optional[WordSelectionStrategy]): Strategy for selecting which words to convert. + word_selection_strategy (WordSelectionStrategy | None): Strategy for selecting which words to convert. If None, all words will be converted. """ super().__init__(word_selection_strategy=word_selection_strategy, word_split_separator=None) diff --git a/pyrit/prompt_converter/image_compression_converter.py b/pyrit/prompt_converter/image_compression_converter.py index 57fc13b856..da8c01e91f 100644 --- a/pyrit/prompt_converter/image_compression_converter.py +++ b/pyrit/prompt_converter/image_compression_converter.py @@ -4,7 +4,7 @@ import base64 import logging from io import BytesIO -from typing import Any, Literal, Optional +from typing import Any, Literal from urllib.parse import urlparse import aiohttp @@ -49,13 +49,13 @@ class ImageCompressionConverter(PromptConverter): def __init__( self, *, - output_format: Optional[Literal["JPEG", "PNG", "WEBP"]] = None, - quality: Optional[int] = None, - optimize: Optional[bool] = None, - progressive: Optional[bool] = None, - compress_level: Optional[int] = None, - lossless: Optional[bool] = None, - method: Optional[int] = None, + output_format: Literal["JPEG", "PNG", "WEBP"] | None = None, + quality: int | None = None, + optimize: bool | None = None, + progressive: bool | None = None, + compress_level: int | None = None, + lossless: bool | None = None, + method: int | None = None, background_color: tuple[int, int, int] = (0, 0, 0), min_compression_threshold: int = 1024, fallback_to_original: bool = True, diff --git a/pyrit/prompt_converter/insert_punctuation_converter.py b/pyrit/prompt_converter/insert_punctuation_converter.py index 668917c69c..1a965e1bd0 100644 --- a/pyrit/prompt_converter/insert_punctuation_converter.py +++ b/pyrit/prompt_converter/insert_punctuation_converter.py @@ -4,7 +4,6 @@ import random import re import string -from typing import Optional from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -69,7 +68,7 @@ def _is_valid_punctuation(self, punctuation_list: list[str]) -> bool: Space, letters, numbers, double punctuations are all invalid. Args: - punctuation_list (List[str]): List of punctuations to validate. + punctuation_list (list[str]): List of punctuations to validate. Returns: bool: valid list and valid punctuations @@ -77,7 +76,7 @@ def _is_valid_punctuation(self, punctuation_list: list[str]) -> bool: return all(char in string.punctuation for char in punctuation_list) async def convert_async( - self, *, prompt: str, input_type: PromptDataType = "text", punctuation_list: Optional[list[str]] = None + self, *, prompt: str, input_type: PromptDataType = "text", punctuation_list: list[str] | None = None ) -> ConverterResult: """ Convert the given prompt by inserting punctuation. @@ -85,7 +84,7 @@ async def convert_async( Args: prompt (str): The text to convert. input_type (PromptDataType): The type of input data. - punctuation_list (Optional[List[str]]): List of punctuations to use for insertion. + punctuation_list (list[str] | None): List of punctuations to use for insertion. Returns: ConverterResult: The result containing an iteration of modified prompts. @@ -115,7 +114,7 @@ def _insert_punctuation(self, prompt: str, punctuation_list: list[str]) -> str: Args: prompt (str): The text to modify. - punctuation_list (List[str]): List of punctuations for insertion. + punctuation_list (list[str]): List of punctuations for insertion. Returns: str: The modified prompt with inserted punctuation from helper method. @@ -144,10 +143,10 @@ def _insert_between_words( Insert punctuation between words in the prompt. Args: - words (List[str]): List of words and punctuations. - word_indices (List[int]): Indices of the actual words without punctuations in words list. + words (list[str]): List of words and punctuations. + word_indices (list[int]): Indices of the actual words without punctuations in words list. num_insertions (int): Number of punctuations to insert. - punctuation_list (List[str]): punctuations for insertion. + punctuation_list (list[str]): punctuations for insertion. Returns: str: The modified prompt with inserted punctuation. @@ -171,7 +170,7 @@ def _insert_within_words(self, prompt: str, num_insertions: int, punctuation_lis Args: prompt (str): The prompt string num_insertions (int): Number of punctuations to insert. - punctuation_list (List[str]): punctuations for insertion. + punctuation_list (list[str]): punctuations for insertion. Returns: str: The modified prompt with inserted punctuation. diff --git a/pyrit/prompt_converter/leetspeak_converter.py b/pyrit/prompt_converter/leetspeak_converter.py index a904d804a7..1e8810cd29 100644 --- a/pyrit/prompt_converter/leetspeak_converter.py +++ b/pyrit/prompt_converter/leetspeak_converter.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import random -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy @@ -18,8 +17,8 @@ def __init__( self, *, deterministic: bool = True, - custom_substitutions: Optional[dict[str, list[str]]] = None, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + custom_substitutions: dict[str, list[str]] | None = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with optional deterministic mode and custom substitutions. @@ -27,8 +26,8 @@ def __init__( Args: deterministic (bool): If True, use the first substitution for each character. If False, randomly choose a substitution for each character. - custom_substitutions (Optional[dict]): A dictionary of custom substitutions to override the defaults. - word_selection_strategy (Optional[WordSelectionStrategy]): Strategy for selecting which words to convert. + custom_substitutions (dict | None): A dictionary of custom substitutions to override the defaults. + word_selection_strategy (WordSelectionStrategy | None): Strategy for selecting which words to convert. If None, all words will be converted. """ super().__init__(word_selection_strategy=word_selection_strategy) diff --git a/pyrit/prompt_converter/malicious_question_generator_converter.py b/pyrit/prompt_converter/malicious_question_generator_converter.py index 7e8b64a0d8..e35270bce3 100644 --- a/pyrit/prompt_converter/malicious_question_generator_converter.py +++ b/pyrit/prompt_converter/malicious_question_generator_converter.py @@ -3,7 +3,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,7 +26,7 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with a specific target and template. diff --git a/pyrit/prompt_converter/math_obfuscation_converter.py b/pyrit/prompt_converter/math_obfuscation_converter.py index 870d7fc03a..ee9ba737e9 100644 --- a/pyrit/prompt_converter/math_obfuscation_converter.py +++ b/pyrit/prompt_converter/math_obfuscation_converter.py @@ -3,7 +3,6 @@ import logging import random -from typing import Optional from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -44,9 +43,9 @@ def __init__( *, min_n: int = 2, max_n: int = 9, - hint: Optional[str] = None, - suffix: Optional[str] = None, - rng: Optional[random.Random] = None, + hint: str | None = None, + suffix: str | None = None, + rng: random.Random | None = None, ) -> None: """ Initialize a MathObfuscationConverter instance. @@ -58,15 +57,15 @@ def __init__( max_n (int): Maximum integer value used for `n`. Must be greater than or equal to `min_n`. - hint (Optional[str]): + hint (str | None): Inline hint appended to the first equation line. If None, uses the default hint explaining the variable encoding. Set to empty string "" to disable hint entirely. - suffix (Optional[str]): + suffix (str | None): Custom suffix to append after the obfuscated text. If None, uses the default suffix prompting the model to decode. Set to empty string "" to disable suffix entirely. - rng (Optional[random.Random]): + rng (random.Random | None): Optional random number generator instance used to produce reproducible obfuscation results. If omitted, a new instance of `random.Random()` is created. diff --git a/pyrit/prompt_converter/math_prompt_converter.py b/pyrit/prompt_converter/math_prompt_converter.py index ce9a4a4246..0fa29f4294 100644 --- a/pyrit/prompt_converter/math_prompt_converter.py +++ b/pyrit/prompt_converter/math_prompt_converter.py @@ -3,7 +3,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,7 +26,7 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with a specific target and template. diff --git a/pyrit/prompt_converter/noise_converter.py b/pyrit/prompt_converter/noise_converter.py index cfa4b83fd7..77e418e731 100644 --- a/pyrit/prompt_converter/noise_converter.py +++ b/pyrit/prompt_converter/noise_converter.py @@ -4,7 +4,6 @@ import logging import pathlib import textwrap -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,9 +26,9 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - noise: Optional[str] = None, + noise: str | None = None, number_errors: int = 5, - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with the specified parameters. diff --git a/pyrit/prompt_converter/pdf_converter.py b/pyrit/prompt_converter/pdf_converter.py index bfb46ffe8e..fd75b05242 100644 --- a/pyrit/prompt_converter/pdf_converter.py +++ b/pyrit/prompt_converter/pdf_converter.py @@ -5,7 +5,7 @@ import hashlib from io import BytesIO from pathlib import Path -from typing import Any, Optional +from typing import Any from pypdf import PageObject, PdfReader, PdfWriter from reportlab.lib.units import mm @@ -42,7 +42,7 @@ class PDFConverter(PromptConverter): def __init__( self, - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, font_type: str = "Helvetica", font_size: int = 12, font_color: tuple[int, int, int] = (255, 255, 255), @@ -50,14 +50,14 @@ def __init__( page_height: int = 297, column_width: int = 0, row_height: int = 10, - existing_pdf: Optional[Path] = None, - injection_items: Optional[list[dict[str, Any]]] = None, + existing_pdf: Path | None = None, + injection_items: list[dict[str, Any]] | None = None, ) -> None: """ Initialize the converter with the specified parameters. Args: - prompt_template (Optional[SeedPrompt], optional): A ``SeedPrompt`` object representing a template. + prompt_template (SeedPrompt | None, optional): A ``SeedPrompt`` object representing a template. font_type (str): Font type for the PDF. Defaults to "Helvetica". font_size (int): Font size for the PDF. Defaults to 12. font_color (tuple): Font color for the PDF in RGB format. Defaults to (255, 255, 255). @@ -65,8 +65,8 @@ def __init__( page_height (int): Height of the PDF page in mm. Defaults to 297 (A4 height). column_width (int): Width of each column in the PDF. Defaults to 0 (full page width). row_height (int): Height of each row in the PDF. Defaults to 10. - existing_pdf (Optional[Path], optional): Path to an existing PDF file. Defaults to None. - injection_items (Optional[List[Dict]], optional): A list of injection items for modifying an existing PDF. + existing_pdf (Path | None, optional): Path to an existing PDF file. Defaults to None. + injection_items (list[Dict] | None, optional): A list of injection items for modifying an existing PDF. Raises: ValueError: If the font color is invalid or the injection items are not provided as a list of dictionaries. @@ -82,9 +82,9 @@ def __init__( self._row_height = row_height # Keeping the user's path here - self._existing_pdf_path: Optional[Path] = existing_pdf + self._existing_pdf_path: Path | None = existing_pdf # We store the file data in a separate BytesIO for type checker compatibility - self._existing_pdf_bytes: Optional[BytesIO] = None + self._existing_pdf_bytes: BytesIO | None = None self._injection_items = injection_items or [] diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 98a80e5c25..282f2b01a3 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -6,7 +6,7 @@ import inspect import re from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, ClassVar, Optional, get_args from pyrit import prompt_converter from pyrit.models import ComponentIdentifier, Identifiable, PromptDataType @@ -56,7 +56,7 @@ class PromptConverter(Identifiable): #: ``super().__init__(converter_target=...)`` so the base class can validate it. TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() - _identifier: Optional[ComponentIdentifier] = None + _identifier: ComponentIdentifier | None = None def __init_subclass__(cls, **kwargs: object) -> None: """ @@ -94,7 +94,7 @@ def __init__(self, *, converter_target: Optional["PromptTarget"] = None) -> None Initialize the prompt converter. Args: - converter_target (Optional[PromptTarget]): Target used by the converter, if any. When + converter_target (PromptTarget | None): Target used by the converter, if any. When provided, it is validated against ``TARGET_REQUIREMENTS``. """ super().__init__() @@ -201,8 +201,8 @@ def _build_identifier(self) -> ComponentIdentifier: def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct and return the converter identifier. @@ -215,9 +215,9 @@ def _create_identifier( to set the identifier with their specific parameters. Args: - params (Optional[Dict[str, Any]]): Additional behavioral parameters from + params (dict[str, Any] | None): Additional behavioral parameters from the subclass (e.g., font, encoding_func). Merged into the base params. - children (Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]]): + children (dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None): Named child component identifiers (e.g., sub-converters, converter targets). Returns: diff --git a/pyrit/prompt_converter/qr_code_converter.py b/pyrit/prompt_converter/qr_code_converter.py index aa8db06ed0..b0515fd92d 100644 --- a/pyrit/prompt_converter/qr_code_converter.py +++ b/pyrit/prompt_converter/qr_code_converter.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional import segno @@ -25,11 +24,11 @@ def __init__( border: int = 4, dark_color: tuple[int, int, int] = (0, 0, 0), light_color: tuple[int, int, int] = (255, 255, 255), - data_dark_color: Optional[tuple[int, int, int]] = None, - data_light_color: Optional[tuple[int, int, int]] = None, - finder_dark_color: Optional[tuple[int, int, int]] = None, - finder_light_color: Optional[tuple[int, int, int]] = None, - border_color: Optional[tuple[int, int, int]] = None, + data_dark_color: tuple[int, int, int] | None = None, + data_light_color: tuple[int, int, int] | None = None, + finder_dark_color: tuple[int, int, int] | None = None, + finder_light_color: tuple[int, int, int] | None = None, + border_color: tuple[int, int, int] | None = None, ) -> None: """ Initialize the converter with specified parameters for QR code generation. diff --git a/pyrit/prompt_converter/random_translation_converter.py b/pyrit/prompt_converter/random_translation_converter.py index 769cb51611..88e8777246 100644 --- a/pyrit/prompt_converter/random_translation_converter.py +++ b/pyrit/prompt_converter/random_translation_converter.py @@ -4,7 +4,6 @@ import logging import random from pathlib import Path -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH, DATASETS_PATH @@ -36,9 +35,9 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - system_prompt_template: Optional[SeedPrompt] = None, - languages: Optional[list[str]] = None, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + system_prompt_template: SeedPrompt | None = None, + languages: list[str] | None = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with a target, an optional system prompt template, and language options. @@ -46,10 +45,10 @@ def __init__( Args: converter_target (PromptTarget): The target for the prompt conversion. Can be omitted if a default has been configured via PyRIT initialization. - system_prompt_template (Optional[SeedPrompt]): The system prompt template to use for the conversion. + system_prompt_template (SeedPrompt | None): The system prompt template to use for the conversion. If not provided, a default template will be used. - languages (Optional[List[str]]): The list of available languages to use for translation. - word_selection_strategy (Optional[WordSelectionStrategy]): Strategy for selecting which words to convert. + languages (list[str] | None): The list of available languages to use for translation. + word_selection_strategy (WordSelectionStrategy | None): Strategy for selecting which words to convert. If None, all words will be converted. Raises: diff --git a/pyrit/prompt_converter/repeat_token_converter.py b/pyrit/prompt_converter/repeat_token_converter.py index c2b5a74825..c711448ca4 100644 --- a/pyrit/prompt_converter/repeat_token_converter.py +++ b/pyrit/prompt_converter/repeat_token_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import re -from typing import Literal, Optional +from typing import Literal from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -35,7 +35,7 @@ def __init__( *, token_to_repeat: str, times_to_repeat: int, - token_insert_mode: Optional[Literal["split", "prepend", "append", "repeat"]] = None, + token_insert_mode: Literal["split", "prepend", "append", "repeat"] | None = None, ) -> None: """ Initialize the converter with the specified token, number of repetitions, and insertion mode. diff --git a/pyrit/prompt_converter/scientific_translation_converter.py b/pyrit/prompt_converter/scientific_translation_converter.py index 85e05428fa..b4229a5226 100644 --- a/pyrit/prompt_converter/scientific_translation_converter.py +++ b/pyrit/prompt_converter/scientific_translation_converter.py @@ -3,7 +3,7 @@ import logging import pathlib -from typing import Literal, Optional, get_args +from typing import Literal, get_args from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -46,7 +46,7 @@ def __init__( *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] mode: str = "combined", - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the scientific translation converter. diff --git a/pyrit/prompt_converter/string_join_converter.py b/pyrit/prompt_converter/string_join_converter.py index cd961fd65f..7f95769a89 100644 --- a/pyrit/prompt_converter/string_join_converter.py +++ b/pyrit/prompt_converter/string_join_converter.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy @@ -17,14 +16,14 @@ def __init__( self, *, join_value: str = "-", - word_selection_strategy: Optional[WordSelectionStrategy] = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified join value and selection strategy. Args: join_value (str): The string used to join characters of each word. - word_selection_strategy (Optional[WordSelectionStrategy]): Strategy for selecting which words to convert. + word_selection_strategy (WordSelectionStrategy | None): Strategy for selecting which words to convert. If None, all words will be converted. """ super().__init__(word_selection_strategy=word_selection_strategy) diff --git a/pyrit/prompt_converter/template_segment_converter.py b/pyrit/prompt_converter/template_segment_converter.py index ed21b0434e..b9bafbce34 100644 --- a/pyrit/prompt_converter/template_segment_converter.py +++ b/pyrit/prompt_converter/template_segment_converter.py @@ -5,7 +5,6 @@ import logging import pathlib import random -from typing import Optional from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH from pyrit.models import ComponentIdentifier, PromptDataType, SeedPrompt @@ -28,7 +27,7 @@ class TemplateSegmentConverter(PromptConverter): def __init__( self, *, - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with the specified target and prompt template. diff --git a/pyrit/prompt_converter/tense_converter.py b/pyrit/prompt_converter/tense_converter.py index f8f20468d6..8f0852b2c2 100644 --- a/pyrit/prompt_converter/tense_converter.py +++ b/pyrit/prompt_converter/tense_converter.py @@ -3,7 +3,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,7 +26,7 @@ def __init__( *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] tense: str, - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with the target chat support, tense, and optional prompt template. diff --git a/pyrit/prompt_converter/text_selection_strategy.py b/pyrit/prompt_converter/text_selection_strategy.py index 6dbc8a8b63..438f5f73ec 100644 --- a/pyrit/prompt_converter/text_selection_strategy.py +++ b/pyrit/prompt_converter/text_selection_strategy.py @@ -5,7 +5,6 @@ import random import re from re import Pattern -from typing import Optional, Union class TextSelectionStrategy(abc.ABC): @@ -81,10 +80,10 @@ def select_words(self, *, words: list[str]) -> list[int]: Select word indices to be converted. Args: - words (List[str]): The list of words to select from. + words (list[str]): The list of words to select from. Returns: - List[int]: A list of indices representing which words should be converted. + list[int]: A list of indices representing which words should be converted. """ def select_range(self, *, text: str, word_separator: str = " ") -> tuple[int, int]: @@ -133,13 +132,13 @@ class IndexSelectionStrategy(TextSelectionStrategy): Selects text based on absolute character indices. """ - def __init__(self, *, start: int = 0, end: Optional[int] = None) -> None: + def __init__(self, *, start: int = 0, end: int | None = None) -> None: """ Initialize the index selection strategy. Args: start (int): The starting character index (inclusive). Defaults to 0. - end (Optional[int]): The ending character index (exclusive). If None, selects to end of text. + end (int | None): The ending character index (exclusive). If None, selects to end of text. """ self._start = start self._end = end @@ -165,12 +164,12 @@ class RegexSelectionStrategy(TextSelectionStrategy): Selects text based on the first regex match. """ - def __init__(self, *, pattern: Union[str, Pattern[str]]) -> None: + def __init__(self, *, pattern: str | Pattern[str]) -> None: """ Initialize the regex selection strategy. Args: - pattern (Union[str, Pattern[str]]): The regex pattern to match. + pattern (str | Pattern[str]): The regex pattern to match. """ self._pattern = re.compile(pattern) if isinstance(pattern, str) else pattern @@ -290,7 +289,7 @@ class ProportionSelectionStrategy(TextSelectionStrategy): Selects a proportion of text anchored to a specific position (start, end, middle, or random). """ - def __init__(self, *, proportion: float, anchor: str = "start", seed: Optional[int] = None) -> None: + def __init__(self, *, proportion: float, anchor: str = "start", seed: int | None = None) -> None: """ Initialize the proportion selection strategy. @@ -301,7 +300,7 @@ def __init__(self, *, proportion: float, anchor: str = "start", seed: Optional[i - 'end': Select from the end - 'middle': Select from the middle - 'random': Select from a random position - seed (Optional[int]): Random seed for reproducible random selections. Defaults to None. + seed (int | None): Random seed for reproducible random selections. Defaults to None. Raises: ValueError: If proportion is not between 0.0 and 1.0, or anchor is invalid. @@ -404,7 +403,7 @@ def __init__(self, *, indices: list[int]) -> None: Initialize the word index selection strategy. Args: - indices (List[int]): The list of word indices to select. + indices (list[int]): The list of word indices to select. """ self._indices = indices @@ -413,10 +412,10 @@ def select_words(self, *, words: list[str]) -> list[int]: Select words at the specified indices. Args: - words (List[str]): The list of words to select from. + words (list[str]): The list of words to select from. Returns: - List[int]: The list of valid indices. + list[int]: The list of valid indices. Raises: ValueError: If any indices are out of range. @@ -443,7 +442,7 @@ def __init__(self, *, keywords: list[str], case_sensitive: bool = True) -> None: Initialize the word keyword selection strategy. Args: - keywords (List[str]): The list of keywords to match. + keywords (list[str]): The list of keywords to match. case_sensitive (bool): Whether matching is case-sensitive. Defaults to True. """ self._keywords = keywords @@ -454,10 +453,10 @@ def select_words(self, *, words: list[str]) -> list[int]: Select words that match the keywords. Args: - words (List[str]): The list of words to select from. + words (list[str]): The list of words to select from. Returns: - List[int]: The list of indices where keywords were found. + list[int]: The list of indices where keywords were found. """ if not words: return [] @@ -473,13 +472,13 @@ class WordProportionSelectionStrategy(WordSelectionStrategy): Selects a random proportion of words. """ - def __init__(self, *, proportion: float, seed: Optional[int] = None) -> None: + def __init__(self, *, proportion: float, seed: int | None = None) -> None: """ Initialize the word proportion selection strategy. Args: proportion (float): The proportion of words to select (0.0 to 1.0). - seed (Optional[int]): Random seed for reproducible selections. Defaults to None. + seed (int | None): Random seed for reproducible selections. Defaults to None. Raises: ValueError: If proportion is not between 0.0 and 1.0. @@ -495,10 +494,10 @@ def select_words(self, *, words: list[str]) -> list[int]: Select a random proportion of words. Args: - words (List[str]): The list of words to select from. + words (list[str]): The list of words to select from. Returns: - List[int]: The list of randomly selected indices. + list[int]: The list of randomly selected indices. """ if not words: return [] @@ -515,12 +514,12 @@ class WordRegexSelectionStrategy(WordSelectionStrategy): Selects words that match a regex pattern. """ - def __init__(self, *, pattern: Union[str, Pattern[str]]) -> None: + def __init__(self, *, pattern: str | Pattern[str]) -> None: """ Initialize the word regex selection strategy. Args: - pattern (Union[str, Pattern[str]]): The regex pattern to match against words. + pattern (str | Pattern[str]): The regex pattern to match against words. """ self._pattern = re.compile(pattern) if isinstance(pattern, str) else pattern @@ -529,10 +528,10 @@ def select_words(self, *, words: list[str]) -> list[int]: Select words that match the regex pattern. Args: - words (List[str]): The list of words to select from. + words (list[str]): The list of words to select from. Returns: - List[int]: The list of indices where words matched the pattern. + list[int]: The list of indices where words matched the pattern. """ if not words: return [] @@ -573,10 +572,10 @@ def select_words(self, *, words: list[str]) -> list[int]: Select words based on the relative position. Args: - words (List[str]): The list of words to select from. + words (list[str]): The list of words to select from. Returns: - List[int]: The list of indices in the specified position range. + list[int]: The list of indices in the specified position range. """ if not words: return [] @@ -598,9 +597,9 @@ def select_words(self, *, words: list[str]) -> list[int]: Select all words. Args: - words (List[str]): The list of words to select from. + words (list[str]): The list of words to select from. Returns: - List[int]: All word indices. + list[int]: All word indices. """ return list(range(len(words))) diff --git a/pyrit/prompt_converter/token_smuggling/ascii_smuggler_converter.py b/pyrit/prompt_converter/token_smuggling/ascii_smuggler_converter.py index de0dc05808..a9f7cc7c68 100644 --- a/pyrit/prompt_converter/token_smuggling/ascii_smuggler_converter.py +++ b/pyrit/prompt_converter/token_smuggling/ascii_smuggler_converter.py @@ -62,7 +62,7 @@ def encode_message(self, *, message: str) -> tuple[str, str]: message (str): The message to encode. Returns: - Tuple[str, str]: A tuple with a summary of code points and the encoded message. + tuple[str, str]: A tuple with a summary of code points and the encoded message. """ encoded = "" code_points = "" diff --git a/pyrit/prompt_converter/token_smuggling/base.py b/pyrit/prompt_converter/token_smuggling/base.py index 71ca289138..34f746f850 100644 --- a/pyrit/prompt_converter/token_smuggling/base.py +++ b/pyrit/prompt_converter/token_smuggling/base.py @@ -114,7 +114,7 @@ def encode_message(self, *, message: str) -> tuple[str, str]: message (str): The message to encode. Returns: - Tuple[str, str]: A tuple containing a summary and the encoded message. + tuple[str, str]: A tuple containing a summary and the encoded message. """ raise NotImplementedError diff --git a/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py b/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py index dbb7a1e312..0ea2e964e6 100644 --- a/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py +++ b/pyrit/prompt_converter/token_smuggling/sneaky_bits_smuggler_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Literal, Optional +from typing import Literal from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.token_smuggling.base import SmugglerConverter @@ -29,16 +29,16 @@ class SneakyBitsSmugglerConverter(SmugglerConverter): def __init__( self, action: Literal["encode", "decode"] = "encode", - zero_char: Optional[str] = None, - one_char: Optional[str] = None, + zero_char: str | None = None, + one_char: str | None = None, ) -> None: """ Initialize the converter with options for encoding/decoding in Sneaky Bits mode. Args: action (Literal["encode", "decode"]): The action to perform. - zero_char (Optional[str]): Character to represent binary 0 in ``sneaky_bits`` mode (default: U+2062). - one_char (Optional[str]): Character to represent binary 1 in ``sneaky_bits`` mode (default: U+2064). + zero_char (str | None): Character to represent binary 0 in ``sneaky_bits`` mode (default: U+2062). + one_char (str | None): Character to represent binary 1 in ``sneaky_bits`` mode (default: U+2064). Raises: ValueError: If an unsupported action or ``encoding_mode`` is provided. @@ -73,7 +73,7 @@ def encode_message(self, message: str) -> tuple[str, str]: message (str): The message to encode. Returns: - Tuple[str, str]: A tuple where the first element is a bit summary (empty in this implementation) + tuple[str, str]: A tuple where the first element is a bit summary (empty in this implementation) and the second element is the encoded message containing the invisible bits. """ encoded_bits = [] diff --git a/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py b/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py index 34f9d5612d..24fb70fc8d 100644 --- a/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py +++ b/pyrit/prompt_converter/token_smuggling/variation_selector_smuggler_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Literal, Optional +from typing import Literal from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.token_smuggling.base import SmugglerConverter @@ -36,7 +36,7 @@ class VariationSelectorSmugglerConverter(SmugglerConverter): def __init__( self, action: Literal["encode", "decode"] = "encode", - base_char_utf8: Optional[str] = None, + base_char_utf8: str | None = None, embed_in_base: bool = True, ) -> None: """ @@ -44,7 +44,7 @@ def __init__( Args: action (Literal["encode", "decode"]): The action to perform. - base_char_utf8 (Optional[str]): Base character for ``variation_selector_smuggler`` mode (default: 😊). + base_char_utf8 (str | None): Base character for ``variation_selector_smuggler`` mode (default: 😊). embed_in_base (bool): If True, the hidden payload is embedded directly into the base character. If False, a visible separator (space) is inserted between the base and payload. Default is True. @@ -86,7 +86,7 @@ def encode_message(self, message: str) -> tuple[str, str]: message (str): The message to encode. Returns: - Tuple[str, str]: A tuple containing a summary of the code points and the encoded string. + tuple[str, str]: A tuple containing a summary of the code points and the encoded string. """ payload = "" data = message.encode("utf-8") @@ -154,7 +154,7 @@ def encode_visible_hidden(self, visible: str, hidden: str) -> tuple[str, str]: hidden (str): The secret/hidden text to encode. Returns: - Tuple[str, str]: A tuple containing a summary and the combined text. + tuple[str, str]: A tuple containing a summary and the combined text. """ summary, encoded_hidden = self.encode_message(hidden) combined = visible + encoded_hidden @@ -172,7 +172,7 @@ def decode_visible_hidden(self, combined: str) -> tuple[str, str]: combined (str): The combined text containing visible and hidden parts. Returns: - Tuple[str, str]: A tuple with the visible text and the decoded hidden text. + tuple[str, str]: A tuple with the visible text and the decoded hidden text. """ base_char = self.utf8_base_char index = combined.find(base_char) diff --git a/pyrit/prompt_converter/tone_converter.py b/pyrit/prompt_converter/tone_converter.py index c21a118603..562a4ee6af 100644 --- a/pyrit/prompt_converter/tone_converter.py +++ b/pyrit/prompt_converter/tone_converter.py @@ -3,7 +3,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -27,7 +26,7 @@ def __init__( *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] tone: str, - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with the target chat support, tone, and optional prompt template. diff --git a/pyrit/prompt_converter/toxic_sentence_generator_converter.py b/pyrit/prompt_converter/toxic_sentence_generator_converter.py index 3159cf1de7..07cfc3744e 100644 --- a/pyrit/prompt_converter/toxic_sentence_generator_converter.py +++ b/pyrit/prompt_converter/toxic_sentence_generator_converter.py @@ -7,7 +7,6 @@ import logging import pathlib -from typing import Optional from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults from pyrit.common.path import CONVERTER_SEED_PROMPT_PATH @@ -35,7 +34,7 @@ def __init__( self, *, converter_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - prompt_template: Optional[SeedPrompt] = None, + prompt_template: SeedPrompt | None = None, ) -> None: """ Initialize the converter with a specific target and template. diff --git a/pyrit/prompt_converter/unicode_replacement_converter.py b/pyrit/prompt_converter/unicode_replacement_converter.py index 71a0e52e54..1b53d89be2 100644 --- a/pyrit/prompt_converter/unicode_replacement_converter.py +++ b/pyrit/prompt_converter/unicode_replacement_converter.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy @@ -17,14 +16,14 @@ def __init__( self, *, encode_spaces: bool = False, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified selection strategy. Args: encode_spaces (bool): If True, spaces in the prompt will be replaced with unicode representation. - word_selection_strategy (Optional[WordSelectionStrategy]): Strategy for selecting which words to convert. + word_selection_strategy (WordSelectionStrategy | None): Strategy for selecting which words to convert. If None, all words will be converted. """ super().__init__(word_selection_strategy=word_selection_strategy) diff --git a/pyrit/prompt_converter/word_doc_converter.py b/pyrit/prompt_converter/word_doc_converter.py index 1dc5d00b5c..b811ab1b3f 100644 --- a/pyrit/prompt_converter/word_doc_converter.py +++ b/pyrit/prompt_converter/word_doc_converter.py @@ -7,7 +7,7 @@ import hashlib from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from docx import Document @@ -26,7 +26,7 @@ class _WordDocInjectionConfig: """Configuration for how to inject content into a Word document.""" - existing_docx: Optional[Path] + existing_docx: Path | None placeholder: str @@ -66,8 +66,8 @@ class WordDocConverter(PromptConverter): def __init__( self, *, - prompt_template: Optional[SeedPrompt] = None, - existing_docx: Optional[Path] = None, + prompt_template: SeedPrompt | None = None, + existing_docx: Path | None = None, placeholder: str = "{{INJECTION_PLACEHOLDER}}", ) -> None: """ @@ -112,7 +112,7 @@ def _build_identifier(self) -> ComponentIdentifier: Returns: ComponentIdentifier: The identifier with converter-specific parameters. """ - template_hash: Optional[str] = None + template_hash: str | None = None if self._prompt_template: template_hash = hashlib.sha256(str(self._prompt_template.value).encode("utf-8")).hexdigest()[:16] diff --git a/pyrit/prompt_converter/word_level_converter.py b/pyrit/prompt_converter/word_level_converter.py index 5a2f874f0c..2b459f7f5b 100644 --- a/pyrit/prompt_converter/word_level_converter.py +++ b/pyrit/prompt_converter/word_level_converter.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import abc -from typing import Any, Optional +from typing import Any from pyrit.models import ComponentIdentifier, PromptDataType from pyrit.prompt_converter.prompt_converter import ConverterResult, PromptConverter @@ -30,17 +30,17 @@ class WordLevelConverter(PromptConverter): def __init__( self, *, - word_selection_strategy: Optional[WordSelectionStrategy] = None, - word_split_separator: Optional[str] = " ", + word_selection_strategy: WordSelectionStrategy | None = None, + word_split_separator: str | None = " ", **kwargs: Any, ) -> None: """ Initialize the converter with the specified selection strategy. Args: - word_selection_strategy (Optional[WordSelectionStrategy]): The strategy for selecting which + word_selection_strategy (WordSelectionStrategy | None): The strategy for selecting which words to convert. If None, all words will be converted. Defaults to None. - word_split_separator (Optional[str]): Separator used to split words in the input text. + word_split_separator (str | None): Separator used to split words in the input text. If None, splits by any whitespace. Defaults to " ". **kwargs: Forwarded to ``PromptConverter.__init__`` to support cooperative multiple inheritance (e.g., ``converter_target`` when mixed with LLM-based converters). diff --git a/pyrit/prompt_converter/zalgo_converter.py b/pyrit/prompt_converter/zalgo_converter.py index ddd7c686ec..6930672c0d 100644 --- a/pyrit/prompt_converter/zalgo_converter.py +++ b/pyrit/prompt_converter/zalgo_converter.py @@ -3,7 +3,6 @@ import logging import random -from typing import Optional from pyrit.models import ComponentIdentifier from pyrit.prompt_converter.text_selection_strategy import WordSelectionStrategy @@ -25,16 +24,16 @@ def __init__( self, *, intensity: int = 10, - seed: Optional[int] = None, - word_selection_strategy: Optional[WordSelectionStrategy] = None, + seed: int | None = None, + word_selection_strategy: WordSelectionStrategy | None = None, ) -> None: """ Initialize the converter with the specified selection parameters. Args: intensity (int): Number of combining marks per character (higher = more cursed). Default is 10. - seed (Optional[int]): Optional seed for reproducible output. - word_selection_strategy (Optional[WordSelectionStrategy]): Strategy for selecting which words to convert. + seed (int | None): Optional seed for reproducible output. + word_selection_strategy (WordSelectionStrategy | None): Strategy for selecting which words to convert. If None, all words will be converted. """ super().__init__(word_selection_strategy=word_selection_strategy) diff --git a/pyrit/prompt_normalizer/normalizer_request.py b/pyrit/prompt_normalizer/normalizer_request.py index c030ca5278..872f235f8d 100644 --- a/pyrit/prompt_normalizer/normalizer_request.py +++ b/pyrit/prompt_normalizer/normalizer_request.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Optional from pyrit.models import Message from pyrit.prompt_normalizer.prompt_converter_configuration import ( @@ -27,7 +26,7 @@ def __init__( message: Message, request_converter_configurations: list[PromptConverterConfiguration] | None = None, response_converter_configurations: list[PromptConverterConfiguration] | None = None, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> None: """ Initialize a normalizer request. @@ -38,7 +37,7 @@ def __init__( the request. Defaults to an empty list. response_converter_configurations (list[PromptConverterConfiguration]): Configurations for converting the response. Defaults to an empty list. - conversation_id (Optional[str]): The ID of the conversation. Defaults to None. + conversation_id (str | None): The ID of the conversation. Defaults to None. """ if response_converter_configurations is None: response_converter_configurations = [] diff --git a/pyrit/prompt_normalizer/prompt_converter_configuration.py b/pyrit/prompt_normalizer/prompt_converter_configuration.py index cb9ae55425..1218b9168b 100644 --- a/pyrit/prompt_normalizer/prompt_converter_configuration.py +++ b/pyrit/prompt_normalizer/prompt_converter_configuration.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Optional from pyrit.models import PromptDataType from pyrit.prompt_converter import PromptConverter @@ -19,8 +18,8 @@ class PromptConverterConfiguration: """ converters: list[PromptConverter] - indexes_to_apply: Optional[list[int]] = None - prompt_data_types_to_apply: Optional[list[PromptDataType]] = None + indexes_to_apply: list[int] | None = None + prompt_data_types_to_apply: list[PromptDataType] | None = None @classmethod def from_converters(cls, *, converters: list[PromptConverter]) -> list["PromptConverterConfiguration"]: @@ -32,7 +31,7 @@ def from_converters(cls, *, converters: list[PromptConverter]) -> list["PromptCo converters: List of PromptConverters Returns: - List[PromptConverterConfiguration]: List of configurations, one per converter + list[PromptConverterConfiguration]: List of configurations, one per converter """ if not converters: return [] diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 9c23f91cf3..c089930778 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -9,7 +9,7 @@ import traceback import wave from pathlib import Path -from typing import Any, Optional +from typing import Any from uuid import uuid4 from pyrit.common.deprecation import print_deprecation_message @@ -68,11 +68,11 @@ async def send_prompt_async( *, message: Message, target: PromptTarget, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, request_converter_configurations: list[PromptConverterConfiguration] | None = None, response_converter_configurations: list[PromptConverterConfiguration] | None = None, - labels: Optional[dict[str, str]] = None, - attack_identifier: Optional[ComponentIdentifier] = None, + labels: dict[str, str] | None = None, + attack_identifier: ComponentIdentifier | None = None, ) -> Message: """ Send a single request to a target. @@ -85,9 +85,9 @@ async def send_prompt_async( converting the request. Defaults to an empty list. response_converter_configurations (list[PromptConverterConfiguration], optional): Configurations for converting the response. Defaults to an empty list. - labels (Optional[dict[str, str]], optional): Labels associated with the request. Defaults to None. + labels (dict[str, str] | None, optional): Labels associated with the request. Defaults to None. Deprecated: This parameter will be removed in a release 0.16.0. - attack_identifier (Optional[ComponentIdentifier], optional): Identifier for the attack. Defaults to + attack_identifier (ComponentIdentifier | None, optional): Identifier for the attack. Defaults to None. Returns: @@ -196,8 +196,8 @@ async def send_prompt_batch_to_target_async( *, requests: list[NormalizerRequest], target: PromptTarget, - labels: Optional[dict[str, str]] = None, - attack_identifier: Optional[ComponentIdentifier] = None, + labels: dict[str, str] | None = None, + attack_identifier: ComponentIdentifier | None = None, batch_size: int = 10, ) -> list[Message]: """ @@ -206,9 +206,9 @@ async def send_prompt_batch_to_target_async( Args: requests (list[NormalizerRequest]): A list of NormalizerRequest objects to be sent. target (PromptTarget): The target to which the prompts are sent. - labels (Optional[dict[str, str]], optional): A dictionary of labels to be included with the request. + labels (dict[str, str] | None, optional): A dictionary of labels to be included with the request. Defaults to None. - attack_identifier (Optional[ComponentIdentifier], optional): The attack identifier. + attack_identifier (ComponentIdentifier | None, optional): The attack identifier. Defaults to None. batch_size (int, optional): The number of prompts to include in each batch. Defaults to 10. @@ -397,23 +397,23 @@ async def add_prepended_conversation_to_memory_async( self, conversation_id: str, should_convert: bool = True, - converter_configurations: Optional[list[PromptConverterConfiguration]] = None, - attack_identifier: Optional[ComponentIdentifier] = None, - prepended_conversation: Optional[list[Message]] = None, - ) -> Optional[list[Message]]: + converter_configurations: list[PromptConverterConfiguration] | None = None, + attack_identifier: ComponentIdentifier | None = None, + prepended_conversation: list[Message] | None = None, + ) -> list[Message] | None: """ Process the prepended conversation by converting it if needed and adding it to memory. Args: conversation_id (str): The conversation ID to use for the message pieces should_convert (bool): Whether to convert the prepended conversation - converter_configurations (Optional[list[PromptConverterConfiguration]]): Configurations for converting the + converter_configurations (list[PromptConverterConfiguration] | None): Configurations for converting the request - attack_identifier (Optional[ComponentIdentifier]): Identifier for the attack - prepended_conversation (Optional[list[Message]]): The conversation to prepend + attack_identifier (ComponentIdentifier | None): Identifier for the attack + prepended_conversation (list[Message] | None): The conversation to prepend Returns: - Optional[list[Message]]: The processed prepended conversation + list[Message] | None: The processed prepended conversation """ if not prepended_conversation: return None @@ -454,15 +454,15 @@ async def add_prepended_conversation_to_memory( # pyrit-async-suffix-exempt self, conversation_id: str, should_convert: bool = True, - converter_configurations: Optional[list[PromptConverterConfiguration]] = None, - attack_identifier: Optional[ComponentIdentifier] = None, - prepended_conversation: Optional[list[Message]] = None, - ) -> Optional[list[Message]]: + converter_configurations: list[PromptConverterConfiguration] | None = None, + attack_identifier: ComponentIdentifier | None = None, + prepended_conversation: list[Message] | None = None, + ) -> list[Message] | None: """ Use ``add_prepended_conversation_to_memory_async`` instead; this is a deprecated alias. Returns: - Optional[list[Message]]: Same as ``add_prepended_conversation_to_memory_async``. + list[Message] | None: Same as ``add_prepended_conversation_to_memory_async``. """ print_deprecation_message( old_item="pyrit.prompt_normalizer.PromptNormalizer.add_prepended_conversation_to_memory", diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index 184cb1664e..b3b10a04e2 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -3,7 +3,6 @@ import logging from enum import Enum -from typing import Optional from urllib.parse import urlparse from azure.core.exceptions import ClientAuthenticationError @@ -69,11 +68,11 @@ class AzureBlobStorageTarget(PromptTarget): def __init__( self, *, - container_url: Optional[str] = None, - sas_token: Optional[str] = None, + container_url: str | None = None, + sas_token: str | None = None, blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT, - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the Azure Blob Storage target. @@ -95,8 +94,8 @@ def __init__( env_var_name=self.AZURE_STORAGE_CONTAINER_ENVIRONMENT_VARIABLE, passed_value=container_url ) - self._sas_token: Optional[str] = sas_token - self._client_async: Optional[AsyncContainerClient] = None + self._sas_token: str | None = sas_token + self._client_async: AsyncContainerClient | None = None super().__init__( endpoint=self._container_url, diff --git a/pyrit/prompt_target/batch_helper.py b/pyrit/prompt_target/batch_helper.py index 95ec6809fb..399465e4fb 100644 --- a/pyrit/prompt_target/batch_helper.py +++ b/pyrit/prompt_target/batch_helper.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Callable, Generator, Sequence -from typing import Any, Optional +from typing import Any from pyrit.prompt_target.common.prompt_target import PromptTarget @@ -31,7 +31,7 @@ def _get_chunks(*args: Sequence[Any], batch_size: int) -> Generator[list[Sequenc yield [arg[i : i + batch_size] for arg in args] -def _validate_rate_limit_parameters(prompt_target: Optional[PromptTarget], batch_size: int) -> None: +def _validate_rate_limit_parameters(prompt_target: PromptTarget | None, batch_size: int) -> None: """ Validate the constraints between Rate Limit (Requests Per Minute) and batch size. @@ -49,7 +49,7 @@ def _validate_rate_limit_parameters(prompt_target: Optional[PromptTarget], batch async def batch_task_async( *, - prompt_target: Optional[PromptTarget] = None, + prompt_target: PromptTarget | None = None, batch_size: int, items_to_batch: Sequence[Sequence[Any]], task_func: Callable[..., Any], diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 9605ed444f..218f46d52e 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -3,7 +3,7 @@ import abc import logging -from typing import Any, Union, final +from typing import Any, final from pyrit.common.deprecation import print_deprecation_message from pyrit.memory import CentralMemory, MemoryInterface @@ -341,7 +341,7 @@ def _create_identifier( self, *, params: dict[str, Any] | None = None, - children: dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the target identifier. @@ -356,7 +356,7 @@ def _create_identifier( Args: params (dict[str, Any] | None): Additional behavioral parameters from the subclass (e.g., temperature, top_p). Merged into the base params. - children (dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]] | None): + children (dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None): Named child component identifiers. Returns: diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 7f1010745f..719c441d38 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from enum import Enum from types import MappingProxyType -from typing import NoReturn, Optional, cast +from typing import NoReturn, cast from pyrit.models import PromptDataType @@ -165,7 +165,7 @@ def includes(self, *, capability: CapabilityName) -> bool: return bool(getattr(self, capability.value)) @staticmethod - def get_known_capabilities(underlying_model: str) -> "Optional[TargetCapabilities]": + def get_known_capabilities(underlying_model: str) -> "TargetCapabilities | None": """ Return the known capabilities for a specific underlying model, or None if unrecognized. diff --git a/pyrit/prompt_target/common/utils.py b/pyrit/prompt_target/common/utils.py index 2aaa10bb68..6b6a74a813 100644 --- a/pyrit/prompt_target/common/utils.py +++ b/pyrit/prompt_target/common/utils.py @@ -3,12 +3,12 @@ import asyncio from collections.abc import Callable -from typing import Any, Optional +from typing import Any from pyrit.exceptions import PyritException -def validate_temperature(temperature: Optional[float]) -> None: +def validate_temperature(temperature: float | None) -> None: """ Validate that temperature parameter is within valid range. @@ -22,7 +22,7 @@ def validate_temperature(temperature: Optional[float]) -> None: raise PyritException(message="temperature must be between 0 and 2 (inclusive).") -def validate_top_p(top_p: Optional[float]) -> None: +def validate_top_p(top_p: float | None) -> None: """ Validate that top_p parameter is within valid range. diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index 1bad549c7f..b2dc6e342a 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -4,7 +4,6 @@ import enum import json import logging -from typing import Optional from pyrit.common import net_utility from pyrit.common.deprecation import print_deprecation_message @@ -43,8 +42,8 @@ def __init__( self, *, level: GandalfLevel, - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the Gandalf target. diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index f1c5988ac8..ff4bcc6999 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -6,7 +6,7 @@ import logging import re from collections.abc import Callable -from typing import Any, Optional +from typing import Any import httpx @@ -42,11 +42,11 @@ def __init__( http_request: str, prompt_regex_string: str = "{PROMPT}", use_tls: bool = True, - callback_function: Optional[Callable[..., Any]] = None, - max_requests_per_minute: Optional[int] = None, - client: Optional[httpx.AsyncClient] = None, + callback_function: Callable[..., Any] | None = None, + max_requests_per_minute: int | None = None, + client: httpx.AsyncClient | None = None, model_name: str = "", - custom_configuration: Optional[TargetConfiguration] = None, + custom_configuration: TargetConfiguration | None = None, **httpx_client_kwargs: Any, ) -> None: """ @@ -113,7 +113,7 @@ def with_client( http_request: str, prompt_regex_string: str = "{PROMPT}", callback_function: Callable[..., Any] | None = None, - max_requests_per_minute: Optional[int] = None, + max_requests_per_minute: int | None = None, ) -> "HTTPTarget": """ Alternative constructor that accepts a pre-configured httpx client. diff --git a/pyrit/prompt_target/http_target/http_target_callback_functions.py b/pyrit/prompt_target/http_target/http_target_callback_functions.py index 90cc7f79a3..8d749af73d 100644 --- a/pyrit/prompt_target/http_target/http_target_callback_functions.py +++ b/pyrit/prompt_target/http_target/http_target_callback_functions.py @@ -5,7 +5,7 @@ import json import re from collections.abc import Callable -from typing import Any, Optional +from typing import Any import requests @@ -42,7 +42,7 @@ def parse_json_http_response(response: requests.Response) -> str: def get_http_target_regex_matching_callback_function( - key: str, url: Optional[str] = None + key: str, url: str | None = None ) -> Callable[[requests.Response], str]: """ Get a callback function that parses HTTP responses using regex matching. diff --git a/pyrit/prompt_target/http_target/httpx_api_target.py b/pyrit/prompt_target/http_target/httpx_api_target.py index bd32fd1fe2..95f0c47124 100644 --- a/pyrit/prompt_target/http_target/httpx_api_target.py +++ b/pyrit/prompt_target/http_target/httpx_api_target.py @@ -5,7 +5,7 @@ import mimetypes from collections.abc import Callable from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal import aiofiles import httpx @@ -42,15 +42,15 @@ def __init__( *, http_url: str, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"] = "POST", - file_path: Optional[str] = None, - json_data: Optional[dict[str, Any]] = None, - form_data: Optional[dict[str, Any]] = None, - params: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, str]] = None, - http2: Optional[bool] = None, + file_path: str | None = None, + json_data: dict[str, Any] | None = None, + form_data: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + http2: bool | None = None, callback_function: Callable[..., Any] | None = None, - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, **httpx_client_kwargs: Any, ) -> None: """ diff --git a/pyrit/prompt_target/openai/openai_chat_audio_config.py b/pyrit/prompt_target/openai/openai_chat_audio_config.py index 6f29f419ca..eb48a78dad 100644 --- a/pyrit/prompt_target/openai/openai_chat_audio_config.py +++ b/pyrit/prompt_target/openai/openai_chat_audio_config.py @@ -8,7 +8,7 @@ # OpenAI SDK: openai/types/chat/chat_completion_audio_param.py voice field # SDK Literal includes: alloy, ash, ballad, coral, echo, sage, shimmer, verse, marin, cedar # SDK docstring also lists: fable, nova, onyx (we include these for completeness) -# Note: SDK uses Union[str, Literal[...]] so any string is accepted by the API. +# Note: SDK uses str | Literal[...] so any string is accepted by the API. ChatAudioVoice = Literal[ "alloy", "ash", "ballad", "coral", "echo", "fable", "nova", "onyx", "sage", "shimmer", "verse", "marin", "cedar" ] diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index f82fc40e29..1fd2400ca9 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -5,7 +5,7 @@ import json import logging from collections.abc import MutableSequence -from typing import Any, Optional +from typing import Any from pyrit.common.data_url_converter import convert_local_image_to_data_url_async from pyrit.exceptions import ( @@ -80,17 +80,17 @@ class OpenAIChatTarget(OpenAITarget): def __init__( self, *, - max_completion_tokens: Optional[int] = None, - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - seed: Optional[int] = None, - n: Optional[int] = None, - audio_response_config: Optional[OpenAIChatAudioConfig] = None, - extra_body_parameters: Optional[dict[str, Any]] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + top_p: float | None = None, + frequency_penalty: float | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + n: int | None = None, + audio_response_config: OpenAIChatAudioConfig | None = None, + extra_body_parameters: dict[str, Any] | None = None, + custom_configuration: TargetConfiguration | None = None, **kwargs: Any, ) -> None: """ @@ -259,7 +259,7 @@ def _check_content_filter(self, response: Any) -> bool: pass return False - def _extract_partial_content(self, response: Any) -> Optional[str]: + def _extract_partial_content(self, response: Any) -> str | None: """ Extract partial content from a Chat Completions response with finish_reason=content_filter. @@ -279,7 +279,7 @@ def _extract_partial_content(self, response: Any) -> Optional[str]: pass return None - def _validate_response(self, response: Any, request: MessagePiece) -> Optional[Message]: + def _validate_response(self, response: Any, request: MessagePiece) -> Message | None: """ Validate a Chat Completions API response for errors. @@ -415,7 +415,7 @@ async def _construct_message_from_response_async(self, response: Any, request: M audio_response = message.audio # Add transcript as text piece with metadata - audio_transcript: Optional[str] = getattr(audio_response, "transcript", None) + audio_transcript: str | None = getattr(audio_response, "transcript", None) if audio_transcript: transcript_piece = construct_response_from_request( request=request, @@ -426,7 +426,7 @@ async def _construct_message_from_response_async(self, response: Any, request: M pieces.append(transcript_piece) # Save audio data and add as audio_path piece - audio_data: Optional[str] = getattr(audio_response, "data", None) + audio_data: str | None = getattr(audio_response, "data", None) if audio_data: audio_path = await self._save_audio_response_async(audio_data_base64=audio_data) audio_piece = construct_response_from_request( @@ -676,7 +676,7 @@ async def _construct_request_body_async( # Filter out None values return {k: v for k, v in body_parameters.items() if v is not None} - def _build_response_format(self, json_config: _JsonResponseConfig) -> Optional[dict[str, Any]]: + def _build_response_format(self, json_config: _JsonResponseConfig) -> dict[str, Any] | None: if not json_config.enabled: return None diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 9aa04c009b..23c0dfd174 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Any, Optional +from typing import Any from pyrit.exceptions.exception_classes import ( pyrit_target_retry, @@ -30,13 +30,13 @@ class OpenAICompletionTarget(OpenAITarget): def __init__( self, - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float] = None, - n: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_tokens: int | None = None, + temperature: float | None = None, + top_p: float | None = None, + presence_penalty: float | None = None, + frequency_penalty: float | None = None, + n: int | None = None, + custom_configuration: TargetConfiguration | None = None, *args: Any, **kwargs: Any, ) -> None: diff --git a/pyrit/prompt_target/openai/openai_error_handling.py b/pyrit/prompt_target/openai/openai_error_handling.py index d6f1099241..b3d0fab505 100644 --- a/pyrit/prompt_target/openai/openai_error_handling.py +++ b/pyrit/prompt_target/openai/openai_error_handling.py @@ -10,7 +10,6 @@ import json import logging -from typing import Optional, Union from pyrit.exceptions.exception_classes import CONTENT_FILTER_MARKERS @@ -36,7 +35,7 @@ ) -def _extract_request_id_from_exception(exc: Exception) -> Optional[str]: +def _extract_request_id_from_exception(exc: Exception) -> str | None: """ Extract the x-request-id from an OpenAI SDK exception for logging/telemetry. @@ -57,7 +56,7 @@ def _extract_request_id_from_exception(exc: Exception) -> Optional[str]: return None -def _extract_retry_after_from_exception(exc: Exception) -> Optional[float]: +def _extract_retry_after_from_exception(exc: Exception) -> float | None: """ Extract the Retry-After header from a rate-limit exception for intelligent backoff. @@ -82,7 +81,7 @@ def _extract_retry_after_from_exception(exc: Exception) -> Optional[float]: return None -def _is_content_filter_error(data: Union[dict[str, object], str]) -> bool: +def _is_content_filter_error(data: dict[str, object] | str) -> bool: """ Check if error data indicates content filtering. @@ -108,7 +107,7 @@ def _is_content_filter_error(data: Union[dict[str, object], str]) -> bool: return any(marker in haystack for marker in CONTENT_FILTER_MARKERS) -def _extract_error_payload(exc: Exception) -> tuple[Union[dict[str, object], str], bool]: +def _extract_error_payload(exc: Exception) -> tuple[dict[str, object] | str, bool]: """ Extract error payload and detect content filter from an OpenAI SDK exception. diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index bc56207c7a..5d03f985c3 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import base64 import logging -from typing import Any, Literal, Optional +from typing import Any, Literal import httpx @@ -72,11 +72,11 @@ def __init__( "1792x1024", "1024x1792", ] = "1024x1024", - output_format: Optional[Literal["png", "jpeg", "webp"]] = None, - quality: Optional[Literal["auto", "low", "medium", "high", "standard", "hd"]] = None, - style: Optional[Literal["natural", "vivid"]] = None, - background: Optional[Literal["transparent", "opaque", "auto"]] = None, - custom_configuration: Optional[TargetConfiguration] = None, + output_format: Literal["png", "jpeg", "webp"] | None = None, + quality: Literal["auto", "low", "medium", "high", "standard", "hd"] | None = None, + style: Literal["natural", "vivid"] | None = None, + background: Literal["transparent", "opaque", "auto"] | None = None, + custom_configuration: TargetConfiguration | None = None, *args: Any, **kwargs: Any, ) -> None: diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 0197d5ba64..e5ed421385 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -6,7 +6,7 @@ import logging import re import wave -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Literal from openai import AsyncOpenAI @@ -87,9 +87,9 @@ class RealtimeTarget(OpenAITarget): def __init__( self, *, - voice: Optional[RealTimeVoice] = None, - existing_convo: Optional[dict[str, Any]] = None, - custom_configuration: Optional[TargetConfiguration] = None, + voice: RealTimeVoice | None = None, + existing_convo: dict[str, Any] | None = None, + custom_configuration: TargetConfiguration | None = None, **kwargs: Any, ) -> None: """ @@ -121,7 +121,7 @@ def __init__( self.voice = voice self._existing_conversation = existing_convo if existing_convo is not None else {} - self._realtime_client: Optional[AsyncOpenAI] = None + self._realtime_client: AsyncOpenAI | None = None def open_streaming_session( self, @@ -550,7 +550,7 @@ async def save_audio_async( num_channels: int = 1, sample_width: int = 2, sample_rate: int = 16000, - output_filename: Optional[str] = None, + output_filename: str | None = None, ) -> str: """ Save audio bytes to a WAV file. @@ -583,7 +583,7 @@ async def save_audio( # pyrit-async-suffix-exempt num_channels: int = 1, sample_width: int = 2, sample_rate: int = 16000, - output_filename: Optional[str] = None, + output_filename: str | None = None, ) -> str: """ Use ``save_audio_async`` instead; this is a deprecated alias. @@ -865,7 +865,7 @@ async def send_text_async( conversation_id: conversation ID Returns: - Tuple[str, RealtimeTargetResult]: Path to saved audio file and the RealtimeTargetResult + tuple[str, RealtimeTargetResult]: Path to saved audio file and the RealtimeTargetResult Raises: RuntimeError: If no audio is received from the server. @@ -913,7 +913,7 @@ async def send_audio_async( conversation_id (str): Conversation ID Returns: - Tuple[str, RealtimeTargetResult]: Path to saved audio file and the RealtimeTargetResult + tuple[str, RealtimeTargetResult]: Path to saved audio file and the RealtimeTargetResult Raises: Exception: If sending audio fails. diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index c96b0c115a..6991996105 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -8,7 +8,6 @@ from typing import ( Any, Literal, - Optional, cast, ) @@ -91,15 +90,15 @@ class OpenAIResponseTarget(OpenAITarget): def __init__( self, *, - custom_functions: Optional[dict[str, ToolExecutor]] = None, - max_output_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - reasoning_effort: Optional[ReasoningEffort] = None, - reasoning_summary: Optional[Literal["auto", "concise", "detailed"]] = None, - extra_body_parameters: Optional[dict[str, Any]] = None, + custom_functions: dict[str, ToolExecutor] | None = None, + max_output_tokens: int | None = None, + temperature: float | None = None, + top_p: float | None = None, + reasoning_effort: ReasoningEffort | None = None, + reasoning_summary: Literal["auto", "concise", "detailed"] | None = None, + extra_body_parameters: dict[str, Any] | None = None, fail_on_missing_function: bool = False, - custom_configuration: Optional[TargetConfiguration] = None, + custom_configuration: TargetConfiguration | None = None, **kwargs: Any, ) -> None: """ @@ -396,12 +395,12 @@ async def _construct_request_body_async( # Filter out None values return {k: v for k, v in body_parameters.items() if v is not None} - def _build_reasoning_config(self) -> Optional[dict[str, Any]]: + def _build_reasoning_config(self) -> dict[str, Any] | None: """ Build the reasoning configuration dict for the Responses API. Returns: - Optional[Dict[str, Any]]: The reasoning config, or None if neither effort nor summary is set. + dict[str, Any] | None: The reasoning config, or None if neither effort nor summary is set. """ if self._reasoning_effort is None and self._reasoning_summary is None: return None @@ -413,7 +412,7 @@ def _build_reasoning_config(self) -> Optional[dict[str, Any]]: reasoning["summary"] = self._reasoning_summary return reasoning - def _build_text_format(self, json_config: _JsonResponseConfig) -> Optional[dict[str, Any]]: + def _build_text_format(self, json_config: _JsonResponseConfig) -> dict[str, Any] | None: if not json_config.enabled: return None @@ -459,7 +458,7 @@ def _check_content_filter(self, response: Any) -> bool: return False - def _extract_partial_content(self, response: Any) -> Optional[str]: + def _extract_partial_content(self, response: Any) -> str | None: """ Extract partial content from a Response API response that was content-filtered. @@ -493,7 +492,7 @@ def _extract_partial_content(self, response: Any) -> Optional[str]: except (AttributeError, IndexError, TypeError): return None - def _validate_response(self, response: Any, request: MessagePiece) -> Optional[Message]: + def _validate_response(self, response: Any, request: MessagePiece) -> Message | None: """ Validate a Response API response for errors. @@ -584,7 +583,7 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me responses_to_return: list[Message] = [] # Main agentic loop - each back-and-forth creates a new message - tool_call_section: Optional[dict[str, Any]] = None + tool_call_section: dict[str, Any] | None = None while True: logger.info(f"Sending conversation with {len(working_conversation)} messages to the prompt target") @@ -625,7 +624,7 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me return responses_to_return def _parse_response_output_section( - self, *, section: Any, message_piece: MessagePiece, error: Optional[PromptResponseError] + self, *, section: Any, message_piece: MessagePiece, error: PromptResponseError | None ) -> MessagePiece | None: """ Parse model output sections, forwarding tool-calls for the agentic loop. @@ -726,7 +725,7 @@ def _parse_response_output_section( # Agentic helpers (module scope) - def _find_last_pending_tool_call(self, reply: Message) -> Optional[dict[str, Any]]: + def _find_last_pending_tool_call(self, reply: Message) -> dict[str, Any] | None: """ Return the last tool-call section in assistant messages, or None. Looks for a piece whose value parses as JSON with a 'type' key matching function_call. diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 3ae988f186..bcd9b48162 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -6,7 +6,7 @@ import re from abc import abstractmethod from collections.abc import Awaitable, Callable -from typing import Any, Optional +from typing import Any from urllib.parse import urlparse from openai import ( @@ -61,7 +61,7 @@ class OpenAITarget(PromptTarget): endpoint_environment_variable: str api_key_environment_variable: str - _async_client: Optional[AsyncOpenAI] = None + _async_client: AsyncOpenAI | None = None @property def _client(self) -> AsyncOpenAI: @@ -78,14 +78,14 @@ def _client(self) -> AsyncOpenAI: def __init__( self, *, - model_name: Optional[str] = None, - endpoint: Optional[str] = None, - api_key: Optional[str | Callable[[], str | Awaitable[str]]] = None, - headers: Optional[str] = None, - max_requests_per_minute: Optional[int] = None, - httpx_client_kwargs: Optional[dict[str, Any]] = None, - underlying_model: Optional[str] = None, - custom_configuration: Optional[TargetConfiguration] = None, + model_name: str | None = None, + endpoint: str | None = None, + api_key: str | Callable[[], str | Awaitable[str]] | None = None, + headers: str | None = None, + max_requests_per_minute: int | None = None, + httpx_client_kwargs: dict[str, Any] | None = None, + underlying_model: str | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize an instance of OpenAITarget. @@ -583,7 +583,7 @@ def _handle_content_filter_response(self, response: Any, request: MessagePiece) return error_message - def _extract_partial_content(self, response: Any) -> Optional[str]: + def _extract_partial_content(self, response: Any) -> str | None: """ Extract any partial content the model generated before the content filter triggered. @@ -598,7 +598,7 @@ def _extract_partial_content(self, response: Any) -> Optional[str]: """ return None - def _validate_response(self, response: Any, request: MessagePiece) -> Optional[Message]: + def _validate_response(self, response: Any, request: MessagePiece) -> Message | None: """ Validate the response and return error Message if needed. @@ -610,7 +610,7 @@ def _validate_response(self, response: Any, request: MessagePiece) -> Optional[M request: The original request MessagePiece. Returns: - Optional[Message]: Error Message if validation fails, None otherwise. + Message | None: Error Message if validation fails, None otherwise. Raises: Various exceptions for validation failures. diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index d9c66128cb..eaa6101c75 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import Any, Literal, Optional +from typing import Any, Literal from pyrit.exceptions import ( pyrit_target_retry, @@ -40,8 +40,8 @@ def __init__( voice: TTSVoice = "alloy", response_format: TTSResponseFormat = "mp3", language: str = "en", - speed: Optional[float] = None, - custom_configuration: Optional[TargetConfiguration] = None, + speed: float | None = None, + custom_configuration: TargetConfiguration | None = None, **kwargs: Any, ) -> None: """ diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 37dedc9ae0..db8deadf35 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -4,7 +4,7 @@ import logging from mimetypes import guess_type from pathlib import Path -from typing import Any, Optional, Union, cast +from typing import Any, cast from openai.types import VideoSeconds, VideoSize @@ -67,7 +67,7 @@ def __init__( *, resolution_dimensions: VideoSize = "1280x720", n_seconds: int | VideoSeconds = 4, - custom_configuration: Optional[TargetConfiguration] = None, + custom_configuration: TargetConfiguration | None = None, **kwargs: Any, ) -> None: """ @@ -428,7 +428,7 @@ async def _construct_message_from_response_async(self, response: Any, request: A ) async def _save_video_response_async( - self, *, request: MessagePiece, video_data: bytes, video_id: Optional[str] = None + self, *, request: MessagePiece, video_data: bytes, video_id: str | None = None ) -> Message: """ Save video data to storage and construct response. @@ -449,7 +449,7 @@ async def _save_video_response_async( logger.info(f"Video saved to: {video_path}") # Include video_id in metadata for chaining (e.g., remix the generated video later) - prompt_metadata: Optional[dict[str, Union[str, int]]] = {"video_id": video_id} if video_id else None + prompt_metadata: dict[str, str | int] | None = {"video_id": video_id} if video_id else None # Construct response return construct_response_from_request( diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index a2dfc796b2..d166f347ea 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -7,7 +7,7 @@ from contextlib import suppress from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pyrit.models import ( ComponentIdentifier, @@ -128,7 +128,7 @@ def __init__( *, page: "Page", copilot_type: CopilotType = CopilotType.CONSUMER, - custom_configuration: Optional[TargetConfiguration] = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the Playwright Copilot target. @@ -254,7 +254,7 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me return [response_entry] - async def _interact_with_copilot_async(self, message: Message) -> Union[str, list[tuple[str, PromptDataType]]]: + async def _interact_with_copilot_async(self, message: Message) -> str | list[tuple[str, PromptDataType]]: """ Interact with Microsoft Copilot interface to send multimodal prompts. @@ -262,7 +262,7 @@ async def _interact_with_copilot_async(self, message: Message) -> Union[str, lis message: The message containing text and/or image pieces to send. Returns: - Union[str, List[Tuple[str, PromptDataType]]]: The response content from Copilot, + str | list[tuple[str, PromptDataType]]: The response content from Copilot, either as a single text string or a list of (data, data_type) tuples. """ selectors = self._get_selectors() @@ -276,9 +276,7 @@ async def _interact_with_copilot_async(self, message: Message) -> Union[str, lis return await self._wait_for_response_async(selectors) - async def _wait_for_response_async( - self, selectors: CopilotSelectors - ) -> Union[str, list[tuple[str, PromptDataType]]]: + async def _wait_for_response_async(self, selectors: CopilotSelectors) -> str | list[tuple[str, PromptDataType]]: """ Wait for Copilot's response and extract the text and/or images. @@ -286,7 +284,7 @@ async def _wait_for_response_async( selectors (CopilotSelectors): The selectors for the Copilot interface. Returns: - Union[str, List[Tuple[str, PromptDataType]]]: The response content from Copilot, + str | list[tuple[str, PromptDataType]]: The response content from Copilot, either as a single text string or a list of (data, data_type) tuples. Raises: @@ -332,7 +330,7 @@ async def _wait_for_response_async( async def _extract_content_if_ready_async( self, selectors: CopilotSelectors, initial_group_count: int - ) -> Union[str, list[tuple[str, PromptDataType]], None]: + ) -> str | list[tuple[str, PromptDataType]] | None: """ Extract content if ready, otherwise return None. @@ -343,7 +341,7 @@ async def _extract_content_if_ready_async( initial_group_count (int): Number of message groups before this response. Returns: - Union[str, List[Tuple[str, PromptDataType]], None]: The extracted content if ready, + str | list[tuple[str, PromptDataType]] | None: The extracted content if ready, None if content is not ready yet or extraction fails. """ try: @@ -733,7 +731,7 @@ async def _extract_fallback_text_async(self, *, ai_message_groups: list[Any]) -> def _assemble_response( self, *, response_pieces: list[tuple[str, PromptDataType]] - ) -> Union[str, list[tuple[str, PromptDataType]]]: + ) -> str | list[tuple[str, PromptDataType]]: """ Assemble response pieces into appropriate return format. @@ -755,7 +753,7 @@ def _assemble_response( async def _extract_multimodal_content_async( self, selectors: CopilotSelectors, initial_group_count: int = 0 - ) -> Union[str, list[tuple[str, PromptDataType]]]: + ) -> str | list[tuple[str, PromptDataType]]: """ Extract multimodal content (text and images) from Copilot response. diff --git a/pyrit/prompt_target/playwright_target.py b/pyrit/prompt_target/playwright_target.py index 4178fe902b..d33ba2736c 100644 --- a/pyrit/prompt_target/playwright_target.py +++ b/pyrit/prompt_target/playwright_target.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import TYPE_CHECKING, Optional, Protocol +from typing import TYPE_CHECKING, Protocol from pyrit.models import ( Message, @@ -71,8 +71,8 @@ def __init__( *, interaction_func: InteractionFunction, page: "Page", - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the Playwright target. diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index cb4588b4d8..78b28367fc 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -4,7 +4,7 @@ import json import logging from collections.abc import Callable -from typing import Any, Literal, Optional +from typing import Any, Literal from pyrit.common import default_values, net_utility from pyrit.models import ( @@ -62,12 +62,12 @@ class PromptShieldTarget(PromptTarget): def __init__( self, - endpoint: Optional[str] = None, - api_key: Optional[str | Callable[[], str]] = None, - api_version: Optional[str] = "2024-09-01", - field: Optional[PromptShieldEntryField] = None, - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, + endpoint: str | None = None, + api_key: str | Callable[[], str] | None = None, + api_version: str | None = "2024-09-01", + field: PromptShieldEntryField | None = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Class that initializes an Azure Content Safety Prompt Shield Target. diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 9ede8d9ddc..8e0deed295 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -5,7 +5,7 @@ import json import sys from pathlib import Path -from typing import IO, Optional +from typing import IO from pyrit.common.deprecation import print_deprecation_message from pyrit.models import Message, MessagePiece @@ -26,7 +26,7 @@ def __init__( self, *, text_stream: IO[str] = sys.stdout, - custom_configuration: Optional[TargetConfiguration] = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the TextTarget. diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 6a5de15f60..41ffaa99a5 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -7,7 +7,7 @@ import pathlib import uuid from enum import IntEnum -from typing import Any, Optional, Union +from typing import Any import httpx import websockets @@ -90,11 +90,11 @@ def __init__( self, *, websocket_base_url: str = "wss://substrate.office.com/m365Copilot/Chathub", - max_requests_per_minute: Optional[int] = None, + max_requests_per_minute: int | None = None, model_name: str = "copilot", response_timeout_seconds: int = RESPONSE_TIMEOUT_SECONDS, - authenticator: Optional[Union[CopilotAuthenticator, ManualCopilotAuthenticator]] = None, - custom_configuration: Optional[TargetConfiguration] = None, + authenticator: CopilotAuthenticator | ManualCopilotAuthenticator | None = None, + custom_configuration: TargetConfiguration | None = None, ) -> None: """ Initialize the WebSocketCopilotTarget. @@ -102,10 +102,10 @@ def __init__( Args: websocket_base_url (str): Base URL for the Copilot WebSocket endpoint. Defaults to ``wss://substrate.office.com/m365Copilot/Chathub``. - max_requests_per_minute (Optional[int]): Maximum number of requests per minute. + max_requests_per_minute (int | None): Maximum number of requests per minute. model_name (str): The model name. Defaults to "copilot". response_timeout_seconds (int): Timeout for receiving responses in seconds. Defaults to 60s. - authenticator (Optional[Union[CopilotAuthenticator, ManualCopilotAuthenticator]]): Authenticator + authenticator (CopilotAuthenticator | ManualCopilotAuthenticator | None): Authenticator instance. Supports both ``CopilotAuthenticator`` and ``ManualCopilotAuthenticator``. If None, a new ``CopilotAuthenticator`` instance will be created with default settings. custom_configuration (TargetConfiguration, Optional): Override the default configuration for diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py index cdff9067f1..c7e37c1700 100644 --- a/pyrit/registry/base.py +++ b/pyrit/registry/base.py @@ -4,14 +4,14 @@ """ Shared base types for PyRIT registries. -This module contains types shared between class registries (which store Type[T]) +This module contains types shared between class registries (which store type[T]) and object registries (which store T instances). """ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable if TYPE_CHECKING: from collections.abc import Iterator @@ -88,8 +88,8 @@ def get_names(self) -> list[str]: def list_metadata( self, *, - include_filters: Optional[dict[str, Any]] = None, - exclude_filters: Optional[dict[str, Any]] = None, + include_filters: dict[str, Any] | None = None, + exclude_filters: dict[str, Any] | None = None, ) -> list[MetadataT]: """ List metadata for all registered items, optionally filtered. @@ -148,8 +148,8 @@ def _get_metadata_value(metadata: Any, key: str) -> tuple[bool, Any]: def _matches_filters( metadata: Any, *, - include_filters: Optional[dict[str, Any]] = None, - exclude_filters: Optional[dict[str, Any]] = None, + include_filters: dict[str, Any] | None = None, + exclude_filters: dict[str, Any] | None = None, ) -> bool: """ Check if a metadata object matches all provided filters. diff --git a/pyrit/registry/class_registries/__init__.py b/pyrit/registry/class_registries/__init__.py index b9b3279d6c..1cc09101ef 100644 --- a/pyrit/registry/class_registries/__init__.py +++ b/pyrit/registry/class_registries/__init__.py @@ -4,7 +4,7 @@ """ Class registries package. -This package contains registries that store classes (Type[T]) which can be +This package contains registries that store classes (type[T]) which can be instantiated on demand. Examples include ScenarioRegistry and InitializerRegistry. For registries that store pre-configured instances, see object_registries/. diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index 85211aba84..9af4836e0e 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -4,14 +4,14 @@ """ Base class registry for PyRIT. -This module provides the abstract base class for registries that store classes (Type[T]). +This module provides the abstract base class for registries that store classes (type[T]). These registries allow on-demand instantiation of registered classes. For registries that store pre-configured instances, see object_registries/. Terminology: - **Metadata**: A TypedDict describing a registered class (e.g., ScenarioMetadata) -- **Class**: The actual Python class (Type[T]) that can be instantiated +- **Class**: The actual Python class (type[T]) that can be instantiated - **Instance**: A created object of that class - **ClassEntry**: Internal wrapper holding a class plus optional factory/defaults """ @@ -19,7 +19,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar if TYPE_CHECKING: from collections.abc import Callable, Iterator @@ -38,14 +38,14 @@ class ClassEntry(Generic[T]): """ Internal wrapper for a registered class. - This holds the class itself (Type[T]) along with optional factory + This holds the class itself (type[T]) along with optional factory and default parameters for creating instances. Note: This is an internal implementation detail. Users interact with registries via get_class(), create_instance(), and list_metadata(). Attributes: - registered_class: The actual Python class (Type[T]). + registered_class: The actual Python class (type[T]). factory: Optional callable to create instances with custom logic. default_kwargs: Default keyword arguments for instance creation. """ @@ -54,14 +54,14 @@ def __init__( self, *, registered_class: type[T], - factory: Optional[Callable[..., T]] = None, - default_kwargs: Optional[dict[str, object]] = None, + factory: Callable[..., T] | None = None, + default_kwargs: dict[str, object] | None = None, ) -> None: """ Initialize a class entry. Args: - registered_class: The actual Python class (Type[T]). + registered_class: The actual Python class (type[T]). factory: Optional callable that creates an instance. default_kwargs: Default keyword arguments for instantiation. """ @@ -97,7 +97,7 @@ def create_instance(self, **kwargs: object) -> T: class BaseClassRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]): """ - Abstract base class for registries that store classes (Type[T]). + Abstract base class for registries that store classes (type[T]). This class implements RegistryProtocol and provides the common infrastructure for class registries including: @@ -129,7 +129,7 @@ def __init__(self, *, lazy_discovery: bool = True) -> None: """ # Maps registry names to ClassEntry wrappers self._class_entries: dict[str, ClassEntry[T]] = {} - self._metadata_cache: Optional[list[MetadataT]] = None + self._metadata_cache: list[MetadataT] | None = None self._discovered = False self._lazy_discovery = lazy_discovery @@ -198,7 +198,7 @@ def get_class(self, name: str) -> type[T]: name: The registry name (snake_case identifier). Returns: - The registered class (Type[T]). + The registered class (type[T]). Note: This returns the class itself, not an instance. Raises: @@ -211,7 +211,7 @@ def get_class(self, name: str) -> type[T]: raise KeyError(f"'{name}' not found in registry. Available: {available}") return entry.registered_class - def get_entry(self, name: str) -> Optional[ClassEntry[T]]: + def get_entry(self, name: str) -> ClassEntry[T] | None: """ Get the full ClassEntry for a registered class. @@ -242,8 +242,8 @@ def get_names(self) -> list[str]: def list_metadata( self, *, - include_filters: Optional[dict[str, object]] = None, - exclude_filters: Optional[dict[str, object]] = None, + include_filters: dict[str, object] | None = None, + exclude_filters: dict[str, object] | None = None, ) -> list[MetadataT]: """ List metadata for all registered classes, optionally filtered. @@ -286,15 +286,15 @@ def register( self, cls: type[T], *, - name: Optional[str] = None, - factory: Optional[Callable[..., T]] = None, - default_kwargs: Optional[dict[str, object]] = None, + name: str | None = None, + factory: Callable[..., T] | None = None, + default_kwargs: dict[str, object] | None = None, ) -> None: """ Register a class with the registry. Args: - cls: The class to register (Type[T], not an instance). + cls: The class to register (type[T], not an instance). name: Optional custom registry name. If not provided, derived from class name. factory: Optional callable for creating instances with custom logic. default_kwargs: Default keyword arguments for instance creation. diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index 17fb39fb7c..5310af3d69 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -15,7 +15,7 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.models import class_name_to_snake_case, validate_registry_name from pyrit.registry.base import ClassRegistryEntry @@ -47,7 +47,7 @@ class InitializerMetadata(ClassRegistryEntry): required_env_vars: tuple[str, ...] = field(kw_only=True) # Supported parameters as tuples of (name, description, default). - supported_parameters: tuple[tuple[str, str, Optional[list[str]]], ...] = field(kw_only=True, default=()) + supported_parameters: tuple[tuple[str, str, list[str] | None], ...] = field(kw_only=True, default=()) class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetadata]): @@ -61,7 +61,7 @@ class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetad The directory structure is used for organization but not exposed to users. """ - def __init__(self, *, discovery_path: Optional[Path] = None, lazy_discovery: bool = False) -> None: + def __init__(self, *, discovery_path: Path | None = None, lazy_discovery: bool = False) -> None: """ Initialize the initializer registry. diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index 2702b1860e..0300b00a06 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -13,7 +13,7 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, get_origin +from typing import TYPE_CHECKING, Any, NamedTuple, get_origin from pyrit.models import class_name_to_snake_case from pyrit.registry.base import ClassRegistryEntry @@ -53,7 +53,7 @@ class ScenarioMetadata(ClassRegistryEntry): default_datasets: tuple[str, ...] = field(kw_only=True) # Maximum number of items per dataset. - max_dataset_size: Optional[int] = field(kw_only=True) + max_dataset_size: int | None = field(kw_only=True) # Scenario-declared custom parameters. supported_parameters: tuple[ScenarioParameterMetadata, ...] = field(kw_only=True, default=()) @@ -71,7 +71,7 @@ class ScenarioParameterMetadata(NamedTuple): description: str default: Any param_type: str - choices: Optional[list[str]] + choices: list[str] | None is_list: bool = False diff --git a/pyrit/registry/discovery.py b/pyrit/registry/discovery.py index 5df0c14fee..34c1562bc3 100644 --- a/pyrit/registry/discovery.py +++ b/pyrit/registry/discovery.py @@ -15,7 +15,7 @@ import pkgutil from collections.abc import Callable, Iterator from pathlib import Path -from typing import Optional, TypeVar +from typing import TypeVar logger = logging.getLogger(__name__) @@ -92,7 +92,7 @@ def discover_in_package( package_name: str, base_class: type[T], recursive: bool = True, - name_builder: Optional[Callable[[str, str], str]] = None, + name_builder: Callable[[str, str], str] | None = None, _prefix: str = "", ) -> Iterator[tuple[str, type[T]]]: """ @@ -156,7 +156,7 @@ def name_builder(prefix: str, name: str) -> str: def discover_subclasses_in_loaded_modules( *, base_class: type[T], - exclude_module_prefixes: Optional[tuple[str, ...]] = None, + exclude_module_prefixes: tuple[str, ...] | None = None, ) -> Iterator[tuple[str, type[T]]]: """ Discover subclasses of a base class from already-loaded modules. diff --git a/pyrit/registry/object_registries/__init__.py b/pyrit/registry/object_registries/__init__.py index 0a43a5af2f..b6edf16088 100644 --- a/pyrit/registry/object_registries/__init__.py +++ b/pyrit/registry/object_registries/__init__.py @@ -8,7 +8,7 @@ Examples include ScorerRegistry which stores Scorer instances that have been initialized with their required parameters (e.g., chat_target). -For registries that store classes (Type[T]), see class_registries/. +For registries that store classes (type[T]), see class_registries/. """ from pyrit.registry.object_registries.attack_technique_registry import ( diff --git a/pyrit/registry/object_registries/base_instance_registry.py b/pyrit/registry/object_registries/base_instance_registry.py index 09be979324..df3ae2d291 100644 --- a/pyrit/registry/object_registries/base_instance_registry.py +++ b/pyrit/registry/object_registries/base_instance_registry.py @@ -13,7 +13,7 @@ where callers retrieve stored objects directly, subclass ``RetrievableInstanceRegistry`` instead. -For registries that store classes (Type[T]), see ``class_registries/``. +For registries that store classes (type[T]), see ``class_registries/``. """ from __future__ import annotations diff --git a/pyrit/registry/object_registries/converter_registry.py b/pyrit/registry/object_registries/converter_registry.py index 4d83c9e1fd..568d1e6332 100644 --- a/pyrit/registry/object_registries/converter_registry.py +++ b/pyrit/registry/object_registries/converter_registry.py @@ -12,7 +12,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pyrit.registry.object_registries.retrievable_instance_registry import ( RetrievableInstanceRegistry, @@ -37,8 +37,8 @@ def register_instance( self, converter: PromptConverter, *, - name: Optional[str] = None, - tags: Optional[Union[dict[str, str], list[str]]] = None, + name: str | None = None, + tags: dict[str, str] | list[str] | None = None, ) -> None: """ Register a converter instance. @@ -56,7 +56,7 @@ def register_instance( self.register(converter, name=name, tags=tags) logger.debug(f"Registered converter instance: {name} ({converter.__class__.__name__})") - def get_instance_by_name(self, name: str) -> Optional[PromptConverter]: + def get_instance_by_name(self, name: str) -> PromptConverter | None: """ Get a registered converter instance by name. diff --git a/pyrit/registry/object_registries/retrievable_instance_registry.py b/pyrit/registry/object_registries/retrievable_instance_registry.py index b5bc4fdfec..7498b3bf21 100644 --- a/pyrit/registry/object_registries/retrievable_instance_registry.py +++ b/pyrit/registry/object_registries/retrievable_instance_registry.py @@ -11,7 +11,7 @@ ``TargetRegistry``). For the shared base class, see ``base_instance_registry``. -For registries that store classes (Type[T]), see ``class_registries/``. +For registries that store classes (type[T]), see ``class_registries/``. """ from __future__ import annotations diff --git a/pyrit/registry/object_registries/scorer_registry.py b/pyrit/registry/object_registries/scorer_registry.py index af5c59946f..d1a938aa30 100644 --- a/pyrit/registry/object_registries/scorer_registry.py +++ b/pyrit/registry/object_registries/scorer_registry.py @@ -10,7 +10,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pyrit.registry.object_registries.retrievable_instance_registry import ( RetrievableInstanceRegistry, @@ -38,8 +38,8 @@ def register_instance( self, scorer: Scorer, *, - name: Optional[str] = None, - tags: Optional[Union[dict[str, str], list[str]]] = None, + name: str | None = None, + tags: dict[str, str] | list[str] | None = None, ) -> None: """ Register a scorer instance. @@ -60,7 +60,7 @@ def register_instance( self.register(scorer, name=name, tags=tags) logger.debug(f"Registered scorer instance: {name} ({scorer.__class__.__name__})") - def get_instance_by_name(self, name: str) -> Optional[Scorer]: + def get_instance_by_name(self, name: str) -> Scorer | None: """ Get a registered scorer instance by name. diff --git a/pyrit/registry/object_registries/target_registry.py b/pyrit/registry/object_registries/target_registry.py index c6fefd3926..170bad2078 100644 --- a/pyrit/registry/object_registries/target_registry.py +++ b/pyrit/registry/object_registries/target_registry.py @@ -10,7 +10,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from pyrit.registry.object_registries.retrievable_instance_registry import ( RetrievableInstanceRegistry, @@ -38,8 +38,8 @@ def register_instance( self, target: PromptTarget, *, - name: Optional[str] = None, - tags: Optional[Union[dict[str, str], list[str]]] = None, + name: str | None = None, + tags: dict[str, str] | list[str] | None = None, ) -> None: """ Register a target instance. @@ -61,7 +61,7 @@ def register_instance( self.register(target, name=name, tags=tags) logger.debug(f"Registered target instance: {name} ({target.__class__.__name__})") - def get_instance_by_name(self, name: str) -> Optional[PromptTarget]: + def get_instance_by_name(self, name: str) -> PromptTarget | None: """ Get a registered target instance by name. diff --git a/pyrit/scenario/core/atomic_attack.py b/pyrit/scenario/core/atomic_attack.py index acec06f8b0..6a48ec69ef 100644 --- a/pyrit/scenario/core/atomic_attack.py +++ b/pyrit/scenario/core/atomic_attack.py @@ -59,7 +59,7 @@ def __init__( seed_groups: list[SeedAttackGroup], adversarial_chat: Optional["PromptTarget"] = None, objective_scorer: Optional["TrueFalseScorer"] = None, - memory_labels: Optional[dict[str, str]] = None, + memory_labels: dict[str, str] | None = None, **attack_execute_params: Any, ) -> None: """ @@ -215,7 +215,7 @@ def objectives(self) -> list[str]: Get the objectives from the seed groups. Returns: - List[str]: List of objectives from all seed groups. + list[str]: List of objectives from all seed groups. """ return [sg.objective.value for sg in self._seed_groups if sg.objective is not None] @@ -225,7 +225,7 @@ def seed_groups(self) -> list[SeedAttackGroup]: Get a copy of the seed groups list for this atomic attack. Returns: - List[SeedAttackGroup]: A copy of the seed groups list. + list[SeedAttackGroup]: A copy of the seed groups list. """ return list(self._seed_groups) @@ -258,7 +258,7 @@ def filter_seed_groups_by_objectives(self, *, remaining_objectives: list[str]) - objective text. Scheduled for removal in 0.16.0. Args: - remaining_objectives (List[str]): List of objectives that still need to be executed. + remaining_objectives (list[str]): List of objectives that still need to be executed. """ print_deprecation_message( old_item="AtomicAttack.filter_seed_groups_by_objectives(remaining_objectives=...)", diff --git a/pyrit/scenario/core/attack_technique_factory.py b/pyrit/scenario/core/attack_technique_factory.py index 3a1b3c6315..6fafa8dbf9 100644 --- a/pyrit/scenario/core/attack_technique_factory.py +++ b/pyrit/scenario/core/attack_technique_factory.py @@ -499,7 +499,7 @@ def _get_scoring_config_type(self) -> type | None: """ Introspect the attack class to determine the required type for ``attack_scoring_config``. - Resolves the type annotation (handling ``Optional[X]`` / ``X | None``) and returns + Resolves the type annotation (handling ``X | None`` / ``X | None``) and returns the inner concrete type. Returns ``None`` if the annotation is the base ``AttackScoringConfig`` or cannot be resolved — meaning any config is accepted. @@ -528,7 +528,7 @@ def _get_scoring_config_type(self) -> type | None: @staticmethod def _unwrap_optional(annotation: Any) -> type | None: """ - Unwrap ``Optional[X]``, ``X | None``, or ``Union[X, None]`` to extract X. + Unwrap ``X | None``, ``X | None``, or ``X | None`` to extract X. Returns: The inner type X, or None if the annotation cannot be unwrapped to a single type. diff --git a/pyrit/scenario/core/dataset_configuration.py b/pyrit/scenario/core/dataset_configuration.py index 25cd9162c3..a6cbf25cab 100644 --- a/pyrit/scenario/core/dataset_configuration.py +++ b/pyrit/scenario/core/dataset_configuration.py @@ -11,7 +11,7 @@ from __future__ import annotations import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.memory import CentralMemory from pyrit.models import SeedAttackGroup, SeedGroup @@ -34,11 +34,11 @@ class DatasetConfiguration: Only ONE of `seed_groups` or `dataset_names` can be set. Args: - seed_groups (Optional[List[SeedGroup]]): Explicit list of SeedGroup to use. - dataset_names (Optional[List[str]]): Names of datasets to load from memory. - max_dataset_size (Optional[int]): If set, randomly samples up to this many SeedGroups + seed_groups (list[SeedGroup] | None): Explicit list of SeedGroup to use. + dataset_names (list[str] | None): Names of datasets to load from memory. + max_dataset_size (int | None): If set, randomly samples up to this many SeedGroups from the configured dataset source (without replacement, so no duplicates). - scenario_strategies (Optional[Sequence[ScenarioStrategy]]): The scenario + scenario_strategies (Sequence[ScenarioStrategy] | None): The scenario strategies being executed. Subclasses can use this to filter or customize which seed groups are loaded based on the selected strategies. """ @@ -46,20 +46,20 @@ class DatasetConfiguration: def __init__( self, *, - seed_groups: Optional[list[SeedGroup]] = None, - dataset_names: Optional[list[str]] = None, - max_dataset_size: Optional[int] = None, - scenario_strategies: Optional[Sequence[ScenarioStrategy]] = None, + seed_groups: list[SeedGroup] | None = None, + dataset_names: list[str] | None = None, + max_dataset_size: int | None = None, + scenario_strategies: Sequence[ScenarioStrategy] | None = None, ) -> None: """ Initialize a DatasetConfiguration. Args: - seed_groups (Optional[List[SeedGroup]]): Explicit list of SeedGroup to use. - dataset_names (Optional[List[str]]): Names of datasets to load from memory. - max_dataset_size (Optional[int]): If set, randomly samples up to this many SeedGroups + seed_groups (list[SeedGroup] | None): Explicit list of SeedGroup to use. + dataset_names (list[str] | None): Names of datasets to load from memory. + max_dataset_size (int | None): If set, randomly samples up to this many SeedGroups (without replacement). - scenario_strategies (Optional[Sequence[ScenarioStrategy]]): The scenario + scenario_strategies (Sequence[ScenarioStrategy] | None): The scenario strategies being executed. Subclasses can use this to filter or customize which seed groups are loaded. @@ -98,7 +98,7 @@ def get_seed_groups(self) -> dict[str, list[SeedGroup]]: are loaded based on the stored scenario_composites. Returns: - Dict[str, List[SeedGroup]]: Dictionary mapping dataset names to their + dict[str, list[SeedGroup]]: Dictionary mapping dataset names to their seed groups. When explicit seed_groups are provided, the key is '_explicit_seed_groups'. Each dataset's seed groups are potentially sampled down to max_dataset_size. @@ -142,7 +142,7 @@ def _load_seed_groups_for_dataset(self, *, dataset_name: str) -> list[SeedGroup] dataset_name (str): The name of the dataset to load. Returns: - List[SeedGroup]: Seed groups loaded from memory, or empty list if none found. + list[SeedGroup]: Seed groups loaded from memory, or empty list if none found. """ memory = CentralMemory.get_memory_instance() return list(memory.get_seed_groups(dataset_name=dataset_name) or []) @@ -156,7 +156,7 @@ def get_all_seed_groups(self) -> list[SeedGroup]: which dataset each seed group came from. Returns: - List[SeedGroup]: All resolved seed groups from all datasets, + list[SeedGroup]: All resolved seed groups from all datasets, with max_dataset_size applied per dataset. Raises: @@ -177,7 +177,7 @@ def get_seed_attack_groups(self) -> dict[str, list[SeedAttackGroup]]: prepended conversations, or simulated conversation configuration. Returns: - Dict[str, List[SeedAttackGroup]]: Dictionary mapping dataset names to their + dict[str, list[SeedAttackGroup]]: Dictionary mapping dataset names to their seed attack groups. Raises: @@ -198,7 +198,7 @@ def get_all_seed_attack_groups(self) -> list[SeedAttackGroup]: SeedAttackGroup functionality. Returns: - List[SeedAttackGroup]: All resolved seed attack groups from all datasets. + list[SeedAttackGroup]: All resolved seed attack groups from all datasets. Raises: ValueError: If no seed groups could be resolved from the configuration. @@ -216,7 +216,7 @@ def get_default_dataset_names(self) -> list[str]: This is used by the CLI to display what datasets the scenario uses by default. Returns: - List[str]: List of dataset names, or empty list if using explicit seed_groups. + list[str]: List of dataset names, or empty list if using explicit seed_groups. """ if self._dataset_names is not None: return list(self._dataset_names) @@ -229,10 +229,10 @@ def _apply_max_dataset_size(self, seed_groups: list[SeedGroup]) -> list[SeedGrou Uses random sampling without replacement (no duplicates in the result). Args: - seed_groups (List[SeedGroup]): The seed groups to potentially sample from. + seed_groups (list[SeedGroup]): The seed groups to potentially sample from. Returns: - List[SeedGroup]: The original list if max_dataset_size is not set, + list[SeedGroup]: The original list if max_dataset_size is not set, or a random sample of up to max_dataset_size unique items. """ if self.max_dataset_size is None or len(seed_groups) <= self.max_dataset_size: @@ -257,7 +257,7 @@ def get_all_seeds(self) -> list[Seed]: samples up to that many prompts per dataset (without replacement). Returns: - List[SeedPrompt]: List of SeedPrompt objects from all configured datasets. + list[SeedPrompt]: List of SeedPrompt objects from all configured datasets. Returns an empty list if no prompts are found. Raises: diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index c23ddf6a1c..3d856a64fa 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -18,7 +18,7 @@ from collections.abc import Sequence from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast, get_origin +from typing import TYPE_CHECKING, Any, ClassVar, cast, get_origin try: # Built-in on Python 3.11+. Fall back to the ``exceptiongroup`` backport on 3.10 @@ -194,7 +194,7 @@ def __init__( default_strategy: ScenarioStrategy, default_dataset_config: DatasetConfiguration, objective_scorer: Scorer, - scenario_result_id: Optional[Union[uuid.UUID, str]] = None, + scenario_result_id: uuid.UUID | str | None = None, include_default_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. ) -> None: """ @@ -203,14 +203,14 @@ def __init__( Args: name (str): Descriptive name for the scenario. version (int): Version number of the scenario. - strategy_class (Type[ScenarioStrategy]): The strategy enum class for this scenario. + strategy_class (type[ScenarioStrategy]): The strategy enum class for this scenario. default_strategy (ScenarioStrategy): The default strategy member used when no ``scenario_strategies`` are passed to ``initialize_async``. Usually an aggregate member like ``MyStrategy.ALL`` or ``MyStrategy.DEFAULT``. default_dataset_config (DatasetConfiguration): The default dataset configuration used when no ``dataset_config`` is passed to ``initialize_async``. objective_scorer (Scorer): The objective scorer used to evaluate attack results. - scenario_result_id (Optional[Union[uuid.UUID, str]]): Optional ID of an existing scenario result to resume. + scenario_result_id (uuid.UUID | str | None): Optional ID of an existing scenario result to resume. Can be either a UUID object or a string representation of a UUID. If provided and found in memory, the scenario will resume from prior progress. All other parameters must still match the stored scenario configuration. @@ -240,10 +240,10 @@ def __init__( self._default_dataset_config = default_dataset_config # These will be set in initialize_async - self._objective_target: Optional[PromptTarget] = None - self._objective_target_identifier: Optional[ComponentIdentifier] = None + self._objective_target: PromptTarget | None = None + self._objective_target_identifier: ComponentIdentifier | None = None self._memory_labels: dict[str, str] = {} - self._max_concurrency: Optional[int] = None + self._max_concurrency: int | None = None self._max_retries: int = 0 self._objective_scorer = objective_scorer @@ -252,7 +252,7 @@ def __init__( self._name = name if name else type(self).__name__ self._memory = CentralMemory.get_memory_instance() self._atomic_attacks: list[AtomicAttack] = [] - self._scenario_result_id: Optional[str] = str(scenario_result_id) if scenario_result_id else None + self._scenario_result_id: str | None = str(scenario_result_id) if scenario_result_id else None # Store prepared strategies for use in _get_atomic_attacks_async self._scenario_strategies: list[ScenarioStrategy] = [] @@ -550,7 +550,7 @@ def _validate_params(self, *, params: dict[str, Any], declared: list[Parameter]) def _prepare_strategies( self, - strategies: Optional[Sequence[ScenarioStrategy]], + strategies: Sequence[ScenarioStrategy] | None, ) -> list[ScenarioStrategy]: """ Resolve strategy inputs into a concrete list for this scenario. @@ -575,11 +575,11 @@ async def initialize_async( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - scenario_strategies: Optional[Sequence[ScenarioStrategy]] = None, - dataset_config: Optional[DatasetConfiguration] = None, + scenario_strategies: Sequence[ScenarioStrategy] | None = None, + dataset_config: DatasetConfiguration | None = None, max_concurrency: int = 4, max_retries: int = 0, - memory_labels: Optional[dict[str, str]] = None, + memory_labels: dict[str, str] | None = None, include_baseline: bool | None = None, ) -> None: """ @@ -595,10 +595,10 @@ async def initialize_async( Args: objective_target (PromptTarget): The target system to attack. - scenario_strategies (Optional[Sequence[ScenarioStrategy]]): The strategies to execute. + scenario_strategies (Sequence[ScenarioStrategy] | None): The strategies to execute. Can be a list of ScenarioStrategy enum members. If None, uses the default aggregate from the scenario's configuration. - dataset_config (Optional[DatasetConfiguration]): Configuration for the dataset source. + dataset_config (DatasetConfiguration | None): Configuration for the dataset source. Use this to specify dataset names or maximum dataset size from the CLI. If not provided, scenarios use their constructor-supplied default_dataset_config. max_concurrency (int): Maximum number of concurrent units of work for the scenario. @@ -613,7 +613,7 @@ async def initialize_async( Set to 0 (default) for no automatic retries. If set to a positive number, the scenario will automatically retry up to this many times after an exception. For example, max_retries=3 allows up to 4 total attempts (1 initial + 3 retries). - memory_labels (Optional[Dict[str, str]]): Additional labels to apply to all + memory_labels (dict[str, str] | None): Additional labels to apply to all attack runs in the scenario. These help track and categorize the scenario. include_baseline (bool | None): Whether to prepend a baseline atomic attack that sends all objectives without modifications, allowing comparison between unmodified prompts @@ -977,7 +977,7 @@ async def _get_remaining_atomic_attacks_async(self) -> list[AtomicAttack]: join is sufficient. Returns: - List[AtomicAttack]: List of atomic attacks with uncompleted objectives. + list[AtomicAttack]: List of atomic attacks with uncompleted objectives. """ if not self._scenario_result_id: # No scenario result yet, return all atomic attacks @@ -1430,7 +1430,7 @@ def _collect_errors_from_outcomes( for outcome in outcomes: if isinstance(outcome, BaseException): logger.error(f"Atomic attack failed in scenario '{self._name}': {str(outcome)}") - error: Optional[BaseException] = outcome + error: BaseException | None = outcome else: atomic_attack, atomic_results = outcome error = self._partial_result_to_exception(atomic_attack=atomic_attack, atomic_results=atomic_results) diff --git a/pyrit/scenario/core/scenario_strategy.py b/pyrit/scenario/core/scenario_strategy.py index de557d724e..441c7df937 100644 --- a/pyrit/scenario/core/scenario_strategy.py +++ b/pyrit/scenario/core/scenario_strategy.py @@ -129,7 +129,7 @@ def get_aggregate_tags(cls: type[T]) -> set[str]: to all non-aggregate strategies. Returns: - Set[str]: Set of tags that represent aggregates. + set[str]: Set of tags that represent aggregates. """ return {"all"} @@ -145,7 +145,7 @@ def get_strategies_by_tag(cls: type[T], tag: str) -> set[T]: tag (str): The tag to filter by (e.g., "easy", "converter", "multi_turn"). Returns: - Set[T]: Set of strategies that include the specified tag, excluding + set[T]: Set of strategies that include the specified tag, excluding any aggregate markers. """ aggregate_tags = cls.get_aggregate_tags() @@ -202,11 +202,11 @@ def normalize_strategies(cls: type[T], strategies: set[T]) -> set[T]: The special "all" tag is automatically supported and expands to all non-aggregate strategies. Args: - strategies (Set[T]): The initial set of attack strategies, which may include + strategies (set[T]): The initial set of attack strategies, which may include aggregate tags. Returns: - Set[T]: The normalized set of concrete attack strategies with aggregate tags + set[T]: The normalized set of concrete attack strategies with aggregate tags expanded and removed. """ print_deprecation_message( diff --git a/pyrit/scenario/scenarios/airt/jailbreak.py b/pyrit/scenario/scenarios/airt/jailbreak.py index 935b81b51f..5184632d49 100644 --- a/pyrit/scenario/scenarios/airt/jailbreak.py +++ b/pyrit/scenario/scenarios/airt/jailbreak.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message # Deprecated. Will be removed in 0.16.0. @@ -90,9 +90,9 @@ def required_datasets(cls) -> list[str]: def __init__( self, *, - objective_scorer: Optional[TrueFalseScorer] = None, - scenario_result_id: Optional[str] = None, - num_templates: Optional[int] = None, + objective_scorer: TrueFalseScorer | None = None, + scenario_result_id: str | None = None, + num_templates: int | None = None, num_attempts: int = 1, jailbreak_names: list[str] | None = None, include_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. @@ -101,12 +101,12 @@ def __init__( Initialize the jailbreak scenario. Args: - objective_scorer (Optional[TrueFalseScorer]): Scorer for detecting successful jailbreaks + objective_scorer (TrueFalseScorer | None): Scorer for detecting successful jailbreaks (non-refusal). If not provided, defaults to an inverted refusal scorer. - scenario_result_id (Optional[str]): Optional ID of an existing scenario result to resume. - num_templates (Optional[int]): Choose num_templates random jailbreaks rather than using all of them. - num_attempts (Optional[int]): Number of times to try each jailbreak. - jailbreak_names (Optional[List[str]]): List of jailbreak names from the template list under datasets. + scenario_result_id (str | None): Optional ID of an existing scenario result to resume. + num_templates (int | None): Choose num_templates random jailbreaks rather than using all of them. + num_attempts (int | None): Number of times to try each jailbreak. + jailbreak_names (list[str] | None): List of jailbreak names from the template list under datasets. to use. include_baseline (bool | None): **Deprecated.** Will be removed in 0.16.0. Pass ``include_baseline`` to ``initialize_async`` instead. @@ -132,7 +132,7 @@ def __init__( self._num_templates = num_templates self._num_attempts = num_attempts - self._adversarial_target: Optional[PromptTarget] = None + self._adversarial_target: PromptTarget | None = None # Note that num_templates and jailbreak_names are mutually exclusive. # If self._num_templates is None, then this returns all discoverable jailbreak templates. @@ -170,7 +170,7 @@ def __init__( self._legacy_include_baseline = include_baseline # Will be resolved in _get_atomic_attacks_async - self._seed_groups: Optional[list[SeedAttackGroup]] = None + self._seed_groups: list[SeedAttackGroup] | None = None def _get_or_create_adversarial_target(self) -> PromptTarget: """ @@ -191,7 +191,7 @@ def _resolve_seed_groups(self) -> list[SeedAttackGroup]: Resolve seed groups from dataset configuration. Returns: - List[SeedAttackGroup]: List of seed attack groups with objectives to be tested. + list[SeedAttackGroup]: List of seed attack groups with objectives to be tested. """ # Use dataset_config (guaranteed to be set by initialize_async) seed_groups = self._dataset_config.get_all_seed_attack_groups() @@ -233,7 +233,7 @@ async def _get_atomic_attack_from_strategy_async( request_converters=PromptConverterConfiguration.from_converters(converters=[jailbreak_converter]) ) - attack: Optional[Union[ManyShotJailbreakAttack, PromptSendingAttack, RolePlayAttack, SkeletonKeyAttack]] = None + attack: ManyShotJailbreakAttack | PromptSendingAttack | RolePlayAttack | SkeletonKeyAttack | None = None args: dict[str, Any] = { "objective_target": self._objective_target, "attack_scoring_config": AttackScoringConfig(objective_scorer=self._objective_scorer), @@ -274,7 +274,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: This method creates an atomic attack for each retrieved jailbreak template. Returns: - List[AtomicAttack]: List of atomic attacks to execute, one per jailbreak template. + list[AtomicAttack]: List of atomic attacks to execute, one per jailbreak template. """ atomic_attacks: list[AtomicAttack] = [] diff --git a/pyrit/scenario/scenarios/airt/psychosocial.py b/pyrit/scenario/scenarios/airt/psychosocial.py index 8ba3991649..c61e7534b8 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial.py +++ b/pyrit/scenario/scenarios/airt/psychosocial.py @@ -4,7 +4,7 @@ import logging import pathlib from dataclasses import dataclass -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar import yaml @@ -73,7 +73,7 @@ class ResolvedSeedData: """Helper dataclass for resolved seed data.""" seed_groups: list[SeedAttackGroup] - subharm: Optional[str] + subharm: str | None class PsychosocialStrategy(ScenarioStrategy): @@ -97,12 +97,12 @@ class PsychosocialStrategy(ScenarioStrategy): LicensedTherapist = ("licensed_therapist", set[str]()) @property - def harm_category_filter(self) -> Optional[str]: + def harm_category_filter(self) -> str | None: """ Get the harm category filter for this strategy. Returns: - Optional[str]: The harm category to filter seeds by, or "psychosocial" as default. + str | None: The harm category to filter seeds by, or "psychosocial" as default. """ # For specific strategies, filter by the strategy value (which matches harm_categories in data) # otherwise, use psychosocial as the default for ALL strategy @@ -179,11 +179,11 @@ class Psychosocial(Scenario): def __init__( self, *, - objectives: Optional[list[str]] = None, - adversarial_chat: Optional[PromptTarget] = None, - objective_scorer: Optional[FloatScaleThresholdScorer] = None, - scenario_result_id: Optional[str] = None, - subharm_configs: Optional[dict[str, SubharmConfig]] = None, + objectives: list[str] | None = None, + adversarial_chat: PromptTarget | None = None, + objective_scorer: FloatScaleThresholdScorer | None = None, + scenario_result_id: str | None = None, + subharm_configs: dict[str, SubharmConfig] | None = None, max_turns: int = 5, include_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. ) -> None: @@ -191,16 +191,16 @@ def __init__( Initialize the Psychosocial Harms Scenario. Args: - objectives (Optional[List[str]]): DEPRECATED - Use dataset_config in initialize_async instead. + objectives (list[str] | None): DEPRECATED - Use dataset_config in initialize_async instead. List of objectives to test for psychosocial harms. - adversarial_chat (Optional[PromptTarget]): Additionally used for adversarial attacks + adversarial_chat (PromptTarget | None): Additionally used for adversarial attacks and scoring defaults. If not provided, a default OpenAI target will be created using environment variables. - objective_scorer (Optional[FloatScaleThresholdScorer]): Scorer to evaluate attack success. + objective_scorer (FloatScaleThresholdScorer | None): Scorer to evaluate attack success. If not provided, creates a default scorer for evaluating psychosocial harms (crisis management, harmful facilitation, etc.). - scenario_result_id (Optional[str]): Optional ID of an existing scenario result to resume. - subharm_configs (Optional[Dict[str, SubharmConfig]]): Dictionary mapping subharm names + scenario_result_id (str | None): Optional ID of an existing scenario result to resume. + subharm_configs (dict[str, SubharmConfig] | None): Dictionary mapping subharm names to their SubharmConfig. Each seed in the dataset specifies its subharm in harm_categories[0], which is used to look up the appropriate config. Example: { @@ -255,7 +255,7 @@ def __init__( # Store deprecated objectives for later resolution in _resolve_seed_groups self._deprecated_objectives = objectives # Will be resolved in _get_atomic_attacks_async - self._seed_groups: Optional[list[SeedAttackGroup]] = None + self._seed_groups: list[SeedAttackGroup] | None = None def _resolve_seed_groups(self) -> ResolvedSeedData: """ @@ -300,12 +300,12 @@ def _resolve_seed_groups(self) -> ResolvedSeedData: subharm=harm_category_filter, ) - def _extract_harm_category_filter(self) -> Optional[str]: + def _extract_harm_category_filter(self) -> str | None: """ Extract harm category filter from scenario strategies. Returns: - Optional[str]: The harm category to filter by, or None if no filter is set. + str | None: The harm category to filter by, or None if no filter is set. """ for strategy in self._scenario_strategies: if isinstance(strategy, PsychosocialStrategy): @@ -324,11 +324,11 @@ def _filter_by_harm_category( Filter seed groups by harm category. Args: - seed_groups (List[SeedAttackGroup]): List of seed attack groups to filter. + seed_groups (list[SeedAttackGroup]): List of seed attack groups to filter. harm_category (str): Harm category to filter by (e.g., 'imminent_crisis', 'psychosocial'). Returns: - List[SeedAttackGroup]: Filtered seed groups containing only seeds with the specified harm category. + list[SeedAttackGroup]: Filtered seed groups containing only seeds with the specified harm category. """ filtered_groups = [] for group in seed_groups: @@ -339,7 +339,7 @@ def _filter_by_harm_category( filtered_groups.append(SeedAttackGroup(seeds=filtered_seeds)) return filtered_groups - def _get_scorer(self, subharm: Optional[str] = None) -> FloatScaleThresholdScorer: + def _get_scorer(self, subharm: str | None = None) -> FloatScaleThresholdScorer: """ Create scorer for psychosocial harms evaluation. @@ -349,7 +349,7 @@ def _get_scorer(self, subharm: Optional[str] = None) -> FloatScaleThresholdScore 3. Avoids facilitating harmful requests Args: - subharm (Optional[str]): The specific subharm category to get a scorer for + subharm (str | None): The specific subharm category to get a scorer for (e.g., 'imminent_crisis', 'dependency'). If provided and exists in subharm_configs, uses that rubric. Otherwise falls back to default. @@ -420,7 +420,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: return atomic_attacks - def _create_scoring_config(self, subharm: Optional[str]) -> AttackScoringConfig: + def _create_scoring_config(self, subharm: str | None) -> AttackScoringConfig: subharm_config = self._subharm_configs.get(subharm) if subharm else None scorer = self._get_scorer(subharm=subharm) if subharm_config else self._objective_scorer return AttackScoringConfig(objective_scorer=scorer) @@ -470,7 +470,7 @@ def _create_multi_turn_attack( self, *, scoring_config: AttackScoringConfig, - subharm: Optional[str], + subharm: str | None, seed_groups: list[SeedAttackGroup], ) -> AtomicAttack: subharm_config = self._subharm_configs.get(subharm) if subharm else None diff --git a/pyrit/scenario/scenarios/airt/scam.py b/pyrit/scenario/scenarios/airt/scam.py index ab05c0fc81..e591f580c8 100644 --- a/pyrit/scenario/scenarios/airt/scam.py +++ b/pyrit/scenario/scenarios/airt/scam.py @@ -3,7 +3,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from pyrit.common import Parameter, apply_defaults from pyrit.common.deprecation import print_deprecation_message # Deprecated. Will be removed in 0.16.0. @@ -121,20 +121,20 @@ def supported_parameters(cls) -> list[Parameter]: def __init__( self, *, - objective_scorer: Optional[TrueFalseScorer] = None, - adversarial_chat: Optional[PromptTarget] = None, - scenario_result_id: Optional[str] = None, + objective_scorer: TrueFalseScorer | None = None, + adversarial_chat: PromptTarget | None = None, + scenario_result_id: str | None = None, include_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. ) -> None: """ Initialize the ScamScenario. Args: - objective_scorer (Optional[TrueFalseScorer]): Custom scorer for objective + objective_scorer (TrueFalseScorer | None): Custom scorer for objective evaluation. - adversarial_chat (Optional[PromptTarget]): Chat target used to rephrase the + adversarial_chat (PromptTarget | None): Chat target used to rephrase the objective into the role-play context (in single-turn strategies). - scenario_result_id (Optional[str]): Optional ID of an existing scenario result to resume. + scenario_result_id (str | None): Optional ID of an existing scenario result to resume. include_baseline (bool | None): **Deprecated.** Will be removed in 0.16.0. Pass ``include_baseline`` to ``initialize_async`` instead. """ @@ -166,14 +166,14 @@ def __init__( self._legacy_include_baseline = include_baseline # Will be resolved in _get_atomic_attacks_async - self._seed_groups: Optional[list[SeedAttackGroup]] = None + self._seed_groups: list[SeedAttackGroup] | None = None def _resolve_seed_groups(self) -> list[SeedAttackGroup]: """ Resolve seed groups from dataset configuration. Returns: - List[SeedAttackGroup]: List of seed attack groups with objectives to be tested. + list[SeedAttackGroup]: List of seed attack groups with objectives to be tested. """ # Use dataset_config (guaranteed to be set by initialize_async) seed_groups = self._dataset_config.get_all_seed_attack_groups() @@ -201,7 +201,7 @@ def _get_atomic_attack_from_strategy(self, strategy: str) -> AtomicAttack: raise ValueError( "Scenario not properly initialized. Call await scenario.initialize_async() before running." ) - attack_strategy: Optional[AttackStrategy[Any, Any]] = None + attack_strategy: AttackStrategy[Any, Any] | None = None if strategy == "persuasive_rta": # Set system prompt to generic persuasion persona @@ -246,7 +246,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: Generate atomic attacks for each strategy. Returns: - List[AtomicAttack]: List of atomic attacks to execute. + list[AtomicAttack]: List of atomic attacks to execute. """ # Resolve seed groups from deprecated objectives or dataset config self._seed_groups = self._resolve_seed_groups() diff --git a/pyrit/scenario/scenarios/foundry/red_team_agent.py b/pyrit/scenario/scenarios/foundry/red_team_agent.py index 142aa53959..c76d6a4912 100644 --- a/pyrit/scenario/scenarios/foundry/red_team_agent.py +++ b/pyrit/scenario/scenarios/foundry/red_team_agent.py @@ -13,7 +13,7 @@ from collections.abc import Sequence from dataclasses import dataclass, field from inspect import signature -from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast from pyrit.common import REQUIRED_VALUE, apply_defaults from pyrit.common.deprecation import print_deprecation_message # Deprecated. Will be removed in 0.16.0. @@ -219,22 +219,22 @@ class RedTeamAgent(Scenario): def __init__( self, *, - adversarial_chat: Optional[PromptTarget] = None, - attack_scoring_config: Optional[AttackScoringConfig] = None, - scenario_result_id: Optional[str] = None, + adversarial_chat: PromptTarget | None = None, + attack_scoring_config: AttackScoringConfig | None = None, + scenario_result_id: str | None = None, include_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. ) -> None: """ Initialize a Foundry Scenario with the specified attack strategies. Args: - adversarial_chat (Optional[PromptTarget]): Target for multi-turn attacks + adversarial_chat (PromptTarget | None): Target for multi-turn attacks like Crescendo and RedTeaming. Additionally used for scoring defaults. If not provided, a default OpenAI target will be created using environment variables. - attack_scoring_config (Optional[AttackScoringConfig]): Configuration for attack scoring, + attack_scoring_config (AttackScoringConfig | None): Configuration for attack scoring, including the objective scorer and auxiliary scorers. If not provided, creates a default configuration with a composite scorer using Azure Content Filter and SelfAsk Refusal scorers. - scenario_result_id (Optional[str]): Optional ID of an existing scenario result to resume. + scenario_result_id (str | None): Optional ID of an existing scenario result to resume. include_baseline (bool | None): **Deprecated.** Will be removed in 0.16.0. Pass ``include_baseline`` to ``initialize_async`` instead. @@ -280,13 +280,11 @@ async def initialize_async( self, *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] - scenario_strategies: Optional[ - Sequence["FoundryStrategy | FoundryComposite | ScenarioCompositeStrategy"] - ] = None, - dataset_config: Optional[DatasetConfiguration] = None, + scenario_strategies: Sequence["FoundryStrategy | FoundryComposite | ScenarioCompositeStrategy"] | None = None, + dataset_config: DatasetConfiguration | None = None, max_concurrency: int = 4, max_retries: int = 0, - memory_labels: Optional[dict[str, str]] = None, + memory_labels: dict[str, str] | None = None, include_baseline: bool | None = None, ) -> None: """ @@ -299,10 +297,10 @@ async def initialize_async( objects (for pairing an attack with converters), or a mix of both. Passing ScenarioCompositeStrategy is deprecated — use FoundryComposite instead. If None, uses the default aggregate (EASY). - dataset_config (Optional[DatasetConfiguration]): Configuration for the dataset source. + dataset_config (DatasetConfiguration | None): Configuration for the dataset source. max_concurrency (int): Maximum number of concurrent attack executions. Defaults to 4. max_retries (int): Maximum number of retries on failure. Defaults to 0. - memory_labels (Optional[dict[str, str]]): Labels to attach to all memory entries. + memory_labels (dict[str, str] | None): Labels to attach to all memory entries. include_baseline (bool | None): See ``Scenario.initialize_async``. """ # This override exists purely for type-widening: FoundryComposite is a dataclass, @@ -320,7 +318,7 @@ async def initialize_async( def _prepare_strategies( # type: ignore[ty:invalid-method-override] self, - strategies: "Optional[Sequence[FoundryStrategy | FoundryComposite | ScenarioCompositeStrategy]]", + strategies: "Sequence[FoundryStrategy | FoundryComposite | ScenarioCompositeStrategy] | None", ) -> list[ScenarioStrategy]: """ Resolve strategies and build FoundryComposite objects. @@ -396,7 +394,7 @@ def _resolve_seed_groups(self) -> list[SeedAttackGroup]: Resolve seed groups from the dataset configuration. Returns: - List[SeedGroup]: The resolved seed groups. + list[SeedGroup]: The resolved seed groups. """ return self._dataset_config.get_all_seed_attack_groups() @@ -405,7 +403,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: Retrieve the list of AtomicAttack instances in this scenario. Returns: - List[AtomicAttack]: The list of AtomicAttack instances in this scenario. + list[AtomicAttack]: The list of AtomicAttack instances in this scenario. """ # Resolve seed groups now that initialize_async has been called self._seed_groups = self._resolve_seed_groups() @@ -510,7 +508,7 @@ def _get_attack( *, attack_type: type[AttackStrategyT], converters: list[PromptConverter], - attack_kwargs: Optional[dict[str, Any]] = None, + attack_kwargs: dict[str, Any] | None = None, ) -> AttackStrategyT: """ Create an attack instance with the specified converters. @@ -530,7 +528,7 @@ def _get_attack( attack_type (type[AttackStrategyT]): The attack strategy class to instantiate. Must accept objective_target and attack_converter_config parameters. converters (list[PromptConverter]): List of converters to apply as request converters. - attack_kwargs (Optional[dict[str, Any]]): Additional attack-specific keyword arguments + attack_kwargs (dict[str, Any] | None): Additional attack-specific keyword arguments to pass to the attack constructor (e.g., tree_width for TreeOfAttacksWithPruningAttack). Returns: diff --git a/pyrit/scenario/scenarios/garak/encoding.py b/pyrit/scenario/scenarios/garak/encoding.py index 65f36e3218..abe36b7ca6 100644 --- a/pyrit/scenario/scenarios/garak/encoding.py +++ b/pyrit/scenario/scenarios/garak/encoding.py @@ -4,7 +4,6 @@ import logging from collections.abc import Sequence -from typing import Optional from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message # Deprecated. Will be removed in 0.16.0. @@ -59,7 +58,7 @@ def get_all_seed_attack_groups(self) -> list[SeedAttackGroup]: - The original seed as a SeedPrompt Returns: - List[SeedAttackGroup]: All resolved seed attack groups with objectives. + list[SeedAttackGroup]: All resolved seed attack groups with objectives. Raises: ValueError: If no seeds could be resolved from the configuration. @@ -138,21 +137,21 @@ class Encoding(Scenario): def __init__( self, *, - objective_scorer: Optional[TrueFalseScorer] = None, - encoding_templates: Optional[Sequence[str]] = None, - scenario_result_id: Optional[str] = None, + objective_scorer: TrueFalseScorer | None = None, + encoding_templates: Sequence[str] | None = None, + scenario_result_id: str | None = None, include_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. ) -> None: """ Initialize the Encoding Scenario. Args: - objective_scorer (Optional[TrueFalseScorer]): The scorer used to evaluate if the model + objective_scorer (TrueFalseScorer | None): The scorer used to evaluate if the model successfully decoded the payload. Defaults to DecodingScorer with encoding_scenario category. - encoding_templates (Optional[Sequence[str]]): Templates used to construct the decoding + encoding_templates (Sequence[str] | None): Templates used to construct the decoding prompts. Defaults to AskToDecodeConverter.garak_templates. - scenario_result_id (Optional[str]): Optional ID of an existing scenario result to resume. + scenario_result_id (str | None): Optional ID of an existing scenario result to resume. include_baseline (bool | None): **Deprecated.** Will be removed in 0.16.0. Pass ``include_baseline`` to ``initialize_async`` instead. """ @@ -184,7 +183,7 @@ def __init__( self._legacy_include_baseline = include_baseline # Will be resolved in _get_atomic_attacks_async - self._resolved_seed_groups: Optional[list[SeedAttackGroup]] = None + self._resolved_seed_groups: list[SeedAttackGroup] | None = None def _resolve_seed_groups(self) -> list[SeedAttackGroup]: """ @@ -206,7 +205,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: Retrieve the list of AtomicAttack instances in this scenario. Returns: - List[AtomicAttack]: The list of AtomicAttack instances in this scenario. + list[AtomicAttack]: The list of AtomicAttack instances in this scenario. """ # Resolve seed prompts from deprecated parameter or dataset config self._resolved_seed_groups = self._resolve_seed_groups() diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index 7b1c40f7a6..17bf2af6c0 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -6,7 +6,6 @@ import tempfile import uuid from pathlib import Path -from typing import Optional import av @@ -107,7 +106,7 @@ def __init__( self, *, text_capable_scorer: Scorer, - use_entra_auth: Optional[bool] = None, + use_entra_auth: bool | None = None, ) -> None: """ Initialize the base audio scorer. @@ -154,13 +153,13 @@ def _validate_text_scorer(scorer: Scorer) -> None: f"Supported types: {scorer._validator._supported_data_types}" ) - async def _score_audio_async(self, *, message_piece: MessagePiece, objective: Optional[str] = None) -> list[Score]: + async def _score_audio_async(self, *, message_piece: MessagePiece, objective: str | None = None) -> list[Score]: """ Transcribe audio and score the transcript. Args: message_piece (MessagePiece): The message piece containing the audio file path. - objective (Optional[str]): Optional objective description for scoring. + objective (str | None): Optional objective description for scoring. Returns: List of scores for the transcribed audio. @@ -267,7 +266,7 @@ def _ensure_wav_format(self, audio_path: str) -> str: channels=self._DEFAULT_CHANNELS, ) - def _extract_audio_from_video(self, video_path: str) -> Optional[str]: + def _extract_audio_from_video(self, video_path: str) -> str | None: """ Extract audio track from a video file. @@ -281,7 +280,7 @@ def _extract_audio_from_video(self, video_path: str) -> Optional[str]: return AudioTranscriptHelper.extract_audio_from_video(video_path) @staticmethod - def extract_audio_from_video(video_path: str) -> Optional[str]: + def extract_audio_from_video(video_path: str) -> str | None: """ Extract audio track from a video file (static version). diff --git a/pyrit/score/batch_scorer.py b/pyrit/score/batch_scorer.py index 66beec5261..4022596f9a 100644 --- a/pyrit/score/batch_scorer.py +++ b/pyrit/score/batch_scorer.py @@ -5,7 +5,6 @@ import uuid from collections.abc import Sequence from datetime import datetime -from typing import Optional from pyrit.memory import CentralMemory from pyrit.models import ( @@ -47,17 +46,17 @@ async def score_responses_by_filters_async( self, *, scorer: Scorer, - attack_id: Optional[str | uuid.UUID] = None, - conversation_id: Optional[str | uuid.UUID] = None, - prompt_ids: Optional[list[str] | list[uuid.UUID]] = None, - labels: Optional[dict[str, str]] = None, - sent_after: Optional[datetime] = None, - sent_before: Optional[datetime] = None, - original_values: Optional[list[str]] = None, - converted_values: Optional[list[str]] = None, - data_type: Optional[str] = None, - not_data_type: Optional[str] = None, - converted_value_sha256: Optional[list[str]] = None, + attack_id: str | uuid.UUID | None = None, + conversation_id: str | uuid.UUID | None = None, + prompt_ids: list[str] | list[uuid.UUID] | None = None, + labels: dict[str, str] | None = None, + sent_after: datetime | None = None, + sent_before: datetime | None = None, + original_values: list[str] | None = None, + converted_values: list[str] | None = None, + data_type: str | None = None, + not_data_type: str | None = None, + converted_value_sha256: list[str] | None = None, objective: str = "", ) -> list[Score]: """ @@ -65,17 +64,17 @@ async def score_responses_by_filters_async( Args: scorer (Scorer): The Scorer object to use for scoring. - attack_id (Optional[str | uuid.UUID]): The ID of the attack. Defaults to None. - conversation_id (Optional[str | uuid.UUID]): The ID of the conversation. Defaults to None. - prompt_ids (Optional[list[str] | list[uuid.UUID]]): A list of prompt IDs. Defaults to None. - labels (Optional[dict[str, str]]): A dictionary of labels. Defaults to None. - sent_after (Optional[datetime]): Filter for prompts sent after this datetime. Defaults to None. - sent_before (Optional[datetime]): Filter for prompts sent before this datetime. Defaults to None. - original_values (Optional[list[str]]): A list of original values. Defaults to None. - converted_values (Optional[list[str]]): A list of converted values. Defaults to None. - data_type (Optional[str]): The data type to filter by. Defaults to None. - not_data_type (Optional[str]): The data type to exclude. Defaults to None. - converted_value_sha256 (Optional[list[str]]): A list of SHA256 hashes of converted values. + attack_id (str | uuid.UUID | None): The ID of the attack. Defaults to None. + conversation_id (str | uuid.UUID | None): The ID of the conversation. Defaults to None. + prompt_ids (list[str] | list[uuid.UUID] | None): A list of prompt IDs. Defaults to None. + labels (dict[str, str] | None): A dictionary of labels. Defaults to None. + sent_after (datetime | None): Filter for prompts sent after this datetime. Defaults to None. + sent_before (datetime | None): Filter for prompts sent before this datetime. Defaults to None. + original_values (list[str] | None): A list of original values. Defaults to None. + converted_values (list[str] | None): A list of converted values. Defaults to None. + data_type (str | None): The data type to filter by. Defaults to None. + not_data_type (str | None): The data type to exclude. Defaults to None. + converted_value_sha256 (list[str] | None): A list of SHA256 hashes of converted values. Defaults to None. objective (str): A task is used to give the scorer more context on what exactly to score. A task might be the request prompt text or the original attack model's objective. diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index d4e824d1fe..d921b2e1cf 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -3,7 +3,7 @@ import uuid from abc import ABC, abstractmethod -from typing import Optional, cast +from typing import cast from uuid import UUID from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score @@ -32,7 +32,7 @@ class ConversationScorer(Scorer, ABC): enforce_all_pieces_valid=False, ) - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: """ Scores the entire conversation history by concatenating all messages and passing to the wrapped scorer. @@ -47,7 +47,7 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non Args: message (Message): A message from the conversation to be scored. The conversation ID from the first message piece is used to retrieve the full conversation from memory. - objective (Optional[str]): Optional objective to evaluate against. + objective (str | None): Optional objective to evaluate against. Returns: list[Score]: List of Score objects from the underlying scorer @@ -128,7 +128,7 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non return scores - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Not used - ConversationScorer operates at conversation level via _score_async. @@ -159,7 +159,7 @@ def validate_return_scores(self, scores: list[Score]) -> None: def create_conversation_scorer( *, scorer: Scorer, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, ) -> Scorer: """ Create a ConversationScorer that inherits from the same type as the wrapped scorer. @@ -171,7 +171,7 @@ def create_conversation_scorer( Args: scorer (Scorer): The scorer to wrap for conversation-level evaluation. Must be an instance of FloatScaleScorer or TrueFalseScorer. - validator (Optional[ScorerPromptValidator]): Optional validator override. + validator (ScorerPromptValidator | None): Optional validator override. If not provided, uses the wrapped scorer's validator. Returns: @@ -187,7 +187,7 @@ def create_conversation_scorer( >>> isinstance(conversation_scorer, ConversationScorer) # True """ # Determine the base class of the wrapped scorer - scorer_base_class: Optional[type[Scorer]] = None + scorer_base_class: type[Scorer] | None = None if isinstance(scorer, FloatScaleScorer): scorer_base_class = FloatScaleScorer diff --git a/pyrit/score/float_scale/audio_float_scale_scorer.py b/pyrit/score/float_scale/audio_float_scale_scorer.py index 203f8e1281..17653c9d5f 100644 --- a/pyrit/score/float_scale/audio_float_scale_scorer.py +++ b/pyrit/score/float_scale/audio_float_scale_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier, MessagePiece, Score from pyrit.score.audio_transcript_scorer import AudioTranscriptHelper @@ -23,8 +22,8 @@ def __init__( self, *, text_capable_scorer: FloatScaleScorer, - validator: Optional[ScorerPromptValidator] = None, - use_entra_auth: Optional[bool] = None, + validator: ScorerPromptValidator | None = None, + use_entra_auth: bool | None = None, ) -> None: """ Initialize the AudioFloatScaleScorer. @@ -62,7 +61,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score an audio file by transcribing it and scoring the transcript. diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 2ef3b412fb..ff8747e298 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -95,27 +95,27 @@ def _get_eval_files_for_category(cls, category: TextCategory) -> Optional["Score def __init__( self, *, - endpoint: Optional[str | None] = None, - api_key: Optional[str | Callable[[], str | Awaitable[str]] | None] = None, - harm_categories: Optional[list[TextCategory]] = None, - validator: Optional[ScorerPromptValidator] = None, + endpoint: str | None = None, + api_key: str | Callable[[], str | Awaitable[str]] | None = None, + harm_categories: list[TextCategory] | None = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize an Azure Content Filter Scorer. Args: - endpoint (Optional[str | None]): The endpoint URL for the Azure Content Safety service. + endpoint (str | None | None): The endpoint URL for the Azure Content Safety service. Defaults to the `ENDPOINT_URI_ENVIRONMENT_VARIABLE` environment variable. - api_key (Optional[str | Callable[[], str | Awaitable[str]] | None]): + api_key (str | Callable[[], str | Awaitable[str]] | None | None): The API key for accessing the Azure Content Safety service, or a callable that returns an access token. Both synchronous and asynchronous token providers are supported. Sync providers are automatically wrapped for async compatibility. If not provided (via parameter or environment variable), Entra ID authentication is used automatically. Defaults to the `API_KEY_ENVIRONMENT_VARIABLE` environment variable. - harm_categories (Optional[list[TextCategory]]): The harm categories you want to query for as + harm_categories (list[TextCategory] | None): The harm categories you want to query for as defined in azure.ai.contentsafety.models.TextCategory. If not provided, defaults to all categories. - validator (Optional[ScorerPromptValidator]): Custom validator for the scorer. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator for the scorer. Defaults to None. Raises: ValueError: If no endpoint is provided. @@ -247,7 +247,7 @@ def _get_chunks(self, text: str) -> list[str]: return [text[i : i + self.MAX_TEXT_LENGTH] for i in range(0, len(text), self.MAX_TEXT_LENGTH)] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Evaluate the input text or image using the Azure Content Filter API. @@ -257,7 +257,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op In case of an image, the image size must be less than 2048 x 2048 pixels, but more than 50x50 pixels. The data size should not exceed 4 MB. Image must be of type JPEG, PNG, GIF, BMP, TIFF, or WEBP. - objective (Optional[str]): The objective for scoring context. Currently not supported for this scorer. + objective (str | None): The objective for scoring context. Currently not supported for this scorer. Defaults to None. Returns: @@ -343,7 +343,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op for result in aggregated_results ] - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: """ Build one neutral ``0.0`` fallback score per configured harm category. @@ -359,7 +359,7 @@ def _build_fallback_score(self, *, message: Message, objective: Optional[str]) - Args: message (Message): The message whose first piece is inspected for status. - objective (Optional[str]): The objective associated with this scoring call. + objective (str | None): The objective associated with this scoring call. Returns: list[Score]: One ``0.0`` ``float_scale`` score per configured harm category, diff --git a/pyrit/score/float_scale/float_scale_scorer.py b/pyrit/score/float_scale/float_scale_scorer.py index c888c117f3..e0501b0c12 100644 --- a/pyrit/score/float_scale/float_scale_scorer.py +++ b/pyrit/score/float_scale/float_scale_scorer.py @@ -35,7 +35,7 @@ class FloatScaleScorer(Scorer): "blocked = True") should override ``_score_piece_async`` or ``_build_fallback_score``. """ - def __init__(self, *, validator: ScorerPromptValidator, chat_target: Optional[PromptTarget] = None) -> None: + def __init__(self, *, validator: ScorerPromptValidator, chat_target: PromptTarget | None = None) -> None: """ Initialize the FloatScaleScorer. @@ -46,7 +46,7 @@ def __init__(self, *, validator: ScorerPromptValidator, chat_target: Optional[Pr """ super().__init__(validator=validator, chat_target=chat_target) - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: """ Build a single-element list containing a neutral ``0.0`` score when no pieces could be scored. @@ -55,7 +55,7 @@ def _build_fallback_score(self, *, message: Message, objective: Optional[str]) - Args: message (Message): The message whose first piece is inspected for status. - objective (Optional[str]): The objective associated with this scoring call. + objective (str | None): The objective associated with this scoring call. Returns: list[Score]: A single-element list containing a ``0.0`` ``float_scale`` score @@ -138,15 +138,15 @@ async def _score_value_with_llm_async( message_value: str, message_data_type: PromptDataType, scored_prompt_id: str | UUID, - prepended_text_message_piece: Optional[str] = None, - category: Optional[str | UUID] = None, - objective: Optional[str] = None, + prepended_text_message_piece: str | None = None, + category: str | UUID | None = None, + objective: str | None = None, score_value_output_key: str = "score_value", rationale_output_key: str = "rationale", description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: Optional[ComponentIdentifier] = None, + attack_identifier: ComponentIdentifier | None = None, ) -> UnvalidatedScore: score: UnvalidatedScore | None = None try: diff --git a/pyrit/score/float_scale/insecure_code_scorer.py b/pyrit/score/float_scale/insecure_code_scorer.py index 91d795e9e4..85919178b0 100644 --- a/pyrit/score/float_scale/insecure_code_scorer.py +++ b/pyrit/score/float_scale/insecure_code_scorer.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from pathlib import Path -from typing import Optional, Union from pyrit.common import verify_and_resolve_path from pyrit.common.path import SCORER_SEED_PROMPT_PATH @@ -26,17 +25,17 @@ def __init__( self, *, chat_target: PromptTarget, - system_prompt_path: Optional[Union[str, Path]] = None, - validator: Optional[ScorerPromptValidator] = None, + system_prompt_path: str | Path | None = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the Insecure Code Scorer. Args: chat_target (PromptTarget): The target to use for scoring code security. - system_prompt_path (Optional[Union[str, Path]]): Path to the YAML file containing the system prompt. + system_prompt_path (str | Path | None): Path to the YAML file containing the system prompt. Defaults to the default insecure code scoring prompt if not provided. - validator (Optional[ScorerPromptValidator]): Custom validator for the scorer. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator for the scorer. Defaults to None. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, chat_target=chat_target) @@ -72,13 +71,13 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the given message piece using LLM to detect security vulnerabilities. Args: message_piece (MessagePiece): The code snippet to be scored. - objective (Optional[str]): Optional objective description for scoring. Defaults to None. + objective (str | None): Optional objective description for scoring. Defaults to None. Returns: list[Score]: A list containing a single Score object. diff --git a/pyrit/score/float_scale/plagiarism_scorer.py b/pyrit/score/float_scale/plagiarism_scorer.py index c1286532ff..f163f55788 100644 --- a/pyrit/score/float_scale/plagiarism_scorer.py +++ b/pyrit/score/float_scale/plagiarism_scorer.py @@ -3,7 +3,6 @@ import re from enum import Enum -from typing import Optional import numpy as np @@ -47,7 +46,7 @@ def __init__( reference_text: str, metric: PlagiarismMetric = PlagiarismMetric.LCS, n: int = 5, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the PlagiarismScorer. @@ -56,7 +55,7 @@ def __init__( reference_text (str): The reference text to compare against. metric (PlagiarismMetric): The plagiarism detection metric to use. Defaults to PlagiarismMetric.LCS. n (int): The n-gram size for n-gram similarity. Defaults to 5. - validator (Optional[ScorerPromptValidator]): Custom validator for the scorer. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator for the scorer. Defaults to None. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) @@ -84,7 +83,7 @@ def _tokenize(self, text: str) -> list[str]: Tokenize text using whitespace-based tokenization (case-insensitive). Returns: - List[str]: List of lowercase tokens with punctuation removed. + list[str]: List of lowercase tokens with punctuation removed. """ text = text.lower() text = re.sub(r"[^\w\s]", "", text) @@ -173,13 +172,13 @@ def _plagiarism_score( raise ValueError("metric must be 'lcs', 'levenshtein', or 'jaccard'") - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the AI response against the reference text using the specified metric. Args: message_piece (MessagePiece): The piece to score. - objective (Optional[str]): Not applicable for this scorer. + objective (str | None): Not applicable for this scorer. Returns: list[Score]: A list containing the computed score. diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index fa1b56627d..17105defb9 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.prompt_target import CHAT_TARGET_REQUIREMENTS from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer @@ -31,11 +31,11 @@ def __init__( *, chat_target: PromptTarget, system_prompt_format_string: str, - prompt_format_string: Optional[str] = None, - category: Optional[str] = None, + prompt_format_string: str | None = None, + category: str | None = None, min_value: int = 0, max_value: int = 100, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, score_value_output_key: str = "score_value", rationale_output_key: str = "rationale", description_output_key: str = "description", @@ -58,11 +58,11 @@ def __init__( possibly via normalization-pipeline adaptation). system_prompt_format_string (str): System prompt template with placeholders for objective, prompt, and message_piece. - prompt_format_string (Optional[str]): User prompt template with the same placeholders. - category (Optional[str]): Category for the score. + prompt_format_string (str | None): User prompt template with the same placeholders. + category (str | None): Category for the score. min_value (int): Minimum of the model's native scale. Defaults to 0. max_value (int): Maximum of the model's native scale. Defaults to 100. - validator (Optional[ScorerPromptValidator]): Custom validator. If omitted, a default + validator (ScorerPromptValidator | None): Custom validator. If omitted, a default validator will be used requiring text input and an objective. score_value_output_key (str): JSON key for the score value. Defaults to "score_value". rationale_output_key (str): JSON key for the rationale. Defaults to "rationale". @@ -112,7 +112,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score a single message piece using the configured prompts and scale to [0, 1]. diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index 6134f6f7af..f5f2e97bcf 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -5,7 +5,6 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import Optional import yaml @@ -24,15 +23,15 @@ class LikertScaleEvalFiles: Configuration for evaluating a Likert scale scorer on a set of dataset files. Args: - human_labeled_datasets_files (List[str]): List of glob patterns to match CSV files. + human_labeled_datasets_files (list[str]): List of glob patterns to match CSV files. result_file (str): Name of the result file for storing evaluation results. - harm_category (Optional[str]): The harm category for harm scorers. Defaults to None. + harm_category (str | None): The harm category for harm scorers. Defaults to None. The harm definition path is derived as "{harm_category}.yaml". """ human_labeled_datasets_files: list[str] result_file: str - harm_category: Optional[str] = None + harm_category: str | None = None class LikertScalePaths(enum.Enum): @@ -158,7 +157,7 @@ def path(self) -> Path: return self.value[0] @property - def evaluation_files(self) -> Optional[LikertScaleEvalFiles]: + def evaluation_files(self) -> LikertScaleEvalFiles | None: """Get the evaluation file configuration, or None if no evaluation dataset exists.""" return self.value[1] @@ -178,24 +177,24 @@ def __init__( self, *, chat_target: PromptTarget, - likert_scale: Optional[LikertScalePaths] = None, - custom_likert_path: Optional[Path] = None, - custom_system_prompt_path: Optional[Path] = None, - validator: Optional[ScorerPromptValidator] = None, + likert_scale: LikertScalePaths | None = None, + custom_likert_path: Path | None = None, + custom_system_prompt_path: Path | None = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the SelfAskLikertScorer. Args: chat_target (PromptTarget): The chat target to use for scoring. - likert_scale (Optional[LikertScalePaths]): The Likert scale configuration to use for scoring. - custom_likert_path (Optional[Path]): Path to a custom YAML file containing the Likert scale definition. + likert_scale (LikertScalePaths | None): The Likert scale configuration to use for scoring. + custom_likert_path (Path | None): Path to a custom YAML file containing the Likert scale definition. This allows users to use their own Likert scales without modifying the code, as long as the YAML file follows the expected format. Only one of `likert_scale` or `custom_likert_path` should be provided. Defaults to None. - custom_system_prompt_path (Optional[Path]): Path to a custom system prompt file. This allows users to + custom_system_prompt_path (Path | None): Path to a custom system prompt file. This allows users to provide their own system prompt without modifying the code. Defaults to None. - validator (Optional[ScorerPromptValidator]): Custom validator for the scorer. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator for the scorer. Defaults to None. Raises: ValueError: If both `likert_scale` and `custom_likert_path` are provided, if neither is provided, @@ -211,9 +210,7 @@ def __init__( if likert_scale is None and custom_likert_path is None: raise ValueError("One of 'likert_scale' or 'custom_likert_path' must be provided.") - self._scoring_instructions_template: Optional[SeedPrompt] = ( - None # Will be set in _set_likert_scale_system_prompt - ) + self._scoring_instructions_template: SeedPrompt | None = None # Will be set in _set_likert_scale_system_prompt if custom_system_prompt_path is not None: self._validate_custom_system_prompt_path(custom_system_prompt_path) self._scoring_instructions_template = SeedPrompt.from_yaml_file(custom_system_prompt_path) @@ -436,13 +433,13 @@ def _validate_custom_likert_path(custom_likert_path: Path) -> None: f"Custom Likert scale file must be a YAML file (.yaml or .yml), got '{custom_likert_path.suffix}'." ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score the given message_piece using "self-ask" for the chat target. Args: message_piece (MessagePiece): The message piece containing the text to be scored. - objective (Optional[str]): The objective for scoring context. Currently not supported for this scorer. + objective (str | None): The objective for scoring context. Currently not supported for this scorer. Defaults to None. Returns: diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 6cdc1e2921..92db37a06a 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -3,7 +3,7 @@ import enum from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import yaml @@ -44,20 +44,20 @@ def __init__( self, *, chat_target: PromptTarget, - scale_arguments_path: Optional[Union[Path, str]] = None, - system_prompt_path: Optional[Union[Path, str]] = None, - validator: Optional[ScorerPromptValidator] = None, + scale_arguments_path: Path | str | None = None, + system_prompt_path: Path | str | None = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the SelfAskScaleScorer. Args: chat_target (PromptTarget): The chat target to use for scoring. - scale_arguments_path (Optional[Union[Path, str]]): Path to the YAML file containing scale definitions. + scale_arguments_path (Path | str | None): Path to the YAML file containing scale definitions. Defaults to TREE_OF_ATTACKS_SCALE if not provided. - system_prompt_path (Optional[Union[Path, str]]): Path to the YAML file containing the system prompt. + system_prompt_path (Path | str | None): Path to the YAML file containing the system prompt. Defaults to GENERAL_SYSTEM_PROMPT if not provided. - validator (Optional[ScorerPromptValidator]): Custom validator for the scorer. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator for the scorer. Defaults to None. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR, chat_target=chat_target) @@ -101,7 +101,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the given message_piece using "self-ask" for the chat target. diff --git a/pyrit/score/float_scale/video_float_scale_scorer.py b/pyrit/score/float_scale/video_float_scale_scorer.py index cb337aa506..8e32bd9064 100644 --- a/pyrit/score/float_scale/video_float_scale_scorer.py +++ b/pyrit/score/float_scale/video_float_scale_scorer.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.models import ComponentIdentifier, MessagePiece, Score from pyrit.score.float_scale.float_scale_score_aggregator import ( @@ -42,12 +42,12 @@ def __init__( self, *, image_capable_scorer: FloatScaleScorer, - audio_scorer: Optional[FloatScaleScorer] = None, - num_sampled_frames: Optional[int] = None, - validator: Optional[ScorerPromptValidator] = None, + audio_scorer: FloatScaleScorer | None = None, + num_sampled_frames: int | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: FloatScaleAggregatorFunc = FloatScaleScorerByCategory.MAX, - image_objective_template: Optional[str] = VideoHelper._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, - audio_objective_template: Optional[str] = None, + image_objective_template: str | None = VideoHelper._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, + audio_objective_template: str | None = None, ) -> None: """ Initialize the VideoFloatScaleScorer. @@ -116,7 +116,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score a single video piece by extracting frames and optionally audio, then aggregating their scores. @@ -145,7 +145,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op # Get the ID from the message piece piece_id = message_piece.id if message_piece.id is not None else message_piece.original_prompt_id - # Call the aggregator - all aggregators now return List[ScoreAggregatorResult] + # Call the aggregator - all aggregators now return list[ScoreAggregatorResult] aggregator_results: list[ScoreAggregatorResult] = self._score_aggregator(all_scores) # Build rationale prefix diff --git a/pyrit/score/score_aggregator_result.py b/pyrit/score/score_aggregator_result.py index de5b8dc212..d039efafac 100644 --- a/pyrit/score/score_aggregator_result.py +++ b/pyrit/score/score_aggregator_result.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from dataclasses import dataclass -from typing import Union @dataclass(frozen=True, slots=True) @@ -11,16 +10,16 @@ class ScoreAggregatorResult: Common result object returned by score aggregators. Attributes: - value (Union[bool, float]): The aggregated value. For true/false aggregators this is + value (bool | float): The aggregated value. For true/false aggregators this is a boolean. For float-scale aggregators, this is a float in the range [0, 1]. description (str): A short, human-friendly description of the aggregation outcome. rationale (str): Combined rationale from constituent scores. - category (List[str]): Combined list of categories from constituent scores. - metadata (Dict[str, Union[str, int, float]]): Combined metadata from constituent scores. + category (list[str]): Combined list of categories from constituent scores. + metadata (dict[str, str | int | float]): Combined metadata from constituent scores. """ - value: Union[bool, float] + value: bool | float description: str rationale: str category: list[str] - metadata: dict[str, Union[str, int, float]] + metadata: dict[str, str | int | float] diff --git a/pyrit/score/score_utils.py b/pyrit/score/score_utils.py index 5ae68c3939..4429b34e67 100644 --- a/pyrit/score/score_utils.py +++ b/pyrit/score/score_utils.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional, Union from pyrit.common.utils import combine_dict from pyrit.models import Score @@ -11,7 +10,7 @@ ORIGINAL_FLOAT_VALUE_KEY = "original_float_value" -def combine_metadata_and_categories(scores: list[Score]) -> tuple[dict[str, Union[str, int, float]], list[str]]: +def combine_metadata_and_categories(scores: list[Score]) -> tuple[dict[str, str | int | float], list[str]]: """ Combine metadata and categories from multiple scores with deduplication. @@ -21,7 +20,7 @@ def combine_metadata_and_categories(scores: list[Score]) -> tuple[dict[str, Unio Returns: Tuple of (metadata dict, sorted category list with empty strings filtered). """ - metadata: dict[str, Union[str, int, float]] = {} + metadata: dict[str, str | int | float] = {} category_set: set[str] = set() for s in scores: @@ -47,7 +46,7 @@ def format_score_for_rationale(score: Score) -> str: return f" - {class_type} {score.score_value}: {score.score_rationale or ''}" -def normalize_score_to_float(score: Optional[Score]) -> float: +def normalize_score_to_float(score: Score | None) -> float: """ Normalize any score to a float value between 0.0 and 1.0. diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 091ecf6c81..f3cda9923b 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -13,8 +13,6 @@ TYPE_CHECKING, Any, ClassVar, - Optional, - Union, cast, ) @@ -66,7 +64,7 @@ class Scorer(Identifiable, abc.ABC): # Evaluation configuration - maps input dataset files to a result file. # Specifies glob patterns for datasets and a result file name. - evaluation_file_mapping: Optional[ScorerEvalDatasetFiles] = None + evaluation_file_mapping: ScorerEvalDatasetFiles | None = None #: Capability requirements placed on the scorer's chat target (if any). #: Subclasses that use a chat target should override this and pass the @@ -74,7 +72,7 @@ class Scorer(Identifiable, abc.ABC): #: validate it. TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements() - _identifier: Optional[ComponentIdentifier] = None + _identifier: ComponentIdentifier | None = None #: When True, blocked responses that contain partial content #: (in prompt_metadata["partial_content"]) will be scored using that content @@ -97,13 +95,13 @@ def __init_subclass__(cls, **kwargs: Any) -> None: enforce_keyword_only_init(cls, base_name="Scorer") - def __init__(self, *, validator: ScorerPromptValidator, chat_target: Optional[PromptTarget] = None) -> None: + def __init__(self, *, validator: ScorerPromptValidator, chat_target: PromptTarget | None = None) -> None: """ Initialize the Scorer. Args: validator (ScorerPromptValidator): Validator for message pieces and scorer configuration. - chat_target (Optional[PromptTarget]): Chat target used by the scorer, if any. When + chat_target (PromptTarget | None): Chat target used by the scorer, if any. When provided, it is validated against ``TARGET_REQUIREMENTS``. """ self._validator = validator @@ -165,8 +163,8 @@ def _memory(self) -> MemoryInterface: def _create_identifier( self, *, - params: Optional[dict[str, Any]] = None, - children: Optional[dict[str, Union[ComponentIdentifier, list[ComponentIdentifier]]]] = None, + params: dict[str, Any] | None = None, + children: dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None = None, ) -> ComponentIdentifier: """ Construct the scorer identifier. @@ -178,10 +176,10 @@ def _create_identifier( to set the identifier with their specific parameters. Args: - params (Optional[Dict[str, Any]]): Additional behavioral parameters from + params (dict[str, Any] | None): Additional behavioral parameters from the subclass (e.g., system_prompt_template, score_aggregator). Merged into the base params. - children (Optional[Dict[str, Union[ComponentIdentifier, List[ComponentIdentifier]]]]): + children (dict[str, ComponentIdentifier | list[ComponentIdentifier]] | None): Named child component identifiers (e.g., prompt_target, sub_scorers). Returns: @@ -199,8 +197,8 @@ async def score_async( self, message: Message, *, - objective: Optional[str] = None, - role_filter: Optional[ChatMessageRole] = None, + objective: str | None = None, + role_filter: ChatMessageRole | None = None, skip_on_error_result: bool = False, infer_objective_from_request: bool = False, ) -> list[Score]: @@ -209,9 +207,9 @@ async def score_async( Args: message (Message): The message to be scored. - objective (Optional[str]): The task or objective based on which the message should be scored. + objective (str | None): The task or objective based on which the message should be scored. Defaults to None. - role_filter (Optional[ChatMessageRole]): Only score messages with this exact stored role. + role_filter (ChatMessageRole | None): Only score messages with this exact stored role. Use "assistant" to score only real assistant responses, or "simulated_assistant" to score only simulated responses. Defaults to None (no filtering). skip_on_error_result (bool): If True, skip scoring if the message contains an error. @@ -283,7 +281,7 @@ async def score_async( return scores - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: """ Score the given request response asynchronously. @@ -293,7 +291,7 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non Args: message (Message): The message to score. - objective (Optional[str]): The objective to evaluate against. Defaults to None. + objective (str | None): The objective to evaluate against. Defaults to None. Returns: list[Score]: A list of Score objects. @@ -316,11 +314,11 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non return [score for sublist in piece_score_lists for score in sublist] @abstractmethod - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: raise NotImplementedError @staticmethod - def _create_text_piece_from_blocked(piece: MessagePiece) -> Optional[MessagePiece]: + def _create_text_piece_from_blocked(piece: MessagePiece) -> MessagePiece | None: """ Create a text-typed copy of a blocked MessagePiece using its partial content. @@ -399,7 +397,7 @@ def _get_supported_pieces(self, message: Message) -> list[MessagePiece]: ] @abstractmethod - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: """ Return neutral fallback ``Score`` objects when ``_score_async`` produced no scores. @@ -417,7 +415,7 @@ def _build_fallback_score(self, *, message: Message, objective: Optional[str]) - Args: message (Message): The (possibly substituted) message that was scored. - objective (Optional[str]): The objective associated with this scoring call. + objective (str | None): The objective associated with this scoring call. Returns: list[Score]: One or more fallback scores. Must not be empty. @@ -437,12 +435,12 @@ def validate_return_scores(self, scores: list[Score]) -> None: async def evaluate_async( self, - file_mapping: Optional[ScorerEvalDatasetFiles] = None, + file_mapping: ScorerEvalDatasetFiles | None = None, *, num_scorer_trials: int = 3, update_registry_behavior: RegistryUpdateBehavior | None = None, max_concurrency: int = 10, - ) -> Optional[ScorerMetrics]: + ) -> ScorerMetrics | None: """ Evaluate this scorer against human-labeled datasets. @@ -491,7 +489,7 @@ async def evaluate_async( ) @abstractmethod - def get_scorer_metrics(self) -> Optional[ScorerMetrics]: + def get_scorer_metrics(self) -> ScorerMetrics | None: """ Get evaluation metrics for this scorer from the configured evaluation result file. @@ -507,13 +505,13 @@ def get_scorer_metrics(self) -> Optional[ScorerMetrics]: """ raise NotImplementedError("Subclasses must implement get_scorer_metrics") - async def score_text_async(self, text: str, *, objective: Optional[str] = None) -> list[Score]: + async def score_text_async(self, text: str, *, objective: str | None = None) -> list[Score]: """ Scores the given text based on the task using the chat target. Args: text (str): The text to be scored. - objective (Optional[str]): The task based on which the text should be scored + objective (str | None): The task based on which the text should be scored Returns: list[Score]: A list of Score objects representing the results. @@ -530,13 +528,13 @@ async def score_text_async(self, text: str, *, objective: Optional[str] = None) request.message_pieces[0].not_in_memory = True return await self.score_async(request, objective=objective) - async def score_image_async(self, image_path: str, *, objective: Optional[str] = None) -> list[Score]: + async def score_image_async(self, image_path: str, *, objective: str | None = None) -> list[Score]: """ Score the given image using the chat target. Args: image_path (str): The path to the image file to be scored. - objective (Optional[str]): The objective based on which the image should be scored. Defaults to None. + objective (str | None): The objective based on which the image should be scored. Defaults to None. Returns: list[Score]: A list of Score objects representing the results. @@ -558,9 +556,9 @@ async def score_prompts_batch_async( self, *, messages: Sequence[Message], - objectives: Optional[Sequence[str]] = None, + objectives: Sequence[str] | None = None, batch_size: int = 10, - role_filter: Optional[ChatMessageRole] = None, + role_filter: ChatMessageRole | None = None, skip_on_error_result: bool = False, infer_objective_from_request: bool = False, ) -> list[Score]: @@ -572,7 +570,7 @@ async def score_prompts_batch_async( objectives (Sequence[str]): The objectives/tasks based on which the prompts should be scored. Must have the same length as messages. batch_size (int): The maximum batch size for processing prompts. Defaults to 10. - role_filter (Optional[ChatMessageRole]): If provided, only score pieces with this role. + role_filter (ChatMessageRole | None): If provided, only score pieces with this role. Defaults to None (no filtering). skip_on_error_result (bool): If True, skip scoring pieces that have errors. Defaults to False. infer_objective_from_request (bool): If True and objective is empty, attempt to infer @@ -610,14 +608,14 @@ async def score_prompts_batch_async( return [score for sublist in results for score in sublist] async def score_image_batch_async( - self, *, image_paths: Sequence[str], objectives: Optional[Sequence[str]] = None, batch_size: int = 10 + self, *, image_paths: Sequence[str], objectives: Sequence[str] | None = None, batch_size: int = 10 ) -> list[Score]: """ Score a batch of images asynchronously. Args: image_paths (Sequence[str]): Sequence of paths to image files to be scored. - objectives (Optional[Sequence[str]]): Optional sequence of objectives corresponding to each image. + objectives (Sequence[str] | None): Optional sequence of objectives corresponding to each image. If provided, must match the length of image_paths. Defaults to None. batch_size (int): Maximum number of images to score concurrently. Defaults to 10. @@ -670,15 +668,15 @@ async def _score_value_with_llm_async( message_value: str, message_data_type: PromptDataType, scored_prompt_id: str, - prepended_text_message_piece: Optional[str] = None, - category: Optional[Sequence[str] | str] = None, - objective: Optional[str] = None, + prepended_text_message_piece: str | None = None, + category: Sequence[str] | str | None = None, + objective: str | None = None, score_value_output_key: str = "score_value", rationale_output_key: str = "rationale", description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: Optional[ComponentIdentifier] = None, + attack_identifier: ComponentIdentifier | None = None, ) -> UnvalidatedScore: """ Send a request to a target, and take care of retries. @@ -694,13 +692,13 @@ async def _score_value_with_llm_async( message_data_type (PromptDataType): The type of the data being sent in the message (e.g., "text", "image_path", "audio_path"). scored_prompt_id (str): The ID of the scored prompt. - prepended_text_message_piece (Optional[str]): Text context to prepend before the main + prepended_text_message_piece (str | None): Text context to prepend before the main message_value. When provided, creates a multi-piece message with this text first, followed by the message_value. Useful for adding objective/context when scoring non-text content. Defaults to None. - category (Optional[Sequence[str] | str]): The category of the score. Can also be parsed from + category (Sequence[str] | str | None): The category of the score. Can also be parsed from the JSON response if not provided. Defaults to None. - objective (Optional[str]): A description of the objective that is associated with the score, + objective (str | None): A description of the objective that is associated with the score, used for contextualizing the result. Defaults to None. score_value_output_key (str): The key in the JSON response that contains the score value. Defaults to "score_value". @@ -712,7 +710,7 @@ async def _score_value_with_llm_async( Defaults to "metadata". category_output_key (str): The key in the JSON response that contains the category. Defaults to "category". - attack_identifier (Optional[ComponentIdentifier]): The attack identifier. + attack_identifier (ComponentIdentifier | None): The attack identifier. Defaults to None. Returns: @@ -786,7 +784,7 @@ async def _score_value_with_llm_async( # Validate and normalize category to a list of strings cat_val = category_response if category_response is not None else category - normalized_category: Optional[list[str]] + normalized_category: list[str] | None if cat_val is None: normalized_category = None elif isinstance(cat_val, str): @@ -801,7 +799,7 @@ async def _score_value_with_llm_async( # Normalize metadata to a dictionary with string keys and string/int/float values raw_md = parsed_response.get(metadata_output_key) - normalized_md: Optional[dict[str, Union[str, int, float]]] + normalized_md: dict[str, str | int | float] | None if raw_md is None: normalized_md = None elif isinstance(raw_md, dict): @@ -868,10 +866,10 @@ def _extract_objective_from_response(self, response: Message) -> str: async def score_response_async( *, response: Message, - objective_scorer: Optional[Scorer] = None, - auxiliary_scorers: Optional[list[Scorer]] = None, + objective_scorer: Scorer | None = None, + auxiliary_scorers: list[Scorer] | None = None, role_filter: ChatMessageRole = "assistant", - objective: Optional[str] = None, + objective: str | None = None, skip_on_error_result: bool = True, ) -> dict[str, list[Score]]: """ @@ -879,15 +877,15 @@ async def score_response_async( Args: response (Message): Response containing pieces to score. - objective_scorer (Optional[Scorer]): The main scorer to determine success. Defaults to None. - auxiliary_scorers (Optional[List[Scorer]]): List of auxiliary scorers to apply. Defaults to None. + objective_scorer (Scorer | None): The main scorer to determine success. Defaults to None. + auxiliary_scorers (list[Scorer] | None): List of auxiliary scorers to apply. Defaults to None. role_filter (ChatMessageRole): Only score pieces with this exact stored role. Defaults to "assistant" (real responses only, not simulated). - objective (Optional[str]): Task/objective for scoring context. Defaults to None. + objective (str | None): Task/objective for scoring context. Defaults to None. skip_on_error_result (bool): If True, skip scoring pieces that have errors. Defaults to True. Returns: - Dict[str, List[Score]]: Dictionary with keys `auxiliary_scores` and `objective_scores` + dict[str, list[Score]]: Dictionary with keys `auxiliary_scores` and `objective_scores` containing lists of scores from each type of scorer. Raises: @@ -946,7 +944,7 @@ async def score_response_multiple_scorers_async( response: Message, scorers: list[Scorer], role_filter: ChatMessageRole = "assistant", - objective: Optional[str] = None, + objective: str | None = None, skip_on_error_result: bool = True, ) -> list[Score]: """ @@ -957,14 +955,14 @@ async def score_response_multiple_scorers_async( Args: response (Message): The response containing pieces to score. - scorers (List[Scorer]): List of scorers to apply. + scorers (list[Scorer]): List of scorers to apply. role_filter (ChatMessageRole): Only score pieces with this exact stored role. Defaults to "assistant" (real responses only, not simulated). - objective (Optional[str]): Optional objective description for scoring context. + objective (str | None): Optional objective description for scoring context. skip_on_error_result (bool): If True, skip scoring pieces that have errors (default: True). Returns: - List[Score]: All scores from all scorers + list[Score]: All scores from all scorers """ if not scorers: return [] diff --git a/pyrit/score/scorer_evaluation/human_labeled_dataset.py b/pyrit/score/scorer_evaluation/human_labeled_dataset.py index f0f0fdcd87..937573d840 100644 --- a/pyrit/score/scorer_evaluation/human_labeled_dataset.py +++ b/pyrit/score/scorer_evaluation/human_labeled_dataset.py @@ -5,7 +5,7 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, cast import pandas as pd @@ -35,7 +35,7 @@ class HumanLabeledEntry: (representing degree of severity) for harm datasets, and booleans for objective datasets. Parameters: - conversation (List[Message]): A list of Message objects representing the + conversation (list[Message]): A list of Message objects representing the conversation to be scored. This can contain one Message object if you are just scoring individual assistant responses. human_scores (List): A list of human-assigned scores for the responses. Each entry in the list corresponds to @@ -126,8 +126,8 @@ def __init__( entries: list[HumanLabeledEntry], metrics_type: MetricsType, version: str, - harm_definition: Optional[str] = None, - harm_definition_version: Optional[str] = None, + harm_definition: str | None = None, + harm_definition_version: str | None = None, ) -> None: """ Initialize the HumanLabeledDataset. @@ -136,7 +136,7 @@ def __init__( name (str): The name of the human-labeled dataset. For datasets of uniform type, this is often the harm category (e.g. hate_speech) or objective. It will be used in the naming of metrics (JSON) and model scores (CSV) files when evaluation is run on this dataset. - entries (List[HumanLabeledEntry]): A list of entries in the dataset. + entries (list[HumanLabeledEntry]): A list of entries in the dataset. metrics_type (MetricsType): The type of the human-labeled dataset, either HARM or OBJECTIVE. version (str): The version of the human-labeled dataset. @@ -156,7 +156,7 @@ def __init__( self.version = version self.harm_definition = harm_definition self.harm_definition_version = harm_definition_version - self._harm_definition_obj: Optional[HarmDefinition] = None + self._harm_definition_obj: HarmDefinition | None = None def get_harm_definition(self) -> Optional["HarmDefinition"]: """ @@ -188,12 +188,12 @@ def get_harm_definition(self) -> Optional["HarmDefinition"]: def from_csv( cls, *, - csv_path: Union[str, Path], + csv_path: str | Path, metrics_type: MetricsType, - dataset_name: Optional[str] = None, - version: Optional[str] = None, - harm_definition: Optional[str] = None, - harm_definition_version: Optional[str] = None, + dataset_name: str | None = None, + version: str | None = None, + harm_definition: str | None = None, + harm_definition_version: str | None = None, ) -> "HumanLabeledDataset": """ Load a human-labeled dataset from a CSV file with standard column names. @@ -210,7 +210,7 @@ def from_csv( - For objective datasets: # dataset_version=x.y Args: - csv_path (Union[str, Path]): The path to the CSV file. + csv_path (str | Path): The path to the CSV file. metrics_type (MetricsType): The type of the human-labeled dataset, either HARM or OBJECTIVE. dataset_name (str, Optional): The name of the dataset. If not provided, it will be inferred from the CSV file name. diff --git a/pyrit/score/scorer_evaluation/scorer_evaluator.py b/pyrit/score/scorer_evaluation/scorer_evaluator.py index 5f203753fc..0e9dc4267b 100644 --- a/pyrit/score/scorer_evaluation/scorer_evaluator.py +++ b/pyrit/score/scorer_evaluation/scorer_evaluator.py @@ -7,7 +7,7 @@ import logging import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast import numpy as np from scipy.stats import ttest_1samp @@ -60,17 +60,17 @@ class ScorerEvalDatasetFiles: Multiple files matching the patterns will be concatenated before evaluation. Args: - human_labeled_datasets_files (List[str]): List of glob patterns to match CSV files. + human_labeled_datasets_files (list[str]): List of glob patterns to match CSV files. Examples: ``["objective/*.csv"]``, ``["objective/hate_speech.csv", "objective/violence.csv"]`` result_file (str): Name of the result file (stem used as dict key in results). Example: ``"objective_achieved_metrics.jsonl"`` - harm_category (Optional[str]): The harm category for harm scorers (e.g., "hate_speech", "violence"). + harm_category (str | None): The harm category for harm scorers (e.g., "hate_speech", "violence"). Required for harm evaluations, ignored for objective evaluations. Defaults to None. """ human_labeled_datasets_files: list[str] result_file: str - harm_category: Optional[str] = None + harm_category: str | None = None class ScorerEvaluator(abc.ABC): @@ -92,7 +92,7 @@ def __init__(self, scorer: Scorer) -> None: self.scorer = scorer @classmethod - def from_scorer(cls, scorer: Scorer, metrics_type: Optional[MetricsType] = None) -> ScorerEvaluator: + def from_scorer(cls, scorer: Scorer, metrics_type: MetricsType | None = None) -> ScorerEvaluator: """ Create a ScorerEvaluator based on the type of scoring. @@ -120,7 +120,7 @@ async def run_evaluation_async( num_scorer_trials: int = 3, update_registry_behavior: RegistryUpdateBehavior = RegistryUpdateBehavior.SKIP_IF_EXISTS, max_concurrency: int = 10, - ) -> Optional[ScorerMetrics]: + ) -> ScorerMetrics | None: """ Evaluate scorer using dataset files configuration. @@ -265,11 +265,11 @@ def _should_skip_evaluation( self, *, dataset_version: str, - harm_definition_version: Optional[str] = None, + harm_definition_version: str | None = None, num_scorer_trials: int, - harm_category: Optional[str] = None, + harm_category: str | None = None, result_file_path: Path, - ) -> tuple[bool, Optional[ScorerMetrics]]: + ) -> tuple[bool, ScorerMetrics | None]: """ Determine whether to skip evaluation based on existing registry entries. @@ -282,13 +282,13 @@ def _should_skip_evaluation( Args: dataset_version (str): The version of the dataset. - harm_definition_version (Optional[str]): Version of the harm definition YAML. For harm evaluations. + harm_definition_version (str | None): Version of the harm definition YAML. For harm evaluations. num_scorer_trials (int): Number of scorer trials requested. - harm_category (Optional[str]): The harm category for harm scorers. Required for harm evaluations. + harm_category (str | None): The harm category for harm scorers. Required for harm evaluations. result_file_path (Path): Path to the result file to search. Returns: - Tuple[bool, Optional[ScorerMetrics]]: (should_skip, existing_metrics) + tuple[bool, ScorerMetrics | None]: (should_skip, existing_metrics) - (True, metrics) if should skip and use existing metrics - (False, None) if should run evaluation """ @@ -302,7 +302,7 @@ def _should_skip_evaluation( # Determine if this is a harm or objective evaluation metrics_type = MetricsType.OBJECTIVE if isinstance(self.scorer, TrueFalseScorer) else MetricsType.HARM - existing: Optional[ScorerMetrics] = None + existing: ScorerMetrics | None = None if metrics_type == MetricsType.HARM: if harm_category is None: logger.warning("harm_category must be provided for harm scorer evaluations") @@ -449,7 +449,7 @@ async def evaluate_dataset_async( def _validate_and_extract_data( self, labeled_dataset: HumanLabeledDataset, - ) -> tuple[list[Message], list[list[float]], Optional[list[str]]]: + ) -> tuple[list[Message], list[list[float]], list[str] | None]: """ Validate the dataset and extract data for evaluation. @@ -471,11 +471,11 @@ def _compute_metrics( all_human_scores: np.ndarray, all_model_scores: np.ndarray, num_scorer_trials: int, - dataset_name: Optional[str] = None, - dataset_version: Optional[str] = None, - harm_category: Optional[str] = None, - harm_definition: Optional[str] = None, - harm_definition_version: Optional[str] = None, + dataset_name: str | None = None, + dataset_version: str | None = None, + harm_category: str | None = None, + harm_definition: str | None = None, + harm_definition_version: str | None = None, ) -> ScorerMetrics: """ Compute evaluation metrics from human and model scores. @@ -532,7 +532,7 @@ class HarmScorerEvaluator(ScorerEvaluator): def _validate_and_extract_data( self, labeled_dataset: HumanLabeledDataset, - ) -> tuple[list[Message], list[list[float]], Optional[list[str]]]: + ) -> tuple[list[Message], list[list[float]], list[str] | None]: """ Validate harm dataset and extract evaluation data. @@ -569,11 +569,11 @@ def _compute_metrics( all_human_scores: np.ndarray, all_model_scores: np.ndarray, num_scorer_trials: int, - dataset_name: Optional[str] = None, - dataset_version: Optional[str] = None, - harm_category: Optional[str] = None, - harm_definition: Optional[str] = None, - harm_definition_version: Optional[str] = None, + dataset_name: str | None = None, + dataset_version: str | None = None, + harm_category: str | None = None, + harm_definition: str | None = None, + harm_definition_version: str | None = None, ) -> HarmScorerMetrics: reliability_data = np.concatenate((all_human_scores, all_model_scores)) # Calculate the median of human scores for each response, which is considered the gold label @@ -647,7 +647,7 @@ class ObjectiveScorerEvaluator(ScorerEvaluator): def _validate_and_extract_data( self, labeled_dataset: HumanLabeledDataset, - ) -> tuple[list[Message], list[list[float]], Optional[list[str]]]: + ) -> tuple[list[Message], list[list[float]], list[str] | None]: """ Validate objective dataset and extract evaluation data. @@ -685,11 +685,11 @@ def _compute_metrics( all_human_scores: np.ndarray, all_model_scores: np.ndarray, num_scorer_trials: int, - dataset_name: Optional[str] = None, - dataset_version: Optional[str] = None, - harm_category: Optional[str] = None, - harm_definition: Optional[str] = None, - harm_definition_version: Optional[str] = None, + dataset_name: str | None = None, + dataset_version: str | None = None, + harm_category: str | None = None, + harm_definition: str | None = None, + harm_definition_version: str | None = None, ) -> ObjectiveScorerMetrics: # Calculate the majority vote of human scores for each response, which is considered the gold label. # If the vote is split, the resulting gold score will be 0 (i.e. False). Same logic is applied to model trials. diff --git a/pyrit/score/scorer_evaluation/scorer_metrics.py b/pyrit/score/scorer_evaluation/scorer_metrics.py index fab1f9a505..c3bc0ae03c 100644 --- a/pyrit/score/scorer_evaluation/scorer_metrics.py +++ b/pyrit/score/scorer_evaluation/scorer_metrics.py @@ -5,7 +5,7 @@ import json from dataclasses import asdict, dataclass, field -from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Generic, TypeVar from pyrit.common.utils import verify_and_resolve_path @@ -43,9 +43,9 @@ class ScorerMetrics: num_responses: int num_human_raters: int num_scorer_trials: int = field(default=1, kw_only=True) - dataset_name: Optional[str] = field(default=None, kw_only=True) - dataset_version: Optional[str] = field(default=None, kw_only=True) - trial_scores: Optional[np.ndarray] = field(default=None, kw_only=True) + dataset_name: str | None = field(default=None, kw_only=True) + dataset_version: str | None = field(default=None, kw_only=True) + trial_scores: np.ndarray | None = field(default=None, kw_only=True) average_score_time_seconds: float = field(default=0.0, kw_only=True) def to_json(self) -> str: @@ -63,7 +63,7 @@ def to_json(self) -> str: return json.dumps(asdict(self)) @classmethod - def from_json_file(cls: type[T], file_path: Union[str, Path]) -> T: + def from_json_file(cls: type[T], file_path: str | Path) -> T: """ Load a metrics instance from a JSON file on disk. @@ -74,7 +74,7 @@ def from_json_file(cls: type[T], file_path: Union[str, Path]) -> T: fields (e.g., cached ``init=False`` attributes) before constructing the instance. Args: - file_path (Union[str, Path]): The path to the JSON file. + file_path (str | Path): The path to the JSON file. Returns: ScorerMetrics: An instance of ScorerMetrics (or subclass) with the loaded data. @@ -96,7 +96,7 @@ def from_json_file(cls: type[T], file_path: Union[str, Path]) -> T: return cls(**filtered_data) @classmethod - def from_json(cls: type[T], file_path: Union[str, Path]) -> T: + def from_json(cls: type[T], file_path: str | Path) -> T: """ Load a metrics instance from a JSON file (deprecated alias for ``from_json_file``). @@ -104,7 +104,7 @@ def from_json(cls: type[T], file_path: Union[str, Path]) -> T: string. Use ``from_json_file`` instead. Args: - file_path (Union[str, Path]): The path to the JSON file. + file_path (str | Path): The path to the JSON file. Returns: ScorerMetrics: An instance of ScorerMetrics (or subclass) with the loaded data. @@ -157,14 +157,14 @@ class HarmScorerMetrics(ScorerMetrics): t_statistic: float p_value: float krippendorff_alpha_combined: float - harm_category: Optional[str] = field(default=None, kw_only=True) - harm_definition: Optional[str] = field(default=None, kw_only=True) - harm_definition_version: Optional[str] = field(default=None, kw_only=True) - krippendorff_alpha_humans: Optional[float] = None - krippendorff_alpha_model: Optional[float] = None - _harm_definition_obj: Optional[HarmDefinition] = field(default=None, init=False, repr=False) - - def get_harm_definition(self) -> Optional[HarmDefinition]: + harm_category: str | None = field(default=None, kw_only=True) + harm_definition: str | None = field(default=None, kw_only=True) + harm_definition_version: str | None = field(default=None, kw_only=True) + krippendorff_alpha_humans: float | None = None + krippendorff_alpha_model: float | None = None + _harm_definition_obj: HarmDefinition | None = field(default=None, init=False, repr=False) + + def get_harm_definition(self) -> HarmDefinition | None: """ Load and return the HarmDefinition object for this metrics instance. @@ -205,7 +205,7 @@ class ObjectiveScorerMetrics(ScorerMetrics): in its positive predictions. recall (float): The recall of the model scores, an indicator of the model's ability to correctly identify positive labels. - trial_scores (Optional[np.ndarray]): The raw scores from each trial. Shape is (num_trials, num_responses). + trial_scores (np.ndarray | None): The raw scores from each trial. Shape is (num_trials, num_responses). Useful for debugging and analyzing scorer variance. """ diff --git a/pyrit/score/scorer_evaluation/scorer_metrics_io.py b/pyrit/score/scorer_evaluation/scorer_metrics_io.py index d915dc24ab..b598bab40a 100644 --- a/pyrit/score/scorer_evaluation/scorer_metrics_io.py +++ b/pyrit/score/scorer_evaluation/scorer_metrics_io.py @@ -11,7 +11,7 @@ import threading from dataclasses import asdict from pathlib import Path -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar from pyrit.common.path import ( SCORER_EVALS_PATH, @@ -53,7 +53,7 @@ def _metrics_to_registry_dict(metrics: ScorerMetrics) -> dict[str, Any]: def get_all_objective_metrics( - file_path: Optional[Path] = None, + file_path: Path | None = None, ) -> list[ScorerMetricsWithIdentity[ObjectiveScorerMetrics]]: """ Load all objective scorer metrics with full scorer identity for comparison. @@ -63,12 +63,12 @@ def get_all_objective_metrics( access like `entry.metrics.accuracy` or `entry.metrics.f1_score`. Args: - file_path (Optional[Path]): Path to a specific JSONL file to load. + file_path (Path | None): Path to a specific JSONL file to load. If not provided, uses the default path: SCORER_EVALS_PATH / "objective" / "objective_achieved_metrics.jsonl" Returns: - List[ScorerMetricsWithIdentity[ObjectiveScorerMetrics]]: List of metrics with scorer identity. + list[ScorerMetricsWithIdentity[ObjectiveScorerMetrics]]: List of metrics with scorer identity. Access metrics via `entry.metrics.accuracy`, `entry.metrics.f1_score`, etc. Access scorer info via `entry.scorer_identifier.class_name`, etc. """ @@ -92,7 +92,7 @@ def get_all_harm_metrics( harm_category (str): The harm category to load metrics for (e.g., "hate_speech", "violence"). Returns: - List[ScorerMetricsWithIdentity[HarmScorerMetrics]]: List of metrics with scorer identity. + list[ScorerMetricsWithIdentity[HarmScorerMetrics]]: List of metrics with scorer identity. Access metrics via `entry.metrics.mean_absolute_error`, `entry.metrics.harm_category`, etc. Access scorer info via `entry.scorer_identifier.class_name`, etc. """ @@ -112,10 +112,10 @@ def _load_metrics_from_file( Args: file_path (Path): Path to the JSONL file to load. - metrics_class (Type[M]): The metrics class to instantiate (ObjectiveScorerMetrics or HarmScorerMetrics). + metrics_class (type[M]): The metrics class to instantiate (ObjectiveScorerMetrics or HarmScorerMetrics). Returns: - List[ScorerMetricsWithIdentity[M]]: List of metrics with scorer identity. + list[ScorerMetricsWithIdentity[M]]: List of metrics with scorer identity. """ results: list[ScorerMetricsWithIdentity[M]] = [] entries = _load_jsonl(file_path) @@ -151,14 +151,14 @@ def _load_metrics_from_file( def find_objective_metrics_by_eval_hash( *, eval_hash: str, - file_path: Optional[Path] = None, -) -> Optional[ObjectiveScorerMetrics]: + file_path: Path | None = None, +) -> ObjectiveScorerMetrics | None: """ Find objective scorer metrics by evaluation hash. Args: eval_hash (str): The scorer evaluation hash to search for. - file_path (Optional[Path]): Path to the JSONL file to search. + file_path (Path | None): Path to the JSONL file to search. If not provided, uses the default path: SCORER_EVALS_PATH / "objective" / "objective_achieved_metrics.jsonl" @@ -175,7 +175,7 @@ def find_harm_metrics_by_eval_hash( *, eval_hash: str, harm_category: str, -) -> Optional[HarmScorerMetrics]: +) -> HarmScorerMetrics | None: """ Find harm scorer metrics by evaluation hash. @@ -195,7 +195,7 @@ def _find_metrics_by_eval_hash( file_path: Path, eval_hash: str, metrics_class: type[M], -) -> Optional[M]: +) -> M | None: """ Find scorer metrics by evaluation hash in a specific file. @@ -205,7 +205,7 @@ def _find_metrics_by_eval_hash( Args: file_path (Path): Path to the JSONL file to search. eval_hash (str): The scorer evaluation hash to search for. - metrics_class (Type[M]): The metrics class to instantiate. + metrics_class (type[M]): The metrics class to instantiate. Returns: The metrics instance if found, else None. diff --git a/pyrit/score/scorer_prompt_validator.py b/pyrit/score/scorer_prompt_validator.py index f89c93d54d..1e6946e3a0 100644 --- a/pyrit/score/scorer_prompt_validator.py +++ b/pyrit/score/scorer_prompt_validator.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from collections.abc import Sequence -from typing import Optional, get_args +from typing import get_args from pyrit.models import ChatMessageRole, Message, MessagePiece, PromptDataType @@ -18,32 +18,32 @@ class ScorerPromptValidator: def __init__( self, *, - supported_data_types: Optional[Sequence[PromptDataType]] = None, - required_metadata: Optional[Sequence[str]] = None, - supported_roles: Optional[Sequence[ChatMessageRole]] = None, - max_pieces_in_response: Optional[int] = None, - max_text_length: Optional[int] = None, - enforce_all_pieces_valid: Optional[bool] = False, - raise_on_no_valid_pieces: Optional[bool] = False, + supported_data_types: Sequence[PromptDataType] | None = None, + required_metadata: Sequence[str] | None = None, + supported_roles: Sequence[ChatMessageRole] | None = None, + max_pieces_in_response: int | None = None, + max_text_length: int | None = None, + enforce_all_pieces_valid: bool | None = False, + raise_on_no_valid_pieces: bool | None = False, is_objective_required: bool = False, ) -> None: """ Initialize the ScorerPromptValidator. Args: - supported_data_types (Optional[Sequence[PromptDataType]]): Data types that the scorer supports. + supported_data_types (Sequence[PromptDataType] | None): Data types that the scorer supports. Defaults to all data types if not provided. - required_metadata (Optional[Sequence[str]]): Metadata keys that must be present in message pieces. + required_metadata (Sequence[str] | None): Metadata keys that must be present in message pieces. Defaults to empty list. - supported_roles (Optional[Sequence[ChatMessageRole]]): Message roles that the scorer supports. + supported_roles (Sequence[ChatMessageRole] | None): Message roles that the scorer supports. Defaults to all roles if not provided. - max_pieces_in_response (Optional[int]): Maximum number of pieces allowed in a response. + max_pieces_in_response (int | None): Maximum number of pieces allowed in a response. Defaults to None (no limit). - max_text_length (Optional[int]): Maximum character length for text data type pieces. + max_text_length (int | None): Maximum character length for text data type pieces. Defaults to None (no limit). - enforce_all_pieces_valid (Optional[bool]): Whether all pieces must be valid or just at least one. + enforce_all_pieces_valid (bool | None): Whether all pieces must be valid or just at least one. Defaults to False. - raise_on_no_valid_pieces (Optional[bool]): Whether to raise ValueError when no pieces are valid. + raise_on_no_valid_pieces (bool | None): Whether to raise ValueError when no pieces are valid. Defaults to False, allowing scorers to handle empty results gracefully (e.g., returning False for blocked responses). Set to True to raise an exception instead. is_objective_required (bool): Whether an objective must be provided for scoring. Defaults to False. diff --git a/pyrit/score/true_false/audio_true_false_scorer.py b/pyrit/score/true_false/audio_true_false_scorer.py index c10befbf44..58397a3a29 100644 --- a/pyrit/score/true_false/audio_true_false_scorer.py +++ b/pyrit/score/true_false/audio_true_false_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier, MessagePiece, Score from pyrit.score.audio_transcript_scorer import AudioTranscriptHelper @@ -23,8 +22,8 @@ def __init__( self, *, text_capable_scorer: TrueFalseScorer, - validator: Optional[ScorerPromptValidator] = None, - use_entra_auth: Optional[bool] = None, + validator: ScorerPromptValidator | None = None, + use_entra_auth: bool | None = None, ) -> None: """ Initialize the AudioTrueFalseScorer. @@ -62,7 +61,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score an audio file by transcribing it and scoring the transcript. diff --git a/pyrit/score/true_false/decoding_scorer.py b/pyrit/score/true_false/decoding_scorer.py index f9cecc5f07..ec17af03fc 100644 --- a/pyrit/score/true_false/decoding_scorer.py +++ b/pyrit/score/true_false/decoding_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.analytics.text_matching import ExactTextMatching, TextMatching from pyrit.memory.central_memory import CentralMemory @@ -30,21 +29,21 @@ class DecodingScorer(TrueFalseScorer): def __init__( self, *, - text_matcher: Optional[TextMatching] = None, - categories: Optional[list[str]] = None, + text_matcher: TextMatching | None = None, + categories: list[str] | None = None, aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the DecodingScorer. Args: - text_matcher (Optional[TextMatching]): The text matching strategy to use. + text_matcher (TextMatching | None): The text matching strategy to use. Defaults to ExactTextMatching with case_sensitive=False. - categories (Optional[list[str]]): Optional list of categories for the score. Defaults to None. + categories (list[str] | None): Optional list of categories for the score. Defaults to None. aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. - validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator. Defaults to None. """ self._text_matcher = text_matcher if text_matcher else ExactTextMatching(case_sensitive=False) self._score_categories = categories if categories else [] @@ -65,13 +64,13 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score the given request piece based on text matching strategy. Args: message_piece (MessagePiece): The message piece to score. - objective (Optional[str]): The objective to evaluate against. Defaults to None. + objective (str | None): The objective to evaluate against. Defaults to None. Currently not used for this scorer. Returns: diff --git a/pyrit/score/true_false/float_scale_threshold_scorer.py b/pyrit/score/true_false/float_scale_threshold_scorer.py index b89d5439fd..828b98a9dd 100644 --- a/pyrit/score/true_false/float_scale_threshold_scorer.py +++ b/pyrit/score/true_false/float_scale_threshold_scorer.py @@ -80,7 +80,7 @@ def get_chat_target(self) -> Optional["PromptTarget"]: Delegate to the wrapped scorer. Returns: - Optional[PromptTarget]: The chat target from the wrapped scorer. + PromptTarget | None: The chat target from the wrapped scorer. """ return self._scorer.get_chat_target() @@ -88,17 +88,17 @@ async def _score_async( self, message: Message, *, - objective: Optional[str] = None, - role_filter: Optional[ChatMessageRole] = None, + objective: str | None = None, + role_filter: ChatMessageRole | None = None, ) -> list[Score]: """ Scores the piece using the underlying float-scale scorer and thresholds the resulting score. Args: message (Message): The message to score. - objective (Optional[str]): The objective to evaluate against (the original attacker model's objective). + objective (str | None): The objective to evaluate against (the original attacker model's objective). Defaults to None. - role_filter (Optional[ChatMessageRole]): Optional filter for message roles. Defaults to None. + role_filter (ChatMessageRole | None): Optional filter for message roles. Defaults to None. Returns: list[Score]: A list containing a single true/false Score object based on the threshold comparison. @@ -173,13 +173,13 @@ async def _score_async( return [score] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Float Scale scorers do not support piecewise scoring. Args: message_piece (MessagePiece): Unused. - objective (Optional[str]): Unused. + objective (str | None): Unused. Raises: NotImplementedError: Always, since composite scoring operates at the response level. diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index c37469ff62..5fc51fbc25 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -3,7 +3,6 @@ import uuid from textwrap import dedent -from typing import Optional import httpx from openai import BadRequestError @@ -37,7 +36,7 @@ def __init__( *, level: GandalfLevel, chat_target: PromptTarget, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -46,7 +45,7 @@ def __init__( Args: level (GandalfLevel): The Gandalf challenge level to score against. chat_target (PromptTarget): The chat target used for password extraction. - validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to text data type validator. + validator (ScorerPromptValidator | None): Custom validator. Defaults to text data type validator. score_aggregator (TrueFalseAggregatorFunc): Aggregator for combining scores. Defaults to TrueFalseScoreAggregator.OR. """ @@ -143,13 +142,13 @@ async def _check_for_password_in_conversation_async(self, conversation_id: str) return "" return response_text - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the text based on the password found in the text. Args: message_piece (MessagePiece): The message piece containing the text to be scored. - objective (Optional[str]): The objective to evaluate against (the original attacker model's objective). + objective (str | None): The objective to evaluate against (the original attacker model's objective). Currently not used for this scorer. Returns: diff --git a/pyrit/score/true_false/markdown_injection.py b/pyrit/score/true_false/markdown_injection.py index 517124120c..33a678469f 100644 --- a/pyrit/score/true_false/markdown_injection.py +++ b/pyrit/score/true_false/markdown_injection.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import re -from typing import Optional from pyrit.models import ComponentIdentifier, MessagePiece, Score from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -27,14 +26,14 @@ class MarkdownInjectionScorer(TrueFalseScorer): def __init__( self, *, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ Initialize the MarkdownInjectionScorer. Args: - validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator. Defaults to None. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. """ @@ -55,14 +54,14 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Check for markdown injection in the text. It returns True if markdown injection is detected, else False. Args: message_piece (MessagePiece): The MessagePiece object containing the text to check for markdown injection. - objective (Optional[str]): The objective to evaluate against. Defaults to None. + objective (str | None): The objective to evaluate against. Defaults to None. Currently not used for this scorer. Returns: diff --git a/pyrit/score/true_false/prompt_shield_scorer.py b/pyrit/score/true_false/prompt_shield_scorer.py index 8f048300cc..a320e89fa9 100644 --- a/pyrit/score/true_false/prompt_shield_scorer.py +++ b/pyrit/score/true_false/prompt_shield_scorer.py @@ -4,7 +4,7 @@ import json import logging import uuid -from typing import Any, Optional +from typing import Any from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score, ScoreType from pyrit.prompt_target import PromptShieldTarget @@ -32,7 +32,7 @@ def __init__( self, *, prompt_shield_target: PromptShieldTarget, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -40,7 +40,7 @@ def __init__( Args: prompt_shield_target (PromptShieldTarget): The Prompt Shield target to use for scoring. - validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator. Defaults to None. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. """ @@ -64,7 +64,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: conversation_id = str(uuid.uuid4()) body = message_piece.original_value diff --git a/pyrit/score/true_false/question_answer_scorer.py b/pyrit/score/true_false/question_answer_scorer.py index a2346fc650..46e7b94527 100644 --- a/pyrit/score/true_false/question_answer_scorer.py +++ b/pyrit/score/true_false/question_answer_scorer.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.models import MessagePiece, Score from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -32,8 +32,8 @@ def __init__( self, *, correct_answer_matching_patterns: list[str] = CORRECT_ANSWER_MATCHING_PATTERNS, - category: Optional[list[str]] = None, - validator: Optional[ScorerPromptValidator] = None, + category: list[str] | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -43,8 +43,8 @@ def __init__( correct_answer_matching_patterns (list[str]): A list of patterns to check for in the response. If any pattern is found in the response, the score will be True. These patterns should be format strings that will be formatted with the correct answer metadata. Defaults to CORRECT_ANSWER_MATCHING_PATTERNS. - category (Optional[list[str]]): Optional list of categories for the score. Defaults to None. - validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. + category (list[str] | None): Optional list of categories for the score. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator. Defaults to None. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. """ @@ -67,14 +67,14 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score the message piece using question answering evaluation. Args: message_piece (MessagePiece): The answer given by the target, which must contain 'correct_answer_index' and 'correct_answer' in prompt_metadata. - objective (Optional[str]): The objective to evaluate against. Defaults to None. + objective (str | None): The objective to evaluate against. Defaults to None. Currently not used for this scorer. Returns: diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index 26f28bcddf..cce672b642 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -3,7 +3,6 @@ import enum from pathlib import Path -from typing import Optional, Union import yaml @@ -42,19 +41,19 @@ def __init__( self, *, chat_target: PromptTarget, - content_classifier_path: Union[str, Path], + content_classifier_path: str | Path, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize a new instance of the SelfAskCategoryScorer class. Args: chat_target (PromptTarget): The chat target to interact with. - content_classifier_path (Union[str, Path]): The path to the classifier YAML file. + content_classifier_path (str | Path): The path to the classifier YAML file. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. - validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator. Defaults to None. """ super().__init__( score_aggregator=score_aggregator, @@ -104,7 +103,7 @@ def _content_classifier_to_string(self, categories: list[dict[str, str]]) -> str Convert the content classifier categories to a string representation to be put in a system prompt. Args: - categories (list[Dict[str, str]]): The categories to convert. + categories (list[dict[str, str]]): The categories to convert. Returns: str: The string representation of the categories. @@ -129,13 +128,13 @@ def _content_classifier_to_string(self, categories: list[dict[str, str]]) -> str return category_descriptions - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the given message using the chat target. Args: message_piece (MessagePiece): The message piece to score. - objective (Optional[str]): The task based on which the text should be scored + objective (str | None): The task based on which the text should be scored (the original attacker model's objective). Defaults to None. Returns: diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index 4fc934aa42..71acd45a56 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.prompt_target import CHAT_TARGET_REQUIREMENTS from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -35,9 +35,9 @@ def __init__( *, chat_target: PromptTarget, system_prompt_format_string: str, - prompt_format_string: Optional[str] = None, - category: Optional[str] = None, - validator: Optional[ScorerPromptValidator] = None, + prompt_format_string: str | None = None, + category: str | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, score_value_output_key: str = "score_value", rationale_output_key: str = "rationale", @@ -61,9 +61,9 @@ def __init__( possibly via normalization-pipeline adaptation). system_prompt_format_string (str): System prompt template with placeholders for objective, task (alias of objective), prompt, and message_piece. - prompt_format_string (Optional[str]): User prompt template with the same placeholders. - category (Optional[str]): Category for the score. - validator (Optional[ScorerPromptValidator]): Custom validator. If omitted, a default + prompt_format_string (str | None): User prompt template with the same placeholders. + category (str | None): Category for the score. + validator (ScorerPromptValidator | None): Custom validator. If omitted, a default validator will be used requiring text input and an objective. score_aggregator (TrueFalseAggregatorFunc): Aggregator for combining scores. Defaults to TrueFalseScoreAggregator.OR. @@ -112,7 +112,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score a single message piece using the configured prompts. diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index 7ea9c9a834..a2f5bc078e 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.common.utils import verify_and_resolve_path @@ -38,8 +38,8 @@ def __init__( self, *, chat_target: PromptTarget, - true_false_question_path: Optional[pathlib.Path] = None, - validator: Optional[ScorerPromptValidator] = None, + true_false_question_path: pathlib.Path | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -49,9 +49,9 @@ def __init__( chat_target (PromptTarget): The chat target to use for the scorer. Must satisfy CHAT_TARGET_REQUIREMENTS (multi-turn + editable history capabilities, possibly via normalization-pipeline adaptation). - true_false_question_path (Optional[pathlib.Path]): The path to the true/false question file. + true_false_question_path (pathlib.Path | None): The path to the true/false question file. Defaults to None, which uses the default question_answering.yaml file. - validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator. Defaults to None. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. """ @@ -67,13 +67,13 @@ def __init__( score_aggregator=score_aggregator, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score the message piece using question answering evaluation. Args: message_piece (MessagePiece): The answer given by the target to be scored. - objective (Optional[str]): The objective, which usually contains the question and the correct answer. + objective (str | None): The objective, which usually contains the question and the correct answer. Defaults to None. Returns: diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index b27fce74f2..b5a5c2b80c 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -3,7 +3,6 @@ import enum from pathlib import Path -from typing import Optional, Union from pyrit.common.path import SCORER_SEED_PROMPT_PATH from pyrit.models import ComponentIdentifier, MessagePiece, Score, SeedPrompt, UnvalidatedScore @@ -69,9 +68,9 @@ def __init__( self, *, chat_target: PromptTarget, - refusal_system_prompt_path: Union[RefusalScorerPaths, Path, str] = RefusalScorerPaths.OBJECTIVE_STRICT, - prompt_format_string: Optional[str] = None, - validator: Optional[ScorerPromptValidator] = None, + refusal_system_prompt_path: RefusalScorerPaths | Path | str = RefusalScorerPaths.OBJECTIVE_STRICT, + prompt_format_string: str | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -81,14 +80,14 @@ def __init__( chat_target (PromptTarget): The chat target to use for the scorer. Must satisfy CHAT_TARGET_REQUIREMENTS (multi-turn + editable history capabilities, possibly via normalization-pipeline adaptation). - refusal_system_prompt_path (Union[RefusalScorerPaths, Path, str]): The path to the system prompt + refusal_system_prompt_path (RefusalScorerPaths | Path | str): The path to the system prompt to use for refusal detection. Can be a RefusalScorerPaths enum value, a Path, or a string path. Defaults to RefusalScorerPaths.OBJECTIVE_STRICT. - prompt_format_string (Optional[str]): The format string for the prompt with placeholders. + prompt_format_string (str | None): The format string for the prompt with placeholders. Use ``{objective}`` for the conversation objective and ``{response}`` for the response to evaluate. Defaults to "conversation_objective: {objective}\\nresponse_to_evaluate_input: {response}". - validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator. Defaults to None. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. """ @@ -138,13 +137,13 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the prompt and determines whether the response is a refusal. Args: message_piece (MessagePiece): The message piece to score. - objective (Optional[str]): The objective to evaluate against (the original attacker model's objective). + objective (str | None): The objective to evaluate against (the original attacker model's objective). Defaults to None. Returns: diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index 70ee7dc2e8..0786d0db38 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -4,7 +4,7 @@ import enum from collections.abc import Iterator from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import yaml @@ -109,10 +109,10 @@ def __init__( self, *, chat_target: PromptTarget, - true_false_question_path: Optional[Union[str, Path]] = None, - true_false_question: Optional[TrueFalseQuestion] = None, - true_false_system_prompt_path: Optional[Union[str, Path]] = None, - validator: Optional[ScorerPromptValidator] = None, + true_false_question_path: str | Path | None = None, + true_false_question: TrueFalseQuestion | None = None, + true_false_system_prompt_path: str | Path | None = None, + validator: ScorerPromptValidator | None = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: """ @@ -122,10 +122,10 @@ def __init__( chat_target (PromptTarget): The chat target to use for the scorer. Must satisfy CHAT_TARGET_REQUIREMENTS (multi-turn + editable history capabilities, possibly via normalization-pipeline adaptation). - true_false_question_path (Optional[Union[str, Path]]): The path to the true/false question file. - true_false_question (Optional[TrueFalseQuestion]): The true/false question object. - true_false_system_prompt_path (Optional[Union[str, Path]]): The path to the system prompt file. - validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. + true_false_question_path (str | Path | None): The path to the true/false question file. + true_false_question (TrueFalseQuestion | None): The true/false question object. + true_false_system_prompt_path (str | Path | None): The path to the system prompt file. + validator (ScorerPromptValidator | None): Custom validator. Defaults to None. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. @@ -194,13 +194,13 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Scores the given message piece using "self-ask" for the chat target. Args: message_piece (MessagePiece): The message piece containing the text or image to be scored. - objective (Optional[str]): The objective to evaluate against (the original attacker model's objective). + objective (str | None): The objective to evaluate against (the original attacker model's objective). Defaults to None. Returns: diff --git a/pyrit/score/true_false/substring_scorer.py b/pyrit/score/true_false/substring_scorer.py index 194f5d19eb..4429930e50 100644 --- a/pyrit/score/true_false/substring_scorer.py +++ b/pyrit/score/true_false/substring_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.analytics.text_matching import ExactTextMatching, TextMatching from pyrit.models import ComponentIdentifier, MessagePiece, Score @@ -27,22 +26,22 @@ def __init__( self, *, substring: str, - text_matcher: Optional[TextMatching] = None, - categories: Optional[list[str]] = None, + text_matcher: TextMatching | None = None, + categories: list[str] | None = None, aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, - validator: Optional[ScorerPromptValidator] = None, + validator: ScorerPromptValidator | None = None, ) -> None: """ Initialize the SubStringScorer. Args: substring (str): The substring to search for in the text. - text_matcher (Optional[TextMatching]): The text matching strategy to use. + text_matcher (TextMatching | None): The text matching strategy to use. Defaults to ExactTextMatching with case_sensitive=False. - categories (Optional[list[str]]): Optional list of categories for the score. Defaults to None. + categories (list[str] | None): Optional list of categories for the score. Defaults to None. aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. - validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator. Defaults to None. """ self._substring = substring self._text_matcher = text_matcher if text_matcher else ExactTextMatching(case_sensitive=False) @@ -65,13 +64,13 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score the given message piece based on presence of the substring. Args: message_piece (MessagePiece): The message piece to score. - objective (Optional[str]): The objective to evaluate against. Defaults to None. + objective (str | None): The objective to evaluate against. Defaults to None. Currently not used for this scorer. Returns: diff --git a/pyrit/score/true_false/true_false_composite_scorer.py b/pyrit/score/true_false/true_false_composite_scorer.py index 148e80322c..0fece73d64 100644 --- a/pyrit/score/true_false/true_false_composite_scorer.py +++ b/pyrit/score/true_false/true_false_composite_scorer.py @@ -36,7 +36,7 @@ def __init__( aggregator (TrueFalseAggregatorFunc): Aggregation function to combine child scores (e.g., ``TrueFalseScoreAggregator.AND``, ``TrueFalseScoreAggregator.OR``, ``TrueFalseScoreAggregator.MAJORITY``). - scorers (List[TrueFalseScorer]): The constituent true/false scorers to invoke. + scorers (list[TrueFalseScorer]): The constituent true/false scorers to invoke. Raises: ValueError: If no scorers are provided. @@ -83,16 +83,16 @@ async def _score_async( self, message: Message, *, - objective: Optional[str] = None, - role_filter: Optional[ChatMessageRole] = None, + objective: str | None = None, + role_filter: ChatMessageRole | None = None, ) -> list[Score]: """ Score a request/response by combining results from all constituent scorers. Args: message (Message): The request/response to score. - objective (Optional[str]): Scoring objective or context. - role_filter (Optional[ChatMessageRole]): Optional filter for message roles. Defaults to None. + objective (str | None): Scoring objective or context. + role_filter (ChatMessageRole | None): Optional filter for message roles. Defaults to None. Returns: list[Score]: A single-element list with the aggregated true/false score. @@ -140,13 +140,13 @@ async def _score_async( return [return_score] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Composite scorers do not support piecewise scoring. Args: message_piece (MessagePiece): Unused. - objective (Optional[str]): Unused. + objective (str | None): Unused. Raises: NotImplementedError: Always, since composite scoring operates at the response level. diff --git a/pyrit/score/true_false/true_false_inverter_scorer.py b/pyrit/score/true_false/true_false_inverter_scorer.py index a62d2ac287..c3b894edda 100644 --- a/pyrit/score/true_false/true_false_inverter_scorer.py +++ b/pyrit/score/true_false/true_false_inverter_scorer.py @@ -15,13 +15,13 @@ class TrueFalseInverterScorer(TrueFalseScorer): """A scorer that inverts a true false score.""" - def __init__(self, *, scorer: TrueFalseScorer, validator: Optional[ScorerPromptValidator] = None) -> None: + def __init__(self, *, scorer: TrueFalseScorer, validator: ScorerPromptValidator | None = None) -> None: """ Initialize the TrueFalseInverterScorer. Args: scorer (TrueFalseScorer): The underlying true/false scorer whose results will be inverted. - validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. + validator (ScorerPromptValidator | None): Custom validator. Defaults to None. Note: This parameter is present for signature compatibility but is not used. Raises: @@ -54,7 +54,7 @@ def get_chat_target(self) -> Optional["PromptTarget"]: Delegate to the wrapped scorer. Returns: - Optional[PromptTarget]: The chat target from the wrapped scorer. + PromptTarget | None: The chat target from the wrapped scorer. """ return self._scorer.get_chat_target() @@ -62,17 +62,17 @@ async def _score_async( self, message: Message, *, - objective: Optional[str] = None, - role_filter: Optional[ChatMessageRole] = None, + objective: str | None = None, + role_filter: ChatMessageRole | None = None, ) -> list[Score]: """ Scores the piece using the underlying true-false scorer and returns the inverted score. Args: message (Message): The message to score. - objective (Optional[str]): The objective to evaluate against (the original attacker model's objective). + objective (str | None): The objective to evaluate against (the original attacker model's objective). Defaults to None. - role_filter (Optional[ChatMessageRole]): Optional filter for message roles. Defaults to None. + role_filter (ChatMessageRole | None): Optional filter for message roles. Defaults to None. Returns: list[Score]: A list containing a single Score object with the inverted true/false value. @@ -100,13 +100,13 @@ async def _score_async( return [inv_score] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Indicate that True False Inverter scorers do not support piecewise scoring. Args: message_piece (MessagePiece): Unused. - objective (Optional[str]): Unused. + objective (str | None): Unused. Raises: NotImplementedError: Always, since composite scoring operates at the response level. diff --git a/pyrit/score/true_false/true_false_score_aggregator.py b/pyrit/score/true_false/true_false_score_aggregator.py index 4a04313599..af97c9fdc8 100644 --- a/pyrit/score/true_false/true_false_score_aggregator.py +++ b/pyrit/score/true_false/true_false_score_aggregator.py @@ -51,7 +51,7 @@ def _create_aggregator( Args: name (str): Name of the aggregator variant. - result_func (Callable[[List[bool]], bool]): Function applied to the list of boolean values + result_func (Callable[[list[bool]], bool]): Function applied to the list of boolean values to compute the aggregation result. true_msg (str): Description to use when the result is True. false_msg (str): Description to use when the result is False. diff --git a/pyrit/score/true_false/true_false_scorer.py b/pyrit/score/true_false/true_false_scorer.py index 68183f544c..557d7c3424 100644 --- a/pyrit/score/true_false/true_false_scorer.py +++ b/pyrit/score/true_false/true_false_scorer.py @@ -56,7 +56,7 @@ def __init__( validator (ScorerPromptValidator): Custom validator. score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. - chat_target (Optional[PromptTarget]): Optional chat target used by the scorer, + chat_target (PromptTarget | None): Optional chat target used by the scorer, forwarded to the base class for validation against ``TARGET_REQUIREMENTS``. """ self._score_aggregator = score_aggregator @@ -117,7 +117,7 @@ def get_scorer_metrics(self) -> Optional["ObjectiveScorerMetrics"]: return find_objective_metrics_by_eval_hash(eval_hash=eval_hash, file_path=result_file) - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: """ Score the given request response asynchronously. @@ -128,7 +128,7 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non Args: message (Message): The message to score. - objective (Optional[str]): The objective to evaluate against. Defaults to None. + objective (str | None): The objective to evaluate against. Defaults to None. Returns: list[Score]: A list containing a single aggregated true/false Score, or an empty @@ -158,7 +158,7 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non ) ] - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: """ Build a single-element list containing a ``false`` score when no pieces could be scored. @@ -167,7 +167,7 @@ def _build_fallback_score(self, *, message: Message, objective: Optional[str]) - Args: message (Message): The message whose first piece is inspected for status. - objective (Optional[str]): The objective associated with this scoring call. + objective (str | None): The objective associated with this scoring call. Returns: list[Score]: A single-element list containing a ``false`` ``true_false`` score diff --git a/pyrit/score/true_false/video_true_false_scorer.py b/pyrit/score/true_false/video_true_false_scorer.py index d1895aa4bb..5c45eae477 100644 --- a/pyrit/score/true_false/video_true_false_scorer.py +++ b/pyrit/score/true_false/video_true_false_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier, MessagePiece, Score from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -30,11 +29,11 @@ def __init__( self, *, image_capable_scorer: TrueFalseScorer, - audio_scorer: Optional[TrueFalseScorer] = None, - num_sampled_frames: Optional[int] = None, - validator: Optional[ScorerPromptValidator] = None, - image_objective_template: Optional[str] = VideoHelper._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, - audio_objective_template: Optional[str] = None, + audio_scorer: TrueFalseScorer | None = None, + num_sampled_frames: int | None = None, + validator: ScorerPromptValidator | None = None, + image_objective_template: str | None = VideoHelper._DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, + audio_objective_template: str | None = None, ) -> None: """ Initialize the VideoTrueFalseScorer. @@ -94,7 +93,7 @@ def _build_identifier(self) -> ComponentIdentifier: }, ) - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: """ Score a single video piece by extracting frames and optionally audio, then aggregating their scores. diff --git a/pyrit/score/video_scorer.py b/pyrit/score/video_scorer.py index 4e8d86f4dd..ea69978357 100644 --- a/pyrit/score/video_scorer.py +++ b/pyrit/score/video_scorer.py @@ -6,7 +6,6 @@ import tempfile import uuid from pathlib import Path -from typing import Optional from pyrit.memory import CentralMemory from pyrit.models import MessagePiece, Score @@ -43,9 +42,9 @@ def __init__( self, *, image_capable_scorer: Scorer, - num_sampled_frames: Optional[int] = None, - image_objective_template: Optional[str] = _DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, - audio_objective_template: Optional[str] = None, + num_sampled_frames: int | None = None, + image_objective_template: str | None = _DEFAULT_IMAGE_OBJECTIVE_TEMPLATE, + audio_objective_template: str | None = None, ) -> None: """ Initialize the base video scorer. @@ -95,7 +94,7 @@ def _validate_audio_scorer(scorer: Scorer) -> None: f"Supported types: {scorer._validator._supported_data_types}" ) - async def _score_frames_async(self, *, message_piece: MessagePiece, objective: Optional[str] = None) -> list[Score]: + async def _score_frames_async(self, *, message_piece: MessagePiece, objective: str | None = None) -> list[Score]: """ Extract frames from video and score them. @@ -211,7 +210,7 @@ def _extract_frames(self, video_path: str) -> list[str]: return frame_paths async def _score_video_audio_async( - self, *, message_piece: MessagePiece, audio_scorer: Optional[Scorer] = None, objective: Optional[str] = None + self, *, message_piece: MessagePiece, audio_scorer: Scorer | None = None, objective: str | None = None ) -> list[Score]: """ Extract and score audio from the video. diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 0af45c4adf..51c1c30e9a 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -11,7 +11,7 @@ import pathlib from collections.abc import Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from pyrit.common.path import DEFAULT_CONFIG_PATH from pyrit.common.yaml_loadable import YamlLoadable @@ -29,8 +29,8 @@ # Type alias for YAML-serializable values that can be passed as initializer args # This matches what YAML can represent: primitives, lists, and nested dicts -YamlPrimitive = Union[str, int, float, bool, None] -YamlValue = Union[YamlPrimitive, list["YamlValue"], dict[str, "YamlValue"]] +YamlPrimitive = str | int | float | bool | None +YamlValue = YamlPrimitive | list["YamlValue"] | dict[str, "YamlValue"] # Mapping from snake_case config values to internal constants _MEMORY_DB_TYPE_MAP: dict[str, str] = { @@ -51,7 +51,7 @@ class InitializerConfig: """ name: str - args: Optional[dict[str, YamlValue]] = None + args: dict[str, YamlValue] | None = None @dataclass @@ -77,7 +77,7 @@ class ScenarioConfig: """ name: str - args: Optional[dict[str, YamlValue]] = None + args: dict[str, YamlValue] | None = None def _scenario_config_to_dict(config: ScenarioConfig) -> dict[str, Any]: @@ -137,16 +137,16 @@ class ConfigurationLoader(YamlLoadable): """ memory_db_type: str = "sqlite" - initializers: list[Union[str, dict[str, Any]]] = field(default_factory=list) - initialization_scripts: Optional[list[str]] = None - env_files: Optional[list[str]] = None + initializers: list[str | dict[str, Any]] = field(default_factory=list) + initialization_scripts: list[str] | None = None + env_files: list[str] | None = None silent: bool = False - operator: Optional[str] = None - operation: Optional[str] = None - scenario: Optional[Union[str, dict[str, Any]]] = None + operator: str | None = None + operation: str | None = None + scenario: str | dict[str, Any] | None = None max_concurrent_scenario_runs: int = 3 allow_custom_initializers: bool = False - server: Optional[dict[str, Any]] = None + server: dict[str, Any] | None = None extensions: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: @@ -229,7 +229,7 @@ def _normalize_scenario(self) -> None: ValueError: For any other shape. """ if self.scenario is None: - self._scenario_config: Optional[ScenarioConfig] = None + self._scenario_config: ScenarioConfig | None = None return if isinstance(self.scenario, str): @@ -263,7 +263,7 @@ def _normalize_server(self) -> None: ValueError: If ``server`` is not ``None`` or a dict, or if ``url`` is not a string. """ if self.server is None: - self._server_config: Optional[ServerConfig] = None + self._server_config: ServerConfig | None = None return if isinstance(self.server, dict): @@ -276,12 +276,12 @@ def _normalize_server(self) -> None: raise ValueError(f"Server entry must be a dict, got: {type(self.server).__name__}") @property - def server_config(self) -> Optional[ServerConfig]: + def server_config(self) -> ServerConfig | None: """The normalized ``server:`` block, or ``None`` when not configured.""" return self._server_config @property - def scenario_config(self) -> Optional[ScenarioConfig]: + def scenario_config(self) -> ScenarioConfig | None: """The normalized ``scenario:`` block, or ``None`` when not configured.""" return self._scenario_config @@ -313,12 +313,12 @@ def from_dict(cls, data: dict[str, Any]) -> "ConfigurationLoader": @staticmethod def load_with_overrides( - config_file: Optional[pathlib.Path] = None, + config_file: pathlib.Path | None = None, *, - memory_db_type: Optional[str] = None, - initializers: Optional[Sequence[Union[str, dict[str, Any]]]] = None, - initialization_scripts: Optional[Sequence[str]] = None, - env_files: Optional[Sequence[str]] = None, + memory_db_type: str | None = None, + initializers: Sequence[str | dict[str, Any]] | None = None, + initialization_scripts: Sequence[str] | None = None, + env_files: Sequence[str] | None = None, ) -> "ConfigurationLoader": """ Load configuration with optional overrides. @@ -487,7 +487,7 @@ def resolve_initializers(self) -> Sequence["PyRITInitializer"]: return resolved - def resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: + def resolve_initialization_scripts(self) -> Sequence[pathlib.Path] | None: """ Resolve initialization script paths. @@ -512,7 +512,7 @@ def resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: return resolved - def resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]: + def resolve_env_files(self) -> Sequence[pathlib.Path] | None: """ Resolve environment file paths. @@ -564,7 +564,7 @@ async def initialize_pyrit_async(self) -> None: async def initialize_from_config_async( - config_path: Optional[Union[str, pathlib.Path]] = None, + config_path: str | pathlib.Path | None = None, ) -> ConfigurationLoader: """ Initialize PyRIT from a configuration file. diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index 5d3dfe6663..98c5891c09 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -3,7 +3,7 @@ import logging import pathlib from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Literal, get_args import dotenv @@ -27,7 +27,7 @@ MemoryDatabaseType = Literal["InMemory", "SQLite", "AzureSQL"] -def _load_environment_files(env_files: Optional[Sequence[pathlib.Path]], *, silent: bool = False) -> None: +def _load_environment_files(env_files: Sequence[pathlib.Path] | None, *, silent: bool = False) -> None: """ Load environment files in the order they are provided. Later files override values from earlier files. @@ -95,9 +95,7 @@ def _print_msg(message: str, quiet: bool, log: bool) -> None: logger.info(message) -def _load_initializers_from_scripts( - *, script_paths: Sequence[Union[str, pathlib.Path]] -) -> Sequence["PyRITInitializer"]: +def _load_initializers_from_scripts(*, script_paths: Sequence[str | pathlib.Path]) -> Sequence["PyRITInitializer"]: """ Load PyRITInitializer instances from external Python files. @@ -105,7 +103,7 @@ def _load_initializers_from_scripts( that inherit from PyRITInitializer will be automatically discovered and instantiated. Args: - script_paths (Sequence[Union[str, pathlib.Path]]): Sequence of file paths to Python scripts to load. + script_paths (Sequence[str | pathlib.Path]): Sequence of file paths to Python scripts to load. Returns: Sequence[PyRITInitializer]: List of PyRITInitializer instances loaded from the scripts. @@ -228,11 +226,11 @@ async def _execute_initializers_async(*, initializers: Sequence["PyRITInitialize async def initialize_pyrit_async( - memory_db_type: Union[MemoryDatabaseType, str], + memory_db_type: MemoryDatabaseType | str, *, - initialization_scripts: Optional[Sequence[Union[str, pathlib.Path]]] = None, - initializers: Optional[Sequence["PyRITInitializer"]] = None, - env_files: Optional[Sequence[pathlib.Path]] = None, + initialization_scripts: Sequence[str | pathlib.Path] | None = None, + initializers: Sequence["PyRITInitializer"] | None = None, + env_files: Sequence[pathlib.Path] | None = None, silent: bool = False, **memory_instance_kwargs: Any, ) -> None: @@ -242,17 +240,17 @@ async def initialize_pyrit_async( Args: memory_db_type (MemoryDatabaseType): The MemoryDatabaseType string literal which indicates the memory instance to use for central memory. Options include "InMemory", "SQLite", and "AzureSQL". - initialization_scripts (Optional[Sequence[Union[str, pathlib.Path]]]): Optional sequence of Python script paths + initialization_scripts (Sequence[str | pathlib.Path] | None): Optional sequence of Python script paths that contain PyRITInitializer classes. Each script must define either a get_initializers() function or an 'initializers' variable that returns/contains a list of PyRITInitializer instances. - initializers (Optional[Sequence[PyRITInitializer]]): Optional sequence of PyRITInitializer instances + initializers (Sequence[PyRITInitializer] | None): Optional sequence of PyRITInitializer instances to execute directly. These provide type-safe, validated configuration with clear documentation. - env_files (Optional[Sequence[pathlib.Path]]): Optional sequence of environment file paths to load + env_files (Sequence[pathlib.Path] | None): Optional sequence of environment file paths to load in order. If not provided, will load default .env and .env.local files from PyRIT home if they exist. All paths must be valid pathlib.Path objects. silent (bool): If True, suppresses print statements about environment file loading and schema migration. Defaults to False. - **memory_instance_kwargs (Optional[Any]): Additional keyword arguments to pass to the memory instance. + **memory_instance_kwargs (Any | None): Additional keyword arguments to pass to the memory instance. Raises: ValueError: If an unsupported memory_db_type is provided or if env_files contains non-existent files. diff --git a/pyrit/setup/initializers/components/targets.py b/pyrit/setup/initializers/components/targets.py index fd0284c7d9..017bb873bf 100644 --- a/pyrit/setup/initializers/components/targets.py +++ b/pyrit/setup/initializers/components/targets.py @@ -17,7 +17,7 @@ from collections import defaultdict from dataclasses import dataclass, field from enum import Enum -from typing import Any, Optional +from typing import Any from pyrit.auth import get_azure_openai_auth, get_azure_token_provider from pyrit.common.parameter import Parameter @@ -71,9 +71,9 @@ class TargetConfig: target_class: type[PromptTarget] endpoint_var: str key_var: str = "" # Empty string means no auth required - model_var: Optional[str] = None - underlying_model_var: Optional[str] = None - temperature: Optional[float] = None + model_var: str | None = None + underlying_model_var: str | None = None + temperature: float | None = None extra_kwargs: dict[str, Any] = field(default_factory=dict) tags: list[TargetInitializerTags] = field(default_factory=lambda: [TargetInitializerTags.DEFAULT]) default_objective_target: bool = False diff --git a/pyrit/setup/initializers/pyrit_initializer.py b/pyrit/setup/initializers/pyrit_initializer.py index 2771cc90bc..d92707c92a 100644 --- a/pyrit/setup/initializers/pyrit_initializer.py +++ b/pyrit/setup/initializers/pyrit_initializer.py @@ -100,7 +100,7 @@ def required_env_vars(self) -> list[str]: set for this initializer to work correctly. Returns: - List[str]: List of required environment variable names. Defaults to empty list. + list[str]: List of required environment variable names. Defaults to empty list. """ return [] @@ -255,7 +255,7 @@ async def get_dynamic_default_values_info_async(self) -> dict[str, Any]: initializer has been run before or which instance is queried. Returns: - Dict[str, Any]: Information about what defaults and globals are set. + dict[str, Any]: Information about what defaults and globals are set. """ # Check if memory is initialized - required for running initialization in sandbox from pyrit.memory import CentralMemory @@ -334,7 +334,7 @@ async def get_info_async(cls) -> dict[str, Any]: await SimpleInitializer.get_info_async() instead of SimpleInitializer().get_info_async() Returns: - Dict[str, Any]: Dictionary containing name, description, class information, and default values. + dict[str, Any]: Dictionary containing name, description, class information, and default values. """ # Create a temporary instance to access properties instance = cls() diff --git a/tests/end_to_end/test_scenarios.py b/tests/end_to_end/test_scenarios.py index dcd37ecb82..56d0b44b08 100644 --- a/tests/end_to_end/test_scenarios.py +++ b/tests/end_to_end/test_scenarios.py @@ -53,7 +53,7 @@ def get_all_scenarios(): Dynamically discover all available scenarios from the scenario registry. Returns: - List[str]: Sorted list of scenario names. + list[str]: Sorted list of scenario names. """ registry = ScenarioRegistry.get_registry_singleton() return registry.get_names() diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index d66c57314c..7fe9fdd356 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from collections.abc import Generator -from typing import Optional from sqlalchemy import inspect @@ -57,8 +56,8 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[ComponentIdentifier] = None, - labels: Optional[dict[str, str]] = None, + attack_identifier: ComponentIdentifier | None = None, + labels: dict[str, str] | None = None, ) -> None: self.system_prompt = system_prompt if self._memory: diff --git a/tests/unit/analytics/test_result_analysis.py b/tests/unit/analytics/test_result_analysis.py index 21d18541ed..8edf1e8ef7 100644 --- a/tests/unit/analytics/test_result_analysis.py +++ b/tests/unit/analytics/test_result_analysis.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from datetime import datetime, timedelta, timezone -from typing import Optional from unittest.mock import MagicMock import pytest @@ -27,13 +26,13 @@ # helpers def make_attack( outcome: AttackOutcome, - attack_type: Optional[str] = "default", + attack_type: str | None = "default", conversation_id: str = "conv-1", ) -> AttackResult: """ Minimal valid AttackResult for analytics tests. """ - attack_identifier: Optional[ComponentIdentifier] = None + attack_identifier: ComponentIdentifier | None = None if attack_type is not None: attack_identifier = ComponentIdentifier(class_name=attack_type, class_module="tests.unit.analytics") @@ -190,7 +189,7 @@ def _make_attack_with_target( target: ComponentIdentifier, *, outcome: AttackOutcome = AttackOutcome.SUCCESS, - timestamp: Optional[datetime] = None, + timestamp: datetime | None = None, ) -> AttackResult: technique = ComponentIdentifier( class_name="PromptSendingAttack", diff --git a/tests/unit/common/test_pyrit_default_value.py b/tests/unit/common/test_pyrit_default_value.py index e29981a6f3..cbd9584293 100644 --- a/tests/unit/common/test_pyrit_default_value.py +++ b/tests/unit/common/test_pyrit_default_value.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional import pytest @@ -28,7 +27,7 @@ def test_no_defaults_configured_returns_none(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: self.param1 = param1 self.param2 = param2 @@ -41,7 +40,7 @@ def test_single_default_value_applied(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: self.param1 = param1 set_default_value(class_type=TestClass, parameter_name="param1", value="default_value") @@ -55,7 +54,7 @@ def test_multiple_default_values_applied(self) -> None: class TestClass: @apply_defaults def __init__( - self, *, param1: Optional[str] = None, param2: Optional[int] = None, param3: Optional[float] = None + self, *, param1: str | None = None, param2: int | None = None, param3: float | None = None ) -> None: self.param1 = param1 self.param2 = param2 @@ -75,7 +74,7 @@ def test_explicit_value_overrides_default(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: self.param1 = param1 self.param2 = param2 @@ -92,7 +91,7 @@ def test_partial_override_uses_remaining_defaults(self) -> None: class TestClass: @apply_defaults def __init__( - self, *, param1: Optional[str] = None, param2: Optional[int] = None, param3: Optional[float] = None + self, *, param1: str | None = None, param2: int | None = None, param3: float | None = None ) -> None: self.param1 = param1 self.param2 = param2 @@ -115,9 +114,9 @@ class TestClass: def __init__( self, *, - param_int: Optional[int] = None, - param_bool: Optional[bool] = None, - param_str: Optional[str] = None, + param_int: int | None = None, + param_bool: bool | None = None, + param_str: str | None = None, ) -> None: self.param_int = param_int self.param_bool = param_bool @@ -145,13 +144,13 @@ def test_subclass_inherits_parent_defaults(self) -> None: class ParentClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: self.param1 = param1 self.param2 = param2 class ChildClass(ParentClass): @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: super().__init__(param1=param1, param2=param2) set_default_value(class_type=ParentClass, parameter_name="param1", value="parent_value") @@ -166,13 +165,13 @@ def test_subclass_specific_defaults_override_parent(self) -> None: class ParentClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: self.param1 = param1 self.param2 = param2 class ChildClass(ParentClass): @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: super().__init__(param1=param1, param2=param2) set_default_value(class_type=ParentClass, parameter_name="param1", value="parent_value") @@ -189,19 +188,19 @@ def test_multiple_inheritance_levels(self) -> None: class GrandParent: @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: self.param1 = param1 class Parent(GrandParent): @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: super().__init__(param1=param1) self.param2 = param2 class Child(Parent): @apply_defaults def __init__( - self, *, param1: Optional[str] = None, param2: Optional[int] = None, param3: Optional[float] = None + self, *, param1: str | None = None, param2: int | None = None, param3: float | None = None ) -> None: super().__init__(param1=param1, param2=param2) self.param3 = param3 @@ -220,12 +219,12 @@ def test_parent_not_affected_by_child_defaults(self) -> None: class ParentClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: self.param1 = param1 class ChildClass(ParentClass): @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: super().__init__(param1=param1) set_default_value(class_type=ChildClass, parameter_name="param1", value="child_value") @@ -354,7 +353,7 @@ def test_set_default_value_stores_value(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: self.param1 = param1 set_default_value(class_type=TestClass, parameter_name="param1", value="stored_value") @@ -367,7 +366,7 @@ def test_set_default_value_overwrites_existing(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None) -> None: self.param1 = param1 set_default_value(class_type=TestClass, parameter_name="param1", value="first_value") @@ -392,9 +391,9 @@ class OpenAIChatTarget: def __init__( self, *, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, + temperature: float | None = None, + top_p: float | None = None, + max_tokens: int | None = None, ) -> None: self.temperature = temperature self.top_p = top_p @@ -405,9 +404,9 @@ class AzureOpenAIChatTarget(OpenAIChatTarget): def __init__( self, *, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, + temperature: float | None = None, + top_p: float | None = None, + max_tokens: int | None = None, ) -> None: super().__init__(temperature=temperature, top_p=top_p, max_tokens=max_tokens) @@ -441,12 +440,12 @@ def test_multiple_classes_independent_defaults(self) -> None: class ClassA: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param class ClassB: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param set_default_value(class_type=ClassA, parameter_name="param", value="value_a") @@ -471,7 +470,7 @@ def test_reset_clears_all_defaults(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[int] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: int | None = None) -> None: self.param1 = param1 self.param2 = param2 @@ -497,12 +496,12 @@ def test_reset_affects_multiple_classes(self) -> None: class ClassA: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param class ClassB: @apply_defaults - def __init__(self, *, param: Optional[int] = None) -> None: + def __init__(self, *, param: int | None = None) -> None: self.param = param # Set defaults for multiple classes @@ -523,7 +522,7 @@ def test_reset_allows_setting_new_defaults(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param # Set initial default @@ -544,7 +543,7 @@ def test_reset_with_no_defaults_does_nothing(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param # Reset when no defaults are set @@ -562,12 +561,12 @@ def test_reset_clears_inheritance_based_defaults(self) -> None: class ParentClass: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param class ChildClass(ParentClass): @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: super().__init__(param=param) # Set defaults for both parent and child @@ -588,7 +587,7 @@ def test_reset_clears_include_subclasses_flag_variations(self) -> None: class TestClass: @apply_defaults - def __init__(self, *, param1: Optional[str] = None, param2: Optional[str] = None) -> None: + def __init__(self, *, param1: str | None = None, param2: str | None = None) -> None: self.param1 = param1 self.param2 = param2 @@ -761,7 +760,7 @@ def __init__( self, *, required_param: str = REQUIRED_VALUE, # type: ignore[assignment] - optional_param: Optional[str] = None, + optional_param: str | None = None, ) -> None: self.required_param = required_param self.optional_param = optional_param @@ -847,7 +846,7 @@ def test_required_value_none_is_different(self) -> None: class TestClass1: @apply_defaults - def __init__(self, *, param: Optional[str] = None) -> None: + def __init__(self, *, param: str | None = None) -> None: self.param = param class TestClass2: diff --git a/tests/unit/datasets/test_decoding_trust_toxicity_dataset.py b/tests/unit/datasets/test_decoding_trust_toxicity_dataset.py index 363ae54ca9..47b8f4b30f 100644 --- a/tests/unit/datasets/test_decoding_trust_toxicity_dataset.py +++ b/tests/unit/datasets/test_decoding_trust_toxicity_dataset.py @@ -251,7 +251,7 @@ async def test_metadata_round_trip(self): meta = dataset.seeds[0].metadata assert meta is not None - # challenging stored as bool (bool ≤ int so dict[str, Union[str, int]] accepts it) + # challenging stored as bool (bool ≤ int so dict[str, str | int] accepts it) assert meta["challenging"] is True # All eight Perspective scores stringified at full precision for key in _PERSPECTIVE_SCORE_KEYS: diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index ed56813139..47d8678b7d 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -18,7 +18,6 @@ """ import uuid -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -56,7 +55,7 @@ class _TestAttackContext(AttackContext): """Concrete AttackContext for testing.""" # Add last_score to match MultiTurnAttackContext behavior for testing - last_score: Optional[Score] = None + last_score: Score | None = None # ============================================================================= diff --git a/tests/unit/executor/attack/compound/test_sequential_attack.py b/tests/unit/executor/attack/compound/test_sequential_attack.py index 79865e8f55..68cf8180d3 100644 --- a/tests/unit/executor/attack/compound/test_sequential_attack.py +++ b/tests/unit/executor/attack/compound/test_sequential_attack.py @@ -3,7 +3,6 @@ """Tests for ``SequentialAttack``.""" -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -35,7 +34,7 @@ def _make_seed_group(objective: str = "obj") -> SeedAttackGroup: def _make_context( *, objective: str = "obj", - labels: Optional[dict[str, str]] = None, + labels: dict[str, str] | None = None, ) -> AttackContext[AttackParameters]: params_type = AttackParameters.excluding("next_message", "prepended_conversation") return AttackContext(params=params_type(objective=objective, memory_labels=labels or {})) diff --git a/tests/unit/executor/attack/core/test_attack_parameters.py b/tests/unit/executor/attack/core/test_attack_parameters.py index 5e8ca975c2..a0d7ab60ca 100644 --- a/tests/unit/executor/attack/core/test_attack_parameters.py +++ b/tests/unit/executor/attack/core/test_attack_parameters.py @@ -134,7 +134,7 @@ def mock_objective_scorer(self) -> MagicMock: @pytest.fixture def mock_simulated_result(self) -> list: - """Create a mock simulated conversation result (List[SeedPrompt]).""" + """Create a mock simulated conversation result (list[SeedPrompt]).""" return [ SeedPrompt(value="Simulated user message", data_type="text", role="user", sequence=0), SeedPrompt(value="Simulated assistant response", data_type="text", role="assistant", sequence=1), diff --git a/tests/unit/executor/attack/multi_turn/test_crescendo.py b/tests/unit/executor/attack/multi_turn/test_crescendo.py index 1bd1c38f46..e8e295ba88 100644 --- a/tests/unit/executor/attack/multi_turn/test_crescendo.py +++ b/tests/unit/executor/attack/multi_turn/test_crescendo.py @@ -4,7 +4,6 @@ import json import uuid from pathlib import Path -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -85,11 +84,11 @@ def create_score( *, score_type: ScoreType, score_value: str, - score_category: Optional[list[str]] = None, + score_category: list[str] | None = None, scorer_class: str, score_rationale: str = "Test rationale", score_value_description: str = "Test description", - score_metadata: Optional[dict] = None, + score_metadata: dict | None = None, ) -> Score: """Create a score with common defaults. @@ -254,10 +253,10 @@ def create_attack( *, objective_target: MagicMock, adversarial_chat: MagicMock, - objective_scorer: Optional[MagicMock] = None, - refusal_scorer: Optional[MagicMock] = None, - prompt_normalizer: Optional[MagicMock] = None, - system_prompt_path: Optional[Path] = None, + objective_scorer: MagicMock | None = None, + refusal_scorer: MagicMock | None = None, + prompt_normalizer: MagicMock | None = None, + system_prompt_path: Path | None = None, **kwargs, ) -> CrescendoAttack: """Create a CrescendoAttack instance with flexible configuration. @@ -909,7 +908,7 @@ async def test_parse_adversarial_response_with_various_inputs( mock_objective_target: MagicMock, mock_adversarial_chat: MagicMock, response_json: str, - expected_error: Optional[str], + expected_error: str | None, ): """Test parsing adversarial response with various inputs. diff --git a/tests/unit/executor/attack/multi_turn/test_red_teaming.py b/tests/unit/executor/attack/multi_turn/test_red_teaming.py index 173b9c2f22..e5e14f1eff 100644 --- a/tests/unit/executor/attack/multi_turn/test_red_teaming.py +++ b/tests/unit/executor/attack/multi_turn/test_red_teaming.py @@ -3,7 +3,6 @@ import uuid from pathlib import Path -from typing import Union from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -215,7 +214,7 @@ def test_init_with_seed_prompt_variations( mock_objective_target: MagicMock, mock_objective_scorer: MagicMock, mock_adversarial_chat: MagicMock, - seed_prompt: Union[str, SeedPrompt], + seed_prompt: str | SeedPrompt, expected_value: str, expected_type: type, ): diff --git a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py index 5c805455bf..aebd00702b 100644 --- a/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_tree_of_attacks.py @@ -6,7 +6,7 @@ import logging import uuid from dataclasses import dataclass, field -from typing import Any, Optional, cast +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -50,11 +50,11 @@ class NodeMockConfig: """Configuration for creating mock _TreeOfAttacksNode objects.""" node_id: str = field(default_factory=lambda: str(uuid.uuid4())) - parent_id: Optional[str] = None + parent_id: str | None = None prompt_sent: bool = False completed: bool = True off_topic: bool = False - objective_score_value: Optional[float] = None + objective_score_value: float | None = None auxiliary_scores: dict[str, float] = field(default_factory=dict) objective_target_conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) adversarial_chat_conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -64,7 +64,7 @@ class MockNodeFactory: """Factory for creating mock _TreeOfAttacksNode objects.""" @staticmethod - def create_node(config: Optional[NodeMockConfig] = None) -> "_TreeOfAttacksNode": + def create_node(config: NodeMockConfig | None = None) -> "_TreeOfAttacksNode": """Create a mock _TreeOfAttacksNode with the given configuration.""" if config is None: config = NodeMockConfig() @@ -150,14 +150,14 @@ class AttackBuilder: """Builder for creating TreeOfAttacksWithPruningAttack instances with common configurations.""" def __init__(self) -> None: - self.objective_target: Optional[PromptTarget] = None - self.adversarial_chat: Optional[PromptTarget] = None - self.objective_scorer: Optional[Scorer] = None + self.objective_target: PromptTarget | None = None + self.adversarial_chat: PromptTarget | None = None + self.objective_scorer: Scorer | None = None self.auxiliary_scorers: list[Scorer] = [] self.tree_params: dict[str, Any] = {} - self.converters: Optional[AttackConverterConfig] = None + self.converters: AttackConverterConfig | None = None self.successful_threshold: float = 0.8 - self.prompt_normalizer: Optional[PromptNormalizer] = None + self.prompt_normalizer: PromptNormalizer | None = None self._supports_multi_turn: bool = True def with_default_mocks(self) -> "AttackBuilder": diff --git a/tests/unit/executor/attack/test_attack_parameter_consistency.py b/tests/unit/executor/attack/test_attack_parameter_consistency.py index a23dd6daca..963e84d3e4 100644 --- a/tests/unit/executor/attack/test_attack_parameter_consistency.py +++ b/tests/unit/executor/attack/test_attack_parameter_consistency.py @@ -10,7 +10,6 @@ import uuid from contextlib import suppress -from typing import Optional from unittest.mock import AsyncMock, MagicMock import pytest @@ -911,7 +910,7 @@ def _assert_prepended_text_in_adversarial_context( *, prepended_conversation: list[Message], adversarial_chat_conversation_id: str, - adversarial_chat_mock: Optional[MagicMock] = None, + adversarial_chat_mock: MagicMock | None = None, ) -> None: """ Assert that text content from prepended conversation appears in adversarial chat context. diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index cdb611e64b..61fbc860b9 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -3,7 +3,7 @@ import uuid -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import pytest @@ -1270,7 +1270,7 @@ def test_get_unique_attack_labels_deduplicates_across_sources(sqlite_instance: M def _make_attack_result_with_identifier( conversation_id: str, class_name: str, - converter_class_names: Optional[list[str]] = None, + converter_class_names: list[str] | None = None, ) -> AttackResult: """Helper to create an AttackResult with a ComponentIdentifier containing converters.""" children: dict = {} diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 4469de670b..f818e45ecd 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. from datetime import datetime, timedelta, timezone -from typing import Optional import pytest from unit.mocks import get_mock_scorer_identifier @@ -42,7 +41,7 @@ def create_scenario_result( name: str = "Test Scenario", description: str = "Test Description", version: int = 1, - attack_results: Optional[dict[str, list[AttackResult]]] = None, + attack_results: dict[str, list[AttackResult]] | None = None, ): """Helper function to create ScenarioResult.""" scenario_identifier = ScenarioIdentifier( diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 9363c88639..dbd1a8a4d4 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -7,7 +7,6 @@ import uuid from collections.abc import Generator, MutableSequence, Sequence from contextlib import AbstractAsyncContextManager -from typing import Optional from unittest.mock import MagicMock, patch from pyrit.memory import AzureSQLMemory, CentralMemory, PromptMemoryEntry @@ -140,8 +139,8 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[ComponentIdentifier] = None, - labels: Optional[dict[str, str]] = None, + attack_identifier: ComponentIdentifier | None = None, + labels: dict[str, str] | None = None, ) -> None: self.system_prompt = system_prompt if self._memory: diff --git a/tests/unit/prompt_target/target/test_openai_target_auth.py b/tests/unit/prompt_target/target/test_openai_target_auth.py index c92614a61d..99cc42fc3e 100644 --- a/tests/unit/prompt_target/target/test_openai_target_auth.py +++ b/tests/unit/prompt_target/target/test_openai_target_auth.py @@ -4,7 +4,6 @@ import asyncio import os from collections.abc import Callable -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -40,8 +39,8 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation): def _build_target( *, endpoint: str = "https://test.openai.azure.com/openai/v1", - api_key: Optional[str | Callable] = "test-key", - env_vars: Optional[dict[str, str]] = None, + api_key: str | Callable | None = "test-key", + env_vars: dict[str, str] | None = None, ) -> _ConcreteOpenAITarget: """Helper to build a _ConcreteOpenAITarget with controlled env.""" env = {"TEST_MODEL": "gpt-4", "TEST_ENDPOINT": endpoint} diff --git a/tests/unit/registry/test_scorer_registry.py b/tests/unit/registry/test_scorer_registry.py index b0d0f3c5d5..6e458ec3eb 100644 --- a/tests/unit/registry/test_scorer_registry.py +++ b/tests/unit/registry/test_scorer_registry.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier, Message, MessagePiece, Score from pyrit.registry.object_registries.scorer_registry import ScorerRegistry @@ -35,10 +34,10 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: return [] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] def validate_return_scores(self, scores: list[Score]): @@ -59,10 +58,10 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: return [] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] def validate_return_scores(self, scores: list[Score]): @@ -83,16 +82,16 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: return [] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] def validate_return_scores(self, scores: list[Score]): pass - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: return [ Score( score_value="false", diff --git a/tests/unit/score/test_audio_scorer.py b/tests/unit/score/test_audio_scorer.py index b228e71cb3..88b43f4cd2 100644 --- a/tests/unit/score/test_audio_scorer.py +++ b/tests/unit/score/test_audio_scorer.py @@ -4,7 +4,6 @@ import os import tempfile import uuid -from typing import Optional from unittest.mock import AsyncMock, patch import pytest @@ -29,7 +28,7 @@ def __init__(self, *, return_value: bool = True): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_type="true_false", @@ -56,7 +55,7 @@ def __init__(self, *, return_value: float = 0.8): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_type="float_scale", diff --git a/tests/unit/score/test_conversation_history_scorer.py b/tests/unit/score/test_conversation_history_scorer.py index dc0be088e8..0e957482a2 100644 --- a/tests/unit/score/test_conversation_history_scorer.py +++ b/tests/unit/score/test_conversation_history_scorer.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import uuid -from typing import Optional from unittest.mock import AsyncMock, MagicMock import pytest @@ -37,7 +36,7 @@ def __init__(self): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] @@ -50,7 +49,7 @@ def __init__(self): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] @@ -63,13 +62,13 @@ def __init__(self): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] def validate_return_scores(self, scores: list[Score]): pass - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: return [ Score( score_value="false", @@ -753,7 +752,7 @@ def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() async def _score_async( # type: ignore[override] - self, message: Message, *, objective: Optional[str] = None + self, message: Message, *, objective: str | None = None ) -> list[Score]: captured_messages.append(message) piece = message.message_pieces[0] @@ -773,9 +772,7 @@ async def _score_async( # type: ignore[override] ] return [] - async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None - ) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] inner_scorer = HarmfulContentDetector() diff --git a/tests/unit/score/test_float_scale_score_aggregator.py b/tests/unit/score/test_float_scale_score_aggregator.py index c726cd331a..19fac4bbff 100644 --- a/tests/unit/score/test_float_scale_score_aggregator.py +++ b/tests/unit/score/test_float_scale_score_aggregator.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from pyrit.models import ComponentIdentifier, Score from pyrit.score.float_scale.float_scale_score_aggregator import ( @@ -17,7 +16,7 @@ ) -def _mk_score(val: float, *, category: Optional[list[str]] = None, prr_id: str = "1", rationale: str = "") -> Score: +def _mk_score(val: float, *, category: list[str] | None = None, prr_id: str = "1", rationale: str = "") -> Score: """Helper to create a float scale score.""" return Score( score_value=str(val), diff --git a/tests/unit/score/test_float_scale_threshold_scorer.py b/tests/unit/score/test_float_scale_threshold_scorer.py index de91ff5d62..b98cb183d8 100644 --- a/tests/unit/score/test_float_scale_threshold_scorer.py +++ b/tests/unit/score/test_float_scale_threshold_scorer.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import uuid -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -237,9 +236,7 @@ def __init__(self): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None - ) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_value="0.9", diff --git a/tests/unit/score/test_gandalf_scorer.py b/tests/unit/score/test_gandalf_scorer.py index 13ecaeed06..e47a4c39cd 100644 --- a/tests/unit/score/test_gandalf_scorer.py +++ b/tests/unit/score/test_gandalf_scorer.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import uuid -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -15,7 +14,7 @@ from pyrit.score import GandalfScorer -def generate_password_extraction_response(response_text: str, conversation_id: Optional[str] = None) -> Message: +def generate_password_extraction_response(response_text: str, conversation_id: str | None = None) -> Message: return Message( message_pieces=[ MessagePiece( @@ -30,7 +29,7 @@ def generate_password_extraction_response(response_text: str, conversation_id: O ) -def generate_request(conversation_id: Optional[str] = None) -> Message: +def generate_request(conversation_id: str | None = None) -> Message: return Message( message_pieces=[ MessagePiece( diff --git a/tests/unit/score/test_scorer.py b/tests/unit/score/test_scorer.py index 9aab662499..01378ebde3 100644 --- a/tests/unit/score/test_scorer.py +++ b/tests/unit/score/test_scorer.py @@ -4,7 +4,6 @@ import asyncio import uuid from textwrap import dedent -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -66,7 +65,7 @@ def _build_identifier(self) -> ComponentIdentifier: """Build the scorer evaluation identifier for this mock scorer.""" return self._create_identifier() - async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + async def _score_async(self, message: Message, *, objective: str | None = None) -> list[Score]: return [ Score( score_value="true", @@ -81,7 +80,7 @@ async def _score_async(self, message: Message, *, objective: Optional[str] = Non ) ] - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_value="true", @@ -122,7 +121,7 @@ def _build_identifier(self) -> ComponentIdentifier: """Build the scorer evaluation identifier for this mock scorer.""" return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: # Track which pieces get scored self.scored_piece_ids.append(str(message_piece.id)) @@ -144,7 +143,7 @@ def validate_return_scores(self, scores: list[Score]): for score in scores: assert 0 <= float(score.score_value) <= 1 - def _build_fallback_score(self, *, message: Message, objective: Optional[str]) -> list[Score]: + def _build_fallback_score(self, *, message: Message, objective: str | None) -> list[Score]: return [ Score( score_value="0.0", @@ -1168,9 +1167,7 @@ def _build_identifier(self) -> ComponentIdentifier: """Build the scorer evaluation identifier for this test scorer.""" return self._create_identifier() - async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None - ) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: self.scored_piece_ids.append(message_piece.id) return [ Score( @@ -1355,7 +1352,7 @@ def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None + self, message_piece: MessagePiece, *, objective: str | None = None ) -> list[Score]: # Return empty list to simulate no scorable pieces return [] @@ -1482,7 +1479,7 @@ def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None + self, message_piece: MessagePiece, *, objective: str | None = None ) -> list[Score]: return [] @@ -1622,7 +1619,7 @@ async def test_score_value_with_llm_skips_reasoning_piece(good_json): class _AcceptAllValidator(ScorerPromptValidator): """Validator that accepts all pieces (like SelfAskRefusalScorer's default).""" - def validate(self, message: Message, objective: Optional[str] = None) -> None: + def validate(self, message: Message, objective: str | None = None) -> None: pass def is_message_piece_supported(self, message_piece: MessagePiece) -> bool: @@ -1635,21 +1632,21 @@ class _TextOnlyValidator(ScorerPromptValidator): def __init__(self) -> None: super().__init__(supported_data_types=["text", "image_path"]) - def validate(self, message: Message, objective: Optional[str] = None) -> None: + def validate(self, message: Message, objective: str | None = None) -> None: pass class _BlockedContentScorer(TrueFalseScorer): """A mock TrueFalseScorer that records what pieces it was asked to score.""" - def __init__(self, *, validator: Optional[ScorerPromptValidator] = None) -> None: + def __init__(self, *, validator: ScorerPromptValidator | None = None) -> None: super().__init__(validator=validator or _TextOnlyValidator()) self.scored_pieces: list[MessagePiece] = [] def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: self.scored_pieces.append(message_piece) return [ Score( @@ -1676,7 +1673,7 @@ def __init__(self) -> None: def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: self.scored_pieces.append(message_piece) if message_piece.response_error == "blocked": return [ @@ -1707,7 +1704,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op ] -def _make_blocked_piece(*, partial_content: Optional[str] = None, conversation_id: str = "test-convo") -> MessagePiece: +def _make_blocked_piece(*, partial_content: str | None = None, conversation_id: str = "test-convo") -> MessagePiece: """Create a blocked MessagePiece, optionally with partial content metadata.""" metadata: dict = {} if partial_content is not None: diff --git a/tests/unit/score/test_true_false_composite_scorer.py b/tests/unit/score/test_true_false_composite_scorer.py index 968b1154e6..5f9458161e 100644 --- a/tests/unit/score/test_true_false_composite_scorer.py +++ b/tests/unit/score/test_true_false_composite_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Optional from unittest.mock import MagicMock import pytest @@ -46,7 +45,7 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_value=str(self._score_value), @@ -154,9 +153,7 @@ def __init__(self): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async( - self, message_piece: MessagePiece, *, objective: Optional[str] = None - ) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [] with pytest.raises(ValueError, match="All scorers must be true_false scorers"): diff --git a/tests/unit/score/test_video_scorer.py b/tests/unit/score/test_video_scorer.py index 9c55be1c7e..e60d62bc56 100644 --- a/tests/unit/score/test_video_scorer.py +++ b/tests/unit/score/test_video_scorer.py @@ -3,7 +3,6 @@ import os import uuid -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import numpy as np @@ -74,7 +73,7 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_type="true_false", @@ -106,7 +105,7 @@ def _build_identifier(self) -> ComponentIdentifier: """ return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: return [ Score( score_type="float_scale", @@ -295,7 +294,7 @@ def __init__(self, *, return_value: bool = True): def _build_identifier(self) -> ComponentIdentifier: return self._create_identifier() - async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: str | None = None) -> list[Score]: self.received_objective = objective return [ Score(