diff --git a/.env_example b/.env_example index 3765bdbbb..fdf9d715e 100644 --- a/.env_example +++ b/.env_example @@ -267,6 +267,7 @@ AZURE_CONTENT_SAFETY_API_KEY="xxxxx" AZURE_CONTENT_SAFETY_API_ENDPOINT="https://xxxxx.cognitiveservices.azure.com/" HUGGINGFACE_TOKEN="hf_xxxxxxx" +HUGGINGFACE_ENDPOINT="https://router.huggingface.co/v1" GOOGLE_GEMINI_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/openai" GOOGLE_GEMINI_API_KEY = "xxxxx" diff --git a/doc/code/targets/1_openai_chat_target.ipynb b/doc/code/targets/1_openai_chat_target.ipynb index e41952b6c..796810538 100644 --- a/doc/code/targets/1_openai_chat_target.ipynb +++ b/doc/code/targets/1_openai_chat_target.ipynb @@ -353,7 +353,7 @@ "source": [ "## OpenAI Configuration\n", "\n", - "All `OpenAITarget`s can communicate to [Azure OpenAI (AOAI)](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference), [OpenAI](https://platform.openai.com/docs/api-reference/introduction), or other compatible endpoints (e.g., Ollama, Groq).\n", + "All `OpenAITarget`s can communicate to [Azure OpenAI (AOAI)](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference), [OpenAI](https://platform.openai.com/docs/api-reference/introduction), or other compatible endpoints (e.g., Ollama, Groq, HuggingFace).\n", "\n", "The `OpenAIChatTarget` is built to be as cross-compatible as we can make it, while still being as flexible as we can make it by exposing functionality via parameters.\n", "\n", diff --git a/doc/code/targets/1_openai_chat_target.py b/doc/code/targets/1_openai_chat_target.py index 26390695a..bf1f4dbe4 100644 --- a/doc/code/targets/1_openai_chat_target.py +++ b/doc/code/targets/1_openai_chat_target.py @@ -178,7 +178,7 @@ # %% [markdown] # ## OpenAI Configuration # -# All `OpenAITarget`s can communicate to [Azure OpenAI (AOAI)](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference), [OpenAI](https://platform.openai.com/docs/api-reference/introduction), or other compatible endpoints (e.g., Ollama, Groq). +# All `OpenAITarget`s can communicate to [Azure OpenAI (AOAI)](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference), [OpenAI](https://platform.openai.com/docs/api-reference/introduction), or other compatible endpoints (e.g., Ollama, Groq, HuggingFace). # # The `OpenAIChatTarget` is built to be as cross-compatible as we can make it, while still being as flexible as we can make it by exposing functionality via parameters. # diff --git a/doc/getting_started/configuration.md b/doc/getting_started/configuration.md index 2916237de..b46ef813e 100644 --- a/doc/getting_started/configuration.md +++ b/doc/getting_started/configuration.md @@ -37,7 +37,7 @@ from pyrit.setup.initializers import SimpleInitializer await initialize_pyrit_async(memory_db_type="InMemory", initializers=[SimpleInitializer()]) ``` -This gives you an in-memory database and default converter/scorer config — enough to run most notebooks and examples. Replace the endpoint/key/model for your provider (Azure, Ollama, Groq, etc.). +This gives you an in-memory database and default converter/scorer config — enough to run most notebooks and examples. Replace the endpoint/key/model for your provider (Azure, Ollama, Groq, HuggingFace, etc.). ## For Persistent Setup @@ -50,7 +50,7 @@ For anything beyond a quick test — especially `pyrit_scan`, scenarios, and rep :link: ./populating_secrets **Set Up Your .env File** -Create `~/.pyrit/.env` with your provider credentials. Tabbed examples for OpenAI, Azure, Ollama, Groq, and more. +Create `~/.pyrit/.env` with your provider credentials. Tabbed examples for OpenAI, Azure, Ollama, Groq, HuggingFace, and more. :::: ::::{card} 📄 Configuration File (Recommended) diff --git a/doc/getting_started/populating_secrets.md b/doc/getting_started/populating_secrets.md index ecea978be..6c784bfef 100644 --- a/doc/getting_started/populating_secrets.md +++ b/doc/getting_started/populating_secrets.md @@ -63,6 +63,16 @@ OPENAI_CHAT_MODEL="llama3-8b-8192" Get your API key from [console.groq.com](https://console.groq.com/). ::: +:::{tab-item} HuggingFace +```bash +OPENAI_CHAT_ENDPOINT="https://router.huggingface.co/v1" +OPENAI_CHAT_KEY="hf_your-token-here" +OPENAI_CHAT_MODEL="meta-llama/Llama-3.1-8B-Instruct" +``` + +Get your token from [huggingface.co/docs/hub/security-tokens](https://huggingface.co/docs/hub/security-tokens). Browse available models at [huggingface.co/models](https://huggingface.co/models). +::: + :::{tab-item} OpenRouter ```bash OPENAI_CHAT_ENDPOINT="https://openrouter.ai/api/v1" diff --git a/doc/index.md b/doc/index.md index 853f578d2..1f7cf4003 100644 --- a/doc/index.md +++ b/doc/index.md @@ -135,7 +135,7 @@ After installing, configure PyRIT with your AI endpoint credentials and initiali :link: getting_started/populating_secrets **Set Up Your .env File** -Create `~/.pyrit/.env` with your provider credentials. Tabbed examples for OpenAI, Azure, Ollama, Groq, and more. +Create `~/.pyrit/.env` with your provider credentials. Tabbed examples for OpenAI, Azure, Ollama, Groq, HuggingFace, and more. :::: ::::{card} 📄 Config File (Recommended) diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index 71a24b2c8..193429d02 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -2,10 +2,12 @@ # Licensed under the MIT license. import asyncio +import json import logging import os +import warnings from pathlib import Path -from typing import Any, Optional, cast +from typing import Any, cast from transformers import ( AutoModelForCausalLM, @@ -55,46 +57,59 @@ class HuggingFaceChatTarget(PromptChatTarget): def __init__( self, *, - model_id: Optional[str] = None, - model_path: Optional[str] = None, - hf_access_token: Optional[str] = None, + model_id: str | None = None, + model_path: str | None = None, + hf_access_token: str | None = None, use_cuda: bool = False, tensor_format: str = "pt", - necessary_files: Optional[list[str]] = None, + necessary_files: list[str] | None = None, max_new_tokens: int = 20, temperature: float = 1.0, top_p: float = 1.0, + top_k: int | None = None, + do_sample: bool | None = None, + repetition_penalty: float | None = None, + random_seed: int | None = None, skip_special_tokens: bool = True, trust_remote_code: bool = False, - device_map: Optional[str] = None, - torch_dtype: Optional[Any] = None, - attn_implementation: Optional[str] = None, - max_requests_per_minute: Optional[int] = None, - custom_configuration: Optional[TargetConfiguration] = None, - custom_capabilities: Optional[TargetCapabilities] = None, + device_map: str | None = None, + torch_dtype: Any | None = None, + attn_implementation: str | None = None, + max_requests_per_minute: int | None = None, + custom_configuration: TargetConfiguration | None = None, + custom_capabilities: TargetCapabilities | None = None, ) -> None: """ Initialize the HuggingFaceChatTarget. Args: - model_id (Optional[str]): The Hugging Face model ID. Either model_id or model_path must be provided. - model_path (Optional[str]): Path to a local model. Either model_id or model_path must be provided. - hf_access_token (Optional[str]): Hugging Face access token for authentication. + model_id (str | None): The Hugging Face model ID. Either model_id or model_path must be provided. + model_path (str | None): Path to a local model. Either model_id or model_path must be provided. + hf_access_token (str | None): Hugging Face access token for authentication. use_cuda (bool): Whether to use CUDA for GPU acceleration. Defaults to False. tensor_format (str): The tensor format. Defaults to "pt". - necessary_files (Optional[list]): List of necessary model files to download. + necessary_files (list[str] | None): List of necessary model files to download. max_new_tokens (int): Maximum number of new tokens to generate. Defaults to 20. temperature (float): Sampling temperature. Defaults to 1.0. top_p (float): Nucleus sampling probability. Defaults to 1.0. + top_k (int | None): Top-K sampling parameter. Only used when do_sample is True. + Defaults to None (uses model default, typically 50). + do_sample (bool | None): Whether to use sampling instead of greedy decoding. When None, + sampling is automatically enabled if temperature, top_p, or top_k suggest + non-greedy decoding. Defaults to None. + repetition_penalty (float | None): Penalty for repeating tokens. Values > 1.0 discourage + repetition. Defaults to None (uses model default, typically 1.0). + random_seed (int | None): Random seed for deterministic generation. When set, calls + torch.manual_seed() at construction time. Defaults to None. skip_special_tokens (bool): Whether to skip special tokens. Defaults to True. trust_remote_code (bool): Whether to trust remote code execution. Defaults to False. - device_map (Optional[str]): Device mapping strategy. - torch_dtype (Optional[torch.dtype]): Torch data type for model weights. - attn_implementation (Optional[str]): Attention implementation type. - max_requests_per_minute (Optional[int]): The maximum number of requests per minute. Defaults to None. - custom_configuration (Optional[TargetConfiguration]): Override the default configuration for this target - instance. Defaults to None - custom_capabilities (TargetCapabilities, Optional): **Deprecated.** Use + device_map (str | None): Device mapping strategy. + torch_dtype (Any | None): Torch data type for model weights. + attn_implementation (str | None): Attention implementation type. + max_requests_per_minute (int | None): The maximum number of requests per minute. Defaults to None. + custom_configuration (TargetConfiguration | None): Override the default configuration for this target + instance. Defaults to None. + custom_capabilities (TargetCapabilities | None): **Deprecated.** Use ``custom_configuration`` instead. Will be removed in v0.14.0. Raises: @@ -148,8 +163,17 @@ def __init__( self.max_new_tokens = max_new_tokens self._temperature = temperature self._top_p = top_p + self._top_k = top_k + self._do_sample = do_sample + self._repetition_penalty = repetition_penalty + self._random_seed = random_seed self.skip_special_tokens = skip_special_tokens + self._warn_if_sampling_params_without_do_sample() + + self._generation_params = self._build_generation_params() + self._seed_rng() + if self.use_cuda and not torch.cuda.is_available(): raise RuntimeError("CUDA requested but not available.") @@ -166,6 +190,10 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "temperature": self._temperature, "top_p": self._top_p, + "top_k": self._top_k, + "do_sample": self._do_sample, + "repetition_penalty": self._repetition_penalty, + "random_seed": self._random_seed, "max_new_tokens": self.max_new_tokens, "skip_special_tokens": self.skip_special_tokens, "use_cuda": self.use_cuda, @@ -300,6 +328,9 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me """ Send a normalized prompt asynchronously to the HuggingFace model. + Builds the full chat history (system, user, assistant turns) from the normalized + conversation and passes it through the model's chat template. + Args: normalized_conversation (list[Message]): The full conversation (history + current message) after running the normalization @@ -310,21 +341,15 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me Raises: EmptyResponseException: If the model generates an empty response. - Exception: If any error occurs during inference. """ - # Load the model and tokenizer using the encapsulated method await self.load_model_and_tokenizer_task - message = normalized_conversation[-1] - request = message.message_pieces[0] - prompt_template = request.converted_value + request = normalized_conversation[-1].message_pieces[0] - logger.info(f"Sending the following prompt to the HuggingFace model: {prompt_template}") + messages = self._build_chat_messages(normalized_conversation=normalized_conversation) - # Prepare the input messages using chat templates - messages = [{"role": "user", "content": prompt_template}] + logger.info(f"Sending the following messages to the HuggingFace model: {messages}") - # Apply chat template via the _apply_chat_template method tokenized_chat = self._apply_chat_template(messages) input_ids = tokenized_chat["input_ids"] attention_mask = tokenized_chat["attention_mask"] @@ -332,28 +357,21 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me logger.info(f"Tokenized chat: {input_ids}") try: - # Ensure model is on the correct device (should already be the case from `load_model_and_tokenizer`) + # Ensure model is on the correct device (should already be, but safeguard for device changes) self.model.to(self.device) - # Record the length of the input tokens to later extract only the generated tokens + # Record input length to extract only newly generated tokens input_length = input_ids.shape[-1] - # Generate the response + generate_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, **self._generation_params} + logger.info("Generating response from model...") - generated_ids = self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_new_tokens=self.max_new_tokens, - temperature=self._temperature, - top_p=self._top_p, - ) + generated_ids = self.model.generate(**generate_kwargs) - logger.info(f"Generated IDs: {generated_ids}") # Log the generated IDs + logger.info(f"Generated IDs: {generated_ids}") - # Extract the assistant's response by slicing the generated tokens after the input tokens generated_tokens = generated_ids[0][input_length:] - # Decode the assistant's response from the generated token IDs assistant_response = cast( "str", self.tokenizer.decode(generated_tokens, skip_special_tokens=self.skip_special_tokens), @@ -366,10 +384,15 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me model_identifier = self.model_id or self.model_path + effective_config = self._get_effective_generation_config() + response = construct_response_from_request( request=request, response_text_pieces=[assistant_response], - prompt_metadata={"model_id": model_identifier or ""}, + prompt_metadata={ + "model_id": model_identifier or "", + "effective_generation_config": json.dumps(effective_config, default=str), + }, ) return [response] @@ -377,6 +400,121 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me logger.error(f"Error occurred during inference: {e}") raise + def _build_chat_messages(self, *, normalized_conversation: list[Message]) -> list[dict[str, str]]: + """ + Build a list of chat message dicts from the full normalized conversation. + + Includes system, user, and assistant messages from the conversation history + so that the model's chat template receives the complete context. + + Args: + normalized_conversation (list[Message]): The full normalized conversation. + + Returns: + list[dict[str, str]]: Messages formatted for the chat template. + """ + messages: list[dict[str, str]] = [] + for msg in normalized_conversation: + piece = msg.message_pieces[0] + role = piece.api_role + content = piece.converted_value or "" + messages.append({"role": role, "content": content}) + return messages + + def set_random_seed(self, random_seed: int) -> None: + """ + Set a new random seed and immediately re-seed the RNG. + + Allows re-seeding between conversations or experiments for controlled + reproducibility. The initial seed (if any) is applied once at construction + time; call this method to change it later. + + Args: + random_seed (int): The random seed value. + """ + self._random_seed = random_seed + self._seed_rng() + + def _build_generation_params(self) -> dict[str, Any]: + """ + Build the static generation parameters dict. + + Computed once at init. Only includes optional parameters when they + are explicitly set (not None), allowing the model's own + generation_config defaults to apply otherwise. + + Returns: + dict[str, Any]: Static keyword arguments for model.generate(). + """ + params: dict[str, Any] = { + "max_new_tokens": self.max_new_tokens, + "temperature": self._temperature, + "top_p": self._top_p, + } + if self._top_k is not None: + params["top_k"] = self._top_k + if self._do_sample is not None: + params["do_sample"] = self._do_sample + if self._repetition_penalty is not None: + params["repetition_penalty"] = self._repetition_penalty + return params + + def _seed_rng(self) -> None: + """ + Seed the random number generators for deterministic generation. + + When ``self._random_seed`` is set, seeds both CPU and CUDA RNGs before each + ``model.generate()`` call. This enables reproducible results when all other + parameters are held constant. + + Note: + This sets global torch RNG state. Concurrent generation calls on + the same process may interfere with determinism. + """ + if self._random_seed is not None: + import torch + + torch.manual_seed(self._random_seed) + if self.use_cuda: + torch.cuda.manual_seed_all(self._random_seed) + + def _get_effective_generation_config(self) -> dict[str, Any]: + """ + Return the effective generation parameters that were used for the last call. + + Combines the model's own generation_config with the explicit overrides from + this target instance, so that the stored metadata reflects what actually ran. + + Returns: + dict[str, Any]: Merged generation configuration. + """ + effective: dict[str, Any] = {} + if hasattr(self.model, "generation_config"): + effective = self.model.generation_config.to_dict() + + effective.update(self._generation_params) + if self._random_seed is not None: + effective["random_seed"] = self._random_seed + return effective + + def _warn_if_sampling_params_without_do_sample(self) -> None: + """ + Emit a warning when sampling parameters are set but do_sample is not explicitly True. + + Sampling-specific parameters (temperature != 1.0, top_p != 1.0, top_k) are + ignored by HuggingFace's generate() unless do_sample=True. This helps users + avoid silent misconfiguration. + """ + has_sampling_override = self._temperature != 1.0 or self._top_p != 1.0 or self._top_k is not None + if has_sampling_override and self._do_sample is not True: + warnings.warn( + "Sampling parameters (temperature, top_p, top_k) are set but do_sample is not True. " + "HuggingFace ignores these parameters during greedy decoding. " + "Set do_sample=True to enable sampling.", + UserWarning, + stacklevel=3, + ) + def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: """ Apply the chat template to the input messages and tokenize them. diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index da47fa34a..c3ae04bf6 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -2,8 +2,9 @@ # Licensed under the MIT license. import logging -from typing import Optional +import warnings +from pyrit.common.deprecation import print_deprecation_message from pyrit.common.net_utility import make_request_and_raise_if_error_async from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, construct_response_from_request @@ -19,7 +20,10 @@ class HuggingFaceEndpointTarget(PromptTarget): """ The HuggingFaceEndpointTarget interacts with HuggingFace models hosted on cloud endpoints. - Inherits from PromptTarget to comply with the current design standards. + .. deprecated:: 0.13.0 + Use ``OpenAIChatTarget`` with ``endpoint="https://router.huggingface.co/v1"`` + and ``api_key=HUGGINGFACE_TOKEN`` instead. The HuggingFace Inference Providers API + is OpenAI-compatible, making this target redundant. Will be removed in v0.15.0. """ def __init__( @@ -31,10 +35,13 @@ def __init__( max_tokens: int = 400, temperature: float = 1.0, top_p: float = 1.0, - max_requests_per_minute: Optional[int] = None, + top_k: int | None = None, + do_sample: bool | None = None, + repetition_penalty: float | None = None, + max_requests_per_minute: int | None = None, verbose: bool = False, - custom_configuration: Optional[TargetConfiguration] = None, - custom_capabilities: Optional[TargetCapabilities] = None, + custom_configuration: TargetConfiguration | None = None, + custom_capabilities: TargetCapabilities | None = None, ) -> None: """ Initialize the HuggingFaceEndpointTarget with API credentials and model parameters. @@ -43,15 +50,27 @@ def __init__( hf_token (str): The Hugging Face token for authenticating with the Hugging Face endpoint. endpoint (str): The endpoint URL for the Hugging Face model. model_id (str): The model ID to be used at the endpoint. - max_tokens (int, Optional): The maximum number of tokens to generate. Defaults to 400. - temperature (float, Optional): The sampling temperature to use. Defaults to 1.0. - top_p (float, Optional): The cumulative probability for nucleus sampling. Defaults to 1.0. - max_requests_per_minute (Optional[int]): The maximum number of requests per minute. Defaults to None. - verbose (bool, Optional): Flag to enable verbose logging. Defaults to False. - custom_configuration (Optional[TargetConfiguration]): Custom configuration for this target instance. - custom_capabilities (TargetCapabilities, Optional): **Deprecated.** Use + max_tokens (int): The maximum number of tokens to generate. Defaults to 400. + temperature (float): The sampling temperature to use. Defaults to 1.0. + top_p (float): The cumulative probability for nucleus sampling. Defaults to 1.0. + top_k (int | None): Top-K sampling parameter. Only used when do_sample is True. + Defaults to None (uses model default). + do_sample (bool | None): Whether to use sampling instead of greedy decoding. + Defaults to None. + repetition_penalty (float | None): Penalty for repeating tokens. Values > 1.0 + discourage repetition. Defaults to None (uses model default). + max_requests_per_minute (int | None): The maximum number of requests per minute. Defaults to None. + verbose (bool): Flag to enable verbose logging. Defaults to False. + custom_configuration (TargetConfiguration | None): Custom configuration for this target instance. + custom_capabilities (TargetCapabilities | None): **Deprecated.** Use ``custom_configuration`` instead. Will be removed in v0.14.0. """ + print_deprecation_message( + old_item=HuggingFaceEndpointTarget, + new_item="OpenAIChatTarget with endpoint='https://router.huggingface.co/v1'", + removed_in="v0.15.0", + ) + super().__init__( max_requests_per_minute=max_requests_per_minute, verbose=verbose, @@ -70,6 +89,11 @@ def __init__( self.max_tokens = max_tokens self._temperature = temperature self._top_p = top_p + self._top_k = top_k + self._do_sample = do_sample + self._repetition_penalty = repetition_penalty + + self._warn_if_sampling_params_without_do_sample() def _build_identifier(self) -> ComponentIdentifier: """ @@ -82,6 +106,9 @@ def _build_identifier(self) -> ComponentIdentifier: params={ "temperature": self._temperature, "top_p": self._top_p, + "top_k": self._top_k, + "do_sample": self._do_sample, + "repetition_penalty": self._repetition_penalty, "max_tokens": self.max_tokens, }, ) @@ -106,13 +133,20 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me message = normalized_conversation[-1] request = message.message_pieces[0] headers = {"Authorization": f"Bearer {self.hf_token}"} + parameters: dict[str, object] = { + "max_tokens": self.max_tokens, + "temperature": self._temperature, + "top_p": self._top_p, + } + if self._top_k is not None: + parameters["top_k"] = self._top_k + if self._do_sample is not None: + parameters["do_sample"] = self._do_sample + if self._repetition_penalty is not None: + parameters["repetition_penalty"] = self._repetition_penalty payload: dict[str, object] = { "inputs": request.converted_value, - "parameters": { - "max_tokens": self.max_tokens, - "temperature": self._temperature, - "top_p": self._top_p, - }, + "parameters": parameters, } logger.info(f"Sending the following prompt to the cloud endpoint: {request.converted_value}") @@ -161,3 +195,20 @@ def _validate_request(self, *, normalized_conversation: list[Message]) -> None: n_pieces = len(message.message_pieces) if n_pieces != 1: raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") + + def _warn_if_sampling_params_without_do_sample(self) -> None: + """ + Emit a warning when sampling parameters are set but do_sample is not explicitly True. + + Sampling-specific parameters (temperature != 1.0, top_p != 1.0, top_k) are + ignored by HuggingFace unless do_sample=True. + """ + has_sampling_override = self._temperature != 1.0 or self._top_p != 1.0 or self._top_k is not None + if has_sampling_override and self._do_sample is not True: + warnings.warn( + "Sampling parameters (temperature, top_p, top_k) are set but do_sample is not True. " + "HuggingFace ignores these parameters during greedy decoding. " + "Set do_sample=True to enable sampling.", + UserWarning, + stacklevel=3, + ) diff --git a/tests/integration/targets/test_hugging_face_integration.py b/tests/integration/targets/test_hugging_face_integration.py new file mode 100644 index 000000000..184fcbfc4 --- /dev/null +++ b/tests/integration/targets/test_hugging_face_integration.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Integration tests for HuggingFace via OpenAI-compatible targets. + +The HuggingFace Inference Providers API is OpenAI-compatible, so +OpenAIChatTarget and OpenAIResponseTarget work directly with +the HUGGINGFACE_ENDPOINT and HUGGINGFACE_TOKEN env vars. + +Requires (loaded from .env by initialize_pyrit_async): +- HUGGINGFACE_TOKEN: HuggingFace API token +- HUGGINGFACE_ENDPOINT: HuggingFace router URL (e.g. https://router.huggingface.co/v1) +""" + +import os + +import pytest + +from pyrit.models import MessagePiece +from pyrit.prompt_target import OpenAIChatTarget, OpenAIResponseTarget + +DEFAULT_HF_MODEL = "meta-llama/Llama-3.1-8B-Instruct" + + +@pytest.fixture() +def hf_token(): + token = os.environ.get("HUGGINGFACE_TOKEN") + if not token: + pytest.skip("HUGGINGFACE_TOKEN environment variable is not set") + return token + + +@pytest.fixture() +def hf_endpoint(): + endpoint = os.environ.get("HUGGINGFACE_ENDPOINT") + if not endpoint: + pytest.skip("HUGGINGFACE_ENDPOINT environment variable is not set") + return endpoint + + +@pytest.fixture() +def hf_chat_target(hf_token, hf_endpoint, sqlite_instance) -> OpenAIChatTarget: + return OpenAIChatTarget( + endpoint=hf_endpoint, + api_key=hf_token, + model_name=DEFAULT_HF_MODEL, + max_tokens=30, + ) + + +@pytest.fixture() +def hf_response_target(hf_token, hf_endpoint, sqlite_instance) -> OpenAIResponseTarget: + return OpenAIResponseTarget( + endpoint=hf_endpoint, + api_key=hf_token, + model_name=DEFAULT_HF_MODEL, + max_output_tokens=30, + ) + + +# ============================================================================ +# Chat Completions API (/v1/chat/completions) +# ============================================================================ + + +@pytest.mark.run_only_if_all_tests +@pytest.mark.asyncio +async def test_chat_completion_basic(hf_chat_target): + """Verify a simple prompt returns a non-empty response via the HF router.""" + msg = MessagePiece(role="user", original_value="What is 2+2? Answer with just the number.").to_message() + response = await hf_chat_target.send_prompt_async(message=msg) + + assert response is not None + assert len(response) >= 1 + text = response[0].message_pieces[0].original_value + assert isinstance(text, str) + assert len(text) > 0 + + +@pytest.mark.run_only_if_all_tests +@pytest.mark.asyncio +async def test_chat_completion_with_temperature(hf_token, hf_endpoint, sqlite_instance): + """Verify temperature param is accepted by the HF router.""" + target = OpenAIChatTarget( + endpoint=hf_endpoint, + api_key=hf_token, + model_name=DEFAULT_HF_MODEL, + max_tokens=30, + temperature=0.7, + ) + + msg = MessagePiece(role="user", original_value="Say hello in one word.").to_message() + response = await target.send_prompt_async(message=msg) + + assert response is not None + assert len(response) >= 1 + text = response[0].message_pieces[0].original_value + assert isinstance(text, str) + assert len(text) > 0 + + +@pytest.mark.run_only_if_all_tests +@pytest.mark.asyncio +async def test_chat_completion_identifier(hf_chat_target): + """Verify the component identifier reflects the HF endpoint and model.""" + identifier = hf_chat_target.get_identifier() + assert "router.huggingface.co" in identifier.params["endpoint"] + assert identifier.params["model_name"] == DEFAULT_HF_MODEL + + +# ============================================================================ +# Responses API (/v1/responses) +# ============================================================================ + + +@pytest.mark.run_only_if_all_tests +@pytest.mark.asyncio +async def test_response_api_basic(hf_response_target): + """Verify a simple prompt returns a non-empty response via the Responses API.""" + msg = MessagePiece(role="user", original_value="What is 2+2? Answer with just the number.").to_message() + response = await hf_response_target.send_prompt_async(message=msg) + + assert response is not None + assert len(response) >= 1 + text = response[0].message_pieces[0].original_value + assert isinstance(text, str) + assert len(text) > 0 + + +@pytest.mark.run_only_if_all_tests +@pytest.mark.asyncio +async def test_response_api_with_temperature(hf_token, hf_endpoint, sqlite_instance): + """Verify temperature param is accepted by the Responses API on HF.""" + target = OpenAIResponseTarget( + endpoint=hf_endpoint, + api_key=hf_token, + model_name=DEFAULT_HF_MODEL, + max_output_tokens=30, + temperature=0.7, + ) + + msg = MessagePiece(role="user", original_value="Say hello in one word.").to_message() + response = await target.send_prompt_async(message=msg) + + assert response is not None + assert len(response) >= 1 + text = response[0].message_pieces[0].original_value + assert isinstance(text, str) + assert len(text) > 0 + + +@pytest.mark.run_only_if_all_tests +@pytest.mark.asyncio +async def test_response_api_identifier(hf_response_target): + """Verify the component identifier reflects the HF endpoint and model.""" + identifier = hf_response_target.get_identifier() + assert "router.huggingface.co" in identifier.params["endpoint"] + assert identifier.params["model_name"] == DEFAULT_HF_MODEL diff --git a/tests/unit/prompt_target/target/test_hugging_face_endpoint_target.py b/tests/unit/prompt_target/target/test_hugging_face_endpoint_target.py index 3b4daadff..72a58173f 100644 --- a/tests/unit/prompt_target/target/test_hugging_face_endpoint_target.py +++ b/tests/unit/prompt_target/target/test_hugging_face_endpoint_target.py @@ -1,12 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from unittest.mock import AsyncMock, MagicMock, patch + import pytest +from pyrit.models import Message, MessagePiece from pyrit.prompt_target.hugging_face.hugging_face_endpoint_target import ( HuggingFaceEndpointTarget, ) +# HuggingFaceEndpointTarget emits a DeprecationWarning on construction +pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") + @pytest.fixture def hugging_face_endpoint_target(patch_central_database) -> HuggingFaceEndpointTarget: @@ -84,3 +90,245 @@ def test_valid_temperature_and_top_p(patch_central_database): ) assert target._temperature == 1.5 assert target._top_p == 0.9 + + +def test_identifier_includes_generation_params(): + """New generation params (top_k, do_sample, repetition_penalty) appear in the identifier.""" + target = HuggingFaceEndpointTarget( + hf_token="test_token", + endpoint="https://api-inference.huggingface.co/models/test-model", + model_id="test-model", + top_k=40, + do_sample=True, + repetition_penalty=1.2, + ) + identifier = target.get_identifier() + assert identifier.params["top_k"] == 40 + assert identifier.params["do_sample"] is True + assert identifier.params["repetition_penalty"] == 1.2 + + +def test_identifier_excludes_none_generation_params(): + """None-valued generation params are excluded from the identifier.""" + target = HuggingFaceEndpointTarget( + hf_token="test_token", + endpoint="https://api-inference.huggingface.co/models/test-model", + model_id="test-model", + ) + identifier = target.get_identifier() + assert "top_k" not in identifier.params + assert "do_sample" not in identifier.params + assert "repetition_penalty" not in identifier.params + + +def test_sampling_params_without_do_sample_warns(): + """Setting temperature != 1.0 without do_sample=True emits a warning.""" + with pytest.warns(UserWarning, match="do_sample is not True"): + HuggingFaceEndpointTarget( + hf_token="test_token", + endpoint="https://api-inference.huggingface.co/models/test-model", + model_id="test-model", + temperature=0.7, + ) + + +def test_sampling_params_with_do_sample_no_warning(): + """Setting temperature != 1.0 with do_sample=True does not warn.""" + import warnings as _warnings + + with _warnings.catch_warnings(): + _warnings.simplefilter("error", UserWarning) + HuggingFaceEndpointTarget( + hf_token="test_token", + endpoint="https://api-inference.huggingface.co/models/test-model", + model_id="test-model", + temperature=0.7, + do_sample=True, + ) + + +@pytest.mark.filterwarnings("default::DeprecationWarning") +def test_init_emits_deprecation_warning(): + """HuggingFaceEndpointTarget emits a DeprecationWarning on construction.""" + with pytest.warns(DeprecationWarning, match="deprecated and will be removed"): + HuggingFaceEndpointTarget( + hf_token="test_token", + endpoint="https://api-inference.huggingface.co/models/test-model", + model_id="test-model", + ) + + +def _make_user_message(text: str) -> Message: + """Helper to create a single-piece user Message.""" + return Message( + message_pieces=[ + MessagePiece( + role="user", + original_value=text, + converted_value=text, + converted_value_data_type="text", + ) + ] + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_send_prompt_async_list_response(): + """Verify send_prompt_async handles a list response from the HF API.""" + target = HuggingFaceEndpointTarget( + hf_token="test_token", + endpoint="https://api-inference.huggingface.co/models/test-model", + model_id="test-model", + ) + + mock_response = MagicMock() + mock_response.json.return_value = [{"generated_text": "Hello from HF"}] + + with patch( + "pyrit.prompt_target.hugging_face.hugging_face_endpoint_target.make_request_and_raise_if_error_async", + new_callable=AsyncMock, + return_value=mock_response, + ): + message = _make_user_message("test prompt") + response = await target.send_prompt_async(message=message) + + assert len(response) == 1 + assert response[0].message_pieces[0].original_value == "Hello from HF" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_send_prompt_async_dict_response(): + """Verify send_prompt_async handles a dict response from the HF API.""" + target = HuggingFaceEndpointTarget( + hf_token="test_token", + endpoint="https://api-inference.huggingface.co/models/test-model", + model_id="test-model", + ) + + mock_response = MagicMock() + mock_response.json.return_value = {"generated_text": "Dict response"} + + with patch( + "pyrit.prompt_target.hugging_face.hugging_face_endpoint_target.make_request_and_raise_if_error_async", + new_callable=AsyncMock, + return_value=mock_response, + ): + message = _make_user_message("test prompt") + response = await target.send_prompt_async(message=message) + + assert len(response) == 1 + assert response[0].message_pieces[0].original_value == "Dict response" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_send_prompt_async_passes_optional_params_in_payload(): + """Verify optional generation params are included in the HTTP payload.""" + target = HuggingFaceEndpointTarget( + hf_token="test_token", + endpoint="https://api-inference.huggingface.co/models/test-model", + model_id="test-model", + top_k=40, + do_sample=True, + repetition_penalty=1.2, + ) + + mock_response = MagicMock() + mock_response.json.return_value = [{"generated_text": "response"}] + + with patch( + "pyrit.prompt_target.hugging_face.hugging_face_endpoint_target.make_request_and_raise_if_error_async", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_request: + message = _make_user_message("test prompt") + await target.send_prompt_async(message=message) + + call_kwargs = mock_request.call_args[1] + params = call_kwargs["request_body"]["parameters"] + assert params["top_k"] == 40 + assert params["do_sample"] is True + assert params["repetition_penalty"] == 1.2 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_send_prompt_async_omits_none_params_from_payload(): + """Verify None-valued optional params are not in the HTTP payload.""" + target = HuggingFaceEndpointTarget( + hf_token="test_token", + endpoint="https://api-inference.huggingface.co/models/test-model", + model_id="test-model", + ) + + mock_response = MagicMock() + mock_response.json.return_value = [{"generated_text": "response"}] + + with patch( + "pyrit.prompt_target.hugging_face.hugging_face_endpoint_target.make_request_and_raise_if_error_async", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_request: + message = _make_user_message("test prompt") + await target.send_prompt_async(message=message) + + call_kwargs = mock_request.call_args[1] + params = call_kwargs["request_body"]["parameters"] + assert "top_k" not in params + assert "do_sample" not in params + assert "repetition_penalty" not in params + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_send_prompt_async_metadata_contains_model_id(): + """Verify prompt_metadata includes the model_id.""" + target = HuggingFaceEndpointTarget( + hf_token="test_token", + endpoint="https://api-inference.huggingface.co/models/test-model", + model_id="test-model", + ) + + mock_response = MagicMock() + mock_response.json.return_value = [{"generated_text": "response"}] + + with patch( + "pyrit.prompt_target.hugging_face.hugging_face_endpoint_target.make_request_and_raise_if_error_async", + new_callable=AsyncMock, + return_value=mock_response, + ): + message = _make_user_message("test prompt") + response = await target.send_prompt_async(message=message) + + metadata = response[0].message_pieces[0].prompt_metadata + assert metadata["model_id"] == "test-model" + + +def test_validate_request_rejects_multiple_pieces(): + """Verify _validate_request raises for messages with multiple pieces.""" + target = HuggingFaceEndpointTarget( + hf_token="test_token", + endpoint="https://api-inference.huggingface.co/models/test-model", + model_id="test-model", + ) + + piece1 = MessagePiece( + role="user", + original_value="first", + converted_value="first", + converted_value_data_type="text", + conversation_id="conv1", + ) + piece2 = MessagePiece( + role="user", + original_value="second", + converted_value="second", + converted_value_data_type="text", + conversation_id="conv1", + ) + message = Message(message_pieces=[piece1, piece2]) + + with pytest.raises(ValueError, match="single message piece"): + target._validate_request(normalized_conversation=[message]) diff --git a/tests/unit/prompt_target/target/test_huggingface_chat_target.py b/tests/unit/prompt_target/target/test_huggingface_chat_target.py index 4ad6b118a..33a1025c6 100644 --- a/tests/unit/prompt_target/target/test_huggingface_chat_target.py +++ b/tests/unit/prompt_target/target/test_huggingface_chat_target.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import json from asyncio import Task from unittest.mock import AsyncMock, MagicMock, patch @@ -322,3 +323,242 @@ async def test_hugging_face_chat_sets_endpoint_and_rate_limit(patch_central_data # HuggingFaceChatTarget doesn't set an endpoint (it's local), so it should be empty assert not identifier.params.get("endpoint") assert target._max_requests_per_minute == 30 + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_identifier_includes_generation_params(): + """New generation params (top_k, do_sample, repetition_penalty, random_seed) appear in the identifier.""" + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + top_k=40, + do_sample=True, + repetition_penalty=1.2, + random_seed=42, + temperature=0.7, + ) + identifier = target.get_identifier() + assert identifier.params["top_k"] == 40 + assert identifier.params["do_sample"] is True + assert identifier.params["repetition_penalty"] == 1.2 + assert identifier.params["random_seed"] == 42 + assert identifier.params["temperature"] == 0.7 + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_identifier_excludes_none_generation_params(): + """None-valued generation params are excluded from the identifier (backward compatibility).""" + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + ) + identifier = target.get_identifier() + assert "top_k" not in identifier.params + assert "do_sample" not in identifier.params + assert "repetition_penalty" not in identifier.params + assert "random_seed" not in identifier.params + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_generate_passes_new_params(): + """Verify top_k, do_sample, repetition_penalty are forwarded to model.generate().""" + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + top_k=40, + do_sample=True, + repetition_penalty=1.2, + ) + await target.load_model_and_tokenizer() + + message_piece = MessagePiece( + role="user", + original_value="test prompt", + converted_value="test prompt", + converted_value_data_type="text", + ) + message = Message(message_pieces=[message_piece]) + await target.send_prompt_async(message=message) + + call_kwargs = target.model.generate.call_args[1] + assert call_kwargs["top_k"] == 40 + assert call_kwargs["do_sample"] is True + assert call_kwargs["repetition_penalty"] == 1.2 + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_generate_omits_none_params(): + """When optional params are None, they should not be passed to model.generate().""" + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + ) + await target.load_model_and_tokenizer() + + message_piece = MessagePiece( + role="user", + original_value="test prompt", + converted_value="test prompt", + converted_value_data_type="text", + ) + message = Message(message_pieces=[message_piece]) + await target.send_prompt_async(message=message) + + call_kwargs = target.model.generate.call_args[1] + assert "top_k" not in call_kwargs + assert "do_sample" not in call_kwargs + assert "repetition_penalty" not in call_kwargs + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_random_seed_calls_manual_seed_at_init(): + """When random_seed is set, torch.manual_seed is called during construction.""" + with patch("torch.manual_seed") as mock_manual_seed: + HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + random_seed=42, + ) + mock_manual_seed.assert_called_once_with(42) + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_no_random_seed_does_not_call_manual_seed(): + """When random_seed is None, torch.manual_seed is not called.""" + with patch("torch.manual_seed") as mock_manual_seed: + HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + ) + mock_manual_seed.assert_not_called() + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_set_random_seed_reseeds_rng(): + """Calling set_random_seed updates the seed and immediately re-seeds the RNG.""" + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + ) + with patch("torch.manual_seed") as mock_manual_seed: + target.set_random_seed(99) + mock_manual_seed.assert_called_once_with(99) + assert target._random_seed == 99 + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_sampling_params_without_do_sample_warns(): + """Setting temperature != 1.0 without do_sample=True emits a warning.""" + with pytest.warns(UserWarning, match="do_sample is not True"): + HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + temperature=0.7, + ) + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_sampling_params_with_do_sample_no_warning(): + """Setting temperature != 1.0 with do_sample=True does not warn.""" + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + temperature=0.7, + do_sample=True, + ) + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_default_params_no_warning(): + """Default parameters (temperature=1.0, top_p=1.0) do not trigger warning.""" + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + ) + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_full_conversation_sent_to_chat_template(): + """Verify system and user messages from the full conversation are sent to the chat template.""" + target = HuggingFaceChatTarget(model_id="test_model", use_cuda=False) + await target.load_model_and_tokenizer() + + system_piece = MessagePiece( + role="system", + original_value="You are a helpful assistant.", + converted_value="You are a helpful assistant.", + converted_value_data_type="text", + conversation_id="conv1", + sequence=0, + ) + user_piece = MessagePiece( + role="user", + original_value="Hello", + converted_value="Hello", + converted_value_data_type="text", + conversation_id="conv1", + sequence=1, + ) + system_msg = Message(message_pieces=[system_piece]) + user_msg = Message(message_pieces=[user_piece]) + + with patch.object(target, "_apply_chat_template", wraps=target._apply_chat_template) as mock_template: + await target._send_prompt_to_target_async(normalized_conversation=[system_msg, user_msg]) + + call_args = mock_template.call_args[0][0] + assert len(call_args) == 2 + assert call_args[0] == {"role": "system", "content": "You are a helpful assistant."} + assert call_args[1] == {"role": "user", "content": "Hello"} + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +async def test_effective_generation_config_in_metadata(): + """Verify effective generation config is stored in response prompt_metadata.""" + target = HuggingFaceChatTarget( + model_id="test_model", + use_cuda=False, + top_k=40, + do_sample=True, + random_seed=42, + ) + await target.load_model_and_tokenizer() + + # Mock generation_config on the model + mock_gen_config = MagicMock() + mock_gen_config.to_dict.return_value = {"eos_token_id": 2, "bos_token_id": 1} + target.model.generation_config = mock_gen_config + + message_piece = MessagePiece( + role="user", + original_value="test", + converted_value="test", + converted_value_data_type="text", + ) + message = Message(message_pieces=[message_piece]) + + response = await target.send_prompt_async(message=message) + metadata = response[0].message_pieces[0].prompt_metadata + effective_config = json.loads(metadata["effective_generation_config"]) + + assert effective_config["top_k"] == 40 + assert effective_config["do_sample"] is True + assert effective_config["random_seed"] == 42 + assert effective_config["temperature"] == 1.0 + # Model defaults should also be present + assert effective_config["eos_token_id"] == 2