diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9ebc945b..4a13a540 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -56,6 +56,7 @@ repos: [ types-PyYAML, types-requests, + types-setuptools, ] verbose: true diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2b076e4a..0eb388d6 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,6 +21,8 @@ Added `__). - Support for ``NotRequired`` and ``Required`` annotations for ``TypedDict`` keys (`#571 `__). +- ``dataclass`` types now accept ``class_path`` and ``init_args`` as value + (`#581 `__). Fixed ^^^^^ @@ -38,8 +40,7 @@ Changed - Removed shtab experimental warning (`#561 `__). - For consistency ``add_subclass_arguments`` now sets default ``None`` instead - of ``SUPPRESS`` (`lightning#20103 - `__). + of ``SUPPRESS`` (`#568 `__). v4.32.1 (2024-08-23) diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py index 9ec653b1..2bee8523 100644 --- a/jsonargparse/_core.py +++ b/jsonargparse/_core.py @@ -321,7 +321,7 @@ def _parse_common( except Exception as ex: self.error(str(ex), ex) - if not skip_check and not lenient_check.get(): + if not skip_check: self.check_config(cfg, skip_required=skip_required) if not (with_meta or (with_meta is None and self._default_meta)): @@ -675,8 +675,10 @@ def add_subcommands(self, required: bool = True, dest: str = "subcommand", **kwa """ if "description" not in kwargs: kwargs["description"] = "For more details of each subcommand, add it as an argument followed by --help." - with parser_context(parent_parser=self, lenient_check=True): - subcommands: _ActionSubCommands = super().add_subparsers(dest=dest, **kwargs) # type: ignore[assignment] + default_config_files = self.default_config_files + self.default_config_files = [] + subcommands: _ActionSubCommands = super().add_subparsers(dest=dest, **kwargs) # type: ignore[assignment] + self.default_config_files = default_config_files if required: self.required_args.add(dest) subcommands._required = required # type: ignore[attr-defined] @@ -1069,7 +1071,7 @@ def check_values(cfg): continue val = cfg[key] if action is not None: - if val is None and skip_none: + if (val is None and skip_none) or lenient_check.get(): continue try: self._check_value_key(action, val, key, ccfg) @@ -1397,7 +1399,7 @@ def default_config_files(self) -> List[str]: return self._default_config_files @default_config_files.setter - def default_config_files(self, default_config_files: Optional[List[Union[str, os.PathLike]]]): + def default_config_files(self, default_config_files: Optional[Sequence[Union[str, os.PathLike]]]): if default_config_files is None: self._default_config_files = [] elif isinstance(default_config_files, list) and all( diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 3a2f8548..249159af 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -633,6 +633,8 @@ def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0): kwargs = dict(sub_add_kwargs) if sub_add_kwargs else {} if skip_args: kwargs.setdefault("skip", set()).add(skip_args) + if is_subclass_spec(kwargs.get("default")): + kwargs["default"] = kwargs["default"].get("init_args") parser = parent_parser.get() parser = type(parser)(exit_on_error=False, logger=parser.logger, parser_mode=parser.parser_mode) remove_actions(parser, (ActionConfigFile, _ActionPrintConfig)) @@ -1010,6 +1012,8 @@ def adapt_typehints( if serialize: val = load_value(parser.dump(val, **dump_kwargs.get())) elif isinstance(val, (dict, Namespace)): + if is_subclass_spec(val) and get_import_path(typehint) == val.get("class_path"): + val = val.get("init_args") val = parser.parse_object(val, defaults=sub_defaults.get() or list_item) elif isinstance(val, NestedArg): prev_val = prev_val if isinstance(prev_val, Namespace) else None diff --git a/jsonargparse_tests/test_dataclass_like.py b/jsonargparse_tests/test_dataclass_like.py index 818eb3b2..67a820eb 100644 --- a/jsonargparse_tests/test_dataclass_like.py +++ b/jsonargparse_tests/test_dataclass_like.py @@ -502,6 +502,46 @@ def test_nested_generic_dataclass(parser): assert "--x.y.g4 g4 (required, type: dict[str, union[float, bool]])" in help_str +# union mixture tests + + +@dataclasses.dataclass +class UnionData: + data_a: int = 1 + data_b: Optional[str] = None + + +class UnionClass: + def __init__(self, prm_1: float, prm_2: bool): + self.prm_1 = prm_1 + self.prm_2 = prm_2 + + +@pytest.mark.parametrize( + "union_type", + [ + Union[UnionData, UnionClass], + Union[UnionClass, UnionData], + ], +) +def test_class_path_union_mixture_dataclass_and_class(parser, union_type): + parser.add_argument("--union", type=union_type, enable_path=True) + + value = {"class_path": f"{__name__}.UnionData", "init_args": {"data_a": 2, "data_b": "x"}} + cfg = parser.parse_args([f"--union={json.dumps(value)}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.union, UnionData) + assert dataclasses.asdict(init.union) == {"data_a": 2, "data_b": "x"} + assert yaml.safe_load(parser.dump(cfg))["union"] == value["init_args"] + + value = {"class_path": f"{__name__}.UnionClass", "init_args": {"prm_1": 1.2, "prm_2": False}} + cfg = parser.parse_args([f"--union={json.dumps(value)}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.union, UnionClass) + assert init.union.prm_1 == 1.2 + assert yaml.safe_load(parser.dump(cfg))["union"] == value + + # final classes tests