Skip to content

Commit

Permalink
Disallow extra nonexistent fields (#290)
Browse files Browse the repository at this point in the history
* Do not allow extra nonexistent fields

* Add test case

* Raise a RuntimeError instead of AssertionError

---------

Co-authored-by: Fabrice Normandin <fabrice.normandin@gmail.com>
  • Loading branch information
anivegesana and lebrice committed Nov 13, 2023
1 parent 5214954 commit f28c90d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
6 changes: 6 additions & 0 deletions simple_parsing/wrappers/dataclass_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,17 +295,23 @@ def set_default(self, value: DataclassT | dict | None):
self._default = value
if field_default_values is None:
return
unknown_names = set(field_default_values)
for field_wrapper in self.fields:
if field_wrapper.name not in field_default_values:
continue
# Manually set the default value for this argument.
field_default_value = field_default_values[field_wrapper.name]
field_wrapper.set_default(field_default_value)
unknown_names.remove(field_wrapper.name)
for nested_dataclass_wrapper in self._children:
if nested_dataclass_wrapper.name not in field_default_values:
continue
field_default_value = field_default_values[nested_dataclass_wrapper.name]
nested_dataclass_wrapper.set_default(field_default_value)
unknown_names.remove(nested_dataclass_wrapper.name)
unknown_names.discard("_type_")
if unknown_names:
raise RuntimeError(f"{sorted(unknown_names)} are not fields of {self.dataclass} at path {self.dest!r}!")

@property
def title(self) -> str:
Expand Down
20 changes: 20 additions & 0 deletions test/test_set_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,26 @@ def test_set_defaults_from_file(tmp_path: Path):
assert args.foo == saved_config


def test_set_broken_defaults_from_file(tmp_path: Path):
parser = ArgumentParser()
parser.add_arguments(Foo, dest="foo")

saved_config = Foo(a=456, b="HOLA")
config_path = tmp_path / "broken_foo.yaml"
broken_yaml = to_dict(saved_config)
broken_yaml["i_do_not_exist"] = 3
with open(config_path, "w") as f:
yaml.dump({"foo": broken_yaml}, f)

with pytest.raises(
RuntimeError,
match=(
r"\['i_do_not_exist'\] are not fields of <class 'test.test_set_defaults.Foo'> at path 'foo'!"
),
):
parser.set_defaults(config_path)


def test_set_defaults_from_file_without_root(tmp_path: Path):
"""test that set_defaults accepts the fields of the dataclass directly, when the parser has
nested_mode=NestedMode.WITHOUT_ROOT.
Expand Down

0 comments on commit f28c90d

Please sign in to comment.