Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions doc/code/memory/10_schema_diagram.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ flowchart LR
P_labels["labels (VARCHAR)"]
P_prompt_metadata["prompt_metadata (VARCHAR)"]
P_converter_identifiers["converter_identifiers (VARCHAR)"]
P_prompt_target_identifier["prompt_target_identifier (VARCHAR)"]
P_attack_identifier["attack_identifier (VARCHAR)"]
P_response_error["response_error (VARCHAR)"]
P_converted_value_data_type["converted_value_data_type (VARCHAR)"]
P_converted_value["converted_value (VARCHAR)"]
P_converted_value_sha256["converted_value_sha256 (VARCHAR)"]
P_original_prompt_id["original_prompt_id (UUID)"]
end
subgraph Conversations["Conversations"]
C_conversation_id["conversation_id (VARCHAR)"]
C_target_identifier["target_identifier (VARCHAR)"]
end
subgraph ScoreEntries["ScoreEntries"]
Sc_id["id (UUID)"]
Sc_prompt_request_response_id["prompt_request_response_id (VARCHAR)"]
Expand All @@ -63,6 +65,7 @@ flowchart LR
end
S_value_sha256 -- N:N relationship to query --> P_original_value_sha256
P_id -- 1:N relationship to query --> Sc_prompt_request_response_id
P_conversation_id -- N:1 relationship to query --> C_conversation_id

style S_value_sha256 fill:#ff8800ff
style P_id fill:#14a519ff
Expand Down
4 changes: 2 additions & 2 deletions doc/code/memory/3_memory_data_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ One of the most fundamental data structures in PyRIT is [MessagePiece](../../../
- **`labels`**: Dictionary of labels for categorization and filtering
- **`prompt_metadata`**: Component-specific metadata (e.g., blob URIs, document types)
- **`converter_identifiers`**: List of converters applied to transform the prompt
- **`prompt_target_identifier`**: Information about the target that received this prompt
- **`attack_identifier`**: Information about the attack that generated this prompt
- **`scorer_identifier`**: Information about the scorer that evaluated this prompt
- **`response_error`**: Error status (e.g., `none`, `blocked`, `processing`)
- **`originator`**: Source of the prompt (`attack`, `converter`, `scorer`, `undefined`)
Expand Down Expand Up @@ -54,6 +52,8 @@ This rich context allows PyRIT to track the full lifecycle of each interaction,

A conversation is a list of `Messages` that share the same `conversation_id`. The sequence of the `MessagePieces` and their corresponding `Messages` dictates the order of the conversation.

A conversation is always held with a single target. That target's identifier is recorded once per conversation in the `Conversations` table (`target_identifier`) rather than on every `MessagePiece`. Use `memory.get_conversation_metadata(conversation_id=...)` to retrieve it.

Here is a sample conversation made up of three `Messages` which all share the same conversation ID. The first `Message` is the `system` message, followed by a multi-modal `user` prompt with a text `MessagePiece` and an image `MessagePiece`, and finally the `assistant` response in the form of a text `MessagePiece`.

```{mermaid}
Expand Down
26 changes: 23 additions & 3 deletions pyrit/backend/services/attack_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt
cutoff_index=request.cutoff_index,
labels_override=labels,
remap_assistant_to_simulated=True,
target_identifier=target_identifier,
)
else:
conversation_id = str(uuid.uuid4())
Expand Down Expand Up @@ -344,6 +345,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt
conversation_id=conversation_id,
prepended=request.prepended_conversation,
labels=labels, # deprecated
target_identifier=target_identifier,
)

return CreateAttackResponse(
Expand Down Expand Up @@ -475,9 +477,13 @@ async def create_related_conversation_async(

# --- Branch via duplication (preferred for tracking) ---------------
if request.source_conversation_id is not None and request.cutoff_index is not None:
source_metadata = self._memory.get_conversation_metadata(
conversation_id=request.source_conversation_id
)
new_conversation_id = self._duplicate_conversation_up_to(
source_conversation_id=request.source_conversation_id,
cutoff_index=request.cutoff_index,
target_identifier=source_metadata.target_identifier if source_metadata else None,
)
else:
new_conversation_id = str(uuid.uuid4())
Expand Down Expand Up @@ -622,11 +628,13 @@ async def add_message_async(self, *, attack_result_id: str, request: AddMessageR
labels=attack_labels, # deprecated
)
else:
existing_metadata = self._memory.get_conversation_metadata(conversation_id=msg_conversation_id)
await self._store_message_only_async(
conversation_id=msg_conversation_id,
request=request,
sequence=sequence,
labels=attack_labels, # deprecated
target_identifier=existing_metadata.target_identifier if existing_metadata else None,
)

await self._update_attack_after_message_async(attack_result_id=attack_result_id, ar=ar, request=request)
Expand Down Expand Up @@ -828,6 +836,7 @@ def _duplicate_conversation_up_to(
cutoff_index: int,
labels_override: dict[str, str] | None = None,
remap_assistant_to_simulated: bool = False,
target_identifier: ComponentIdentifier | None = None,
) -> str:
"""
Duplicate messages from a conversation up to and including a turn index.
Expand All @@ -846,6 +855,9 @@ def _duplicate_conversation_up_to(
``assistant`` are changed to ``simulated_assistant`` so the
branched context is inert and won't confuse the target.

target_identifier (ComponentIdentifier | None): The target the new conversation
is held with, if known. Recorded once for the duplicated conversation.

Returns:
The new conversation ID containing the duplicated messages.
"""
Expand All @@ -865,7 +877,9 @@ def _duplicate_conversation_up_to(
piece.role = "simulated_assistant"

if all_pieces:
self._memory.add_message_pieces_to_memory(message_pieces=list(all_pieces))
self._memory.add_message_pieces_to_memory(
message_pieces=list(all_pieces), target_identifier=target_identifier
)

return new_conversation_id

Expand Down Expand Up @@ -953,6 +967,7 @@ async def _store_prepended_messages_async(
conversation_id: str,
prepended: list[Any],
labels: dict[str, str] | None = None, # deprecated
target_identifier: ComponentIdentifier | None = None,
) -> None:
"""Store prepended conversation messages in memory."""
for seq, msg in enumerate(prepended):
Expand All @@ -964,7 +979,9 @@ async def _store_prepended_messages_async(
sequence=seq,
labels=labels, # deprecated
)
self._memory.add_message_pieces_to_memory(message_pieces=[piece])
self._memory.add_message_pieces_to_memory(
message_pieces=[piece], target_identifier=target_identifier
)

async def _send_and_store_message_async(
self,
Expand Down Expand Up @@ -1010,6 +1027,7 @@ async def _store_message_only_async(
request: AddMessageRequest,
sequence: int,
labels: dict[str, str] | None = None, # deprecated
target_identifier: ComponentIdentifier | None = None,
) -> None:
"""Store message without sending (send=False)."""
await self._persist_base64_pieces_async(request)
Expand All @@ -1021,7 +1039,9 @@ async def _store_message_only_async(
sequence=sequence,
labels=labels, # deprecated
)
self._memory.add_message_pieces_to_memory(message_pieces=[piece])
self._memory.add_message_pieces_to_memory(
message_pieces=[piece], target_identifier=target_identifier
)

def _resolve_video_remix_metadata(self, request: AddMessageRequest) -> None:
"""
Expand Down
21 changes: 9 additions & 12 deletions pyrit/executor/attack/component/conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def get_adversarial_chat_messages(
prepended_conversation: list[Message],
*,
adversarial_chat_conversation_id: str,
attack_identifier: ComponentIdentifier,
adversarial_chat_target_identifier: ComponentIdentifier,
labels: dict[str, str] | None = None, # deprecated
) -> list[Message]:
"""
Expand All @@ -72,8 +70,6 @@ def get_adversarial_chat_messages(
Args:
prepended_conversation: The original conversation messages to transform.
adversarial_chat_conversation_id: Conversation ID for the adversarial chat.
attack_identifier (ComponentIdentifier): Attack identifier to associate with messages.
adversarial_chat_target_identifier (ComponentIdentifier): Target identifier for the adversarial chat.
labels: Optional labels to associate with the messages.
Deprecated: This parameter will be removed in a release 0.16.0.

Expand Down Expand Up @@ -114,8 +110,6 @@ def get_adversarial_chat_messages(
original_value_data_type=piece.original_value_data_type,
converted_value_data_type=piece.converted_value_data_type,
conversation_id=adversarial_chat_conversation_id,
attack_identifier=attack_identifier,
prompt_target_identifier=adversarial_chat_target_identifier,
labels=labels or {}, # deprecated
)

Expand Down Expand Up @@ -190,20 +184,17 @@ class ConversationManager:
def __init__(
self,
*,
attack_identifier: ComponentIdentifier,
prompt_normalizer: PromptNormalizer | None = None,
) -> None:
"""
Initialize the conversation manager.

Args:
attack_identifier (ComponentIdentifier): The identifier of the attack this manager belongs to.
prompt_normalizer: Optional prompt normalizer for converting prompts.
If not provided, a default PromptNormalizer instance will be created.
"""
self._prompt_normalizer = prompt_normalizer or PromptNormalizer()
self._memory = CentralMemory.get_memory_instance()
self._attack_identifier = attack_identifier

def get_conversation(self, conversation_id: str) -> list[Message]:
"""
Expand Down Expand Up @@ -276,7 +267,6 @@ def set_system_prompt(
target.set_system_prompt(
system_prompt=system_prompt,
conversation_id=conversation_id,
attack_identifier=self._attack_identifier,
labels=labels, # deprecated
)

Expand Down Expand Up @@ -359,6 +349,7 @@ async def initialize_context_async(
request_converters=request_converters,
prepended_conversation_config=prepended_conversation_config,
max_turns=max_turns,
target_identifier=target.get_identifier(),
)

async def _handle_non_chat_target_async(
Expand Down Expand Up @@ -439,6 +430,7 @@ async def add_prepended_conversation_to_memory_async(
request_converters: list[PromptConverterConfiguration] | None = None,
prepended_conversation_config: Optional["PrependedConversationConfig"] = None,
max_turns: int | None = None,
target_identifier: ComponentIdentifier | None = None,
) -> int:
"""
Add prepended conversation messages to memory for a chat target.
Expand All @@ -459,6 +451,8 @@ async def add_prepended_conversation_to_memory_async(
request_converters: Optional converters to apply to messages.
prepended_conversation_config: Optional configuration for converter roles.
max_turns: If provided, validates that turn count doesn't exceed this limit.
target_identifier (ComponentIdentifier | None): The target the conversation is held
with, if known. Recorded once per conversation.

Returns:
The number of turns (assistant messages) added.
Expand All @@ -485,7 +479,6 @@ async def add_prepended_conversation_to_memory_async(

for piece in message_copy.message_pieces:
piece.conversation_id = conversation_id
piece.attack_identifier = self._attack_identifier

# Count turns at message level (only assistant/simulated_assistant messages)
# A multi-part response still counts as one turn
Expand All @@ -506,7 +499,7 @@ async def add_prepended_conversation_to_memory_async(
)

# Add to memory
self._memory.add_message_to_memory(request=message_copy)
self._memory.add_message_to_memory(request=message_copy, target_identifier=target_identifier)
logger.debug(f"Added prepended message {i + 1}/{len(valid_messages)} to memory")

return turn_count
Expand All @@ -520,6 +513,7 @@ async def _process_prepended_for_chat_target_async(
request_converters: list[PromptConverterConfiguration] | None,
prepended_conversation_config: Optional["PrependedConversationConfig"],
max_turns: int | None,
target_identifier: ComponentIdentifier | None = None,
) -> ConversationState:
"""
Process prepended conversation for a chat target.
Expand All @@ -536,6 +530,8 @@ async def _process_prepended_for_chat_target_async(
request_converters: Converters to apply.
prepended_conversation_config: Configuration for converter roles.
max_turns: Maximum turns for validation.
target_identifier (ComponentIdentifier | None): The objective target the
conversation is held with, if known.

Returns:
ConversationState with turn_count and scores.
Expand All @@ -555,6 +551,7 @@ async def _process_prepended_for_chat_target_async(
request_converters=request_converters,
prepended_conversation_config=prepended_conversation_config,
max_turns=max_turns,
target_identifier=target_identifier,
)

# Update context for multi-turn attacks to reflect prepended_conversation
Expand Down
4 changes: 0 additions & 4 deletions pyrit/executor/attack/multi_turn/chunked_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def __init__(
# Initialize prompt normalizer and conversation manager
self._prompt_normalizer = prompt_normalizer or PromptNormalizer()
self._conversation_manager = ConversationManager(
attack_identifier=self.get_identifier(),
prompt_normalizer=self._prompt_normalizer,
)

Expand Down Expand Up @@ -279,7 +278,6 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac
with execution_context(
component_role=ComponentRole.OBJECTIVE_TARGET,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._objective_target.get_identifier(),
objective_target_conversation_id=context.session.conversation_id,
objective=context.objective,
Expand All @@ -291,7 +289,6 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac
request_converter_configurations=self._request_converters,
response_converter_configurations=self._response_converters,
labels=context.memory_labels,
attack_identifier=self.get_identifier(),
)

# Store the response
Expand Down Expand Up @@ -377,7 +374,6 @@ async def _score_combined_value_async(
with execution_context(
component_role=ComponentRole.OBJECTIVE_SCORER,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._objective_scorer.get_identifier(),
objective=objective,
):
Expand Down
8 changes: 0 additions & 8 deletions pyrit/executor/attack/multi_turn/crescendo.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def __init__(
# Initialize utilities
self._prompt_normalizer = prompt_normalizer or PromptNormalizer()
self._conversation_manager = ConversationManager(
attack_identifier=self.get_identifier(),
prompt_normalizer=self._prompt_normalizer,
)

Expand Down Expand Up @@ -331,7 +330,6 @@ async def _setup_async(self, *, context: CrescendoAttackContext) -> None:
self._adversarial_chat.set_system_prompt(
system_prompt=system_prompt,
conversation_id=context.session.adversarial_chat_conversation_id,
attack_identifier=self.get_identifier(),
labels=context.memory_labels, # deprecated
)

Expand Down Expand Up @@ -545,7 +543,6 @@ async def _send_prompt_to_adversarial_chat_async(
with execution_context(
component_role=ComponentRole.ADVERSARIAL_CHAT,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._adversarial_chat.get_identifier(),
objective_target_conversation_id=context.session.conversation_id,
objective=context.objective,
Expand All @@ -554,7 +551,6 @@ async def _send_prompt_to_adversarial_chat_async(
message=message,
conversation_id=context.session.adversarial_chat_conversation_id,
target=self._adversarial_chat,
attack_identifier=self.get_identifier(),
labels=context.memory_labels,
)

Expand Down Expand Up @@ -649,7 +645,6 @@ async def _send_prompt_to_objective_target_async(
with execution_context(
component_role=ComponentRole.OBJECTIVE_TARGET,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._objective_target.get_identifier(),
objective_target_conversation_id=context.session.conversation_id,
objective=context.objective,
Expand All @@ -660,7 +655,6 @@ async def _send_prompt_to_objective_target_async(
conversation_id=context.session.conversation_id,
request_converter_configurations=self._request_converters,
response_converter_configurations=self._response_converters,
attack_identifier=self.get_identifier(),
labels=context.memory_labels,
)

Expand Down Expand Up @@ -689,7 +683,6 @@ async def _check_refusal_async(self, context: CrescendoAttackContext, objective:
with execution_context(
component_role=ComponentRole.REFUSAL_SCORER,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._refusal_scorer.get_identifier(),
objective_target_conversation_id=context.session.conversation_id,
objective=context.objective,
Expand Down Expand Up @@ -721,7 +714,6 @@ async def _score_response_async(self, *, context: CrescendoAttackContext) -> Sco
with execution_context(
component_role=ComponentRole.OBJECTIVE_SCORER,
attack_strategy_name=self.__class__.__name__,
attack_identifier=self.get_identifier(),
component_identifier=self._objective_scorer.get_identifier(),
objective_target_conversation_id=context.session.conversation_id,
objective=context.objective,
Expand Down
Loading
Loading