Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/instructions/style-guide.instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def validate_input(self, data: dict) -> None: # Should be private
- `str | None` not `Optional[str]`
- `int | float` not `Union[int, float]`
- Still import `Any`, `Literal`, `TypeVar`, `Protocol`, `cast` etc. from `typing` as needed
- **This rule applies to docstrings and comments too.** Argument type references inside docstrings (e.g. `Args:` blocks) and any comment mentioning a type should use the modern form so the docs stay consistent with the signatures.

```python
# CORRECT
Expand Down
2 changes: 1 addition & 1 deletion doc/blog/2024_12_3.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ It turns out, yes, we can. `CrescendoOrchestrator`, `PairOrchestrator`, `RedTeam
- `max_turns` defines the maximum number of conversation turns.
- `prompt_converters` are used to modify prompts before sending them to the target.
- `objective_scorer` evaluates whether the objective was achieved.
- `run_attack_async(objective: str, memory_labels: Optional[dict[str, str]] = None)` executes the attack and always returns a `OrchestratorResult`, which includes information about the conversation and the outcome.
- `run_attack_async(objective: str, memory_labels: dict[str, str] | None = None)` executes the attack and always returns a `OrchestratorResult`, which includes information about the conversation and the outcome.
- `run_attacks_async` enables parallelized attacks.
- `print_conversation_async` is now standardized and prints the "best" conversation (when multiple exist).

Expand Down
2 changes: 1 addition & 1 deletion doc/code/memory/0_memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ At the beginning of each notebook, make sure to call:
# Import the specific constant for the MemoryDatabaseType, or provide the literal value
from pyrit.setup import initialize_pyrit_async, IN_MEMORY, SQLITE, AZURE_SQL

await initialize_pyrit_async(memory_db_type: MemoryDatabaseType, memory_instance_kwargs: Optional[Any])
await initialize_pyrit_async(memory_db_type: MemoryDatabaseType, memory_instance_kwargs: Any | None)
```

The `MemoryDatabaseType` is a `Literal` with 3 options: IN_MEMORY, SQLITE, AZURE_SQL. (Read more below)
Expand Down
4 changes: 2 additions & 2 deletions doc/code/registry/0_registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ PyRIT has two registry patterns for different use cases:

| Type | Stores | Use Case |
|------|--------|----------|
| **Class Registry** | Classes (Type[T]) | Components instantiated with user-provided parameters |
| **Class Registry** | Classes (type[T]) | Components instantiated with user-provided parameters |
| **Instance Registry** | Pre-configured instances | Components requiring complex setup before use |

## Common API (RegistryProtocol)
Expand Down Expand Up @@ -44,7 +44,7 @@ def show_registry_contents(registry: RegistryProtocol) -> None:

| Aspect | Class Registry | Instance Registry |
|--------|----------------|-------------------|
| Stores | Classes (Type[T]) | Instances (T) |
| Stores | Classes (type[T]) | Instances (T) |
| Registration | Automatic discovery | Explicit via `register()` |
| Returns | Class to instantiate | Ready-to-use instance |
| Instantiation | Caller provides parameters | Pre-configured by initializer |
Expand Down
2 changes: 1 addition & 1 deletion doc/code/setup/default_values.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ from pyrit.common.apply_defaults import apply_defaults

class MyConverter(PromptConverter):
@apply_defaults
def __init__(self, *, converter_target: Optional[PromptTarget] = None, temperature: Optional[float] = None):
def __init__(self, *, converter_target: PromptTarget | None = None, temperature: float | None = None):
self.converter_target = converter_target
self.temperature = temperature
```
Expand Down
4 changes: 2 additions & 2 deletions doc/code/targets/11_message_normalizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
"## Base Classes\n",
"\n",
"There are two base normalizer types:\n",
"- **`MessageListNormalizer[T]`**: Converts `List[Message]` → `List[T]` (e.g., to `ChatMessage` objects)\n",
"- **`MessageStringNormalizer`**: Converts `List[Message]` → `str` (e.g., to ChatML format)\n",
"- **`MessageListNormalizer[T]`**: Converts `list[Message]` → `list[T]` (e.g., to `ChatMessage` objects)\n",
"- **`MessageStringNormalizer`**: Converts `list[Message]` → `str` (e.g., to ChatML format)\n",
"\n",
"Some normalizers implement both interfaces."
]
Expand Down
4 changes: 2 additions & 2 deletions doc/code/targets/11_message_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
# ## Base Classes
#
# There are two base normalizer types:
# - **`MessageListNormalizer[T]`**: Converts `List[Message]` → `List[T]` (e.g., to `ChatMessage` objects)
# - **`MessageStringNormalizer`**: Converts `List[Message]` → `str` (e.g., to ChatML format)
# - **`MessageListNormalizer[T]`**: Converts `list[Message]` → `list[T]` (e.g., to `ChatMessage` objects)
# - **`MessageStringNormalizer`**: Converts `list[Message]` → `str` (e.g., to ChatML format)
#
# Some normalizers implement both interfaces.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,12 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Union\n",
"\n",
"from azure.ai.ml import MLClient\n",
"from azure.core.exceptions import ResourceNotFoundError\n",
"from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential\n",
"\n",
"try:\n",
" credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential()\n",
" credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential()\n",
" credential.get_token(\"https://management.azure.com/.default\")\n",
"except Exception as ex:\n",
" credential = InteractiveBrowserCredential()\n",
Expand Down
3 changes: 1 addition & 2 deletions doc/getting_started/troubleshooting/deploy_hf_model_aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,13 @@
# Set up the `DefaultAzureCredential` for seamless authentication with Azure services. This method should handle most authentication scenarios. If you encounter issues, refer to the [Azure Identity documentation](https://docs.microsoft.com/en-us/python/api/azure-identity/azure.identity?view=azure-python) for alternative credentials.
#
# %%
from typing import Union

from azure.ai.ml import MLClient
from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential

try:
credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential()
credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential()
credential.get_token("https://management.azure.com/.default")
except Exception as ex:
credential = InteractiveBrowserCredential()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
"source": [
"# Import the Azure ML SDK components required for workspace connection and model management.\n",
"import os\n",
"from typing import Union\n",
"\n",
"# Import necessary libraries for Azure ML operations and authentication\n",
"from azure.ai.ml import MLClient, UserIdentityConfiguration\n",
Expand Down Expand Up @@ -201,7 +200,7 @@
"source": [
"# Setup Azure credentials, preferring DefaultAzureCredential and falling back to InteractiveBrowserCredential if necessary\n",
"try:\n",
" credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential()\n",
" credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential()\n",
" # Verify if the default credential can fetch a token successfully\n",
" credential.get_token(\"https://management.azure.com/.default\")\n",
"except Exception as ex:\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
# %%
# Import the Azure ML SDK components required for workspace connection and model management.
import os
from typing import Union

# Import necessary libraries for Azure ML operations and authentication
from azure.ai.ml import MLClient, UserIdentityConfiguration
Expand Down Expand Up @@ -160,7 +159,7 @@
# %%
# Setup Azure credentials, preferring DefaultAzureCredential and falling back to InteractiveBrowserCredential if necessary
try:
credential: Union[DefaultAzureCredential, InteractiveBrowserCredential] = DefaultAzureCredential()
credential: DefaultAzureCredential | InteractiveBrowserCredential = DefaultAzureCredential()
# Verify if the default credential can fetch a token successfully
credential.get_token("https://management.azure.com/.default")
except Exception as ex:
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,6 @@ ignore = [
"DOC502", # Raised exception is not explicitly raised
"PERF203", # try-except-in-loop (intentional per-item error handling)
"SIM117", # multiple-with-statements (combining often exceeds line length)
"UP007", # non-pep604-annotation-union (keep Union[X, Y] syntax)
"UP045", # non-pep604-annotation-optional (keep Optional[X] syntax)
]
extend-select = [
"D204", # 1 blank line required after class docstring
Expand Down
4 changes: 2 additions & 2 deletions pyrit/analytics/conversation_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def get_similar_chat_messages_by_embedding(
Retrieve chat messages that are similar to the given embedding based on cosine similarity.

Args:
chat_message_embedding (List[float]): The embedding of the chat message to find similar messages for.
chat_message_embedding (list[float]): The embedding of the chat message to find similar messages for.
threshold (float): The similarity threshold for considering messages as similar. Defaults to 0.8.

Returns:
List[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing
list[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing
the similar chat messages based on embedding similarity.
"""
all_embdedding_memory = self.memory_interface.get_all_embeddings()
Expand Down
12 changes: 6 additions & 6 deletions pyrit/analytics/result_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from pyrit.models import (
AttackOutcome,
Expand All @@ -22,7 +22,7 @@
class AttackStats:
"""Statistics for attack analysis results."""

success_rate: Optional[float]
success_rate: float | None
total_decided: int
successes: int
failures: int
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_cached_results_for_technique(
*,
technique_eval_hash: str,
objective_target_eval_hash: str,
additional_filters: Optional[Sequence[IdentifierFilter]] = None,
additional_filters: Sequence[IdentifierFilter] | None = None,
) -> list[AttackResult]:
"""
Return cached AttackResults matching a (technique × objective target) pair.
Expand All @@ -144,7 +144,7 @@ def get_cached_results_for_technique(
(also exposed as ``AtomicAttack.technique_eval_hash``).
objective_target_eval_hash (str): Behavioral eval hash of the objective
target, as produced by ``ObjectiveTargetEvaluationIdentifier.eval_hash``.
additional_filters (Optional[Sequence[IdentifierFilter]]): Extra
additional_filters (Sequence[IdentifierFilter] | None): Extra
``IdentifierFilter`` predicates appended to the SQL pre-filter.
Defaults to None.

Expand All @@ -170,7 +170,7 @@ def get_cached_results_for_technique(
return matches


def _objective_target_eval_hash_for(attack_result: AttackResult) -> Optional[str]:
def _objective_target_eval_hash_for(attack_result: AttackResult) -> str | None:
"""
Return the ObjectiveTargetEvaluationIdentifier eval hash for a result.

Expand All @@ -182,7 +182,7 @@ def _objective_target_eval_hash_for(attack_result: AttackResult) -> Optional[str
``atomic_attack_identifier`` tree should be inspected.

Returns:
Optional[str]: The ``ObjectiveTargetEvaluationIdentifier.eval_hash``
str | None: The ``ObjectiveTargetEvaluationIdentifier.eval_hash``
computed from the persisted objective-target identifier, or
``None`` when the identifier tree is missing expected nodes
(e.g. legacy rows or atomic attacks without a distinct objective
Expand Down
16 changes: 8 additions & 8 deletions pyrit/auth/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import inspect
import logging
import time
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, cast

import msal
from azure.core.credentials import AccessToken
Expand Down Expand Up @@ -41,7 +41,7 @@ class TokenProviderCredential:
get_azure_token_provider) and Azure SDK clients that require a TokenCredential object.
"""

def __init__(self, token_provider: Callable[[], Union[str, Callable[..., Any]]]) -> None:
def __init__(self, token_provider: Callable[[], str | Callable[..., Any]]) -> None:
"""
Initialize TokenProviderCredential.

Expand Down Expand Up @@ -75,7 +75,7 @@ class AsyncTokenProviderCredential:
async clients that require an AsyncTokenCredential object (with async def get_token).
"""

def __init__(self, token_provider: Callable[[], Union[str, Awaitable[str]]]) -> None:
def __init__(self, token_provider: Callable[[], str | Awaitable[str]]) -> None:
"""
Initialize AsyncTokenProviderCredential.

Expand Down Expand Up @@ -394,14 +394,14 @@ def get_azure_openai_auth(endpoint: str) -> Callable[[], Awaitable[str]]:
return get_azure_async_token_provider(scope)


def get_speech_config(resource_id: Union[str, None], key: Union[str, None], region: str) -> speechsdk.SpeechConfig:
def get_speech_config(resource_id: str | None, key: str | None, region: str) -> speechsdk.SpeechConfig:
"""
Get the speech config using key/region pair (for key auth scenarios) or resource_id/region pair
(for Entra auth scenarios).

Args:
resource_id (Union[str, None]): The resource ID to get the token for.
key (Union[str, None]): The Azure Speech key
resource_id (str | None): The resource ID to get the token for.
key (str | None): The Azure Speech key
region (str): The region to get the token for.

Returns:
Expand Down Expand Up @@ -437,8 +437,8 @@ def get_speech_config(resource_id: Union[str, None], key: Union[str, None], regi
async def get_speech_config_async(
*,
token_provider: Callable[[], str | Awaitable[str]] | None,
resource_id: Union[str, None],
key: Union[str, None],
resource_id: str | None,
key: str | None,
region: str,
) -> speechsdk.SpeechConfig:
"""
Expand Down
22 changes: 11 additions & 11 deletions pyrit/auth/copilot_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import sys
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
from typing import Any

from msal_extensions import FilePersistence, build_encrypted_persistence

Expand Down Expand Up @@ -196,7 +196,7 @@ def _create_persistent_cache(cache_file: str, fallback_to_plaintext: bool = Fals
logger.error(f"Encryption unavailable ({e}) and fallback_to_plaintext is False. Cannot proceed.")
raise

async def _get_cached_token_if_available_and_valid_async(self) -> Optional[dict[str, Any]]:
async def _get_cached_token_if_available_and_valid_async(self) -> dict[str, Any] | None:
"""
Retrieve and validate cached token.

Expand Down Expand Up @@ -258,7 +258,7 @@ async def _get_cached_token_if_available_and_valid_async(self) -> Optional[dict[
logger.error(f"Failed to load cached token ({error_name}): {e}")
return None

def _save_token_to_cache(self, *, token: str, expires_in: Optional[int] = None) -> None:
def _save_token_to_cache(self, *, token: str, expires_in: int | None = None) -> None:
"""
Save token to persistent cache with metadata.

Expand Down Expand Up @@ -301,12 +301,12 @@ def _clear_token_cache(self) -> None:
except Exception as e:
logger.error(f"Failed to clear cache: {e}")

async def _fetch_access_token_with_playwright_async(self) -> Optional[str]:
async def _fetch_access_token_with_playwright_async(self) -> str | None:
"""
Fetch access token using Playwright browser automation.

Returns:
Optional[str]: The bearer token if successfully retrieved, None otherwise.
str | None: The bearer token if successfully retrieved, None otherwise.

Raises:
RuntimeError: If Playwright is not installed or browser launch fails.
Expand Down Expand Up @@ -339,35 +339,35 @@ async def _fetch_access_token_with_playwright_async(self) -> Optional[str]:
# If not on Windows or using the right loop already, proceed normally
return await self._run_playwright_browser_automation_async()

async def _run_playwright_in_thread_async(self) -> Optional[str]:
async def _run_playwright_in_thread_async(self) -> str | None:
"""
Run Playwright browser automation in a separate thread with ProactorEventLoop.
This is needed on Windows when the main loop is SelectorEventLoop (e.g., in Jupyter).

Returns:
Optional[str]: The bearer token if successfully retrieved, None otherwise.
str | None: The bearer token if successfully retrieved, None otherwise.
"""

def run_in_new_loop() -> Optional[str]:
def run_in_new_loop() -> str | None:
if sys.platform == "win32":
new_loop = asyncio.ProactorEventLoop()
else:
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
result: Optional[str] = new_loop.run_until_complete(self._run_playwright_browser_automation_async())
result: str | None = new_loop.run_until_complete(self._run_playwright_browser_automation_async())
return result
finally:
new_loop.close()

return await asyncio.get_running_loop().run_in_executor(None, run_in_new_loop)

async def _run_playwright_browser_automation_async(self) -> Optional[str]:
async def _run_playwright_browser_automation_async(self) -> str | None:
"""
Execute the actual Playwright browser automation to fetch the access token.

Returns:
Optional[str]: The bearer token if successfully retrieved, None otherwise.
str | None: The bearer token if successfully retrieved, None otherwise.

Raises:
ValueError: If the username is not set.
Expand Down
6 changes: 3 additions & 3 deletions pyrit/auth/manual_copilot_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import logging
import os
from typing import Any, Optional
from typing import Any

import jwt

Expand Down Expand Up @@ -36,12 +36,12 @@ class ManualCopilotAuthenticator(Authenticator):
#: Environment variable name for the Copilot access token
ACCESS_TOKEN_ENV_VAR: str = "COPILOT_ACCESS_TOKEN"

def __init__(self, *, access_token: Optional[str] = None) -> None:
def __init__(self, *, access_token: str | None = None) -> None:
"""
Initialize the ManualCopilotAuthenticator with a pre-obtained access token.

Args:
access_token (Optional[str]): A valid JWT access token for Microsoft Copilot.
access_token (str | None): A valid JWT access token for Microsoft Copilot.
This token can be obtained from browser DevTools when connected to Copilot.
If None, the token will be read from the ``COPILOT_ACCESS_TOKEN`` environment variable.

Expand Down
Loading
Loading