diff --git a/.github/instructions/style-guide.instructions.md b/.github/instructions/style-guide.instructions.md index 8ddb2e609..f5f245f9f 100644 --- a/.github/instructions/style-guide.instructions.md +++ b/.github/instructions/style-guide.instructions.md @@ -95,29 +95,65 @@ def process(self, data: str) -> str: ... ``` -## Documentation Standards +## Imports + +### Placement and Organization + +Top of file, grouped: stdlib → third-party → local. + +### Deferred Imports for Performance + +Imports may be placed inside functions/methods when they pull in expensive +third-party packages (`transformers`, `azure.storage.blob`, `alembic`, `openai`, +`scipy`, `pandas`, `av`). Two cases: + +1. **CLI entry points** — defer heavy imports to after arg parsing so `--help` is instant. +2. **Internal modules** — when a method is the only consumer of a heavy package. + +```python +def main() -> int: + parsed_args = parse_args() + from pyrit.cli import frontend_core # deferred: heavy + ... + +async def _create_container_client_async(self): + from azure.storage.blob.aio import ContainerClient # deferred: heavy + ... +``` + +Guard tests in `tests/unit/cli/test_import_guards.py` enforce that key import +paths stay fast. + +### Lazy `__init__.py` Exports (PEP 562) + +Public API packages (`pyrit.prompt_target`, `pyrit.prompt_converter`, `pyrit.score`) +use `__getattr__`-based lazy loading so heavy symbols can be imported from the +package without paying the cost at package load time. See +`pyrit/prompt_target/__init__.py` for the canonical example. Rules: -### Import Placement -- **MANDATORY**: All import statements MUST be at the top of the file -- Do NOT use inline/local imports inside functions or methods -- The only exception is breaking circular import dependencies, which should be rare and documented +- Lazy names must remain in `__all__` and have a `TYPE_CHECKING` import for IDE support. +- Internal utility packages (e.g., `pyrit.common`) simply omit heavy submodules + from `__init__.py` — consumers import directly from the specific file. + +### Import Paths + +Import from the package root when the symbol is exported from `__init__.py`: ```python -# CORRECT — imports at the top of the file -from contextlib import closing -from sqlalchemy.exc import SQLAlchemyError - -def update_entry(self, entry: Base) -> None: - with closing(self.get_session()) as session: - ... - -# INCORRECT — inline import inside a function -def update_entry(self, entry: Base) -> None: - from contextlib import closing # ← WRONG, must be at top of file - with closing(self.get_session()) as session: - ... +from pyrit.prompt_target import PromptChatTarget # CORRECT +from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget # WRONG ``` +Heavy submodules not re-exported from `__init__.py` are imported directly: + +```python +from pyrit.common.net_utility import get_httpx_client +``` + +Within the same package, import from the specific file to avoid circular imports. + +## Documentation Standards + ### Docstring Format - Use Google-style docstrings - Include type information in parameter descriptions @@ -211,63 +247,6 @@ async def execute_attack_async(self, *, context: AttackContext) -> AttackResult: 5. Private methods (internal implementation) 6. Static methods and class methods at the end -### Import Organization -```python -# Standard library imports -import asyncio -import json -import logging -from dataclasses import dataclass -from enum import Enum -from pathlib import Path -from typing import Any - -# Third-party imports -import numpy as np -from tqdm import tqdm - -# Local application imports -from pyrit.attacks.base import AttackStrategy -from pyrit.models import AttackResult -from pyrit.prompt_target import PromptTarget -``` - -Unless necessary, always import at the top of the file. Don't import inside a function or method. - - -### Import paths - -Often, pyrit has specific files that can be imported. However IF you are importing from a different module than your namespace, -import from the root pyrit module if it's exposed from init. - -In the same module, importing from the specific path is usually necessary to prevent circular imports. - -- Always check __init__.py exports first - Before using a specific file path, verify if the class/function is exposed at a higher level -- Group related imports - Put all imports from the same root module together -- Use multi-line formatting for readability - When importing 3+ items from the same module, use parentheses - - -```python -# Correct -from pyrit.prompt_target import PromptChatTarget, OpenAIChatTarget - -# Correct -from pyrit.score import ( - AzureContentFilterScorer, - FloatScaleThresholdScorer, - SelfAskRefusalScorer, - TrueFalseCompositeScorer, - TrueFalseInverterScorer, - TrueFalseScoreAggregator, - TrueFalseScorer, -) - -# Incorrect (if importing from a non-target module) -from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget -from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget - -``` - ## Error Handling ### Specific Exceptions @@ -438,6 +417,16 @@ def process_large_dataset(self, *, file_path: Path) -> list[Result]: return [self._process_line(line) for line in lines] ``` +### Lazy Imports for Startup Performance +- When adding a new module that imports heavy third-party packages (e.g., `transformers`, + `scipy`, `PIL`, `datasets`, `av`), consider whether it is re-exported from a package + `__init__.py` that is on the CLI startup path +- If so, add it to the `_LAZY_IMPORTS` dict in that `__init__.py` instead of as an + eager top-level import (see the Import Placement section for the pattern) +- This is especially important for `pyrit/common/__init__.py`, `pyrit/prompt_target/__init__.py`, + `pyrit/prompt_converter/__init__.py`, and `pyrit/score/__init__.py` which are all on the + import path for CLI startup + ## Final Checklist Before committing code, ensure: @@ -449,6 +438,7 @@ Before committing code, ensure: - [ ] Functions are focused and under 20 lines - [ ] Error messages are helpful and specific - [ ] Code follows the import organization pattern +- [ ] New modules with heavy deps follow `__init__.py` startup guidance - [ ] No hard-coded dependencies - [ ] Complex logic is extracted to helper methods diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index 9a8ca771c..fb6cc75e3 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -14,7 +14,13 @@ from pathlib import Path from typing import Optional -from pyrit.cli import frontend_core +from pyrit.cli._cli_args import ( + ARG_HELP, + _parse_initializer_arg, + non_negative_int, + positive_int, + validate_log_level_argparse, +) def parse_args(args: Optional[list[str]] = None) -> Namespace: @@ -53,12 +59,12 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: parser.add_argument( "--config-file", type=Path, - help=frontend_core.ARG_HELP["config_file"], + help=ARG_HELP["config_file"], ) parser.add_argument( "--log-level", - type=frontend_core.validate_log_level_argparse, + type=validate_log_level_argparse, default=logging.WARNING, help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING)", ) @@ -91,16 +97,16 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: parser.add_argument( "--initializers", - type=frontend_core._parse_initializer_arg, + type=_parse_initializer_arg, nargs="+", - help=frontend_core.ARG_HELP["initializers"], + help=ARG_HELP["initializers"], ) parser.add_argument( "--initialization-scripts", type=str, nargs="+", - help=frontend_core.ARG_HELP["initialization_scripts"], + help=ARG_HELP["initialization_scripts"], ) parser.add_argument( @@ -109,44 +115,44 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: type=str, nargs="+", dest="scenario_strategies", - help=frontend_core.ARG_HELP["scenario_strategies"], + help=ARG_HELP["scenario_strategies"], ) parser.add_argument( "--max-concurrency", - type=frontend_core.positive_int, - help=frontend_core.ARG_HELP["max_concurrency"], + type=positive_int, + help=ARG_HELP["max_concurrency"], ) parser.add_argument( "--max-retries", - type=frontend_core.non_negative_int, - help=frontend_core.ARG_HELP["max_retries"], + type=non_negative_int, + help=ARG_HELP["max_retries"], ) parser.add_argument( "--memory-labels", type=str, - help=frontend_core.ARG_HELP["memory_labels"], + help=ARG_HELP["memory_labels"], ) parser.add_argument( "--dataset-names", type=str, nargs="+", - help=frontend_core.ARG_HELP["dataset_names"], + help=ARG_HELP["dataset_names"], ) parser.add_argument( "--max-dataset-size", - type=frontend_core.positive_int, - help=frontend_core.ARG_HELP["max_dataset_size"], + type=positive_int, + help=ARG_HELP["max_dataset_size"], ) parser.add_argument( "--target", type=str, - help=frontend_core.ARG_HELP["target"], + help=ARG_HELP["target"], ) return parser.parse_args(args) @@ -159,14 +165,17 @@ def main(args: Optional[list[str]] = None) -> int: Returns: int: Exit code (0 for success, 1 for error). """ - print("Starting PyRIT...") - sys.stdout.flush() - try: parsed_args = parse_args(args) except SystemExit as e: return e.code if isinstance(e.code, int) else 1 + print("Starting PyRIT...") + sys.stdout.flush() + + # Defer the heavy import until after arg parsing so --help is instant. + from pyrit.cli import frontend_core + # Handle list commands (don't need full context) if parsed_args.list_scenarios: # Simple context just for listing diff --git a/pyrit/common/__init__.py b/pyrit/common/__init__.py index 15dceeb8a..afd50c6e0 100644 --- a/pyrit/common/__init__.py +++ b/pyrit/common/__init__.py @@ -1,7 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Common utilities and helpers for PyRIT.""" +""" +Common utilities and helpers for PyRIT. + +Heavy submodules (data_url_converter, display_response, download_hf_model, +net_utility) are intentionally NOT re-exported here to keep ``import pyrit`` +fast. Import them directly, e.g.:: + + from pyrit.common.net_utility import get_httpx_client +""" from pyrit.common.apply_defaults import ( REQUIRED_VALUE, @@ -12,18 +20,8 @@ reset_default_values, set_default_value, ) -from pyrit.common.data_url_converter import convert_local_image_to_data_url from pyrit.common.default_values import get_non_required_value, get_required_value from pyrit.common.deprecation import print_deprecation_message -from pyrit.common.display_response import display_image_response -from pyrit.common.download_hf_model import ( - download_chunk, - download_file, - download_files, - download_specific_files, - get_available_files, -) -from pyrit.common.net_utility import get_httpx_client, make_request_and_raise_if_error_async from pyrit.common.notebook_utils import is_in_ipython_session from pyrit.common.singleton import Singleton from pyrit.common.utils import ( @@ -41,28 +39,19 @@ "apply_defaults_to_method", "combine_dict", "combine_list", - "convert_local_image_to_data_url", "DefaultValueScope", - "display_image_response", - "download_chunk", - "download_file", - "download_files", - "download_specific_files", - "get_available_files", "get_global_default_values", - "get_httpx_client", "get_kwarg_param", "get_non_required_value", "get_random_indices", "get_required_value", - "verify_and_resolve_path", "is_in_ipython_session", - "make_request_and_raise_if_error_async", + "print_deprecation_message", "REQUIRED_VALUE", "reset_default_values", "set_default_value", "Singleton", + "verify_and_resolve_path", "warn_if_set", "YamlLoadable", - "print_deprecation_message", ] diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index e3d93b918..e7ffb9382 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -18,13 +18,12 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.attributes import InstrumentedAttribute +if TYPE_CHECKING: + from pyrit.memory.memory_embedding import MemoryEmbedding + from pyrit.common.deprecation import print_deprecation_message from pyrit.common.path import DB_DATA_PATH from pyrit.identifiers.identifier_filters import IdentifierFilter, IdentifierType -from pyrit.memory.memory_embedding import ( - MemoryEmbedding, - default_memory_embedding_factory, -) from pyrit.memory.memory_exporter import MemoryExporter from pyrit.memory.memory_models import ( AttackResultEntry, @@ -35,7 +34,6 @@ ScoreEntry, SeedEntry, ) -from pyrit.memory.migration import check_schema_migrations, reset_database, run_schema_migrations from pyrit.models import ( AttackResult, ConversationStats, @@ -84,7 +82,7 @@ class MemoryInterface(abc.ABC): # for backends with higher limits (e.g., Azure SQL supports 2100). _MAX_BIND_VARS: int = 500 - memory_embedding: MemoryEmbedding | None = None + memory_embedding: "MemoryEmbedding | None" = None results_storage_io: StorageIO | None = None results_path: str | None = None engine: Engine | None = None @@ -123,6 +121,8 @@ def enable_embedding(self, embedding_model: Optional[Any] = None) -> None: ValueError: If no embedding model is provided and required environment variables are not set. """ + from pyrit.memory.memory_embedding import default_memory_embedding_factory + self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model) def disable_embedding(self) -> None: @@ -1153,6 +1153,8 @@ def _run_schema_migration(self) -> None: RuntimeError: If the engine is not initialized when required. Exception: If there is an error during schema validation or migration. """ + from pyrit.memory.migration import check_schema_migrations, run_schema_migrations + logger.info("Running schema migration.") if self.engine is None: raise RuntimeError("Engine must be initialized to run schema migrations.") @@ -1166,6 +1168,8 @@ def reset_database(self) -> None: Raises: RuntimeError: If the engine is not initialized. """ + from pyrit.memory.migration import reset_database + if self.engine is None: raise RuntimeError("Engine is not initialized") reset_database(engine=self.engine) diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 2fa3bfc0a..e090c6ade 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -6,7 +6,7 @@ import os from typing import TYPE_CHECKING, Any, Union -from pyrit.common import convert_local_image_to_data_url +from pyrit.common.data_url_converter import convert_local_image_to_data_url from pyrit.message_normalizer.message_normalizer import ( MessageListNormalizer, MessageStringNormalizer, diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index 08142e857..4d2ed293b 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -1,20 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import logging import os from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from urllib.parse import urlparse import aiofiles -from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError -from azure.storage.blob import ContentSettings -from azure.storage.blob.aio import ContainerClient as AsyncContainerClient -from pyrit.auth import AzureStorageAuth +if TYPE_CHECKING: + from azure.storage.blob.aio import ContainerClient as AsyncContainerClient logger = logging.getLogger(__name__) @@ -195,6 +195,10 @@ async def _create_container_client_async(self) -> AsyncContainerClient: Returns: AsyncContainerClient: The initialized container client. """ + from azure.storage.blob.aio import ContainerClient as AsyncContainerClient + + from pyrit.auth import AzureStorageAuth + sas_token = self._sas_token if not self._sas_token: logger.info("SAS token not provided. Creating a delegation SAS token using Entra ID authentication.") @@ -218,6 +222,9 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st Raises: RuntimeError: If the Azure container client is not initialized. """ + from azure.core.exceptions import ClientAuthenticationError + from azure.storage.blob import ContentSettings + content_settings = ContentSettings(content_type=f"{content_type}") logger.info(msg="\nUploading to Azure Storage as blob:\n\t" + file_name) @@ -364,6 +371,8 @@ async def path_exists(self, path: Union[Path, str]) -> bool: Returns: bool: True when the path exists. """ + from azure.core.exceptions import ResourceNotFoundError + if not self._client_async: self._client_async = await self._create_container_client_async() try: @@ -387,6 +396,8 @@ async def is_file(self, path: Union[Path, str]) -> bool: Returns: bool: True when the blob exists and has non-zero content size. """ + from azure.core.exceptions import ResourceNotFoundError + if not self._client_async: self._client_async = await self._create_container_client_async() try: diff --git a/pyrit/prompt_converter/__init__.py b/pyrit/prompt_converter/__init__.py index 74db8fe7d..0d6afb6f8 100644 --- a/pyrit/prompt_converter/__init__.py +++ b/pyrit/prompt_converter/__init__.py @@ -11,6 +11,9 @@ transformation pipelines for testing AI system robustness. """ +import importlib +from typing import TYPE_CHECKING + from pyrit.prompt_converter.add_image_text_converter import AddImageTextConverter from pyrit.prompt_converter.add_image_to_video_converter import AddImageVideoConverter from pyrit.prompt_converter.add_text_image_converter import AddTextImageConverter @@ -18,11 +21,6 @@ from pyrit.prompt_converter.ascii_art_converter import AsciiArtConverter from pyrit.prompt_converter.ask_to_decode_converter import AskToDecodeConverter from pyrit.prompt_converter.atbash_converter import AtbashConverter -from pyrit.prompt_converter.audio_echo_converter import AudioEchoConverter -from pyrit.prompt_converter.audio_frequency_converter import AudioFrequencyConverter -from pyrit.prompt_converter.audio_speed_converter import AudioSpeedConverter -from pyrit.prompt_converter.audio_volume_converter import AudioVolumeConverter -from pyrit.prompt_converter.audio_white_noise_converter import AudioWhiteNoiseConverter from pyrit.prompt_converter.azure_speech_audio_to_text_converter import AzureSpeechAudioToTextConverter from pyrit.prompt_converter.azure_speech_text_to_audio_converter import AzureSpeechTextToAudioConverter from pyrit.prompt_converter.base64_converter import Base64Converter @@ -72,7 +70,6 @@ from pyrit.prompt_converter.superscript_converter import SuperscriptConverter from pyrit.prompt_converter.template_segment_converter import TemplateSegmentConverter from pyrit.prompt_converter.tense_converter import TenseConverter -from pyrit.prompt_converter.text_jailbreak_converter import TextJailbreakConverter from pyrit.prompt_converter.text_selection_strategy import ( AllWordsSelectionStrategy, IndexSelectionStrategy, @@ -108,6 +105,36 @@ from pyrit.prompt_converter.zalgo_converter import ZalgoConverter from pyrit.prompt_converter.zero_width_converter import ZeroWidthConverter +if TYPE_CHECKING: + from pyrit.prompt_converter.audio_echo_converter import AudioEchoConverter + from pyrit.prompt_converter.audio_frequency_converter import AudioFrequencyConverter + from pyrit.prompt_converter.audio_speed_converter import AudioSpeedConverter + from pyrit.prompt_converter.audio_volume_converter import AudioVolumeConverter + from pyrit.prompt_converter.audio_white_noise_converter import AudioWhiteNoiseConverter + from pyrit.prompt_converter.text_jailbreak_converter import TextJailbreakConverter + +# Lazy imports for modules with heavy third-party dependencies (PEP 562). +# Audio converters import `scipy` which adds ~1.3s to startup. +# TextJailbreakConverter imports `pyrit.datasets` which triggers `datasets` → `pandas` (~1.6s). +_LAZY_IMPORTS: dict[str, str] = { + "AudioEchoConverter": "pyrit.prompt_converter.audio_echo_converter", + "AudioFrequencyConverter": "pyrit.prompt_converter.audio_frequency_converter", + "AudioSpeedConverter": "pyrit.prompt_converter.audio_speed_converter", + "AudioVolumeConverter": "pyrit.prompt_converter.audio_volume_converter", + "AudioWhiteNoiseConverter": "pyrit.prompt_converter.audio_white_noise_converter", + "TextJailbreakConverter": "pyrit.prompt_converter.text_jailbreak_converter", +} + + +def __getattr__(name: str) -> object: + if name in _LAZY_IMPORTS: + module = importlib.import_module(_LAZY_IMPORTS[name]) + attr = getattr(module, name) + globals()[name] = attr + return attr + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + __all__ = [ "AddImageTextConverter", "AddImageVideoConverter", diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index c71dca408..de658298a 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -8,6 +8,9 @@ for example sending prompts or transferring content (uploads). """ +import importlib +from typing import TYPE_CHECKING + from pyrit.prompt_target.azure_blob_storage_target import AzureBlobStorageTarget from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline @@ -29,7 +32,6 @@ get_http_target_regex_matching_callback_function, ) from pyrit.prompt_target.http_target.httpx_api_target import HTTPXAPITarget -from pyrit.prompt_target.hugging_face.hugging_face_chat_target import HuggingFaceChatTarget from pyrit.prompt_target.hugging_face.hugging_face_endpoint_target import HuggingFaceEndpointTarget from pyrit.prompt_target.openai.openai_chat_audio_config import OpenAIChatAudioConfig from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget @@ -46,6 +48,25 @@ from pyrit.prompt_target.text_target import TextTarget from pyrit.prompt_target.websocket_copilot_target import WebSocketCopilotTarget +if TYPE_CHECKING: + from pyrit.prompt_target.hugging_face.hugging_face_chat_target import HuggingFaceChatTarget + +# Lazy imports for modules with heavy third-party dependencies (PEP 562). +# HuggingFaceChatTarget imports `transformers` which adds ~4s to startup. +_LAZY_IMPORTS: dict[str, str] = { + "HuggingFaceChatTarget": "pyrit.prompt_target.hugging_face.hugging_face_chat_target", +} + + +def __getattr__(name: str) -> object: + if name in _LAZY_IMPORTS: + module = importlib.import_module(_LAZY_IMPORTS[name]) + attr = getattr(module, name) + globals()[name] = attr + return attr + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + __all__ = [ "AzureBlobStorageTarget", "AzureMLChatTarget", diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index e61dcbb7c..d60cf856d 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -8,7 +8,7 @@ from dataclasses import replace from typing import Any, Optional -from pyrit.common import convert_local_image_to_data_url +from pyrit.common.data_url_converter import convert_local_image_to_data_url from pyrit.exceptions import ( EmptyResponseException, PyritException, diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 1666ccc2c..dbe71e540 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -14,7 +14,7 @@ from openai.types.shared import ReasoningEffort -from pyrit.common import convert_local_image_to_data_url +from pyrit.common.data_url_converter import convert_local_image_to_data_url from pyrit.exceptions import ( EmptyResponseException, PyritException, diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index 57565fc14..174b8e645 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -14,7 +14,7 @@ from websockets.exceptions import InvalidStatus from pyrit.auth import CopilotAuthenticator, ManualCopilotAuthenticator -from pyrit.common import convert_local_image_to_data_url +from pyrit.common.data_url_converter import convert_local_image_to_data_url from pyrit.exceptions import ( EmptyResponseException, pyrit_target_retry, diff --git a/pyrit/score/__init__.py b/pyrit/score/__init__.py index 24582e479..5aa0e9ac2 100644 --- a/pyrit/score/__init__.py +++ b/pyrit/score/__init__.py @@ -6,9 +6,11 @@ including harm detection, objective completion, and content classification. """ +import importlib +from typing import TYPE_CHECKING + from pyrit.score.batch_scorer import BatchScorer from pyrit.score.conversation_scorer import ConversationScorer, create_conversation_scorer -from pyrit.score.float_scale.audio_float_scale_scorer import AudioFloatScaleScorer from pyrit.score.float_scale.azure_content_filter_scorer import AzureContentFilterScorer from pyrit.score.float_scale.float_scale_score_aggregator import ( FloatScaleScoreAggregator, @@ -21,22 +23,9 @@ from pyrit.score.float_scale.self_ask_general_float_scale_scorer import SelfAskGeneralFloatScaleScorer from pyrit.score.float_scale.self_ask_likert_scorer import LikertScaleEvalFiles, LikertScalePaths, SelfAskLikertScorer from pyrit.score.float_scale.self_ask_scale_scorer import SelfAskScaleScorer -from pyrit.score.float_scale.video_float_scale_scorer import VideoFloatScaleScorer from pyrit.score.printer import ConsoleScorerPrinter, ScorerPrinter from pyrit.score.scorer import Scorer -from pyrit.score.scorer_evaluation.human_labeled_dataset import ( - HarmHumanLabeledEntry, - HumanLabeledDataset, - HumanLabeledEntry, - ObjectiveHumanLabeledEntry, -) from pyrit.score.scorer_evaluation.metrics_type import MetricsType, RegistryUpdateBehavior -from pyrit.score.scorer_evaluation.scorer_evaluator import ( - HarmScorerEvaluator, - ObjectiveScorerEvaluator, - ScorerEvalDatasetFiles, - ScorerEvaluator, -) from pyrit.score.scorer_evaluation.scorer_metrics import ( HarmScorerMetrics, ObjectiveScorerMetrics, @@ -49,7 +38,6 @@ get_all_objective_metrics, ) from pyrit.score.scorer_prompt_validator import ScorerPromptValidator -from pyrit.score.true_false.audio_true_false_scorer import AudioTrueFalseScorer from pyrit.score.true_false.decoding_scorer import DecodingScorer from pyrit.score.true_false.float_scale_threshold_scorer import FloatScaleThresholdScorer from pyrit.score.true_false.gandalf_scorer import GandalfScorer @@ -70,7 +58,52 @@ from pyrit.score.true_false.true_false_inverter_scorer import TrueFalseInverterScorer from pyrit.score.true_false.true_false_score_aggregator import TrueFalseAggregatorFunc, TrueFalseScoreAggregator from pyrit.score.true_false.true_false_scorer import TrueFalseScorer -from pyrit.score.true_false.video_true_false_scorer import VideoTrueFalseScorer + +if TYPE_CHECKING: + from pyrit.score.float_scale.audio_float_scale_scorer import AudioFloatScaleScorer + from pyrit.score.float_scale.video_float_scale_scorer import VideoFloatScaleScorer + from pyrit.score.scorer_evaluation.human_labeled_dataset import ( + HarmHumanLabeledEntry, + HumanLabeledDataset, + HumanLabeledEntry, + ObjectiveHumanLabeledEntry, + ) + from pyrit.score.scorer_evaluation.scorer_evaluator import ( + HarmScorerEvaluator, + ObjectiveScorerEvaluator, + ScorerEvalDatasetFiles, + ScorerEvaluator, + ) + from pyrit.score.true_false.audio_true_false_scorer import AudioTrueFalseScorer + from pyrit.score.true_false.video_true_false_scorer import VideoTrueFalseScorer + +# Lazy imports for modules with heavy third-party dependencies (PEP 562). +# Audio/video scorers import `av` (~1.9s), human_labeled_dataset imports `pandas` (~1.6s), +# scorer_evaluator imports `scipy.stats` (~1s). +_LAZY_IMPORTS: dict[str, str] = { + "AudioFloatScaleScorer": "pyrit.score.float_scale.audio_float_scale_scorer", + "AudioTrueFalseScorer": "pyrit.score.true_false.audio_true_false_scorer", + "VideoFloatScaleScorer": "pyrit.score.float_scale.video_float_scale_scorer", + "VideoTrueFalseScorer": "pyrit.score.true_false.video_true_false_scorer", + "HarmHumanLabeledEntry": "pyrit.score.scorer_evaluation.human_labeled_dataset", + "HumanLabeledDataset": "pyrit.score.scorer_evaluation.human_labeled_dataset", + "HumanLabeledEntry": "pyrit.score.scorer_evaluation.human_labeled_dataset", + "ObjectiveHumanLabeledEntry": "pyrit.score.scorer_evaluation.human_labeled_dataset", + "HarmScorerEvaluator": "pyrit.score.scorer_evaluation.scorer_evaluator", + "ObjectiveScorerEvaluator": "pyrit.score.scorer_evaluation.scorer_evaluator", + "ScorerEvalDatasetFiles": "pyrit.score.scorer_evaluation.scorer_evaluator", + "ScorerEvaluator": "pyrit.score.scorer_evaluation.scorer_evaluator", +} + + +def __getattr__(name: str) -> object: + if name in _LAZY_IMPORTS: + module = importlib.import_module(_LAZY_IMPORTS[name]) + attr = getattr(module, name) + globals()[name] = attr + return attr + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = [ "AudioFloatScaleScorer", diff --git a/tests/unit/cli/test_import_guards.py b/tests/unit/cli/test_import_guards.py new file mode 100644 index 000000000..3b87b2d44 --- /dev/null +++ b/tests/unit/cli/test_import_guards.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Import guard tests to prevent performance regressions. + +These tests verify that importing key entry points does NOT pull in heavy +third-party modules. Each test spawns a fresh subprocess (since sys.modules +is global and sticky within a process) and checks which modules are loaded. + +If a test fails, it means someone added a top-level import that pulls in an +expensive dependency. The fix is to defer that import to point-of-use (inside +the method that actually needs it). +""" + +import subprocess +import sys + +import pytest + + +def _check_forbidden_imports(*, import_statement: str, forbidden: list[str]) -> list[str]: + """ + Run `import_statement` in a subprocess and return any forbidden modules that got loaded. + """ + code = ( + "import sys\n" + f"{import_statement}\n" + "forbidden = " + repr(forbidden) + "\n" + "loaded = [m for m in forbidden if any(k == m or k.startswith(m + '.') for k in sys.modules)]\n" + "if loaded:\n" + " print(','.join(sorted(loaded)))\n" + ) + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + pytest.fail(f"Subprocess crashed: {result.stderr}") + output = result.stdout.strip() + return output.split(",") if output else [] + + +# Heavy modules that should never be loaded during CLI arg parsing. +# This ensures `pyrit_scan --help` stays near-instant (~0.3s). +_CLI_FORBIDDEN = [ + "alembic", + "av", + "azure.storage.blob", + "httpx", + "numpy", + "openai", + "pandas", + "pydantic", + "scipy", + "sqlalchemy", + "torch", + "transformers", +] + +# Heavy modules that should not be loaded by `import pyrit` alone. +_IMPORT_PYRIT_FORBIDDEN = [ + "alembic", + "av", + "azure.storage.blob", + "openai", + "pandas", + "scipy", + "sqlalchemy", + "torch", + "transformers", +] + +# Heavy modules that should not be loaded by importing the PromptTarget base class. +_PROMPT_TARGET_FORBIDDEN = [ + "av", + "pandas", + "scipy", + "torch", + "transformers", +] + + +class TestImportGuards: + """Verify heavy modules are not eagerly loaded at key import points.""" + + def test_cli_arg_parsing_does_not_load_heavy_modules(self): + """ + Importing pyrit_scan's module-level symbols (for --help) must not + pull in any heavy third-party dependencies. + """ + loaded = _check_forbidden_imports( + import_statement="from pyrit.cli.pyrit_scan import parse_args", + forbidden=_CLI_FORBIDDEN, + ) + assert not loaded, ( + f"CLI arg parsing loaded heavy modules: {loaded}. " + f"Move these imports to point-of-use (inside a function/method)." + ) + + def test_import_pyrit_does_not_load_heavy_modules(self): + """ + `import pyrit` must stay fast and not pull in database or ML libraries. + """ + loaded = _check_forbidden_imports( + import_statement="import pyrit", + forbidden=_IMPORT_PYRIT_FORBIDDEN, + ) + assert not loaded, ( + f"`import pyrit` loaded heavy modules: {loaded}. " + f"Check pyrit/__init__.py and ensure heavy submodules are not eagerly imported." + ) + + def test_prompt_target_base_does_not_load_ml_modules(self): + """ + Importing PromptTarget must not pull in ML frameworks like transformers or av. + These are only needed by specific subclasses (HuggingFaceChatTarget, etc.). + """ + loaded = _check_forbidden_imports( + import_statement="from pyrit.prompt_target import PromptTarget", + forbidden=_PROMPT_TARGET_FORBIDDEN, + ) + assert not loaded, ( + f"PromptTarget base class loaded ML modules: {loaded}. " + f"Ensure heavy subclass imports use __getattr__ lazy loading in __init__.py." + ) diff --git a/tests/unit/common/test_convert_local_image_to_data_url.py b/tests/unit/common/test_convert_local_image_to_data_url.py index fc67dc771..26e0d9632 100644 --- a/tests/unit/common/test_convert_local_image_to_data_url.py +++ b/tests/unit/common/test_convert_local_image_to_data_url.py @@ -7,7 +7,7 @@ import pytest -from pyrit.common import convert_local_image_to_data_url +from pyrit.common.data_url_converter import convert_local_image_to_data_url from pyrit.memory.sqlite_memory import SQLiteMemory diff --git a/tests/unit/models/test_storage_io.py b/tests/unit/models/test_storage_io.py index 512b51d00..65e50228a 100644 --- a/tests/unit/models/test_storage_io.py +++ b/tests/unit/models/test_storage_io.py @@ -180,9 +180,9 @@ async def test_azure_blob_storage_io_create_container_client_uses_explicit_sas_t mock_container_client = AsyncMock() with ( - patch("pyrit.models.storage_io.AzureStorageAuth.get_sas_token", new_callable=AsyncMock) as mock_get_sas_token, + patch("pyrit.auth.AzureStorageAuth.get_sas_token", new_callable=AsyncMock) as mock_get_sas_token, patch( - "pyrit.models.storage_io.AsyncContainerClient.from_container_url", return_value=mock_container_client + "azure.storage.blob.aio.ContainerClient.from_container_url", return_value=mock_container_client ) as mock_from_container_url, ): await azure_blob_storage_io._create_container_client_async()