diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b5624a22..c80d4eae 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -48,6 +48,7 @@ v4.19.0 (2022-12-27) Added ^^^^^ +- ``add_dataclass_arguments`` now supports the ``fail_untyped`` parameter - ``CLI`` now supports the ``fail_untyped`` and ``parser_class`` parameters. - ``bytes`` and ``bytearray`` registered on first use and decodes from standard Base64. diff --git a/jsonargparse/signatures.py b/jsonargparse/signatures.py index 1e3cf034..cd19e3fb 100644 --- a/jsonargparse/signatures.py +++ b/jsonargparse/signatures.py @@ -312,6 +312,7 @@ def _add_signature_parameter( kwargs['required'] = True is_subclass_typehint = False is_final_class_typehint = is_final_class(annotation) + is_pure_dataclass_typehint = is_pure_dataclass(annotation) dest = (nested_key+'.' if nested_key else '') + name args = [dest if is_required and as_positional else '--'+dest] if param.origin: @@ -326,7 +327,7 @@ def _add_signature_parameter( if annotation in {str, int, float, bool} or \ is_subclass(annotation, (str, int, float)) or \ is_final_class_typehint or \ - is_pure_dataclass(annotation): + is_pure_dataclass_typehint: kwargs['type'] = annotation elif annotation != inspect_empty: try: @@ -353,7 +354,7 @@ def _add_signature_parameter( 'sub_configs': sub_configs, 'instantiate': instantiate, } - if is_final_class_typehint: + if is_final_class_typehint or is_pure_dataclass_typehint: kwargs.update(sub_add_kwargs) action = group.add_argument(*args, **kwargs) action.sub_add_kwargs = sub_add_kwargs @@ -370,6 +371,7 @@ def add_dataclass_arguments( nested_key: str, default: Optional[Union[Type, dict]] = None, as_group: bool = True, + fail_untyped: bool = True, **kwargs ) -> List[str]: """Adds arguments from a dataclass based on its field types and docstrings. @@ -379,6 +381,7 @@ def add_dataclass_arguments( nested_key: Key for nested namespace. default: Value for defaults. Must be instance of or kwargs for theclass. as_group: Whether arguments should be added to a new argument group. + fail_untyped: Whether to raise exception if a required parameter does not have a type. Returns: The list of arguments added. @@ -413,6 +416,7 @@ def add_dataclass_arguments( nested_key, params[field.name], added_args, + fail_untyped=fail_untyped, default=defaults.get(field.name, inspect_empty), ) diff --git a/jsonargparse_tests/test_signatures.py b/jsonargparse_tests/test_signatures.py index 6a4d3cb5..35400a15 100755 --- a/jsonargparse_tests/test_signatures.py +++ b/jsonargparse_tests/test_signatures.py @@ -1571,6 +1571,32 @@ class MyDataClass: self.assertEqual({'a': 1.2, 'b': 3.4}, cfg['a2']) + def test_dataclass_fail_untyped(self): + + class MyClass: + def __init__(self, c1) -> None: + self.c1 = c1 + + @dataclasses.dataclass + class MyDataclass: + a1: MyClass + a2: str = "a2" + a3: str = "a3" + + parser = ArgumentParser(exit_on_error=False) + parser.add_argument('--cfg', type=MyDataclass, fail_untyped=False) + + with mock_module(MyDataclass, MyClass) as module: + class_path = f'"class_path": "{module}.MyClass"' + init_args = '"init_args": {"c1": 1}' + cfg = parser.parse_args(['--cfg.a1={'+class_path+', '+init_args+'}']) + cfg = parser.instantiate_classes(cfg) + self.assertIsInstance(cfg['cfg'], MyDataclass) + self.assertIsInstance(cfg['cfg'].a1, MyClass) + self.assertIsInstance(cfg['cfg'].a2, str) + self.assertIsInstance(cfg['cfg'].a3, str) + + def test_compose_dataclasses(self): @dataclasses.dataclass