From 115aebe2ec757ce1f21f37d9b049f7dd76a26878 Mon Sep 17 00:00:00 2001 From: Extra Small Date: Sun, 10 May 2026 17:04:08 -0700 Subject: [PATCH] feat: preserve unknown config keys in extensions --- pyrit/setup/configuration_loader.py | 11 +++++++- tests/unit/setup/test_configuration_loader.py | 27 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 184a235f65..7b279b773f 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -132,6 +132,7 @@ class ConfigurationLoader(YamlLoadable): operator: Optional[str] = None operation: Optional[str] = None scenario: Optional[Union[str, dict[str, Any]]] = None + extensions: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: """Validate and normalize the configuration after loading.""" @@ -254,7 +255,15 @@ def from_dict(cls, data: dict[str, Any]) -> "ConfigurationLoader": """ # Filter out None values only - empty lists are meaningful ("load nothing") filtered_data = {k: v for k, v in data.items() if v is not None} - return cls(**filtered_data) + known_fields = set(cls.__dataclass_fields__.keys()) + known_data = {k: v for k, v in filtered_data.items() if k in known_fields and k != "extensions"} + extra_data = {k: v for k, v in filtered_data.items() if k not in known_fields} + if "extensions" in filtered_data: + extensions = filtered_data["extensions"] + if not isinstance(extensions, dict): + raise ValueError(f"ConfigurationLoader.extensions must be a dict. Got: {type(extensions).__name__}") + extra_data = {**extensions, **extra_data} + return cls(**known_data, extensions=extra_data) @staticmethod def load_with_overrides( diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index 74caf3ad3c..356007bccc 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -165,6 +165,33 @@ def test_from_dict_filters_none_values(self): assert config.memory_db_type == "sqlite" # Normalized to snake_case assert config.initializers == [] # Uses default, not None + def test_from_dict_preserves_unknown_top_level_keys_in_extensions(self): + """Unknown top-level keys should be preserved for downstream tooling.""" + data = { + "memory_db_type": "in_memory", + "targets": [{"name": "x"}], + "scan_modes": {"default": {"threshold": 0.8}}, + } + config = ConfigurationLoader.from_dict(data) + assert config.memory_db_type == "in_memory" + assert config.extensions == { + "targets": [{"name": "x"}], + "scan_modes": {"default": {"threshold": 0.8}}, + } + + def test_from_dict_merges_explicit_extensions_with_unknown_keys(self): + data = { + "memory_db_type": "sqlite", + "extensions": {"team": "red"}, + "targets": [{"name": "x"}], + } + config = ConfigurationLoader.from_dict(data) + assert config.extensions == {"team": "red", "targets": [{"name": "x"}]} + + def test_from_dict_rejects_non_dict_extensions(self): + with pytest.raises(ValueError, match="extensions must be a dict"): + ConfigurationLoader.from_dict({"extensions": ["not", "a", "dict"]}) + def test_from_yaml_file(self): """Test loading configuration from a YAML file.""" yaml_content = """