Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/integration/converter/test_notebooks_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
nb_directory_path = pathlib.Path(path.DOCS_CODE_PATH, "converters").resolve()

skipped_files = [
"2_audio_converters.ipynb", # requires Azure Speech API key
"7_human_converter.ipynb", # requires human input
]

Expand Down
6 changes: 5 additions & 1 deletion tests/integration/datasets/test_notebooks_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@

nb_directory_path = pathlib.Path(path.DOCS_CODE_PATH, "datasets").resolve()

skipped_files = [
"2_seed_programming.ipynb", # requires OpenAI API credentials
]


@pytest.mark.parametrize(
"file_name",
[file for file in os.listdir(nb_directory_path) if file.endswith(".ipynb")],
[file for file in os.listdir(nb_directory_path) if file.endswith(".ipynb") and file not in skipped_files],
)
def test_execute_notebooks(file_name):
nb_path = pathlib.Path(nb_directory_path, file_name).resolve()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ async def _fetch_dataset(self, *, cache=True):

return type(f"_Mock_{name}", (_RemoteDatasetLoader,), attrs)

def test_filter_matches_correct_remote_provider(self):
@pytest.mark.asyncio
async def test_filter_matches_correct_remote_provider(self):
"""Filter by size returns only providers that match."""
large_cls = self._make_remote_provider_cls(
name="large_ds",
Expand All @@ -130,12 +131,13 @@ def test_filter_matches_correct_remote_provider(self):
{"Large": large_cls, "Small": small_cls},
clear=True,
):
names = SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(sizes=["large"]),
names = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(size={"large"}),
)
assert names == ["large_ds"]

def test_filter_all_tag_returns_everything(self):
@pytest.mark.asyncio
async def test_filter_all_tag_returns_everything(self):
"""tags={'all'} bypasses filtering and returns every provider."""
cls1 = self._make_remote_provider_cls(
name="ds_a",
Expand All @@ -157,12 +159,13 @@ def test_filter_all_tag_returns_everything(self):
{"A": cls1, "B": cls2},
clear=True,
):
names = SeedDatasetProvider.get_all_dataset_names_async(
names = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(tags={"all"}),
)
assert sorted(names) == ["ds_a", "ds_b"]

def test_multi_axis_filter(self):
@pytest.mark.asyncio
async def test_multi_axis_filter(self):
"""Multiple filter axes are ANDed together."""
cls1 = self._make_remote_provider_cls(
name="text_large",
Expand All @@ -184,10 +187,10 @@ def test_multi_axis_filter(self):
{"TL": cls1, "IL": cls2},
clear=True,
):
names = SeedDatasetProvider.get_all_dataset_names_async(
names = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(
sizes=["large"],
modalities=["text"],
size={"large"},
modalities={"text"},
),
)
assert names == ["text_large"]
Expand Down Expand Up @@ -217,7 +220,8 @@ def init_fn(self):
{"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__},
)

def test_local_filter_by_size(self, tmp_path):
@pytest.mark.asyncio
async def test_local_filter_by_size(self, tmp_path):
"""Local YAML with size metadata is correctly coerced and filtered."""
large_yaml = tmp_path / "large_ds.prompt"
large_yaml.write_text(
Expand Down Expand Up @@ -252,14 +256,15 @@ def test_local_filter_by_size(self, tmp_path):
{"Large": large_cls, "Small": small_cls},
clear=True,
):
names = SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(sizes=["large"]),
names = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(size={"large"}),
)
# dataset_name falls back to file stem when SeedDataset.from_yaml_file
# rejects extra keys like "size" during __init__ pre-loading
assert names == ["large_ds"]

def test_local_filter_by_tags(self, tmp_path):
@pytest.mark.asyncio
async def test_local_filter_by_tags(self, tmp_path):
"""Local YAML tags (list) are coerced to set for intersection."""
yaml_path = tmp_path / "tagged.prompt"
yaml_path.write_text(
Expand All @@ -284,17 +289,18 @@ def test_local_filter_by_tags(self, tmp_path):
):
# dataset_name falls back to file stem ("tagged") when
# SeedDataset.from_yaml_file rejects extra keys like "tags"
matched = SeedDatasetProvider.get_all_dataset_names_async(
matched = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(tags={"safety"}),
)
assert matched == ["tagged"]

not_matched = SeedDatasetProvider.get_all_dataset_names_async(
not_matched = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(tags={"unrelated"}),
)
assert not_matched == []

def test_local_no_metadata_skipped(self, tmp_path):
@pytest.mark.asyncio
async def test_local_no_metadata_skipped(self, tmp_path):
"""Local YAML without metadata fields is skipped when filters are provided."""
yaml_path = tmp_path / "bare.prompt"
yaml_path.write_text(
Expand All @@ -313,11 +319,11 @@ def test_local_no_metadata_skipped(self, tmp_path):
clear=True,
):
# Without filters, the dataset is included
all_names = SeedDatasetProvider.get_all_dataset_names_async()
all_names = await SeedDatasetProvider.get_all_dataset_names_async()
assert "bare_local" in all_names

# With filters, it's skipped (no metadata to match against)
filtered = SeedDatasetProvider.get_all_dataset_names_async(
filtered = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(tags={"safety"}),
)
assert filtered == []
Expand Down Expand Up @@ -392,8 +398,8 @@ async def test_user_discovers_and_fetches_filtered_dataset(self, tmp_path):
clear=True,
):
# --- Step 1: User filters by harm_categories ---
names = SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(harm_categories=["cybercrime"]),
names = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(harm_categories={"cybercrime"}),
)
assert len(names) == 1
dataset_name = names[0]
Expand All @@ -410,9 +416,9 @@ async def test_user_discovers_and_fetches_filtered_dataset(self, tmp_path):

# --- Step 3: User inspects metadata ---
provider = matching_cls()
metadata = provider._parse_metadata()
metadata = await provider._parse_metadata()
assert metadata is not None
assert metadata.harm_categories == ["cybercrime"]
assert metadata.harm_categories == {"cybercrime"}

@pytest.mark.asyncio
async def test_user_fetches_unfiltered(self, tmp_path):
Expand Down Expand Up @@ -447,7 +453,7 @@ async def test_user_fetches_unfiltered(self, tmp_path):
{"One": cls1, "Two": cls2},
clear=True,
):
names = SeedDatasetProvider.get_all_dataset_names_async()
names = await SeedDatasetProvider.get_all_dataset_names_async()
assert len(names) == 2

datasets = await SeedDatasetProvider.fetch_datasets_async()
Expand Down Expand Up @@ -480,7 +486,8 @@ def init_fn(self):
{"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__},
)

def test_all_tag_includes_datasets_without_metadata(self, tmp_path):
@pytest.mark.asyncio
async def test_all_tag_includes_datasets_without_metadata(self, tmp_path):
"""
A dataset whose YAML has no metadata fields at all is normally
skipped when filters are present. tags={'all'} overrides that.
Expand All @@ -502,18 +509,19 @@ def test_all_tag_includes_datasets_without_metadata(self, tmp_path):
clear=True,
):
# Normal filter skips it
filtered = SeedDatasetProvider.get_all_dataset_names_async(
filtered = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(tags={"safety"}),
)
assert filtered == []

# 'all' includes it
all_names = SeedDatasetProvider.get_all_dataset_names_async(
all_names = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(tags={"all"}),
)
assert "bare_dataset" in all_names

def test_all_tag_ignores_other_filter_axes(self, tmp_path):
@pytest.mark.asyncio
async def test_all_tag_ignores_other_filter_axes(self, tmp_path):
"""
tags={'all'} returns everything even when other filter axes
would exclude datasets.
Expand All @@ -538,18 +546,19 @@ def test_all_tag_ignores_other_filter_axes(self, tmp_path):
clear=True,
):
# Size filter alone would exclude it
size_filtered = SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(sizes=["large"]),
size_filtered = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(size={"large"}),
)
assert size_filtered == []

# 'all' tag overrides the size filter
all_names = SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(tags={"all"}, sizes=["large"]),
all_names = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(tags={"all"}, size={"large"}),
)
assert "small" in all_names

def test_all_tag_with_mixed_metadata_and_bare_datasets(self, tmp_path):
@pytest.mark.asyncio
async def test_all_tag_with_mixed_metadata_and_bare_datasets(self, tmp_path):
"""
With a mix of metadata-rich and metadata-bare datasets,
tags={'all'} returns all of them.
Expand Down Expand Up @@ -585,7 +594,7 @@ def test_all_tag_with_mixed_metadata_and_bare_datasets(self, tmp_path):
{"Rich": rich_cls, "Bare": bare_cls},
clear=True,
):
all_names = SeedDatasetProvider.get_all_dataset_names_async(
all_names = await SeedDatasetProvider.get_all_dataset_names_async(
filters=SeedDatasetFilter(tags={"all"}),
)
assert len(all_names) == 2
Expand Down Expand Up @@ -634,33 +643,27 @@ async def test_harmbench_discoverable_via_filter(self):
assert "harmbench" in names_by_harm

@pytest.mark.asyncio
async def test_harmbench_loads_and_stores_in_memory(self):
async def test_harmbench_loads_and_stores_in_memory(self, sqlite_instance):
"""HarmBench can be fetched and stored in memory for scenario use."""
from pyrit.memory import CentralMemory
from pyrit.setup import initialize_pyrit_async

await initialize_pyrit_async(memory_db_type="InMemory")

datasets = await SeedDatasetProvider.fetch_datasets_async(
dataset_names=["harmbench"],
)
assert len(datasets) == 1
assert datasets[0].dataset_name == "harmbench"
assert len(datasets[0].seeds) > 0

memory = CentralMemory.get_memory_instance()
await memory.add_seed_datasets_to_memory_async(
await sqlite_instance.add_seed_datasets_to_memory_async(
datasets=datasets,
added_by="test",
)

# Verify seeds are queryable from memory (this is what scenarios do)
seed_groups = memory.get_seed_groups(dataset_name="harmbench")
seed_groups = sqlite_instance.get_seed_groups(dataset_name="harmbench")
assert seed_groups is not None
assert len(list(seed_groups)) > 0

@pytest.mark.asyncio
async def test_red_team_agent_initializes_with_harmbench(self):
async def test_red_team_agent_initializes_with_harmbench(self, sqlite_instance):
"""
RedTeamAgent can initialize with harmbench dataset loaded in memory.

Expand All @@ -671,23 +674,18 @@ async def test_red_team_agent_initializes_with_harmbench(self):
from unittest.mock import MagicMock

from pyrit.executor.attack.core.attack_config import AttackScoringConfig
from pyrit.memory import CentralMemory
from pyrit.prompt_target import TextTarget
from pyrit.scenario.scenarios.foundry.red_team_agent import (
FoundryStrategy,
RedTeamAgent,
)
from pyrit.score.true_false.true_false_scorer import TrueFalseScorer
from pyrit.setup import initialize_pyrit_async

await initialize_pyrit_async(memory_db_type="InMemory")

# Load harmbench into memory
datasets = await SeedDatasetProvider.fetch_datasets_async(
dataset_names=["harmbench"],
)
memory = CentralMemory.get_memory_instance()
await memory.add_seed_datasets_to_memory_async(
await sqlite_instance.add_seed_datasets_to_memory_async(
datasets=datasets,
added_by="test",
)
Expand Down
1 change: 1 addition & 0 deletions tests/integration/embeddings/test_openai_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyrit.embedding import OpenAITextEmbedding


@pytest.mark.run_only_if_all_tests
@pytest.mark.parametrize(
"endpoint_env,key_env,model_env",
[
Expand Down
1 change: 1 addition & 0 deletions tests/integration/memory/test_notebooks_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
skipped_files = [
"6_azure_sql_memory.ipynb", # todo: requires Azure SQL setup, remove following completion of #4001
"7_azure_sql_memory_attacks.ipynb", # todo: remove following completion of #4001
"embeddings.ipynb", # requires OpenAI embedding API key
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ async def test_azure_content_filter_scorer_long_text_chunking_integration(memory
assert max_score > 0, "text should have > 0 score"


@pytest.mark.run_only_if_all_tests
@pytest.mark.asyncio
async def test_azure_content_filter_scorer_image_with_api_key(memory) -> None:
"""
Expand Down Expand Up @@ -109,6 +110,7 @@ async def test_azure_content_filter_scorer_image_with_api_key(memory) -> None:
assert max_score < 0.5, "Architecture diagram should have low harm scores"


@pytest.mark.run_only_if_all_tests
@pytest.mark.asyncio
async def test_azure_content_filter_scorer_text_with_api_key(memory) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions tests/integration/score/test_scorer_notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

skipped_files = [
"5_human_in_the_loop_scorer.ipynb", # requires human input
"prompt_shield_scorer.ipynb", # requires Azure Content Safety API key
]


Expand Down
1 change: 1 addition & 0 deletions tests/integration/targets/test_notebooks_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
nb_directory_path = pathlib.Path(path.DOCS_CODE_PATH, "targets").resolve()

skipped_files = [
"4_openai_video_target.ipynb", # requires OpenAI video API key
"10_1_playwright_target.ipynb", # Playwright installation takes too long
"10_2_playwright_target_copilot.ipynb", # Playwright installation takes too long, plus requires M365 account
"10_3_websocket_copilot_target.ipynb", # WebSocket Copilot target requires manual pasting tokens
Expand Down
Loading
Loading