Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pyrit/setup/configuration_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class ConfigurationLoader(YamlLoadable):
operator: Optional[str] = None
operation: Optional[str] = None
scenario: Optional[Union[str, dict[str, Any]]] = None
extensions: dict[str, Any] = field(default_factory=dict)

def __post_init__(self) -> None:
"""Validate and normalize the configuration after loading."""
Expand Down Expand Up @@ -254,7 +255,15 @@ def from_dict(cls, data: dict[str, Any]) -> "ConfigurationLoader":
"""
# Filter out None values only - empty lists are meaningful ("load nothing")
filtered_data = {k: v for k, v in data.items() if v is not None}
return cls(**filtered_data)
known_fields = set(cls.__dataclass_fields__.keys())
known_data = {k: v for k, v in filtered_data.items() if k in known_fields and k != "extensions"}
extra_data = {k: v for k, v in filtered_data.items() if k not in known_fields}
if "extensions" in filtered_data:
extensions = filtered_data["extensions"]
if not isinstance(extensions, dict):
raise ValueError(f"ConfigurationLoader.extensions must be a dict. Got: {type(extensions).__name__}")
extra_data = {**extensions, **extra_data}
return cls(**known_data, extensions=extra_data)

@staticmethod
def load_with_overrides(
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/setup/test_configuration_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,33 @@ def test_from_dict_filters_none_values(self):
assert config.memory_db_type == "sqlite" # Normalized to snake_case
assert config.initializers == [] # Uses default, not None

def test_from_dict_preserves_unknown_top_level_keys_in_extensions(self):
"""Unknown top-level keys should be preserved for downstream tooling."""
data = {
"memory_db_type": "in_memory",
"targets": [{"name": "x"}],
"scan_modes": {"default": {"threshold": 0.8}},
}
config = ConfigurationLoader.from_dict(data)
assert config.memory_db_type == "in_memory"
assert config.extensions == {
"targets": [{"name": "x"}],
"scan_modes": {"default": {"threshold": 0.8}},
}

def test_from_dict_merges_explicit_extensions_with_unknown_keys(self):
data = {
"memory_db_type": "sqlite",
"extensions": {"team": "red"},
"targets": [{"name": "x"}],
}
config = ConfigurationLoader.from_dict(data)
assert config.extensions == {"team": "red", "targets": [{"name": "x"}]}

def test_from_dict_rejects_non_dict_extensions(self):
with pytest.raises(ValueError, match="extensions must be a dict"):
ConfigurationLoader.from_dict({"extensions": ["not", "a", "dict"]})

def test_from_yaml_file(self):
"""Test loading configuration from a YAML file."""
yaml_content = """
Expand Down