Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ extend-select = [
# Temporary ignores for pyrit/ subdirectories until issue #1176
# https://github.com/Azure/PyRIT/issues/1176 is fully resolved
# TODO: Remove these ignores once the issues are fixed
"pyrit/{auxiliary_attacks,exceptions,models,prompt_converter,prompt_target,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"]
"pyrit/{auxiliary_attacks,exceptions,models,prompt_converter,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"]
"pyrit/__init__.py" = ["D104"]

[tool.ruff.lint.pydocstyle]
Expand Down
7 changes: 7 additions & 0 deletions pyrit/prompt_target/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Prompt targets for PyRIT.

Target implementations for interacting with different services and APIs,
for example sending prompts or transferring content (uploads).
"""

from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target.common.utils import limit_requests_per_minute
Expand Down
26 changes: 20 additions & 6 deletions pyrit/prompt_target/azure_blob_storage_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,18 @@ def __init__(
blob_content_type: SupportedContentType = SupportedContentType.PLAIN_TEXT,
max_requests_per_minute: Optional[int] = None,
) -> None:
"""
Initialize the Azure Blob Storage target.

Args:
container_url (str, Optional): The Azure Storage container URL.
Defaults to the AZURE_STORAGE_ACCOUNT_CONTAINER_URL environment variable.
sas_token (str, Optional): The SAS token for authentication.
Defaults to the AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable.
blob_content_type (SupportedContentType): The content type for blobs.
Defaults to PLAIN_TEXT.
max_requests_per_minute (int, Optional): Maximum number of requests per minute.
"""
self._blob_content_type: str = blob_content_type.value

self._container_url: str = default_values.get_required_value(
Expand All @@ -69,7 +80,7 @@ def __init__(

async def _create_container_client_async(self) -> None:
"""
Creates an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the
Create an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the
AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used
for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication.
"""
Expand Down Expand Up @@ -126,8 +137,13 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st
logger.exception(msg=f"An unexpected error occurred: {exc}")
raise

def _parse_url(self):
"""Parses the Azure Storage Blob URL to extract components."""
def _parse_url(self) -> tuple[str, str]:
"""
Parse the Azure Storage Blob URL to extract components.

Returns:
tuple: A tuple containing the container URL and blob prefix.
"""
parsed_url = urlparse(self._container_url)
path_parts = parsed_url.path.split("/")
container_name = path_parts[1]
Expand All @@ -142,9 +158,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]:
to the provided storage container.

Args:
normalized_prompt (str): A normalized prompt to be sent to the prompt target.
conversation_id (str): The ID of the conversation.
normalizer_id (str): ID provided by the prompt normalizer.
message (Message): A Message to be sent to the target.

Returns:
list[Message]: A list containing the response with the Blob URL.
Expand Down
69 changes: 51 additions & 18 deletions pyrit/prompt_target/azure_ml_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@


class AzureMLChatTarget(PromptChatTarget):
"""
A prompt target for Azure Machine Learning chat endpoints.

This class works with most chat completion Instruct models deployed on Azure AI Machine Learning
Studio endpoints (including but not limited to: mistralai-Mixtral-8x7B-Instruct-v01,
mistralai-Mistral-7B-Instruct-v01, Phi-3.5-MoE-instruct, Phi-3-mini-4k-instruct,
Llama-3.2-3B-Instruct, and Meta-Llama-3.1-8B-Instruct).

Please create or adjust environment variables (endpoint and key) as needed for the model you are using.
"""

endpoint_uri_environment_variable: str = "AZURE_ML_MANAGED_ENDPOINT"
api_key_environment_variable: str = "AZURE_ML_KEY"
Expand All @@ -44,12 +54,7 @@ def __init__(
**param_kwargs,
) -> None:
"""
Initializes an instance of the AzureMLChatTarget class. This class works with most chat completion
Instruct models deployed on Azure AI Machine Learning Studio endpoints
(including but not limited to: mistralai-Mixtral-8x7B-Instruct-v01, mistralai-Mistral-7B-Instruct-v01,
Phi-3.5-MoE-instruct, Phi-3-mini-4k-instruct, Llama-3.2-3B-Instruct, and Meta-Llama-3.1-8B-Instruct).
Please create or adjust environment variables (endpoint and key) as needed for the
model you are using.
Initialize an instance of the AzureMLChatTarget class.

Args:
endpoint (str, Optional): The endpoint URL for the deployed Azure ML model.
Expand Down Expand Up @@ -101,7 +106,7 @@ def _set_env_configuration_vars(
api_key_environment_variable: Optional[str] = None,
) -> None:
"""
Sets the environment configuration variable names from which to pull the endpoint uri and the api key
Set the environment configuration variable names from which to pull the endpoint uri and the api key
to access the deployed Azure ML model. Use this function to set the environment variable names to
however they are named in the .env file and pull the corresponding endpoint uri and api key.
This is the recommended way to pass in a uri and key to access the model endpoint.
Expand All @@ -110,17 +115,14 @@ def _set_env_configuration_vars(
Args:
endpoint_uri_environment_variable (str, optional): The environment variable name for the endpoint uri.
api_key_environment_variable (str, optional): The environment variable name for the api key.

Returns:
None
"""
self.endpoint_uri_environment_variable = endpoint_uri_environment_variable or "AZURE_ML_MANAGED_ENDPOINT"
self.api_key_environment_variable = api_key_environment_variable or "AZURE_ML_KEY"
self._initialize_vars()

def _initialize_vars(self, endpoint: Optional[str] = None, api_key: Optional[str] = None) -> None:
"""
Sets the endpoint and key for accessing the Azure ML model. Use this function to manually
Set the endpoint and key for accessing the Azure ML model. Use this function to manually
pass in your own endpoint uri and api key. Defaults to the values in the .env file for the variables
stored in self.endpoint_uri_environment_variable and self.api_key_environment_variable (which default to
"AZURE_ML_MANAGED_ENDPOINT" and "AZURE_ML_KEY" respectively). It is recommended to set these variables
Expand All @@ -130,9 +132,6 @@ def _initialize_vars(self, endpoint: Optional[str] = None, api_key: Optional[str
Args:
endpoint (str, optional): The endpoint uri for the deployed Azure ML model.
api_key (str, optional): The API key for accessing the Azure ML endpoint.

Returns:
None
"""
self._endpoint = default_values.get_required_value(
env_var_name=self.endpoint_uri_environment_variable, passed_value=endpoint
Expand All @@ -150,8 +149,15 @@ def _set_model_parameters(
**param_kwargs,
) -> None:
"""
Sets the model parameters for generating responses, offering the option to add additional ones not
Set the model parameters for generating responses, offering the option to add additional ones not
explicitly listed.

Args:
max_new_tokens: Maximum number of new tokens to generate.
temperature: Sampling temperature for response generation.
top_p: Nucleus sampling parameter.
repetition_penalty: Penalty for repeating tokens.
**param_kwargs: Additional model parameters.
"""
self._max_new_tokens = max_new_tokens or self._max_new_tokens
self._temperature = temperature or self._temperature
Expand All @@ -162,7 +168,20 @@ def _set_model_parameters(

@limit_requests_per_minute
async def send_prompt_async(self, *, message: Message) -> list[Message]:
"""
Asynchronously send a message to the Azure ML chat target.

Args:
message (Message): The message object containing the prompt to send.

Returns:
list[Message]: A list containing the response from the prompt target.

Raises:
EmptyResponseException: If the response from the chat is empty.
RateLimitException: If the target rate limit is exceeded.
HTTPStatusError: For any other HTTP errors during the process.
"""
self._validate_request(message=message)
request = message.message_pieces[0]

Expand Down Expand Up @@ -207,7 +226,8 @@ async def _complete_chat_async(
messages (list[ChatMessage]): The chat messages objects containing the role and content.

Raises:
Exception: For any errors during the process.
EmptyResponseException: If the response from the chat is empty.
Exception: For any other errors during the process.

Returns:
str: The generated response message.
Expand All @@ -233,7 +253,15 @@ def _construct_http_body(
self,
messages: list[ChatMessage],
) -> dict:
"""Constructs the HTTP request body for the AML online endpoint."""
"""
Construct the HTTP request body for the AML online endpoint.

Args:
messages: List of chat messages to include in the request body.

Returns:
dict: The constructed HTTP request body.
"""
squashed_messages = self.chat_message_normalizer.normalize(messages)
messages_dict = [message.model_dump() for message in squashed_messages]

Expand Down Expand Up @@ -281,5 +309,10 @@ def _validate_request(self, *, message: Message) -> None:
raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.")

def is_json_response_supported(self) -> bool:
"""Indicates that this target supports JSON response format."""
"""
Check if the target supports JSON as a response format.

Returns:
bool: True if JSON response is supported, False otherwise.
"""
return False
16 changes: 12 additions & 4 deletions pyrit/prompt_target/batch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@

def _get_chunks(*args, batch_size: int):
"""
Helper function utilized during prompt batching to chunk based off of size.
Split provided lists into chunks of specified batch size.

Args:
*args: Arguments to chunk; each argument should be a list
batch_size (int): Batch size

Yields:
list: Lists of chunked items.

Raises:
ValueError: When no arguments are provided or when arguments have different lengths.
"""
if len(args) == 0:
raise ValueError("No arguments provided to chunk.")
Expand All @@ -27,8 +32,7 @@ def _get_chunks(*args, batch_size: int):

def _validate_rate_limit_parameters(prompt_target: Optional[PromptTarget], batch_size: int):
"""
Helper function to validate the constraints between Rate Limit (Requests Per Minute)
and batch size.
Validate the constraints between Rate Limit (Requests Per Minute) and batch size.

Args:
prompt_target (PromptTarget): Target to validate
Expand All @@ -52,7 +56,7 @@ async def batch_task_async(
**task_kwargs,
):
"""
Performs provided task in batches and validates parameters using helpers.
Perform provided task in batches and validate parameters using helpers.

Args:
prompt_target(PromptTarget): Target to validate
Expand All @@ -64,6 +68,10 @@ async def batch_task_async(

Returns:
responses(list): List of results from the batched function

Raises:
ValueError: When no items to batch are provided.
ValueError: When number of lists of items to batch does not match number of task arguments.
"""
responses = []

Expand Down
15 changes: 13 additions & 2 deletions pyrit/prompt_target/common/prompt_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ def __init__(
endpoint: str = "",
model_name: str = "",
) -> None:
"""
Initialize the PromptChatTarget.

Args:
max_requests_per_minute (int, Optional): Maximum number of requests per minute.
endpoint (str): The endpoint URL. Defaults to empty string.
model_name (str): The model name. Defaults to empty string.
"""
super().__init__(max_requests_per_minute=max_requests_per_minute, endpoint=endpoint, model_name=model_name)

def set_system_prompt(
Expand All @@ -37,7 +45,10 @@ def set_system_prompt(
labels: Optional[dict[str, str]] = None,
) -> None:
"""
Sets the system prompt for the prompt target. May be overridden by subclasses.
Set the system prompt for the prompt target. May be overridden by subclasses.

Raises:
RuntimeError: If the conversation already exists.
"""
messages = self._memory.get_conversation(conversation_id=conversation_id)

Expand Down Expand Up @@ -68,7 +79,7 @@ def is_json_response_supported(self) -> bool:

def is_response_format_json(self, message_piece: MessagePiece) -> bool:
"""
Checks if the response format is JSON and ensures the target supports it.
Check if the response format is JSON and ensure the target supports it.

Args:
message_piece: A MessagePiece object with a `prompt_metadata` dictionary that may
Expand Down
35 changes: 29 additions & 6 deletions pyrit/prompt_target/common/prompt_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@


class PromptTarget(abc.ABC, Identifier):
_memory: MemoryInterface

"""
A list of PromptConverters that are supported by the prompt target.
An empty list implies that the prompt target supports all converters.
Abstract base class for prompt targets.

A prompt target is a destination where prompts can be sent to interact with various services,
models, or APIs. This class defines the interface that all prompt targets must implement.
"""

_memory: MemoryInterface

#: A list of PromptConverters that are supported by the prompt target.
#: An empty list implies that the prompt target supports all converters.
supported_converters: list

def __init__(
Expand All @@ -27,6 +32,15 @@ def __init__(
endpoint: str = "",
model_name: str = "",
) -> None:
"""
Initialize the PromptTarget.

Args:
verbose (bool): Enable verbose logging. Defaults to False.
max_requests_per_minute (int, Optional): Maximum number of requests per minute.
endpoint (str): The endpoint URL. Defaults to empty string.
model_name (str): The model name. Defaults to empty string.
"""
self._memory = CentralMemory.get_memory_instance()
self._verbose = verbose
self._max_requests_per_minute = max_requests_per_minute
Expand All @@ -39,7 +53,7 @@ def __init__(
@abc.abstractmethod
async def send_prompt_async(self, *, message: Message) -> list[Message]:
"""
Sends a normalized prompt async to the prompt target.
Send a normalized prompt async to the prompt target.

Returns:
list[Message]: A list of message responses. Most targets return a single message,
Expand All @@ -49,7 +63,10 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]:
@abc.abstractmethod
def _validate_request(self, *, message: Message) -> None:
"""
Validates the provided message.
Validate the provided message.

Args:
message: The message to validate.
"""

def set_model_name(self, *, model_name: str) -> None:
Expand All @@ -68,6 +85,12 @@ def dispose_db_engine(self) -> None:
self._memory.dispose_engine()

def get_identifier(self) -> dict:
"""
Get the identifier dictionary for the prompt target.

Returns:
dict: Dictionary containing the target's type, module, endpoint, and model name.
"""
public_attributes = {}
public_attributes["__type__"] = self.__class__.__name__
public_attributes["__module__"] = self.__class__.__module__
Expand Down
2 changes: 1 addition & 1 deletion pyrit/prompt_target/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def validate_top_p(top_p: Optional[float]) -> None:

def limit_requests_per_minute(func: Callable) -> Callable:
"""
A decorator to enforce rate limit of the target through setting requests per minute.
Enforce rate limit of the target through setting requests per minute.
This should be applied to all send_prompt_async() functions on PromptTarget and PromptChatTarget.

Args:
Expand Down
Loading