diff --git a/doc/code/memory/10_schema_diagram.md b/doc/code/memory/10_schema_diagram.md index 9fbcefc2a2..40837b71d8 100644 --- a/doc/code/memory/10_schema_diagram.md +++ b/doc/code/memory/10_schema_diagram.md @@ -40,14 +40,16 @@ flowchart LR P_labels["labels (VARCHAR)"] P_prompt_metadata["prompt_metadata (VARCHAR)"] P_converter_identifiers["converter_identifiers (VARCHAR)"] - P_prompt_target_identifier["prompt_target_identifier (VARCHAR)"] - P_attack_identifier["attack_identifier (VARCHAR)"] P_response_error["response_error (VARCHAR)"] P_converted_value_data_type["converted_value_data_type (VARCHAR)"] P_converted_value["converted_value (VARCHAR)"] P_converted_value_sha256["converted_value_sha256 (VARCHAR)"] P_original_prompt_id["original_prompt_id (UUID)"] end + subgraph Conversations["Conversations"] + C_conversation_id["conversation_id (VARCHAR)"] + C_target_identifier["target_identifier (VARCHAR)"] + end subgraph ScoreEntries["ScoreEntries"] Sc_id["id (UUID)"] Sc_prompt_request_response_id["prompt_request_response_id (VARCHAR)"] @@ -63,6 +65,7 @@ flowchart LR end S_value_sha256 -- N:N relationship to query --> P_original_value_sha256 P_id -- 1:N relationship to query --> Sc_prompt_request_response_id + P_conversation_id -- N:1 relationship to query --> C_conversation_id style S_value_sha256 fill:#ff8800ff style P_id fill:#14a519ff diff --git a/doc/code/memory/3_memory_data_types.md b/doc/code/memory/3_memory_data_types.md index b35daaa005..0fd1b0988d 100644 --- a/doc/code/memory/3_memory_data_types.md +++ b/doc/code/memory/3_memory_data_types.md @@ -23,8 +23,6 @@ One of the most fundamental data structures in PyRIT is [MessagePiece](../../../ - **`labels`**: Dictionary of labels for categorization and filtering - **`prompt_metadata`**: Component-specific metadata (e.g., blob URIs, document types) - **`converter_identifiers`**: List of converters applied to transform the prompt -- **`prompt_target_identifier`**: Information about the target that received this prompt -- **`attack_identifier`**: Information about the attack that generated this prompt - **`scorer_identifier`**: Information about the scorer that evaluated this prompt - **`response_error`**: Error status (e.g., `none`, `blocked`, `processing`) - **`originator`**: Source of the prompt (`attack`, `converter`, `scorer`, `undefined`) @@ -54,6 +52,8 @@ This rich context allows PyRIT to track the full lifecycle of each interaction, A conversation is a list of `Messages` that share the same `conversation_id`. The sequence of the `MessagePieces` and their corresponding `Messages` dictates the order of the conversation. +A conversation is always held with a single target. That target's identifier is recorded once per conversation in the `Conversations` table (`target_identifier`) rather than on every `MessagePiece`. Use `memory.get_conversation_metadata(conversation_id=...)` to retrieve it. + Here is a sample conversation made up of three `Messages` which all share the same conversation ID. The first `Message` is the `system` message, followed by a multi-modal `user` prompt with a text `MessagePiece` and an image `MessagePiece`, and finally the `assistant` response in the form of a text `MessagePiece`. ```{mermaid} diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 23d58bf623..f6cea27567 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -312,6 +312,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt cutoff_index=request.cutoff_index, labels_override=labels, remap_assistant_to_simulated=True, + target_identifier=target_identifier, ) else: conversation_id = str(uuid.uuid4()) @@ -344,6 +345,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt conversation_id=conversation_id, prepended=request.prepended_conversation, labels=labels, # deprecated + target_identifier=target_identifier, ) return CreateAttackResponse( @@ -475,9 +477,13 @@ async def create_related_conversation_async( # --- Branch via duplication (preferred for tracking) --------------- if request.source_conversation_id is not None and request.cutoff_index is not None: + source_metadata = self._memory.get_conversation_metadata( + conversation_id=request.source_conversation_id + ) new_conversation_id = self._duplicate_conversation_up_to( source_conversation_id=request.source_conversation_id, cutoff_index=request.cutoff_index, + target_identifier=source_metadata.target_identifier if source_metadata else None, ) else: new_conversation_id = str(uuid.uuid4()) @@ -622,11 +628,13 @@ async def add_message_async(self, *, attack_result_id: str, request: AddMessageR labels=attack_labels, # deprecated ) else: + existing_metadata = self._memory.get_conversation_metadata(conversation_id=msg_conversation_id) await self._store_message_only_async( conversation_id=msg_conversation_id, request=request, sequence=sequence, labels=attack_labels, # deprecated + target_identifier=existing_metadata.target_identifier if existing_metadata else None, ) await self._update_attack_after_message_async(attack_result_id=attack_result_id, ar=ar, request=request) @@ -828,6 +836,7 @@ def _duplicate_conversation_up_to( cutoff_index: int, labels_override: dict[str, str] | None = None, remap_assistant_to_simulated: bool = False, + target_identifier: ComponentIdentifier | None = None, ) -> str: """ Duplicate messages from a conversation up to and including a turn index. @@ -846,6 +855,9 @@ def _duplicate_conversation_up_to( ``assistant`` are changed to ``simulated_assistant`` so the branched context is inert and won't confuse the target. + target_identifier (ComponentIdentifier | None): The target the new conversation + is held with, if known. Recorded once for the duplicated conversation. + Returns: The new conversation ID containing the duplicated messages. """ @@ -865,7 +877,9 @@ def _duplicate_conversation_up_to( piece.role = "simulated_assistant" if all_pieces: - self._memory.add_message_pieces_to_memory(message_pieces=list(all_pieces)) + self._memory.add_message_pieces_to_memory( + message_pieces=list(all_pieces), target_identifier=target_identifier + ) return new_conversation_id @@ -953,6 +967,7 @@ async def _store_prepended_messages_async( conversation_id: str, prepended: list[Any], labels: dict[str, str] | None = None, # deprecated + target_identifier: ComponentIdentifier | None = None, ) -> None: """Store prepended conversation messages in memory.""" for seq, msg in enumerate(prepended): @@ -964,7 +979,9 @@ async def _store_prepended_messages_async( sequence=seq, labels=labels, # deprecated ) - self._memory.add_message_pieces_to_memory(message_pieces=[piece]) + self._memory.add_message_pieces_to_memory( + message_pieces=[piece], target_identifier=target_identifier + ) async def _send_and_store_message_async( self, @@ -1010,6 +1027,7 @@ async def _store_message_only_async( request: AddMessageRequest, sequence: int, labels: dict[str, str] | None = None, # deprecated + target_identifier: ComponentIdentifier | None = None, ) -> None: """Store message without sending (send=False).""" await self._persist_base64_pieces_async(request) @@ -1021,7 +1039,9 @@ async def _store_message_only_async( sequence=sequence, labels=labels, # deprecated ) - self._memory.add_message_pieces_to_memory(message_pieces=[piece]) + self._memory.add_message_pieces_to_memory( + message_pieces=[piece], target_identifier=target_identifier + ) def _resolve_video_remix_metadata(self, request: AddMessageRequest) -> None: """ diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index b4b47175f2..8a3cbac557 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -54,8 +54,6 @@ def get_adversarial_chat_messages( prepended_conversation: list[Message], *, adversarial_chat_conversation_id: str, - attack_identifier: ComponentIdentifier, - adversarial_chat_target_identifier: ComponentIdentifier, labels: dict[str, str] | None = None, # deprecated ) -> list[Message]: """ @@ -72,8 +70,6 @@ def get_adversarial_chat_messages( Args: prepended_conversation: The original conversation messages to transform. adversarial_chat_conversation_id: Conversation ID for the adversarial chat. - attack_identifier (ComponentIdentifier): Attack identifier to associate with messages. - adversarial_chat_target_identifier (ComponentIdentifier): Target identifier for the adversarial chat. labels: Optional labels to associate with the messages. Deprecated: This parameter will be removed in a release 0.16.0. @@ -114,8 +110,6 @@ def get_adversarial_chat_messages( original_value_data_type=piece.original_value_data_type, converted_value_data_type=piece.converted_value_data_type, conversation_id=adversarial_chat_conversation_id, - attack_identifier=attack_identifier, - prompt_target_identifier=adversarial_chat_target_identifier, labels=labels or {}, # deprecated ) @@ -190,20 +184,17 @@ class ConversationManager: def __init__( self, *, - attack_identifier: ComponentIdentifier, prompt_normalizer: PromptNormalizer | None = None, ) -> None: """ Initialize the conversation manager. Args: - attack_identifier (ComponentIdentifier): The identifier of the attack this manager belongs to. prompt_normalizer: Optional prompt normalizer for converting prompts. If not provided, a default PromptNormalizer instance will be created. """ self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._memory = CentralMemory.get_memory_instance() - self._attack_identifier = attack_identifier def get_conversation(self, conversation_id: str) -> list[Message]: """ @@ -276,7 +267,6 @@ def set_system_prompt( target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, - attack_identifier=self._attack_identifier, labels=labels, # deprecated ) @@ -359,6 +349,7 @@ async def initialize_context_async( request_converters=request_converters, prepended_conversation_config=prepended_conversation_config, max_turns=max_turns, + target_identifier=target.get_identifier(), ) async def _handle_non_chat_target_async( @@ -439,6 +430,7 @@ async def add_prepended_conversation_to_memory_async( request_converters: list[PromptConverterConfiguration] | None = None, prepended_conversation_config: Optional["PrependedConversationConfig"] = None, max_turns: int | None = None, + target_identifier: ComponentIdentifier | None = None, ) -> int: """ Add prepended conversation messages to memory for a chat target. @@ -459,6 +451,8 @@ async def add_prepended_conversation_to_memory_async( request_converters: Optional converters to apply to messages. prepended_conversation_config: Optional configuration for converter roles. max_turns: If provided, validates that turn count doesn't exceed this limit. + target_identifier (ComponentIdentifier | None): The target the conversation is held + with, if known. Recorded once per conversation. Returns: The number of turns (assistant messages) added. @@ -485,7 +479,6 @@ async def add_prepended_conversation_to_memory_async( for piece in message_copy.message_pieces: piece.conversation_id = conversation_id - piece.attack_identifier = self._attack_identifier # Count turns at message level (only assistant/simulated_assistant messages) # A multi-part response still counts as one turn @@ -506,7 +499,7 @@ async def add_prepended_conversation_to_memory_async( ) # Add to memory - self._memory.add_message_to_memory(request=message_copy) + self._memory.add_message_to_memory(request=message_copy, target_identifier=target_identifier) logger.debug(f"Added prepended message {i + 1}/{len(valid_messages)} to memory") return turn_count @@ -520,6 +513,7 @@ async def _process_prepended_for_chat_target_async( request_converters: list[PromptConverterConfiguration] | None, prepended_conversation_config: Optional["PrependedConversationConfig"], max_turns: int | None, + target_identifier: ComponentIdentifier | None = None, ) -> ConversationState: """ Process prepended conversation for a chat target. @@ -536,6 +530,8 @@ async def _process_prepended_for_chat_target_async( request_converters: Converters to apply. prepended_conversation_config: Configuration for converter roles. max_turns: Maximum turns for validation. + target_identifier (ComponentIdentifier | None): The objective target the + conversation is held with, if known. Returns: ConversationState with turn_count and scores. @@ -555,6 +551,7 @@ async def _process_prepended_for_chat_target_async( request_converters=request_converters, prepended_conversation_config=prepended_conversation_config, max_turns=max_turns, + target_identifier=target_identifier, ) # Update context for multi-turn attacks to reflect prepended_conversation diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 4e7e5caefa..b7b04ec10e 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -172,7 +172,6 @@ def __init__( # Initialize prompt normalizer and conversation manager self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._conversation_manager = ConversationManager( - attack_identifier=self.get_identifier(), prompt_normalizer=self._prompt_normalizer, ) @@ -279,7 +278,6 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -291,7 +289,6 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=context.memory_labels, - attack_identifier=self.get_identifier(), ) # Store the response @@ -377,7 +374,6 @@ async def _score_combined_value_async( with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier(), objective=objective, ): diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index bc987f1270..8a2e6dd730 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -232,7 +232,6 @@ def __init__( # Initialize utilities self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._conversation_manager = ConversationManager( - attack_identifier=self.get_identifier(), prompt_normalizer=self._prompt_normalizer, ) @@ -331,7 +330,6 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None: self._adversarial_chat.set_system_prompt( system_prompt=system_prompt, conversation_id=context.session.adversarial_chat_conversation_id, - attack_identifier=self.get_identifier(), labels=context.memory_labels, # deprecated ) @@ -545,7 +543,6 @@ async def _send_prompt_to_adversarial_chat_async( with execution_context( component_role=ComponentRole.ADVERSARIAL_CHAT, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._adversarial_chat.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -554,7 +551,6 @@ async def _send_prompt_to_adversarial_chat_async( message=message, conversation_id=context.session.adversarial_chat_conversation_id, target=self._adversarial_chat, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) @@ -649,7 +645,6 @@ async def _send_prompt_to_objective_target_async( with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -660,7 +655,6 @@ async def _send_prompt_to_objective_target_async( conversation_id=context.session.conversation_id, request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) @@ -689,7 +683,6 @@ async def _check_refusal_async(self, context: CrescendoAttackContext, objective: with execution_context( component_role=ComponentRole.REFUSAL_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._refusal_scorer.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -721,7 +714,6 @@ async def _score_response_async(self, *, context: CrescendoAttackContext) -> Sco with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index e15ef6c63d..cc4c53531d 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -175,7 +175,6 @@ def __init__( # Initialize prompt normalizer and conversation manager self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._conversation_manager = ConversationManager( - attack_identifier=self.get_identifier(), prompt_normalizer=self._prompt_normalizer, ) @@ -355,7 +354,6 @@ async def _send_prompt_to_objective_target_async( with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -367,7 +365,6 @@ async def _send_prompt_to_objective_target_async( request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=context.memory_labels, # combined with strategy labels at _setup() - attack_identifier=self.get_identifier(), ) async def _evaluate_response_async(self, *, response: Message, objective: str) -> Score | None: @@ -389,7 +386,6 @@ async def _evaluate_response_async(self, *, response: Message, objective: str) - with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier() if self._objective_scorer else None, objective=objective, ): diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index 4aca72d054..bb3b07cef2 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -141,7 +141,9 @@ def _rotate_conversation_for_single_turn_target( if system_messages: new_conversation_id, pieces = memory.duplicate_messages(messages=system_messages) - memory.add_message_pieces_to_memory(message_pieces=pieces) + memory.add_message_pieces_to_memory( + message_pieces=pieces, target_identifier=self._objective_target.get_identifier() + ) context.session.conversation_id = new_conversation_id else: context.session.conversation_id = str(uuid.uuid4()) diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 8c0d34c6eb..08c18fc4c4 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -166,7 +166,7 @@ def __init__( # Initialize utilities self._prompt_normalizer = prompt_normalizer or PromptNormalizer() - self._conversation_manager = ConversationManager(attack_identifier=self.get_identifier()) + self._conversation_manager = ConversationManager() # set the maximum number of turns for the attack if max_turns <= 0: @@ -260,7 +260,6 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None: self._adversarial_chat.set_system_prompt( system_prompt=adversarial_system_prompt, conversation_id=context.session.adversarial_chat_conversation_id, - attack_identifier=self.get_identifier(), labels=context.memory_labels, # deprecated ) @@ -270,13 +269,13 @@ async def _setup_async(self, *, context: MultiTurnAttackContext[Any]) -> None: adversarial_messages = get_adversarial_chat_messages( prepended_conversation=context.prepended_conversation, adversarial_chat_conversation_id=context.session.adversarial_chat_conversation_id, - attack_identifier=self.get_identifier(), - adversarial_chat_target_identifier=self._adversarial_chat.get_identifier(), labels=context.memory_labels, ) for msg in adversarial_messages: - self._memory.add_message_to_memory(request=msg) + self._memory.add_message_to_memory( + request=msg, target_identifier=self._adversarial_chat.get_identifier() + ) async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> AttackResult: """ @@ -388,7 +387,6 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any] with execution_context( component_role=ComponentRole.ADVERSARIAL_CHAT, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._adversarial_chat.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -397,7 +395,6 @@ async def _generate_next_prompt_async(self, context: MultiTurnAttackContext[Any] message=prompt_message, conversation_id=context.session.adversarial_chat_conversation_id, target=self._adversarial_chat, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) @@ -550,7 +547,6 @@ async def _send_prompt_to_objective_target_async( with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, @@ -563,7 +559,6 @@ async def _send_prompt_to_objective_target_async( response_converter_configurations=self._response_converters, target=self._objective_target, labels=context.memory_labels, - attack_identifier=self.get_identifier(), ) if response is None: @@ -598,7 +593,6 @@ async def _score_response_async(self, *, context: MultiTurnAttackContext[Any]) - with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier(), objective_target_conversation_id=context.session.conversation_id, objective=context.objective, diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 0cd557b1c6..22d0541e9d 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -404,7 +404,6 @@ async def initialize_with_prepended_conversation_async( # Use ConversationManager to add messages to memory conversation_manager = ConversationManager( - attack_identifier=self._attack_id, prompt_normalizer=self._prompt_normalizer, ) @@ -413,6 +412,7 @@ async def initialize_with_prepended_conversation_async( conversation_id=self.objective_target_conversation_id, request_converters=self._request_converters, prepended_conversation_config=prepended_conversation_config, + target_identifier=self._objective_target.get_identifier(), ) # Build context string for adversarial chat system prompt (like Crescendo) @@ -558,7 +558,6 @@ async def _send_prompt_to_target_async(self, prompt: str) -> Message: with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self._attack_strategy_name, - attack_identifier=self._attack_id, component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, objective=self._objective, @@ -570,7 +569,6 @@ async def _send_prompt_to_target_async(self, prompt: str) -> Message: conversation_id=self.objective_target_conversation_id, target=self._objective_target, labels=self._memory_labels, - attack_identifier=self._attack_id, ) # Store the last response text for reference @@ -618,7 +616,6 @@ async def _send_initial_prompt_to_target_async(self) -> Message: with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self._attack_strategy_name, - attack_identifier=self._attack_id, component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, objective=self._objective, @@ -630,7 +627,6 @@ async def _send_initial_prompt_to_target_async(self) -> Message: conversation_id=self.objective_target_conversation_id, target=self._objective_target, labels=self._memory_labels, - attack_identifier=self._attack_id, ) # Store the last response text for reference @@ -675,7 +671,6 @@ async def _score_response_async(self, *, response: Message, objective: str) -> N with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self._attack_strategy_name, - attack_identifier=self._attack_id, component_identifier=self._objective_scorer.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, objective=objective, @@ -826,7 +821,9 @@ def duplicate(self) -> "_TreeOfAttacksNode": system_messages = [m for m in messages if m.api_role == "system"] if system_messages: new_id, pieces = self._memory.duplicate_messages(messages=system_messages) - self._memory.add_message_pieces_to_memory(message_pieces=pieces) + self._memory.add_message_pieces_to_memory( + message_pieces=pieces, target_identifier=self._objective_target.get_identifier() + ) duplicate_node.objective_target_conversation_id = new_id else: duplicate_node.objective_target_conversation_id = str(uuid.uuid4()) @@ -1021,7 +1018,6 @@ async def _generate_first_turn_prompt_async(self, objective: str) -> str: self._adversarial_chat.set_system_prompt( system_prompt=system_prompt, conversation_id=self.adversarial_chat_conversation_id, - attack_identifier=self._attack_id, labels=self._memory_labels, # deprecated ) @@ -1138,7 +1134,6 @@ async def _send_to_adversarial_chat_async(self, prompt_text: str) -> str: with execution_context( component_role=ComponentRole.ADVERSARIAL_CHAT, attack_strategy_name=self._attack_strategy_name, - attack_identifier=self._attack_id, component_identifier=self._adversarial_chat.get_identifier(), objective_target_conversation_id=self.objective_target_conversation_id, objective=self._objective, @@ -1148,7 +1143,6 @@ async def _send_to_adversarial_chat_async(self, prompt_text: str) -> str: conversation_id=self.adversarial_chat_conversation_id, target=self._adversarial_chat, labels=self._memory_labels, - attack_identifier=self._attack_id, ) return response.get_value() diff --git a/pyrit/executor/attack/single_turn/context_compliance.py b/pyrit/executor/attack/single_turn/context_compliance.py index 4568a158e8..98b0292541 100644 --- a/pyrit/executor/attack/single_turn/context_compliance.py +++ b/pyrit/executor/attack/single_turn/context_compliance.py @@ -234,7 +234,6 @@ async def _get_objective_as_benign_question_async( response = await self._prompt_normalizer.send_prompt_async( message=message, target=self._adversarial_chat, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) @@ -261,7 +260,6 @@ async def _get_benign_question_answer_async( response = await self._prompt_normalizer.send_prompt_async( message=message, target=self._adversarial_chat, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) @@ -286,7 +284,6 @@ async def _get_objective_as_question_async(self, *, objective: str, context: Sin response = await self._prompt_normalizer.send_prompt_async( message=message, target=self._adversarial_chat, - attack_identifier=self.get_identifier(), labels=context.memory_labels, ) diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index 794cc4294e..32e4db677b 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -106,7 +106,6 @@ def __init__( # Skip criteria could be set directly in the injected prompt normalizer self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._conversation_manager = ConversationManager( - attack_identifier=self.get_identifier(), prompt_normalizer=self._prompt_normalizer, ) @@ -314,7 +313,6 @@ async def _send_prompt_to_objective_target_async( with execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_target.get_identifier(), objective_target_conversation_id=context.conversation_id, objective=context.params.objective, @@ -326,7 +324,6 @@ async def _send_prompt_to_objective_target_async( request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=context.memory_labels, # combined with strategy labels at _setup() - attack_identifier=self.get_identifier(), ) async def _evaluate_response_async( @@ -353,7 +350,6 @@ async def _evaluate_response_async( with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name=self.__class__.__name__, - attack_identifier=self.get_identifier(), component_identifier=self._objective_scorer.get_identifier() if self._objective_scorer else None, objective=objective, ): diff --git a/pyrit/executor/attack/streaming/barge_in.py b/pyrit/executor/attack/streaming/barge_in.py index 161e364f85..2c96fbdb62 100644 --- a/pyrit/executor/attack/streaming/barge_in.py +++ b/pyrit/executor/attack/streaming/barge_in.py @@ -99,7 +99,6 @@ def __init__( self._response_converters = attack_converter_config.response_converters self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._conversation_manager = ConversationManager( - attack_identifier=self.get_identifier(), prompt_normalizer=self._prompt_normalizer, ) @@ -163,7 +162,6 @@ async def _perform_async(self, *, context: BargeInAttackContext[Any]) -> AttackR request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, prepended_conversation=context.prepended_conversation, - attack_identifier=self.get_identifier(), persist_prepended_conversation=False, ) diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index 4e3bfaa505..4ab0fe432a 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -21,7 +21,6 @@ from pyrit.models import ( AttackOutcome, AttackResult, - ComponentIdentifier, Message, build_atomic_attack_identifier, ) @@ -198,7 +197,6 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta objective=context.generated_objective, outcome=AttackOutcome.FAILURE, atomic_attack_identifier=build_atomic_attack_identifier( - attack_identifier=ComponentIdentifier.of(self), ), labels=context.memory_labels, ) diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index 0d953fe8ad..abc8efbd69 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -218,7 +218,6 @@ async def _setup_async(self, *, context: AnecdoctorContext) -> None: self._objective_target.set_system_prompt( system_prompt=system_prompt, conversation_id=context.conversation_id, - attack_identifier=self.get_identifier(), labels=context.memory_labels, # deprecated ) @@ -312,7 +311,6 @@ async def _send_examples_to_target_async( request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=context.memory_labels, - attack_identifier=self.get_identifier(), ) def _load_prompt_from_yaml(self, *, yaml_filename: str) -> str: @@ -381,7 +379,6 @@ async def _extract_knowledge_graph_async(self, *, context: AnecdoctorContext) -> self._processing_model.set_system_prompt( system_prompt=kg_system_prompt, conversation_id=kg_conversation_id, - attack_identifier=self.get_identifier(), labels=self._memory_labels, # deprecated ) @@ -399,7 +396,6 @@ async def _extract_knowledge_graph_async(self, *, context: AnecdoctorContext) -> request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=self._memory_labels, - attack_identifier=self.get_identifier(), ) if not kg_response: diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index 73410799a4..492ae2ee62 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -1003,7 +1003,6 @@ async def _send_prompts_to_target_async(self, *, context: FuzzerContext, prompts requests=requests, target=self._objective_target, labels=context.memory_labels, - attack_identifier=self.get_identifier(), batch_size=self._batch_size, ) diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py b/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py index 5b211487cc..868526279f 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_converter_base.py @@ -84,7 +84,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text self.converter_target.set_system_prompt( system_prompt=self.system_prompt, conversation_id=conversation_id, - attack_identifier=None, ) formatted_prompt = f"===={self.template_label} BEGINS====\n{prompt}\n===={self.template_label} ENDS====" @@ -97,7 +96,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text converted_value=formatted_prompt, conversation_id=conversation_id, sequence=1, - prompt_target_identifier=self.converter_target.get_identifier(), original_value_data_type=input_type, converted_value_data_type=input_type, converter_identifiers=[self.get_identifier()], diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py index 812979a797..05eb69aab7 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_crossover_converter.py @@ -82,7 +82,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text self.converter_target.set_system_prompt( system_prompt=self.system_prompt, conversation_id=conversation_id, - attack_identifier=None, ) formatted_prompt = f"===={self.template_label} BEGINS====\n{prompt}\n===={self.template_label} ENDS====" @@ -99,7 +98,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text converted_value=formatted_prompt, conversation_id=conversation_id, sequence=1, - prompt_target_identifier=self.converter_target.get_identifier(), original_value_data_type=input_type, converted_value_data_type=input_type, converter_identifiers=[self.get_identifier()], diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py index 627ed159ed..91b3ea127b 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer_expand_converter.py @@ -58,7 +58,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text self.converter_target.set_system_prompt( system_prompt=self.system_prompt, conversation_id=conversation_id, - attack_identifier=None, ) formatted_prompt = f"===={self.template_label} BEGINS====\n{prompt}\n===={self.template_label} ENDS====" @@ -72,7 +71,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text converted_value=formatted_prompt, conversation_id=conversation_id, sequence=1, - prompt_target_identifier=self.converter_target.get_identifier(), original_value_data_type=input_type, converted_value_data_type=input_type, converter_identifiers=[self.get_identifier()], diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index c367cd9b02..80c0fdefd0 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -335,7 +335,6 @@ async def _setup_attack_async(self, *, context: XPIAContext) -> str: response_converter_configurations=self._response_converters, target=self._attack_setup_target, labels=context.memory_labels, - attack_identifier=self.get_identifier(), conversation_id=context.attack_setup_target_conversation_id, ) @@ -374,7 +373,6 @@ async def _execute_processing_async(self, *, context: XPIAContext) -> str: original_value=processing_response, original_value_data_type="text", role="assistant", - attack_identifier=self.get_identifier(), ) ], ) @@ -576,7 +574,6 @@ async def process_async() -> str: request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, labels=context.memory_labels, - attack_identifier=self.get_identifier(), conversation_id=context.processing_conversation_id, ) diff --git a/pyrit/memory/alembic/versions/b2f4c6a8d1e3_add_conversations_table.py b/pyrit/memory/alembic/versions/b2f4c6a8d1e3_add_conversations_table.py new file mode 100644 index 0000000000..8b8d93f8ca --- /dev/null +++ b/pyrit/memory/alembic/versions/b2f4c6a8d1e3_add_conversations_table.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Introduce the Conversations table for conversation-scoped metadata and stop +stamping that metadata onto every PromptMemoryEntry row. + +Creates ``Conversations`` (one row per ``conversation_id``) holding the target +identifier, backfills it from the existing +``PromptMemoryEntries.prompt_target_identifier`` column (plus placeholder rows for +conversation_ids referenced only by ``AttackResultEntries``), and drops the now +per-row ``prompt_target_identifier`` and ``attack_identifier`` columns from +``PromptMemoryEntries``. + +Revision ID: b2f4c6a8d1e3 +Revises: 9c8b7a6d5e4f +Create Date: 2026-05-20 12:00:00.000000 +""" + +from __future__ import annotations + +import logging +from collections.abc import Sequence # noqa: TC003 + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b2f4c6a8d1e3" +down_revision: str | None = "9c8b7a6d5e4f" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +logger = logging.getLogger(__name__) + + +def upgrade() -> None: + """Apply this schema upgrade.""" + op.create_table( + "Conversations", + sa.Column("conversation_id", sa.String(), primary_key=True, nullable=False), + sa.Column("target_identifier", sa.JSON(), nullable=True), + sa.Column("pyrit_version", sa.String(), nullable=True), + ) + + _backfill_conversations() + + # Stop persisting conversation-scoped metadata per row: the target identifier now + # lives in Conversations, and the attack identifier is no longer stamped on pieces + # (resolved via AttackResult). Batch op for SQLite portability. + with op.batch_alter_table("PromptMemoryEntries") as batch_op: + batch_op.drop_column("prompt_target_identifier") + batch_op.drop_column("attack_identifier") + + +def downgrade() -> None: + """Revert this schema upgrade.""" + # Re-add the dropped columns (data is not restored) then drop Conversations. + with op.batch_alter_table("PromptMemoryEntries") as batch_op: + batch_op.add_column(sa.Column("prompt_target_identifier", sa.JSON(), nullable=True)) + batch_op.add_column(sa.Column("attack_identifier", sa.JSON(), nullable=True)) + op.drop_table("Conversations") + + +def _backfill_conversations() -> None: + """ + Populate ``Conversations`` with one row per distinct ``conversation_id``. + + The target identifier is taken from the existing + ``PromptMemoryEntries.prompt_target_identifier`` column, preferring a non-null + value when a conversation has rows with differing targets (a non-null target + always wins over null; a WARNING is logged if two distinct non-null targets are + seen for the same conversation). Conversation ids that are referenced only by + ``AttackResultEntries`` (no prompt rows) get a placeholder row with a null + target so reads/joins stay consistent. + + Idempotent: only conversation_ids not already present in ``Conversations`` are + inserted. + """ + bind = op.get_bind() + + existing_ids = { + row[0] for row in bind.execute(sa.text('SELECT conversation_id FROM "Conversations"')).fetchall() + } + + targets_by_conversation: dict[str, str | None] = {} + conflict_warnings = 0 + + prompt_rows = bind.execute( + sa.text( + 'SELECT conversation_id, prompt_target_identifier ' + 'FROM "PromptMemoryEntries" ' + "WHERE conversation_id IS NOT NULL " + "ORDER BY sequence" + ) + ).fetchall() + + for conversation_id, target_identifier in prompt_rows: + if conversation_id is None: + continue + current = targets_by_conversation.get(conversation_id, "__unset__") + if current == "__unset__": + targets_by_conversation[conversation_id] = target_identifier + elif target_identifier is not None: + if current is None: + targets_by_conversation[conversation_id] = target_identifier + elif current != target_identifier: + conflict_warnings += 1 + logger.warning( + f"Backfill: conversation_id {conversation_id!r} has multiple distinct " + f"target identifiers; keeping the first non-null value." + ) + + # Conversation ids referenced only by AttackResultEntries (no prompt rows). + attack_rows = bind.execute( + sa.text('SELECT DISTINCT conversation_id FROM "AttackResultEntries" WHERE conversation_id IS NOT NULL') + ).fetchall() + for (conversation_id,) in attack_rows: + if conversation_id is not None and conversation_id not in targets_by_conversation: + targets_by_conversation[conversation_id] = None + + insert_stmt = sa.text( + 'INSERT INTO "Conversations" (conversation_id, target_identifier, pyrit_version) ' + "VALUES (:cid, :target, :version)" + ) + + inserted = 0 + for conversation_id, target_identifier in targets_by_conversation.items(): + if conversation_id in existing_ids: + continue + bind.execute( + insert_stmt, + {"cid": conversation_id, "target": target_identifier, "version": None}, + ) + inserted += 1 + + if inserted or conflict_warnings: + logger.info( + f"Conversations backfill: inserted {inserted} row(s); " + f"{conflict_warnings} target-conflict warning(s)." + ) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 7d62ab5bd9..f3bfe8e126 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -27,7 +27,7 @@ PromptMemoryEntry, ) from pyrit.memory.storage import AzureBlobStorageIO -from pyrit.models import ConversationStats, MessagePiece +from pyrit.models import ComponentIdentifier, ConversationStats, MessagePiece if TYPE_CHECKING: from azure.core.credentials import AccessToken @@ -695,7 +695,9 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any conditions.append(condition) return and_(*conditions) - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: + def add_message_pieces_to_memory( + self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None + ) -> None: """ Insert a list of message pieces into the memory storage. @@ -705,6 +707,8 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] Args: message_pieces (Sequence[MessagePiece]): A sequence of MessagePiece instances to be added. + target_identifier (ComponentIdentifier | None): The target the conversation(s) + are held with, if known. Applied to every distinct ``conversation_id``. """ # ``not_in_memory`` pieces are ephemeral — typically synthesized inside a # scorer to score arbitrary content that never came through a real @@ -715,6 +719,7 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] pieces_to_insert = [piece for piece in message_pieces if not piece.not_in_memory] if not pieces_to_insert: return + self._capture_conversations(message_pieces=pieces_to_insert, target_identifier=target_identifier) self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in pieces_to_insert]) def dispose_engine(self) -> None: @@ -824,7 +829,9 @@ def _query_entries( try: query = session.query(model_class) if join_scores and model_class == PromptMemoryEntry: - query = query.options(joinedload(PromptMemoryEntry.scores)) + query = query.options( + joinedload(PromptMemoryEntry.scores), + ) elif model_class == AttackResultEntry: query = query.options( joinedload(AttackResultEntry.last_response).joinedload(PromptMemoryEntry.scores), diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index ff4a6af97f..a1c2219cd4 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, TypeVar -from sqlalchemy import MetaData, and_, not_, or_ +from sqlalchemy import MetaData, and_, not_, or_, select from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.attributes import InstrumentedAttribute @@ -27,6 +27,7 @@ from pyrit.memory.memory_models import ( AttackResultEntry, Base, + ConversationEntry, EmbeddingDataEntry, PromptMemoryEntry, ScenarioResultEntry, @@ -41,6 +42,8 @@ ) from pyrit.models import ( AttackResult, + ComponentIdentifier, + Conversation, ConversationStats, IdentifierFilter, IdentifierType, @@ -347,10 +350,89 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, str | int]) -> An """ @abc.abstractmethod - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: + def add_message_pieces_to_memory( + self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None + ) -> None: """ Insert a list of message pieces into the memory storage. + + Args: + message_pieces (Sequence[MessagePiece]): The pieces to persist. + target_identifier (ComponentIdentifier | None): The target the conversation(s) + are held with, if known. A conversation is always with a single target, so + this is applied to every distinct ``conversation_id`` in ``message_pieces``. + """ + + def _capture_conversations( + self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None + ) -> None: """ + Record one ``Conversations`` row per conversation for the given pieces. + + Conversation-scoped metadata (currently the target identifier) is persisted + once per ``conversation_id`` instead of being stamped onto every piece. This + runs from each backend's ``add_message_pieces_to_memory`` so every write path + -- normalizer, conversation duplication, prepended conversations, direct + target writers -- captures the target through a single choke point. + + A conversation is always held with a single target, so ``target_identifier`` + (when provided) is applied to every distinct ``conversation_id`` in this call. + A ``None`` target never overwrites a target already recorded for the + conversation (see ``_upsert_conversation``). + + Args: + message_pieces (Sequence[MessagePiece]): The pieces being persisted. + target_identifier (ComponentIdentifier | None): The target the conversation(s) + are held with, if known. + """ + conversation_ids: list[str] = [] + seen: set[str] = set() + for piece in message_pieces: + if piece.not_in_memory: + continue + conversation_id = piece.conversation_id + if conversation_id not in seen: + seen.add(conversation_id) + conversation_ids.append(conversation_id) + for conversation_id in conversation_ids: + self._upsert_conversation(conversation_id=conversation_id, target_identifier=target_identifier) + + def _upsert_conversation( + self, *, conversation_id: str, target_identifier: ComponentIdentifier | None + ) -> None: + """ + Insert or update the ``Conversations`` row for ``conversation_id``. + + A non-``None`` ``target_identifier`` is written; a ``None`` value never + overwrites a target already recorded for the conversation (so response/copy + pieces and write ordering cannot clobber it). + + Args: + conversation_id (str): The conversation to record. + target_identifier (ComponentIdentifier | None): The target the conversation + is held with, if known. + + Raises: + SQLAlchemyError: If the upsert fails. + """ + if not conversation_id: + return + entry = ConversationEntry( + conversation=Conversation(conversation_id=conversation_id, target_identifier=target_identifier) + ) + with closing(self.get_session()) as session: + try: + existing = session.get(ConversationEntry, conversation_id) + if existing is None: + session.add(entry) + elif target_identifier is not None: + existing.target_identifier = entry.target_identifier + existing.pyrit_version = entry.pyrit_version + session.commit() + except SQLAlchemyError as e: + session.rollback() + logger.exception(f"Error upserting conversation {conversation_id}: {e}") + raise @abc.abstractmethod def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None: @@ -856,6 +938,25 @@ def get_conversation(self, *, conversation_id: str) -> MutableSequence[Message]: message_pieces = self.get_message_pieces(conversation_id=conversation_id) return group_conversation_message_pieces_by_sequence(message_pieces=message_pieces) + def get_conversation_metadata(self, *, conversation_id: str) -> Conversation | None: + """ + Return the conversation-scoped metadata stored for ``conversation_id``. + + Args: + conversation_id (str): The conversation to look up. + + Returns: + Conversation | None: The conversation metadata (including the target + identifier), or ``None`` if no row exists for the conversation. + """ + entries = self._query_entries( + ConversationEntry, + conditions=ConversationEntry.conversation_id == str(conversation_id), + ) + if not entries: + return None + return entries[0].get_conversation() + def get_request_from_response(self, *, response: Message) -> Message: """ Retrieve the request that produced the given response. @@ -877,6 +978,80 @@ def get_request_from_response(self, *, response: Message) -> Message: conversation = self.get_conversation(conversation_id=response.conversation_id) return conversation[response.sequence - 1] + def _resolve_attack_id_to_conversation_condition(self, *, attack_id: str | uuid.UUID) -> Any: + """ + Build a deprecated ``attack_id`` filter condition for ``get_message_pieces``. + + The attack identifier is no longer stamped on every piece. Instead, resolve the + raw attack-strategy hash against persisted ``AttackResult`` rows and constrain + the query to those attacks' main conversations. + + Args: + attack_id (str | uuid.UUID): The raw attack-strategy identifier hash. + + Returns: + Any: A SQLAlchemy condition restricting pieces to the matching attacks' + main conversation ids (matches nothing when no attack matches). + """ + print_deprecation_message( + old_item="get_message_pieces(attack_id=...) / get_prompt_scores(attack_id=...)", + new_item="get_message_pieces(conversation_id=...) resolved via get_attack_results(...)", + removed_in="0.17.0", + ) + matching_conversation_ids = { + result.conversation_id + for result in self.get_attack_results() + if (strategy := result.get_attack_strategy_identifier()) is not None and strategy.hash == str(attack_id) + } + return PromptMemoryEntry.conversation_id.in_(matching_conversation_ids) + + def _build_message_piece_identifier_conditions( + self, *, identifier_filters: Sequence[IdentifierFilter] + ) -> list[Any]: + """ + Build ``get_message_pieces`` conditions for identifier filters. + + ``CONVERTER`` identifiers remain on the piece. ``TARGET`` identifiers moved to + the ``Conversations`` table, so target filters are applied via a subquery on + ``ConversationEntry`` correlated by ``conversation_id``. ``ATTACK`` identifiers + are no longer stamped on pieces (use ``get_attack_results`` instead) and are + rejected by ``_build_identifier_filter_conditions``. + + Args: + identifier_filters (Sequence[IdentifierFilter]): The filters to convert. + + Returns: + list[Any]: SQLAlchemy conditions for the message-piece query. + """ + conditions: list[Any] = [] + piece_filters = [f for f in identifier_filters if f.identifier_type != IdentifierType.TARGET] + target_filters = [f for f in identifier_filters if f.identifier_type == IdentifierType.TARGET] + + if piece_filters: + conditions.extend( + self._build_identifier_filter_conditions( + identifier_filters=piece_filters, + identifier_column_map={ + IdentifierType.CONVERTER: PromptMemoryEntry.converter_identifiers, + }, + caller="get_message_pieces", + ) + ) + if target_filters: + target_conditions = self._build_identifier_filter_conditions( + identifier_filters=target_filters, + identifier_column_map={ + IdentifierType.TARGET: ConversationEntry.target_identifier, + }, + caller="get_message_pieces", + ) + conditions.append( + PromptMemoryEntry.conversation_id.in_( + select(ConversationEntry.conversation_id).where(and_(*target_conditions)) + ) + ) + return conditions + def get_message_pieces( self, *, @@ -932,13 +1107,7 @@ def get_message_pieces( try: conditions: list[Any] = [] if attack_id: - conditions.append( - self._get_condition_json_property_match( - json_column=PromptMemoryEntry.attack_identifier, - property_path="$.hash", - value=str(attack_id), - ) - ) + conditions.append(self._resolve_attack_id_to_conversation_condition(attack_id=attack_id)) if role: conditions.append(PromptMemoryEntry.role == role) if conversation_id: @@ -956,17 +1125,7 @@ def get_message_pieces( if not_data_type: conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) if identifier_filters: - conditions.extend( - self._build_identifier_filter_conditions( - identifier_filters=identifier_filters, - identifier_column_map={ - IdentifierType.ATTACK: PromptMemoryEntry.attack_identifier, - IdentifierType.TARGET: PromptMemoryEntry.prompt_target_identifier, - IdentifierType.CONVERTER: PromptMemoryEntry.converter_identifiers, - }, - caller="get_message_pieces", - ) - ) + conditions.extend(self._build_message_piece_identifier_conditions(identifier_filters=identifier_filters)) # Identify list parameters that may need batching list_params: list[tuple[InstrumentedAttribute[Any], Sequence[Any], str]] = [] @@ -1034,8 +1193,10 @@ def duplicate_conversation(self, *, conversation_id: str) -> str: The uuid for the new conversation. """ messages = self.get_conversation(conversation_id=conversation_id) + source_metadata = self.get_conversation_metadata(conversation_id=conversation_id) + source_target = source_metadata.target_identifier if source_metadata else None new_conversation_id, all_pieces = self.duplicate_messages(messages=messages) - self.add_message_pieces_to_memory(message_pieces=all_pieces) + self.add_message_pieces_to_memory(message_pieces=all_pieces, target_identifier=source_target) return new_conversation_id def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> str: @@ -1067,12 +1228,16 @@ def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> message for message in messages if message.sequence <= last_message.sequence - length_of_sequence_to_remove ] + source_metadata = self.get_conversation_metadata(conversation_id=conversation_id) + source_target = source_metadata.target_identifier if source_metadata else None new_conversation_id, all_pieces = self.duplicate_messages(messages=messages_to_duplicate) - self.add_message_pieces_to_memory(message_pieces=all_pieces) + self.add_message_pieces_to_memory(message_pieces=all_pieces, target_identifier=source_target) return new_conversation_id - def add_message_to_memory(self, *, request: Message) -> None: + def add_message_to_memory( + self, *, request: Message, target_identifier: ComponentIdentifier | None = None + ) -> None: """ Insert a list of message pieces into the memory storage. @@ -1080,7 +1245,9 @@ def add_message_to_memory(self, *, request: Message) -> None: If necessary, generates embedding data for applicable entries Args: - request (MessagePiece): The message piece to add to the memory. + request (Message): The message to add to the memory. + target_identifier (ComponentIdentifier | None): The target the conversation + is held with, if known. Forwarded to ``add_message_pieces_to_memory``. """ request.validate() @@ -1089,7 +1256,7 @@ def add_message_to_memory(self, *, request: Message) -> None: self._update_sequence(message_pieces=message_pieces) - self.add_message_pieces_to_memory(message_pieces=message_pieces) + self.add_message_pieces_to_memory(message_pieces=message_pieces, target_identifier=target_identifier) if self.memory_embedding: for piece in message_pieces: diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 9f52b38afa..afec28d01e 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -37,6 +37,7 @@ AttackResult, ChatMessageRole, ComponentIdentifier, + Conversation, ConversationReference, ConversationType, MessagePiece, @@ -239,7 +240,6 @@ class PromptMemoryEntry(Base): e.g. the URI from a file uploaded to a blob store, or a document type you want to upload. converters (list[PromptConverter]): The converters for the prompt. prompt_target (PromptTarget): The target for the prompt. - attack_identifier (dict[str, str]): The attack identifier for the prompt. original_value_data_type (PromptDataType): The data type of the original prompt (text, image) original_value (str): The text of the original prompt. If prompt is an image, it's a link. original_value_sha256 (str): The SHA256 hash of the original prompt data. @@ -267,8 +267,6 @@ class PromptMemoryEntry(Base): prompt_metadata: Mapped[dict[str, str | int]] = mapped_column(JSON) targeted_harm_categories: Mapped[list[str] | None] = mapped_column(JSON) converter_identifiers: Mapped[list[dict[str, str]] | None] = mapped_column(JSON) - prompt_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON) - attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON) response_error: Mapped[Literal["blocked", "none", "processing", "unknown"]] = mapped_column(String, nullable=True) original_value_data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False) @@ -310,8 +308,6 @@ def __init__(self, *, entry: MessagePiece) -> None: self.prompt_metadata = entry.prompt_metadata self.targeted_harm_categories = entry.targeted_harm_categories self.converter_identifiers = _dump_identifiers(entry.converter_identifiers) - self.prompt_target_identifier = _dump_identifier(entry.prompt_target_identifier) or {} - self.attack_identifier = _dump_identifier(entry.attack_identifier) or {} self.original_value = entry.original_value self.original_value_data_type = entry.original_value_data_type @@ -336,8 +332,6 @@ def get_message_piece(self) -> MessagePiece: # Reconstruct ComponentIdentifiers with the stored pyrit_version stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION converter_ids = _load_identifiers(self.converter_identifiers, pyrit_version=stored_version) - target_id = _load_identifier(self.prompt_target_identifier, pyrit_version=stored_version) - attack_id = _load_identifier(self.attack_identifier, pyrit_version=stored_version) message_piece = MessagePiece( role=self.role, @@ -350,8 +344,6 @@ def get_message_piece(self) -> MessagePiece: sequence=self.sequence, prompt_metadata=self.prompt_metadata, converter_identifiers=converter_ids or [], - prompt_target_identifier=target_id, - attack_identifier=attack_id, original_value_data_type=self.original_value_data_type, converted_value_data_type=self.converted_value_data_type, response_error=self.response_error, @@ -374,13 +366,51 @@ def __str__(self) -> str: Returns: str: Formatted string representation of the memory entry. """ - if self.prompt_target_identifier: - # prompt_target_identifier is stored as dict in the database - class_name = self.prompt_target_identifier.get("class_name") or self.prompt_target_identifier.get( - "__type__", "Unknown" - ) - return f"{class_name}: {self.role}: {self.converted_value}" - return f": {self.role}: {self.converted_value}" + return f"{self.role}: {self.converted_value}" + + +class ConversationEntry(Base): + """ + Conversation-scoped metadata, persisted once per ``conversation_id``. + + Holds identifiers that belong to the conversation as a whole -- currently the + target identifier -- so they are not duplicated onto every ``PromptMemoryEntry`` + row. The target is captured once when the conversation's pieces are written and + read back via ``MemoryInterface.get_conversation_metadata`` (it is not stamped + onto individual pieces). + """ + + __tablename__ = "Conversations" + __table_args__ = {"extend_existing": True} + + conversation_id = mapped_column(String, primary_key=True, nullable=False) + target_identifier: Mapped[dict[str, str] | None] = mapped_column(JSON, nullable=True) + + # Version of PyRIT used when this entry was created. Nullable for backwards + # compatibility with existing databases. + pyrit_version = mapped_column(String, nullable=True) + + def __init__(self, *, conversation: Conversation) -> None: + """ + Initialize a ConversationEntry from a Conversation model. + + Args: + conversation (Conversation): The conversation metadata to persist. + """ + self.conversation_id = conversation.conversation_id + self.target_identifier = _dump_identifier(conversation.target_identifier) + self.pyrit_version = pyrit.__version__ + + def get_conversation(self) -> Conversation: + """ + Convert this database entry back into a Conversation model. + + Returns: + Conversation: The reconstructed conversation metadata. + """ + stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION + target_id = _load_identifier(self.target_identifier, pyrit_version=stored_version) + return Conversation(conversation_id=self.conversation_id, target_identifier=target_id) class EmbeddingDataEntry(Base): diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 61f556aee2..17c0f9485b 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -30,7 +30,7 @@ ScenarioResultEntry, ) from pyrit.memory.storage import DiskStorageIO -from pyrit.models import ConversationStats, MessagePiece +from pyrit.models import ComponentIdentifier, ConversationStats, MessagePiece logger = logging.getLogger(__name__) @@ -302,17 +302,25 @@ def _get_condition_json_array_match( combined = joiner.join(conditions) return text(f"({combined})").bindparams(**bindparams_dict) - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: + def add_message_pieces_to_memory( + self, *, message_pieces: Sequence[MessagePiece], target_identifier: ComponentIdentifier | None = None + ) -> None: """ Insert a list of message pieces into the memory storage. Pieces flagged via ``MessagePiece.not_in_memory = True`` are silently filtered out so callers don't need to track persistence policy themselves. + + Args: + message_pieces (Sequence[MessagePiece]): The pieces to persist. + target_identifier (ComponentIdentifier | None): The target the conversation(s) + are held with, if known. Applied to every distinct ``conversation_id``. """ pieces_to_insert = [piece for piece in message_pieces if not piece.not_in_memory] if not pieces_to_insert: return + self._capture_conversations(message_pieces=pieces_to_insert, target_identifier=target_identifier) self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in pieces_to_insert]) def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None: @@ -362,10 +370,13 @@ def _query_entries( try: query = session.query(model_class) if join_scores and model_class == PromptMemoryEntry: - query = query.options(joinedload(PromptMemoryEntry.scores)) + query = query.options( + joinedload(PromptMemoryEntry.scores), + ) elif model_class == AttackResultEntry: query = query.options( - joinedload(AttackResultEntry.last_response).joinedload(PromptMemoryEntry.scores), + joinedload(AttackResultEntry.last_response) + .joinedload(PromptMemoryEntry.scores), joinedload(AttackResultEntry.last_score), ) if conditions is not None: diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 6ad07cfcc9..823543dfd6 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -60,6 +60,7 @@ SeedType, ) from pyrit.models.messages import ( + Conversation, Message, MessagePiece, construct_response_from_request, @@ -112,6 +113,7 @@ "ComponentIdentifier", "compute_eval_hash", "config_hash", + "Conversation", "ConversationReference", "ConversationStats", "ConversationType", diff --git a/pyrit/models/messages/__init__.py b/pyrit/models/messages/__init__.py index fca91f47ba..58c9f1a63e 100644 --- a/pyrit/models/messages/__init__.py +++ b/pyrit/models/messages/__init__.py @@ -9,6 +9,7 @@ - conversations: Free functions that operate on collections of messages/pieces. """ +from pyrit.models.messages.conversation import Conversation from pyrit.models.messages.conversations import ( construct_response_from_request, flatten_to_message_pieces, @@ -20,6 +21,7 @@ from pyrit.models.messages.message_piece import MessagePiece, sort_message_pieces __all__ = [ + "Conversation", "Message", "MessagePiece", "construct_response_from_request", diff --git a/pyrit/models/messages/conversation.py b/pyrit/models/messages/conversation.py new file mode 100644 index 0000000000..f5b8d956fb --- /dev/null +++ b/pyrit/models/messages/conversation.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict + +from pyrit.models.score import ( # noqa: TC001 (runtime-required by Pydantic field annotations) + ComponentIdentifierField, +) + + +class Conversation(BaseModel): + """ + Conversation-scoped metadata shared by every piece in a conversation. + + A ``Conversation`` records identifiers that belong to the conversation as a + whole rather than to any individual ``MessagePiece`` -- most importantly the + target the conversation is held with. Persisting these once per conversation + (instead of stamping them onto every piece/row) is what keeps ``MessagePiece`` + small. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + validate_assignment=False, + ) + + conversation_id: str + target_identifier: ComponentIdentifierField | None = None diff --git a/pyrit/models/messages/conversations.py b/pyrit/models/messages/conversations.py index b225e527b2..e4cb34a121 100644 --- a/pyrit/models/messages/conversations.py +++ b/pyrit/models/messages/conversations.py @@ -205,8 +205,6 @@ def construct_response_from_request( original_value=resp_text, conversation_id=request.conversation_id, labels=request.labels, - prompt_target_identifier=request.prompt_target_identifier, - attack_identifier=request.attack_identifier, original_value_data_type=response_type, converted_value_data_type=response_type, prompt_metadata=prompt_metadata or {}, diff --git a/pyrit/models/messages/message_piece.py b/pyrit/models/messages/message_piece.py index 1ad0533231..da097b129c 100644 --- a/pyrit/models/messages/message_piece.py +++ b/pyrit/models/messages/message_piece.py @@ -112,8 +112,6 @@ class MessagePiece(BaseModel): targeted_harm_categories: list[str] = Field(default_factory=list) prompt_metadata: dict[str, Any] = Field(default_factory=dict) converter_identifiers: list[ComponentIdentifierField] = Field(default_factory=list) - prompt_target_identifier: ComponentIdentifierField | None = None - attack_identifier: ComponentIdentifierField | None = None scorer_identifier: ComponentIdentifierField | None = None scores: list[Score] = Field(default_factory=list) @@ -219,7 +217,7 @@ def copy_lineage_from(self, *, source: MessagePiece) -> None: Copy lineage metadata from ``source`` onto this piece. Lineage fields are the metadata that tie a piece back to its originating - conversation, attack, and target. Mutable containers (``labels``, + conversation. Mutable containers (``labels``, ``prompt_metadata``) are shallow-copied so that mutations on one piece do not affect others. @@ -228,8 +226,6 @@ def copy_lineage_from(self, *, source: MessagePiece) -> None: """ self.conversation_id = source.conversation_id self.labels = dict(source.labels) - self.attack_identifier = source.attack_identifier - self.prompt_target_identifier = source.prompt_target_identifier self.prompt_metadata = dict(source.prompt_metadata) def has_error(self) -> bool: diff --git a/pyrit/models/seeds/seed_group.py b/pyrit/models/seeds/seed_group.py index 3998820c38..5a37c58bd3 100644 --- a/pyrit/models/seeds/seed_group.py +++ b/pyrit/models/seeds/seed_group.py @@ -450,7 +450,6 @@ def _prompts_to_messages(self, prompts: Sequence[SeedPrompt]) -> list[Message]: role=role, original_value=prompt.value, original_value_data_type=prompt.data_type or "text", - prompt_target_identifier=None, conversation_id=str(prompt.prompt_group_id), sequence=sequence, prompt_metadata=prompt.metadata, diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index e1422d68e7..210e6a5ab9 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -157,7 +157,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text self._converter_target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, - attack_identifier=None, ) converted_prompt = prompt @@ -175,7 +174,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text converted_value=converted_prompt, conversation_id=conversation_id, sequence=1, - prompt_target_identifier=self._converter_target.get_identifier(), original_value_data_type=input_type, converted_value_data_type=input_type, converter_identifiers=[self.get_identifier()], diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 360dcde35c..c005e687ab 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -88,7 +88,7 @@ async def send_prompt_async( labels (dict[str, str] | None, optional): Labels associated with the request. Defaults to None. Deprecated: This parameter will be removed in a release 0.16.0. attack_identifier (ComponentIdentifier | None, optional): Identifier for the attack. Defaults to - None. + None. Deprecated: this parameter is ignored and will be removed in release 0.17.0. Returns: Message: The response received from the target. @@ -103,6 +103,12 @@ async def send_prompt_async( new_item="send_prompt_async(...)", removed_in="0.16.0", ) + if attack_identifier is not None: + print_deprecation_message( + old_item="send_prompt_async(..., attack_identifier=...)", + new_item="send_prompt_async(...)", + removed_in="0.17.0", + ) # Validates that the MessagePieces in the Message are part of the same sequence request_converter_configurations = request_converter_configurations or [] response_converter_configurations = response_converter_configurations or [] @@ -112,14 +118,12 @@ async def send_prompt_async( # Prepare the request by updating conversation ID, labels, and attack identifier request = copy.deepcopy(message) conversation_id = conversation_id if conversation_id else str(uuid4()) + target_identifier = target.get_identifier() for piece in request.message_pieces: piece.conversation_id = conversation_id if labels: piece.labels = labels # deprecated - piece.prompt_target_identifier = target.get_identifier() - if attack_identifier: - piece.attack_identifier = attack_identifier # Apply request converters await self.convert_values_async(converter_configurations=request_converter_configurations, message=request) @@ -130,10 +134,10 @@ async def send_prompt_async( try: responses = await target.send_prompt_async(message=request) - self.memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) except EmptyResponseException: # Empty responses are retried, but we don't want them to stop execution - self.memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) responses = [ construct_response_from_request( @@ -146,7 +150,7 @@ async def send_prompt_async( except Exception as ex: # Ensure request to memory before processing exception - self.memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) error_response = construct_response_from_request( request=request.message_pieces[0], @@ -156,7 +160,7 @@ async def send_prompt_async( ) await self._calc_hash_async(request=error_response) - self.memory.add_message_to_memory(request=error_response) + self.memory.add_message_to_memory(request=error_response, target_identifier=target_identifier) cid = request.message_pieces[0].conversation_id if request and request.message_pieces else None raise Exception(f"Error sending prompt with conversation ID: {cid}") from ex @@ -173,7 +177,7 @@ async def send_prompt_async( error="empty", ) await self._calc_hash_async(request=empty_response) - self.memory.add_message_to_memory(request=empty_response) + self.memory.add_message_to_memory(request=empty_response, target_identifier=target_identifier) return empty_response # Process all response messages (targets return list[Message]) @@ -186,7 +190,7 @@ async def send_prompt_async( converter_configurations=response_converter_configurations, message=resp ) await self._calc_hash_async(request=resp) - self.memory.add_message_to_memory(request=resp) + self.memory.add_message_to_memory(request=resp, target_identifier=target_identifier) # Return the last response for backward compatibility return responses[-1] @@ -209,7 +213,7 @@ async def send_prompt_batch_to_target_async( labels (dict[str, str] | None, optional): A dictionary of labels to be included with the request. Defaults to None. attack_identifier (ComponentIdentifier | None, optional): The attack identifier. - Defaults to None. + Defaults to None. Deprecated: this parameter is ignored and will be removed in release 0.17.0. batch_size (int, optional): The number of prompts to include in each batch. Defaults to 10. Returns: @@ -380,7 +384,9 @@ async def _calc_hash_async(self, request: Message) -> None: tasks = [asyncio.create_task(set_message_piece_sha256_async(piece)) for piece in request.message_pieces] await asyncio.gather(*tasks) - async def hash_and_persist_message_async(self, *, message: Message) -> None: + async def hash_and_persist_message_async( + self, *, message: Message, target_identifier: ComponentIdentifier | None = None + ) -> None: """ Hash and persist a Message to memory. @@ -389,9 +395,11 @@ async def hash_and_persist_message_async(self, *, message: Message) -> None: Args: message (Message): The message to hash and persist. + target_identifier (ComponentIdentifier | None): The target the conversation + is held with, if known. """ await self._calc_hash_async(request=message) - self.memory.add_message_to_memory(request=message) + self.memory.add_message_to_memory(request=message, target_identifier=target_identifier) async def add_prepended_conversation_to_memory_async( self, @@ -400,6 +408,7 @@ async def add_prepended_conversation_to_memory_async( converter_configurations: list[PromptConverterConfiguration] | None = None, attack_identifier: ComponentIdentifier | None = None, prepended_conversation: list[Message] | None = None, + target_identifier: ComponentIdentifier | None = None, ) -> list[Message] | None: """ Process the prepended conversation by converting it if needed and adding it to memory. @@ -409,8 +418,11 @@ async def add_prepended_conversation_to_memory_async( should_convert (bool): Whether to convert the prepended conversation converter_configurations (list[PromptConverterConfiguration] | None): Configurations for converting the request - attack_identifier (ComponentIdentifier | None): Identifier for the attack + attack_identifier (ComponentIdentifier | None): Identifier for the attack. + Deprecated: this parameter is ignored and will be removed in release 0.17.0. prepended_conversation (list[Message] | None): The conversation to prepend + target_identifier (ComponentIdentifier | None): The target the conversation is held + with, if known. Recorded once per conversation. Returns: list[Message] | None: The processed prepended conversation @@ -418,6 +430,13 @@ async def add_prepended_conversation_to_memory_async( if not prepended_conversation: return None + if attack_identifier is not None: + print_deprecation_message( + old_item="add_prepended_conversation_to_memory_async(..., attack_identifier=...)", + new_item="add_prepended_conversation_to_memory_async(...)", + removed_in="0.17.0", + ) + # Create a deep copy of the prepended conversation to avoid modifying the original prepended_conversation = copy.deepcopy(prepended_conversation) @@ -426,14 +445,12 @@ async def add_prepended_conversation_to_memory_async( await self.convert_values_async(message=request, converter_configurations=converter_configurations) for piece in request.message_pieces: piece.conversation_id = conversation_id - if attack_identifier: - piece.attack_identifier = attack_identifier # if the piece is retrieved from somewhere else, it needs to be unique # and if not, this won't hurt anything piece.id = uuid4() - self.memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request, target_identifier=target_identifier) return prepended_conversation @@ -457,6 +474,7 @@ async def add_prepended_conversation_to_memory( # pyrit-async-suffix-exempt converter_configurations: list[PromptConverterConfiguration] | None = None, attack_identifier: ComponentIdentifier | None = None, prepended_conversation: list[Message] | None = None, + target_identifier: ComponentIdentifier | None = None, ) -> list[Message] | None: """ Use ``add_prepended_conversation_to_memory_async`` instead; this is a deprecated alias. @@ -475,6 +493,7 @@ async def add_prepended_conversation_to_memory( # pyrit-async-suffix-exempt converter_configurations=converter_configurations, attack_identifier=attack_identifier, prepended_conversation=prepended_conversation, + target_identifier=target_identifier, ) diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index 44e0313b89..ba91daee80 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -322,7 +322,9 @@ async def _probe_system_prompt_async(target: PromptTarget, timeout_s: float, ret prompt_metadata=_probe_metadata(), ) try: - target._memory.add_message_to_memory(request=Message(message_pieces=[system_piece])) + target._memory.add_message_to_memory( + request=Message(message_pieces=[system_piece]), target_identifier=target.get_identifier() + ) except Exception as exc: logger.debug("System-prompt probe could not seed system message: %s", exc) return False @@ -406,7 +408,9 @@ async def _probe_multi_turn_async(target: PromptTarget, timeout_s: float, retrie # Seed memory so the second send sees real prior history. try: - target._memory.add_message_to_memory(request=Message(message_pieces=[first])) + target._memory.add_message_to_memory( + request=Message(message_pieces=[first]), target_identifier=target.get_identifier() + ) assistant_reply = MessagePiece( role="assistant", original_value="Got it.", @@ -414,7 +418,9 @@ async def _probe_multi_turn_async(target: PromptTarget, timeout_s: float, retrie conversation_id=conversation_id, prompt_metadata=_probe_metadata(), ).to_message() - target._memory.add_message_to_memory(request=assistant_reply) + target._memory.add_message_to_memory( + request=assistant_reply, target_identifier=target.get_identifier() + ) except Exception as exc: logger.debug("Multi-turn probe could not seed conversation history: %s", exc) return False diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 218f46d52e..3ff6416bb4 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -295,6 +295,7 @@ def set_system_prompt( system_prompt (str): The system prompt text to set. conversation_id (str): The conversation id to attach the prompt to. attack_identifier (ComponentIdentifier | None): Optional attack identifier. + Deprecated: this parameter is ignored and will be removed in release 0.17.0. labels (dict[str, str] | None): Optional labels. Raises: @@ -308,6 +309,13 @@ def set_system_prompt( removed_in="0.16.0", ) + if attack_identifier is not None: + print_deprecation_message( + old_item="set_system_prompt(..., attack_identifier=...)", + new_item="set_system_prompt(...)", + removed_in="0.17.0", + ) + if not self.capabilities.supports_multi_turn or not self.capabilities.supports_editable_history: raise ValueError( f"Target {type(self).__name__} does not support setting a system prompt. " @@ -325,10 +333,9 @@ def set_system_prompt( conversation_id=conversation_id, original_value=system_prompt, converted_value=system_prompt, - prompt_target_identifier=self.get_identifier(), - attack_identifier=attack_identifier, labels=labels or {}, - ).to_message() + ).to_message(), + target_identifier=self.get_identifier(), ) def dispose_db_engine(self) -> None: diff --git a/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py b/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py index abe1d89b06..6d7a0767e4 100644 --- a/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py +++ b/pyrit/prompt_target/openai/_openai_realtime_streaming_session.py @@ -30,7 +30,6 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator - from pyrit.models import ComponentIdentifier from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target.common.realtime_audio import CommittedEvent @@ -127,7 +126,6 @@ def __init__( response_converter_configurations: list[PromptConverterConfiguration] | None = None, prepended_conversation: list[Message] | None = None, server_vad: bool | ServerVadConfig = True, - attack_identifier: ComponentIdentifier | None = None, persist_prepended_conversation: bool = True, ) -> None: self._target = target @@ -137,7 +135,6 @@ def __init__( self._request_converter_configurations = request_converter_configurations or [] self._response_converter_configurations = response_converter_configurations or [] self._prepended_conversation = prepended_conversation or [] - self._attack_identifier = attack_identifier self._persist_prepended_conversation = persist_prepended_conversation # Normalize server_vad once at construction so config send and commit-time trim @@ -206,6 +203,7 @@ async def run_async(self) -> AsyncIterator[Message]: conversation_id=self._conversation_id, should_convert=False, prepended_conversation=self._prepended_conversation, + target_identifier=self._target.get_identifier(), ) self._queue = asyncio.Queue() @@ -411,8 +409,6 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: converted_value=converted_user_path, converted_value_data_type="audio_path", conversation_id=self._conversation_id, - prompt_target_identifier=target_identifier, - attack_identifier=self._attack_identifier, ) for cfg in self._request_converter_configurations: user_piece.converter_identifiers.extend(converter.get_identifier() for converter in cfg.converters) @@ -423,16 +419,12 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: original_value=result.flatten_transcripts(), original_value_data_type="text", conversation_id=self._conversation_id, - prompt_target_identifier=target_identifier, - attack_identifier=self._attack_identifier, ) assistant_audio_piece = MessagePiece( role="assistant", original_value=assistant_audio_path, original_value_data_type="audio_path", conversation_id=self._conversation_id, - prompt_target_identifier=target_identifier, - attack_identifier=self._attack_identifier, ) if result.interrupted: assistant_text_piece.prompt_metadata[STREAMING_INTERRUPTED_KEY] = True @@ -445,8 +437,12 @@ async def _handle_committed_turn_async(self, *, event: CommittedEvent, raw_pcm: message=assistant_message, ) - await self._prompt_normalizer.hash_and_persist_message_async(message=user_message) - await self._prompt_normalizer.hash_and_persist_message_async(message=assistant_message) + await self._prompt_normalizer.hash_and_persist_message_async( + message=user_message, target_identifier=target_identifier + ) + await self._prompt_normalizer.hash_and_persist_message_async( + message=assistant_message, target_identifier=target_identifier + ) return assistant_message # ---- Wire helpers ------------------------------------------------------- diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index bb8c305f47..0f61c28c4c 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -149,9 +149,8 @@ def open_streaming_session( server_vad: Server-side voice activity detection. ``True`` (default) enables VAD with default tuning. Pass a ``ServerVadConfig`` for custom tuning, or ``False`` to disable (sending streaming config will then raise). - attack_identifier: Stamped on every persisted user / assistant piece for - attribution. Pass the caller's identifier so live messages share the - provenance contract of prepended messages. + attack_identifier: Deprecated. This parameter is ignored and will be removed in + release 0.17.0. persist_prepended_conversation: When ``True`` (default), the session writes ``prepended_conversation`` to memory itself. Pass ``False`` when the caller already persisted the prepended conversation (e.g. via @@ -164,6 +163,12 @@ def open_streaming_session( (but not yielded). The session owns its websocket connection + dispatcher for the duration of ``run_async``. """ + if attack_identifier is not None: + print_deprecation_message( + old_item="open_streaming_session(..., attack_identifier=...)", + new_item="open_streaming_session(...)", + removed_in="0.17.0", + ) return _OpenAIRealtimeStreamingSession( target=self, audio_chunks=audio_chunks, @@ -173,7 +178,6 @@ def open_streaming_session( response_converter_configurations=response_converter_configurations, prepended_conversation=prepended_conversation, server_vad=server_vad, - attack_identifier=attack_identifier, persist_prepended_conversation=persist_prepended_conversation, ) diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 6991996105..f0a797af93 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -717,8 +717,6 @@ def _parse_response_output_section( original_value=piece_value, conversation_id=message_piece.conversation_id, labels=message_piece.labels, # deprecated - prompt_target_identifier=message_piece.prompt_target_identifier, - attack_identifier=message_piece.attack_identifier, original_value_data_type=piece_type, response_error=error or "none", ) @@ -825,6 +823,4 @@ def _make_tool_piece(self, output: dict[str, Any], call_id: str, *, reference_pi original_value_data_type="function_call_output", conversation_id=reference_piece.conversation_id, labels={"call_id": call_id}, # deprecated - prompt_target_identifier=reference_piece.prompt_target_identifier, - attack_identifier=reference_piece.attack_identifier, ) diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index 99a625cae6..48786daba6 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -236,8 +236,6 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me original_value=piece_data, conversation_id=request_piece.conversation_id, labels=request_piece.labels, # deprecated - prompt_target_identifier=request_piece.prompt_target_identifier, - attack_identifier=request_piece.attack_identifier, original_value_data_type=piece_type, converted_value_data_type=piece_type, prompt_metadata=request_piece.prompt_metadata, diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index 8e0deed295..ba69a8dcdb 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -86,12 +86,13 @@ def import_scores_from_csv(self, csv_file_path: Path) -> list[MessagePiece]: sequence=int(sequence_str) if sequence_str else 0, labels=labels, # deprecated response_error=row.get("response_error", None), - prompt_target_identifier=self.get_identifier(), ) message_pieces.append(message_piece) # This is post validation, so the message_pieces should be okay and normalized - self._memory.add_message_pieces_to_memory(message_pieces=message_pieces) + self._memory.add_message_pieces_to_memory( + message_pieces=message_pieces, target_identifier=self.get_identifier() + ) return message_pieces def _validate_request(self, *, normalized_conversation: list[Message]) -> None: diff --git a/pyrit/score/conversation_scorer.py b/pyrit/score/conversation_scorer.py index d921b2e1cf..9c0842ce38 100644 --- a/pyrit/score/conversation_scorer.py +++ b/pyrit/score/conversation_scorer.py @@ -102,8 +102,6 @@ async def _score_async(self, message: Message, *, objective: str | None = None) id=original_piece.id, conversation_id=original_piece.conversation_id, labels=original_piece.labels, # deprecated - prompt_target_identifier=original_piece.prompt_target_identifier, - attack_identifier=original_piece.attack_identifier, original_value_data_type="text", converted_value_data_type="text", response_error="none", diff --git a/pyrit/score/float_scale/float_scale_scorer.py b/pyrit/score/float_scale/float_scale_scorer.py index e0501b0c12..8629b92bb0 100644 --- a/pyrit/score/float_scale/float_scale_scorer.py +++ b/pyrit/score/float_scale/float_scale_scorer.py @@ -5,7 +5,7 @@ from uuid import UUID from pyrit.exceptions.exception_classes import InvalidJsonException -from pyrit.models import ComponentIdentifier, Message, PromptDataType, Score, UnvalidatedScore +from pyrit.models import Message, PromptDataType, Score, UnvalidatedScore from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.score.scorer import Scorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -146,7 +146,6 @@ async def _score_value_with_llm_async( description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: ComponentIdentifier | None = None, ) -> UnvalidatedScore: score: UnvalidatedScore | None = None try: @@ -164,7 +163,6 @@ async def _score_value_with_llm_async( description_output_key=description_output_key, metadata_output_key=metadata_output_key, category_output_key=category_output_key, - attack_identifier=attack_identifier, ) if score is None: raise ValueError("Score returned None") diff --git a/pyrit/score/float_scale/insecure_code_scorer.py b/pyrit/score/float_scale/insecure_code_scorer.py index 85919178b0..f46b635110 100644 --- a/pyrit/score/float_scale/insecure_code_scorer.py +++ b/pyrit/score/float_scale/insecure_code_scorer.py @@ -94,7 +94,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st scored_prompt_id=message_piece.id, category=self._harm_category, objective=objective, - attack_identifier=message_piece.attack_identifier, ) # Modify the UnvalidatedScore parsing to check for 'score_value' diff --git a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py index 17105defb9..9631f944a9 100644 --- a/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_general_float_scale_scorer.py @@ -148,7 +148,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st scored_prompt_id=message_piece.id, category=self._score_category, objective=objective, - attack_identifier=message_piece.attack_identifier, score_value_output_key=self._score_value_output_key, rationale_output_key=self._rationale_output_key, description_output_key=self._description_output_key, diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index f5f2e97bcf..750a86e7c6 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -453,7 +453,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st message_data_type=message_piece.converted_value_data_type, scored_prompt_id=message_piece.id, category=self._score_category, - attack_identifier=message_piece.attack_identifier, objective=objective, ) diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 92db37a06a..87e8e73b51 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -138,7 +138,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st prepended_text_message_piece=prepended_text, category=self._category, objective=objective, - attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score( diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index f3cda9923b..5cef5744e2 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -349,8 +349,6 @@ def _create_text_piece_from_blocked(piece: MessagePiece) -> MessagePiece | None: labels=piece.labels, prompt_metadata=piece.prompt_metadata, converter_identifiers=list(piece.converter_identifiers), # type: ignore[arg-type] - prompt_target_identifier=piece.prompt_target_identifier, - attack_identifier=piece.attack_identifier, response_error="none", timestamp=piece.timestamp, ) @@ -676,7 +674,6 @@ async def _score_value_with_llm_async( description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: ComponentIdentifier | None = None, ) -> UnvalidatedScore: """ Send a request to a target, and take care of retries. @@ -710,8 +707,6 @@ async def _score_value_with_llm_async( Defaults to "metadata". category_output_key (str): The key in the JSON response that contains the category. Defaults to "category". - attack_identifier (ComponentIdentifier | None): The attack identifier. - Defaults to None. Returns: UnvalidatedScore: The score object containing the response from the target LLM. @@ -727,7 +722,6 @@ async def _score_value_with_llm_async( prompt_target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, - attack_identifier=attack_identifier, ) prompt_metadata: dict[str, str | int] = {"response_format": "json"} @@ -743,7 +737,6 @@ async def _score_value_with_llm_async( original_value_data_type="text", converted_value_data_type="text", conversation_id=conversation_id, - prompt_target_identifier=prompt_target.get_identifier(), prompt_metadata=prompt_metadata, ) ) @@ -756,7 +749,6 @@ async def _score_value_with_llm_async( original_value_data_type=message_data_type, converted_value_data_type=message_data_type, conversation_id=conversation_id, - prompt_target_identifier=prompt_target.get_identifier(), prompt_metadata=prompt_metadata, ) ) diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index 5fc51fbc25..e28ad75bbc 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -128,7 +128,6 @@ async def _check_for_password_in_conversation_async(self, conversation_id: str) original_value=conversation_as_text, converted_value=conversation_as_text, conversation_id=scoring_conversation_id, - prompt_target_identifier=self._prompt_target.get_identifier(), ) ] ) diff --git a/pyrit/score/true_false/prompt_shield_scorer.py b/pyrit/score/true_false/prompt_shield_scorer.py index a320e89fa9..d6ac555610 100644 --- a/pyrit/score/true_false/prompt_shield_scorer.py +++ b/pyrit/score/true_false/prompt_shield_scorer.py @@ -76,7 +76,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st original_value=body, prompt_metadata=message_piece.prompt_metadata, conversation_id=conversation_id, - prompt_target_identifier=self._prompt_target.get_identifier(), ) ] ) diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index cce672b642..9c526deff4 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -150,7 +150,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st message_data_type=message_piece.converted_value_data_type, scored_prompt_id=message_piece.id, objective=objective, - attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score(score_value=unvalidated_score.raw_score_value, score_type="true_false") diff --git a/pyrit/score/true_false/self_ask_general_true_false_scorer.py b/pyrit/score/true_false/self_ask_general_true_false_scorer.py index 71acd45a56..f706efbcbe 100644 --- a/pyrit/score/true_false/self_ask_general_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_general_true_false_scorer.py @@ -148,7 +148,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st scored_prompt_id=message_piece.id, category=self._score_category, objective=objective, - attack_identifier=message_piece.attack_identifier, score_value_output_key=self._score_value_output_key, rationale_output_key=self._rationale_output_key, description_output_key=self._description_output_key, diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index a2f5bc078e..0d05c67b76 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -92,7 +92,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st scored_prompt_id=message_piece.id, category=self._score_category, objective=objective, - attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score(score_value=unvalidated_score.raw_score_value, score_type="true_false") diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index b5a5c2b80c..0ad8e598e4 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -194,7 +194,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st scored_prompt_id=message_piece.id, category=self._score_category, objective=objective, - attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score(score_value=unvalidated_score.raw_score_value, score_type="true_false") diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index 0786d0db38..d315a2e4ed 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -229,7 +229,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: st prepended_text_message_piece=prepended_text, category=self._score_category, objective=objective, - attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score(score_value=unvalidated_score.raw_score_value, score_type="true_false") diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index f02367f6f8..a268d0b744 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -2175,9 +2175,14 @@ class TestAttackServiceAdditionalCoverage: async def test_create_related_conversation_uses_duplicate_branch(self, attack_service, mock_memory): """When source_conversation_id and cutoff_index are provided, duplication path is used.""" from pyrit.backend.models.attacks import CreateConversationRequest + from pyrit.models import Conversation ar = make_attack_result(conversation_id="attack-1") mock_memory.get_attack_results.return_value = [ar] + expected_target = ComponentIdentifier(class_name="TextTarget", class_module="pyrit.prompt_target") + mock_memory.get_conversation_metadata.return_value = Conversation( + conversation_id="attack-1", target_identifier=expected_target + ) with patch.object(attack_service, "_duplicate_conversation_up_to", return_value="branch-dup") as mock_dup: result = await attack_service.create_related_conversation_async( @@ -2187,7 +2192,11 @@ async def test_create_related_conversation_uses_duplicate_branch(self, attack_se assert result is not None assert result.conversation_id == "branch-dup" - mock_dup.assert_called_once_with(source_conversation_id="attack-1", cutoff_index=2) + mock_dup.assert_called_once_with( + source_conversation_id="attack-1", + cutoff_index=2, + target_identifier=expected_target, + ) async def test_add_message_merges_converter_identifiers_without_duplicates(self, attack_service, mock_memory): """Should merge new converter identifiers with existing attack identifiers by hash.""" diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index 47d8678b7d..d3d107732d 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -251,8 +251,6 @@ def test_swaps_user_to_assistant(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -267,8 +265,6 @@ def test_swaps_assistant_to_user(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -286,8 +282,6 @@ def test_swaps_simulated_assistant_to_user(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -305,8 +299,6 @@ def test_skips_system_messages(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) # Only user message should be present, system skipped @@ -322,8 +314,6 @@ def test_assigns_new_uuids(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) # New ID should be different from original @@ -344,8 +334,6 @@ def test_preserves_message_content(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert result[0].get_piece().original_value == "Original content" @@ -356,8 +344,6 @@ def test_empty_prepended_conversation(self) -> None: result = get_adversarial_chat_messages( [], adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert result == [] @@ -371,8 +357,6 @@ def test_applies_labels(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), labels=labels, ) @@ -389,8 +373,6 @@ def test_labels_emit_deprecation_warning(self) -> None: get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier=ComponentIdentifier(class_name="TestAttack", class_module="test_module"), - adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), labels={"env": "prod"}, ) @@ -499,9 +481,8 @@ class TestConversationManagerInitialization: def test_init_with_required_parameters(self, attack_identifier: ComponentIdentifier) -> None: """Test initialization with only required parameters.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() - assert manager._attack_identifier == attack_identifier assert isinstance(manager._prompt_normalizer, PromptNormalizer) assert manager._memory is not None @@ -509,7 +490,7 @@ def test_init_with_custom_prompt_normalizer( self, attack_identifier: ComponentIdentifier, mock_prompt_normalizer: MagicMock ) -> None: """Test initialization with a custom prompt normalizer.""" - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_prompt_normalizer) + manager = ConversationManager(prompt_normalizer=mock_prompt_normalizer) assert manager._prompt_normalizer == mock_prompt_normalizer @@ -525,7 +506,7 @@ class TestConversationRetrieval: def test_get_conversation_returns_empty_list_when_no_messages(self, attack_identifier: ComponentIdentifier) -> None: """Test get_conversation returns empty list for non-existent conversation.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) result = manager.get_conversation(conversation_id) @@ -536,7 +517,7 @@ def test_get_conversation_returns_messages_in_order( self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_conversation returns messages in order.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Add messages to the database @@ -553,7 +534,7 @@ def test_get_conversation_returns_messages_in_order( def test_get_last_message_returns_none_for_empty_conversation(self, attack_identifier: ComponentIdentifier) -> None: """Test get_last_message returns None for empty conversation.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) result = manager.get_last_message(conversation_id=conversation_id) @@ -564,7 +545,7 @@ def test_get_last_message_returns_last_piece( self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_last_message returns the most recent message.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Add messages to the database @@ -582,7 +563,7 @@ def test_get_last_message_with_role_filter( self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_last_message with role filter returns correct message.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Add messages to the database @@ -601,7 +582,7 @@ def test_get_last_message_with_role_filter_returns_none_when_no_match( self, attack_identifier: ComponentIdentifier, sample_conversation: list[Message] ) -> None: """Test get_last_message returns None when no message matches role filter.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Add messages to the database @@ -629,7 +610,7 @@ def test_set_system_prompt_with_chat_target( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock ) -> None: """Test set_system_prompt calls target's set_system_prompt method.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) system_prompt = "You are a helpful assistant" labels = {"type": "system"} @@ -644,7 +625,6 @@ def test_set_system_prompt_with_chat_target( mock_chat_target.set_system_prompt.assert_called_once_with( system_prompt=system_prompt, conversation_id=conversation_id, - attack_identifier=attack_identifier, labels=labels, ) @@ -652,7 +632,7 @@ def test_set_system_prompt_without_labels( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock ) -> None: """Test set_system_prompt works without labels.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) system_prompt = "You are a helpful assistant" @@ -670,7 +650,7 @@ def test_set_system_prompt_labels_emit_deprecation_warning( self, attack_identifier: ComponentIdentifier, mock_chat_target: MagicMock ) -> None: """Test that passing labels emits deprecation warning.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() with patch( "pyrit.executor.attack.component.conversation_manager.print_deprecation_message" @@ -701,7 +681,7 @@ async def test_raises_error_for_empty_conversation_id( mock_attack_context: AttackContext, ) -> None: """Test that empty conversation_id raises ValueError.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() with pytest.raises(ValueError, match="conversation_id cannot be empty"): await manager.initialize_context_async( @@ -717,7 +697,7 @@ async def test_returns_default_state_for_no_prepended_conversation( mock_attack_context: AttackContext, ) -> None: """Test that no prepended conversation returns default state.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) state = await manager.initialize_context_async( @@ -736,7 +716,7 @@ async def test_merges_memory_labels( mock_chat_target: MagicMock, ) -> None: """Test that memory_labels are merged with context labels.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.memory_labels = {"context_key": "context_value"} @@ -759,7 +739,7 @@ async def test_adds_prepended_conversation_to_memory_for_chat_target( sample_conversation: list[Message], ) -> None: """Test that prepended conversation is added to memory for chat targets.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -781,7 +761,7 @@ async def test_converts_assistant_to_simulated_assistant( sample_assistant_piece: MessagePiece, ) -> None: """Test that assistant messages are converted to simulated_assistant.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = [Message(message_pieces=[sample_assistant_piece])] @@ -805,7 +785,7 @@ async def test_normalizes_for_non_chat_target_by_default( sample_conversation: list[Message], ) -> None: """Test that prepended conversation is normalized for non-chat targets by default.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -830,7 +810,7 @@ async def test_normalizes_for_non_chat_target_when_configured( sample_conversation: list[Message], ) -> None: """Test that non-chat target normalizes prepended conversation when configured.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -858,7 +838,7 @@ async def test_returns_turn_count_for_multi_turn_attacks( sample_conversation: list[Message], ) -> None: """Test that turn count is returned for multi-turn attacks.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -880,7 +860,7 @@ async def test_multipart_message_extracts_scores_from_all_pieces( sample_score: Score, ) -> None: """Test that multi-part assistant messages extract scores from all pieces.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) @@ -957,7 +937,7 @@ async def test_prepended_conversation_ignores_true_scores( would incorrectly indicate the objective was already achieved. Only false scores are extracted to provide feedback rationale for continued attack attempts. """ - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) @@ -1056,7 +1036,7 @@ async def test_non_chat_target_behavior_normalize_is_default( sample_conversation: list[Message], ) -> None: """Test that non-chat targets normalize by default (no config), matching dataclass field default.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1081,7 +1061,7 @@ async def test_non_chat_target_behavior_raise_explicit( sample_conversation: list[Message], ) -> None: """Test that non_chat_target_behavior='raise' raises ValueError.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1108,7 +1088,7 @@ async def test_non_chat_target_behavior_normalize_first_turn_creates_next_messag sample_conversation: list[Message], ) -> None: """Test that normalize_first_turn creates next_message when none exists.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1135,7 +1115,7 @@ async def test_non_chat_target_behavior_normalize_first_turn_prepends_to_existin sample_conversation: list[Message], ) -> None: """Test that normalize_first_turn prepends context to existing next_message.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1164,7 +1144,7 @@ async def test_non_chat_target_behavior_normalize_returns_empty_state( sample_conversation: list[Message], ) -> None: """Test that normalize_first_turn returns empty ConversationState (no turn tracking).""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1195,7 +1175,7 @@ async def test_apply_converters_to_roles_default_applies_to_all( """Test that converters are applied to all roles by default.""" mock_normalizer = MagicMock(spec=PromptNormalizer) mock_normalizer.convert_values_async = AsyncMock() - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_normalizer) + manager = ConversationManager(prompt_normalizer=mock_normalizer) conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1221,7 +1201,7 @@ async def test_apply_converters_to_roles_user_only( """Test that converters are applied only to user role when configured.""" mock_normalizer = MagicMock(spec=PromptNormalizer) mock_normalizer.convert_values_async = AsyncMock() - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_normalizer) + manager = ConversationManager(prompt_normalizer=mock_normalizer) conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1249,7 +1229,7 @@ async def test_apply_converters_to_roles_assistant_only( """Test that converters are applied only to assistant role when configured.""" mock_normalizer = MagicMock(spec=PromptNormalizer) mock_normalizer.convert_values_async = AsyncMock() - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_normalizer) + manager = ConversationManager(prompt_normalizer=mock_normalizer) conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1277,7 +1257,7 @@ async def test_apply_converters_to_roles_empty_list_skips_all( """Test that empty roles list means no converters applied to any role.""" mock_normalizer = MagicMock(spec=PromptNormalizer) mock_normalizer.convert_values_async = AsyncMock() - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_normalizer) + manager = ConversationManager(prompt_normalizer=mock_normalizer) conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1307,7 +1287,7 @@ async def test_message_normalizer_default_uses_conversation_context_normalizer( sample_conversation: list[Message], ) -> None: """Test that default normalizer produces Turn N format.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1340,7 +1320,7 @@ async def test_message_normalizer_custom_normalizer_is_used( mock_normalizer = MagicMock(spec=MessageStringNormalizer) mock_normalizer.normalize_string_async = AsyncMock(return_value="CUSTOM_FORMAT: test content") - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1423,7 +1403,7 @@ async def test_chat_target_ignores_non_chat_target_behavior( sample_conversation: list[Message], ) -> None: """Test that chat targets ignore non_chat_target_behavior setting.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = sample_conversation @@ -1454,7 +1434,7 @@ async def test_config_with_max_turns_validation( mock_chat_target: MagicMock, ) -> None: """Test that config works correctly with max_turns validation.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) @@ -1503,7 +1483,7 @@ async def test_adds_messages_to_memory( sample_conversation: list[Message], ) -> None: """Test that messages are added to memory.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) turn_count = await manager.add_prepended_conversation_to_memory_async( @@ -1521,7 +1501,7 @@ async def test_assigns_conversation_id_to_all_pieces( sample_conversation: list[Message], ) -> None: """Test that conversation_id is assigned to all message pieces.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) await manager.add_prepended_conversation_to_memory_async( @@ -1534,25 +1514,6 @@ async def test_assigns_conversation_id_to_all_pieces( for piece in msg.message_pieces: assert piece.conversation_id == conversation_id - async def test_assigns_attack_identifier_to_all_pieces( - self, - attack_identifier: ComponentIdentifier, - sample_conversation: list[Message], - ) -> None: - """Test that attack_identifier is assigned to all message pieces.""" - manager = ConversationManager(attack_identifier=attack_identifier) - conversation_id = str(uuid.uuid4()) - - await manager.add_prepended_conversation_to_memory_async( - prepended_conversation=sample_conversation, - conversation_id=conversation_id, - ) - - stored = manager.get_conversation(conversation_id) - for msg in stored: - for piece in msg.message_pieces: - assert piece.attack_identifier == attack_identifier - async def test_raises_error_when_exceeds_max_turns( self, attack_identifier: ComponentIdentifier, @@ -1560,7 +1521,7 @@ async def test_raises_error_when_exceeds_max_turns( sample_assistant_piece: MessagePiece, ) -> None: """Test that exceeding max_turns raises ValueError.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Create conversation with 2 assistant messages @@ -1583,7 +1544,7 @@ async def test_multipart_response_counts_as_one_turn( attack_identifier: ComponentIdentifier, ) -> None: """Test that a multi-part assistant response counts as only one turn.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) piece_conversation_id = str(uuid.uuid4()) @@ -1621,7 +1582,7 @@ async def test_returns_zero_for_empty_conversation( attack_identifier: ComponentIdentifier, ) -> None: """Test that empty conversation returns 0 turns.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) turn_count = await manager.add_prepended_conversation_to_memory_async( @@ -1638,7 +1599,7 @@ async def test_applies_converters_when_provided( sample_user_piece: MessagePiece, ) -> None: """Test that converters are applied when provided.""" - manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_prompt_normalizer) + manager = ConversationManager(prompt_normalizer=mock_prompt_normalizer) conversation_id = str(uuid.uuid4()) conversation = [Message(message_pieces=[sample_user_piece])] converter_config = [PromptConverterConfiguration(converters=[])] @@ -1657,7 +1618,7 @@ async def test_handles_none_messages_gracefully( attack_identifier: ComponentIdentifier, ) -> None: """Test that None messages are handled gracefully.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) turn_count = await manager.add_prepended_conversation_to_memory_async( @@ -1684,7 +1645,7 @@ async def test_preserves_piece_metadata( sample_user_piece: MessagePiece, ) -> None: """Test that piece metadata is preserved during processing.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) # Add metadata to piece @@ -1712,7 +1673,7 @@ async def test_preserves_original_and_converted_values( sample_user_piece: MessagePiece, ) -> None: """Test that original and converted values are preserved.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) sample_user_piece.original_value = "Original message" @@ -1740,7 +1701,7 @@ async def test_handles_system_messages_in_prepended_conversation( sample_user_piece: MessagePiece, ) -> None: """Test that system messages are handled in prepended conversation.""" - manager = ConversationManager(attack_identifier=attack_identifier) + manager = ConversationManager() conversation_id = str(uuid.uuid4()) context = _TestAttackContext(params=AttackParameters(objective="Test objective")) context.prepended_conversation = [ diff --git a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py index ae7ccaafaa..111c9e00bf 100644 --- a/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py +++ b/tests/unit/executor/attack/multi_turn/test_supports_multi_turn_attacks.py @@ -28,7 +28,7 @@ def _make_strategy(*, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn target.configuration.includes.return_value = supports_multi_turn - target.get_identifier.return_value = MagicMock() + target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} with patch.multiple( MultiTurnAttackStrategy, @@ -378,13 +378,13 @@ def _make_tap_node(self, *, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn target.configuration.includes.return_value = supports_multi_turn - target.get_identifier.return_value = MagicMock() + target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} adversarial_chat = MagicMock() - adversarial_chat.get_identifier.return_value = MagicMock() + adversarial_chat.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} scorer = MagicMock() - scorer.get_identifier.return_value = MagicMock() + scorer.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} seed = MagicMock() seed.render_template_value.return_value = "template" @@ -694,14 +694,14 @@ def _make_single_turn_target(self): target.configuration = TargetConfiguration( capabilities=TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True), ) - target.get_identifier.return_value = MagicMock() + target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} return target def _make_adversarial_config(self): from pyrit.executor.attack.core.attack_config import AttackAdversarialConfig adversarial_chat = MagicMock() - adversarial_chat.get_identifier.return_value = MagicMock() + adversarial_chat.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} return AttackAdversarialConfig(target=adversarial_chat) def _make_scoring_config(self): @@ -709,7 +709,7 @@ def _make_scoring_config(self): from pyrit.score import TrueFalseScorer scorer = MagicMock(spec=TrueFalseScorer) - scorer.get_identifier.return_value = MagicMock() + scorer.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} return AttackScoringConfig(objective_scorer=scorer) async def test_crescendo_raises_for_single_turn_target(self): @@ -752,13 +752,13 @@ def _make_tap_node(self, *, supports_multi_turn: bool): target = MagicMock() target.capabilities.supports_multi_turn = supports_multi_turn target.configuration.includes.return_value = supports_multi_turn - target.get_identifier.return_value = MagicMock() + target.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} adversarial_chat = MagicMock() - adversarial_chat.get_identifier.return_value = MagicMock() + adversarial_chat.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} scorer = MagicMock() - scorer.get_identifier.return_value = MagicMock() + scorer.get_identifier.return_value = {"__type__": "MockTarget", "__module__": "test", "id": "mock-id"} seed = MagicMock() seed.render_template_value.return_value = "template" diff --git a/tests/unit/executor/attack/single_turn/test_context_compliance.py b/tests/unit/executor/attack/single_turn/test_context_compliance.py index b10ff8e640..5164050b6f 100644 --- a/tests/unit/executor/attack/single_turn/test_context_compliance.py +++ b/tests/unit/executor/attack/single_turn/test_context_compliance.py @@ -568,7 +568,6 @@ async def test_get_objective_as_benign_question_async( call_args = mock_prompt_normalizer.send_prompt_async.call_args assert call_args.kwargs["target"] == attack._adversarial_chat - assert call_args.kwargs["attack_identifier"] == attack.get_identifier() assert call_args.kwargs["labels"] == basic_context.memory_labels # Verify message was created correctly (converted from seed group) @@ -616,7 +615,6 @@ async def test_get_benign_question_answer_async( call_args = mock_prompt_normalizer.send_prompt_async.call_args assert call_args.kwargs["target"] == attack._adversarial_chat - assert call_args.kwargs["attack_identifier"] == attack.get_identifier() assert call_args.kwargs["labels"] == basic_context.memory_labels # Verify template was rendered with benign request @@ -657,7 +655,6 @@ async def test_get_objective_as_question_async( call_args = mock_prompt_normalizer.send_prompt_async.call_args assert call_args.kwargs["target"] == attack._adversarial_chat - assert call_args.kwargs["attack_identifier"] == attack.get_identifier() assert call_args.kwargs["labels"] == basic_context.memory_labels # Verify template was rendered diff --git a/tests/unit/executor/attack/single_turn/test_prompt_sending.py b/tests/unit/executor/attack/single_turn/test_prompt_sending.py index 7e23fcbda6..bf9d61a627 100644 --- a/tests/unit/executor/attack/single_turn/test_prompt_sending.py +++ b/tests/unit/executor/attack/single_turn/test_prompt_sending.py @@ -418,7 +418,6 @@ async def test_send_prompt_to_target_with_all_configurations( assert call_args.kwargs["request_converter_configurations"] == request_converters assert call_args.kwargs["response_converter_configurations"] == response_converters assert call_args.kwargs["labels"] == {"test": "label"} - assert "attack_identifier" in call_args.kwargs async def test_send_prompt_handles_none_response(self, mock_target, mock_prompt_normalizer, basic_context): attack = PromptSendingAttack(objective_target=mock_target, prompt_normalizer=mock_prompt_normalizer) diff --git a/tests/unit/executor/attack/streaming/test_barge_in.py b/tests/unit/executor/attack/streaming/test_barge_in.py index fee1d181d9..7a4bc289fe 100644 --- a/tests/unit/executor/attack/streaming/test_barge_in.py +++ b/tests/unit/executor/attack/streaming/test_barge_in.py @@ -225,7 +225,6 @@ async def test_perform_async_opens_session_with_expected_kwargs(vad_target): assert kwargs["request_converter_configurations"] == attack._request_converters assert kwargs["response_converter_configurations"] == attack._response_converters assert kwargs["prepended_conversation"] == ctx.prepended_conversation - assert kwargs["attack_identifier"] == attack.get_identifier() assert kwargs["persist_prepended_conversation"] is False diff --git a/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py b/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py index 4e633b8a9f..e8466493da 100644 --- a/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py +++ b/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py @@ -14,7 +14,7 @@ FuzzerShortenConverter, FuzzerSimilarConverter, ) -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece @pytest.mark.parametrize( @@ -89,8 +89,6 @@ async def test_converter_send_prompt_async_bad_json_exception_retries( converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] diff --git a/tests/unit/memory/memory_interface/test_batching_scale.py b/tests/unit/memory/memory_interface/test_batching_scale.py index 239a86d474..65c3805877 100644 --- a/tests/unit/memory/memory_interface/test_batching_scale.py +++ b/tests/unit/memory/memory_interface/test_batching_scale.py @@ -36,7 +36,6 @@ def _create_message_piece( converted_value_sha256=sha256, sequence=0, conversation_id=conversation_id or str(uuid.uuid4()), - attack_identifier=ComponentIdentifier.from_dict({"id": str(uuid.uuid4())}), ) diff --git a/tests/unit/memory/memory_interface/test_interface_export.py b/tests/unit/memory/memory_interface/test_interface_export.py index 34252b7547..aafb83fb12 100644 --- a/tests/unit/memory/memory_interface/test_interface_export.py +++ b/tests/unit/memory/memory_interface/test_interface_export.py @@ -19,7 +19,7 @@ def test_export_conversation_by_attack_id_file_created( sqlite_instance: MemoryInterface, sample_conversations: Sequence[MessagePiece] ): - attack1_id = sample_conversations[0].attack_identifier.hash + attack1_id = "attack-1" # Default path in export_conversations() file_name = f"{attack1_id}.json" diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index f1261b9597..ad5cd1ddb6 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -14,6 +14,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.memory import MemoryInterface, PromptMemoryEntry from pyrit.models import ( + AttackResult, ComponentIdentifier, IdentifierFilter, IdentifierType, @@ -21,6 +22,7 @@ MessagePiece, Score, SeedPrompt, + build_atomic_attack_identifier, ) @@ -135,7 +137,6 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): converted_value="Hello, how are you?", conversation_id=conversation_id_1, sequence=0, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", @@ -143,14 +144,12 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): converted_value="I'm fine, thank you!", conversation_id=conversation_id_1, sequence=1, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", original_value="original prompt text", converted_value="I'm fine, thank you!", conversation_id=conversation_id_3, - attack_identifier=attack2.get_identifier(), ), MessagePiece( role="user", @@ -158,7 +157,6 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): converted_value="Hello, how are you?", conversation_id=conversation_id_2, sequence=0, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", @@ -166,7 +164,6 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): converted_value="I'm fine, thank you!", conversation_id=conversation_id_2, sequence=1, - attack_identifier=attack1.get_identifier(), ), ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) @@ -179,28 +176,6 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): ) all_pieces = sqlite_instance.get_message_pieces() assert len(all_pieces) == 9 - # Attack IDs are preserved (not changed) when duplicating - assert all(p.attack_identifier is not None for p in all_pieces) - assert ( - len( - [ - p - for p in all_pieces - if p.attack_identifier is not None and p.attack_identifier.hash == attack1.get_identifier().hash - ] - ) - == 8 - ) - assert ( - len( - [ - p - for p in all_pieces - if p.attack_identifier is not None and p.attack_identifier.hash == attack2.get_identifier().hash - ] - ) - == 1 - ) assert len([p for p in all_pieces if p.conversation_id == conversation_id_1]) == 2 assert len([p for p in all_pieces if p.conversation_id == conversation_id_2]) == 2 assert len([p for p in all_pieces if p.conversation_id == conversation_id_3]) == 1 @@ -223,7 +198,6 @@ def test_duplicate_conversation_pieces_not_score(sqlite_instance: MemoryInterfac converted_value="Hello, how are you?", conversation_id=conversation_id, sequence=0, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), MessagePiece( @@ -233,7 +207,6 @@ def test_duplicate_conversation_pieces_not_score(sqlite_instance: MemoryInterfac converted_value="I'm fine, thank you!", conversation_id=conversation_id, sequence=0, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), ] @@ -276,8 +249,6 @@ def test_duplicate_conversation_pieces_not_score(sqlite_instance: MemoryInterfac for piece in new_pieces: assert piece.id not in (prompt_id_1, prompt_id_2) assert len(sqlite_instance.get_prompt_scores(labels=memory_labels)) == 2 - # Attack ID is preserved, so both original and duplicated pieces have the same attack ID - assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier().hash)) == 2 # The duplicate prompts ids should not have scores so only two scores are returned assert len(sqlite_instance.get_prompt_scores(prompt_ids=[str(prompt_id_1), str(prompt_id_2)] + new_pieces_ids)) == 2 @@ -294,14 +265,12 @@ def test_duplicate_conversation_excluding_last_turn(sqlite_instance: MemoryInter original_value="original prompt text", conversation_id=conversation_id_1, sequence=0, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", original_value="original prompt text", conversation_id=conversation_id_1, sequence=1, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="user", @@ -309,7 +278,6 @@ def test_duplicate_conversation_excluding_last_turn(sqlite_instance: MemoryInter converted_value="I'm fine, thank you!", sequence=2, conversation_id=conversation_id_1, - attack_identifier=attack2.get_identifier(), ), MessagePiece( role="user", @@ -317,7 +285,6 @@ def test_duplicate_conversation_excluding_last_turn(sqlite_instance: MemoryInter converted_value="Hello, how are you?", conversation_id=conversation_id_2, sequence=2, - attack_identifier=attack2.get_identifier(), ), MessagePiece( role="assistant", @@ -325,7 +292,6 @@ def test_duplicate_conversation_excluding_last_turn(sqlite_instance: MemoryInter converted_value="I'm fine, thank you!", conversation_id=conversation_id_2, sequence=3, - attack_identifier=attack1.get_identifier(), ), ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) @@ -359,7 +325,6 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M converted_value="Hello, how are you?", conversation_id=conversation_id, sequence=0, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), MessagePiece( @@ -369,7 +334,6 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M converted_value="I'm fine, thank you!", conversation_id=conversation_id, sequence=1, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), MessagePiece( @@ -378,7 +342,6 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M converted_value="That's good.", conversation_id=conversation_id, sequence=2, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), MessagePiece( @@ -387,7 +350,6 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M converted_value="Thanks.", conversation_id=conversation_id, sequence=3, - attack_identifier=attack1.get_identifier(), labels=memory_labels, ), ] @@ -430,8 +392,6 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M assert new_pieces[0].id != prompt_id_1 assert new_pieces[1].id != prompt_id_2 assert len(sqlite_instance.get_prompt_scores(labels=memory_labels)) == 2 - # Attack ID is preserved - assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier().hash)) == 2 # The duplicate prompts ids should not have scores so only two scores are returned assert len(sqlite_instance.get_prompt_scores(prompt_ids=[str(prompt_id_1), str(prompt_id_2)] + new_pieces_ids)) == 2 @@ -445,28 +405,24 @@ def test_duplicate_conversation_excluding_last_turn_same_attack(sqlite_instance: original_value="original prompt text", conversation_id=conversation_id_1, sequence=0, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", original_value="original prompt text", conversation_id=conversation_id_1, sequence=1, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="user", original_value="original prompt text", conversation_id=conversation_id_1, sequence=2, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", original_value="original prompt text", conversation_id=conversation_id_1, sequence=3, - attack_identifier=attack1.get_identifier(), ), ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) @@ -486,39 +442,6 @@ def test_duplicate_conversation_excluding_last_turn_same_attack(sqlite_instance: assert piece.sequence < 2 -def test_duplicate_memory_preserves_attack_id(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=get_mock_target()) - conversation_id = "11111" - pieces = [ - MessagePiece( - role="user", - original_value="original prompt text", - converted_value="Hello, how are you?", - conversation_id=conversation_id, - sequence=0, - attack_identifier=attack1.get_identifier(), - ), - ] - sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) - assert len(sqlite_instance.get_message_pieces()) == 1 - - # Duplicating preserves the attack ID - new_conversation_id = sqlite_instance.duplicate_conversation( - conversation_id=conversation_id, - ) - - # Verify duplication succeeded - all_pieces = sqlite_instance.get_message_pieces() - assert len(all_pieces) == 2 - assert new_conversation_id != conversation_id - - # Both pieces should have the same attack ID - assert all(p.attack_identifier is not None for p in all_pieces) - attack_ids = {p.attack_identifier.hash for p in all_pieces if p.attack_identifier is not None} - assert len(attack_ids) == 1 - assert attack1.get_identifier().hash in attack_ids - - def test_duplicate_conversation_creates_new_ids(sqlite_instance: MemoryInterface): """Test that duplicated conversation has new piece IDs.""" attack1 = PromptSendingAttack(objective_target=get_mock_target()) @@ -529,7 +452,6 @@ def test_duplicate_conversation_creates_new_ids(sqlite_instance: MemoryInterface converted_value="Hello", conversation_id=conversation_id, sequence=1, - attack_identifier=attack1.get_identifier(), ) sqlite_instance.add_message_pieces_to_memory(message_pieces=[original_piece]) @@ -560,7 +482,6 @@ def test_duplicate_conversation_preserves_original_prompt_id(sqlite_instance: Me original_value="traceable prompt", conversation_id=conversation_id, sequence=1, - attack_identifier=attack1.get_identifier(), ) sqlite_instance.add_message_pieces_to_memory(message_pieces=[original_piece]) original_prompt_id = original_piece.original_prompt_id @@ -586,21 +507,18 @@ def test_duplicate_conversation_with_multiple_pieces(sqlite_instance: MemoryInte original_value="user message 1", conversation_id=conversation_id, sequence=1, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="assistant", original_value="assistant response 1", conversation_id=conversation_id, sequence=2, - attack_identifier=attack1.get_identifier(), ), MessagePiece( role="user", original_value="user message 2", conversation_id=conversation_id, sequence=3, - attack_identifier=attack1.get_identifier(), ), ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) @@ -932,31 +850,29 @@ def test_get_message_pieces_attack(sqlite_instance: MemoryInterface): attack1 = PromptSendingAttack(objective_target=get_mock_target()) attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) - entries = [ - PromptMemoryEntry( - entry=MessagePiece( - role="user", - original_value="Hello 1", - attack_identifier=attack1.get_identifier(), - ) - ), - PromptMemoryEntry( - entry=MessagePiece( - role="assistant", - original_value="Hello 2", - attack_identifier=attack2.get_identifier(), - ) - ), - PromptMemoryEntry( - entry=MessagePiece( - role="user", - original_value="Hello 3", - attack_identifier=attack1.get_identifier(), - ) - ), + pieces = [ + MessagePiece(role="user", original_value="Hello 1", conversation_id="c1", sequence=0), + MessagePiece(role="assistant", original_value="Hello 2", conversation_id="c2", sequence=0), + MessagePiece(role="user", original_value="Hello 3", conversation_id="c1", sequence=1), ] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) - sqlite_instance._insert_entries(entries=entries) + # attack_identifier is no longer stamped on pieces; the deprecated attack_id filter + # resolves to an attack's main conversation via persisted AttackResults. + sqlite_instance.add_attack_results_to_memory( + attack_results=[ + AttackResult( + conversation_id="c1", + objective="objective 1", + atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack1.get_identifier()), + ), + AttackResult( + conversation_id="c2", + objective="objective 2", + atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack2.get_identifier()), + ), + ] + ) attack1_entries = sqlite_instance.get_message_pieces(attack_id=attack1.get_identifier().hash) @@ -1115,7 +1031,6 @@ def test_get_message_pieces_with_non_matching_memory_labels(sqlite_instance: Mem role="user", original_value="Hello 3", converted_value="Hello 1", - attack_identifier=attack.get_identifier(), ) ), ] @@ -1371,53 +1286,21 @@ def test_get_request_from_response_raises_error_for_sequence_less_than_one(sqlit def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryInterface): attack1 = PromptSendingAttack(objective_target=get_mock_target()) - attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) - - entries = [ - PromptMemoryEntry( - entry=MessagePiece( - role="user", - original_value="Hello 1", - attack_identifier=attack1.get_identifier(), - ) - ), - PromptMemoryEntry( - entry=MessagePiece( - role="assistant", - original_value="Hello 2", - attack_identifier=attack2.get_identifier(), - ) - ), - ] - - sqlite_instance._insert_entries(entries=entries) - # Filter by exact attack hash - results = sqlite_instance.get_message_pieces( - identifier_filters=[ - IdentifierFilter( - identifier_type=IdentifierType.ATTACK, - property_path="$.hash", - value=attack1.get_identifier().hash, - partial_match=False, - ) - ], - ) - assert len(results) == 1 - assert results[0].original_value == "Hello 1" - - # No match - results = sqlite_instance.get_message_pieces( - identifier_filters=[ - IdentifierFilter( - identifier_type=IdentifierType.ATTACK, - property_path="$.hash", - value="nonexistent_hash", - partial_match=False, - ) - ], - ) - assert len(results) == 0 + # IdentifierType.ATTACK is no longer stamped on message pieces, so the piece-level + # identifier filter rejects it. Attack filtering now goes through get_attack_results + # or the deprecated attack_id parameter. + with pytest.raises(ValueError, match="does not support identifier type"): + sqlite_instance.get_message_pieces( + identifier_filters=[ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.hash", + value=attack1.get_identifier().hash, + partial_match=False, + ) + ], + ) def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryInterface): @@ -1432,24 +1315,26 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, ) - entries = [ - PromptMemoryEntry( - entry=MessagePiece( + sqlite_instance.add_message_pieces_to_memory( + message_pieces=[ + MessagePiece( role="user", original_value="Hello OpenAI", - prompt_target_identifier=target_id_1, - ) - ), - PromptMemoryEntry( - entry=MessagePiece( + conversation_id="conv-openai", + ), + ], + target_identifier=target_id_1, + ) + sqlite_instance.add_message_pieces_to_memory( + message_pieces=[ + MessagePiece( role="user", original_value="Hello Azure", - prompt_target_identifier=target_id_2, - ) - ), - ] - - sqlite_instance._insert_entries(entries=entries) + conversation_id="conv-azure", + ), + ], + target_identifier=target_id_2, + ) # Filter by target hash results = sqlite_instance.get_message_pieces( diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 7f786c260c..ead87d6666 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -8,17 +8,17 @@ from uuid import uuid4 import pytest -from unit.mocks import get_mock_target -from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.memory import MemoryInterface, PromptMemoryEntry from pyrit.models import ( + AttackResult, ComponentIdentifier, IdentifierFilter, IdentifierType, MessagePiece, Score, SeedPrompt, + build_atomic_attack_identifier, ) @@ -41,6 +41,19 @@ def test_get_scores_by_attack_id_and_label( sqlite_instance.add_message_pieces_to_memory(message_pieces=sample_conversations) + # attack_identifier is no longer stamped on pieces; the deprecated attack_id filter + # resolves to an attack's main conversation via persisted AttackResults. + attack_strategy_id = ComponentIdentifier(class_name="TestAttack", class_module="test.module") + sqlite_instance.add_attack_results_to_memory( + attack_results=[ + AttackResult( + conversation_id=sample_conversations[0].conversation_id, + objective="test objective", + atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=attack_strategy_id), + ) + ] + ) + score = Score( score_value=str(0.8), score_value_description="High score", @@ -55,8 +68,7 @@ def test_get_scores_by_attack_id_and_label( sqlite_instance.add_scores_to_memory(scores=[score]) # Fetch the score we just added - assert sample_conversations[0].attack_identifier is not None - db_score = sqlite_instance.get_prompt_scores(attack_id=sample_conversations[0].attack_identifier.hash) + db_score = sqlite_instance.get_prompt_scores(attack_id=attack_strategy_id.hash) assert len(db_score) == 1 assert db_score[0].score_value == score.score_value @@ -76,9 +88,8 @@ def test_get_scores_by_attack_id_and_label( assert len(db_score) == 1 assert db_score[0].score_value == score.score_value - assert sample_conversations[0].attack_identifier is not None db_score = sqlite_instance.get_prompt_scores( - attack_id=sample_conversations[0].attack_identifier.hash, + attack_id=attack_strategy_id.hash, labels={"x": "y"}, ) assert len(db_score) == 0 @@ -161,7 +172,6 @@ def test_get_prompt_scores_empty_prompt_ids_returns_empty(sqlite_instance: Memor def test_add_score_duplicate_prompt(sqlite_instance: MemoryInterface): # Ensure that scores of duplicate prompts are linked back to the original original_id = uuid4() - attack = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = str(uuid4()) pieces = [ MessagePiece( @@ -171,12 +181,11 @@ def test_add_score_duplicate_prompt(sqlite_instance: MemoryInterface): converted_value="Hello, how are you?", conversation_id=conversation_id, sequence=0, - attack_identifier=attack.get_identifier(), ) ] sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) sqlite_instance.duplicate_conversation(conversation_id=conversation_id) - # Get the duplicated piece (it will have a different conversation_id but same attack_id) + # Get the duplicated piece (it will have a different conversation_id) all_pieces = sqlite_instance.get_message_pieces() dupe_piece = [p for p in all_pieces if p.id != original_id][0] dupe_id = dupe_piece.id diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index b8ee5bb6dc..ddcd509718 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -4,7 +4,6 @@ import os import uuid from collections.abc import Generator, MutableSequence, Sequence -from datetime import timezone from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch @@ -197,49 +196,43 @@ def test_get_memories_with_json_properties(memory_interface: AzureSQLMemory): converter_identifiers = [Base64Converter().get_identifier()] target = TextTarget() - # Start a session - with memory_interface.get_session() as session: # type: ignore[arg-type] - # Create a ConversationData entry with all attributes filled - entry = PromptMemoryEntry( - entry=MessagePiece( - conversation_id=specific_conversation_id, - role="user", - sequence=1, - original_value="Test content", - converted_value="Test content", - labels={"normalizer_id": "id1"}, - converter_identifiers=converter_identifiers, - prompt_target_identifier=target.get_identifier(), - ) - ) + piece = MessagePiece( + conversation_id=specific_conversation_id, + role="user", + sequence=1, + original_value="Test content", + converted_value="Test content", + labels={"normalizer_id": "id1"}, + converter_identifiers=converter_identifiers, + ) + + memory_interface.add_message_pieces_to_memory( + message_pieces=[piece], target_identifier=target.get_identifier() + ) - # Insert the ConversationData entry - session.add(entry) - session.commit() - - # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id - retrieved_entries = memory_interface.get_conversation(conversation_id=specific_conversation_id) - - # Verify that the retrieved entry matches the inserted entry - assert len(retrieved_entries) == 1 - retrieved_entry = retrieved_entries[0].message_pieces[0] - assert retrieved_entry.conversation_id == specific_conversation_id - assert retrieved_entry.api_role == "user" - assert retrieved_entry.original_value == "Test content" - # For timestamp, you might want to check if it's close to the current time instead of an exact match - assert ( - abs((retrieved_entry.timestamp - entry.timestamp.replace(tzinfo=timezone.utc)).total_seconds()) < 10 - ) # Assuming the test runs quickly - - converter_identifiers = retrieved_entry.converter_identifiers - assert len(converter_identifiers) == 1 - assert converter_identifiers[0].class_name == "Base64Converter" - - prompt_target = retrieved_entry.prompt_target_identifier - assert prompt_target.class_name == "TextTarget" - - labels = retrieved_entry.labels - assert labels["normalizer_id"] == "id1" + # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id + retrieved_entries = memory_interface.get_conversation(conversation_id=specific_conversation_id) + + # Verify that the retrieved entry matches the inserted entry + assert len(retrieved_entries) == 1 + retrieved_entry = retrieved_entries[0].message_pieces[0] + assert retrieved_entry.conversation_id == specific_conversation_id + assert retrieved_entry.api_role == "user" + assert retrieved_entry.original_value == "Test content" + # For timestamp, you might want to check if it's close to the current time instead of an exact match + assert abs((retrieved_entry.timestamp - piece.timestamp).total_seconds()) < 10 # Assuming the test runs quickly + + converter_identifiers = retrieved_entry.converter_identifiers + assert len(converter_identifiers) == 1 + assert converter_identifiers[0].class_name == "Base64Converter" + + # The target identifier is conversation-scoped and stored in the Conversations table. + metadata = memory_interface.get_conversation_metadata(conversation_id=specific_conversation_id) + assert metadata is not None + assert metadata.target_identifier.class_name == "TextTarget" + + labels = retrieved_entry.labels + assert labels["normalizer_id"] == "id1" def test_get_memories_with_attack_id(memory_interface: AzureSQLMemory): diff --git a/tests/unit/memory/test_memory_models.py b/tests/unit/memory/test_memory_models.py index 3f3b7ac990..ca6912d749 100644 --- a/tests/unit/memory/test_memory_models.py +++ b/tests/unit/memory/test_memory_models.py @@ -52,8 +52,6 @@ def _make_message_piece(**overrides) -> MessagePiece: "labels": {"label1": "value1"}, "prompt_metadata": {"meta": "data"}, "converter_identifiers": [ComponentIdentifier(class_name="NoOp", class_module="pyrit.converters")], - "prompt_target_identifier": ComponentIdentifier(class_name="MockTarget", class_module="tests.mocks"), - "attack_identifier": ComponentIdentifier(class_name="MockAttack", class_module="tests.mocks"), "original_value_data_type": "text", "converted_value_data_type": "text", "response_error": "none", @@ -224,16 +222,6 @@ def test_init_stores_converter_identifiers_as_dicts(self): assert isinstance(entry.converter_identifiers, list) assert isinstance(entry.converter_identifiers[0], dict) - def test_init_with_no_attack_identifier(self): - piece = _make_message_piece(attack_identifier=None) - entry = PromptMemoryEntry(entry=piece) - assert entry.attack_identifier == {} - - def test_init_with_no_target_identifier(self): - piece = _make_message_piece(prompt_target_identifier=None) - entry = PromptMemoryEntry(entry=piece) - assert entry.prompt_target_identifier == {} - def test_roundtrip_get_message_piece(self): piece = _make_message_piece() entry = PromptMemoryEntry(entry=piece) @@ -245,17 +233,10 @@ def test_roundtrip_get_message_piece(self): assert recovered.conversation_id == piece.conversation_id assert isinstance(recovered.converter_identifiers[0], ComponentIdentifier) - def test_str_with_target_identifier(self): + def test_str_renders_role_and_value(self): piece = _make_message_piece() entry = PromptMemoryEntry(entry=piece) s = str(entry) - assert "MockTarget" in s - assert "user" in s - - def test_str_without_target_identifier(self): - piece = _make_message_piece(prompt_target_identifier=None) - entry = PromptMemoryEntry(entry=piece) - s = str(entry) assert "user" in s diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index 3eddf11133..c22f6a0bca 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -7,7 +7,6 @@ import tempfile import uuid from collections.abc import Sequence -from datetime import timezone from unittest.mock import MagicMock import pytest @@ -58,7 +57,6 @@ def test_conversation_data_schema(sqlite_instance): "labels", "prompt_metadata", "converter_identifiers", - "prompt_target_identifier", "original_value_data_type", "original_value", "original_value_sha256", @@ -97,8 +95,6 @@ def test_conversation_data_column_types(sqlite_instance): "labels": (String, JSON), "prompt_metadata": (String, JSON), "converter_identifiers": (String, JSON), - "prompt_target_identifier": (String, JSON), - "attack_identifier": (String, JSON), "response_error": String, "original_value_data_type": String, "original_value": String, @@ -522,47 +518,74 @@ def test_get_memories_with_json_properties(sqlite_instance): converter_identifiers = [Base64Converter().get_identifier()] target = TextTarget() - # Start a session - with sqlite_instance.get_session() as session: - # Create a ConversationData entry with all attributes filled - piece = MessagePiece( - conversation_id=specific_conversation_id, - role="user", - sequence=1, - original_value="Test content", - converted_value="Test content", - labels={"normalizer_id": "id1"}, - converter_identifiers=converter_identifiers, - prompt_target_identifier=target.get_identifier(), - ) - entry = PromptMemoryEntry(entry=piece) + piece = MessagePiece( + conversation_id=specific_conversation_id, + role="user", + sequence=1, + original_value="Test content", + converted_value="Test content", + labels={"normalizer_id": "id1"}, + converter_identifiers=converter_identifiers, + ) - # Insert the ConversationData entry - session.add(entry) - session.commit() + sqlite_instance.add_message_pieces_to_memory( + message_pieces=[piece], target_identifier=target.get_identifier() + ) + + # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id + retrieved_entries = sqlite_instance.get_conversation(conversation_id=specific_conversation_id) + + # Verify that the retrieved entry matches the inserted entry + assert len(retrieved_entries) == 1 + retrieved_entry = retrieved_entries[0].message_pieces[0] + assert retrieved_entry.conversation_id == specific_conversation_id + assert retrieved_entry.api_role == "user" + assert retrieved_entry.original_value == "Test content" + # For timestamp, you might want to check if it's close to the current time instead of an exact match + assert abs((retrieved_entry.timestamp - piece.timestamp).total_seconds()) < 0.1 - # Use the get_memories_with_conversation_id method to retrieve entries with the specific conversation_id - retrieved_entries = sqlite_instance.get_conversation(conversation_id=specific_conversation_id) + converter_identifiers = retrieved_entry.converter_identifiers + assert len(converter_identifiers) == 1 + assert converter_identifiers[0].class_name == "Base64Converter" - # Verify that the retrieved entry matches the inserted entry - assert len(retrieved_entries) == 1 - retrieved_entry = retrieved_entries[0].message_pieces[0] - assert retrieved_entry.conversation_id == specific_conversation_id - assert retrieved_entry.api_role == "user" - assert retrieved_entry.original_value == "Test content" - # For timestamp, you might want to check if it's close to the current time instead of an exact match - assert abs((retrieved_entry.timestamp - piece.timestamp).total_seconds()) < 0.1 - assert abs((retrieved_entry.timestamp - entry.timestamp.replace(tzinfo=timezone.utc)).total_seconds()) < 0.1 + # The target identifier is conversation-scoped and stored in the Conversations table. + metadata = sqlite_instance.get_conversation_metadata(conversation_id=specific_conversation_id) + assert metadata is not None + assert metadata.target_identifier.class_name == "TextTarget" - converter_identifiers = retrieved_entry.converter_identifiers - assert len(converter_identifiers) == 1 - assert converter_identifiers[0].class_name == "Base64Converter" + labels = retrieved_entry.labels + assert labels["normalizer_id"] == "id1" - prompt_target = retrieved_entry.prompt_target_identifier - assert prompt_target.class_name == "TextTarget" - labels = retrieved_entry.labels - assert labels["normalizer_id"] == "id1" +def test_capture_conversation_none_target_does_not_clobber(sqlite_instance): + # A conversation is held with a single target. The request piece records the + # target; a later write for the same conversation that has no target (e.g. a + # response or branched copy) must NOT overwrite the recorded target with None. + conversation_id = "conv-none-clobber" + target = TextTarget() + + request_piece = MessagePiece( + conversation_id=conversation_id, + role="user", + sequence=1, + original_value="hello", + ) + sqlite_instance.add_message_pieces_to_memory( + message_pieces=[request_piece], target_identifier=target.get_identifier() + ) + + response_piece = MessagePiece( + conversation_id=conversation_id, + role="assistant", + sequence=2, + original_value="world", + ) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[response_piece], target_identifier=None) + + metadata = sqlite_instance.get_conversation_metadata(conversation_id=conversation_id) + assert metadata is not None + assert metadata.target_identifier is not None + assert metadata.target_identifier.class_name == "TextTarget" def test_update_entries(sqlite_instance): diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index dbd1a8a4d4..2616bd1fd5 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -150,7 +150,6 @@ def set_system_prompt( original_value=system_prompt, converted_value=system_prompt, conversation_id=conversation_id, - attack_identifier=attack_identifier, labels=labels or {}, ).to_message() ) @@ -165,7 +164,6 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me role="assistant", original_value="default", conversation_id=message.message_pieces[0].conversation_id, - attack_identifier=message.message_pieces[0].attack_identifier, labels=message.message_pieces[0].labels, ).to_message() ] @@ -259,7 +257,6 @@ def get_test_message_piece() -> MessagePiece: def get_sample_conversations() -> MutableSequence[Message]: with patch.object(CentralMemory, "get_memory_instance", return_value=MagicMock()): conversation_1 = str(uuid.uuid4()) - attack_id = get_mock_attack_identifier() return [ MessagePiece( @@ -268,7 +265,6 @@ def get_sample_conversations() -> MutableSequence[Message]: converted_value="Hello, how are you?", conversation_id=conversation_1, sequence=0, - attack_identifier=attack_id, ).to_message(), MessagePiece( role="assistant", @@ -276,14 +272,12 @@ def get_sample_conversations() -> MutableSequence[Message]: converted_value="I'm fine, thank you!", conversation_id=conversation_1, sequence=1, - attack_identifier=attack_id, ).to_message(), MessagePiece( role="assistant", original_value="original prompt text", converted_value="I'm fine, thank you!", conversation_id=str(uuid.uuid4()), - attack_identifier=attack_id, ).to_message(), ] diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index ea50d4de7e..1b64fb4378 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -370,11 +370,6 @@ def test_to_dict_from_dict_roundtrip(): class_name="SelfAskTrueFalseScorer", class_module="pyrit.score", ) - target_id = ComponentIdentifier( - class_name="OpenAIChatTarget", - class_module="pyrit.prompt_target", - params={"endpoint": "https://api.example.com"}, - ) attack_id = ComponentIdentifier( class_name="PromptSendingAttack", class_module="pyrit.executor.attack", @@ -386,8 +381,6 @@ def test_to_dict_from_dict_roundtrip(): conversation_id="conv-1", sequence=1, timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), - prompt_target_identifier=target_id, - attack_identifier=attack_id, ) last_score = Score( score_value="true", diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index e8d4457d81..1809257d35 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -10,9 +10,8 @@ from unittest.mock import patch import pytest -from unit.mocks import MockPromptTarget, get_mock_target, get_sample_conversations +from unit.mocks import get_sample_conversations -from pyrit.executor.attack import PromptSendingAttack from pyrit.models import ( ComponentIdentifier, Message, @@ -70,34 +69,6 @@ def test_converters_serialize(): assert converter.class_module == "pyrit.prompt_converter.base64_converter" -def test_prompt_targets_serialize(patch_central_database): - target = MockPromptTarget() - entry = MessagePiece( - role="user", - original_value="Hello", - converted_value="Hello", - prompt_target_identifier=target.get_identifier(), - ) - assert patch_central_database.called - assert entry.prompt_target_identifier.class_name == "MockPromptTarget" - assert entry.prompt_target_identifier.class_module == "unit.mocks" - - -def test_executors_serialize(): - attack = PromptSendingAttack(objective_target=get_mock_target()) - - entry = MessagePiece( - role="user", - original_value="Hello", - converted_value="Hello", - attack_identifier=attack.get_identifier(), - ) - - assert entry.attack_identifier.hash is not None - assert entry.attack_identifier.class_name == "PromptSendingAttack" - assert entry.attack_identifier.class_module == "pyrit.executor.attack.single_turn.prompt_sending" - - async def test_hashes_generated(): entry = MessagePiece( role="user", @@ -689,14 +660,6 @@ def test_message_piece_to_dict(): params={"supported_input_types": ["text"], "supported_output_types": ["text"]}, ) ], - prompt_target_identifier=ComponentIdentifier( - class_name="MockPromptTarget", - class_module="unit.mocks", - ), - attack_identifier=ComponentIdentifier( - class_name="PromptSendingAttack", - class_module="pyrit.executor.attack.single_turn.prompt_sending_attack", - ), scorer_identifier=ComponentIdentifier( class_name="TestScorer", class_module="pyrit.score.test_scorer", @@ -739,8 +702,6 @@ def test_message_piece_to_dict(): "targeted_harm_categories", "prompt_metadata", "converter_identifiers", - "prompt_target_identifier", - "attack_identifier", "scorer_identifier", "original_value_data_type", "original_value", @@ -767,8 +728,6 @@ def test_message_piece_to_dict(): assert result["targeted_harm_categories"] == entry.targeted_harm_categories assert result["prompt_metadata"] == entry.prompt_metadata assert result["converter_identifiers"] == [conv.to_dict() for conv in entry.converter_identifiers] - assert result["prompt_target_identifier"] == entry.prompt_target_identifier.to_dict() - assert result["attack_identifier"] == entry.attack_identifier.to_dict() assert result["scorer_identifier"] == entry.scorer_identifier.to_dict() assert result["original_value_data_type"] == entry.original_value_data_type assert result["original_value"] == entry.original_value @@ -1093,8 +1052,6 @@ def test_to_dict_from_dict_roundtrip(): timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), prompt_metadata={"doc_type": "text"}, converter_identifiers=[converter_id], - prompt_target_identifier=target_id, - attack_identifier=attack_id, original_value_data_type="text", converted_value_data_type="text", response_error="none", @@ -1141,8 +1098,6 @@ def _make_piece(self, **overrides) -> MessagePiece: def test_copies_lineage_fields_from_source_to_target(self) -> None: source = self._make_piece( conversation_id="conv-A", - attack_identifier={"__type__": "Attack", "__module__": "x", "id": "atk-1"}, - prompt_target_identifier={"__type__": "Target", "__module__": "x", "id": "tgt-1"}, ) source.prompt_metadata = {"k": "v"} @@ -1151,8 +1106,6 @@ def test_copies_lineage_fields_from_source_to_target(self) -> None: target.copy_lineage_from(source=source) assert target.conversation_id == "conv-A" - assert target.attack_identifier == source.attack_identifier - assert target.prompt_target_identifier == source.prompt_target_identifier assert target.prompt_metadata == {"k": "v"} def test_labels_and_metadata_are_shallow_copied(self) -> None: @@ -1223,8 +1176,6 @@ def test_to_dict_golden_shape(self) -> None: "targeted_harm_categories", "prompt_metadata", "converter_identifiers", - "prompt_target_identifier", - "attack_identifier", "scorer_identifier", "scores", ] @@ -1238,8 +1189,6 @@ def test_to_dict_golden_shape(self) -> None: assert d["targeted_harm_categories"] == [] assert d["prompt_metadata"] == {} assert d["converter_identifiers"] == [] - assert d["prompt_target_identifier"] is None - assert d["attack_identifier"] is None assert d["scorer_identifier"] is None assert d["original_value_data_type"] == "text" assert d["original_value"] == "hello" diff --git a/tests/unit/prompt_converter/test_persuasion_converter.py b/tests/unit/prompt_converter/test_persuasion_converter.py index 197e5564ef..256eac967a 100644 --- a/tests/unit/prompt_converter/test_persuasion_converter.py +++ b/tests/unit/prompt_converter/test_persuasion_converter.py @@ -7,7 +7,7 @@ from unit.mocks import MockPromptTarget from pyrit.exceptions.exception_classes import InvalidJsonException -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import PersuasionConverter @@ -72,8 +72,6 @@ async def test_persuasion_converter_send_prompt_async_bad_json_exception_retries converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] @@ -101,7 +99,6 @@ async def test_persuasion_converter_extracts_mutated_text(sqlite_instance): conversation_id="test-id", original_value='{"mutated_text": "rephrased prompt"}', original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -123,7 +120,6 @@ async def test_persuasion_converter_missing_mutated_text_raises_invalid_json(sql conversation_id="test-id", original_value='{"other_key": "value"}', original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -164,7 +160,6 @@ async def test_send_persuasion_prompt_async_emits_deprecation_warning_and_delega conversation_id="conv-1", original_value="test input", original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), ) ] ) diff --git a/tests/unit/prompt_converter/test_translation_converter.py b/tests/unit/prompt_converter/test_translation_converter.py index 9073dea9ac..c5e1aa2702 100644 --- a/tests/unit/prompt_converter/test_translation_converter.py +++ b/tests/unit/prompt_converter/test_translation_converter.py @@ -7,7 +7,7 @@ import pytest from unit.mocks import MockPromptTarget -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import TranslationConverter @@ -40,7 +40,6 @@ async def test_translation_converter_returns_stripped_response(sqlite_instance): conversation_id="test-id", original_value=" hola \n", original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -72,7 +71,6 @@ async def test_translation_converter_user_prompt_byte_for_byte_equivalent(sqlite conversation_id="test-id", original_value="hola", original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -119,7 +117,6 @@ async def test_translation_converter_succeeds_after_retries(sqlite_instance): converted_value="hola", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test-identifier", class_module="test"), sequence=1, ) ] diff --git a/tests/unit/prompt_converter/test_variation_converter.py b/tests/unit/prompt_converter/test_variation_converter.py index 542fccf0c1..e62d880f9d 100644 --- a/tests/unit/prompt_converter/test_variation_converter.py +++ b/tests/unit/prompt_converter/test_variation_converter.py @@ -7,7 +7,7 @@ from unit.mocks import MockPromptTarget from pyrit.exceptions.exception_classes import InvalidJsonException -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import VariationConverter @@ -44,8 +44,6 @@ async def test_variation_converter_send_prompt_async_bad_json_exception_retries( converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] @@ -72,7 +70,6 @@ async def test_variation_converter_extracts_first_element_from_json_list(sqlite_ conversation_id="test-id", original_value='["first variation", "second variation"]', original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -93,7 +90,6 @@ async def test_variation_converter_preserves_original_and_converted_values(sqlit conversation_id="test-id", original_value='["variation"]', original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) ] @@ -128,7 +124,6 @@ async def test_send_variation_prompt_async_emits_deprecation_warning_and_delegat conversation_id="conv-1", original_value="test input", original_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), ) ] ) diff --git a/tests/unit/prompt_normalizer/test_prompt_normalizer.py b/tests/unit/prompt_normalizer/test_prompt_normalizer.py index decc056441..55a4a4f818 100644 --- a/tests/unit/prompt_normalizer/test_prompt_normalizer.py +++ b/tests/unit/prompt_normalizer/test_prompt_normalizer.py @@ -614,7 +614,6 @@ def test_memory_property_raises_when_memory_none(): async def test_add_prepended_conversation_to_memory(mock_memory_instance): normalizer = PromptNormalizer() conv_id = "test-conv-id" - attack_id = get_mock_attack_identifier() piece = MessagePiece(role="user", original_value="prepended text", conversation_id="old-id") message = Message(message_pieces=[piece]) @@ -622,14 +621,12 @@ async def test_add_prepended_conversation_to_memory(mock_memory_instance): result = await normalizer.add_prepended_conversation_to_memory_async( conversation_id=conv_id, should_convert=False, - attack_identifier=attack_id, prepended_conversation=[message], ) assert result is not None assert len(result) == 1 assert result[0].message_pieces[0].conversation_id == conv_id - assert result[0].message_pieces[0].attack_identifier == attack_id mock_memory_instance.add_message_to_memory.assert_called_once() @@ -847,4 +844,5 @@ async def test_add_prepended_conversation_to_memory_emits_deprecation_warning_an converter_configurations=None, attack_identifier=None, prepended_conversation=None, + target_identifier=None, ) diff --git a/tests/unit/prompt_target/target/test_http_target.py b/tests/unit/prompt_target/target/test_http_target.py index e31fa005af..aef915d8b2 100644 --- a/tests/unit/prompt_target/target/test_http_target.py +++ b/tests/unit/prompt_target/target/test_http_target.py @@ -70,7 +70,6 @@ async def test_send_prompt_async(mock_request, mock_http_target, mock_http_respo MagicMock( converted_value="test_prompt", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, @@ -125,7 +124,6 @@ async def test_send_prompt_async_client_kwargs(patch_central_database): MagicMock( converted_value="", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, @@ -164,7 +162,6 @@ async def test_send_prompt_regex_parse_async(mock_request, mock_http_target): MagicMock( converted_value="test_prompt", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, @@ -200,7 +197,6 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http MagicMock( converted_value="test_prompt", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, @@ -228,7 +224,6 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http MagicMock( converted_value="second_test_prompt", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, @@ -285,7 +280,6 @@ async def test_http_target_with_injected_client(patch_central_database): MagicMock( converted_value="test_prompt", converted_value_data_type="text", - prompt_target_identifier=None, attack_identifier=None, conversation_id="", labels={}, diff --git a/tests/unit/prompt_target/target/test_normalize_async_integration.py b/tests/unit/prompt_target/target/test_normalize_async_integration.py index 2bd58d18a0..3b14a24169 100644 --- a/tests/unit/prompt_target/target/test_normalize_async_integration.py +++ b/tests/unit/prompt_target/target/test_normalize_async_integration.py @@ -16,7 +16,7 @@ from pyrit.memory.memory_interface import MemoryInterface from pyrit.message_normalizer import GenericSystemSquashNormalizer -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.prompt_target import AzureMLChatTarget, OpenAIChatTarget from pyrit.prompt_target.common.target_capabilities import ( CapabilityHandlingPolicy, @@ -36,8 +36,6 @@ def _make_message_piece(*, role: str, content: str, conversation_id: str = "conv converted_value=content, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), ) diff --git a/tests/unit/prompt_target/target/test_openai_chat_target.py b/tests/unit/prompt_target/target/test_openai_chat_target.py index 6483c931ea..e32ce00100 100644 --- a/tests/unit/prompt_target/target/test_openai_chat_target.py +++ b/tests/unit/prompt_target/target/test_openai_chat_target.py @@ -25,7 +25,7 @@ RateLimitException, ) from pyrit.memory.memory_interface import MemoryInterface -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target import ( OpenAIChatAudioConfig, @@ -283,8 +283,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -294,8 +292,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -393,8 +389,6 @@ async def test_send_prompt_async(openai_response_json: dict, patch_central_datab converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -404,8 +398,6 @@ async def test_send_prompt_async(openai_response_json: dict, patch_central_datab converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -458,8 +450,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -469,8 +459,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] diff --git a/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py b/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py index c765e7e4d6..451e353ece 100644 --- a/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py +++ b/tests/unit/prompt_target/target/test_openai_realtime_streaming_session.py @@ -12,7 +12,7 @@ import pytest -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.prompt_target.common.realtime_audio import ( STREAMING_INTERRUPTED_KEY, CommittedEvent, @@ -314,7 +314,7 @@ async def test_run_async_swaps_user_audio_and_records_identifiers_when_request_c persisted_user_messages: list[Message] = [] - async def _capture(*, message: Message) -> None: + async def _capture(*, message: Message, target_identifier=None) -> None: if message.message_pieces[0].api_role == "user": persisted_user_messages.append(message) @@ -349,7 +349,7 @@ async def test_run_async_skips_swap_and_identifiers_when_no_request_converters() persisted_user_messages: list[Message] = [] - async def _capture(*, message: Message) -> None: + async def _capture(*, message: Message, target_identifier=None) -> None: if message.message_pieces[0].api_role == "user": persisted_user_messages.append(message) @@ -642,73 +642,6 @@ async def _fire() -> None: assert len(calls[2].args[0]) == 14400 -# --------------------------------------------------------------------------- -# 9. attack_identifier is stamped on persisted user + assistant pieces -# --------------------------------------------------------------------------- - - -async def test_attack_identifier_stamped_on_persisted_pieces_when_set(): - """When ``attack_identifier`` is provided, every persisted piece carries it.""" - target = _build_target() - normalizer = _build_normalizer() - - persisted_messages: list[Message] = [] - - async def _capture(*, message: Message) -> None: - persisted_messages.append(message) - - normalizer.hash_and_persist_message_async = AsyncMock(side_effect=_capture) - - attack_id = ComponentIdentifier(class_name="BargeInAttack", class_module="test") - - finish = asyncio.Event() - session = _OpenAIRealtimeStreamingSession( - target=target, - audio_chunks=_paced_chunks([b"\x01" * 96], finish), - prompt_normalizer=normalizer, - attack_identifier=attack_id, - ) - _mock_session_wire(session) - - with _patched_dispatcher(): - await _run_session_with_events(session, finish=finish, events=[CommittedEvent(item_id="i")]) - - # Expect one user message + one assistant message (two pieces) — three pieces total. - all_pieces = [piece for msg in persisted_messages for piece in msg.message_pieces] - assert len(all_pieces) == 3 - for piece in all_pieces: - assert piece.attack_identifier == attack_id - - -async def test_attack_identifier_absent_when_not_provided(): - """Without ``attack_identifier``, persisted pieces have None attribution (back-compat).""" - target = _build_target() - normalizer = _build_normalizer() - - persisted_messages: list[Message] = [] - - async def _capture(*, message: Message) -> None: - persisted_messages.append(message) - - normalizer.hash_and_persist_message_async = AsyncMock(side_effect=_capture) - - finish = asyncio.Event() - session = _OpenAIRealtimeStreamingSession( - target=target, - audio_chunks=_paced_chunks([b"\x01" * 96], finish), - prompt_normalizer=normalizer, - ) - _mock_session_wire(session) - - with _patched_dispatcher(): - await _run_session_with_events(session, finish=finish, events=[CommittedEvent(item_id="i")]) - - all_pieces = [piece for msg in persisted_messages for piece in msg.message_pieces] - assert len(all_pieces) == 3 - for piece in all_pieces: - assert piece.attack_identifier is None - - # --------------------------------------------------------------------------- # 10. persist_prepended_conversation=False skips the prepended-memory write # --------------------------------------------------------------------------- @@ -770,7 +703,6 @@ async def _empty(): req_cfgs = [MagicMock(name="req_cfg")] resp_cfgs = [MagicMock(name="resp_cfg")] vad = ServerVadConfig(prefix_padding_ms=42) - attack_id = {"__type__": "BargeInAttack", "id": "x"} captured: dict[str, Any] = {} @@ -790,7 +722,6 @@ def _fake_session_ctor(**kwargs): response_converter_configurations=resp_cfgs, prepended_conversation=prepended, server_vad=vad, - attack_identifier=attack_id, persist_prepended_conversation=False, ) @@ -802,7 +733,6 @@ def _fake_session_ctor(**kwargs): assert captured["response_converter_configurations"] is resp_cfgs assert captured["prepended_conversation"] is prepended assert captured["server_vad"] is vad - assert captured["attack_identifier"] is attack_id assert captured["persist_prepended_conversation"] is False diff --git a/tests/unit/prompt_target/target/test_openai_response_target.py b/tests/unit/prompt_target/target/test_openai_response_target.py index 3e90cbc00c..6a9cb9952f 100644 --- a/tests/unit/prompt_target/target/test_openai_response_target.py +++ b/tests/unit/prompt_target/target/test_openai_response_target.py @@ -23,7 +23,7 @@ RateLimitException, ) from pyrit.memory.memory_interface import MemoryInterface -from pyrit.models import ComponentIdentifier, Message, MessagePiece +from pyrit.models import Message, MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target import OpenAIResponseTarget, PromptTarget @@ -306,8 +306,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory( converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -317,8 +315,6 @@ async def test_send_prompt_async_empty_response_adds_to_memory( converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -399,8 +395,6 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIRespo converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -410,8 +404,6 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIRespo converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -445,8 +437,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -456,8 +446,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier=ComponentIdentifier(class_name="target-identifier", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py index 4ad66bdbde..161070de85 100644 --- a/tests/unit/prompt_target/target/test_prompt_target.py +++ b/tests/unit/prompt_target/target/test_prompt_target.py @@ -60,7 +60,6 @@ def test_set_system_prompt(azure_openai_target: OpenAIChatTarget, mock_attack_st azure_openai_target.set_system_prompt( system_prompt="system prompt", conversation_id="1", - attack_identifier=mock_attack_strategy.get_identifier(), labels={}, ) @@ -76,7 +75,6 @@ async def test_set_system_prompt_adds_memory( azure_openai_target.set_system_prompt( system_prompt="system prompt", conversation_id="1", - attack_identifier=mock_attack_strategy.get_identifier(), labels={}, ) @@ -110,7 +108,6 @@ async def test_send_prompt_with_system_calls_chat_complete( azure_openai_target.set_system_prompt( system_prompt="system prompt", conversation_id="1", - attack_identifier=mock_attack_strategy.get_identifier(), labels={}, ) @@ -164,8 +161,6 @@ async def test_send_prompt_async_with_delay( _LINEAGE_CONVERSATION_ID = "original-conv-id-12345" _LINEAGE_LABELS = {"op_name": "test_op", "user_id": "user42"} -_LINEAGE_ATTACK_IDENTIFIER = ComponentIdentifier(class_name="TestAttack", class_module="tests.attacks") -_LINEAGE_PROMPT_TARGET_IDENTIFIER = ComponentIdentifier(class_name="OpenAIChatTarget", class_module="pyrit") _LINEAGE_PROMPT_METADATA = {"scenario": "test_scenario", "turn": 3} @@ -178,8 +173,6 @@ def _make_lineage_piece(*, role: str, content: str) -> MessagePiece: original_value_data_type="text", converted_value_data_type="text", labels=dict(_LINEAGE_LABELS), - prompt_target_identifier=_LINEAGE_PROMPT_TARGET_IDENTIFIER, - attack_identifier=_LINEAGE_ATTACK_IDENTIFIER, prompt_metadata=dict(_LINEAGE_PROMPT_METADATA), ) @@ -242,8 +235,6 @@ async def test_history_squash_preserves_metadata_on_normalized_message(): assert normalized_piece.conversation_id == _LINEAGE_CONVERSATION_ID assert normalized_piece.labels == _LINEAGE_LABELS - assert normalized_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER - assert normalized_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert normalized_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -291,8 +282,6 @@ async def test_response_preserves_metadata_after_history_squash(): assert response_piece.conversation_id == _LINEAGE_CONVERSATION_ID assert response_piece.labels == _LINEAGE_LABELS - assert response_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER - assert response_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert response_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -338,8 +327,6 @@ async def test_system_squash_preserves_metadata(): assert normalized_piece.conversation_id == _LINEAGE_CONVERSATION_ID assert normalized_piece.labels == _LINEAGE_LABELS - assert normalized_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER - assert normalized_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert normalized_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -389,8 +376,6 @@ async def test_history_squash_propagates_lineage_to_all_pieces(): for piece in normalized[0].message_pieces: assert piece.conversation_id == _LINEAGE_CONVERSATION_ID assert piece.labels == _LINEAGE_LABELS - assert piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER - assert piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert piece.prompt_metadata == _LINEAGE_PROMPT_METADATA @@ -453,8 +438,6 @@ async def test_conversation_id_stamped_on_all_but_full_lineage_only_on_last(): # Last message should carry full lineage. last_piece = normalized[-1].message_pieces[0] assert last_piece.labels == _LINEAGE_LABELS - assert last_piece.attack_identifier == _LINEAGE_ATTACK_IDENTIFIER - assert last_piece.prompt_target_identifier == _LINEAGE_PROMPT_TARGET_IDENTIFIER assert last_piece.prompt_metadata == _LINEAGE_PROMPT_METADATA # Warning should fire because message count increased (2 → 3). diff --git a/tests/unit/prompt_target/test_round_robin_target.py b/tests/unit/prompt_target/test_round_robin_target.py index 6218ea5013..bb2bf55bf7 100644 --- a/tests/unit/prompt_target/test_round_robin_target.py +++ b/tests/unit/prompt_target/test_round_robin_target.py @@ -373,12 +373,10 @@ async def test_full_send_prompt_async_keeps_round_robin_identifier(): for piece in message.message_pieces: piece.conversation_id = conv_id # Simulate what PromptNormalizer does - piece.prompt_target_identifier = rr.get_identifier() responses = await rr.send_prompt_async(message=message) # The request should still have the round-robin's identifier - assert message.message_pieces[0].prompt_target_identifier == rr.get_identifier() # Only t1 should have received the prompt (first in rotation) assert t1.prompt_sent == ["end to end test"] diff --git a/tests/unit/score/test_conversation_history_scorer.py b/tests/unit/score/test_conversation_history_scorer.py index 0e957482a2..8c6c3ff71d 100644 --- a/tests/unit/score/test_conversation_history_scorer.py +++ b/tests/unit/score/test_conversation_history_scorer.py @@ -251,8 +251,6 @@ async def test_conversation_history_scorer_preserves_metadata(patch_central_data original_value="Response", conversation_id=conversation_id, labels={"test": "label"}, - prompt_target_identifier=ComponentIdentifier(class_name="test", class_module="test"), - attack_identifier=ComponentIdentifier(class_name="test", class_module="test"), sequence=1, ) @@ -288,8 +286,6 @@ async def test_conversation_history_scorer_preserves_metadata(patch_central_data assert called_piece.id == message_piece.id assert called_piece.conversation_id == message_piece.conversation_id assert called_piece.labels == message_piece.labels - assert called_piece.prompt_target_identifier == message_piece.prompt_target_identifier - assert called_piece.attack_identifier == message_piece.attack_identifier async def test_conversation_scorer_regenerates_score_ids_to_prevent_collisions(patch_central_database): diff --git a/tests/unit/score/test_scorer.py b/tests/unit/score/test_scorer.py index 01378ebde3..6491ccefeb 100644 --- a/tests/unit/score/test_scorer.py +++ b/tests/unit/score/test_scorer.py @@ -205,71 +205,6 @@ async def test_scorer_score_value_with_llm_exception_display_prompt_id(): ) -async def test_scorer_score_value_with_llm_use_provided_attack_identifier(good_json): - scorer = MockScorer() - - message = Message( - message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] - ) - chat_target = MagicMock(PromptTarget) - chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") - chat_target.send_prompt_async = AsyncMock(return_value=[message]) - chat_target.set_system_prompt = MagicMock() - - expected_system_prompt = "system_prompt" - expected_attack_identifier = ComponentIdentifier(class_name="TestAttack", class_module="test.module") - expected_scored_prompt_id = "123" - - await scorer._score_value_with_llm_async( - prompt_target=chat_target, - system_prompt=expected_system_prompt, - message_value="message_value", - message_data_type="text", - scored_prompt_id=expected_scored_prompt_id, - category="category", - objective="task", - attack_identifier=expected_attack_identifier, - ) - - chat_target.set_system_prompt.assert_called_once() - - _, set_sys_prompt_args = chat_target.set_system_prompt.call_args - assert set_sys_prompt_args["system_prompt"] == expected_system_prompt - assert isinstance(set_sys_prompt_args["conversation_id"], str) - assert set_sys_prompt_args["attack_identifier"] is expected_attack_identifier - - -async def test_scorer_score_value_with_llm_does_not_add_score_prompt_id_for_empty_attack_identifier(good_json): - scorer = MockScorer() - - message = Message( - message_pieces=[MessagePiece(role="assistant", original_value=good_json, conversation_id="test-convo")] - ) - chat_target = MagicMock(PromptTarget) - chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget") - chat_target.send_prompt_async = AsyncMock(return_value=[message]) - chat_target.set_system_prompt = MagicMock() - - expected_system_prompt = "system_prompt" - - await scorer._score_value_with_llm_async( - prompt_target=chat_target, - system_prompt=expected_system_prompt, - message_value="message_value", - message_data_type="text", - scored_prompt_id="123", - category="category", - objective="task", - ) - - chat_target.set_system_prompt.assert_called_once() - - _, set_sys_prompt_args = chat_target.set_system_prompt.call_args - assert set_sys_prompt_args["system_prompt"] == expected_system_prompt - assert isinstance(set_sys_prompt_args["conversation_id"], str) - assert not set_sys_prompt_args["attack_identifier"] - - async def test_scorer_send_chat_target_async_good_response(good_json): chat_target = MagicMock(PromptTarget) chat_target.get_identifier.return_value = get_mock_target_identifier("MockChatTarget")