diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index e66d9ae625..18282fe22b 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -537,7 +537,7 @@ def request_to_pyrit_message( ) for p in request.pieces ] - return PyritMessage(pieces) + return PyritMessage(message_pieces=pieces) # ============================================================================ diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index fafa8ea7e9..9e0fc454ba 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -481,7 +481,7 @@ async def add_prepended_conversation_to_memory_async( turn_count = 0 for i, message in enumerate(valid_messages): - message_copy = message.duplicate_message() + message_copy = message.duplicate() message_copy.set_simulated_role() diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 50184d57a3..dfd57e2515 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -797,7 +797,7 @@ async def _generate_next_prompt_async(self, context: CrescendoAttackContext) -> if context.next_message: self._logger.debug("Using custom message, bypassing adversarial chat") # Duplicate to ensure fresh IDs (avoids conflicts if message was already in memory) - message = context.next_message.duplicate_message() + message = context.next_message.duplicate() context.next_message = None # Clear for future turns return message diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index ee21855644..04b5aef915 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -608,7 +608,7 @@ async def _send_initial_prompt_to_target_async(self) -> Message: self.objective_target_conversation_id = str(uuid.uuid4()) # Duplicate to ensure fresh IDs (avoids conflicts if message was already in memory) - message = self._initial_prompt.duplicate_message() + message = self._initial_prompt.duplicate() self._initial_prompt = None # Clear for future turns # Store the prompt text for reference diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index e7842d3d72..f3c8aeedae 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -293,7 +293,7 @@ def _get_message(self, context: SingleTurnAttackContext[Any]) -> Message: """ if context.next_message: # Deep copy the message to preserve all fields, then assign new IDs - return context.next_message.duplicate_message() + return context.next_message.duplicate() return Message.from_prompt(prompt=context.objective, role="user") diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py b/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py index cdcc6eaa33..5b211487cc 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py @@ -90,7 +90,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text formatted_prompt = f"===={self.template_label} BEGINS====\n{prompt}\n===={self.template_label} ENDS====" prompt_metadata: dict[str, str | int] = {"response_format": "json"} request = Message( - [ + message_pieces=[ MessagePiece( role="user", original_value=formatted_prompt, diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py index be1d4ed24b..7e68a9907d 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py @@ -92,7 +92,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text prompt_metadata: dict[str, str | int] = {"response_format": "json"} request = Message( - [ + message_pieces=[ MessagePiece( role="user", original_value=formatted_prompt, diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py index 28b1801a72..0c1f3fdf95 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py @@ -66,7 +66,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text prompt_metadata: dict[str, str | int] = {"response_format": "json"} request = Message( - [ + message_pieces=[ MessagePiece( role="user", original_value=formatted_prompt, diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 4af7955162..99c76bf29f 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1009,7 +1009,7 @@ def duplicate_messages(self, *, messages: Sequence[Message]) -> tuple[str, Seque all_pieces: list[MessagePiece] = [] for message in messages: - duplicated_message = message.duplicate_message() + duplicated_message = message.duplicate() for piece in duplicated_message.message_pieces: piece.conversation_id = new_conversation_id diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index c5d3547e80..dd336563eb 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -14,7 +14,7 @@ apply_system_message_behavior, ) from pyrit.models import ChatMessage, DataTypeSerializer, Message -from pyrit.models.message_piece import MessagePiece +from pyrit.models.messages.message_piece import MessagePiece if TYPE_CHECKING: from pyrit.models.literals import ChatMessageRole diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 5f570638d4..cb071a0b9a 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -63,13 +63,16 @@ validate_registry_name, ) from pyrit.models.literals import ChatMessageRole, Modality, PromptDataType, PromptResponseError, SeedType -from pyrit.models.message import ( +from pyrit.models.messages import ( Message, + MessagePiece, construct_response_from_request, + flatten_to_message_pieces, + get_all_values, group_conversation_message_pieces_by_sequence, group_message_pieces_into_conversations, + sort_message_pieces, ) -from pyrit.models.message_piece import MessagePiece, sort_message_pieces from pyrit.models.question_answering import QuestionAnsweringDataset, QuestionAnsweringEntry, QuestionChoice from pyrit.models.retry_event import RetryEvent from pyrit.models.scenario_result import ScenarioIdentifier, ScenarioResult @@ -128,7 +131,9 @@ "EmbeddingUsageInformation", "ErrorDataTypeSerializer", "EvaluationIdentifier", + "flatten_to_message_pieces", "get_all_harm_definitions", + "get_all_values", "group_conversation_message_pieces_by_sequence", "group_message_pieces_into_conversations", "HarmDefinition", diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 839ebe44f0..eaafc67030 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -14,7 +14,7 @@ from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.models.identifiers.component_identifier import ComponentIdentifier -from pyrit.models.message_piece import MessagePiece +from pyrit.models.messages.message_piece import MessagePiece from pyrit.models.retry_event import RetryEvent from pyrit.models.score import Score from pyrit.models.strategy_result import StrategyResult diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 0e1ca86203..27936dc04a 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -1,601 +1,37 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from __future__ import annotations - -import copy -import uuid -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Optional, Union - -from pyrit.common.deprecation import print_deprecation_message -from pyrit.common.utils import combine_dict -from pyrit.models.message_piece import MessagePiece - -if TYPE_CHECKING: - from collections.abc import MutableSequence, Sequence - - from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError - - -class Message: - """ - Represents a message in a conversation, for example a prompt or a response to a prompt. - - This is a single request to a target. It can contain multiple message pieces. - """ - - def __init__(self, message_pieces: Sequence[MessagePiece], *, skip_validation: Optional[bool] = False) -> None: - """ - Initialize a Message from one or more message pieces. - - Args: - message_pieces (Sequence[MessagePiece]): Pieces belonging to the same message turn. - skip_validation (Optional[bool]): Whether to skip consistency validation. - - Raises: - ValueError: If no message pieces are provided. - - """ - if not message_pieces: - raise ValueError("Message must have at least one message piece.") - self.message_pieces = message_pieces - if not skip_validation: - self.validate() - - def get_value(self, n: int = 0) -> str: - """ - Return the converted value of the nth message piece. - - Args: - n (int): Zero-based index of the piece to read. - - Returns: - str: Converted value of the selected message piece. - - Raises: - IndexError: If the index is out of bounds. - - """ - if n >= len(self.message_pieces): - raise IndexError(f"No message piece at index {n}.") - return self.message_pieces[n].converted_value - - def get_values(self) -> list[str]: - """ - Return the converted values of all message pieces. - - Returns: - list[str]: Converted values for all message pieces. - - """ - return [message_piece.converted_value for message_piece in self.message_pieces] - - def get_piece(self, n: int = 0) -> MessagePiece: - """ - Return the nth message piece. - - Args: - n (int): Zero-based index of the piece to return. - - Returns: - MessagePiece: Selected message piece. - - Raises: - ValueError: If the message has no pieces. - IndexError: If the index is out of bounds. - - """ - if len(self.message_pieces) == 0: - raise ValueError("Empty message pieces.") - - if n >= len(self.message_pieces): - raise IndexError(f"No message piece at index {n}.") - - return self.message_pieces[n] - - def get_pieces_by_type( - self, - *, - data_type: Optional[PromptDataType] = None, - original_value_data_type: Optional[PromptDataType] = None, - converted_value_data_type: Optional[PromptDataType] = None, - ) -> list[MessagePiece]: - """ - Return all message pieces matching the given data type. - - Args: - data_type: Alias for converted_value_data_type (for convenience). - original_value_data_type: The original_value_data_type to filter by. - converted_value_data_type: The converted_value_data_type to filter by. - - Returns: - A list of matching MessagePiece objects (may be empty). - - """ - effective_converted = converted_value_data_type or data_type - results = self.message_pieces - if effective_converted: - results = [p for p in results if p.converted_value_data_type == effective_converted] - if original_value_data_type: - results = [p for p in results if p.original_value_data_type == original_value_data_type] - return list(results) - - def get_piece_by_type( - self, - *, - data_type: Optional[PromptDataType] = None, - original_value_data_type: Optional[PromptDataType] = None, - converted_value_data_type: Optional[PromptDataType] = None, - ) -> Optional[MessagePiece]: - """ - Return the first message piece matching the given data type, or None. - - Args: - data_type: Alias for converted_value_data_type (for convenience). - original_value_data_type: The original_value_data_type to filter by. - converted_value_data_type: The converted_value_data_type to filter by. - - Returns: - The first matching MessagePiece, or None if no match is found. - - """ - pieces = self.get_pieces_by_type( - data_type=data_type, - original_value_data_type=original_value_data_type, - converted_value_data_type=converted_value_data_type, - ) - return pieces[0] if pieces else None - - @property - def api_role(self) -> ChatMessageRole: - """ - Return the API-compatible role of the first message piece. - - Maps simulated_assistant to assistant for API compatibility. - All message pieces in a Message should have the same role. - - Returns: - ChatMessageRole: Role compatible with external API calls. - - Raises: - ValueError: If the message has no pieces. - - """ - if len(self.message_pieces) == 0: - raise ValueError("Empty message pieces.") - return self.message_pieces[0].api_role - - @property - def is_simulated(self) -> bool: - """ - Check if this is a simulated assistant response. - - Simulated responses come from prepended conversations or generated - simulated conversations, not from actual target responses. - """ - if len(self.message_pieces) == 0: - return False - return self.message_pieces[0].is_simulated - - @property - def conversation_id(self) -> str: - """ - Return the conversation ID of the first request piece. - - Returns: - str: Conversation identifier. - - Raises: - ValueError: If the message has no pieces. - - """ - if len(self.message_pieces) == 0: - raise ValueError("Empty message pieces.") - return self.message_pieces[0].conversation_id - - @property - def sequence(self) -> int: - """ - Return the sequence value of the first request piece. - - Returns: - int: Sequence number for the message turn. - - Raises: - ValueError: If the message has no pieces. - - """ - if len(self.message_pieces) == 0: - raise ValueError("Empty message pieces.") - return self.message_pieces[0].sequence - - def is_error(self) -> bool: - """ - Check whether any message piece indicates an error. - - Returns: - bool: True when any piece has a non-none error flag or error data type. - - """ - for piece in self.message_pieces: - if piece.response_error != "none" or piece.converted_value_data_type == "error": - return True - return False - - def set_response_not_in_memory(self) -> None: - """ - Mark every piece in this message as ephemeral. - - This is needed when we're scoring prompts or other things that have not been sent by PyRIT. - Ephemeral pieces are skipped by ``add_message_pieces_to_memory``. - """ - for piece in self.message_pieces: - piece.not_in_memory = True - - def set_response_not_in_database(self) -> None: - """ - Mark every piece in this message as ephemeral (DEPRECATED — use ``set_response_not_in_memory``). - """ - print_deprecation_message( - old_item="Message.set_response_not_in_database()", - new_item="Message.set_response_not_in_memory()", - removed_in="0.16.0", - ) - self.set_response_not_in_memory() - - def set_simulated_role(self) -> None: - """ - Set the role of all message pieces to simulated_assistant. - - This marks the message as coming from a simulated conversation - rather than an actual target response. - """ - for piece in self.message_pieces: - if piece.role == "assistant": - piece.role = "simulated_assistant" - - def validate(self) -> None: - """ - Validate that all message pieces are internally consistent. - - Raises: - ValueError: If piece collection is empty or contains mismatched conversation IDs, - sequence numbers, roles, or missing converted values. - - """ - if len(self.message_pieces) == 0: - raise ValueError("Empty message pieces.") - - conversation_id = self.message_pieces[0].conversation_id - sequence = self.message_pieces[0].sequence - role = self.message_pieces[0].role - for message_piece in self.message_pieces: - if message_piece.conversation_id != conversation_id: - raise ValueError("Conversation ID mismatch.") - - if message_piece.sequence != sequence: - raise ValueError("Inconsistent sequences within the same message entry.") - - if message_piece.converted_value is None: - raise ValueError("Converted prompt text is None.") - - if message_piece.role != role: - raise ValueError("Inconsistent roles within the same message entry.") - - def __str__(self) -> str: - """ - Return a newline-delimited string representation of message pieces. - - Returns: - str: Concatenated string representation. - - """ - return "\n".join(f"{piece.role}: {piece.converted_value}" for piece in self.message_pieces) - - def to_dict(self) -> dict[str, object]: - """ - Convert the message to a dictionary representation. - - Includes the original top-level fields ('role', 'converted_value', 'conversation_id', - 'sequence', 'converted_value_data_type') for backward compatibility, plus a 'pieces' - list containing each piece's Pydantic JSON dump — the latter is the source of truth - used by from_dict(). - - Returns: - dict[str, object]: Dictionary with 'role', 'converted_value', 'conversation_id', - 'sequence', 'converted_value_data_type', and 'pieces' keys. - """ - if len(self.message_pieces) == 1: - converted_value: str | list[str] = self.message_pieces[0].converted_value - converted_value_data_type: str | list[str] = self.message_pieces[0].converted_value_data_type - else: - converted_value = [piece.converted_value for piece in self.message_pieces] - converted_value_data_type = [piece.converted_value_data_type for piece in self.message_pieces] - - return { - "role": self.api_role, - "converted_value": converted_value, - "conversation_id": self.conversation_id, - "sequence": self.sequence, - "converted_value_data_type": converted_value_data_type, - "pieces": [piece.model_dump(mode="json") for piece in self.message_pieces], - } - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Message: - """ - Reconstruct a Message from a dictionary. - - Expects the format produced by to_dict(), which includes a 'pieces' key - containing a list of MessagePiece dictionaries. - - Args: - data (dict[str, Any]): Dictionary as produced by to_dict(). - - Returns: - Message: Reconstructed instance. - """ - pieces_data = data.get("pieces", []) - message_pieces = [MessagePiece.model_validate(p) for p in pieces_data] - return cls(message_pieces, skip_validation=True) - - @staticmethod - def get_all_values(messages: Sequence[Message]) -> list[str]: - """ - Return all converted values across the provided messages. - - Args: - messages (Sequence[Message]): Messages to aggregate. - - Returns: - list[str]: Flattened list of converted values. - - """ - values: list[str] = [] - for message in messages: - values.extend(message.get_values()) - return values - - @staticmethod - def flatten_to_message_pieces( - messages: Sequence[Message], - ) -> MutableSequence[MessagePiece]: - """ - Flatten messages into a single list of message pieces. - - Args: - messages (Sequence[Message]): Messages to flatten. - - Returns: - MutableSequence[MessagePiece]: Flattened message pieces. - - """ - if not messages: - return [] - message_pieces: MutableSequence[MessagePiece] = [] - - for response in messages: - message_pieces.extend(response.message_pieces) - - return message_pieces - - @classmethod - def from_prompt( - cls, - *, - prompt: str, - role: ChatMessageRole, - prompt_metadata: Optional[dict[str, Union[str, int]]] = None, - ) -> Message: - """ - Build a single-piece message from prompt text. - - Args: - prompt (str): Prompt text. - role (ChatMessageRole): Role assigned to the message piece. - prompt_metadata (Optional[Dict[str, Union[str, int]]]): Optional prompt metadata. - - Returns: - Message: Constructed message instance. - - """ - piece = MessagePiece(original_value=prompt, role=role, prompt_metadata=prompt_metadata or {}) - return cls(message_pieces=[piece]) - - @classmethod - def from_system_prompt(cls, system_prompt: str) -> Message: - """ - Build a message from a system prompt. - - Args: - system_prompt (str): System instruction text. - - Returns: - Message: Constructed system-role message. - - """ - return cls.from_prompt(prompt=system_prompt, role="system") - - def duplicate_message(self) -> Message: - """ - Create a deep copy of this message with new IDs and timestamp for all message pieces. - - This is useful when you need to reuse a message template but want fresh IDs - to avoid database conflicts (e.g., during retry attempts). - - The original_prompt_id is intentionally kept the same to track the origin. - Generates a new timestamp to reflect when the duplicate is created. - - Returns: - Message: A new Message with deep-copied message pieces, new IDs, and fresh timestamp. - - """ - new_pieces = copy.deepcopy(self.message_pieces) - new_timestamp = datetime.now(tz=timezone.utc) - for piece in new_pieces: - piece.id = uuid.uuid4() - piece.timestamp = new_timestamp - # original_prompt_id intentionally kept the same to track the origin - return Message(message_pieces=new_pieces) - - -def group_conversation_message_pieces_by_sequence( - message_pieces: Sequence[MessagePiece], -) -> MutableSequence[Message]: - """ - Group message pieces from the same conversation into messages. - - This is done using the sequence number and conversation ID. - - Args: - message_pieces (Sequence[MessagePiece]): A list of MessagePiece objects representing individual - message pieces. - - Returns: - MutableSequence[Message]: A list of Message objects representing grouped message - pieces. This is ordered by the sequence number. - - Raises: - ValueError: If the conversation ID of any message piece does not match the conversation ID of the first - message piece. - - Example: - >>> message_pieces = [ - >>> MessagePiece(conversation_id=1, sequence=1, text="Given this list of creatures, which is your - >>> favorite:"), - >>> MessagePiece(conversation_id=1, sequence=2, text="Good question!"), - >>> MessagePiece(conversation_id=1, sequence=1, text="Raccoon, Narwhal, or Sloth?"), - >>> MessagePiece(conversation_id=1, sequence=2, text="I'd have to say raccoons are my favorite!"), - >>> ] - >>> grouped_responses = group_conversation_message_pieces(message_pieces) - ... [ - ... Message(message_pieces=[ - ... MessagePiece(conversation_id=1, sequence=1, text="Given this list of creatures, which is your - ... favorite:"), - ... MessagePiece(conversation_id=1, sequence=1, text="Raccoon, Narwhal, or Sloth?") - ... ]), - ... Message(message_pieces=[ - ... MessagePiece(conversation_id=1, sequence=2, text="Good question!"), - ... MessagePiece(conversation_id=1, sequence=2, text="I'd have to say raccoons are my favorite!") - ... ]) - ... ] - - """ - if not message_pieces: - return [] - - conversation_id = message_pieces[0].conversation_id - - conversation_by_sequence: dict[int, list[MessagePiece]] = {} - - for message_piece in message_pieces: - if message_piece.conversation_id != conversation_id: - raise ValueError( - f"All message pieces must be from the same conversation. " - f"Expected conversation_id='{conversation_id}', but found '{message_piece.conversation_id}'. " - f"If grouping pieces from multiple conversations, group by conversation_id first." - ) - - if message_piece.sequence not in conversation_by_sequence: - conversation_by_sequence[message_piece.sequence] = [] - conversation_by_sequence[message_piece.sequence].append(message_piece) - - sorted_sequences = sorted(conversation_by_sequence.keys()) - return [Message(conversation_by_sequence[seq]) for seq in sorted_sequences] - - -def group_message_pieces_into_conversations( - message_pieces: Sequence[MessagePiece], -) -> list[list[Message]]: - """ - Group message pieces from multiple conversations into separate conversation groups. - - This function first groups pieces by conversation ID, then groups each conversation's - pieces by sequence number. Each conversation is returned as a separate list of - Message objects. - - Args: - message_pieces (Sequence[MessagePiece]): A list of MessagePiece objects from - potentially different conversations. - - Returns: - list[list[Message]]: A list of conversations, where each conversation is a list - of Message objects grouped by sequence. - - Example: - >>> message_pieces = [ - >>> MessagePiece(conversation_id="conv1", sequence=1, text="Hello"), - >>> MessagePiece(conversation_id="conv2", sequence=1, text="Hi there"), - >>> MessagePiece(conversation_id="conv1", sequence=2, text="How are you?"), - >>> MessagePiece(conversation_id="conv2", sequence=2, text="I'm good"), - >>> ] - >>> conversations = group_message_pieces_into_conversations(message_pieces) - >>> # Returns a list of 2 conversations: - >>> # [ - >>> # [Message(seq=1), Message(seq=2)], # conv1 - >>> # [Message(seq=1), Message(seq=2)] # conv2 - >>> # ] - - """ - if not message_pieces: - return [] - - # Group pieces by conversation ID - conversations: dict[str, list[MessagePiece]] = {} - for piece in message_pieces: - conv_id = piece.conversation_id - if conv_id not in conversations: - conversations[conv_id] = [] - conversations[conv_id].append(piece) - - # For each conversation, group by sequence - result: list[list[Message]] = [] - for conv_pieces in conversations.values(): - responses = group_conversation_message_pieces_by_sequence(conv_pieces) - result.append(list(responses)) - - return result - - -def construct_response_from_request( - request: MessagePiece, - response_text_pieces: list[str], - response_type: PromptDataType = "text", - prompt_metadata: Optional[dict[str, Union[str, int]]] = None, - error: PromptResponseError = "none", -) -> Message: - """ - Construct a response message from a request message piece. - - Args: - request (MessagePiece): Source request message piece. - response_text_pieces (list[str]): Response values to include. - response_type (PromptDataType): Data type for original and converted response values. - prompt_metadata (Optional[Dict[str, Union[str, int]]]): Additional metadata to merge. - error (PromptResponseError): Error classification for the response. - - Returns: - Message: Constructed response message. - - """ - if request.prompt_metadata: - prompt_metadata = combine_dict(request.prompt_metadata, prompt_metadata or {}) - - return Message( - message_pieces=[ - MessagePiece( - role="assistant", - original_value=resp_text, - conversation_id=request.conversation_id, - labels=request.labels, - prompt_target_identifier=request.prompt_target_identifier, - attack_identifier=request.attack_identifier, - original_value_data_type=response_type, - converted_value_data_type=response_type, - prompt_metadata=prompt_metadata or {}, - response_error=error, - ) - for resp_text in response_text_pieces - ] - ) +""" +Backward-compatibility shim. + +``Message`` and the conversation helpers now live in ``pyrit.models.messages``. +Import from there (or from ``pyrit.models``) instead. This module re-exports the +public names so existing ``from pyrit.models.message import ...`` imports keep +working. +""" + +from typing import Any + +from pyrit.models.messages import message as _message +from pyrit.models.messages.conversations import ( + construct_response_from_request, + flatten_to_message_pieces, + get_all_values, + group_conversation_message_pieces_by_sequence, + group_message_pieces_into_conversations, +) +from pyrit.models.messages.message import Message + + +def __getattr__(name: str) -> Any: + return getattr(_message, name) + + +__all__ = [ + "Message", + "construct_response_from_request", + "flatten_to_message_pieces", + "get_all_values", + "group_conversation_message_pieces_by_sequence", + "group_message_pieces_into_conversations", +] diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 6c99f36c28..b5d92f3036 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -1,356 +1,22 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from __future__ import annotations +""" +Backward-compatibility shim. -import uuid -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional -from uuid import uuid4 +``MessagePiece`` now lives in ``pyrit.models.messages``. Import from there (or +from ``pyrit.models``) instead. This module re-exports the public names so +existing ``from pyrit.models.message_piece import ...`` imports keep working. +""" -from pydantic import ( - AwareDatetime, - BaseModel, - BeforeValidator, - ConfigDict, - Field, - PlainSerializer, - model_validator, -) +from typing import Any -from pyrit.common.deprecation import print_deprecation_message -from pyrit.models.data_type_serializer import data_serializer_factory -from pyrit.models.identifiers.component_identifier import ComponentIdentifier -from pyrit.models.literals import ( # noqa: TC001 (runtime-required by Pydantic field annotations) - ChatMessageRole, - PromptDataType, - PromptResponseError, -) -from pyrit.models.score import Score - -if TYPE_CHECKING: - from pyrit.models.message import Message - - -# Deprecated kwargs whose presence in ``MessagePiece(...)`` should emit a -# ``DeprecationWarning``. Each entry is ``(kwarg_name, removed_in)``. Kept here -# (rather than embedded in the validator body) to make the deprecation surface -# easy to read and update. -# -# These can be deleted entirely once their ``removed_in`` releases ship — the -# Pydantic field definitions and ``extra="forbid"`` config will then reject -# the kwargs naturally. -_DEPRECATED_KWARGS: tuple[tuple[str, str], ...] = ( - ("labels", "0.16.0"), - ("scorer_identifier", "0.15.0"), - ("scores", "0.15.0"), - ("targeted_harm_categories", "0.15.0"), -) - - -# Annotated alias that round-trips identifier fields through the flat dict -# storage shape. ``ComponentIdentifier`` is a Pydantic model with a custom -# flat serializer; ``Score`` is still a plain class needing ``from_dict`` / -# ``to_dict``. Drop the ``Score`` alias once it becomes a Pydantic model. -ComponentIdentifierField = Annotated[ - ComponentIdentifier, - BeforeValidator(lambda v: ComponentIdentifier.model_validate(v) if isinstance(v, dict) else v), - PlainSerializer(lambda v: v.model_dump() if v is not None else None, return_type=Optional[dict]), -] - -ScoreField = Annotated[ - Score, - BeforeValidator(lambda v: Score.from_dict(v) if isinstance(v, dict) else v), - PlainSerializer(lambda v: v.to_dict(), return_type=dict), -] +from pyrit.models.messages import message_piece as _message_piece +from pyrit.models.messages.message_piece import MessagePiece, sort_message_pieces def __getattr__(name: str) -> Any: - """ - Lazily resolve deprecated module-level aliases. - - Args: - name: The attribute name being accessed. - - Returns: - The resolved alias (currently only ``Originator``). - - Raises: - AttributeError: If ``name`` is not a known deprecated alias. - """ - if name == "Originator": - print_deprecation_message( - old_item="pyrit.models.message_piece.Originator", - new_item=( - "inline Literal['attack', 'converter', 'undefined', 'scorer'] " - "(the type alias is being removed; the originator field itself is " - "deprecated and will be removed in 0.15.0)" - ), - removed_in="0.15.0", - ) - return Literal["attack", "converter", "undefined", "scorer"] - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -class MessagePiece(BaseModel): - """ - A single piece of a message exchanged with a target. - - Targets that accept multimodal input (e.g., text + image) are represented - as a list of ``MessagePiece`` instances grouped under one - ``Message``. - """ - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="forbid", - validate_assignment=False, - ) - - id: uuid.UUID = Field(default_factory=uuid4) # noqa: A003 - role: ChatMessageRole - conversation_id: str = Field(default_factory=lambda: str(uuid4())) - sequence: int = -1 - timestamp: AwareDatetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) - original_value: str - original_value_data_type: PromptDataType = "text" - original_value_sha256: Optional[str] = None - converted_value: str = "" - converted_value_data_type: PromptDataType = "text" - converted_value_sha256: Optional[str] = None - response_error: PromptResponseError = "none" - originator: Literal["attack", "converter", "undefined", "scorer"] = "undefined" - original_prompt_id: Optional[uuid.UUID] = None - labels: dict[str, Any] = Field(default_factory=dict) - targeted_harm_categories: list[str] = Field(default_factory=list) - prompt_metadata: dict[str, Any] = Field(default_factory=dict) - converter_identifiers: list[ComponentIdentifierField] = Field(default_factory=list) - prompt_target_identifier: Optional[ComponentIdentifierField] = None - attack_identifier: Optional[ComponentIdentifierField] = None - scorer_identifier: Optional[ComponentIdentifierField] = None - scores: list[ScoreField] = Field(default_factory=list) - - # When True, the memory layer skips persisting this piece. Used for ephemeral - # pieces a scorer creates to score arbitrary content; ``exclude=True`` keeps - # the flag out of JSON / memory schema serialization. Named ``not_in_memory`` - # to match PyRIT's ``add_*_to_memory`` API verbs. - not_in_memory: bool = Field(default=False, exclude=True) - - # ------------------------------------------------------------------ # - # Validators - # ------------------------------------------------------------------ # - @model_validator(mode="before") - @classmethod - def _warn_on_deprecated_kwargs(cls, data: Any) -> Any: - """ - Emit DeprecationWarning for each deprecated kwarg explicitly passed. - - Returns: - The (unchanged) input ``data`` so validation can continue. - """ - if not isinstance(data, dict): - return data - for kwarg, removed_in in _DEPRECATED_KWARGS: - if data.get(kwarg) is not None: - print_deprecation_message( - old_item=f"MessagePiece(..., {kwarg}=...)", - new_item="MessagePiece(...)", - removed_in=removed_in, - ) - # ``originator`` is special: only warn when the caller explicitly - # opts into a non-default value. - if data.get("originator", "undefined") != "undefined": - print_deprecation_message( - old_item="MessagePiece(..., originator=...)", - new_item="MessagePiece(...)", - removed_in="0.15.0", - ) - return data - - @model_validator(mode="before") - @classmethod - def _mirror_original_to_converted(cls, data: Any) -> Any: - """ - When ``converted_value`` / ``converted_value_data_type`` aren't supplied, mirror the originals. - - Returns: - The input ``data`` with mirrored converted fields applied. - """ - if not isinstance(data, dict): - return data - if not data.get("converted_value") and "original_value" in data: - data["converted_value"] = data["original_value"] - if not data.get("converted_value_data_type") and "original_value_data_type" in data: - data["converted_value_data_type"] = data["original_value_data_type"] - return data - - @model_validator(mode="after") - def _set_original_prompt_id_default(self) -> MessagePiece: - """ - Enforce invariant: ``original_prompt_id == id`` for non-duplicate pieces. - - Returns: - ``self`` (with ``original_prompt_id`` populated when previously ``None``). - """ - if self.original_prompt_id is None: - self.original_prompt_id = self.id - return self - - # ------------------------------------------------------------------ # - # Public API - # ------------------------------------------------------------------ # - @property - def api_role(self) -> ChatMessageRole: - """ - Role to use for API calls. - - Maps ``simulated_assistant`` to ``assistant`` for API compatibility. - Use this property when sending messages to external APIs. - """ - return "assistant" if self.role == "simulated_assistant" else self.role - - @property - def is_simulated(self) -> bool: - """Whether this piece represents a simulated assistant response.""" - return self.role == "simulated_assistant" - - def to_message(self) -> Message: - """ - Wrap this piece in a single-piece ``Message``. - - Returns: - A new ``Message`` containing only this piece. - """ - # Deferred import: ``pyrit.models.message`` imports ``MessagePiece`` at - # module load, so a top-level import here would deadlock the cycle. - from pyrit.models.message import Message - - return Message([self]) - - def copy_lineage_from(self, *, source: MessagePiece) -> None: - """ - Copy lineage metadata from ``source`` onto this piece. - - Lineage fields are the metadata that tie a piece back to its originating - conversation, attack, and target. Mutable containers (``labels``, - ``prompt_metadata``) are shallow-copied so that mutations on one piece - do not affect others. - - Args: - source: The piece whose lineage will be copied onto ``self``. - """ - self.conversation_id = source.conversation_id - self.labels = dict(source.labels) - self.attack_identifier = source.attack_identifier - self.prompt_target_identifier = source.prompt_target_identifier - self.prompt_metadata = dict(source.prompt_metadata) - - def has_error(self) -> bool: - """ - Return ``True`` when ``response_error`` is not ``"none"``. - - Returns: - ``True`` if the piece carries any non-``"none"`` error code. - """ - return self.response_error != "none" - - def is_blocked(self) -> bool: - """ - Return ``True`` when ``response_error`` is ``"blocked"``. - - Returns: - ``True`` if the response was blocked by the target / content filter. - """ - return self.response_error == "blocked" - - # ------------------------------------------------------------------ # - # Deprecated method shims (removed in 0.16.0) - # ------------------------------------------------------------------ # - def to_dict(self) -> dict[str, Any]: - """ - Return a JSON-mode dict representation (DEPRECATED — use ``model_dump``). - - Returns: - A JSON-mode dict representation of the piece (same as - ``self.model_dump(mode="json")``). - """ - print_deprecation_message( - old_item="MessagePiece.to_dict()", - new_item='MessagePiece.model_dump(mode="json")', - removed_in="0.16.0", - ) - return self.model_dump(mode="json") - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> MessagePiece: - """ - Construct a MessagePiece from a dict (DEPRECATED — use ``model_validate``). - - Args: - data: A dict matching the MessagePiece field schema. - - Returns: - A new ``MessagePiece`` (same as ``cls.model_validate(data)``). - """ - print_deprecation_message( - old_item="MessagePiece.from_dict()", - new_item="MessagePiece.model_validate()", - removed_in="0.16.0", - ) - return cls.model_validate(data) - - def set_piece_not_in_database(self) -> None: - """ - Mark this piece as ephemeral (DEPRECATED — set ``not_in_memory`` directly). - - Example:: - - piece.not_in_memory = True - """ - print_deprecation_message( - old_item="MessagePiece.set_piece_not_in_database()", - new_item="MessagePiece.not_in_memory = True", - removed_in="0.16.0", - ) - self.not_in_memory = True - - async def set_sha256_values_async(self) -> None: - """ - Compute SHA256 hash values for original and converted payloads. - - Async because blob payloads may need to be fetched. Must be called - explicitly after construction. - """ - original_serializer = data_serializer_factory( - category="prompt-memory-entries", - data_type=self.original_value_data_type, - value=self.original_value, - ) - self.original_value_sha256 = await original_serializer.get_sha256() - - converted_serializer = data_serializer_factory( - category="prompt-memory-entries", - data_type=self.converted_value_data_type, - value=self.converted_value, - ) - self.converted_value_sha256 = await converted_serializer.get_sha256() - - -def sort_message_pieces(message_pieces: list[MessagePiece]) -> list[MessagePiece]: - """ - Group by ``conversation_id``, ordering by earliest timestamp then ``sequence``. - - Conversations are ordered by their earliest piece's timestamp; pieces - within a conversation are ordered by ``sequence``. + return getattr(_message_piece, name) - Args: - message_pieces: The pieces to sort. Not mutated. - Returns: - A new list containing the same pieces in deterministic order. - """ - earliest_timestamps = { - convo_id: min(x.timestamp for x in message_pieces if x.conversation_id == convo_id) - for convo_id in {x.conversation_id for x in message_pieces} - } - return sorted(message_pieces, key=lambda x: (earliest_timestamps[x.conversation_id], x.conversation_id, x.sequence)) +__all__ = ["MessagePiece", "sort_message_pieces"] diff --git a/pyrit/models/messages/__init__.py b/pyrit/models/messages/__init__.py new file mode 100644 index 0000000000..fca91f47ba --- /dev/null +++ b/pyrit/models/messages/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Messages module - message types and helpers for PyRIT. + +- MessagePiece: A single piece of a message exchanged with a target. +- Message: One request/response to a target, made up of one or more pieces. +- conversations: Free functions that operate on collections of messages/pieces. +""" + +from pyrit.models.messages.conversations import ( + construct_response_from_request, + flatten_to_message_pieces, + get_all_values, + group_conversation_message_pieces_by_sequence, + group_message_pieces_into_conversations, +) +from pyrit.models.messages.message import Message +from pyrit.models.messages.message_piece import MessagePiece, sort_message_pieces + +__all__ = [ + "Message", + "MessagePiece", + "construct_response_from_request", + "flatten_to_message_pieces", + "get_all_values", + "group_conversation_message_pieces_by_sequence", + "group_message_pieces_into_conversations", + "sort_message_pieces", +] diff --git a/pyrit/models/messages/conversations.py b/pyrit/models/messages/conversations.py new file mode 100644 index 0000000000..2b5224f6c6 --- /dev/null +++ b/pyrit/models/messages/conversations.py @@ -0,0 +1,217 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Helpers that operate on collections of ``Message`` / ``MessagePiece``.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Union + +from pyrit.models.messages.message import Message +from pyrit.models.messages.message_piece import MessagePiece + +if TYPE_CHECKING: + from collections.abc import MutableSequence, Sequence + + from pyrit.models.literals import PromptDataType, PromptResponseError + + +def get_all_values(messages: Sequence[Message]) -> list[str]: + """ + Return all converted values across the provided messages. + + Args: + messages (Sequence[Message]): Messages to aggregate. + + Returns: + list[str]: Flattened list of converted values. + + """ + values: list[str] = [] + for message in messages: + values.extend(message.get_values()) + return values + + +def flatten_to_message_pieces( + messages: Sequence[Message], +) -> MutableSequence[MessagePiece]: + """ + Flatten messages into a single list of message pieces. + + Args: + messages (Sequence[Message]): Messages to flatten. + + Returns: + MutableSequence[MessagePiece]: Flattened message pieces. + + """ + if not messages: + return [] + message_pieces: MutableSequence[MessagePiece] = [] + + for response in messages: + message_pieces.extend(response.message_pieces) + + return message_pieces + + +def group_conversation_message_pieces_by_sequence( + message_pieces: Sequence[MessagePiece], +) -> MutableSequence[Message]: + """ + Group message pieces from the same conversation into messages. + + This is done using the sequence number and conversation ID. + + Args: + message_pieces (Sequence[MessagePiece]): A list of MessagePiece objects representing individual + message pieces. + + Returns: + MutableSequence[Message]: A list of Message objects representing grouped message + pieces. This is ordered by the sequence number. + + Raises: + ValueError: If the conversation ID of any message piece does not match the conversation ID of the first + message piece. + + Example: + >>> message_pieces = [ + >>> MessagePiece(conversation_id=1, sequence=1, text="Given this list of creatures, which is your + >>> favorite:"), + >>> MessagePiece(conversation_id=1, sequence=2, text="Good question!"), + >>> MessagePiece(conversation_id=1, sequence=1, text="Raccoon, Narwhal, or Sloth?"), + >>> MessagePiece(conversation_id=1, sequence=2, text="I'd have to say raccoons are my favorite!"), + >>> ] + >>> grouped_responses = group_conversation_message_pieces(message_pieces) + ... [ + ... Message(message_pieces=[ + ... MessagePiece(conversation_id=1, sequence=1, text="Given this list of creatures, which is your + ... favorite:"), + ... MessagePiece(conversation_id=1, sequence=1, text="Raccoon, Narwhal, or Sloth?") + ... ]), + ... Message(message_pieces=[ + ... MessagePiece(conversation_id=1, sequence=2, text="Good question!"), + ... MessagePiece(conversation_id=1, sequence=2, text="I'd have to say raccoons are my favorite!") + ... ]) + ... ] + + """ + if not message_pieces: + return [] + + conversation_id = message_pieces[0].conversation_id + + conversation_by_sequence: dict[int, list[MessagePiece]] = {} + + for message_piece in message_pieces: + if message_piece.conversation_id != conversation_id: + raise ValueError( + f"All message pieces must be from the same conversation. " + f"Expected conversation_id='{conversation_id}', but found '{message_piece.conversation_id}'. " + f"If grouping pieces from multiple conversations, group by conversation_id first." + ) + + if message_piece.sequence not in conversation_by_sequence: + conversation_by_sequence[message_piece.sequence] = [] + conversation_by_sequence[message_piece.sequence].append(message_piece) + + sorted_sequences = sorted(conversation_by_sequence.keys()) + return [Message(message_pieces=conversation_by_sequence[seq]) for seq in sorted_sequences] + + +def group_message_pieces_into_conversations( + message_pieces: Sequence[MessagePiece], +) -> list[list[Message]]: + """ + Group message pieces from multiple conversations into separate conversation groups. + + This function first groups pieces by conversation ID, then groups each conversation's + pieces by sequence number. Each conversation is returned as a separate list of + Message objects. + + Args: + message_pieces (Sequence[MessagePiece]): A list of MessagePiece objects from + potentially different conversations. + + Returns: + list[list[Message]]: A list of conversations, where each conversation is a list + of Message objects grouped by sequence. + + Example: + >>> message_pieces = [ + >>> MessagePiece(conversation_id="conv1", sequence=1, text="Hello"), + >>> MessagePiece(conversation_id="conv2", sequence=1, text="Hi there"), + >>> MessagePiece(conversation_id="conv1", sequence=2, text="How are you?"), + >>> MessagePiece(conversation_id="conv2", sequence=2, text="I'm good"), + >>> ] + >>> conversations = group_message_pieces_into_conversations(message_pieces) + >>> # Returns a list of 2 conversations: + >>> # [ + >>> # [Message(seq=1), Message(seq=2)], # conv1 + >>> # [Message(seq=1), Message(seq=2)] # conv2 + >>> # ] + + """ + if not message_pieces: + return [] + + # Group pieces by conversation ID + conversations: dict[str, list[MessagePiece]] = {} + for piece in message_pieces: + conv_id = piece.conversation_id + if conv_id not in conversations: + conversations[conv_id] = [] + conversations[conv_id].append(piece) + + # For each conversation, group by sequence + result: list[list[Message]] = [] + for conv_pieces in conversations.values(): + responses = group_conversation_message_pieces_by_sequence(conv_pieces) + result.append(list(responses)) + + return result + + +def construct_response_from_request( + request: MessagePiece, + response_text_pieces: list[str], + response_type: PromptDataType = "text", + prompt_metadata: Optional[dict[str, Union[str, int]]] = None, + error: PromptResponseError = "none", +) -> Message: + """ + Construct a response message from a request message piece. + + Args: + request (MessagePiece): Source request message piece. + response_text_pieces (list[str]): Response values to include. + response_type (PromptDataType): Data type for original and converted response values. + prompt_metadata (Optional[Dict[str, Union[str, int]]]): Additional metadata to merge. + error (PromptResponseError): Error classification for the response. + + Returns: + Message: Constructed response message. + + """ + if request.prompt_metadata: + prompt_metadata = {**request.prompt_metadata, **(prompt_metadata or {})} + + return Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=resp_text, + conversation_id=request.conversation_id, + labels=request.labels, + prompt_target_identifier=request.prompt_target_identifier, + attack_identifier=request.attack_identifier, + original_value_data_type=response_type, + converted_value_data_type=response_type, + prompt_metadata=prompt_metadata or {}, + response_error=error, + ) + for resp_text in response_text_pieces + ] + ) diff --git a/pyrit/models/messages/message.py b/pyrit/models/messages/message.py new file mode 100644 index 0000000000..d14f1ec1b9 --- /dev/null +++ b/pyrit/models/messages/message.py @@ -0,0 +1,535 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import copy +import uuid +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Optional, Union + +from pydantic import BaseModel, ConfigDict, model_validator + +from pyrit.common.deprecation import print_deprecation_message +from pyrit.models.messages.message_piece import MessagePiece + +if TYPE_CHECKING: + from collections.abc import MutableSequence, Sequence + + from pyrit.models.literals import ChatMessageRole, PromptDataType + + +class Message(BaseModel): + """ + Represents a message in a conversation, for example a prompt or a response to a prompt. + + This is a single request to a target. It can contain multiple message pieces. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + validate_assignment=False, + ) + + message_pieces: list[MessagePiece] + + def __init__(self, *args: Any, **data: Any) -> None: + """ + Initialize a Message from one or more message pieces. + + Supports the canonical keyword form ``Message(message_pieces=[...])`` as + well as two deprecated forms that emit a ``DeprecationWarning``: + + - positional construction ``Message([piece, ...])`` + - the ``skip_validation`` keyword (now a no-op; validation always runs) + + Raises: + TypeError: If more than one positional argument is supplied. + ValueError: If no message pieces are provided (via validation). + """ + if args: + if len(args) > 1: + raise TypeError(f"Message() takes at most 1 positional argument but {len(args)} were given.") + print_deprecation_message( + old_item="Message(message_pieces) (positional)", + new_item="Message(message_pieces=...)", + removed_in="0.16.0", + ) + data["message_pieces"] = args[0] + if "skip_validation" in data: + data.pop("skip_validation") + print_deprecation_message( + old_item="Message(..., skip_validation=...)", + new_item="Message(message_pieces=...)", + removed_in="0.16.0", + ) + super().__init__(**data) + + # ------------------------------------------------------------------ # + # Validators + # ------------------------------------------------------------------ # + @model_validator(mode="before") + @classmethod + def _rewrite_legacy_dict(cls, data: Any) -> Any: + """ + Accept the legacy ``to_dict()`` payload shape during ``model_validate``. + + The legacy dict carries top-level convenience fields plus a ``pieces`` + list. Under ``extra="forbid"`` those extra keys would be rejected, so + collapse the payload down to ``{"message_pieces": [...]}``. + + Returns: + The normalized input ``data``. + """ + if isinstance(data, dict) and "pieces" in data and "message_pieces" not in data: + return {"message_pieces": data["pieces"]} + return data + + @model_validator(mode="after") + def _validate_after(self) -> Message: + """ + Enforce internal consistency of the message pieces after construction. + + Returns: + ``self``. + """ + self._validate_invariants() + return self + + def _validate_invariants(self) -> None: + """ + Check that all message pieces are internally consistent. + + Raises: + ValueError: If the piece collection is empty or contains mismatched conversation IDs, + sequence numbers, roles, or missing converted values. + """ + if len(self.message_pieces) == 0: + raise ValueError("Message must have at least one message piece.") + + conversation_id = self.message_pieces[0].conversation_id + sequence = self.message_pieces[0].sequence + role = self.message_pieces[0].role + for message_piece in self.message_pieces: + if message_piece.conversation_id != conversation_id: + raise ValueError("Conversation ID mismatch.") + + if message_piece.sequence != sequence: + raise ValueError("Inconsistent sequences within the same message entry.") + + if message_piece.converted_value is None: + raise ValueError("Converted prompt text is None.") + + if message_piece.role != role: + raise ValueError("Inconsistent roles within the same message entry.") + + def validate(self) -> None: + """ + Validate that all message pieces are internally consistent. + + Retained as a public instance method because callers invoke + ``message.validate()`` directly. Shadows the deprecated + ``BaseModel.validate`` classmethod. + + Raises: + ValueError: If piece collection is empty or contains mismatched conversation IDs, + sequence numbers, roles, or missing converted values. + """ + self._validate_invariants() + + # ------------------------------------------------------------------ # + # Public API + # ------------------------------------------------------------------ # + def get_value(self, n: int = 0) -> str: + """ + Return the converted value of the nth message piece. + + Args: + n (int): Zero-based index of the piece to read. + + Returns: + str: Converted value of the selected message piece. + + Raises: + IndexError: If the index is out of bounds. + + """ + if n >= len(self.message_pieces): + raise IndexError(f"No message piece at index {n}.") + return self.message_pieces[n].converted_value + + def get_values(self) -> list[str]: + """ + Return the converted values of all message pieces. + + Returns: + list[str]: Converted values for all message pieces. + + """ + return [message_piece.converted_value for message_piece in self.message_pieces] + + def get_piece(self, n: int = 0) -> MessagePiece: + """ + Return the nth message piece. + + Args: + n (int): Zero-based index of the piece to return. + + Returns: + MessagePiece: Selected message piece. + + Raises: + ValueError: If the message has no pieces. + IndexError: If the index is out of bounds. + + """ + if len(self.message_pieces) == 0: + raise ValueError("Empty message pieces.") + + if n >= len(self.message_pieces): + raise IndexError(f"No message piece at index {n}.") + + return self.message_pieces[n] + + def get_pieces_by_type( + self, + *, + data_type: Optional[PromptDataType] = None, + original_value_data_type: Optional[PromptDataType] = None, + converted_value_data_type: Optional[PromptDataType] = None, + ) -> list[MessagePiece]: + """ + Return all message pieces matching the given data type. + + Args: + data_type: Alias for converted_value_data_type (for convenience). + original_value_data_type: The original_value_data_type to filter by. + converted_value_data_type: The converted_value_data_type to filter by. + + Returns: + A list of matching MessagePiece objects (may be empty). + + """ + effective_converted = converted_value_data_type or data_type + results = self.message_pieces + if effective_converted: + results = [p for p in results if p.converted_value_data_type == effective_converted] + if original_value_data_type: + results = [p for p in results if p.original_value_data_type == original_value_data_type] + return list(results) + + def get_piece_by_type( + self, + *, + data_type: Optional[PromptDataType] = None, + original_value_data_type: Optional[PromptDataType] = None, + converted_value_data_type: Optional[PromptDataType] = None, + ) -> Optional[MessagePiece]: + """ + Return the first message piece matching the given data type, or None. + + Args: + data_type: Alias for converted_value_data_type (for convenience). + original_value_data_type: The original_value_data_type to filter by. + converted_value_data_type: The converted_value_data_type to filter by. + + Returns: + The first matching MessagePiece, or None if no match is found. + + """ + pieces = self.get_pieces_by_type( + data_type=data_type, + original_value_data_type=original_value_data_type, + converted_value_data_type=converted_value_data_type, + ) + return pieces[0] if pieces else None + + @property + def api_role(self) -> ChatMessageRole: + """ + Return the API-compatible role of the first message piece. + + Maps simulated_assistant to assistant for API compatibility. + All message pieces in a Message should have the same role. + + Returns: + ChatMessageRole: Role compatible with external API calls. + + Raises: + ValueError: If the message has no pieces. + + """ + if len(self.message_pieces) == 0: + raise ValueError("Empty message pieces.") + return self.message_pieces[0].api_role + + @property + def is_simulated(self) -> bool: + """ + Check if this is a simulated assistant response. + + Simulated responses come from prepended conversations or generated + simulated conversations, not from actual target responses. + """ + if len(self.message_pieces) == 0: + return False + return self.message_pieces[0].is_simulated + + @property + def conversation_id(self) -> str: + """ + Return the conversation ID of the first request piece. + + Returns: + str: Conversation identifier. + + Raises: + ValueError: If the message has no pieces. + + """ + if len(self.message_pieces) == 0: + raise ValueError("Empty message pieces.") + return self.message_pieces[0].conversation_id + + @property + def sequence(self) -> int: + """ + Return the sequence value of the first request piece. + + Returns: + int: Sequence number for the message turn. + + Raises: + ValueError: If the message has no pieces. + + """ + if len(self.message_pieces) == 0: + raise ValueError("Empty message pieces.") + return self.message_pieces[0].sequence + + def is_error(self) -> bool: + """ + Check whether any message piece indicates an error. + + Returns: + bool: True when any piece has a non-none error flag or error data type. + + """ + for piece in self.message_pieces: + if piece.response_error != "none" or piece.converted_value_data_type == "error": + return True + return False + + def set_response_not_in_memory(self) -> None: + """ + Mark every piece in this message as ephemeral. + + This is needed when we're scoring prompts or other things that have not been sent by PyRIT. + Ephemeral pieces are skipped by ``add_message_pieces_to_memory``. + """ + for piece in self.message_pieces: + piece.not_in_memory = True + + def set_simulated_role(self) -> None: + """ + Set the role of all message pieces to simulated_assistant. + + This marks the message as coming from a simulated conversation + rather than an actual target response. + """ + for piece in self.message_pieces: + if piece.role == "assistant": + piece.role = "simulated_assistant" + + def __str__(self) -> str: + """ + Return a newline-delimited string representation of message pieces. + + Returns: + str: Concatenated string representation. + + """ + return "\n".join(f"{piece.role}: {piece.converted_value}" for piece in self.message_pieces) + + @classmethod + def from_prompt( + cls, + *, + prompt: str, + role: ChatMessageRole, + prompt_metadata: Optional[dict[str, Union[str, int]]] = None, + ) -> Message: + """ + Build a single-piece message from prompt text. + + Args: + prompt (str): Prompt text. + role (ChatMessageRole): Role assigned to the message piece. + prompt_metadata (Optional[Dict[str, Union[str, int]]]): Optional prompt metadata. + + Returns: + Message: Constructed message instance. + + """ + piece = MessagePiece(original_value=prompt, role=role, prompt_metadata=prompt_metadata or {}) + return cls(message_pieces=[piece]) + + @classmethod + def from_system_prompt(cls, system_prompt: str) -> Message: + """ + Build a message from a system prompt. + + Args: + system_prompt (str): System instruction text. + + Returns: + Message: Constructed system-role message. + + """ + return cls.from_prompt(prompt=system_prompt, role="system") + + def duplicate(self) -> Message: + """ + Create a deep copy of this message with new IDs and timestamp for all message pieces. + + This is useful when you need to reuse a message template but want fresh IDs + to avoid database conflicts (e.g., during retry attempts). + + The original_prompt_id is intentionally kept the same to track the origin. + Generates a new timestamp to reflect when the duplicate is created. + + Returns: + Message: A new Message with deep-copied message pieces, new IDs, and fresh timestamp. + + """ + new_pieces = copy.deepcopy(list(self.message_pieces)) + new_timestamp = datetime.now(tz=timezone.utc) + for piece in new_pieces: + piece.id = uuid.uuid4() + piece.timestamp = new_timestamp + # original_prompt_id intentionally kept the same to track the origin + return Message(message_pieces=new_pieces) + + # ------------------------------------------------------------------ # + # Deprecated method shims (removed in 0.16.0) + # ------------------------------------------------------------------ # + def set_response_not_in_database(self) -> None: + """ + Mark every piece in this message as ephemeral (DEPRECATED — use ``set_response_not_in_memory``). + """ + print_deprecation_message( + old_item="Message.set_response_not_in_database()", + new_item="Message.set_response_not_in_memory()", + removed_in="0.16.0", + ) + self.set_response_not_in_memory() + + def duplicate_message(self) -> Message: + """ + Create a deep copy of this message (DEPRECATED — use ``duplicate``). + + Returns: + Message: A new Message with deep-copied pieces, new IDs, and fresh timestamp. + """ + print_deprecation_message( + old_item="Message.duplicate_message()", + new_item="Message.duplicate()", + removed_in="0.16.0", + ) + return self.duplicate() + + def to_dict(self) -> dict[str, object]: + """ + Convert the message to a dictionary representation (DEPRECATED — use ``model_dump``). + + Includes the original top-level fields ('role', 'converted_value', 'conversation_id', + 'sequence', 'converted_value_data_type') for backward compatibility, plus a 'pieces' + list containing each piece's Pydantic JSON dump. + + Returns: + dict[str, object]: Dictionary with 'role', 'converted_value', 'conversation_id', + 'sequence', 'converted_value_data_type', and 'pieces' keys. + """ + print_deprecation_message( + old_item="Message.to_dict()", + new_item='Message.model_dump(mode="json")', + removed_in="0.16.0", + ) + if len(self.message_pieces) == 1: + converted_value: str | list[str] = self.message_pieces[0].converted_value + converted_value_data_type: str | list[str] = self.message_pieces[0].converted_value_data_type + else: + converted_value = [piece.converted_value for piece in self.message_pieces] + converted_value_data_type = [piece.converted_value_data_type for piece in self.message_pieces] + + return { + "role": self.api_role, + "converted_value": converted_value, + "conversation_id": self.conversation_id, + "sequence": self.sequence, + "converted_value_data_type": converted_value_data_type, + "pieces": [piece.model_dump(mode="json") for piece in self.message_pieces], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Message: + """ + Reconstruct a Message from a dictionary (DEPRECATED — use ``model_validate``). + + Args: + data (dict[str, Any]): Dictionary as produced by ``to_dict()``. + + Returns: + Message: Reconstructed instance. + """ + print_deprecation_message( + old_item="Message.from_dict()", + new_item="Message.model_validate()", + removed_in="0.16.0", + ) + return cls.model_validate(data) + + @staticmethod + def get_all_values(messages: Sequence[Message]) -> list[str]: + """ + Return all converted values across the provided messages (DEPRECATED — use the module function). + + Args: + messages (Sequence[Message]): Messages to aggregate. + + Returns: + list[str]: Flattened list of converted values. + + """ + print_deprecation_message( + old_item="Message.get_all_values()", + new_item="pyrit.models.get_all_values()", + removed_in="0.16.0", + ) + from pyrit.models.messages.conversations import get_all_values as _get_all_values + + return _get_all_values(messages) + + @staticmethod + def flatten_to_message_pieces( + messages: Sequence[Message], + ) -> MutableSequence[MessagePiece]: + """ + Flatten messages into a single list of message pieces (DEPRECATED — use the module function). + + Args: + messages (Sequence[Message]): Messages to flatten. + + Returns: + MutableSequence[MessagePiece]: Flattened message pieces. + + """ + print_deprecation_message( + old_item="Message.flatten_to_message_pieces()", + new_item="pyrit.models.flatten_to_message_pieces()", + removed_in="0.16.0", + ) + from pyrit.models.messages.conversations import flatten_to_message_pieces as _flatten + + return _flatten(messages) diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py new file mode 100644 index 0000000000..98f5d9b026 --- /dev/null +++ b/pyrit/models/messages/message_piece.py @@ -0,0 +1,356 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional +from uuid import uuid4 + +from pydantic import ( + AwareDatetime, + BaseModel, + BeforeValidator, + ConfigDict, + Field, + PlainSerializer, + model_validator, +) + +from pyrit.common.deprecation import print_deprecation_message +from pyrit.models.data_type_serializer import data_serializer_factory +from pyrit.models.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.literals import ( # noqa: TC001 (runtime-required by Pydantic field annotations) + ChatMessageRole, + PromptDataType, + PromptResponseError, +) +from pyrit.models.score import Score + +if TYPE_CHECKING: + from pyrit.models.messages.message import Message + + +# Deprecated kwargs whose presence in ``MessagePiece(...)`` should emit a +# ``DeprecationWarning``. Each entry is ``(kwarg_name, removed_in)``. Kept here +# (rather than embedded in the validator body) to make the deprecation surface +# easy to read and update. +# +# These can be deleted entirely once their ``removed_in`` releases ship — the +# Pydantic field definitions and ``extra="forbid"`` config will then reject +# the kwargs naturally. +_DEPRECATED_KWARGS: tuple[tuple[str, str], ...] = ( + ("labels", "0.16.0"), + ("scorer_identifier", "0.15.0"), + ("scores", "0.15.0"), + ("targeted_harm_categories", "0.15.0"), +) + + +# Annotated alias that round-trips identifier fields through the flat dict +# storage shape. ``ComponentIdentifier`` is a Pydantic model with a custom +# flat serializer; ``Score`` is still a plain class needing ``from_dict`` / +# ``to_dict``. Drop the ``Score`` alias once it becomes a Pydantic model. +ComponentIdentifierField = Annotated[ + ComponentIdentifier, + BeforeValidator(lambda v: ComponentIdentifier.model_validate(v) if isinstance(v, dict) else v), + PlainSerializer(lambda v: v.model_dump() if v is not None else None, return_type=Optional[dict]), +] + +ScoreField = Annotated[ + Score, + BeforeValidator(lambda v: Score.from_dict(v) if isinstance(v, dict) else v), + PlainSerializer(lambda v: v.to_dict(), return_type=dict), +] + + +def __getattr__(name: str) -> Any: + """ + Lazily resolve deprecated module-level aliases. + + Args: + name: The attribute name being accessed. + + Returns: + The resolved alias (currently only ``Originator``). + + Raises: + AttributeError: If ``name`` is not a known deprecated alias. + """ + if name == "Originator": + print_deprecation_message( + old_item="pyrit.models.message_piece.Originator", + new_item=( + "inline Literal['attack', 'converter', 'undefined', 'scorer'] " + "(the type alias is being removed; the originator field itself is " + "deprecated and will be removed in 0.15.0)" + ), + removed_in="0.15.0", + ) + return Literal["attack", "converter", "undefined", "scorer"] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +class MessagePiece(BaseModel): + """ + A single piece of a message exchanged with a target. + + Targets that accept multimodal input (e.g., text + image) are represented + as a list of ``MessagePiece`` instances grouped under one + ``Message``. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + validate_assignment=False, + ) + + id: uuid.UUID = Field(default_factory=uuid4) # noqa: A003 + role: ChatMessageRole + conversation_id: str = Field(default_factory=lambda: str(uuid4())) + sequence: int = -1 + timestamp: AwareDatetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc)) + original_value: str + original_value_data_type: PromptDataType = "text" + original_value_sha256: Optional[str] = None + converted_value: str = "" + converted_value_data_type: PromptDataType = "text" + converted_value_sha256: Optional[str] = None + response_error: PromptResponseError = "none" + originator: Literal["attack", "converter", "undefined", "scorer"] = "undefined" + original_prompt_id: Optional[uuid.UUID] = None + labels: dict[str, Any] = Field(default_factory=dict) + targeted_harm_categories: list[str] = Field(default_factory=list) + prompt_metadata: dict[str, Any] = Field(default_factory=dict) + converter_identifiers: list[ComponentIdentifierField] = Field(default_factory=list) + prompt_target_identifier: Optional[ComponentIdentifierField] = None + attack_identifier: Optional[ComponentIdentifierField] = None + scorer_identifier: Optional[ComponentIdentifierField] = None + scores: list[ScoreField] = Field(default_factory=list) + + # When True, the memory layer skips persisting this piece. Used for ephemeral + # pieces a scorer creates to score arbitrary content; ``exclude=True`` keeps + # the flag out of JSON / memory schema serialization. Named ``not_in_memory`` + # to match PyRIT's ``add_*_to_memory`` API verbs. + not_in_memory: bool = Field(default=False, exclude=True) + + # ------------------------------------------------------------------ # + # Validators + # ------------------------------------------------------------------ # + @model_validator(mode="before") + @classmethod + def _warn_on_deprecated_kwargs(cls, data: Any) -> Any: + """ + Emit DeprecationWarning for each deprecated kwarg explicitly passed. + + Returns: + The (unchanged) input ``data`` so validation can continue. + """ + if not isinstance(data, dict): + return data + for kwarg, removed_in in _DEPRECATED_KWARGS: + if data.get(kwarg) is not None: + print_deprecation_message( + old_item=f"MessagePiece(..., {kwarg}=...)", + new_item="MessagePiece(...)", + removed_in=removed_in, + ) + # ``originator`` is special: only warn when the caller explicitly + # opts into a non-default value. + if data.get("originator", "undefined") != "undefined": + print_deprecation_message( + old_item="MessagePiece(..., originator=...)", + new_item="MessagePiece(...)", + removed_in="0.15.0", + ) + return data + + @model_validator(mode="before") + @classmethod + def _mirror_original_to_converted(cls, data: Any) -> Any: + """ + When ``converted_value`` / ``converted_value_data_type`` aren't supplied, mirror the originals. + + Returns: + The input ``data`` with mirrored converted fields applied. + """ + if not isinstance(data, dict): + return data + if not data.get("converted_value") and "original_value" in data: + data["converted_value"] = data["original_value"] + if not data.get("converted_value_data_type") and "original_value_data_type" in data: + data["converted_value_data_type"] = data["original_value_data_type"] + return data + + @model_validator(mode="after") + def _set_original_prompt_id_default(self) -> MessagePiece: + """ + Enforce invariant: ``original_prompt_id == id`` for non-duplicate pieces. + + Returns: + ``self`` (with ``original_prompt_id`` populated when previously ``None``). + """ + if self.original_prompt_id is None: + self.original_prompt_id = self.id + return self + + # ------------------------------------------------------------------ # + # Public API + # ------------------------------------------------------------------ # + @property + def api_role(self) -> ChatMessageRole: + """ + Role to use for API calls. + + Maps ``simulated_assistant`` to ``assistant`` for API compatibility. + Use this property when sending messages to external APIs. + """ + return "assistant" if self.role == "simulated_assistant" else self.role + + @property + def is_simulated(self) -> bool: + """Whether this piece represents a simulated assistant response.""" + return self.role == "simulated_assistant" + + def to_message(self) -> Message: + """ + Wrap this piece in a single-piece ``Message``. + + Returns: + A new ``Message`` containing only this piece. + """ + # Deferred import: ``pyrit.models.messages.message`` imports ``MessagePiece`` + # at module load, so a top-level import here would deadlock the cycle. + from pyrit.models.messages.message import Message + + return Message(message_pieces=[self]) + + def copy_lineage_from(self, *, source: MessagePiece) -> None: + """ + Copy lineage metadata from ``source`` onto this piece. + + Lineage fields are the metadata that tie a piece back to its originating + conversation, attack, and target. Mutable containers (``labels``, + ``prompt_metadata``) are shallow-copied so that mutations on one piece + do not affect others. + + Args: + source: The piece whose lineage will be copied onto ``self``. + """ + self.conversation_id = source.conversation_id + self.labels = dict(source.labels) + self.attack_identifier = source.attack_identifier + self.prompt_target_identifier = source.prompt_target_identifier + self.prompt_metadata = dict(source.prompt_metadata) + + def has_error(self) -> bool: + """ + Return ``True`` when ``response_error`` is not ``"none"``. + + Returns: + ``True`` if the piece carries any non-``"none"`` error code. + """ + return self.response_error != "none" + + def is_blocked(self) -> bool: + """ + Return ``True`` when ``response_error`` is ``"blocked"``. + + Returns: + ``True`` if the response was blocked by the target / content filter. + """ + return self.response_error == "blocked" + + # ------------------------------------------------------------------ # + # Deprecated method shims (removed in 0.16.0) + # ------------------------------------------------------------------ # + def to_dict(self) -> dict[str, Any]: + """ + Return a JSON-mode dict representation (DEPRECATED — use ``model_dump``). + + Returns: + A JSON-mode dict representation of the piece (same as + ``self.model_dump(mode="json")``). + """ + print_deprecation_message( + old_item="MessagePiece.to_dict()", + new_item='MessagePiece.model_dump(mode="json")', + removed_in="0.16.0", + ) + return self.model_dump(mode="json") + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> MessagePiece: + """ + Construct a MessagePiece from a dict (DEPRECATED — use ``model_validate``). + + Args: + data: A dict matching the MessagePiece field schema. + + Returns: + A new ``MessagePiece`` (same as ``cls.model_validate(data)``). + """ + print_deprecation_message( + old_item="MessagePiece.from_dict()", + new_item="MessagePiece.model_validate()", + removed_in="0.16.0", + ) + return cls.model_validate(data) + + def set_piece_not_in_database(self) -> None: + """ + Mark this piece as ephemeral (DEPRECATED — set ``not_in_memory`` directly). + + Example:: + + piece.not_in_memory = True + """ + print_deprecation_message( + old_item="MessagePiece.set_piece_not_in_database()", + new_item="MessagePiece.not_in_memory = True", + removed_in="0.16.0", + ) + self.not_in_memory = True + + async def set_sha256_values_async(self) -> None: + """ + Compute SHA256 hash values for original and converted payloads. + + Async because blob payloads may need to be fetched. Must be called + explicitly after construction. + """ + original_serializer = data_serializer_factory( + category="prompt-memory-entries", + data_type=self.original_value_data_type, + value=self.original_value, + ) + self.original_value_sha256 = await original_serializer.get_sha256() + + converted_serializer = data_serializer_factory( + category="prompt-memory-entries", + data_type=self.converted_value_data_type, + value=self.converted_value, + ) + self.converted_value_sha256 = await converted_serializer.get_sha256() + + +def sort_message_pieces(message_pieces: list[MessagePiece]) -> list[MessagePiece]: + """ + Group by ``conversation_id``, ordering by earliest timestamp then ``sequence``. + + Conversations are ordered by their earliest piece's timestamp; pieces + within a conversation are ordered by ``sequence``. + + Args: + message_pieces: The pieces to sort. Not mutated. + + Returns: + A new list containing the same pieces in deterministic order. + """ + earliest_timestamps = { + convo_id: min(x.timestamp for x in message_pieces if x.conversation_id == convo_id) + for convo_id in {x.conversation_id for x in message_pieces} + } + return sorted(message_pieces, key=lambda x: (earliest_timestamps[x.conversation_id], x.conversation_id, x.sequence)) diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index 6e76d0a171..579ba44360 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -16,8 +16,8 @@ from typing import TYPE_CHECKING, Any, Optional, Union from pyrit.common.yaml_loadable import YamlLoadable -from pyrit.models.message import Message -from pyrit.models.message_piece import MessagePiece +from pyrit.models.messages.message import Message +from pyrit.models.messages.message_piece import MessagePiece from pyrit.models.seeds.seed import Seed from pyrit.models.seeds.seed_objective import SeedObjective from pyrit.models.seeds.seed_prompt import SeedPrompt diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index 2c9fbe99c9..e1422d68e7 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -168,7 +168,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text ) request = Message( - [ + message_pieces=[ MessagePiece( role="user", original_value=prompt, diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index 45600e6009..55acbd227b 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -321,14 +321,14 @@ async def _probe_system_prompt_async(target: PromptTarget, timeout_s: float, ret prompt_metadata=_probe_metadata(), ) try: - target._memory.add_message_to_memory(request=Message([system_piece])) + target._memory.add_message_to_memory(request=Message(message_pieces=[system_piece])) except Exception as exc: logger.debug("System-prompt probe could not seed system message: %s", exc) return False user_piece = _user_text_piece(value="hi", conversation_id=conversation_id) return await _send_and_check_async( target=target, - message=Message([user_piece]), + message=Message(message_pieces=[user_piece]), timeout_s=timeout_s, retries=retries, label="System-prompt probe", @@ -356,7 +356,7 @@ async def _probe_multi_message_pieces_async(target: PromptTarget, timeout_s: flo ] return await _send_and_check_async( target=target, - message=Message(pieces), + message=Message(message_pieces=pieces), timeout_s=timeout_s, retries=retries, label="Multi-message-pieces probe", @@ -395,13 +395,17 @@ async def _probe_multi_turn_async(target: PromptTarget, timeout_s: float, retrie conversation_id = _new_conversation_id() first = _user_text_piece(value="My favorite color is blue.", conversation_id=conversation_id) if not await _send_and_check_async( - target=target, message=Message([first]), timeout_s=timeout_s, retries=retries, label="Multi-turn probe (turn 1)" + target=target, + message=Message(message_pieces=[first]), + timeout_s=timeout_s, + retries=retries, + label="Multi-turn probe (turn 1)", ): return False # Seed memory so the second send sees real prior history. try: - target._memory.add_message_to_memory(request=Message([first])) + target._memory.add_message_to_memory(request=Message(message_pieces=[first])) assistant_reply = MessagePiece( role="assistant", original_value="Got it.", @@ -417,7 +421,7 @@ async def _probe_multi_turn_async(target: PromptTarget, timeout_s: float, retrie second = _user_text_piece(value="What did I just tell you?", conversation_id=conversation_id) return await _send_and_check_async( target=target, - message=Message([second]), + message=Message(message_pieces=[second]), timeout_s=timeout_s, retries=retries, label="Multi-turn probe (turn 2)", @@ -452,7 +456,11 @@ async def _probe_json_output_async(target: PromptTarget, timeout_s: float, retri prompt_metadata=_probe_metadata({"response_format": "json"}), ) return await _send_and_check_async( - target=target, message=Message([piece]), timeout_s=timeout_s, retries=retries, label="JSON-output probe" + target=target, + message=Message(message_pieces=[piece]), + timeout_s=timeout_s, + retries=retries, + label="JSON-output probe", ) @@ -495,7 +503,11 @@ async def _probe_json_schema_async(target: PromptTarget, timeout_s: float, retri ), ) return await _send_and_check_async( - target=target, message=Message([piece]), timeout_s=timeout_s, retries=retries, label="JSON-schema probe" + target=target, + message=Message(message_pieces=[piece]), + timeout_s=timeout_s, + retries=retries, + label="JSON-schema probe", ) @@ -836,4 +848,4 @@ def _create_test_message( ) ) - return Message(pieces) + return Message(message_pieces=pieces) diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 745fb18e31..ff7a2fc4e2 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -614,7 +614,7 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me # Create a new message with the tool output tool_piece = self._make_tool_piece(tool_output, tool_call_section["call_id"], reference_piece=message_piece) - tool_message = Message(message_pieces=[tool_piece], skip_validation=True) + tool_message = Message(message_pieces=[tool_piece]) # Add tool output message to conversation and responses list working_conversation.append(tool_message) diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index c06d85fbb0..0d45786551 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -746,7 +746,7 @@ async def _score_value_with_llm( ) ) - scorer_llm_request = Message(message_pieces) + scorer_llm_request = Message(message_pieces=message_pieces) try: response = await prompt_target.send_prompt_async(message=scorer_llm_request) except Exception as ex: diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index bf81759a2e..6af8a963be 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -120,7 +120,7 @@ async def _check_for_password_in_conversation(self, conversation_id: str) -> str conversation_as_text += "\n" request = Message( - [ + message_pieces=[ MessagePiece( role="user", original_value_data_type="text", diff --git a/pyrit/score/true_false/prompt_shield_scorer.py b/pyrit/score/true_false/prompt_shield_scorer.py index ee55f87054..8f048300cc 100644 --- a/pyrit/score/true_false/prompt_shield_scorer.py +++ b/pyrit/score/true_false/prompt_shield_scorer.py @@ -70,7 +70,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op body = message_piece.original_value request = Message( - [ + message_pieces=[ MessagePiece( role="user", original_value=body, diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index 3d78c210e5..47465d0299 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -8,7 +8,7 @@ from pyrit.models import ComponentIdentifier, build_atomic_attack_identifier from pyrit.models.attack_result import AttackOutcome, AttackResult from pyrit.models.conversation_reference import ConversationReference, ConversationType -from pyrit.models.message_piece import MessagePiece +from pyrit.models.messages.message_piece import MessagePiece from pyrit.models.retry_event import RetryEvent from pyrit.models.score import Score diff --git a/tests/unit/models/test_import_boundary.py b/tests/unit/models/test_import_boundary.py index bb4b6ca9e9..bc6342d8d8 100644 --- a/tests/unit/models/test_import_boundary.py +++ b/tests/unit/models/test_import_boundary.py @@ -45,9 +45,6 @@ # violations not in this list fail the test; entries that no longer match # source also fail. KNOWN_TOP_LEVEL_VIOLATIONS: dict[str, dict[str, str]] = { - "pyrit.models.message": { - "pyrit.common.utils": "phase-4", - }, "pyrit.models.harm_definition": { "pyrit.common.path": "phase-1", }, diff --git a/tests/unit/models/test_message.py b/tests/unit/models/test_message.py index 595aefa2f6..abc3e104b7 100644 --- a/tests/unit/models/test_message.py +++ b/tests/unit/models/test_message.py @@ -355,3 +355,113 @@ def test_set_response_not_in_database_emits_warning_and_delegates(self) -> None: msgs = [w for w in caught if issubclass(w.category, DeprecationWarning)] assert any("set_response_not_in_database" in str(m.message) for m in msgs) assert piece.not_in_memory is True + + +class TestMessagePydanticShape: + """Tests for the Pydantic v2 BaseModel behavior of Message.""" + + def test_keyword_construction_does_not_warn(self) -> None: + import warnings as _warnings + + piece = MessagePiece(role="user", original_value="hi", conversation_id="c") + with _warnings.catch_warnings(record=True) as caught: + _warnings.simplefilter("always") + Message(message_pieces=[piece]) + assert not [w for w in caught if issubclass(w.category, DeprecationWarning)] + + def test_positional_construction_warns_and_works(self) -> None: + import warnings as _warnings + + piece = MessagePiece(role="user", original_value="hi", conversation_id="c") + with _warnings.catch_warnings(record=True) as caught: + _warnings.simplefilter("always") + message = Message([piece]) + assert message.message_pieces == [piece] + assert any(issubclass(w.category, DeprecationWarning) and "positional" in str(w.message) for w in caught) + + def test_too_many_positional_args_raises(self) -> None: + piece = MessagePiece(role="user", original_value="hi", conversation_id="c") + with pytest.raises(TypeError, match="at most 1 positional argument"): + Message([piece], [piece]) + + def test_skip_validation_kwarg_is_deprecated_noop(self) -> None: + import warnings as _warnings + + piece = MessagePiece(role="user", original_value="hi", conversation_id="c") + with _warnings.catch_warnings(record=True) as caught: + _warnings.simplefilter("always") + message = Message(message_pieces=[piece], skip_validation=True) + assert message.message_pieces == [piece] + assert any(issubclass(w.category, DeprecationWarning) and "skip_validation" in str(w.message) for w in caught) + + def test_model_validate_canonical_shape(self) -> None: + piece = MessagePiece(role="user", original_value="hi", conversation_id="c") + message = Message.model_validate({"message_pieces": [piece.model_dump()]}) + assert message.get_value() == "hi" + + def test_model_validate_legacy_dict_shape(self) -> None: + original = Message.from_prompt(prompt="legacy hello", role="user") + rebuilt = Message.model_validate(original.to_dict()) + assert rebuilt.get_value() == "legacy hello" + + def test_value_equality(self, message_pieces: list[MessagePiece]) -> None: + assert Message(message_pieces=message_pieces) == Message(message_pieces=message_pieces) + + def test_membership_uses_value_equality(self, message_pieces: list[MessagePiece]) -> None: + a = Message(message_pieces=message_pieces) + b = Message(message_pieces=message_pieces) + assert a in [b] + + def test_validate_instance_method_still_callable(self, message: Message) -> None: + message.validate() + message.message_pieces = [] + with pytest.raises(ValueError, match="at least one message piece"): + message.validate() + + def test_duplicate_creates_new_ids_and_deep_copy(self, message: Message) -> None: + duplicated = message.duplicate() + original_ids = {p.id for p in message.message_pieces} + duplicated_ids = {p.id for p in duplicated.message_pieces} + assert original_ids.isdisjoint(duplicated_ids) + duplicated.message_pieces[0].original_value = "changed" + assert message.message_pieces[0].original_value == "First piece" + + def test_to_dict_keeps_legacy_keys_while_model_dump_is_canonical(self) -> None: + message = Message.from_prompt(prompt="hi", role="user") + with pytest.warns(DeprecationWarning): + legacy = message.to_dict() + assert set(legacy) == { + "role", + "converted_value", + "conversation_id", + "sequence", + "converted_value_data_type", + "pieces", + } + assert set(message.model_dump()) == {"message_pieces"} + + +class TestMessageModuleLayout: + """Lock in the messages-package layout and its backward-compatible re-exports.""" + + def test_conversation_helpers_live_in_conversations_module(self) -> None: + from pyrit.models import messages + from pyrit.models.messages import conversations + + for name in ( + "get_all_values", + "flatten_to_message_pieces", + "group_conversation_message_pieces_by_sequence", + "group_message_pieces_into_conversations", + "construct_response_from_request", + ): + assert getattr(conversations, name) is getattr(messages, name) + + def test_legacy_module_paths_reexport_same_objects(self) -> None: + import pyrit.models.message as legacy_message + import pyrit.models.message_piece as legacy_message_piece + from pyrit.models.messages.message import Message as PackagedMessage + from pyrit.models.messages.message_piece import MessagePiece as PackagedMessagePiece + + assert legacy_message.Message is PackagedMessage + assert legacy_message_piece.MessagePiece is PackagedMessagePiece diff --git a/tests/unit/prompt_target/test_discover_target_capabilities.py b/tests/unit/prompt_target/test_discover_target_capabilities.py index 38ade07477..dfe9deef99 100644 --- a/tests/unit/prompt_target/test_discover_target_capabilities.py +++ b/tests/unit/prompt_target/test_discover_target_capabilities.py @@ -874,11 +874,8 @@ async def test_response_with_no_pieces_treated_as_failure(self) -> None: """Responses whose Messages have no pieces must also be rejected.""" target = MockPromptTarget() target._send_prompt_to_target_async = AsyncMock( # type: ignore[method-assign] - return_value=[Message.__new__(Message)] + return_value=[Message.model_construct(message_pieces=[])] ) - # Bypass __init__ to construct a Message with no pieces (Message.__init__ rejects empty). - empty_msg = target._send_prompt_to_target_async.return_value[0] - empty_msg.message_pieces = [] result = await _discover_capability_flags_async( target=target, @@ -891,8 +888,7 @@ async def test_mixed_empty_message_in_response_treated_as_failure(self) -> None: """Any empty Message in a multi-message response must cause the probe to fail.""" target = MockPromptTarget() ok = _ok_response()[0] - empty = Message.__new__(Message) - empty.message_pieces = [] + empty = Message.model_construct(message_pieces=[]) target._send_prompt_to_target_async = AsyncMock(return_value=[ok, empty]) # type: ignore[method-assign] result = await _discover_capability_flags_async(