Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 25 additions & 55 deletions pyrit/executor/attack/component/conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import logging
import uuid
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

from pyrit.memory import CentralMemory
from pyrit.models import ChatMessageRole, Message, MessagePiece, Score
from pyrit.prompt_normalizer.prompt_converter_configuration import (
PromptConverterConfiguration,
)
from pyrit.prompt_normalizer.prompt_normalizer import PromptNormalizer
from pyrit.prompt_target import PromptTarget
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target.common.prompt_target import PromptTarget

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -213,8 +213,8 @@ def set_system_prompt(
async def update_conversation_state_async(
self,
*,
target: PromptTarget,
conversation_id: str,
target: Optional[Union[PromptTarget, PromptChatTarget]] = None,
prepended_conversation: List[Message],
request_converters: Optional[List[PromptConverterConfiguration]] = None,
response_converters: Optional[List[PromptConverterConfiguration]] = None,
Expand All @@ -237,9 +237,9 @@ async def update_conversation_state_async(
and extracts per-session counters such as the current turn index.

Args:
target (PromptTarget): The target for which the conversation is being prepared.
Used to validate that prepended_conversation is compatible with the target type.
conversation_id (str): Unique identifier for the conversation to update or create.
target (Optional[Union[PromptTarget, PromptChatTarget]]): The target to set system prompts on (if
applicable).
prepended_conversation (List[Message]):
List of messages to prepend to the conversation history.
request_converters (Optional[List[PromptConverterConfiguration]]):
Expand All @@ -254,12 +254,21 @@ async def update_conversation_state_async(
messages, including turn count and last user message.

Raises:
ValueError: If `conversation_id` is empty or if the last message in a multi-turn
context is a user message (which should not be prepended).
ValueError: If `conversation_id` is empty, if the last message in a multi-turn
context is a user message (which should not be prepended), or if
prepended_conversation is provided with a non-PromptChatTarget target.
"""
if not conversation_id:
raise ValueError("conversation_id cannot be empty")

# Validate prepended_conversation compatibility with target type
# Non-chat targets do not read conversation history from memory
if prepended_conversation and not isinstance(target, PromptChatTarget):
raise ValueError(
"prepended_conversation requires target to be a PromptChatTarget. "
"Non-chat targets do not support explicit conversation history management."
)

# Initialize conversation state
state = ConversationState()
logger.debug(f"Preparing conversation with ID: {conversation_id}")
Expand Down Expand Up @@ -304,7 +313,6 @@ async def update_conversation_state_async(
request=request,
conversation_id=conversation_id,
conversation_state=state,
target=target,
max_turns=max_turns,
)

Expand Down Expand Up @@ -364,7 +372,6 @@ async def _process_prepended_message_async(
request: Message,
conversation_id: str,
conversation_state: ConversationState,
target: Optional[Union[PromptTarget, PromptChatTarget]] = None,
max_turns: Optional[int] = None,
) -> None:
"""
Expand All @@ -376,100 +383,63 @@ async def _process_prepended_message_async(
request (Message): The request containing pieces to process.
conversation_id (str): The ID of the conversation to update.
conversation_state (ConversationState): The current state of the conversation.
target (Optional[Union[PromptTarget, PromptChatTarget]]): The target to set system prompts on (if
applicable).
max_turns (Optional[int]): Maximum allowed turns for the conversation.

Raises:
ValueError: If the request is invalid or if a system prompt is provided but target doesn't support it.
"""
# Validate the request before processing
if not request or not request.message_pieces:
return

# Set the conversation ID and attack ID for each piece in the request
save_to_memory = True
for piece in request.message_pieces:
piece.conversation_id = conversation_id
piece.attack_identifier = self._attack_identifier
piece.id = uuid.uuid4()

# Process the piece based on its role
# Process the piece based on its role (validates turn count for multi-turn)
self._process_piece(
piece=piece,
conversation_state=conversation_state,
max_turns=max_turns,
target=target,
)

if ConversationManager._should_exclude_piece_from_memory(piece=piece, max_turns=max_turns):
# it is excluded, so we don't want to save it to memory
save_to_memory = False

# Add the request to memory if it was not a system piece
if save_to_memory:
self._memory.add_message_to_memory(request=request)
# Add the request to memory
self._memory.add_message_to_memory(request=request)

def _process_piece(
self,
*,
piece: MessagePiece,
conversation_state: ConversationState,
max_turns: Optional[int] = None,
target: Optional[Union[PromptTarget, PromptChatTarget]] = None,
) -> None:
"""
Process a message piece based on its role and update conversation state.

For multi-turn conversations, this validates that the turn count doesn't exceed
max_turns. Only assistant messages count as turns.

Args:
piece (MessagePiece): The piece to process.
conversation_state (ConversationState): The current state of the conversation.
max_turns (Optional[int]): Maximum allowed turns (for validation).
target (Optional[Union[PromptTarget, PromptChatTarget]]): The target to set system prompts on.

Raises:
ValueError: If max_turns would be exceeded by this piece.
ValueError: If a system prompt is provided but target doesn't support it.
"""
# Check if multiturn
is_multi_turn = max_turns is not None

# Handle system prompts (both single-turn and multi-turn)
if piece.role == "system":
if target is None:
raise ValueError("Target must be provided to handle system prompts")

if not isinstance(target, PromptChatTarget):
raise ValueError("Target must be a PromptChatTarget to set system prompts")

# Set system prompt and exclude from memory
self.set_system_prompt(
target=target,
conversation_id=piece.conversation_id,
system_prompt=piece.converted_value,
labels=piece.labels,
)

# Handle assistant messages (count turns for multi-turn only)
elif piece.role == "assistant" and is_multi_turn:
# Update turn count
# Only assistant messages count as turns
if piece.role == "assistant" and is_multi_turn:
conversation_state.turn_count += 1

# Validate against max_turns
if max_turns and conversation_state.turn_count > max_turns:
if conversation_state.turn_count > max_turns:
raise ValueError(
f"The number of turns in the prepended conversation ({conversation_state.turn_count-1}) is equal to"
+ f" or exceeds the maximum number of turns ({max_turns}), which means the"
+ " conversation will not be able to continue. Please reduce the number of turns in"
+ " the prepended conversation or increase the maximum number of turns and try again."
)

@staticmethod
def _should_exclude_piece_from_memory(*, piece: MessagePiece, max_turns: Optional[int] = None) -> bool:
# System pieces should always be excluded from memory because set_system_prompt function
# is called on the target, which internally adds them to memory
return piece.role == "system"

async def _populate_conversation_state_async(
self,
*,
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/multi_turn/tree_of_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ async def initialize_with_prepended_conversation_async(

# Add to objective target conversation (handles system prompts and memory)
await conversation_manager.update_conversation_state_async(
conversation_id=self.objective_target_conversation_id,
target=self._objective_target,
conversation_id=self.objective_target_conversation_id,
prepended_conversation=prepended_conversation,
)

Expand Down
Loading