From b63361c44359a11a66daabf5e7f202b926817cb4 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 15 Apr 2026 16:55:10 -0700 Subject: [PATCH 01/23] Add labels to attack results --- pyrit/backend/mappers/attack_mappers.py | 5 +- pyrit/backend/services/attack_service.py | 1 + .../attack/multi_turn/chunked_request.py | 1 + pyrit/executor/attack/multi_turn/crescendo.py | 1 + .../attack/multi_turn/multi_prompt_sending.py | 1 + .../executor/attack/multi_turn/red_teaming.py | 1 + .../attack/multi_turn/tree_of_attacks.py | 1 + .../attack/single_turn/prompt_sending.py | 1 + .../attack/single_turn/skeleton_key.py | 1 + pyrit/executor/benchmark/fairness_bias.py | 1 + pyrit/memory/azure_sql_memory.py | 28 +++-- pyrit/memory/memory_interface.py | 8 +- pyrit/memory/memory_models.py | 4 + pyrit/memory/sqlite_memory.py | 21 +--- pyrit/models/attack_result.py | 3 + tests/unit/backend/test_attack_service.py | 5 +- tests/unit/backend/test_mappers.py | 22 +++- .../test_interface_attack_results.py | 117 ++++++++---------- tests/unit/scenario/test_scenario.py | 1 + 19 files changed, 122 insertions(+), 101 deletions(-) diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index 0245e2af12..c37dd77fd9 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -197,8 +197,11 @@ def attack_result_to_summary( """ message_count = stats.message_count last_preview = stats.last_message_preview - labels = dict(stats.labels) if stats.labels else {} + # Merge attack-result labels with conversation-level labels. + # Conversation labels take precedence on key collision. + labels = dict(ar.labels) if ar.labels else {} + labels.update(stats.labels or {}) created_str = ar.metadata.get("created_at") updated_str = ar.metadata.get("updated_at") created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 8852071a66..b7c3635e88 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -308,6 +308,7 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt "created_at": now.isoformat(), "updated_at": now.isoformat(), }, + labels=labels, ) # Store in memory diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index 1a70c89195..ed95c5d226 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -325,6 +325,7 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac outcome_reason=outcome_reason, executed_turns=context.executed_turns, metadata={"combined_chunks": combined_value, "chunk_count": len(context.chunk_responses)}, + labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 4a180d5df3..f137b322f3 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -402,6 +402,7 @@ async def _perform_async(self, *, context: CrescendoAttackContext) -> CrescendoA last_response=context.last_response.get_piece() if context.last_response else None, last_score=context.last_score, related_conversations=context.related_conversations, + labels=context.memory_labels, ) # setting metadata for backtrack count result.backtrack_count = context.backtrack_count diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index a9d4b75adc..8447737578 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -295,6 +295,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac outcome=outcome, outcome_reason=outcome_reason, executed_turns=context.executed_turns, + labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index a8778f664a..1feec20586 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -322,6 +322,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac last_response=context.last_response.get_piece() if context.last_response else None, last_score=context.last_score, related_conversations=context.related_conversations, + labels=context.memory_labels, ) async def _teardown_async(self, *, context: MultiTurnAttackContext[Any]) -> None: diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index e92bd1cf67..f6ccc4ed64 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -2082,6 +2082,7 @@ def _create_attack_result( last_response=last_response, last_score=context.best_objective_score, related_conversations=context.related_conversations, + labels=context.memory_labels, ) # Set attack-specific metadata using properties diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index 07f1d670fa..cdb2d4b619 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -238,6 +238,7 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta outcome=outcome, outcome_reason=outcome_reason, executed_turns=1, + labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/single_turn/skeleton_key.py b/pyrit/executor/attack/single_turn/skeleton_key.py index 683614dce5..40cc5cc302 100644 --- a/pyrit/executor/attack/single_turn/skeleton_key.py +++ b/pyrit/executor/attack/single_turn/skeleton_key.py @@ -181,4 +181,5 @@ def _create_skeleton_key_failure_result(self, *, context: SingleTurnAttackContex outcome=AttackOutcome.FAILURE, outcome_reason="Skeleton key prompt was filtered or failed", executed_turns=1, + labels=context.memory_labels, ) diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index 05bb424c17..63d33f4639 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -200,6 +200,7 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta atomic_attack_identifier=build_atomic_attack_identifier( attack_identifier=ComponentIdentifier.of(self), ), + labels=context.memory_labels, ) return last_attack_result diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 207a7f7f98..15586f7152 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -446,7 +446,8 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - Get the SQL Azure implementation for filtering AttackResults by labels. + Get the SQL Azure implementation for filtering AttackResults by labels + stored directly on the AttackResultEntry. Uses JSON_VALUE() function specific to SQL Azure with parameterized queries. @@ -454,24 +455,27 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: labels (dict[str, str]): Dictionary of label key-value pairs to filter by. Returns: - Any: SQLAlchemy exists subquery condition with bound parameters. + Any: SQLAlchemy condition with bound parameters. """ # Build JSON conditions for all labels with parameterized queries label_conditions = [] bindparams_dict = {} - for key, value in labels.items(): - param_name = f"label_{key}" - label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") - bindparams_dict[param_name] = str(value) + for i, (key, value) in enumerate(labels.items()): + path_param = f"label_path_{i}" + value_param = f"label_val_{i}" + label_conditions.append( + f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' + ) + bindparams_dict[path_param] = f"$.{key}" + bindparams_dict[value_param] = str(value) combined_conditions = " AND ".join(label_conditions) - return exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - PromptMemoryEntry.labels.isnot(None), - text(f"ISJSON(labels) = 1 AND {combined_conditions}").bindparams(**bindparams_dict), - ) + return and_( + AttackResultEntry.labels.isnot(None), + text( + f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}' + ).bindparams(**bindparams_dict), ) def get_unique_attack_class_names(self) -> list[str]: diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 023045a5c3..19c5939cbf 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -420,7 +420,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ Return a database-specific condition for filtering AttackResults by labels - in the associated PromptMemoryEntry records. + stored directly on the AttackResultEntry. Args: labels: Dictionary of labels that must ALL be present. @@ -1505,9 +1505,9 @@ def get_attack_results( not necessarily one(s) that were found in the response. By providing a list, this means ALL categories in the list must be present. Defaults to None. - labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. - These labels are associated with the prompts themselves, used for custom tagging and tracking. - Defaults to None. + labels (Optional[dict[str, str]], optional): A dictionary of labels to filter results by. + These labels are stored directly on the AttackResult. All specified key-value pairs + must be present (AND logic). Defaults to None. identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A sequence of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 9376768bd4..511005174d 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -720,6 +720,7 @@ class AttackResultEntry(Base): outcome (AttackOutcome): The outcome of the attack, indicating success, failure, or undetermined. outcome_reason (str): Optional reason for the outcome, providing additional context. attack_metadata (dict[str, Any]): Metadata can be included as key-value pairs to provide extra context. + labels (dict[str, str]): Optional labels associated with the attack result entry. pruned_conversation_ids (List[str]): List of conversation IDs that were pruned from the attack. adversarial_chat_conversation_ids (List[str]): List of conversation IDs used for adversarial chat. timestamp (DateTime): The timestamp of the attack result entry. @@ -751,6 +752,7 @@ class AttackResultEntry(Base): ) outcome_reason = mapped_column(String, nullable=True) attack_metadata: Mapped[dict[str, Union[str, int, float, bool]]] = mapped_column(JSON, nullable=True) + labels: Mapped[dict[str, str]] = mapped_column(JSON, nullable=True) pruned_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) adversarial_chat_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) timestamp = mapped_column(DateTime, nullable=False) @@ -806,6 +808,7 @@ def __init__(self, *, entry: AttackResult): self.outcome = entry.outcome.value self.outcome_reason = entry.outcome_reason self.attack_metadata = self.filter_json_serializable_metadata(entry.metadata) + self.labels = entry.labels or {} # Persist conversation references by type self.pruned_conversation_ids = [ @@ -917,6 +920,7 @@ def get_attack_result(self) -> AttackResult: outcome_reason=self.outcome_reason, related_conversations=related_conversations, metadata=self.attack_metadata or {}, + labels=self.labels or {}, ) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index bd376d67cd..3c4bb287fd 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -613,26 +613,17 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - SQLite implementation for filtering AttackResults by labels. + SQLite implementation for filtering AttackResults by labels + stored directly on the AttackResultEntry. Uses json_extract() function specific to SQLite. Returns: - Any: A SQLAlchemy subquery for filtering by labels. + Any: A SQLAlchemy condition for filtering by labels. """ - from sqlalchemy import and_, exists, func - - from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry - - labels_subquery = exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - PromptMemoryEntry.labels.isnot(None), - and_( - *[func.json_extract(PromptMemoryEntry.labels, f"$.{key}") == value for key, value in labels.items()] - ), - ) + return and_( + AttackResultEntry.labels.isnot(None), + *[func.json_extract(AttackResultEntry.labels, f"$.{key}") == value for key, value in labels.items()], ) - return labels_subquery # noqa: RET504 def get_unique_attack_class_names(self) -> list[str]: """ diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index a385ac36e7..5cbdf3c93e 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -87,6 +87,9 @@ class AttackResult(StrategyResult): # Arbitrary metadata metadata: dict[str, Any] = field(default_factory=dict) + # labels associated with this attack result + labels: dict[str, str] = field(default_factory=dict) + @property def attack_identifier(self) -> Optional[ComponentIdentifier]: """ diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index 11da01effe..83dd47bbb2 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -95,6 +95,9 @@ def make_attack_result( "created_at": created.isoformat(), "updated_at": updated.isoformat(), }, + labels={ + "test_ar_label": "test_ar_value" + }, ) @@ -320,7 +323,7 @@ async def test_list_attacks_includes_labels_in_summary(self, attack_service, moc result = await attack_service.list_attacks_async() assert len(result.items) == 1 - assert result.items[0].labels == {"env": "prod", "team": "red"} + assert result.items[0].labels == {"env": "prod", "team": "red", "test_ar_label": "test_ar_value"} @pytest.mark.asyncio async def test_list_attacks_filters_by_labels_directly(self, attack_service, mock_memory) -> None: diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index 0f483b3f10..f7c2495c71 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -81,6 +81,9 @@ def _make_attack_result( "created_at": now.isoformat(), "updated_at": now.isoformat(), }, + labels={ + "test_ar_label": "test_ar_value" + }, ) @@ -175,7 +178,7 @@ def test_labels_are_mapped(self) -> None: summary = attack_result_to_summary(ar, stats=stats) - assert summary.labels == {"env": "prod", "team": "red"} + assert summary.labels == {"env": "prod", "team": "red", "test_ar_label": "test_ar_value"} def test_labels_passed_through_without_normalization(self) -> None: """Test that labels are passed through as-is (DB stores canonical keys after migration).""" @@ -187,7 +190,21 @@ def test_labels_passed_through_without_normalization(self) -> None: summary = attack_result_to_summary(ar, stats=stats) - assert summary.labels == {"operator": "alice", "operation": "op_red", "env": "prod"} + assert summary.labels == { + "operator": "alice", "operation": "op_red", "env": "prod", "test_ar_label": "test_ar_value" + } + + def test_conversation_labels_take_precedence_on_collision(self) -> None: + """Test that conversation-level labels override attack-result labels on key collision.""" + ar = _make_attack_result() + stats = ConversationStats( + message_count=1, + labels={"test_ar_label": "conversation_wins"}, + ) + + summary = attack_result_to_summary(ar, stats=stats) + + assert summary.labels["test_ar_label"] == "conversation_wins" def test_outcome_success(self) -> None: """Test that success outcome is mapped.""" @@ -249,6 +266,7 @@ def test_converters_extracted_from_identifier(self) -> None: ), outcome=AttackOutcome.UNDETERMINED, metadata={"created_at": now.isoformat(), "updated_at": now.isoformat()}, + labels={"test_label": "test_value"}, ) summary = attack_result_to_summary(ar, stats=ConversationStats(message_count=0)) diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 2e30ba368a..d998e7c02a 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -36,12 +36,18 @@ def create_message_piece(conversation_id: str, prompt_num: int, targeted_harm_ca ) -def create_attack_result(conversation_id: str, objective_num: int, outcome: AttackOutcome = AttackOutcome.SUCCESS): +def create_attack_result( + conversation_id: str, + objective_num: int, + outcome: AttackOutcome = AttackOutcome.SUCCESS, + labels: dict[str, str] | None = None, +): """Helper function to create AttackResult.""" return AttackResult( conversation_id=conversation_id, objective=f"Objective {objective_num}", outcome=outcome, + labels=labels or {}, ) @@ -780,17 +786,14 @@ def test_get_attack_results_by_harm_category_multiple(sqlite_instance: MemoryInt def test_get_attack_results_by_labels_single(sqlite_instance: MemoryInterface): """Test filtering attack results by single label.""" - # Create message pieces with labels - message_piece1 = create_message_piece("conv_1", 1, labels={"operation": "test_op", "operator": "roakey"}) - message_piece2 = create_message_piece("conv_2", 2, labels={"operation": "test_op"}) - message_piece3 = create_message_piece("conv_3", 3, labels={"operation": "other_op", "operator": "roakey"}) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results - attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) - attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE) - attack_result3 = create_attack_result("conv_3", 3, AttackOutcome.SUCCESS) + # Create attack results with labels + attack_result1 = create_attack_result( + "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} + ) + attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE, labels={"operation": "test_op"}) + attack_result3 = create_attack_result( + "conv_3", 3, AttackOutcome.SUCCESS, labels={"operation": "other_op", "operator": "roakey"} + ) sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) @@ -808,22 +811,20 @@ def test_get_attack_results_by_labels_single(sqlite_instance: MemoryInterface): def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface): """Test filtering attack results by multiple labels (AND logic).""" - # Create message pieces with multiple labels using helper function - message_piece1 = create_message_piece( - "conv_1", 1, labels={"operation": "test_op", "operator": "roakey", "phase": "initial"} - ) - message_piece2 = create_message_piece( - "conv_2", 2, labels={"operation": "test_op", "operator": "roakey", "phase": "final"} - ) - message_piece3 = create_message_piece("conv_3", 3, labels={"operation": "test_op", "phase": "initial"}) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - - # Create attack results + # Create attack results with multiple labels attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE), + create_attack_result( + "conv_1", 1, AttackOutcome.SUCCESS, + labels={"operation": "test_op", "operator": "roakey", "phase": "initial"}, + ), + create_attack_result( + "conv_2", 2, AttackOutcome.SUCCESS, + labels={"operation": "test_op", "operator": "roakey", "phase": "final"}, + ), + create_attack_result( + "conv_3", 3, AttackOutcome.FAILURE, + labels={"operation": "test_op", "phase": "initial"}, + ), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) @@ -842,30 +843,24 @@ def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface) def test_get_attack_results_by_harm_category_and_labels(sqlite_instance: MemoryInterface): """Test filtering attack results by both harm categories and labels.""" - # Create message pieces with both harm categories and labels using helper function - message_piece1 = create_message_piece( - "conv_1", - 1, - targeted_harm_categories=["violence", "illegal"], - labels={"operation": "test_op", "operator": "roakey"}, - ) - message_piece2 = create_message_piece( - "conv_2", 2, targeted_harm_categories=["violence"], labels={"operation": "test_op", "operator": "roakey"} - ) - message_piece3 = create_message_piece( - "conv_3", - 3, - targeted_harm_categories=["violence", "illegal"], - labels={"operation": "other_op", "operator": "bob"}, - ) + # Create message pieces with harm categories (harm categories still live on PromptMemoryEntry) + message_piece1 = create_message_piece("conv_1", 1, targeted_harm_categories=["violence", "illegal"]) + message_piece2 = create_message_piece("conv_2", 2, targeted_harm_categories=["violence"]) + message_piece3 = create_message_piece("conv_3", 3, targeted_harm_categories=["violence", "illegal"]) sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - # Create attack results + # Create attack results with labels attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE), + create_attack_result( + "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} + ), + create_attack_result( + "conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} + ), + create_attack_result( + "conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "other_op", "operator": "bob"} + ), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) @@ -904,11 +899,8 @@ def test_get_attack_results_harm_category_no_matches(sqlite_instance: MemoryInte def test_get_attack_results_labels_no_matches(sqlite_instance: MemoryInterface): """Test filtering by labels that don't exist.""" - # Create attack result without the labels we'll search for - message_piece = create_message_piece("conv_1", 1, labels={"operation": "test_op"}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) - - attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) + # Create attack result with labels that don't match the search + attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op"}) sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) # Search for non-existent labels @@ -920,11 +912,6 @@ def test_get_attack_results_labels_query_on_empty_labels(sqlite_instance: Memory """Test querying for labels when records have no labels at all""" # Create attack results with NO labels - message_piece1 = create_message_piece("conv_1", 1) - message_piece2 = create_message_piece("conv_2", 1) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2]) - attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE) @@ -944,16 +931,14 @@ def test_get_attack_results_labels_key_exists_value_mismatch(sqlite_instance: Me """Test querying for labels where the key exists but the value doesn't match.""" # Create attack results with specific label values - message_piece1 = create_message_piece("conv_1", 1, labels={"operation": "op_exists", "researcher": "roakey"}) - message_piece2 = create_message_piece("conv_2", 1, labels={"operation": "another_op", "researcher": "roakey"}) - message_piece3 = create_message_piece("conv_3", 1, labels={"operation": "test_op"}) - - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE), + create_attack_result( + "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "op_exists", "researcher": "roakey"} + ), + create_attack_result( + "conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "another_op", "researcher": "roakey"} + ), + create_attack_result("conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "test_op"}), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index 7f02982015..480adf0543 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -85,6 +85,7 @@ def sample_attack_results(): objective=f"objective{i}", outcome=AttackOutcome.SUCCESS, executed_turns=1, + labels={"test_label": f"value{i}"}, ) for i in range(5) ] From 4ec494a67867d6bccf1da5eb4d21ccc98d5e92db Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 16 Apr 2026 09:37:10 -0700 Subject: [PATCH 02/23] format --- pyrit/memory/azure_sql_memory.py | 8 ++----- tests/unit/backend/test_attack_service.py | 4 +--- tests/unit/backend/test_mappers.py | 9 +++---- .../test_interface_attack_results.py | 24 +++++++++---------- 4 files changed, 20 insertions(+), 25 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 15586f7152..64ff020416 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -463,9 +463,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: for i, (key, value) in enumerate(labels.items()): path_param = f"label_path_{i}" value_param = f"label_val_{i}" - label_conditions.append( - f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' - ) + label_conditions.append(f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}') bindparams_dict[path_param] = f"$.{key}" bindparams_dict[value_param] = str(value) @@ -473,9 +471,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: return and_( AttackResultEntry.labels.isnot(None), - text( - f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}' - ).bindparams(**bindparams_dict), + text(f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}').bindparams(**bindparams_dict), ) def get_unique_attack_class_names(self) -> list[str]: diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index e62ac73d63..c494e29944 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -95,9 +95,7 @@ def make_attack_result( "created_at": created.isoformat(), "updated_at": updated.isoformat(), }, - labels={ - "test_ar_label": "test_ar_value" - }, + labels={"test_ar_label": "test_ar_value"}, ) diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index f7c2495c71..ad6ef1a380 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -81,9 +81,7 @@ def _make_attack_result( "created_at": now.isoformat(), "updated_at": now.isoformat(), }, - labels={ - "test_ar_label": "test_ar_value" - }, + labels={"test_ar_label": "test_ar_value"}, ) @@ -191,7 +189,10 @@ def test_labels_passed_through_without_normalization(self) -> None: summary = attack_result_to_summary(ar, stats=stats) assert summary.labels == { - "operator": "alice", "operation": "op_red", "env": "prod", "test_ar_label": "test_ar_value" + "operator": "alice", + "operation": "op_red", + "env": "prod", + "test_ar_label": "test_ar_value", } def test_conversation_labels_take_precedence_on_collision(self) -> None: diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index d998e7c02a..425b1f1360 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -814,15 +814,21 @@ def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface) # Create attack results with multiple labels attack_results = [ create_attack_result( - "conv_1", 1, AttackOutcome.SUCCESS, + "conv_1", + 1, + AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey", "phase": "initial"}, ), create_attack_result( - "conv_2", 2, AttackOutcome.SUCCESS, + "conv_2", + 2, + AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey", "phase": "final"}, ), create_attack_result( - "conv_3", 3, AttackOutcome.FAILURE, + "conv_3", + 3, + AttackOutcome.FAILURE, labels={"operation": "test_op", "phase": "initial"}, ), ] @@ -852,15 +858,9 @@ def test_get_attack_results_by_harm_category_and_labels(sqlite_instance: MemoryI # Create attack results with labels attack_results = [ - create_attack_result( - "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} - ), - create_attack_result( - "conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} - ), - create_attack_result( - "conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "other_op", "operator": "bob"} - ), + create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"}), + create_attack_result("conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"}), + create_attack_result("conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "other_op", "operator": "bob"}), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) From 5e04ca7e7f515ca8dca695d93b03f8f4a481aef1 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 16 Apr 2026 09:44:45 -0700 Subject: [PATCH 03/23] label queries OR'ed with old way --- pyrit/memory/azure_sql_memory.py | 56 +++++++++++++++++++++++--------- pyrit/memory/memory_interface.py | 2 +- pyrit/memory/sqlite_memory.py | 21 +++++++++--- 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 64ff020416..3023bf8240 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union -from sqlalchemy import and_, create_engine, event, exists, text +from sqlalchemy import and_, create_engine, event, exists, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker @@ -446,8 +446,10 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - Get the SQL Azure implementation for filtering AttackResults by labels - stored directly on the AttackResultEntry. + Get the SQL Azure implementation for filtering AttackResults by labels. + + Matches if the labels are found on the AttackResultEntry directly + OR on an associated PromptMemoryEntry (via conversation_id). Uses JSON_VALUE() function specific to SQL Azure with parameterized queries. @@ -457,23 +459,47 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: Returns: Any: SQLAlchemy condition with bound parameters. """ - # Build JSON conditions for all labels with parameterized queries - label_conditions = [] - bindparams_dict = {} + # --- Direct match on AttackResultEntry.labels --- + ar_label_conditions = [] + ar_bindparams: dict[str, str] = {} for i, (key, value) in enumerate(labels.items()): - path_param = f"label_path_{i}" - value_param = f"label_val_{i}" - label_conditions.append(f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}') - bindparams_dict[path_param] = f"$.{key}" - bindparams_dict[value_param] = str(value) - - combined_conditions = " AND ".join(label_conditions) + path_param = f"ar_label_path_{i}" + value_param = f"ar_label_val_{i}" + ar_label_conditions.append( + f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' + ) + ar_bindparams[path_param] = f"$.{key}" + ar_bindparams[value_param] = str(value) - return and_( + ar_combined = " AND ".join(ar_label_conditions) + direct_condition = and_( AttackResultEntry.labels.isnot(None), - text(f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}').bindparams(**bindparams_dict), + text( + f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}' + ).bindparams(**ar_bindparams), ) + # --- Conversation-level match on PromptMemoryEntry.labels --- + pme_label_conditions = [] + pme_bindparams: dict[str, str] = {} + for i, (key, value) in enumerate(labels.items()): + param_name = f"pme_label_{key}" + pme_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") + pme_bindparams[param_name] = str(value) + + pme_combined = " AND ".join(pme_label_conditions) + conversation_condition = exists().where( + and_( + PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, + PromptMemoryEntry.labels.isnot(None), + text( + f"ISJSON(labels) = 1 AND {pme_combined}" + ).bindparams(**pme_bindparams), + ) + ) + + return or_(direct_condition, conversation_condition) + def get_unique_attack_class_names(self) -> list[str]: """ Azure SQL implementation: extract unique class_name values from diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 5cd9797cdc..c90c978777 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -419,7 +419,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ Return a database-specific condition for filtering AttackResults by labels - stored directly on the AttackResultEntry. + stored directly on the AttackResultEntry OR on an associated PromptMemoryEntry (via conversation_id). Args: labels: Dictionary of labels that must ALL be present. diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 3c4bb287fd..ac8b2319eb 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any, Optional, TypeVar, Union, cast -from sqlalchemy import and_, create_engine, func, or_, text +from sqlalchemy import and_, create_engine, exists, func, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker @@ -613,18 +613,29 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - SQLite implementation for filtering AttackResults by labels - stored directly on the AttackResultEntry. - Uses json_extract() function specific to SQLite. + SQLite implementation for filtering AttackResults by labels. + + Matches if the labels are found on the AttackResultEntry directly + OR on an associated PromptMemoryEntry (via conversation_id). Returns: Any: A SQLAlchemy condition for filtering by labels. """ - return and_( + direct_condition = and_( AttackResultEntry.labels.isnot(None), *[func.json_extract(AttackResultEntry.labels, f"$.{key}") == value for key, value in labels.items()], ) + conversation_condition = exists().where( + and_( + PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, + PromptMemoryEntry.labels.isnot(None), + *[func.json_extract(PromptMemoryEntry.labels, f"$.{key}") == value for key, value in labels.items()], + ) + ) + + return or_(direct_condition, conversation_condition) + def get_unique_attack_class_names(self) -> list[str]: """ SQLite implementation: extract unique class_name values from From fa5a0d76bab4c582a5e4f86ae8f727d08b3400fd Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 16 Apr 2026 09:57:02 -0700 Subject: [PATCH 04/23] review feedback --- pyrit/memory/memory_interface.py | 4 ++-- .../test_interface_attack_results.py | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index c90c978777..6e7e1a1b15 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1481,8 +1481,8 @@ def get_attack_results( By providing a list, this means ALL categories in the list must be present. Defaults to None. labels (Optional[dict[str, str]], optional): A dictionary of labels to filter results by. - These labels are stored directly on the AttackResult. All specified key-value pairs - must be present (AND logic). Defaults to None. + These labels are stored on the AttackResult or associated PromptMemoryEntry (via conversation_id) + . All specified key-value pairs must be present (AND logic). Defaults to None. identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A sequence of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 425b1f1360..a4d87fd6d3 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -981,6 +981,27 @@ def test_get_attack_results_labels_key_exists_value_mismatch(sqlite_instance: Me assert results[0].conversation_id == "conv_1" +def test_get_attack_results_by_labels_falls_back_to_conversation_labels(sqlite_instance: MemoryInterface): + """Test that label filtering matches via PromptMemoryEntry when AttackResult has no labels.""" + + # Attack result with NO labels + attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={}) + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) + + # Conversation message carries the labels instead + message_piece = create_message_piece("conv_1", 1, labels={"operation": "legacy_op"}) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) + + # Should still find the attack result via the PME fallback path + results = sqlite_instance.get_attack_results(labels={"operation": "legacy_op"}) + assert len(results) == 1 + assert results[0].conversation_id == "conv_1" + + # Non-matching label should return nothing + results = sqlite_instance.get_attack_results(labels={"operation": "missing"}) + assert len(results) == 0 + + # --------------------------------------------------------------------------- # get_unique_attack_labels tests # --------------------------------------------------------------------------- From aafefc1ad8c222f9cb451a497e3f11fe3bb2381d Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 16 Apr 2026 10:02:58 -0700 Subject: [PATCH 05/23] format --- pyrit/memory/azure_sql_memory.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 3023bf8240..e7d3097615 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -465,24 +465,20 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: for i, (key, value) in enumerate(labels.items()): path_param = f"ar_label_path_{i}" value_param = f"ar_label_val_{i}" - ar_label_conditions.append( - f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' - ) + ar_label_conditions.append(f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}') ar_bindparams[path_param] = f"$.{key}" ar_bindparams[value_param] = str(value) ar_combined = " AND ".join(ar_label_conditions) direct_condition = and_( AttackResultEntry.labels.isnot(None), - text( - f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}' - ).bindparams(**ar_bindparams), + text(f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}').bindparams(**ar_bindparams), ) # --- Conversation-level match on PromptMemoryEntry.labels --- pme_label_conditions = [] pme_bindparams: dict[str, str] = {} - for i, (key, value) in enumerate(labels.items()): + for key, value in labels.items(): param_name = f"pme_label_{key}" pme_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") pme_bindparams[param_name] = str(value) @@ -492,9 +488,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: and_( PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, PromptMemoryEntry.labels.isnot(None), - text( - f"ISJSON(labels) = 1 AND {pme_combined}" - ).bindparams(**pme_bindparams), + text(f"ISJSON(labels) = 1 AND {pme_combined}").bindparams(**pme_bindparams), ) ) From 7d17c4427024c16ef38a333f872cae76f77f8e72 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Fri, 17 Apr 2026 13:49:55 -0700 Subject: [PATCH 06/23] Implement DB schema tracking with alembic --- .pre-commit-config.yaml | 5 + build_scripts/memory_migrations.py | 148 +++++++++++++ pyproject.toml | 1 + pyrit/memory/alembic.ini | 37 ++++ pyrit/memory/alembic/env.py | 76 +++++++ pyrit/memory/alembic/script.py.mako | 32 +++ .../versions/e373726d391b_initial_schema.py | 197 ++++++++++++++++++ pyrit/memory/azure_sql_memory.py | 20 +- pyrit/memory/memory_interface.py | 14 ++ pyrit/memory/migration.py | 122 +++++++++++ pyrit/memory/sqlite_memory.py | 18 +- tests/unit/memory/test_azure_sql_memory.py | 23 ++ tests/unit/memory/test_sqlite_memory.py | 148 ++++++++++++- uv.lock | 2 + 14 files changed, 807 insertions(+), 36 deletions(-) create mode 100644 build_scripts/memory_migrations.py create mode 100644 pyrit/memory/alembic.ini create mode 100644 pyrit/memory/alembic/env.py create mode 100644 pyrit/memory/alembic/script.py.mako create mode 100644 pyrit/memory/alembic/versions/e373726d391b_initial_schema.py create mode 100644 pyrit/memory/migration.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 113f7cb486..af5ac76bb8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,6 +25,11 @@ repos: files: ^(doc/.*\.(py|ipynb|md)|doc/myst\.yml)$ pass_filenames: false additional_dependencies: ['pyyaml'] + - id: memory-migrations-check + name: Check Memory Migrations + entry: python ./build_scripts/memory_migrations.py check + language: system + pass_filenames: false - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 diff --git a/build_scripts/memory_migrations.py b/build_scripts/memory_migrations.py new file mode 100644 index 0000000000..6e0e7a3c0f --- /dev/null +++ b/build_scripts/memory_migrations.py @@ -0,0 +1,148 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import argparse +import sys +import tempfile +from pathlib import Path + +from alembic import command +from alembic.config import Config +from alembic.util.exc import AutogenerateDiffsDetected + +_REPO_ROOT = Path(__file__).resolve().parent.parent +_ALEMBIC_INI = _REPO_ROOT / "pyrit" / "memory" / "alembic.ini" + +# ANSI color codes +_RED = "\033[91m" +_RESET = "\033[0m" + + +def _print_error(message: str) -> None: + """Print an error message in red to stderr.""" + print(f"{_RED}{message}{_RESET}", file=sys.stderr) + + +def _make_config(*, db_url: str) -> Config: + """ + Build an Alembic Config pointed at a specific database URL. + + Args: + db_url (str): SQLAlchemy database URL to use. + + Returns: + Config: Configured Alembic config object. + """ + cfg = Config(str(_ALEMBIC_INI)) + cfg.set_main_option("sqlalchemy.url", db_url) + return cfg + + +def _cmd_generate(*, message: str, force: bool = False) -> None: + """ + Generate a new Alembic revision from model changes. + + A fresh temporary SQLite database is created, upgraded to the current head, + and then checked against the current models. If the check passes (schema matches), + no migration is needed and the command fails. If the check fails (schema doesn't match), + a migration is generated. + + Args: + message (str): Human-readable migration message. + force (bool): If True, generate a migration even if schema matches models. + Useful for manual migrations. + + Raises: + SystemExit: If schema matches models and force=False. + """ + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_url = f"sqlite:///{tmp.name}" + + try: + config = _make_config(db_url=db_url) + command.upgrade(config, "head") + + # Check if schema matches current models + try: + command.check(config) + # If check passes, schema already matches models + if not force: + _print_error("No schema changes detected. Use --force to generate an empty migration.") + raise SystemExit(1) + print("Generating migration even though schema matches models (--force used).") + except AutogenerateDiffsDetected: + # Check failed, meaning schema doesn't match models - proceed with generation + pass + + command.revision(config, autogenerate=True, message=message) + print("Migration file generated. Review it carefully before committing.") + finally: + Path(tmp.name).unlink(missing_ok=True) + + +def _cmd_check() -> None: + """ + Verify that all migration scripts apply cleanly and schema matches models. + + Creates a clean temporary database, applies all migrations to head, + and validates that the resulting schema matches the current memory models. + """ + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_url = f"sqlite:///{tmp.name}" + + try: + command.upgrade(_make_config(db_url=db_url), "head") + command.check(_make_config(db_url=db_url)) + except Exception as e: + _print_error( + f"Migration check failed. Run the script in 'generate' mode to generate a new migration. Error: {e}" + ) + raise SystemExit(1) from e + finally: + Path(tmp.name).unlink(missing_ok=True) + + +def _build_parser() -> argparse.ArgumentParser: + """ + Build the CLI argument parser. + + Returns: + argparse.ArgumentParser: Configured parser. + """ + parser = argparse.ArgumentParser( + description="PyRIT memory migration tool. Generate and validate migrations against clean databases." + ) + sub = parser.add_subparsers(dest="command", required=True) + + gen = sub.add_parser("generate", help="Generate a new migration from model changes.") + gen.add_argument("-m", "--message", required=True, help="Migration message.") + gen.add_argument("--force", action="store_true", help="Generate migration even if no changes detected.") + + sub.add_parser("check", help="Verify all migrations apply cleanly to a fresh database.") + + return parser + + +def main() -> int: + """ + Dispatch the selected migration command. + + Returns: + int: Process exit code. + """ + if not _ALEMBIC_INI.exists(): + _print_error(f"Alembic config not found at {_ALEMBIC_INI}") + return 1 + + args = _build_parser().parse_args() + + if args.command == "generate": + _cmd_generate(message=args.message, force=args.force) + elif args.command == "check": + _cmd_check() + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/pyproject.toml b/pyproject.toml index abe9199092..565d938aba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ requires-python = ">=3.10, <3.14" dependencies = [ "aiofiles>=24,<25", + "alembic>=1.16.0", "appdirs>=1.4.0", "art>=6.5.0", "av>=14.0.0", diff --git a/pyrit/memory/alembic.ini b/pyrit/memory/alembic.ini new file mode 100644 index 0000000000..6af71fffe2 --- /dev/null +++ b/pyrit/memory/alembic.ini @@ -0,0 +1,37 @@ +# A generic, package-local Alembic config for PyRIT memory schema migrations. + +[alembic] +script_location = pyrit/memory/alembic +prepend_sys_path = . + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s diff --git a/pyrit/memory/alembic/env.py b/pyrit/memory/alembic/env.py new file mode 100644 index 0000000000..b8e2af6685 --- /dev/null +++ b/pyrit/memory/alembic/env.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from logging.config import fileConfig +from typing import TYPE_CHECKING + +from alembic import context +from sqlalchemy import engine_from_config, pool + +from pyrit.memory.memory_models import Base + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection + +config = context.config + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +target_metadata = Base.metadata +VERSION_TABLE = "pyrit_memory_alembic_version" + + +def run_migrations_offline() -> None: + """Run migrations in offline mode.""" + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + compare_type=True, + version_table=VERSION_TABLE, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in online mode.""" + connection: Connection | None = config.attributes.get("connection") + + if connection is not None: + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, + version_table=VERSION_TABLE, + ) + with context.begin_transaction(): + context.run_migrations() + return + + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as created_connection: + context.configure( + connection=created_connection, + target_metadata=target_metadata, + compare_type=True, + version_table=VERSION_TABLE, + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/pyrit/memory/alembic/script.py.mako b/pyrit/memory/alembic/script.py.mako new file mode 100644 index 0000000000..2b67b03e12 --- /dev/null +++ b/pyrit/memory/alembic/script.py.mako @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +${message}. + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = "${up_revision}" +down_revision: str | None = ${repr(down_revision).replace("'", '"')} +branch_labels: str | Sequence[str] | None = ${repr(branch_labels).replace("'", '"')} +depends_on: str | Sequence[str] | None = ${repr(depends_on).replace("'", '"')} + + +def upgrade() -> None: + """Apply this schema upgrade.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Revert this schema upgrade.""" + ${downgrades if downgrades else "pass"} diff --git a/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py new file mode 100644 index 0000000000..3299bb82c5 --- /dev/null +++ b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +initial schema. + +Revision ID: e373726d391b +Revises: +Create Date: 2026-04-17 10:03:04.066932 +""" + +import uuid +from collections.abc import Sequence +from typing import Any + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.sqlite import CHAR +from sqlalchemy.engine import Dialect +from sqlalchemy.types import TypeDecorator, Uuid + + +class _CustomUUID(TypeDecorator[uuid.UUID]): + """Frozen copy of CustomUUID kept here so this revision stays self-contained.""" + + impl = CHAR + cache_ok = True + + def load_dialect_impl(self, dialect: Dialect) -> Any: + if dialect.name == "sqlite": + return dialect.type_descriptor(CHAR(36)) + return dialect.type_descriptor(Uuid()) + + def process_bind_param(self, value: Any, dialect: Any) -> str | None: + return str(value) if value is not None else None + + def process_result_value(self, value: Any, dialect: Any) -> uuid.UUID | None: + if value is None: + return None + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(value) + + +# revision identifiers, used by Alembic. +revision: str = "e373726d391b" +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply this schema upgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "PromptMemoryEntries", + sa.Column("id", _CustomUUID(), nullable=False), + sa.Column("role", sa.String(), nullable=False), + sa.Column("conversation_id", sa.String(), nullable=False), + sa.Column("sequence", sa.INTEGER(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("labels", sa.JSON(), nullable=False), + sa.Column("prompt_metadata", sa.JSON(), nullable=False), + sa.Column("targeted_harm_categories", sa.JSON(), nullable=True), + sa.Column("converter_identifiers", sa.JSON(), nullable=True), + sa.Column("prompt_target_identifier", sa.JSON(), nullable=False), + sa.Column("attack_identifier", sa.JSON(), nullable=False), + sa.Column("response_error", sa.String(), nullable=True), + sa.Column("original_value_data_type", sa.String(), nullable=False), + sa.Column("original_value", sa.Unicode(), nullable=False), + sa.Column("original_value_sha256", sa.String(), nullable=True), + sa.Column("converted_value_data_type", sa.String(), nullable=False), + sa.Column("converted_value", sa.Unicode(), nullable=True), + sa.Column("converted_value_sha256", sa.String(), nullable=True), + sa.Column("original_prompt_id", _CustomUUID(), nullable=False), + sa.Column("pyrit_version", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "ScenarioResultEntries", + sa.Column("id", _CustomUUID(), nullable=False), + sa.Column("scenario_name", sa.String(), nullable=False), + sa.Column("scenario_description", sa.Unicode(), nullable=True), + sa.Column("scenario_version", sa.INTEGER(), nullable=False), + sa.Column("pyrit_version", sa.String(), nullable=False), + sa.Column("scenario_init_data", sa.JSON(), nullable=True), + sa.Column("objective_target_identifier", sa.JSON(), nullable=False), + sa.Column("objective_scorer_identifier", sa.JSON(), nullable=True), + sa.Column("scenario_run_state", sa.String(), nullable=False), + sa.Column("attack_results_json", sa.Unicode(), nullable=False), + sa.Column("labels", sa.JSON(), nullable=True), + sa.Column("number_tries", sa.INTEGER(), nullable=False), + sa.Column("completion_time", sa.DateTime(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "SeedPromptEntries", + sa.Column("id", _CustomUUID(), nullable=False), + sa.Column("value", sa.Unicode(), nullable=False), + sa.Column("value_sha256", sa.Unicode(), nullable=True), + sa.Column("data_type", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("dataset_name", sa.String(), nullable=True), + sa.Column("harm_categories", sa.JSON(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("authors", sa.JSON(), nullable=True), + sa.Column("groups", sa.JSON(), nullable=True), + sa.Column("source", sa.String(), nullable=True), + sa.Column("date_added", sa.DateTime(), nullable=False), + sa.Column("added_by", sa.String(), nullable=False), + sa.Column("prompt_metadata", sa.JSON(), nullable=True), + sa.Column("parameters", sa.JSON(), nullable=True), + sa.Column("prompt_group_id", _CustomUUID(), nullable=True), + sa.Column("sequence", sa.INTEGER(), nullable=True), + sa.Column("role", sa.String(), nullable=True), + sa.Column("seed_type", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "EmbeddingData", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column( + "embedding", + sa.ARRAY(sa.Float()).with_variant(sa.JSON(), "mssql").with_variant(sa.JSON(), "sqlite"), + nullable=True, + ), + sa.Column("embedding_type_name", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["id"], + ["PromptMemoryEntries.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "ScoreEntries", + sa.Column("id", _CustomUUID(), nullable=False), + sa.Column("score_value", sa.String(), nullable=False), + sa.Column("score_value_description", sa.String(), nullable=True), + sa.Column("score_type", sa.String(), nullable=False), + sa.Column("score_category", sa.JSON(), nullable=True), + sa.Column("score_rationale", sa.String(), nullable=True), + sa.Column("score_metadata", sa.JSON(), nullable=False), + sa.Column("scorer_class_identifier", sa.JSON(), nullable=False), + sa.Column("prompt_request_response_id", _CustomUUID(), nullable=True), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("task", sa.String(), nullable=True), + sa.Column("objective", sa.String(), nullable=True), + sa.Column("pyrit_version", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["prompt_request_response_id"], + ["PromptMemoryEntries.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "AttackResultEntries", + sa.Column("id", _CustomUUID(), nullable=False), + sa.Column("conversation_id", sa.String(), nullable=False), + sa.Column("objective", sa.Unicode(), nullable=False), + sa.Column("attack_identifier", sa.JSON(), nullable=False), + sa.Column("atomic_attack_identifier", sa.JSON(), nullable=True), + sa.Column("objective_sha256", sa.String(), nullable=True), + sa.Column("last_response_id", _CustomUUID(), nullable=True), + sa.Column("last_score_id", _CustomUUID(), nullable=True), + sa.Column("executed_turns", sa.INTEGER(), nullable=False), + sa.Column("execution_time_ms", sa.INTEGER(), nullable=False), + sa.Column("outcome", sa.String(), nullable=False), + sa.Column("outcome_reason", sa.String(), nullable=True), + sa.Column("attack_metadata", sa.JSON(), nullable=True), + sa.Column("pruned_conversation_ids", sa.JSON(), nullable=True), + sa.Column("adversarial_chat_conversation_ids", sa.JSON(), nullable=True), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("pyrit_version", sa.String(), nullable=True), + sa.ForeignKeyConstraint( + ["last_response_id"], + ["PromptMemoryEntries.id"], + ), + sa.ForeignKeyConstraint( + ["last_score_id"], + ["ScoreEntries.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Revert this schema upgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("AttackResultEntries") + op.drop_table("ScoreEntries") + op.drop_table("EmbeddingData") + op.drop_table("SeedPromptEntries") + op.drop_table("ScenarioResultEntries") + op.drop_table("PromptMemoryEntries") + # ### end Alembic commands ### diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 207a7f7f98..b15d500f97 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -26,6 +26,7 @@ EmbeddingDataEntry, PromptMemoryEntry, ) +from pyrit.memory.migration import reset_schema from pyrit.models import ( AzureBlobStorageIO, ConversationStats, @@ -208,20 +209,6 @@ def provide_token(_dialect: Any, _conn_rec: Any, cargs: list[Any], cparams: dict # add the encoded token cparams["attrs_before"] = {self.SQL_COPT_SS_ACCESS_TOKEN: packed_azure_token} - def _create_tables_if_not_exist(self) -> None: - """ - Create all tables defined in the Base metadata, if they don't already exist in the database. - - Raises: - Exception: If there's an issue creating the tables in the database. - """ - try: - # Using the 'checkfirst=True' parameter to avoid attempting to recreate existing tables - Base.metadata.create_all(self.engine, checkfirst=True) - except Exception as e: - logger.exception(f"Error during table creation: {e}") - raise - def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None: """ Insert embedding data into memory storage. @@ -791,7 +778,4 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict def reset_database(self) -> None: """Drop and recreate existing tables.""" - # Drop all existing tables - Base.metadata.drop_all(self.engine) - # Recreate the tables - Base.metadata.create_all(self.engine, checkfirst=True) + reset_schema(engine=self.engine) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 9bf33031f9..b8e0b5607a 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -34,6 +34,7 @@ ScoreEntry, SeedEntry, ) +from pyrit.memory.migration import run_schema_migrations from pyrit.models import ( AttackResult, ConversationStats, @@ -119,6 +120,19 @@ def disable_embedding(self) -> None: """ self.memory_embedding = None + def _create_tables_if_not_exist(self) -> None: + """ + Upgrade the database schema to the latest Alembic revision. + + Raises: + Exception: If there's an issue applying schema migrations. + """ + try: + run_schema_migrations(engine=self.engine) + except Exception as e: + logger.exception(f"Error during schema migration: {e}") + raise + def _build_identifier_filter_conditions( self, *, diff --git a/pyrit/memory/migration.py b/pyrit/memory/migration.py new file mode 100644 index 0000000000..5697399079 --- /dev/null +++ b/pyrit/memory/migration.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +from pathlib import Path + +from alembic import command +from alembic.autogenerate.api import compare_metadata +from alembic.config import Config +from alembic.migration import MigrationContext +from sqlalchemy import inspect, text +from sqlalchemy.engine import Connection, Engine + +from pyrit.memory.memory_models import Base + +logger = logging.getLogger(__name__) + +_MEMORY_ALEMBIC_VERSION_TABLE = "pyrit_memory_alembic_version" +_HEAD_REVISION = "head" +_INITIAL_MEMORY_REVISION = "e373726d391b" +_MEMORY_TABLES = { + "AttackResultEntries", + "EmbeddingData", + "PromptMemoryEntries", + "ScenarioResultEntries", + "ScoreEntries", + "SeedPromptEntries", +} + + +def _make_config(*, connection: Connection) -> Config: + """ + Build an Alembic config for the memory migration scripts. + + Args: + connection (Connection): Database connection for Alembic commands. + + Returns: + Config: Configured Alembic config object. + """ + script_location = Path(__file__).with_name("alembic") + + config = Config() + config.set_main_option("script_location", str(script_location)) + config.attributes["connection"] = connection + config.attributes["version_table"] = _MEMORY_ALEMBIC_VERSION_TABLE + return config + + +def _validate_and_stamp_unversioned_memory_schema(*, config: Config, connection: Connection) -> None: + """ + Validate and stamp unversioned legacy memory schemas. + + Args: + config (Config): Alembic config bound to the current connection. + connection (Connection): Database connection to inspect. + + Raises: + RuntimeError: If an unversioned memory schema does not match models. + """ + table_names = set(inspect(connection).get_table_names()) + if _MEMORY_ALEMBIC_VERSION_TABLE in table_names: + return + + if not _MEMORY_TABLES.intersection(table_names): + return + + try: + migration_context = MigrationContext.configure(connection=connection, opts={"compare_type": True}) + diffs = compare_metadata(migration_context, Base.metadata) + except Exception as e: + raise RuntimeError( + "Detected an unversioned legacy memory schema (memory tables exist, but " + "pyrit_memory_alembic_version is missing), " + "and it does not match current models. Repair or rebuild the database before upgrading to this release." + ) from e + + if diffs: + raise RuntimeError( + "Detected an unversioned legacy memory schema (memory tables exist, but " + "pyrit_memory_alembic_version is missing), " + "and it does not match current models. Repair or rebuild the database before upgrading to this release." + ) + + logger.info("Detected matching unversioned memory schema; stamping revision %s", _INITIAL_MEMORY_REVISION) + command.stamp(config, _INITIAL_MEMORY_REVISION) + + +def run_schema_migrations(*, engine: Engine) -> None: + """ + Upgrade the database schema to the latest Alembic revision. + + Args: + engine (Engine): SQLAlchemy engine bound to the target database. + + Raises: + Exception: If Alembic fails to apply migrations. + """ + script_location = Path(__file__).with_name("alembic") + logger.debug("Applying Alembic migrations from %s", script_location) + with engine.begin() as connection: + config = _make_config(connection=connection) + _validate_and_stamp_unversioned_memory_schema(config=config, connection=connection) + command.upgrade(config, _HEAD_REVISION) + + +def reset_schema(*, engine: Engine) -> None: + """ + Recreate the database schema at the latest Alembic revision. + + Args: + engine (Engine): SQLAlchemy engine bound to the target database. + """ + logger.debug("Resetting memory schema using Alembic migrations") + with engine.begin() as connection: + Base.metadata.drop_all(connection) + + inspector = inspect(connection) + if _MEMORY_ALEMBIC_VERSION_TABLE in inspector.get_table_names(): + connection.execute(text(f'DROP TABLE "{_MEMORY_ALEMBIC_VERSION_TABLE}"')) + + command.upgrade(_make_config(connection=connection), _HEAD_REVISION) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index bd376d67cd..f2f69789ed 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -28,6 +28,7 @@ PromptMemoryEntry, ScenarioResultEntry, ) +from pyrit.memory.migration import reset_schema from pyrit.models import ConversationStats, DiskStorageIO, MessagePiece logger = logging.getLogger(__name__) @@ -128,20 +129,6 @@ def _create_engine(self, *, has_echo: bool) -> Engine: logger.exception(f"Error creating the engine for the database: {e}") raise - def _create_tables_if_not_exist(self) -> None: - """ - Create all tables defined in the Base metadata, if they don't already exist in the database. - - Raises: - Exception: If there's an issue creating the tables in the database. - """ - try: - # Using the 'checkfirst=True' parameter to avoid attempting to recreate existing tables - Base.metadata.create_all(self.engine, checkfirst=True) - except Exception as e: - logger.exception(f"Error during table creation: {e}") - raise - def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: """ Fetch all entries from the specified table and returns them as model instances. @@ -441,8 +428,7 @@ def reset_database(self) -> None: """ Drop and recreates all tables in the database. """ - Base.metadata.drop_all(self.engine) - Base.metadata.create_all(self.engine) + reset_schema(engine=self.engine) def dispose_engine(self) -> None: """ diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index acf4420604..a7b3c5f705 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING import pytest +from sqlalchemy import inspect, text from pyrit.memory import AzureSQLMemory, EmbeddingDataEntry, PromptMemoryEntry from pyrit.models import MessagePiece @@ -135,6 +136,28 @@ def test_default_embedding_raises(memory_interface: AzureSQLMemory): memory_interface.enable_embedding() +def test_reset_database_recreates_versioned_schema(memory_interface: AzureSQLMemory): + memory_interface.reset_database() + + inspector = inspect(memory_interface.engine) + table_names = set(inspector.get_table_names()) + + assert { + "AttackResultEntries", + "EmbeddingData", + "PromptMemoryEntries", + "ScenarioResultEntries", + "ScoreEntries", + "SeedPromptEntries", + "pyrit_memory_alembic_version", + }.issubset(table_names) + + with memory_interface.engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + + assert version + + def test_query_entries( memory_interface: AzureSQLMemory, sample_conversation_entries: MutableSequence[PromptMemoryEntry] ): diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index ba07356578..b78cb28489 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -4,19 +4,21 @@ import io import logging import os +import tempfile import uuid from collections.abc import Sequence from datetime import timezone from unittest.mock import MagicMock import pytest -from sqlalchemy import ARRAY, DateTime, Integer, String, inspect +from sqlalchemy import ARRAY, DateTime, Integer, String, create_engine, inspect, text from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.sqlite import CHAR, JSON from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.sql.sqltypes import NullType -from pyrit.memory.memory_models import EmbeddingDataEntry, PromptMemoryEntry +from pyrit.memory.memory_models import Base, EmbeddingDataEntry, PromptMemoryEntry +from pyrit.memory.migration import run_schema_migrations from pyrit.models import MessagePiece from pyrit.prompt_converter.base64_converter import Base64Converter from pyrit.prompt_target.text_target import TextTarget @@ -156,6 +158,148 @@ def test_embedding_data_column_types(sqlite_instance): ], f"Unexpected type for 'embedding' column: {column_types['embedding']}" +def test_run_schema_migrations_stamps_matching_unversioned_legacy_database(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "legacy-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + Base.metadata.create_all(engine, checkfirst=True) + + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "pyrit_memory_alembic_version" in table_names + + with engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + + assert version + finally: + engine.dispose() + + +def test_run_schema_migrations_fails_synthetic_unversioned_schema_with_drift(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "legacy-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + Base.metadata.create_all(engine, checkfirst=True) + with engine.begin() as connection: + connection.execute(text('DROP TABLE "ScoreEntries"')) + + with pytest.raises(RuntimeError, match="unversioned legacy memory schema"): + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "pyrit_memory_alembic_version" not in table_names + finally: + engine.dispose() + + +def test_run_schema_migrations_fails_pre_alembic_like_schema() -> None: + """ + Validate strict behavior for unsupported legacy schemas. + + This schema shape intentionally resembles an older pre-Alembic layout where + newer columns (e.g. pyrit_version, converter_identifiers) are absent. + Such databases are intentionally unsupported and must fail migration checks. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "legacy-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + connection.execute( + text( + """ + CREATE TABLE "PromptMemoryEntries" ( + id CHAR(36) NOT NULL, + role VARCHAR NOT NULL, + conversation_id VARCHAR NOT NULL, + sequence INTEGER NOT NULL, + timestamp DATETIME NOT NULL, + labels JSON NOT NULL, + prompt_metadata JSON NOT NULL, + prompt_target_identifier JSON NOT NULL, + attack_identifier JSON NOT NULL, + original_value_data_type VARCHAR NOT NULL, + original_value VARCHAR NOT NULL, + original_value_sha256 VARCHAR, + converted_value_data_type VARCHAR NOT NULL, + converted_value VARCHAR, + converted_value_sha256 VARCHAR, + original_prompt_id CHAR(36) NOT NULL, + PRIMARY KEY (id) + ) + """ + ) + ) + + with pytest.raises(RuntimeError, match="unversioned legacy memory schema"): + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "pyrit_memory_alembic_version" not in table_names + finally: + engine.dispose() + + +def test_run_schema_migrations_isolates_foreign_alembic_version_table(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "legacy-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + Base.metadata.create_all(engine, checkfirst=True) + with engine.begin() as connection: + connection.execute(text('CREATE TABLE "alembic_version" (version_num VARCHAR(32) NOT NULL)')) + + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "alembic_version" in table_names + assert "pyrit_memory_alembic_version" in table_names + + with engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + + assert version + finally: + engine.dispose() + + +def test_reset_database_recreates_schema(sqlite_instance): + sqlite_instance.reset_database() + + inspector = inspect(sqlite_instance.engine) + table_names = set(inspector.get_table_names()) + + assert { + "AttackResultEntries", + "EmbeddingData", + "PromptMemoryEntries", + "ScenarioResultEntries", + "ScoreEntries", + "SeedPromptEntries", + "pyrit_memory_alembic_version", + }.issubset(table_names) + + with sqlite_instance.engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + + assert version + + +def test_reset_database_keeps_foreign_alembic_version_table(sqlite_instance): + with sqlite_instance.engine.begin() as connection: + connection.execute(text('CREATE TABLE "alembic_version" (version_num VARCHAR(32) NOT NULL)')) + + sqlite_instance.reset_database() + + table_names = set(inspect(sqlite_instance.engine).get_table_names()) + assert "alembic_version" in table_names + assert "pyrit_memory_alembic_version" in table_names + + @pytest.mark.asyncio() async def test_insert_entry(sqlite_instance): session = sqlite_instance.get_session() diff --git a/uv.lock b/uv.lock index 21952cbdc8..165dc224df 100644 --- a/uv.lock +++ b/uv.lock @@ -4899,6 +4899,7 @@ version = "0.13.0.dev0" source = { editable = "." } dependencies = [ { name = "aiofiles" }, + { name = "alembic" }, { name = "appdirs" }, { name = "art" }, { name = "av" }, @@ -5025,6 +5026,7 @@ requires-dist = [ { name = "accelerate", marker = "extra == 'all'", specifier = ">=1.7.0" }, { name = "accelerate", marker = "extra == 'gcg'", specifier = ">=1.7.0" }, { name = "aiofiles", specifier = ">=24,<25" }, + { name = "alembic", specifier = ">=1.16.0" }, { name = "appdirs", specifier = ">=1.4.0" }, { name = "art", specifier = ">=6.5.0" }, { name = "av", specifier = ">=14.0.0" }, From 255c7fd910f7c252d856571e82f7e6cd48005283 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Fri, 17 Apr 2026 17:02:52 -0700 Subject: [PATCH 07/23] optimizations --- build_scripts/memory_migrations.py | 5 +- pyproject.toml | 2 + pyrit/memory/alembic/env.py | 2 +- .../versions/e373726d391b_initial_schema.py | 9 +- pyrit/memory/migration.py | 53 +++++- tests/unit/memory/test_migration.py | 169 ++++++++++++++++++ tests/unit/memory/test_sqlite_memory.py | 114 +++++++++++- 7 files changed, 341 insertions(+), 13 deletions(-) create mode 100644 tests/unit/memory/test_migration.py diff --git a/build_scripts/memory_migrations.py b/build_scripts/memory_migrations.py index 6e0e7a3c0f..ac1bab40da 100644 --- a/build_scripts/memory_migrations.py +++ b/build_scripts/memory_migrations.py @@ -91,8 +91,9 @@ def _cmd_check() -> None: db_url = f"sqlite:///{tmp.name}" try: - command.upgrade(_make_config(db_url=db_url), "head") - command.check(_make_config(db_url=db_url)) + config = _make_config(db_url=db_url) + command.upgrade(config, "head") + command.check(config) except Exception as e: _print_error( f"Migration check failed. Run the script in 'generate' mode to generate a new migration. Error: {e}" diff --git a/pyproject.toml b/pyproject.toml index 565d938aba..825fe907fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -197,6 +197,8 @@ include = ["pyrit", "pyrit.*"] [tool.setuptools.package-data] pyrit = [ "backend/frontend/**/*", + "memory/alembic/**/*", + "memory/alembic.ini", "py.typed" ] diff --git a/pyrit/memory/alembic/env.py b/pyrit/memory/alembic/env.py index b8e2af6685..1b77327183 100644 --- a/pyrit/memory/alembic/env.py +++ b/pyrit/memory/alembic/env.py @@ -18,7 +18,7 @@ fileConfig(config.config_file_name) target_metadata = Base.metadata -VERSION_TABLE = "pyrit_memory_alembic_version" +VERSION_TABLE = config.attributes.get("version_table", "pyrit_memory_alembic_version") def run_migrations_offline() -> None: diff --git a/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py index 3299bb82c5..43e42dcd72 100644 --- a/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py +++ b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py @@ -21,7 +21,14 @@ class _CustomUUID(TypeDecorator[uuid.UUID]): - """Frozen copy of CustomUUID kept here so this revision stays self-contained.""" + """ + Frozen copy of CustomUUID kept here so this revision stays self-contained. + + This class is embedded in the migration script rather than imported to ensure + the migration remains reproducible and independent of future changes to the + main CustomUUID implementation in memory_models.py. Any future modifications + to CustomUUID must NOT affect this frozen version. + """ impl = CHAR cache_ok = True diff --git a/pyrit/memory/migration.py b/pyrit/memory/migration.py index 5697399079..c352abb3eb 100644 --- a/pyrit/memory/migration.py +++ b/pyrit/memory/migration.py @@ -8,7 +8,7 @@ from alembic.autogenerate.api import compare_metadata from alembic.config import Config from alembic.migration import MigrationContext -from sqlalchemy import inspect, text +from sqlalchemy import MetaData, Table, inspect from sqlalchemy.engine import Connection, Engine from pyrit.memory.memory_models import Base @@ -28,6 +28,32 @@ } +def _include_name_for_memory_schema( + name: str | None, + type_: str, + parent_names: dict[str, str], +) -> bool: + """ + Restrict schema comparisons to PyRIT memory tables and their child objects. + + Args: + name (str | None): Name of the database object being considered. + type_ (str): SQLAlchemy object type (e.g., "table", "column", "index"). + parent_names (dict[str, str]): Parent-name context provided by Alembic. + + Returns: + bool: True when the object should be included in schema comparison. + """ + if type_ == "table": + return bool(name and name in _MEMORY_TABLES) + + table_name = parent_names.get("table_name") + if table_name: + return table_name in _MEMORY_TABLES + + return True + + def _make_config(*, connection: Connection) -> Config: """ Build an Alembic config for the memory migration scripts. @@ -58,15 +84,24 @@ def _validate_and_stamp_unversioned_memory_schema(*, config: Config, connection: Raises: RuntimeError: If an unversioned memory schema does not match models. """ - table_names = set(inspect(connection).get_table_names()) + # Perform all inspection in one atomic call to avoid race conditions + inspector = inspect(connection) + table_names = set(inspector.get_table_names()) + + # If version table already exists, migration has been stamped if _MEMORY_ALEMBIC_VERSION_TABLE in table_names: return + # If no memory tables exist, this is a fresh database if not _MEMORY_TABLES.intersection(table_names): return + # Unversioned memory schema detected; validate it matches current models try: - migration_context = MigrationContext.configure(connection=connection, opts={"compare_type": True}) + migration_context = MigrationContext.configure( + connection=connection, + opts={"compare_type": True, "include_name": _include_name_for_memory_schema}, + ) diffs = compare_metadata(migration_context, Base.metadata) except Exception as e: raise RuntimeError( @@ -96,8 +131,6 @@ def run_schema_migrations(*, engine: Engine) -> None: Raises: Exception: If Alembic fails to apply migrations. """ - script_location = Path(__file__).with_name("alembic") - logger.debug("Applying Alembic migrations from %s", script_location) with engine.begin() as connection: config = _make_config(connection=connection) _validate_and_stamp_unversioned_memory_schema(config=config, connection=connection) @@ -113,10 +146,14 @@ def reset_schema(*, engine: Engine) -> None: """ logger.debug("Resetting memory schema using Alembic migrations") with engine.begin() as connection: - Base.metadata.drop_all(connection) - + # Drop version table first (not part of Base.metadata) inspector = inspect(connection) if _MEMORY_ALEMBIC_VERSION_TABLE in inspector.get_table_names(): - connection.execute(text(f'DROP TABLE "{_MEMORY_ALEMBIC_VERSION_TABLE}"')) + version_table = Table(_MEMORY_ALEMBIC_VERSION_TABLE, MetaData(), autoload_with=connection) + version_table.drop(connection) + + # Drop all application tables defined in models + Base.metadata.drop_all(connection) + # Rebuild schema from migrations command.upgrade(_make_config(connection=connection), _HEAD_REVISION) diff --git a/tests/unit/memory/test_migration.py b/tests/unit/memory/test_migration.py new file mode 100644 index 0000000000..c1ecb9bad5 --- /dev/null +++ b/tests/unit/memory/test_migration.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import tempfile +import uuid +from pathlib import Path + +from alembic import command +from alembic.config import Config +from alembic.script import ScriptDirectory +from sqlalchemy import create_engine, inspect, text + +from pyrit.memory.alembic.versions.e373726d391b_initial_schema import _CustomUUID +from pyrit.memory.migration import run_schema_migrations + + +def _get_alembic_head_revision(*, config: Config) -> str: + """Return the current Alembic head revision for the configured script location.""" + head_revision = ScriptDirectory.from_config(config).get_current_head() + if head_revision is None: + raise RuntimeError("No Alembic head revision found for memory migrations.") + + return head_revision + + +def test_custom_uuid_process_bind_param_with_none(): + """Test _CustomUUID.process_bind_param with None value.""" + uuid_type = _CustomUUID() + result = uuid_type.process_bind_param(None, None) + assert result is None + + +def test_custom_uuid_process_bind_param_with_uuid(): + """Test _CustomUUID.process_bind_param with UUID value.""" + uuid_type = _CustomUUID() + test_uuid = uuid.uuid4() + result = uuid_type.process_bind_param(test_uuid, None) + assert result == str(test_uuid) + + +def test_custom_uuid_process_result_value_with_none(): + """Test _CustomUUID.process_result_value with None value.""" + uuid_type = _CustomUUID() + result = uuid_type.process_result_value(None, None) + assert result is None + + +def test_custom_uuid_process_result_value_with_uuid(): + """Test _CustomUUID.process_result_value with UUID value.""" + uuid_type = _CustomUUID() + test_uuid = uuid.uuid4() + result = uuid_type.process_result_value(test_uuid, None) + assert result == test_uuid + + +def test_custom_uuid_process_result_value_with_string(): + """Test _CustomUUID.process_result_value with string value.""" + uuid_type = _CustomUUID() + test_uuid = uuid.uuid4() + result = uuid_type.process_result_value(str(test_uuid), None) + assert result == test_uuid + + +def test_custom_uuid_load_dialect_impl_sqlite(): + """Test _CustomUUID.load_dialect_impl for SQLite dialect.""" + from sqlalchemy.dialects import sqlite + + uuid_type = _CustomUUID() + dialect = sqlite.dialect() + result = uuid_type.load_dialect_impl(dialect) + assert result is not None + + +def test_custom_uuid_load_dialect_impl_postgresql(): + """Test _CustomUUID.load_dialect_impl for PostgreSQL dialect.""" + from sqlalchemy.dialects import postgresql + + uuid_type = _CustomUUID() + dialect = postgresql.dialect() + result = uuid_type.load_dialect_impl(dialect) + assert result is not None + + +def test_run_schema_migrations_applies_head_revision(): + """Test that run_schema_migrations applies the current head revision.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "offline-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + pyrit_root = Path(__file__).resolve().parent.parent.parent.parent / "pyrit" + script_location = pyrit_root / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + expected_head = _get_alembic_head_revision(config=config) + + run_schema_migrations(engine=engine) + + with engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + assert version == expected_head + finally: + engine.dispose() + + +def test_migration_online_mode(): + """ + Test that online migration configuration is valid. + This tests the run_migrations_online path in env.py. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "online-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + pyrit_root = Path(__file__).resolve().parent.parent.parent.parent / "pyrit" + script_location = pyrit_root / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + config.attributes["connection"] = connection + config.attributes["version_table"] = "pyrit_memory_alembic_version" + expected_head = _get_alembic_head_revision(config=config) + + command.upgrade(config, "head") + + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + assert version == expected_head + finally: + engine.dispose() + + +def test_migration_script_metadata(): + """Test that the initial migration script has correct metadata.""" + from pyrit.memory.alembic.versions import e373726d391b_initial_schema + + assert e373726d391b_initial_schema.revision == "e373726d391b" + assert e373726d391b_initial_schema.down_revision is None + assert e373726d391b_initial_schema.branch_labels is None + assert e373726d391b_initial_schema.depends_on is None + + +def test_migration_downgrade_creates_proper_structure(): + """ + Test that downgrade function doesn't corrupt the database. + This indirectly tests the downgrade path in e373726d391b_initial_schema.py. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "downgrade-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + with engine.begin() as connection: + pyrit_root = Path(__file__).resolve().parent.parent.parent.parent / "pyrit" + script_location = pyrit_root / "memory" / "alembic" + config = Config() + config.set_main_option("script_location", str(script_location)) + config.attributes["connection"] = connection + config.attributes["version_table"] = "pyrit_memory_alembic_version" + + command.upgrade(config, "head") + + tables_before = set(inspect(connection).get_table_names()) + assert len(tables_before) > 0 + + command.downgrade(config, "base") + + tables_after = set(inspect(connection).get_table_names()) + assert "pyrit_memory_alembic_version" not in tables_after or len(tables_after) == 1 + finally: + engine.dispose() diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index b78cb28489..f274607d86 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -8,7 +8,7 @@ import uuid from collections.abc import Sequence from datetime import timezone -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from sqlalchemy import ARRAY, DateTime, Integer, String, create_engine, inspect, text @@ -19,6 +19,7 @@ from pyrit.memory.memory_models import Base, EmbeddingDataEntry, PromptMemoryEntry from pyrit.memory.migration import run_schema_migrations +from pyrit.memory.sqlite_memory import SQLiteMemory from pyrit.models import MessagePiece from pyrit.prompt_converter.base64_converter import Base64Converter from pyrit.prompt_target.text_target import TextTarget @@ -178,6 +179,38 @@ def test_run_schema_migrations_stamps_matching_unversioned_legacy_database(): engine.dispose() +def test_run_schema_migrations_stamps_unversioned_legacy_database_with_extra_tables(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "legacy-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + Base.metadata.create_all(engine, checkfirst=True) + with engine.begin() as connection: + connection.execute( + text( + """ + CREATE TABLE "SharedAuditLog" ( + id INTEGER PRIMARY KEY, + event_name VARCHAR NOT NULL + ) + """ + ) + ) + + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "SharedAuditLog" in table_names + assert "pyrit_memory_alembic_version" in table_names + + with engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + + assert version + finally: + engine.dispose() + + def test_run_schema_migrations_fails_synthetic_unversioned_schema_with_drift(): with tempfile.TemporaryDirectory() as temp_dir: db_path = os.path.join(temp_dir, "legacy-memory.db") @@ -843,3 +876,82 @@ def test_create_engine_uses_static_pool_for_in_memory(sqlite_instance): from sqlalchemy.pool import StaticPool assert isinstance(sqlite_instance.engine.pool, StaticPool) + + +def test_run_schema_migrations_early_return_with_existing_version_table(): + """ + Test that migration early-returns when the version table already exists. + This tests the line 57 return in _validate_and_stamp_unversioned_memory_schema. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "versioned-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + Base.metadata.create_all(engine, checkfirst=True) + + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "pyrit_memory_alembic_version" in table_names + + with engine.connect() as connection: + version = connection.execute(text("SELECT version_num FROM pyrit_memory_alembic_version")).scalar_one() + assert version + + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert "pyrit_memory_alembic_version" in table_names + + with engine.connect() as connection: + version_after = connection.execute( + text("SELECT version_num FROM pyrit_memory_alembic_version") + ).scalar_one() + assert version_after == version + finally: + engine.dispose() + + +def test_run_schema_migrations_no_memory_tables(): + """ + Test that migration early-returns when no memory tables exist. + This tests the line 60 return in _validate_and_stamp_unversioned_memory_schema. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "empty-memory.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + run_schema_migrations(engine=engine) + + table_names = set(inspect(engine).get_table_names()) + assert { + "AttackResultEntries", + "EmbeddingData", + "PromptMemoryEntries", + "ScenarioResultEntries", + "ScoreEntries", + "SeedPromptEntries", + "pyrit_memory_alembic_version", + }.issubset(table_names) + finally: + engine.dispose() + + +def test_create_tables_if_not_exist_with_schema_migration_exception(): + """ + Test that _create_tables_if_not_exist properly handles and re-raises exceptions. + This tests lines 132-134 in memory_interface.py. + """ + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "error-memory.db") + + memory = SQLiteMemory(db_path=db_path) + + try: + with patch("pyrit.memory.memory_interface.run_schema_migrations") as mock_migrate: + mock_migrate.side_effect = RuntimeError("Mock migration error") + + with pytest.raises(RuntimeError, match="Mock migration error"): + memory._create_tables_if_not_exist() + finally: + memory.dispose_engine() From 570cc60d195683aba7af77a95dd517ff7e689ab7 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 21 Apr 2026 13:52:17 -0700 Subject: [PATCH 08/23] Revert "format" This reverts commit aafefc1ad8c222f9cb451a497e3f11fe3bb2381d. --- pyrit/memory/azure_sql_memory.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index e7d3097615..3023bf8240 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -465,20 +465,24 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: for i, (key, value) in enumerate(labels.items()): path_param = f"ar_label_path_{i}" value_param = f"ar_label_val_{i}" - ar_label_conditions.append(f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}') + ar_label_conditions.append( + f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' + ) ar_bindparams[path_param] = f"$.{key}" ar_bindparams[value_param] = str(value) ar_combined = " AND ".join(ar_label_conditions) direct_condition = and_( AttackResultEntry.labels.isnot(None), - text(f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}').bindparams(**ar_bindparams), + text( + f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}' + ).bindparams(**ar_bindparams), ) # --- Conversation-level match on PromptMemoryEntry.labels --- pme_label_conditions = [] pme_bindparams: dict[str, str] = {} - for key, value in labels.items(): + for i, (key, value) in enumerate(labels.items()): param_name = f"pme_label_{key}" pme_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") pme_bindparams[param_name] = str(value) @@ -488,7 +492,9 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: and_( PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, PromptMemoryEntry.labels.isnot(None), - text(f"ISJSON(labels) = 1 AND {pme_combined}").bindparams(**pme_bindparams), + text( + f"ISJSON(labels) = 1 AND {pme_combined}" + ).bindparams(**pme_bindparams), ) ) From f0c1c3df34898de15e29fdfd5cd97265ecdd65b1 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 21 Apr 2026 13:52:30 -0700 Subject: [PATCH 09/23] Revert "review feedback" This reverts commit fa5a0d76bab4c582a5e4f86ae8f727d08b3400fd. --- pyrit/memory/memory_interface.py | 4 ++-- .../test_interface_attack_results.py | 21 ------------------- 2 files changed, 2 insertions(+), 23 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index f84e59d3fa..79be5b7c3d 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1482,8 +1482,8 @@ def get_attack_results( By providing a list, this means ALL categories in the list must be present. Defaults to None. labels (Optional[dict[str, str]], optional): A dictionary of labels to filter results by. - These labels are stored on the AttackResult or associated PromptMemoryEntry (via conversation_id) - . All specified key-value pairs must be present (AND logic). Defaults to None. + These labels are stored directly on the AttackResult. All specified key-value pairs + must be present (AND logic). Defaults to None. identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A sequence of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index e0620d131a..7855e9ecfc 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -981,27 +981,6 @@ def test_get_attack_results_labels_key_exists_value_mismatch(sqlite_instance: Me assert results[0].conversation_id == "conv_1" -def test_get_attack_results_by_labels_falls_back_to_conversation_labels(sqlite_instance: MemoryInterface): - """Test that label filtering matches via PromptMemoryEntry when AttackResult has no labels.""" - - # Attack result with NO labels - attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={}) - sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) - - # Conversation message carries the labels instead - message_piece = create_message_piece("conv_1", 1, labels={"operation": "legacy_op"}) - sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) - - # Should still find the attack result via the PME fallback path - results = sqlite_instance.get_attack_results(labels={"operation": "legacy_op"}) - assert len(results) == 1 - assert results[0].conversation_id == "conv_1" - - # Non-matching label should return nothing - results = sqlite_instance.get_attack_results(labels={"operation": "missing"}) - assert len(results) == 0 - - # --------------------------------------------------------------------------- # get_unique_attack_labels tests # --------------------------------------------------------------------------- From 483c1ecd4fb472926ee4b57efe58b63108dafe6e Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 21 Apr 2026 13:52:41 -0700 Subject: [PATCH 10/23] Revert " label queries OR'ed with old way" This reverts commit 5e04ca7e7f515ca8dca695d93b03f8f4a481aef1. --- pyrit/memory/azure_sql_memory.py | 56 +++++++++----------------------- pyrit/memory/memory_interface.py | 2 +- pyrit/memory/sqlite_memory.py | 21 +++--------- 3 files changed, 21 insertions(+), 58 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 3023bf8240..64ff020416 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -9,7 +9,7 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union -from sqlalchemy import and_, create_engine, event, exists, or_, text +from sqlalchemy import and_, create_engine, event, exists, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker @@ -446,10 +446,8 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - Get the SQL Azure implementation for filtering AttackResults by labels. - - Matches if the labels are found on the AttackResultEntry directly - OR on an associated PromptMemoryEntry (via conversation_id). + Get the SQL Azure implementation for filtering AttackResults by labels + stored directly on the AttackResultEntry. Uses JSON_VALUE() function specific to SQL Azure with parameterized queries. @@ -459,47 +457,23 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: Returns: Any: SQLAlchemy condition with bound parameters. """ - # --- Direct match on AttackResultEntry.labels --- - ar_label_conditions = [] - ar_bindparams: dict[str, str] = {} + # Build JSON conditions for all labels with parameterized queries + label_conditions = [] + bindparams_dict = {} for i, (key, value) in enumerate(labels.items()): - path_param = f"ar_label_path_{i}" - value_param = f"ar_label_val_{i}" - ar_label_conditions.append( - f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' - ) - ar_bindparams[path_param] = f"$.{key}" - ar_bindparams[value_param] = str(value) + path_param = f"label_path_{i}" + value_param = f"label_val_{i}" + label_conditions.append(f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}') + bindparams_dict[path_param] = f"$.{key}" + bindparams_dict[value_param] = str(value) - ar_combined = " AND ".join(ar_label_conditions) - direct_condition = and_( - AttackResultEntry.labels.isnot(None), - text( - f'ISJSON("AttackResultEntries".labels) = 1 AND {ar_combined}' - ).bindparams(**ar_bindparams), - ) - - # --- Conversation-level match on PromptMemoryEntry.labels --- - pme_label_conditions = [] - pme_bindparams: dict[str, str] = {} - for i, (key, value) in enumerate(labels.items()): - param_name = f"pme_label_{key}" - pme_label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") - pme_bindparams[param_name] = str(value) + combined_conditions = " AND ".join(label_conditions) - pme_combined = " AND ".join(pme_label_conditions) - conversation_condition = exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - PromptMemoryEntry.labels.isnot(None), - text( - f"ISJSON(labels) = 1 AND {pme_combined}" - ).bindparams(**pme_bindparams), - ) + return and_( + AttackResultEntry.labels.isnot(None), + text(f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}').bindparams(**bindparams_dict), ) - return or_(direct_condition, conversation_condition) - def get_unique_attack_class_names(self) -> list[str]: """ Azure SQL implementation: extract unique class_name values from diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 79be5b7c3d..7269266697 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -420,7 +420,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ Return a database-specific condition for filtering AttackResults by labels - stored directly on the AttackResultEntry OR on an associated PromptMemoryEntry (via conversation_id). + stored directly on the AttackResultEntry. Args: labels: Dictionary of labels that must ALL be present. diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index ac8b2319eb..3c4bb287fd 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any, Optional, TypeVar, Union, cast -from sqlalchemy import and_, create_engine, exists, func, or_, text +from sqlalchemy import and_, create_engine, func, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker @@ -613,29 +613,18 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - SQLite implementation for filtering AttackResults by labels. - - Matches if the labels are found on the AttackResultEntry directly - OR on an associated PromptMemoryEntry (via conversation_id). + SQLite implementation for filtering AttackResults by labels + stored directly on the AttackResultEntry. + Uses json_extract() function specific to SQLite. Returns: Any: A SQLAlchemy condition for filtering by labels. """ - direct_condition = and_( + return and_( AttackResultEntry.labels.isnot(None), *[func.json_extract(AttackResultEntry.labels, f"$.{key}") == value for key, value in labels.items()], ) - conversation_condition = exists().where( - and_( - PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, - PromptMemoryEntry.labels.isnot(None), - *[func.json_extract(PromptMemoryEntry.labels, f"$.{key}") == value for key, value in labels.items()], - ) - ) - - return or_(direct_condition, conversation_condition) - def get_unique_attack_class_names(self) -> list[str]: """ SQLite implementation: extract unique class_name values from From dc67bc62040a20e9f27c9f4e88aa4ac934cd9f41 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 21 Apr 2026 13:52:53 -0700 Subject: [PATCH 11/23] Revert "format" This reverts commit 4ec494a67867d6bccf1da5eb4d21ccc98d5e92db. --- pyrit/memory/azure_sql_memory.py | 8 +++++-- tests/unit/backend/test_attack_service.py | 4 +++- tests/unit/backend/test_mappers.py | 9 ++++--- .../test_interface_attack_results.py | 24 +++++++++---------- 4 files changed, 25 insertions(+), 20 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 64ff020416..15586f7152 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -463,7 +463,9 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: for i, (key, value) in enumerate(labels.items()): path_param = f"label_path_{i}" value_param = f"label_val_{i}" - label_conditions.append(f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}') + label_conditions.append( + f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' + ) bindparams_dict[path_param] = f"$.{key}" bindparams_dict[value_param] = str(value) @@ -471,7 +473,9 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: return and_( AttackResultEntry.labels.isnot(None), - text(f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}').bindparams(**bindparams_dict), + text( + f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}' + ).bindparams(**bindparams_dict), ) def get_unique_attack_class_names(self) -> list[str]: diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index c494e29944..e62ac73d63 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -95,7 +95,9 @@ def make_attack_result( "created_at": created.isoformat(), "updated_at": updated.isoformat(), }, - labels={"test_ar_label": "test_ar_value"}, + labels={ + "test_ar_label": "test_ar_value" + }, ) diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index ad6ef1a380..f7c2495c71 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -81,7 +81,9 @@ def _make_attack_result( "created_at": now.isoformat(), "updated_at": now.isoformat(), }, - labels={"test_ar_label": "test_ar_value"}, + labels={ + "test_ar_label": "test_ar_value" + }, ) @@ -189,10 +191,7 @@ def test_labels_passed_through_without_normalization(self) -> None: summary = attack_result_to_summary(ar, stats=stats) assert summary.labels == { - "operator": "alice", - "operation": "op_red", - "env": "prod", - "test_ar_label": "test_ar_value", + "operator": "alice", "operation": "op_red", "env": "prod", "test_ar_label": "test_ar_value" } def test_conversation_labels_take_precedence_on_collision(self) -> None: diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 7855e9ecfc..a34fbdbec7 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -814,21 +814,15 @@ def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface) # Create attack results with multiple labels attack_results = [ create_attack_result( - "conv_1", - 1, - AttackOutcome.SUCCESS, + "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey", "phase": "initial"}, ), create_attack_result( - "conv_2", - 2, - AttackOutcome.SUCCESS, + "conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey", "phase": "final"}, ), create_attack_result( - "conv_3", - 3, - AttackOutcome.FAILURE, + "conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "test_op", "phase": "initial"}, ), ] @@ -858,9 +852,15 @@ def test_get_attack_results_by_harm_category_and_labels(sqlite_instance: MemoryI # Create attack results with labels attack_results = [ - create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"}), - create_attack_result("conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"}), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "other_op", "operator": "bob"}), + create_attack_result( + "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} + ), + create_attack_result( + "conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} + ), + create_attack_result( + "conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "other_op", "operator": "bob"} + ), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) From 543a582fcbd9d1d7b8f078fae74c6ea3a435b06e Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 21 Apr 2026 13:53:16 -0700 Subject: [PATCH 12/23] Revert "Add labels to attack results" This reverts commit b63361c44359a11a66daabf5e7f202b926817cb4. --- pyrit/backend/mappers/attack_mappers.py | 5 +- pyrit/backend/services/attack_service.py | 1 - .../attack/multi_turn/chunked_request.py | 1 - pyrit/executor/attack/multi_turn/crescendo.py | 1 - .../attack/multi_turn/multi_prompt_sending.py | 1 - .../executor/attack/multi_turn/red_teaming.py | 1 - .../attack/multi_turn/tree_of_attacks.py | 1 - .../attack/single_turn/prompt_sending.py | 1 - .../attack/single_turn/skeleton_key.py | 1 - pyrit/executor/benchmark/fairness_bias.py | 1 - pyrit/memory/azure_sql_memory.py | 28 ++--- pyrit/memory/memory_interface.py | 8 +- pyrit/memory/memory_models.py | 4 - pyrit/memory/sqlite_memory.py | 21 +++- pyrit/models/attack_result.py | 3 - tests/unit/backend/test_attack_service.py | 5 +- tests/unit/backend/test_mappers.py | 22 +--- .../test_interface_attack_results.py | 117 ++++++++++-------- tests/unit/scenario/test_scenario.py | 1 - 19 files changed, 101 insertions(+), 122 deletions(-) diff --git a/pyrit/backend/mappers/attack_mappers.py b/pyrit/backend/mappers/attack_mappers.py index c37dd77fd9..0245e2af12 100644 --- a/pyrit/backend/mappers/attack_mappers.py +++ b/pyrit/backend/mappers/attack_mappers.py @@ -197,11 +197,8 @@ def attack_result_to_summary( """ message_count = stats.message_count last_preview = stats.last_message_preview + labels = dict(stats.labels) if stats.labels else {} - # Merge attack-result labels with conversation-level labels. - # Conversation labels take precedence on key collision. - labels = dict(ar.labels) if ar.labels else {} - labels.update(stats.labels or {}) created_str = ar.metadata.get("created_at") updated_str = ar.metadata.get("updated_at") created_at = datetime.fromisoformat(created_str) if created_str else datetime.now(timezone.utc) diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 4105bb68aa..5c8ea20b6d 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -308,7 +308,6 @@ async def create_attack_async(self, *, request: CreateAttackRequest) -> CreateAt "created_at": now.isoformat(), "updated_at": now.isoformat(), }, - labels=labels, ) # Store in memory diff --git a/pyrit/executor/attack/multi_turn/chunked_request.py b/pyrit/executor/attack/multi_turn/chunked_request.py index ed95c5d226..1a70c89195 100644 --- a/pyrit/executor/attack/multi_turn/chunked_request.py +++ b/pyrit/executor/attack/multi_turn/chunked_request.py @@ -325,7 +325,6 @@ async def _perform_async(self, *, context: ChunkedRequestAttackContext) -> Attac outcome_reason=outcome_reason, executed_turns=context.executed_turns, metadata={"combined_chunks": combined_value, "chunk_count": len(context.chunk_responses)}, - labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index f137b322f3..4a180d5df3 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -402,7 +402,6 @@ async def _perform_async(self, *, context: CrescendoAttackContext) -> CrescendoA last_response=context.last_response.get_piece() if context.last_response else None, last_score=context.last_score, related_conversations=context.related_conversations, - labels=context.memory_labels, ) # setting metadata for backtrack count result.backtrack_count = context.backtrack_count diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index 8447737578..a9d4b75adc 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -295,7 +295,6 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac outcome=outcome, outcome_reason=outcome_reason, executed_turns=context.executed_turns, - labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index 1feec20586..a8778f664a 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -322,7 +322,6 @@ async def _perform_async(self, *, context: MultiTurnAttackContext[Any]) -> Attac last_response=context.last_response.get_piece() if context.last_response else None, last_score=context.last_score, related_conversations=context.related_conversations, - labels=context.memory_labels, ) async def _teardown_async(self, *, context: MultiTurnAttackContext[Any]) -> None: diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index f6ccc4ed64..e92bd1cf67 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -2082,7 +2082,6 @@ def _create_attack_result( last_response=last_response, last_score=context.best_objective_score, related_conversations=context.related_conversations, - labels=context.memory_labels, ) # Set attack-specific metadata using properties diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index cdb2d4b619..07f1d670fa 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -238,7 +238,6 @@ async def _perform_async(self, *, context: SingleTurnAttackContext[Any]) -> Atta outcome=outcome, outcome_reason=outcome_reason, executed_turns=1, - labels=context.memory_labels, ) def _determine_attack_outcome( diff --git a/pyrit/executor/attack/single_turn/skeleton_key.py b/pyrit/executor/attack/single_turn/skeleton_key.py index 40cc5cc302..683614dce5 100644 --- a/pyrit/executor/attack/single_turn/skeleton_key.py +++ b/pyrit/executor/attack/single_turn/skeleton_key.py @@ -181,5 +181,4 @@ def _create_skeleton_key_failure_result(self, *, context: SingleTurnAttackContex outcome=AttackOutcome.FAILURE, outcome_reason="Skeleton key prompt was filtered or failed", executed_turns=1, - labels=context.memory_labels, ) diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index 63d33f4639..05bb424c17 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -200,7 +200,6 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta atomic_attack_identifier=build_atomic_attack_identifier( attack_identifier=ComponentIdentifier.of(self), ), - labels=context.memory_labels, ) return last_attack_result diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 15586f7152..207a7f7f98 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -446,8 +446,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - Get the SQL Azure implementation for filtering AttackResults by labels - stored directly on the AttackResultEntry. + Get the SQL Azure implementation for filtering AttackResults by labels. Uses JSON_VALUE() function specific to SQL Azure with parameterized queries. @@ -455,27 +454,24 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: labels (dict[str, str]): Dictionary of label key-value pairs to filter by. Returns: - Any: SQLAlchemy condition with bound parameters. + Any: SQLAlchemy exists subquery condition with bound parameters. """ # Build JSON conditions for all labels with parameterized queries label_conditions = [] bindparams_dict = {} - for i, (key, value) in enumerate(labels.items()): - path_param = f"label_path_{i}" - value_param = f"label_val_{i}" - label_conditions.append( - f'JSON_VALUE("AttackResultEntries".labels, :{path_param}) = :{value_param}' - ) - bindparams_dict[path_param] = f"$.{key}" - bindparams_dict[value_param] = str(value) + for key, value in labels.items(): + param_name = f"label_{key}" + label_conditions.append(f"JSON_VALUE(labels, '$.{key}') = :{param_name}") + bindparams_dict[param_name] = str(value) combined_conditions = " AND ".join(label_conditions) - return and_( - AttackResultEntry.labels.isnot(None), - text( - f'ISJSON("AttackResultEntries".labels) = 1 AND {combined_conditions}' - ).bindparams(**bindparams_dict), + return exists().where( + and_( + PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, + PromptMemoryEntry.labels.isnot(None), + text(f"ISJSON(labels) = 1 AND {combined_conditions}").bindparams(**bindparams_dict), + ) ) def get_unique_attack_class_names(self) -> list[str]: diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 7269266697..9bf33031f9 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -420,7 +420,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ Return a database-specific condition for filtering AttackResults by labels - stored directly on the AttackResultEntry. + in the associated PromptMemoryEntry records. Args: labels: Dictionary of labels that must ALL be present. @@ -1481,9 +1481,9 @@ def get_attack_results( not necessarily one(s) that were found in the response. By providing a list, this means ALL categories in the list must be present. Defaults to None. - labels (Optional[dict[str, str]], optional): A dictionary of labels to filter results by. - These labels are stored directly on the AttackResult. All specified key-value pairs - must be present (AND logic). Defaults to None. + labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. + These labels are associated with the prompts themselves, used for custom tagging and tracking. + Defaults to None. identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A sequence of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 60fb1dca64..b34c906af6 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -697,7 +697,6 @@ class AttackResultEntry(Base): outcome (AttackOutcome): The outcome of the attack, indicating success, failure, or undetermined. outcome_reason (str): Optional reason for the outcome, providing additional context. attack_metadata (dict[str, Any]): Metadata can be included as key-value pairs to provide extra context. - labels (dict[str, str]): Optional labels associated with the attack result entry. pruned_conversation_ids (List[str]): List of conversation IDs that were pruned from the attack. adversarial_chat_conversation_ids (List[str]): List of conversation IDs used for adversarial chat. timestamp (DateTime): The timestamp of the attack result entry. @@ -729,7 +728,6 @@ class AttackResultEntry(Base): ) outcome_reason = mapped_column(String, nullable=True) attack_metadata: Mapped[dict[str, Union[str, int, float, bool]]] = mapped_column(JSON, nullable=True) - labels: Mapped[dict[str, str]] = mapped_column(JSON, nullable=True) pruned_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) adversarial_chat_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) timestamp = mapped_column(DateTime, nullable=False) @@ -785,7 +783,6 @@ def __init__(self, *, entry: AttackResult): self.outcome = entry.outcome.value self.outcome_reason = entry.outcome_reason self.attack_metadata = self.filter_json_serializable_metadata(entry.metadata) - self.labels = entry.labels or {} # Persist conversation references by type self.pruned_conversation_ids = [ @@ -897,7 +894,6 @@ def get_attack_result(self) -> AttackResult: outcome_reason=self.outcome_reason, related_conversations=related_conversations, metadata=self.attack_metadata or {}, - labels=self.labels or {}, ) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 3c4bb287fd..bd376d67cd 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -613,17 +613,26 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ - SQLite implementation for filtering AttackResults by labels - stored directly on the AttackResultEntry. + SQLite implementation for filtering AttackResults by labels. Uses json_extract() function specific to SQLite. Returns: - Any: A SQLAlchemy condition for filtering by labels. + Any: A SQLAlchemy subquery for filtering by labels. """ - return and_( - AttackResultEntry.labels.isnot(None), - *[func.json_extract(AttackResultEntry.labels, f"$.{key}") == value for key, value in labels.items()], + from sqlalchemy import and_, exists, func + + from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry + + labels_subquery = exists().where( + and_( + PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, + PromptMemoryEntry.labels.isnot(None), + and_( + *[func.json_extract(PromptMemoryEntry.labels, f"$.{key}") == value for key, value in labels.items()] + ), + ) ) + return labels_subquery # noqa: RET504 def get_unique_attack_class_names(self) -> list[str]: """ diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 5cbdf3c93e..a385ac36e7 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -87,9 +87,6 @@ class AttackResult(StrategyResult): # Arbitrary metadata metadata: dict[str, Any] = field(default_factory=dict) - # labels associated with this attack result - labels: dict[str, str] = field(default_factory=dict) - @property def attack_identifier(self) -> Optional[ComponentIdentifier]: """ diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index e62ac73d63..3d09fe6b7f 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -95,9 +95,6 @@ def make_attack_result( "created_at": created.isoformat(), "updated_at": updated.isoformat(), }, - labels={ - "test_ar_label": "test_ar_value" - }, ) @@ -324,7 +321,7 @@ async def test_list_attacks_includes_labels_in_summary(self, attack_service, moc result = await attack_service.list_attacks_async() assert len(result.items) == 1 - assert result.items[0].labels == {"env": "prod", "team": "red", "test_ar_label": "test_ar_value"} + assert result.items[0].labels == {"env": "prod", "team": "red"} @pytest.mark.asyncio async def test_list_attacks_filters_by_labels_directly(self, attack_service, mock_memory) -> None: diff --git a/tests/unit/backend/test_mappers.py b/tests/unit/backend/test_mappers.py index f7c2495c71..0f483b3f10 100644 --- a/tests/unit/backend/test_mappers.py +++ b/tests/unit/backend/test_mappers.py @@ -81,9 +81,6 @@ def _make_attack_result( "created_at": now.isoformat(), "updated_at": now.isoformat(), }, - labels={ - "test_ar_label": "test_ar_value" - }, ) @@ -178,7 +175,7 @@ def test_labels_are_mapped(self) -> None: summary = attack_result_to_summary(ar, stats=stats) - assert summary.labels == {"env": "prod", "team": "red", "test_ar_label": "test_ar_value"} + assert summary.labels == {"env": "prod", "team": "red"} def test_labels_passed_through_without_normalization(self) -> None: """Test that labels are passed through as-is (DB stores canonical keys after migration).""" @@ -190,21 +187,7 @@ def test_labels_passed_through_without_normalization(self) -> None: summary = attack_result_to_summary(ar, stats=stats) - assert summary.labels == { - "operator": "alice", "operation": "op_red", "env": "prod", "test_ar_label": "test_ar_value" - } - - def test_conversation_labels_take_precedence_on_collision(self) -> None: - """Test that conversation-level labels override attack-result labels on key collision.""" - ar = _make_attack_result() - stats = ConversationStats( - message_count=1, - labels={"test_ar_label": "conversation_wins"}, - ) - - summary = attack_result_to_summary(ar, stats=stats) - - assert summary.labels["test_ar_label"] == "conversation_wins" + assert summary.labels == {"operator": "alice", "operation": "op_red", "env": "prod"} def test_outcome_success(self) -> None: """Test that success outcome is mapped.""" @@ -266,7 +249,6 @@ def test_converters_extracted_from_identifier(self) -> None: ), outcome=AttackOutcome.UNDETERMINED, metadata={"created_at": now.isoformat(), "updated_at": now.isoformat()}, - labels={"test_label": "test_value"}, ) summary = attack_result_to_summary(ar, stats=ConversationStats(message_count=0)) diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index a34fbdbec7..6a3857c570 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -36,18 +36,12 @@ def create_message_piece(conversation_id: str, prompt_num: int, targeted_harm_ca ) -def create_attack_result( - conversation_id: str, - objective_num: int, - outcome: AttackOutcome = AttackOutcome.SUCCESS, - labels: dict[str, str] | None = None, -): +def create_attack_result(conversation_id: str, objective_num: int, outcome: AttackOutcome = AttackOutcome.SUCCESS): """Helper function to create AttackResult.""" return AttackResult( conversation_id=conversation_id, objective=f"Objective {objective_num}", outcome=outcome, - labels=labels or {}, ) @@ -786,14 +780,17 @@ def test_get_attack_results_by_harm_category_multiple(sqlite_instance: MemoryInt def test_get_attack_results_by_labels_single(sqlite_instance: MemoryInterface): """Test filtering attack results by single label.""" - # Create attack results with labels - attack_result1 = create_attack_result( - "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} - ) - attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE, labels={"operation": "test_op"}) - attack_result3 = create_attack_result( - "conv_3", 3, AttackOutcome.SUCCESS, labels={"operation": "other_op", "operator": "roakey"} - ) + # Create message pieces with labels + message_piece1 = create_message_piece("conv_1", 1, labels={"operation": "test_op", "operator": "roakey"}) + message_piece2 = create_message_piece("conv_2", 2, labels={"operation": "test_op"}) + message_piece3 = create_message_piece("conv_3", 3, labels={"operation": "other_op", "operator": "roakey"}) + + sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) + + # Create attack results + attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) + attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE) + attack_result3 = create_attack_result("conv_3", 3, AttackOutcome.SUCCESS) sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) @@ -811,20 +808,22 @@ def test_get_attack_results_by_labels_single(sqlite_instance: MemoryInterface): def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface): """Test filtering attack results by multiple labels (AND logic).""" - # Create attack results with multiple labels + # Create message pieces with multiple labels using helper function + message_piece1 = create_message_piece( + "conv_1", 1, labels={"operation": "test_op", "operator": "roakey", "phase": "initial"} + ) + message_piece2 = create_message_piece( + "conv_2", 2, labels={"operation": "test_op", "operator": "roakey", "phase": "final"} + ) + message_piece3 = create_message_piece("conv_3", 3, labels={"operation": "test_op", "phase": "initial"}) + + sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) + + # Create attack results attack_results = [ - create_attack_result( - "conv_1", 1, AttackOutcome.SUCCESS, - labels={"operation": "test_op", "operator": "roakey", "phase": "initial"}, - ), - create_attack_result( - "conv_2", 2, AttackOutcome.SUCCESS, - labels={"operation": "test_op", "operator": "roakey", "phase": "final"}, - ), - create_attack_result( - "conv_3", 3, AttackOutcome.FAILURE, - labels={"operation": "test_op", "phase": "initial"}, - ), + create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), + create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), + create_attack_result("conv_3", 3, AttackOutcome.FAILURE), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) @@ -843,24 +842,30 @@ def test_get_attack_results_by_labels_multiple(sqlite_instance: MemoryInterface) def test_get_attack_results_by_harm_category_and_labels(sqlite_instance: MemoryInterface): """Test filtering attack results by both harm categories and labels.""" - # Create message pieces with harm categories (harm categories still live on PromptMemoryEntry) - message_piece1 = create_message_piece("conv_1", 1, targeted_harm_categories=["violence", "illegal"]) - message_piece2 = create_message_piece("conv_2", 2, targeted_harm_categories=["violence"]) - message_piece3 = create_message_piece("conv_3", 3, targeted_harm_categories=["violence", "illegal"]) + # Create message pieces with both harm categories and labels using helper function + message_piece1 = create_message_piece( + "conv_1", + 1, + targeted_harm_categories=["violence", "illegal"], + labels={"operation": "test_op", "operator": "roakey"}, + ) + message_piece2 = create_message_piece( + "conv_2", 2, targeted_harm_categories=["violence"], labels={"operation": "test_op", "operator": "roakey"} + ) + message_piece3 = create_message_piece( + "conv_3", + 3, + targeted_harm_categories=["violence", "illegal"], + labels={"operation": "other_op", "operator": "bob"}, + ) sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) - # Create attack results with labels + # Create attack results attack_results = [ - create_attack_result( - "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} - ), - create_attack_result( - "conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "test_op", "operator": "roakey"} - ), - create_attack_result( - "conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "other_op", "operator": "bob"} - ), + create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), + create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), + create_attack_result("conv_3", 3, AttackOutcome.FAILURE), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) @@ -899,8 +904,11 @@ def test_get_attack_results_harm_category_no_matches(sqlite_instance: MemoryInte def test_get_attack_results_labels_no_matches(sqlite_instance: MemoryInterface): """Test filtering by labels that don't exist.""" - # Create attack result with labels that don't match the search - attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "test_op"}) + # Create attack result without the labels we'll search for + message_piece = create_message_piece("conv_1", 1, labels={"operation": "test_op"}) + sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece]) + + attack_result = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) # Search for non-existent labels @@ -912,6 +920,11 @@ def test_get_attack_results_labels_query_on_empty_labels(sqlite_instance: Memory """Test querying for labels when records have no labels at all""" # Create attack results with NO labels + message_piece1 = create_message_piece("conv_1", 1) + message_piece2 = create_message_piece("conv_2", 1) + + sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2]) + attack_result1 = create_attack_result("conv_1", 1, AttackOutcome.SUCCESS) attack_result2 = create_attack_result("conv_2", 2, AttackOutcome.FAILURE) @@ -931,14 +944,16 @@ def test_get_attack_results_labels_key_exists_value_mismatch(sqlite_instance: Me """Test querying for labels where the key exists but the value doesn't match.""" # Create attack results with specific label values + message_piece1 = create_message_piece("conv_1", 1, labels={"operation": "op_exists", "researcher": "roakey"}) + message_piece2 = create_message_piece("conv_2", 1, labels={"operation": "another_op", "researcher": "roakey"}) + message_piece3 = create_message_piece("conv_3", 1, labels={"operation": "test_op"}) + + sqlite_instance.add_message_pieces_to_memory(message_pieces=[message_piece1, message_piece2, message_piece3]) + attack_results = [ - create_attack_result( - "conv_1", 1, AttackOutcome.SUCCESS, labels={"operation": "op_exists", "researcher": "roakey"} - ), - create_attack_result( - "conv_2", 2, AttackOutcome.SUCCESS, labels={"operation": "another_op", "researcher": "roakey"} - ), - create_attack_result("conv_3", 3, AttackOutcome.FAILURE, labels={"operation": "test_op"}), + create_attack_result("conv_1", 1, AttackOutcome.SUCCESS), + create_attack_result("conv_2", 2, AttackOutcome.SUCCESS), + create_attack_result("conv_3", 3, AttackOutcome.FAILURE), ] sqlite_instance.add_attack_results_to_memory(attack_results=attack_results) diff --git a/tests/unit/scenario/test_scenario.py b/tests/unit/scenario/test_scenario.py index 6e0f06eb71..d1a8505c81 100644 --- a/tests/unit/scenario/test_scenario.py +++ b/tests/unit/scenario/test_scenario.py @@ -85,7 +85,6 @@ def sample_attack_results(): objective=f"objective{i}", outcome=AttackOutcome.SUCCESS, executed_turns=1, - labels={"test_label": f"value{i}"}, ) for i in range(5) ] From 568ab07983989900160945c8706a660846e06bb8 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 21 Apr 2026 14:26:40 -0700 Subject: [PATCH 13/23] rename functions --- pyrit/memory/azure_sql_memory.py | 6 +++--- pyrit/memory/memory_interface.py | 2 +- pyrit/memory/migration.py | 8 +++++--- pyrit/memory/sqlite_memory.py | 6 +++--- tests/unit/memory/test_sqlite_memory.py | 6 +++--- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index b15d500f97..e60827b62f 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -26,7 +26,7 @@ EmbeddingDataEntry, PromptMemoryEntry, ) -from pyrit.memory.migration import reset_schema +from pyrit.memory.migration import reset_database from pyrit.models import ( AzureBlobStorageIO, ConversationStats, @@ -104,7 +104,7 @@ def __init__( self._enable_azure_authorization() self.SessionFactory = sessionmaker(bind=self.engine) - self._create_tables_if_not_exist() + self._ensure_schema_is_current() super().__init__() @@ -778,4 +778,4 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict def reset_database(self) -> None: """Drop and recreate existing tables.""" - reset_schema(engine=self.engine) + reset_database(engine=self.engine) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index b8e0b5607a..727fd8b4d6 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -120,7 +120,7 @@ def disable_embedding(self) -> None: """ self.memory_embedding = None - def _create_tables_if_not_exist(self) -> None: + def _ensure_schema_is_current(self) -> None: """ Upgrade the database schema to the latest Alembic revision. diff --git a/pyrit/memory/migration.py b/pyrit/memory/migration.py index c352abb3eb..be6af50a29 100644 --- a/pyrit/memory/migration.py +++ b/pyrit/memory/migration.py @@ -137,14 +137,16 @@ def run_schema_migrations(*, engine: Engine) -> None: command.upgrade(config, _HEAD_REVISION) -def reset_schema(*, engine: Engine) -> None: +def reset_database(*, engine: Engine) -> None: """ - Recreate the database schema at the latest Alembic revision. + Drop all tables and recreate the database schema at the latest Alembic revision. + + This destroys all existing data. Args: engine (Engine): SQLAlchemy engine bound to the target database. """ - logger.debug("Resetting memory schema using Alembic migrations") + logger.debug("Resetting database using Alembic migrations") with engine.begin() as connection: # Drop version table first (not part of Base.metadata) inspector = inspect(connection) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index f2f69789ed..b15331047d 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -28,7 +28,7 @@ PromptMemoryEntry, ScenarioResultEntry, ) -from pyrit.memory.migration import reset_schema +from pyrit.memory.migration import reset_database from pyrit.models import ConversationStats, DiskStorageIO, MessagePiece logger = logging.getLogger(__name__) @@ -81,7 +81,7 @@ def __init__( self.engine = self._create_engine(has_echo=verbose) self.SessionFactory = sessionmaker(bind=self.engine) - self._create_tables_if_not_exist() + self._ensure_schema_is_current() def _init_storage_io(self) -> None: # Handles disk-based storage for SQLite local memory. @@ -428,7 +428,7 @@ def reset_database(self) -> None: """ Drop and recreates all tables in the database. """ - reset_schema(engine=self.engine) + reset_database(engine=self.engine) def dispose_engine(self) -> None: """ diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index f274607d86..7a34cd3191 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -937,9 +937,9 @@ def test_run_schema_migrations_no_memory_tables(): engine.dispose() -def test_create_tables_if_not_exist_with_schema_migration_exception(): +def test_ensure_schema_is_current_with_schema_migration_exception(): """ - Test that _create_tables_if_not_exist properly handles and re-raises exceptions. + Test that _ensure_schema_is_current properly handles and re-raises exceptions. This tests lines 132-134 in memory_interface.py. """ with tempfile.TemporaryDirectory() as temp_dir: @@ -952,6 +952,6 @@ def test_create_tables_if_not_exist_with_schema_migration_exception(): mock_migrate.side_effect = RuntimeError("Mock migration error") with pytest.raises(RuntimeError, match="Mock migration error"): - memory._create_tables_if_not_exist() + memory._ensure_schema_is_current() finally: memory.dispose_engine() From 7ad42b608ca00e23edb1f2ac9db80dfe8ad8e4dd Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 21 Apr 2026 14:35:48 -0700 Subject: [PATCH 14/23] move reset to interface --- pyrit/memory/azure_sql_memory.py | 5 ----- pyrit/memory/memory_interface.py | 8 +++++++- pyrit/memory/sqlite_memory.py | 7 ------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index e60827b62f..b3987d467d 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -26,7 +26,6 @@ EmbeddingDataEntry, PromptMemoryEntry, ) -from pyrit.memory.migration import reset_database from pyrit.models import ( AzureBlobStorageIO, ConversationStats, @@ -775,7 +774,3 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict session.rollback() logger.exception(f"Error updating entries: {e}") raise - - def reset_database(self) -> None: - """Drop and recreate existing tables.""" - reset_database(engine=self.engine) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 727fd8b4d6..7799105e14 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -34,7 +34,7 @@ ScoreEntry, SeedEntry, ) -from pyrit.memory.migration import run_schema_migrations +from pyrit.memory.migration import reset_database, run_schema_migrations from pyrit.models import ( AttackResult, ConversationStats, @@ -965,6 +965,12 @@ def update_prompt_metadata_by_conversation_id( conversation_id=conversation_id, update_fields={"prompt_metadata": prompt_metadata} ) + def reset_database(self) -> None: + """ + Drop and recreate all tables in the database. + """ + reset_database(engine=self.engine) + @abc.abstractmethod def dispose_engine(self) -> None: """ diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index b15331047d..a3e007c3e2 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -28,7 +28,6 @@ PromptMemoryEntry, ScenarioResultEntry, ) -from pyrit.memory.migration import reset_database from pyrit.models import ConversationStats, DiskStorageIO, MessagePiece logger = logging.getLogger(__name__) @@ -424,12 +423,6 @@ def get_session(self) -> Session: """ return self.SessionFactory() - def reset_database(self) -> None: - """ - Drop and recreates all tables in the database. - """ - reset_database(engine=self.engine) - def dispose_engine(self) -> None: """ Dispose the engine and close all connections. From 8e7f3383bb7425eeaab47930304641ab6e536859 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 21 Apr 2026 18:45:47 -0700 Subject: [PATCH 15/23] filter for files in precommit migration check --- .pre-commit-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index af5ac76bb8..deda27f01e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,6 +30,7 @@ repos: entry: python ./build_scripts/memory_migrations.py check language: system pass_filenames: false + files: ^pyrit/memory/(memory_models\.py|alembic/.*|migration\.py)$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 From ad9628523b531b410d40a799f3b7f99e426bf1af Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 21 Apr 2026 20:40:04 -0700 Subject: [PATCH 16/23] update autogenerated commands comment --- pyrit/memory/alembic/versions/e373726d391b_initial_schema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py index 43e42dcd72..935b077e8c 100644 --- a/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py +++ b/pyrit/memory/alembic/versions/e373726d391b_initial_schema.py @@ -58,7 +58,7 @@ def process_result_value(self, value: Any, dialect: Any) -> uuid.UUID | None: def upgrade() -> None: """Apply this schema upgrade.""" - # ### commands auto generated by Alembic - please adjust! ### + # ### commands auto generated by Alembic and reviewed by author### op.create_table( "PromptMemoryEntries", sa.Column("id", _CustomUUID(), nullable=False), @@ -194,7 +194,7 @@ def upgrade() -> None: def downgrade() -> None: """Revert this schema upgrade.""" - # ### commands auto generated by Alembic - please adjust! ### + # ### commands auto generated by Alembic and reviewed by author ### op.drop_table("AttackResultEntries") op.drop_table("ScoreEntries") op.drop_table("EmbeddingData") From f6e932b8acae7bbf970c1199aae65b845f4ecf1c Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 22 Apr 2026 09:09:44 -0700 Subject: [PATCH 17/23] cleanups --- .pyrit_conf_example | 8 ++ build_scripts/memory_migrations.py | 117 +++++++--------------------- pyrit/memory/alembic.ini | 37 --------- pyrit/memory/alembic/env.py | 80 ++++--------------- pyrit/memory/azure_sql_memory.py | 1 - pyrit/memory/migration.py | 102 +++++++++++++++++------- pyrit/memory/sqlite_memory.py | 1 - pyrit/setup/configuration_loader.py | 10 +++ pyrit/setup/initialization.py | 7 ++ 9 files changed, 141 insertions(+), 222 deletions(-) delete mode 100644 pyrit/memory/alembic.ini diff --git a/.pyrit_conf_example b/.pyrit_conf_example index e136b7a913..f66c075d1b 100644 --- a/.pyrit_conf_example +++ b/.pyrit_conf_example @@ -95,6 +95,14 @@ operation: op_trash_panda # - /path/to/.env # - /path/to/.env.local +# Schema Migration Check +# --------------------- +# If true, runs database schema migration on startup to ensure the database +# is up to date with the latest PyRIT version. +# Set to false to skip the check (e.g., for read-only access, testing, or +# when managing migrations externally). +check_schema: true + # Silent Mode # ----------- # If true, suppresses print statements during initialization. diff --git a/build_scripts/memory_migrations.py b/build_scripts/memory_migrations.py index ac1bab40da..bb37584c11 100644 --- a/build_scripts/memory_migrations.py +++ b/build_scripts/memory_migrations.py @@ -6,12 +6,11 @@ import tempfile from pathlib import Path -from alembic import command -from alembic.config import Config from alembic.util.exc import AutogenerateDiffsDetected +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine -_REPO_ROOT = Path(__file__).resolve().parent.parent -_ALEMBIC_INI = _REPO_ROOT / "pyrit" / "memory" / "alembic.ini" +from pyrit.memory.migration import check_schema_migrations, generate_schema_migration, run_schema_migrations # ANSI color codes _RED = "\033[91m" @@ -23,95 +22,46 @@ def _print_error(message: str) -> None: print(f"{_RED}{message}{_RESET}", file=sys.stderr) -def _make_config(*, db_url: str) -> Config: - """ - Build an Alembic Config pointed at a specific database URL. - - Args: - db_url (str): SQLAlchemy database URL to use. - - Returns: - Config: Configured Alembic config object. - """ - cfg = Config(str(_ALEMBIC_INI)) - cfg.set_main_option("sqlalchemy.url", db_url) - return cfg +def _create_temp_engine() -> tuple[Engine, Path]: + """Create a temp SQLite database upgraded to head and return engine and path.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + tmp_path = Path(tmp.name) + engine = create_engine(f"sqlite:///{tmp_path}") + run_schema_migrations(engine=engine) + return engine, tmp_path def _cmd_generate(*, message: str, force: bool = False) -> None: - """ - Generate a new Alembic revision from model changes. - - A fresh temporary SQLite database is created, upgraded to the current head, - and then checked against the current models. If the check passes (schema matches), - no migration is needed and the command fails. If the check fails (schema doesn't match), - a migration is generated. - - Args: - message (str): Human-readable migration message. - force (bool): If True, generate a migration even if schema matches models. - Useful for manual migrations. - - Raises: - SystemExit: If schema matches models and force=False. - """ - with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: - db_url = f"sqlite:///{tmp.name}" - + """Generate a new Alembic revision from model changes.""" + engine, tmp_path = _create_temp_engine() try: - config = _make_config(db_url=db_url) - command.upgrade(config, "head") - - # Check if schema matches current models - try: - command.check(config) - # If check passes, schema already matches models - if not force: - _print_error("No schema changes detected. Use --force to generate an empty migration.") - raise SystemExit(1) - print("Generating migration even though schema matches models (--force used).") - except AutogenerateDiffsDetected: - # Check failed, meaning schema doesn't match models - proceed with generation - pass - - command.revision(config, autogenerate=True, message=message) + generate_schema_migration(engine=engine, message=message, force=force) print("Migration file generated. Review it carefully before committing.") + except RuntimeError as e: + _print_error(str(e)) + raise SystemExit(1) from e finally: - Path(tmp.name).unlink(missing_ok=True) + engine.dispose() + tmp_path.unlink(missing_ok=True) def _cmd_check() -> None: - """ - Verify that all migration scripts apply cleanly and schema matches models. - - Creates a clean temporary database, applies all migrations to head, - and validates that the resulting schema matches the current memory models. - """ - with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: - db_url = f"sqlite:///{tmp.name}" - + """Verify all migrations apply cleanly and schema matches models.""" + engine, tmp_path = _create_temp_engine() try: - config = _make_config(db_url=db_url) - command.upgrade(config, "head") - command.check(config) - except Exception as e: - _print_error( - f"Migration check failed. Run the script in 'generate' mode to generate a new migration. Error: {e}" - ) + check_schema_migrations(engine=engine) + except AutogenerateDiffsDetected as e: + _print_error(f"Migration check failed. Run 'generate' to create a migration. Error: {e}") raise SystemExit(1) from e finally: - Path(tmp.name).unlink(missing_ok=True) + engine.dispose() + tmp_path.unlink(missing_ok=True) def _build_parser() -> argparse.ArgumentParser: - """ - Build the CLI argument parser. - - Returns: - argparse.ArgumentParser: Configured parser. - """ + """Build the CLI argument parser.""" parser = argparse.ArgumentParser( - description="PyRIT memory migration tool. Generate and validate migrations against clean databases." + description="PyRIT memory migration tool. Generate and validate migrations based on the current memory models." ) sub = parser.add_subparsers(dest="command", required=True) @@ -119,22 +69,13 @@ def _build_parser() -> argparse.ArgumentParser: gen.add_argument("-m", "--message", required=True, help="Migration message.") gen.add_argument("--force", action="store_true", help="Generate migration even if no changes detected.") - sub.add_parser("check", help="Verify all migrations apply cleanly to a fresh database.") + sub.add_parser("check", help="Verify all migrations apply cleanly and add up to the current memory models.") return parser def main() -> int: - """ - Dispatch the selected migration command. - - Returns: - int: Process exit code. - """ - if not _ALEMBIC_INI.exists(): - _print_error(f"Alembic config not found at {_ALEMBIC_INI}") - return 1 - + """Dispatch the selected migration command.""" args = _build_parser().parse_args() if args.command == "generate": diff --git a/pyrit/memory/alembic.ini b/pyrit/memory/alembic.ini deleted file mode 100644 index 6af71fffe2..0000000000 --- a/pyrit/memory/alembic.ini +++ /dev/null @@ -1,37 +0,0 @@ -# A generic, package-local Alembic config for PyRIT memory schema migrations. - -[alembic] -script_location = pyrit/memory/alembic -prepend_sys_path = . - -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARN -handlers = console - -[logger_sqlalchemy] -level = WARN -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s diff --git a/pyrit/memory/alembic/env.py b/pyrit/memory/alembic/env.py index 1b77327183..672b62e794 100644 --- a/pyrit/memory/alembic/env.py +++ b/pyrit/memory/alembic/env.py @@ -1,76 +1,24 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from logging.config import fileConfig -from typing import TYPE_CHECKING - from alembic import context -from sqlalchemy import engine_from_config, pool +from sqlalchemy.engine import Connection from pyrit.memory.memory_models import Base - -if TYPE_CHECKING: - from sqlalchemy.engine import Connection +from pyrit.memory.migration import PYRIT_MEMORY_ALEMBIC_VERSION_TABLE config = context.config - -if config.config_file_name is not None: - fileConfig(config.config_file_name) - +connection: Connection | None = config.attributes.get("connection") target_metadata = Base.metadata -VERSION_TABLE = config.attributes.get("version_table", "pyrit_memory_alembic_version") - - -def run_migrations_offline() -> None: - """Run migrations in offline mode.""" - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - compare_type=True, - version_table=VERSION_TABLE, - dialect_opts={"paramstyle": "named"}, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def run_migrations_online() -> None: - """Run migrations in online mode.""" - connection: Connection | None = config.attributes.get("connection") - - if connection is not None: - context.configure( - connection=connection, - target_metadata=target_metadata, - compare_type=True, - version_table=VERSION_TABLE, - ) - with context.begin_transaction(): - context.run_migrations() - return - - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - with connectable.connect() as created_connection: - context.configure( - connection=created_connection, - target_metadata=target_metadata, - compare_type=True, - version_table=VERSION_TABLE, - ) - - with context.begin_transaction(): - context.run_migrations() - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() +if connection is None: + raise RuntimeError("No connection found for Alembic migration") + +context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, + version_table=PYRIT_MEMORY_ALEMBIC_VERSION_TABLE, +) +with context.begin_transaction(): + context.run_migrations() diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index b3987d467d..440eb1b760 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -103,7 +103,6 @@ def __init__( self._enable_azure_authorization() self.SessionFactory = sessionmaker(bind=self.engine) - self._ensure_schema_is_current() super().__init__() diff --git a/pyrit/memory/migration.py b/pyrit/memory/migration.py index be6af50a29..b67c631bce 100644 --- a/pyrit/memory/migration.py +++ b/pyrit/memory/migration.py @@ -8,6 +8,7 @@ from alembic.autogenerate.api import compare_metadata from alembic.config import Config from alembic.migration import MigrationContext +from alembic.util.exc import AutogenerateDiffsDetected from sqlalchemy import MetaData, Table, inspect from sqlalchemy.engine import Connection, Engine @@ -15,17 +16,24 @@ logger = logging.getLogger(__name__) -_MEMORY_ALEMBIC_VERSION_TABLE = "pyrit_memory_alembic_version" +PYRIT_MEMORY_ALEMBIC_VERSION_TABLE = "pyrit_memory_alembic_version" _HEAD_REVISION = "head" -_INITIAL_MEMORY_REVISION = "e373726d391b" -_MEMORY_TABLES = { - "AttackResultEntries", - "EmbeddingData", - "PromptMemoryEntries", - "ScenarioResultEntries", - "ScoreEntries", - "SeedPromptEntries", -} +_MEMORY_TABLES = {table.name for table in Base.metadata.sorted_tables} + +_ERROR_UNVERSIONED_SCHEMA_COMPARISON_FAILED = ( + "Detected an unversioned legacy memory schema (memory tables exist, but " + "pyrit_memory_alembic_version is missing), " + "but failed to compare the existing schema against current models. " + "This could be transient, or you may need to repair or rebuild the database " + "before upgrading to this release." +) + +_ERROR_UNVERSIONED_SCHEMA_MISMATCH = ( + "Detected an unversioned legacy memory schema (memory tables exist, but " + "pyrit_memory_alembic_version is missing), " + "and it does not match current models. " + "Repair or rebuild the database before upgrading to this release." +) def _include_name_for_memory_schema( @@ -69,7 +77,6 @@ def _make_config(*, connection: Connection) -> Config: config = Config() config.set_main_option("script_location", str(script_location)) config.attributes["connection"] = connection - config.attributes["version_table"] = _MEMORY_ALEMBIC_VERSION_TABLE return config @@ -89,7 +96,7 @@ def _validate_and_stamp_unversioned_memory_schema(*, config: Config, connection: table_names = set(inspector.get_table_names()) # If version table already exists, migration has been stamped - if _MEMORY_ALEMBIC_VERSION_TABLE in table_names: + if PYRIT_MEMORY_ALEMBIC_VERSION_TABLE in table_names: return # If no memory tables exist, this is a fresh database @@ -104,21 +111,13 @@ def _validate_and_stamp_unversioned_memory_schema(*, config: Config, connection: ) diffs = compare_metadata(migration_context, Base.metadata) except Exception as e: - raise RuntimeError( - "Detected an unversioned legacy memory schema (memory tables exist, but " - "pyrit_memory_alembic_version is missing), " - "and it does not match current models. Repair or rebuild the database before upgrading to this release." - ) from e + raise RuntimeError(_ERROR_UNVERSIONED_SCHEMA_COMPARISON_FAILED) from e if diffs: - raise RuntimeError( - "Detected an unversioned legacy memory schema (memory tables exist, but " - "pyrit_memory_alembic_version is missing), " - "and it does not match current models. Repair or rebuild the database before upgrading to this release." - ) + raise RuntimeError(_ERROR_UNVERSIONED_SCHEMA_MISMATCH) - logger.info("Detected matching unversioned memory schema; stamping revision %s", _INITIAL_MEMORY_REVISION) - command.stamp(config, _INITIAL_MEMORY_REVISION) + logger.warn("Detected matching unversioned memory schema; stamping revision %s", _HEAD_REVISION) + command.stamp(config, _HEAD_REVISION) def run_schema_migrations(*, engine: Engine) -> None: @@ -137,6 +136,52 @@ def run_schema_migrations(*, engine: Engine) -> None: command.upgrade(config, _HEAD_REVISION) +def check_schema_migrations(*, engine: Engine) -> None: + """ + Verify that the current schema matches the models. + + Creates a fresh connection, builds an Alembic config, and runs + ``alembic check`` to detect any unapplied model changes. + + Args: + engine (Engine): SQLAlchemy engine bound to the target database. + + Raises: + AutogenerateDiffsDetected: If schema does not match models. + """ + with engine.begin() as connection: + config = _make_config(connection=connection) + command.check(config) + + +def generate_schema_migration(*, engine: Engine, message: str, force: bool = False) -> None: + """ + Generate a new Alembic revision from model changes. + + Args: + engine (Engine): SQLAlchemy engine upgraded to head. + message (str): Human-readable migration message. + force (bool): If True, generate even if no changes detected. Defaults to False. + + Raises: + RuntimeError: If no changes detected and force is False. + """ + with engine.begin() as connection: + config = _make_config(connection=connection) + + if force: + command.revision(config, autogenerate=True, message=message) + return + + try: + command.check(config) + except AutogenerateDiffsDetected: + command.revision(config, autogenerate=True, message=message) + return + + raise RuntimeError("No schema changes detected. Use force=True to generate an empty migration.") + + def reset_database(*, engine: Engine) -> None: """ Drop all tables and recreate the database schema at the latest Alembic revision. @@ -150,12 +195,11 @@ def reset_database(*, engine: Engine) -> None: with engine.begin() as connection: # Drop version table first (not part of Base.metadata) inspector = inspect(connection) - if _MEMORY_ALEMBIC_VERSION_TABLE in inspector.get_table_names(): - version_table = Table(_MEMORY_ALEMBIC_VERSION_TABLE, MetaData(), autoload_with=connection) + if PYRIT_MEMORY_ALEMBIC_VERSION_TABLE in inspector.get_table_names(): + version_table = Table(PYRIT_MEMORY_ALEMBIC_VERSION_TABLE, MetaData(), autoload_with=connection) version_table.drop(connection) # Drop all application tables defined in models Base.metadata.drop_all(connection) - - # Rebuild schema from migrations - command.upgrade(_make_config(connection=connection), _HEAD_REVISION) + # Rebuild schema from migrations + run_schema_migrations(engine=engine) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index a3e007c3e2..5547cafb66 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -80,7 +80,6 @@ def __init__( self.engine = self._create_engine(has_echo=verbose) self.SessionFactory = sessionmaker(bind=self.engine) - self._ensure_schema_is_current() def _init_storage_io(self) -> None: # Handles disk-based storage for SQLite local memory. diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 6f720b3035..32279982a8 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -100,6 +100,7 @@ class ConfigurationLoader(YamlLoadable): initialization_scripts: Optional[list[str]] = None env_files: Optional[list[str]] = None silent: bool = False + check_schema: bool = True operator: Optional[str] = None operation: Optional[str] = None @@ -189,6 +190,7 @@ def load_with_overrides( initializers: Optional[Sequence[Union[str, dict[str, Any]]]] = None, initialization_scripts: Optional[Sequence[str]] = None, env_files: Optional[Sequence[str]] = None, + check_schema: Optional[bool] = None, ) -> "ConfigurationLoader": """ Load configuration with optional overrides. @@ -208,6 +210,7 @@ def load_with_overrides( initializers: Override for initializer list. initialization_scripts: Override for initialization script paths. env_files: Override for environment file paths. + check_schema: Override for schema migration check. True to run, False to skip. Returns: A merged ConfigurationLoader instance. @@ -227,6 +230,7 @@ def load_with_overrides( "initialization_scripts": None, # None = use defaults "env_files": None, # None = use defaults "silent": False, + "check_schema": True, } # 1. Try loading default config file if it exists @@ -245,6 +249,7 @@ def load_with_overrides( config_data["initialization_scripts"] = default_config.initialization_scripts config_data["env_files"] = default_config.env_files config_data["silent"] = default_config.silent + config_data["check_schema"] = default_config.check_schema if default_config.operator: config_data["operator"] = default_config.operator if default_config.operation: @@ -268,6 +273,7 @@ def load_with_overrides( config_data["initialization_scripts"] = explicit_config.initialization_scripts config_data["env_files"] = explicit_config.env_files config_data["silent"] = explicit_config.silent + config_data["check_schema"] = explicit_config.check_schema if explicit_config.operator: config_data["operator"] = explicit_config.operator if explicit_config.operation: @@ -293,6 +299,9 @@ def load_with_overrides( if env_files is not None: config_data["env_files"] = list(env_files) + if check_schema is not None: + config_data["check_schema"] = check_schema + return ConfigurationLoader.from_dict(config_data) @classmethod @@ -418,6 +427,7 @@ async def initialize_pyrit_async(self) -> None: initializers=resolved_initializers if resolved_initializers else None, env_files=resolved_env_files, silent=self.silent, + check_schema=self.check_schema, ) diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index ea34362c30..8789e11241 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -240,6 +240,7 @@ async def initialize_pyrit_async( initializers: Optional[Sequence["PyRITInitializer"]] = None, env_files: Optional[Sequence[pathlib.Path]] = None, silent: bool = False, + check_schema: bool = True, **memory_instance_kwargs: Any, ) -> None: """ @@ -258,6 +259,9 @@ async def initialize_pyrit_async( All paths must be valid pathlib.Path objects. silent (bool): If True, suppresses print statements about environment file loading. Defaults to False. + check_schema (bool): If True, runs schema migration to ensure the database is up to date. + Set to False to bypass the migration check (e.g., for testing or read-only access). + Defaults to True. **memory_instance_kwargs (Optional[Any]): Additional keyword arguments to pass to the memory instance. Raises: @@ -287,6 +291,9 @@ async def initialize_pyrit_async( raise ValueError( f"Memory database type '{memory_db_type}' is not a supported type {get_args(MemoryDatabaseType)}" ) + if check_schema: + memory._ensure_schema_is_current() + CentralMemory.set_memory_instance(memory) # Combine directly provided initializers with those loaded from scripts From e49bce88a3e5ee23ee0b2129786a278fde338a51 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 22 Apr 2026 09:46:52 -0700 Subject: [PATCH 18/23] enforce-migration-immutability Co-authored-by: Copilot --- .pre-commit-config.yaml | 6 +++ .../enforce_alembic_revision_immutability.py | 39 +++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 build_scripts/enforce_alembic_revision_immutability.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index deda27f01e..7450328402 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,6 +25,12 @@ repos: files: ^(doc/.*\.(py|ipynb|md)|doc/myst\.yml)$ pass_filenames: false additional_dependencies: ['pyyaml'] + - id: enforce_alembic_revision_immutability + name: Enforce Alembic Revision Immutability + entry: python ./build_scripts/enforce_alembic_revision_immutability.py + language: python + files: ^pyrit/memory/alembic/versions/.*\.py$ + pass_filenames: false - id: memory-migrations-check name: Check Memory Migrations entry: python ./build_scripts/memory_migrations.py check diff --git a/build_scripts/enforce_alembic_revision_immutability.py b/build_scripts/enforce_alembic_revision_immutability.py new file mode 100644 index 0000000000..1e1971542b --- /dev/null +++ b/build_scripts/enforce_alembic_revision_immutability.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Migration history must be immutable. This hook enforces that by preventing deletion or updates to migration scripts. + +Checks both staged changes (local pre-commit) and the full branch diff against origin/main (CI). +""" + +import subprocess +import sys + +_VERSIONS_PATH = "pyrit/memory/alembic/versions/" + + +def _git(*args: str) -> str: + result = subprocess.run(["git", *args], capture_output=True, text=True) + return result.stdout.strip() + + +def _has_non_add_changes(diff_spec: list[str]) -> bool: + output = _git("diff", "--name-status", *diff_spec, "--", _VERSIONS_PATH) + return any(line and not line.startswith("A") for line in output.splitlines()) + + +def has_revision_violations() -> bool: + # Local pre-commit: check staged changes + if _has_non_add_changes(["--cached"]): + return True + + # CI: check full branch diff against origin/main + merge_base = _git("merge-base", "origin/main", "HEAD") + return bool(merge_base and _has_non_add_changes([f"{merge_base}...HEAD"])) + + +if __name__ == "__main__": + if has_revision_violations(): + print("[ERROR] Migration scripts can only be added, not modified or deleted.") + sys.exit(1) From c2f5424a7666e48234ec45a0d1d5d5bce355ab4c Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 22 Apr 2026 09:47:01 -0700 Subject: [PATCH 19/23] doc --- doc/contributing/11_memory_models.md | 97 ++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 doc/contributing/11_memory_models.md diff --git a/doc/contributing/11_memory_models.md b/doc/contributing/11_memory_models.md new file mode 100644 index 0000000000..f5631e3eda --- /dev/null +++ b/doc/contributing/11_memory_models.md @@ -0,0 +1,97 @@ +# Memory Models & Migrations + +This guide covers how to work with PyRIT's memory models — where they live, how to add or update them, and how the migration system works. + +## Where Things Live + +| What | Path | +|---|---| +| ORM models (SQLAlchemy) | `pyrit/memory/memory_models.py` | +| Domain objects they map to | `pyrit/models/` (e.g. `MessagePiece`, `Score`, `Seed`, `AttackResult`, `ScenarioResult`) | +| Alembic migration environment | `pyrit/memory/alembic/env.py` | +| Migration revisions | `pyrit/memory/alembic/versions/` | +| Migration helpers | `pyrit/memory/migration.py` | +| CLI migration tool | `build_scripts/memory_migrations.py` | +| Schema diagram | `doc/code/memory/10_schema_diagram.md` | + +## Current Models + +All models inherit from the SQLAlchemy `Base` declarative class and live in `memory_models.py`: + +- **`PromptMemoryEntry`** — prompt/response data (`PromptMemoryEntries` table) +- **`ScoreEntry`** — evaluation results (`ScoreEntries` table) +- **`EmbeddingDataEntry`** — embeddings for semantic search (`EmbeddingData` table) +- **`SeedEntry`** — dataset prompts/templates (`SeedPromptEntries` table) +- **`AttackResultEntry`** — attack execution results (`AttackResultEntries` table) +- **`ScenarioResultEntry`** — scenario execution metadata (`ScenarioResultEntries` table) + +Each entry model has a corresponding domain object and conversion methods (e.g. `PromptMemoryEntry.__init__(entry: MessagePiece)` and `get_message_piece()`). + +## Adding or Updating a Model + +### 1. Edit the model + +Make your changes in `pyrit/memory/memory_models.py`. Follow these conventions: + +- Use `mapped_column()` with explicit types. +- Use `CustomUUID` for all UUID columns (handles cross-database compatibility). +- Add foreign keys where relationships exist. +- Include `pyrit_version` on new entry models. + +### 2. Generate a migration + +```bash +python build_scripts/memory_migrations.py generate -m "short description of change" +``` + +This creates a new revision file under `pyrit/memory/alembic/versions/`. **Review the generated file carefully** — auto-generated migrations may need manual adjustments (e.g. for data migrations or default values). + +### 3. Validate the migration + +```bash +python build_scripts/memory_migrations.py check +``` + +This verifies the schema produced by running all migrations matches the current models. Both pre-commit hooks (see below) and CI run this check. + +### 4. Update the schema diagram + +If you changed the schema in a meaningful way (added a table, added a foreign key, etc.), update the Mermaid diagram in `doc/code/memory/10_schema_diagram.md`. + +## How Migrations Run at Startup + +When `initialize_pyrit_async()` is called with `check_schema=True` (the default), migrations run automatically: + +``` +initialize_pyrit_async() + → memory._ensure_schema_is_current() # pyrit/memory/memory_interface.py + → run_schema_migrations(engine=...) # pyrit/memory/migration.py + → alembic upgrade head +``` + +This means any new migration you add will be applied automatically the next time a user initializes PyRIT. The behavior depends on the database state: + +| Database state | What happens | +|---|---| +| **Fresh (no tables)** | All migrations apply from scratch | +| **Already versioned** | Only unapplied migrations run (idempotent) | +| **Legacy (tables exist, no version tracking)** | Validates schema matches models, stamps current version, then upgrades. Raises `RuntimeError` on mismatch to prevent data corruption | + +Migrations run inside a transaction (`engine.begin()`), so a failed migration rolls back cleanly. The version tracking table is `pyrit_memory_alembic_version`. + +Users can skip this check by passing `check_schema=False` to `initialize_pyrit_async()`. + +## Important Rules + +### Migration revisions are immutable + +Once a migration revision is committed, it **must not be modified or deleted**. This is enforced by a pre-commit hook (`enforce_alembic_revision_immutability`). If you need to fix a migration, create a new revision instead. + +### Pre-commit hooks + +Two hooks run automatically when you touch memory-related files: + +1. **`enforce_alembic_revision_immutability`** — blocks modifications/deletions to existing revision files. +2. **`memory-migrations-check`** — runs `memory_migrations.py check` to verify the schema is in sync. + +These hooks trigger on changes to `pyrit/memory/memory_models.py`, `pyrit/memory/migration.py`, and files under `pyrit/memory/alembic/`. \ No newline at end of file From 499438918d8032b4781f179e994a7cfca5f62bf1 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 22 Apr 2026 10:08:36 -0700 Subject: [PATCH 20/23] tests Co-authored-by: Copilot --- pyrit/memory/memory_interface.py | 8 ++++++++ tests/unit/memory/test_azure_sql_memory.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index ff51d0504c..18600a3225 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -126,7 +126,10 @@ def _ensure_schema_is_current(self) -> None: Raises: Exception: If there's an issue applying schema migrations. + RuntimeError: If the engine is not initialized. """ + if self.engine is None: + raise RuntimeError("Engine is not initialized") try: run_schema_migrations(engine=self.engine) except Exception as e: @@ -968,7 +971,12 @@ def update_prompt_metadata_by_conversation_id( def reset_database(self) -> None: """ Drop and recreate all tables in the database. + + Raises: + RuntimeError: If the engine is not initialized. """ + if self.engine is None: + raise RuntimeError("Engine is not initialized") reset_database(engine=self.engine) @abc.abstractmethod diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index f548c6e2dc..9ffb02099e 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -462,11 +462,11 @@ def decorator(fn): captured_fn(None, None, ["some_connection_string"], {}) -def test_create_tables_if_not_exist_raises_when_engine_none(): +def test_ensure_schema_is_current_raises_when_engine_none(): obj = AzureSQLMemory.__new__(AzureSQLMemory) obj.engine = None with pytest.raises(RuntimeError, match="Engine is not initialized"): - obj._create_tables_if_not_exist() + obj._ensure_schema_is_current() def test_reset_database_raises_when_engine_none(): From 8d08ff3406e42387081c929e0eb0166fdff1c3e3 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 22 Apr 2026 10:13:50 -0700 Subject: [PATCH 21/23] EOF --- doc/contributing/11_memory_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/contributing/11_memory_models.md b/doc/contributing/11_memory_models.md index f5631e3eda..ef215eeae4 100644 --- a/doc/contributing/11_memory_models.md +++ b/doc/contributing/11_memory_models.md @@ -94,4 +94,4 @@ Two hooks run automatically when you touch memory-related files: 1. **`enforce_alembic_revision_immutability`** — blocks modifications/deletions to existing revision files. 2. **`memory-migrations-check`** — runs `memory_migrations.py check` to verify the schema is in sync. -These hooks trigger on changes to `pyrit/memory/memory_models.py`, `pyrit/memory/migration.py`, and files under `pyrit/memory/alembic/`. \ No newline at end of file +These hooks trigger on changes to `pyrit/memory/memory_models.py`, `pyrit/memory/migration.py`, and files under `pyrit/memory/alembic/`. From 8c17a427a13a41e170409662994558625699fff6 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 22 Apr 2026 10:44:59 -0700 Subject: [PATCH 22/23] myst Co-authored-by: Copilot --- doc/myst.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/myst.yml b/doc/myst.yml index c3fb071a29..2c995bd5c0 100644 --- a/doc/myst.yml +++ b/doc/myst.yml @@ -53,6 +53,7 @@ project: - file: contributing/8_pre_commit.md - file: contributing/9_exception.md - file: contributing/10_release_process.md + - file: contributing/11_memory_models.md - file: gui/0_gui.md - file: scanner/0_scanner.md children: From c6f310d994feabd8eb69ae0858e19f01b2e0f282 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 22 Apr 2026 11:47:52 -0700 Subject: [PATCH 23/23] cc Co-authored-by: Copilot --- tests/unit/memory/test_migration.py | 97 ++++++++++++++++++- tests/unit/setup/test_configuration_loader.py | 22 +++++ 2 files changed, 118 insertions(+), 1 deletion(-) diff --git a/tests/unit/memory/test_migration.py b/tests/unit/memory/test_migration.py index c1ecb9bad5..0140ce5b1a 100644 --- a/tests/unit/memory/test_migration.py +++ b/tests/unit/memory/test_migration.py @@ -6,13 +6,49 @@ import uuid from pathlib import Path +import pytest from alembic import command from alembic.config import Config from alembic.script import ScriptDirectory +from alembic.util.exc import AutogenerateDiffsDetected from sqlalchemy import create_engine, inspect, text from pyrit.memory.alembic.versions.e373726d391b_initial_schema import _CustomUUID -from pyrit.memory.migration import run_schema_migrations +from pyrit.memory.migration import check_schema_migrations, generate_schema_migration, run_schema_migrations + + +def test_alembic_env_raises_when_no_connection(): + """Covers env.py line 15: RuntimeError when connection is None.""" + import importlib + import sys + from unittest.mock import MagicMock + + # Build a mock alembic context where config.attributes has no "connection" + mock_config = MagicMock() + mock_config.attributes = {} # .get("connection") → None + + mock_context = MagicMock() + mock_context.config = mock_config + + # Remove cached env module so reload runs module-level code + env_module_name = "pyrit.memory.alembic.env" + saved = sys.modules.pop(env_module_name, None) + + # Patch alembic.context to be our mock context + import alembic + + original_context = getattr(alembic, "context", None) + alembic.context = mock_context + + try: + with pytest.raises(RuntimeError, match="No connection found for Alembic migration"): + importlib.import_module(env_module_name) + finally: + # Restore original state + alembic.context = original_context + sys.modules.pop(env_module_name, None) + if saved is not None: + sys.modules[env_module_name] = saved def _get_alembic_head_revision(*, config: Config) -> str: @@ -167,3 +203,62 @@ def test_migration_downgrade_creates_proper_structure(): assert "pyrit_memory_alembic_version" not in tables_after or len(tables_after) == 1 finally: engine.dispose() + + +def test_check_schema_migrations_calls_alembic_check(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "check-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + # First apply migrations so schema matches models + run_schema_migrations(engine=engine) + # Now check should succeed (no diffs) + check_schema_migrations(engine=engine) + finally: + engine.dispose() + + +def test_generate_schema_migration_force_creates_revision(): + from unittest.mock import patch + + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "gen-force-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + run_schema_migrations(engine=engine) + with patch("pyrit.memory.migration.command.revision") as mock_revision: + generate_schema_migration(engine=engine, message="force empty", force=True) + mock_revision.assert_called_once() + finally: + engine.dispose() + + +def test_generate_schema_migration_no_changes_raises(): + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "gen-nochange-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + run_schema_migrations(engine=engine) + with pytest.raises(RuntimeError, match="No schema changes detected"): + generate_schema_migration(engine=engine, message="should fail") + finally: + engine.dispose() + + +def test_generate_schema_migration_with_diffs_creates_revision(): + from unittest.mock import patch + + with tempfile.TemporaryDirectory() as temp_dir: + db_path = os.path.join(temp_dir, "gen-diffs-test.db") + engine = create_engine(f"sqlite:///{db_path}") + try: + run_schema_migrations(engine=engine) + exc = AutogenerateDiffsDetected.__new__(AutogenerateDiffsDetected) + with ( + patch("pyrit.memory.migration.command.check", side_effect=exc), + patch("pyrit.memory.migration.command.revision") as mock_revision, + ): + generate_schema_migration(engine=engine, message="with diffs") + mock_revision.assert_called_once() + finally: + engine.dispose() diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index 76776bf515..fbdc80db1c 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -531,3 +531,25 @@ def test_load_with_overrides_preserves_silent_from_config_file(self, mock_defaul assert config.silent is True finally: config_path.unlink() + + @mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") + def test_load_with_overrides_check_schema_from_default_config(self, mock_default_path): + """Test that check_schema is loaded from the default config file.""" + mock_default_path.exists.return_value = True + + with mock.patch.object(ConfigurationLoader, "from_yaml_file") as mock_from_yaml: + fake_config = ConfigurationLoader(memory_db_type="sqlite", check_schema=False) + mock_from_yaml.return_value = fake_config + + config = ConfigurationLoader.load_with_overrides() + + assert config.check_schema is False + + @mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") + def test_load_with_overrides_check_schema_override(self, mock_default_path): + """Test that check_schema override takes precedence.""" + mock_default_path.exists.return_value = False + + config = ConfigurationLoader.load_with_overrides(check_schema=False) + + assert config.check_schema is False