diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index d36cf42a2..a34b88191 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -54,6 +54,9 @@ class AzureSQLMemory(MemoryInterface, metaclass=Singleton): TOKEN_URL = "https://database.windows.net/.default" # The token URL for any Azure SQL database AZURE_SQL_DB_CONNECTION_STRING = "AZURE_SQL_DB_CONNECTION_STRING" + # Azure SQL supports up to 2100 parameters per statement + _MAX_BIND_VARS: int = 2000 + # Azure Storage Account Container datasets and results environment variables AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: str = "AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL" AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN: str = "AZURE_STORAGE_ACCOUNT_DB_DATA_SAS_TOKEN" diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index a971d3879..8ccec656f 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -71,6 +71,11 @@ class MemoryInterface(abc.ABC): such as files, databases, or cloud storage services. """ + # Maximum number of bind variables per SQL statement. + # Conservative default based on SQLite's limit of 999. Subclasses can override + # for backends with higher limits (e.g., Azure SQL supports 2100). + _MAX_BIND_VARS: int = 500 + memory_embedding: MemoryEmbedding | None = None results_storage_io: StorageIO | None = None results_path: str | None = None @@ -349,6 +354,149 @@ def _query_entries( List of model instances representing the rows fetched from the table. """ + def _execute_batched_query( + self, + model_class: type[Model], + *, + batch_column: InstrumentedAttribute[Any], + batch_values: Sequence[Any], + other_conditions: list[Any] | None = None, + distinct: bool = False, + join_scores: bool = False, + batch_size: int | None = None, + ) -> MutableSequence[Model]: + """ + Execute queries in batches to avoid exceeding database bind variable limits. + + SQLite and other databases have per-statement parameter limits. This method + executes separate queries for each batch of values and merges the results. + + Args: + model_class: The SQLAlchemy model class to query. + batch_column: The column to batch the IN condition on. + batch_values: The values to filter by (will be batched). + other_conditions: Additional SQLAlchemy conditions to include in each query. + distinct: Whether to return distinct rows only. + join_scores: Whether to join the scores table. + batch_size: Override for the number of values per batch. + Defaults to ``_MAX_BIND_VARS`` when not specified. + + Returns: + MutableSequence[Model]: Merged and deduplicated results from all batched queries. + """ + if other_conditions is None: + other_conditions = [] + + effective_size = batch_size if batch_size is not None else self._MAX_BIND_VARS + + # If values fit in one batch, execute a single query + if len(batch_values) <= effective_size: + conditions = other_conditions + [batch_column.in_(batch_values)] + return self._query_entries( + model_class, + conditions=and_(*conditions) if conditions else None, + distinct=distinct, + join_scores=join_scores, + ) + + # Execute multiple separate queries and merge results + all_results: MutableSequence[Model] = [] + seen_ids: set[str] = set() + + for i in range(0, len(batch_values), effective_size): + batch = batch_values[i : i + effective_size] + conditions = other_conditions + [batch_column.in_(batch)] + + results = self._query_entries( + model_class, + conditions=and_(*conditions) if conditions else None, + distinct=distinct, + join_scores=join_scores, + ) + + # Deduplicate by primary key (id) + for result in results: + result_id = getattr(result, "id", None) + if result_id is not None: + id_str = str(result_id) + if id_str not in seen_ids: + seen_ids.add(id_str) + all_results.append(result) + else: + all_results.append(result) + + return all_results + + def _query_with_list_params( + self, + model_class: type[Model], + *, + conditions: list[Any], + list_params: list[tuple[InstrumentedAttribute[Any], Sequence[Any], str]], + join_scores: bool = False, + ) -> MutableSequence[Model]: + """ + Execute a query with list-based IN filters, batching when lists exceed bind limits. + + Splits list parameters into "small" (fit in one query) and "large" (need batching). + Small params are added to the SQL conditions directly. The first large param is + batched via ``_execute_batched_query``; any remaining large params are filtered + in Python after fetching. + + The effective batch size is reduced to account for bind variables contributed by + small params, preventing cumulative overflow of the per-statement limit. + + Args: + model_class: The SQLAlchemy model class to query. + conditions: Base conditions (scalar filters) to include in every query. + list_params: List of (column, values, attr_name) tuples for IN-clause filters. + join_scores: Whether to join the scores table. + + Returns: + MutableSequence[Model]: Query results with all filters applied. + """ + if not list_params: + return self._query_entries( + model_class, + conditions=and_(*conditions) if conditions else None, + join_scores=join_scores, + ) + + large_params = [(col, vals, name) for col, vals, name in list_params if len(vals) > self._MAX_BIND_VARS] + small_params = [(col, vals, name) for col, vals, name in list_params if len(vals) <= self._MAX_BIND_VARS] + + small_param_binds = sum(len(vals) for _, vals, _ in small_params) + for col, vals, _ in small_params: + conditions.append(col.in_(vals)) + + if not large_params: + return self._query_entries( + model_class, + conditions=and_(*conditions) if conditions else None, + join_scores=join_scores, + ) + + batch_col, batch_vals, _ = large_params[0] + other_large_params = large_params[1:] + + # Reduce batch size to account for bind variables already used by small params + effective_batch_size = max(1, self._MAX_BIND_VARS - small_param_binds) + + results = self._execute_batched_query( + model_class, + batch_column=batch_col, + batch_values=batch_vals, + other_conditions=conditions, + join_scores=join_scores, + batch_size=effective_batch_size, + ) + + for _col, vals, attr_name in other_large_params: + vals_set = set(vals) + results = [e for e in results if getattr(e, attr_name, None) in vals_set] + + return results + @abc.abstractmethod def _insert_entry(self, entry: Base) -> None: """ @@ -525,10 +673,11 @@ def get_scores( Returns: Sequence[Score]: A list of Score objects that match the specified filters. """ + if score_ids is not None and len(score_ids) == 0: + return [] + conditions: list[Any] = [] - if score_ids: - conditions.append(ScoreEntry.id.in_(score_ids)) if score_type: conditions.append(ScoreEntry.score_type == score_type) if score_category: @@ -546,11 +695,22 @@ def get_scores( ) ) + # Handle score_ids with batched queries if needed + if score_ids: + entries = self._execute_batched_query( + ScoreEntry, + batch_column=ScoreEntry.id, + batch_values=list(score_ids), + other_conditions=conditions, + ) + return [entry.get_score() for entry in entries] + + # No score_ids specified - use regular query if not conditions: return [] - entries: Sequence[ScoreEntry] = self._query_entries(ScoreEntry, conditions=and_(*conditions)) - return [entry.get_score() for entry in entries] + score_entries: Sequence[ScoreEntry] = self._query_entries(ScoreEntry, conditions=and_(*conditions)) + return [entry.get_score() for entry in score_entries] def get_prompt_scores( self, @@ -708,55 +868,63 @@ def get_message_pieces( Exception: If there is an error retrieving the prompts, an exception is logged and an empty list is returned. """ - conditions = [] - if attack_id: - conditions.append( - self._get_condition_json_property_match( - json_column=PromptMemoryEntry.attack_identifier, - property_path="$.hash", - value=str(attack_id), + try: + conditions: list[Any] = [] + if attack_id: + conditions.append( + self._get_condition_json_property_match( + json_column=PromptMemoryEntry.attack_identifier, + property_path="$.hash", + value=str(attack_id), + ) ) - ) - if role: - conditions.append(PromptMemoryEntry.role == role) - if conversation_id: - conditions.append(PromptMemoryEntry.conversation_id == str(conversation_id)) - if prompt_ids: - prompt_ids = [str(pi) for pi in prompt_ids] - conditions.append(PromptMemoryEntry.id.in_(prompt_ids)) - if labels: - conditions.extend(self._get_message_pieces_memory_label_conditions(memory_labels=labels)) - if prompt_metadata: - conditions.extend(self._get_message_pieces_prompt_metadata_conditions(prompt_metadata=prompt_metadata)) - if sent_after: - conditions.append(PromptMemoryEntry.timestamp >= sent_after) - if sent_before: - conditions.append(PromptMemoryEntry.timestamp <= sent_before) - if original_values: - conditions.append(PromptMemoryEntry.original_value.in_(original_values)) - if converted_values: - conditions.append(PromptMemoryEntry.converted_value.in_(converted_values)) - if data_type: - conditions.append(PromptMemoryEntry.converted_value_data_type == data_type) - if not_data_type: - conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) - if converted_value_sha256: - conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) - if identifier_filters: - conditions.extend( - self._build_identifier_filter_conditions( - identifier_filters=identifier_filters, - identifier_column_map={ - IdentifierType.ATTACK: PromptMemoryEntry.attack_identifier, - IdentifierType.TARGET: PromptMemoryEntry.prompt_target_identifier, - IdentifierType.CONVERTER: PromptMemoryEntry.converter_identifiers, - }, - caller="get_message_pieces", + if role: + conditions.append(PromptMemoryEntry.role == role) + if conversation_id: + conditions.append(PromptMemoryEntry.conversation_id == str(conversation_id)) + if labels: + conditions.extend(self._get_message_pieces_memory_label_conditions(memory_labels=labels)) + if prompt_metadata: + conditions.extend(self._get_message_pieces_prompt_metadata_conditions(prompt_metadata=prompt_metadata)) + if sent_after: + conditions.append(PromptMemoryEntry.timestamp >= sent_after) + if sent_before: + conditions.append(PromptMemoryEntry.timestamp <= sent_before) + if data_type: + conditions.append(PromptMemoryEntry.converted_value_data_type == data_type) + if not_data_type: + conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) + if identifier_filters: + conditions.extend( + self._build_identifier_filter_conditions( + identifier_filters=identifier_filters, + identifier_column_map={ + IdentifierType.ATTACK: PromptMemoryEntry.attack_identifier, + IdentifierType.TARGET: PromptMemoryEntry.prompt_target_identifier, + IdentifierType.CONVERTER: PromptMemoryEntry.converter_identifiers, + }, + caller="get_message_pieces", + ) ) - ) - try: - memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( - PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True + + # Identify list parameters that may need batching + list_params: list[tuple[InstrumentedAttribute[Any], Sequence[Any], str]] = [] + if prompt_ids: + list_params.append((PromptMemoryEntry.id, [str(pi) for pi in prompt_ids], "id")) + if original_values: + list_params.append((PromptMemoryEntry.original_value, list(original_values), "original_value")) + if converted_values: + list_params.append((PromptMemoryEntry.converted_value, list(converted_values), "converted_value")) + if converted_value_sha256: + list_params.append( + (PromptMemoryEntry.converted_value_sha256, list(converted_value_sha256), "converted_value_sha256") + ) + + memory_entries = self._query_with_list_params( + PromptMemoryEntry, + conditions=conditions, + list_params=list_params, + join_scores=True, ) message_pieces = [memory_entry.get_message_piece() for memory_entry in memory_entries] return sort_message_pieces(message_pieces=message_pieces) @@ -1519,20 +1687,18 @@ def get_attack_results( Returns: Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. """ - conditions: list[ColumnElement[bool]] = [] + # Handle empty list cases + if attack_result_ids is not None and len(attack_result_ids) == 0: + return [] + if objective_sha256 is not None and len(objective_sha256) == 0: + return [] - if attack_result_ids is not None: - if len(attack_result_ids) == 0: - # Empty list means no results - return [] - conditions.append(AttackResultEntry.id.in_(attack_result_ids)) + # Build non-list conditions + conditions: list[ColumnElement[bool]] = [] if conversation_id: conditions.append(AttackResultEntry.conversation_id == conversation_id) if objective: conditions.append(AttackResultEntry.objective.contains(objective)) - - if objective_sha256: - conditions.append(AttackResultEntry.objective_sha256.in_(objective_sha256)) if outcome: conditions.append(AttackResultEntry.outcome == outcome) @@ -1565,13 +1731,10 @@ def get_attack_results( DeprecationWarning, stacklevel=2, ) - # Use database-specific JSON query method conditions.append( self._get_attack_result_harm_category_condition(targeted_harm_categories=targeted_harm_categories) ) - if labels: - # Use database-specific JSON query method conditions.append(self._get_attack_result_label_condition(labels=labels)) if identifier_filters: @@ -1584,21 +1747,39 @@ def get_attack_results( ) try: - entries: Sequence[AttackResultEntry] = self._query_entries( - AttackResultEntry, conditions=and_(*conditions) if conditions else None + list_params: list[tuple[InstrumentedAttribute[Any], Sequence[Any], str]] = [] + if attack_result_ids: + list_params.append((AttackResultEntry.id, list(attack_result_ids), "id")) + if objective_sha256: + list_params.append((AttackResultEntry.objective_sha256, list(objective_sha256), "objective_sha256")) + + entries = self._query_with_list_params( + AttackResultEntry, + conditions=conditions, + list_params=list_params, ) - # Deduplicate by conversation_id — when duplicate rows exist - # (legacy bug), keep only the newest entry per conversation_id. - seen: dict[str, AttackResultEntry] = {} - for entry in entries: - prev = seen.get(entry.conversation_id) - if prev is None or entry.timestamp > prev.timestamp: - seen[entry.conversation_id] = entry - return [entry.get_attack_result() for entry in seen.values()] + return self._dedup_attack_entries(entries) except Exception as e: logger.exception(f"Failed to retrieve attack results with error {e}") raise + @staticmethod + def _dedup_attack_entries(entries: Sequence[AttackResultEntry]) -> list[AttackResult]: + """ + Deduplicate AttackResultEntry rows by conversation_id and convert to AttackResult. + + When duplicate rows exist (legacy bug), keeps only the newest entry per conversation_id. + + Returns: + list[AttackResult]: Deduplicated attack results. + """ + seen: dict[str, AttackResultEntry] = {} + for entry in entries: + prev = seen.get(entry.conversation_id) + if prev is None or entry.timestamp > prev.timestamp: + seen[entry.conversation_id] = entry + return [entry.get_attack_result() for entry in seen.values()] + def get_unique_attack_labels(self) -> dict[str, list[str]]: """ Return all unique label key-value pairs across attack results. @@ -1794,18 +1975,12 @@ def get_scenario_results( Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. """ - conditions: list[ColumnElement[bool]] = [] + if scenario_result_ids is not None and len(scenario_result_ids) == 0: + return [] - if scenario_result_ids is not None: - if len(scenario_result_ids) == 0: - # Empty list means no results - return [] - conditions.append(ScenarioResultEntry.id.in_(scenario_result_ids)) + conditions: list[ColumnElement[bool]] = [] if scenario_name: - # Normalize CLI snake_case names (e.g., "foundry" or "content_harms") - # to class names (e.g., "Foundry" or "ContentHarms") - # This allows users to query with either format normalized_name = ScenarioResult.normalize_scenario_name(scenario_name) conditions.append(ScenarioResultEntry.scenario_name.contains(normalized_name)) @@ -1822,11 +1997,9 @@ def get_scenario_results( conditions.append(ScenarioResultEntry.completion_time <= added_before) if labels: - # Use database-specific JSON query method conditions.append(self._get_scenario_result_label_condition(labels=labels)) if objective_target_endpoint: - # Use database-specific JSON query method conditions.append( self._get_condition_json_property_match( json_column=ScenarioResultEntry.objective_target_identifier, @@ -1837,7 +2010,6 @@ def get_scenario_results( ) if objective_target_model_name: - # Use database-specific JSON query method conditions.append( self._get_condition_json_property_match( json_column=ScenarioResultEntry.objective_target_identifier, @@ -1860,9 +2032,16 @@ def get_scenario_results( ) try: - entries: Sequence[ScenarioResultEntry] = self._query_entries( - ScenarioResultEntry, conditions=and_(*conditions) if conditions else None - ) + # Handle scenario_result_ids with batched queries if needed + if scenario_result_ids: + entries = self._execute_batched_query( + ScenarioResultEntry, + batch_column=ScenarioResultEntry.id, + batch_values=list(scenario_result_ids), + other_conditions=conditions, + ) + else: + entries = self._query_entries(ScenarioResultEntry, conditions=and_(*conditions) if conditions else None) # Convert entries to ScenarioResults and populate attack_results efficiently scenario_results = [] @@ -1877,12 +2056,12 @@ def get_scenario_results( for conv_ids in conversation_ids_by_attack.values(): all_conversation_ids.extend(conv_ids) - # Query all AttackResults in a single batch if there are any + # Query all AttackResults using batched queries if needed if all_conversation_ids: - # Build condition to query multiple conversation IDs at once - attack_conditions = [AttackResultEntry.conversation_id.in_(all_conversation_ids)] - attack_entries: Sequence[AttackResultEntry] = self._query_entries( - AttackResultEntry, conditions=and_(*attack_conditions) + attack_entries = self._execute_batched_query( + AttackResultEntry, + batch_column=AttackResultEntry.conversation_id, + batch_values=all_conversation_ids, ) # Build a dict for quick lookup diff --git a/tests/unit/memory/memory_interface/test_batching_scale.py b/tests/unit/memory/memory_interface/test_batching_scale.py new file mode 100644 index 000000000..c19a1e6ed --- /dev/null +++ b/tests/unit/memory/memory_interface/test_batching_scale.py @@ -0,0 +1,586 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for batching functionality to handle large numbers of IDs. +This addresses the scaling bug where methods like get_scores_by_prompt_ids +fail when querying with many IDs due to SQLite bind variable limits. +""" + +import hashlib +import uuid +from unittest.mock import patch + +from pyrit.memory import MemoryInterface +from pyrit.models import AttackResult, MessagePiece, Score + +# Use the class attribute for the batch limit in tests +_MAX_BIND_VARS = MemoryInterface._MAX_BIND_VARS + + +def _create_message_piece( + conversation_id: str | None = None, + original_value: str = "test message", + role: str = "user", +) -> MessagePiece: + """Create a sample message piece for testing.""" + converted_value = original_value + # Compute SHA256 for converted_value so filtering by sha256 works + sha256 = hashlib.sha256(converted_value.encode("utf-8")).hexdigest() + return MessagePiece( + id=str(uuid.uuid4()), + role=role, + original_value=original_value, + converted_value=converted_value, + converted_value_sha256=sha256, + sequence=0, + conversation_id=conversation_id or str(uuid.uuid4()), + labels={"test": "label"}, + attack_identifier={"id": str(uuid.uuid4())}, + ) + + +def _create_score(message_piece_id: str) -> Score: + """Create a sample score for testing.""" + return Score( + score_value="0.5", + score_value_description="test score", + score_type="float_scale", + score_category=["test"], + score_rationale="test rationale", + score_metadata={}, + scorer_class_identifier={"__type__": "TestScorer"}, + message_piece_id=message_piece_id, + ) + + +class TestBatchingScale: + """Tests for batching when querying with many IDs.""" + + def test_get_message_pieces_with_many_prompt_ids(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with more IDs than the batch limit.""" + # Create more message pieces than the batch limit + num_pieces = _MAX_BIND_VARS + 100 + pieces = [_create_message_piece() for _ in range(num_pieces)] + + # Add to memory + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Query with all IDs - this should work with batching + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces, f"Expected {num_pieces} results, got {len(results)}" + + def test_get_message_pieces_with_exact_batch_size(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with exactly the batch limit.""" + num_pieces = _MAX_BIND_VARS + pieces = [_create_message_piece() for _ in range(num_pieces)] + + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces + + def test_get_message_pieces_with_double_batch_size(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with double the batch limit.""" + num_pieces = _MAX_BIND_VARS * 2 + pieces = [_create_message_piece() for _ in range(num_pieces)] + + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces + + def test_get_scores_with_many_score_ids(self, sqlite_instance: MemoryInterface): + """Test that get_scores works with more IDs than the batch limit.""" + # Create message pieces first (scores need to reference them) + num_scores = _MAX_BIND_VARS + 100 + pieces = [_create_message_piece() for _ in range(num_scores)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Create and add scores + scores = [_create_score(str(piece.id)) for piece in pieces] + sqlite_instance.add_scores_to_memory(scores=scores) + + # Query with all score IDs - this should work with batching + all_score_ids = [str(score.id) for score in scores] + results = sqlite_instance.get_scores(score_ids=all_score_ids) + + assert len(results) == num_scores, f"Expected {num_scores} results, got {len(results)}" + + def test_get_prompt_scores_with_many_prompt_ids(self, sqlite_instance: MemoryInterface): + """Test that get_prompt_scores works with more prompt IDs than the batch limit.""" + # Create message pieces + num_pieces = _MAX_BIND_VARS + 50 + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Create and add scores for half of them + num_scores = num_pieces // 2 + scores = [_create_score(str(pieces[i].id)) for i in range(num_scores)] + sqlite_instance.add_scores_to_memory(scores=scores) + + # Query with all prompt IDs - should return scores for pieces that have them + all_prompt_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_prompt_scores(prompt_ids=all_prompt_ids) + + assert len(results) == num_scores, f"Expected {num_scores} results, got {len(results)}" + + def test_get_message_pieces_batching_preserves_other_filters(self, sqlite_instance: MemoryInterface): + """Test that batching still applies other filter conditions correctly.""" + # Create pieces with different roles + num_pieces = _MAX_BIND_VARS + 50 + user_pieces = [_create_message_piece(role="user") for _ in range(num_pieces)] + + assistant_pieces = [_create_message_piece(role="assistant") for _ in range(50)] + + all_pieces = user_pieces + assistant_pieces + sqlite_instance.add_message_pieces_to_memory(message_pieces=all_pieces) + + # Query with all IDs but filter by role + all_ids = [piece.id for piece in all_pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids, role="user") + + assert len(results) == num_pieces, f"Expected {num_pieces} user pieces, got {len(results)}" + + def test_get_message_pieces_small_list_still_works(self, sqlite_instance: MemoryInterface): + """Test that small ID lists (under batch limit) still work correctly.""" + num_pieces = 10 + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces + + def test_get_message_pieces_with_many_original_values(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with many original_values exceeding batch limit.""" + num_pieces = _MAX_BIND_VARS + 100 + # Create pieces with unique original values + pieces = [_create_message_piece(original_value=f"unique_value_{i}") for i in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Query with all original values + all_values = [piece.original_value for piece in pieces] + results = sqlite_instance.get_message_pieces(original_values=all_values) + + assert len(results) == num_pieces, f"Expected {num_pieces} results, got {len(results)}" + + def test_get_message_pieces_with_many_converted_value_sha256(self, sqlite_instance: MemoryInterface): + """Test that get_message_pieces works with many converted_value_sha256 exceeding batch limit.""" + num_pieces = _MAX_BIND_VARS + 100 + pieces = [_create_message_piece(original_value=f"unique_value_{i}") for i in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get SHA256 hashes from stored pieces + stored_pieces = sqlite_instance.get_message_pieces() + all_hashes = [piece.converted_value_sha256 for piece in stored_pieces if piece.converted_value_sha256] + + assert len(all_hashes) > _MAX_BIND_VARS, "Test setup failed: not enough hashes to trigger batching" + results = sqlite_instance.get_message_pieces(converted_value_sha256=all_hashes) + assert len(results) == len(all_hashes) + + def test_get_message_pieces_combines_filters_correctly(self, sqlite_instance: MemoryInterface): + """Test that multiple filters can be combined (e.g., prompt_ids AND role).""" + # Create message pieces with different roles + num_pieces = 50 + user_pieces = [_create_message_piece(role="user") for _ in range(num_pieces)] + + assistant_pieces = [_create_message_piece(role="assistant") for _ in range(num_pieces)] + + all_pieces = user_pieces + assistant_pieces + sqlite_instance.add_message_pieces_to_memory(message_pieces=all_pieces) + + # Query with both prompt_ids AND role filter + user_ids = [piece.id for piece in user_pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=user_ids, role="user") + + # Should return only user pieces (intersection of both filters) + assert len(results) == num_pieces + assert all(r.get_role_for_storage() == "user" for r in results) + + # Query with role filter and a subset of IDs + subset_ids = user_ids[:10] + results = sqlite_instance.get_message_pieces(prompt_ids=subset_ids, role="user") + assert len(results) == 10 + + def test_get_message_pieces_multiple_large_params_simultaneously(self, sqlite_instance: MemoryInterface): + """Test batching with multiple parameters exceeding batch limit simultaneously.""" + # Create enough pieces to exceed batch limit with unique values + num_pieces = _MAX_BIND_VARS + 200 + pieces = [_create_message_piece(original_value=f"original_value_{i}") for i in range(num_pieces)] + + # Add to memory + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get all stored pieces to extract their IDs and SHA256 hashes + stored_pieces = sqlite_instance.get_message_pieces() + assert len(stored_pieces) >= num_pieces + + # Extract multiple large parameter lists + all_ids = [piece.id for piece in stored_pieces[:num_pieces]] + all_original_values = [piece.original_value for piece in stored_pieces[:num_pieces]] + all_sha256 = [piece.converted_value_sha256 for piece in stored_pieces[:num_pieces]] + + # Query with multiple large parameters simultaneously + # This tests that ALL parameters are batched correctly, not just one + results = sqlite_instance.get_message_pieces( + prompt_ids=all_ids, + original_values=all_original_values, + converted_value_sha256=all_sha256, + ) + + # Should return all pieces that match ALL conditions (intersection) + assert len(results) == num_pieces, ( + f"Expected {num_pieces} results when filtering with multiple large parameters, got {len(results)}" + ) + + # Verify all returned pieces match all filter criteria + result_ids = {r.id for r in results} + result_original_values = {r.original_value for r in results} + result_sha256 = {r.converted_value_sha256 for r in results} + + assert result_ids == set(all_ids), "Returned IDs don't match filter" + assert result_original_values == set(all_original_values), "Returned original_values don't match filter" + assert result_sha256 == set(all_sha256), "Returned SHA256 hashes don't match filter" + + def test_get_message_pieces_multiple_batched_params_with_query_spy(self, sqlite_instance: MemoryInterface): + """Test that batching executes multiple separate queries and merges results correctly.""" + # Create pieces exceeding batch limit + num_pieces = _MAX_BIND_VARS + 100 + pieces = [_create_message_piece(original_value=f"value_{i}") for i in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get stored pieces + stored_pieces = sqlite_instance.get_message_pieces() + all_ids = [piece.id for piece in stored_pieces[:num_pieces]] + all_original_values = [piece.original_value for piece in stored_pieces[:num_pieces]] + + # Mock _query_entries to track how it's called + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids, original_values=all_original_values) + + # Should get all results despite batching + assert len(results) == num_pieces + + # With the new batching approach, multiple separate queries should be executed + # when the primary batch parameter exceeds _MAX_BIND_VARS + # Expected: ceil(num_pieces / _MAX_BIND_VARS) = 2 queries + expected_min_calls = (num_pieces + _MAX_BIND_VARS - 1) // _MAX_BIND_VARS + assert call_count >= expected_min_calls, ( + f"Expected at least {expected_min_calls} separate queries for {num_pieces} items, " + f"but only got {call_count} calls" + ) + + def test_get_message_pieces_triple_large_params_preserves_intersection(self, sqlite_instance: MemoryInterface): + """Test that filtering with 3 large parameter lists returns correct intersection.""" + # Create a large set of pieces + total_pieces = _MAX_BIND_VARS + 150 + pieces = [ + _create_message_piece(conversation_id=str(uuid.uuid4()), original_value=f"content_{i}") + for i in range(total_pieces) + ] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Get stored pieces + stored_pieces = sqlite_instance.get_message_pieces() + + # Create three overlapping large filter lists + # List 1: All IDs + filter_ids = [p.id for p in stored_pieces[:total_pieces]] + + # List 2: All original values + filter_original_values = [p.original_value for p in stored_pieces[:total_pieces]] + + # List 3: Subset of SHA256 hashes (to test intersection) + subset_size = _MAX_BIND_VARS + 50 + filter_sha256 = [p.converted_value_sha256 for p in stored_pieces[:subset_size]] + + # Query with all three large parameters + results = sqlite_instance.get_message_pieces( + prompt_ids=filter_ids, + original_values=filter_original_values, + converted_value_sha256=filter_sha256, + ) + + # Should return only the intersection (subset_size items) + assert len(results) == subset_size, f"Expected {subset_size} results from intersection, got {len(results)}" + + # Verify all results have SHA256 in the filter list + result_sha256 = {r.converted_value_sha256 for r in results} + assert result_sha256.issubset(set(filter_sha256)), "Results contain unexpected SHA256 values" + + +class TestExecuteBatchedQuery: + """Tests for the _execute_batched_query helper method.""" + + def test_execute_batched_query_small_list_single_query(self, sqlite_instance: MemoryInterface): + """Test that small lists execute a single query.""" + # Create a small number of pieces (under batch limit) + num_pieces = 10 + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Track query calls + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + # Should be a single query for small lists + assert call_count == 1 + assert len(results) == num_pieces + + def test_execute_batched_query_large_list_multiple_queries(self, sqlite_instance: MemoryInterface): + """Test that large lists execute multiple separate queries.""" + # Create pieces exceeding batch limit + num_pieces = _MAX_BIND_VARS * 3 # 3 batches needed + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Track query calls + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + # Should execute 3 separate queries (one per batch) + assert call_count == 3, f"Expected 3 queries for 3 batches, got {call_count}" + assert len(results) == num_pieces + + def test_execute_batched_query_deduplicates_results(self, sqlite_instance: MemoryInterface): + """Test that batched queries properly deduplicate results.""" + # Create pieces + num_pieces = 50 + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Query with the same IDs repeated (should still return unique results) + all_ids = [piece.id for piece in pieces] + # Query twice with same IDs - results should still be unique + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + assert len(results) == num_pieces + # Verify no duplicates + result_ids = [r.id for r in results] + assert len(result_ids) == len(set(result_ids)), "Results contain duplicate entries" + + def test_execute_batched_query_exact_batch_boundary(self, sqlite_instance: MemoryInterface): + """Test querying with exactly the batch limit (edge case).""" + num_pieces = _MAX_BIND_VARS + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # Track query calls + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + all_ids = [piece.id for piece in pieces] + results = sqlite_instance.get_message_pieces(prompt_ids=all_ids) + + # Exactly at the limit should still be a single query + assert call_count == 1, f"Expected 1 query at exact batch limit, got {call_count}" + assert len(results) == num_pieces + + def test_batching_with_scores_exceeds_limit(self, sqlite_instance: MemoryInterface): + """Test that get_scores handles large numbers of score IDs correctly.""" + # Create message pieces and scores exceeding the limit + num_items = _MAX_BIND_VARS * 2 + 50 + pieces = [_create_message_piece() for _ in range(num_items)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + scores = [_create_score(str(piece.id)) for piece in pieces] + sqlite_instance.add_scores_to_memory(scores=scores) + + # Query with all score IDs + all_score_ids = [str(score.id) for score in scores] + + # Track query calls + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + results = sqlite_instance.get_scores(score_ids=all_score_ids) + + # Should execute multiple queries + expected_calls = (num_items + _MAX_BIND_VARS - 1) // _MAX_BIND_VARS + assert call_count == expected_calls, f"Expected {expected_calls} queries, got {call_count}" + assert len(results) == num_items + + +def _create_attack_result(conversation_id: str | None = None, objective: str = "test objective") -> AttackResult: + """Create a sample attack result for testing.""" + return AttackResult( + conversation_id=conversation_id or str(uuid.uuid4()), + objective=objective, + ) + + +class TestAttackResultBatching: + """Tests for batching in get_attack_results.""" + + def test_get_attack_results_with_many_ids(self, sqlite_instance: MemoryInterface): + """Test that get_attack_results works with more IDs than the batch limit.""" + num_results = _MAX_BIND_VARS + 100 + results = [_create_attack_result(objective=f"objective_{i}") for i in range(num_results)] + sqlite_instance.add_attack_results_to_memory(attack_results=results) + + all_ids = [r.attack_result_id for r in results] + fetched = sqlite_instance.get_attack_results(attack_result_ids=all_ids) + + assert len(fetched) == num_results, f"Expected {num_results} results, got {len(fetched)}" + + def test_get_attack_results_with_many_objective_sha256(self, sqlite_instance: MemoryInterface): + """Test that get_attack_results works with many objective_sha256 values exceeding batch limit.""" + num_results = _MAX_BIND_VARS + 100 + results = [_create_attack_result(objective=f"unique_objective_{i}") for i in range(num_results)] + sqlite_instance.add_attack_results_to_memory(attack_results=results) + + all_sha256 = [hashlib.sha256(f"unique_objective_{i}".encode()).hexdigest() for i in range(num_results)] + fetched = sqlite_instance.get_attack_results(objective_sha256=all_sha256) + + assert len(fetched) == num_results, f"Expected {num_results} results, got {len(fetched)}" + + def test_get_attack_results_small_list_still_works(self, sqlite_instance: MemoryInterface): + """Test that small ID lists (under batch limit) still work correctly.""" + num_results = 10 + results = [_create_attack_result(objective=f"objective_{i}") for i in range(num_results)] + sqlite_instance.add_attack_results_to_memory(attack_results=results) + + all_ids = [r.attack_result_id for r in results] + fetched = sqlite_instance.get_attack_results(attack_result_ids=all_ids) + + assert len(fetched) == num_results + + def test_get_attack_results_empty_list_returns_empty(self, sqlite_instance: MemoryInterface): + """Test that explicit empty list returns empty results.""" + results = [_create_attack_result() for _ in range(5)] + sqlite_instance.add_attack_results_to_memory(attack_results=results) + + fetched = sqlite_instance.get_attack_results(attack_result_ids=[]) + assert fetched == [] + + fetched = sqlite_instance.get_attack_results(objective_sha256=[]) + assert fetched == [] + + +class TestScoresEmptyList: + """Tests for get_scores empty list handling.""" + + def test_get_scores_empty_list_returns_empty(self, sqlite_instance: MemoryInterface): + """Test that get_scores with explicit empty score_ids returns empty results.""" + pieces = [_create_message_piece() for _ in range(3)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + scores = [_create_score(str(p.id)) for p in pieces] + sqlite_instance.add_scores_to_memory(scores=scores) + + fetched = sqlite_instance.get_scores(score_ids=[]) + assert fetched == [] + + +class TestEffectiveBatchSize: + """Tests for effective batch size reduction when small + large params are combined.""" + + def test_batch_size_reduced_by_small_params(self, sqlite_instance: MemoryInterface): + """Test that batch size is reduced when small IN-clause params consume bind variables.""" + # Create pieces with unique values so we can filter by both prompt_ids and original_values + num_pieces = _MAX_BIND_VARS + 100 + pieces = [_create_message_piece(original_value=f"value_{i}") for i in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + # small param: original_values with 200 items (under _MAX_BIND_VARS) + small_original_values = [f"value_{i}" for i in range(200)] + # large param: prompt_ids with all IDs (over _MAX_BIND_VARS) + all_ids = [piece.id for piece in pieces] + + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + results = sqlite_instance.get_message_pieces( + prompt_ids=all_ids, + original_values=small_original_values, + ) + + # With 200 small param binds, effective_batch_size = 500 - 200 = 300 + # num_pieces = 600, so ceil(600 / 300) = 2 batches + effective_batch_size = _MAX_BIND_VARS - 200 + expected_calls = (num_pieces + effective_batch_size - 1) // effective_batch_size + assert call_count == expected_calls, ( + f"Expected {expected_calls} queries (effective_batch={effective_batch_size}), got {call_count}" + ) + # Results should be the intersection: only pieces whose original_value is in small_original_values + assert len(results) == 200 + + def test_custom_batch_size_on_execute_batched_query(self, sqlite_instance: MemoryInterface): + """Test that _execute_batched_query respects custom batch_size parameter.""" + num_pieces = 100 + pieces = [_create_message_piece() for _ in range(num_pieces)] + sqlite_instance.add_message_pieces_to_memory(message_pieces=pieces) + + from pyrit.memory.memory_models import PromptMemoryEntry + + original_query = sqlite_instance._query_entries + call_count = 0 + + def spy_query(*args, **kwargs): + nonlocal call_count + call_count += 1 + return original_query(*args, **kwargs) + + all_ids = [piece.id for piece in pieces] + with patch.object(sqlite_instance, "_query_entries", side_effect=spy_query): + results = sqlite_instance._execute_batched_query( + PromptMemoryEntry, + batch_column=PromptMemoryEntry.id, + batch_values=all_ids, + batch_size=30, + ) + + # 100 items / batch_size 30 = ceil(100/30) = 4 batches + assert call_count == 4, f"Expected 4 queries with batch_size=30, got {call_count}" + assert len(results) == num_pieces