Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0d26cef
ENH: enforce _async suffix on async functions via pre-commit hook
romanlutz Jun 1, 2026
1ff47a9
MAINT: Apply _async suffix to pyrit/auth (PR 2 of sweep)
romanlutz Jun 2, 2026
560dddd
MAINT: Apply _async suffix to pyrit/backend (PR 3 of sweep)
romanlutz Jun 2, 2026
66495ab
MAINT: Apply _async suffix to pyrit/cli (PR 4 of sweep)
romanlutz Jun 2, 2026
15057ef
MAINT: Apply _async suffix exempt markers to pyrit/common shims (PR 5…
romanlutz Jun 2, 2026
99bce14
MAINT: Apply _async suffix to pyrit/datasets (PR 6 of sweep)
romanlutz Jun 2, 2026
aa5dada
STYLE: Rename async methods in pyrit/executor to use _async suffix
romanlutz Jun 2, 2026
a51f8da
STYLE: Rename async methods in pyrit/memory + pyrit/message_normalize…
romanlutz Jun 2, 2026
eba977e
REFACTOR: rename pyrit.models async methods to _async suffix (PR 9: m…
romanlutz Jun 2, 2026
ccdcc71
REFACTOR: mark pyrit.output.scorer deprecation shims async-suffix-exe…
romanlutz Jun 2, 2026
8c8c4ac
REFACTOR: rename pyrit.prompt_converter private async methods to _asy…
romanlutz Jun 2, 2026
16d22e0
REFACTOR: rename pyrit.prompt_normalizer async methods to _async suff…
romanlutz Jun 2, 2026
4f1267a
FIX: rename async methods in prompt_target to add _async suffix (PR 13)
romanlutz Jun 2, 2026
5ec9333
FIX: rename async methods in prompt_target/openai to add _async suffi…
romanlutz Jun 2, 2026
74c0fc0
FIX: rename worker closure in scenario to add _async suffix (PR 15)
romanlutz Jun 2, 2026
d051e22
FIX: rename async methods in score to add _async suffix (PR 16)
romanlutz Jun 2, 2026
2cbe41c
FIX: remove async-suffix transitional baseline (PR 17, final)
romanlutz Jun 2, 2026
d598f2e
FIX: exempt async_token_provider nested closure from _async suffix rule
romanlutz Jun 2, 2026
1f4e436
Merge branch 'main' into romanlutz/romanlutz-async-suffix-sweep
romanlutz Jun 2, 2026
b03b106
Fix stale call sites of renamed _async methods from main merge
romanlutz Jun 2, 2026
f39a8c1
Add deprecation shim coverage tests for renamed async methods
romanlutz Jun 2, 2026
33ae803
Address PR review feedback on async-suffix hook
romanlutz Jun 2, 2026
9e6fafc
Merge remote-tracking branch 'origin/main' into pr/1889/romanlutz/rom…
romanlutz Jun 2, 2026
b27fa45
Merge remote-tracking branch 'origin/main' into pr/1889/romanlutz/rom…
romanlutz Jun 2, 2026
59628a4
Merge branch 'main' into romanlutz/romanlutz-async-suffix-sweep
romanlutz Jun 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/instructions/datasets.instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Each `SeedPrompt` / `SeedObjective` must carry:

## Set class-level dataset metadata when known

`_parse_metadata` on `_RemoteDatasetLoader` reads class attributes matching `SeedDatasetMetadata` fields. Declare what you can know statically as class-level constants so dataset discovery/filtering works:
`_parse_metadata_async` on `_RemoteDatasetLoader` reads class attributes matching `SeedDatasetMetadata` fields. Declare what you can know statically as class-level constants so dataset discovery/filtering works:

```python
class _MyDataset(_RemoteDatasetLoader):
Expand Down
8 changes: 8 additions & 0 deletions .github/instructions/style-guide.instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async def _read_audio_async(self, path):
### Async Functions
- **MANDATORY**: All async functions and methods MUST end with `_async` suffix
- This applies to ALL async functions without exception
- Enforced by the `check-async-suffix` pre-commit hook (`build_scripts/check_async_suffix.py`)

```python
# CORRECT
Expand All @@ -54,6 +55,13 @@ async def send_prompt(self, prompt: str) -> Message: # Missing _async suffix
...
```

**Exemptions** are limited and explicit:
- Async dunders (`__aenter__`, `__aexit__`, `__aiter__`, `__anext__`) are exempt automatically.
- A small set of framework-mandated names (`lifespan`, `dispatch`, `__call__`) is exempt
automatically; see `_FRAMEWORK_EXEMPT_NAMES` in `build_scripts/check_async_suffix.py`.
- For one-off exemptions (e.g. an external SDK protocol method) add a
`# pyrit-async-suffix-exempt` trailing comment on the `async def` line.

### Private Methods
- Private methods MUST start with underscore
- This clearly indicates internal implementation details
Expand Down
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ repos:
language: python
files: ^pyrit/memory/alembic/versions/.*\.py$
pass_filenames: false
- id: check-async-suffix
name: Enforce _async Suffix on async def
entry: python ./build_scripts/check_async_suffix.py
language: python
files: ^pyrit/.*\.py$
pass_filenames: false
- id: memory-migrations-check
name: Check Memory Migrations
entry: python ./build_scripts/memory_migrations.py check
Expand Down
137 changes: 137 additions & 0 deletions build_scripts/check_async_suffix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Enforce ``.github/instructions/style-guide.instructions.md`` §1: every ``async def`` in
``pyrit/`` must end with the ``_async`` suffix.

Mechanism: walk every ``pyrit/**/*.py`` file with ``ast`` and flag every ``AsyncFunctionDef``
whose name does not end in ``_async`` and is not exempted via either:

1. **Hard-coded framework exemptions** (``_FRAMEWORK_EXEMPT_NAMES``) — names whose meaning
is dictated by an external framework or by the Python data model
(e.g. ``lifespan`` for FastAPI, ``dispatch`` for Starlette middleware, ``__call__``
on Protocol classes). The set is intentionally small; one-off exemptions
should use the per-line ``# pyrit-async-suffix-exempt`` marker instead.

2. **Per-line ``# pyrit-async-suffix-exempt`` marker** on any line of the ``async def``
header (the marker is scanned across the full signature, which the formatter may
split across multiple lines). Common reasons: a deprecation shim that intentionally
keeps the old non-``_async`` name for one release cycle; a one-off external-SDK or
protocol method name.
"""

from __future__ import annotations

import ast
import sys
from pathlib import Path

# Project layout — anchor everything off the repo root (directory containing pyrit/).
_REPO_ROOT = Path(__file__).resolve().parent.parent
_SCAN_ROOTS = ("pyrit",)

# Framework-mandated names: do NOT add to this set for one-off exemptions.
# Use a per-line ``# pyrit-async-suffix-exempt`` marker instead so each exemption is
# visible at the violation site.
_FRAMEWORK_EXEMPT_NAMES: frozenset[str] = frozenset(
{
"lifespan", # FastAPI app lifespan context manager
"dispatch", # Starlette BaseHTTPMiddleware.dispatch override
"__call__", # Python dunder; Protocol classes commonly define async __call__
}
)

_NOQA_MARKER = "# pyrit-async-suffix-exempt"


def _is_violation_name(name: str) -> bool:
"""Return True if ``name`` violates the async-suffix rule."""
if name.endswith("_async"):
return False
if name.startswith("__a"):
# Async dunders: __aenter__, __aexit__, __aiter__, __anext__.
return False
return name not in _FRAMEWORK_EXEMPT_NAMES


def _line_has_noqa(source_lines: list[str], lineno: int) -> bool:
"""Return True if ``source_lines[lineno - 1]`` carries the exempt marker."""
if lineno < 1 or lineno > len(source_lines):
return False
return _NOQA_MARKER in source_lines[lineno - 1]


def _header_has_noqa(source_lines: list[str], node: ast.AsyncFunctionDef) -> bool:
"""Return True if any line of the def header carries the exempt marker.

The header spans ``node.lineno`` through the line just before the function body
starts (which is where the formatter may place the marker after splitting a
long signature across multiple lines).
"""
start = node.lineno
end = node.body[0].lineno - 1 if node.body else start
return any(_line_has_noqa(source_lines, lineno) for lineno in range(start, max(start, end) + 1))


def _scan_file(path: Path) -> list[tuple[str, int, str]]:
"""Return ``(relative_path, line, name)`` violations in ``path``.

``relative_path`` is forward-slash normalized relative to the repo root so that
violations are reported portably between Windows and Linux checkouts.
"""
source = path.read_text(encoding="utf-8")
try:
tree = ast.parse(source, filename=str(path))
except SyntaxError as exc:
rel = path.relative_to(_REPO_ROOT).as_posix()
# Surface the parse failure as a violation so an unparseable file can't
# silently slip past the check. Other hooks (e.g. ruff) should flag the
# syntax error too, but we don't rely on their ordering.
message = f"{exc.msg} (line {exc.lineno})" if exc.lineno is not None else exc.msg
return [(rel, exc.lineno or 0, f"<SyntaxError: {message}>")]
source_lines = source.splitlines()
rel = path.relative_to(_REPO_ROOT).as_posix()
violations: list[tuple[str, int, str]] = []
for node in ast.walk(tree):
if not isinstance(node, ast.AsyncFunctionDef):
continue
if not _is_violation_name(node.name):
continue
if _header_has_noqa(source_lines, node):
continue
violations.append((rel, node.lineno, node.name))
return violations


def _scan_repo() -> list[tuple[str, int, str]]:
"""Return all violations across the scanned roots, sorted for determinism."""
violations: list[tuple[str, int, str]] = []
for root in _SCAN_ROOTS:
for path in sorted((_REPO_ROOT / root).rglob("*.py")):
violations.extend(_scan_file(path))
return violations


def main() -> int:
violations = _scan_repo()
if not violations:
return 0

print(
"[ERROR] Async functions are missing the `_async` suffix "
"(see .github/instructions/style-guide.instructions.md §1):"
)
for path, line, name in violations:
if name.startswith("<SyntaxError"):
print(f" {path}:{line}: could not parse file: {name[1:-1]}")
else:
print(f" {path}:{line}: async def {name}(...)")
print("")
print("Rename each function to end in `_async`, or — if the name is dictated")
print("by a framework — add `# pyrit-async-suffix-exempt` at the end of the `async def` line.")
return 1


if __name__ == "__main__":
sys.exit(main())
2 changes: 1 addition & 1 deletion doc/code/datasets/4_dataset_coding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"\n",
" async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:\n",
" # Fetch from HuggingFace\n",
" data = await self._fetch_from_huggingface(\n",
" data = await self._fetch_from_huggingface_async(\n",
" dataset_name=\"apart/darkbench\",\n",
" config=\"default\",\n",
" split=\"train\",\n",
Expand Down
2 changes: 1 addition & 1 deletion doc/code/datasets/4_dataset_coding.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def dataset_name(self) -> str:

async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
# Fetch from HuggingFace
data = await self._fetch_from_huggingface(
data = await self._fetch_from_huggingface_async(
dataset_name="apart/darkbench",
config="default",
split="train",
Expand Down
2 changes: 1 addition & 1 deletion doc/code/targets/realtime_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@
"attack = PromptSendingAttack(objective_target=target)\n",
"result = await attack.execute_with_context_async(context=context) # type: ignore\n",
"await output_attack_async(result)\n",
"await target.cleanup_target() # type: ignore"
"await target.cleanup_target_async() # type: ignore"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion doc/code/targets/realtime_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
attack = PromptSendingAttack(objective_target=target)
result = await attack.execute_with_context_async(context=context) # type: ignore
await output_attack_async(result)
await target.cleanup_target() # type: ignore
await target.cleanup_target_async() # type: ignore

# %% [markdown]
# ## Text Conversation
Expand Down
6 changes: 3 additions & 3 deletions pyrit/auth/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, token_provider: Callable[[], Union[str, Awaitable[str]]]) ->
"""
self._token_provider = token_provider

async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pyrit-async-suffix-exempt
"""
Get an access token asynchronously.

Expand All @@ -104,7 +104,7 @@ async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
expires_on = int(time.time()) + 3600
return AccessToken(str(token), expires_on)

async def close(self) -> None:
async def close(self) -> None: # pyrit-async-suffix-exempt
"""No-op close for protocol compliance. The callable provider does not hold resources."""

async def __aenter__(self) -> AsyncTokenProviderCredential:
Expand Down Expand Up @@ -149,7 +149,7 @@ def ensure_async_token_provider(
" Automatically wrapping in async function for compatibility with async client."
)

async def async_token_provider() -> str:
async def async_token_provider() -> str: # pyrit-async-suffix-exempt
"""
Async wrapper for synchronous token provider.

Expand Down
47 changes: 44 additions & 3 deletions pyrit/auth/azure_storage_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)
from azure.storage.blob.aio import BlobServiceClient

from pyrit.common.deprecation import print_deprecation_message


class AzureStorageAuth:
"""
Expand All @@ -20,7 +22,7 @@ class AzureStorageAuth:
"""

@staticmethod
async def get_user_delegation_key(blob_service_client: BlobServiceClient) -> UserDelegationKey:
async def get_user_delegation_key_async(blob_service_client: BlobServiceClient) -> UserDelegationKey:
"""
Retrieve a user delegation key valid for one day.

Expand All @@ -39,7 +41,28 @@ async def get_user_delegation_key(blob_service_client: BlobServiceClient) -> Use
)

@staticmethod
async def get_sas_token(container_url: str) -> str:
async def get_user_delegation_key(
blob_service_client: BlobServiceClient,
) -> UserDelegationKey: # pyrit-async-suffix-exempt
"""
Retrieve a user delegation key (deprecated alias of ``get_user_delegation_key_async``).

Args:
blob_service_client (BlobServiceClient): An instance of BlobServiceClient to interact
with Azure Blob Storage.

Returns:
UserDelegationKey: A user delegation key valid for one day.
"""
print_deprecation_message(
old_item="AzureStorageAuth.get_user_delegation_key",
new_item="AzureStorageAuth.get_user_delegation_key_async",
removed_in="0.16.0",
)
return await AzureStorageAuth.get_user_delegation_key_async(blob_service_client)

@staticmethod
async def get_sas_token_async(container_url: str) -> str:
"""
Generate a SAS token for the specified blob using a user delegation key.

Expand Down Expand Up @@ -72,7 +95,7 @@ async def get_sas_token(container_url: str) -> str:

try:
async with BlobServiceClient(account_url=account_url, credential=credential) as blob_service_client:
user_delegation_key = await AzureStorageAuth.get_user_delegation_key(
user_delegation_key = await AzureStorageAuth.get_user_delegation_key_async(
blob_service_client=blob_service_client
)
container_name = parsed_url.path.lstrip("/")
Expand All @@ -94,3 +117,21 @@ async def get_sas_token(container_url: str) -> str:
await credential.close()

return sas_token

@staticmethod
async def get_sas_token(container_url: str) -> str: # pyrit-async-suffix-exempt
"""
Generate a SAS token (deprecated alias of ``get_sas_token_async``).

Args:
container_url (str): The URL of the Azure Blob Storage container.

Returns:
str: The generated SAS token.
"""
print_deprecation_message(
old_item="AzureStorageAuth.get_sas_token",
new_item="AzureStorageAuth.get_sas_token_async",
removed_in="0.16.0",
)
return await AzureStorageAuth.get_sas_token_async(container_url)
Loading
Loading