Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ repos:
[
types-PyYAML,
types-requests,
types-setuptools,
]
verbose: true

Expand Down
5 changes: 3 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Added
<https://github.com/omni-us/jsonargparse/pull/554>`__).
- Support for ``NotRequired`` and ``Required`` annotations for ``TypedDict``
keys (`#571 <https://github.com/omni-us/jsonargparse/pull/571>`__).
- ``dataclass`` types now accept ``class_path`` and ``init_args`` as value
(`#581 <https://github.com/omni-us/jsonargparse/pull/581>`__).

Fixed
^^^^^
Expand All @@ -38,8 +40,7 @@ Changed
- Removed shtab experimental warning (`#561
<https://github.com/omni-us/jsonargparse/pull/561>`__).
- For consistency ``add_subclass_arguments`` now sets default ``None`` instead
of ``SUPPRESS`` (`lightning#20103
<https://github.com/Lightning-AI/pytorch-lightning/issue/20103>`__).
of ``SUPPRESS`` (`#568 <https://github.com/omni-us/jsonargparse/pull/568>`__).


v4.32.1 (2024-08-23)
Expand Down
12 changes: 7 additions & 5 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _parse_common(
except Exception as ex:
self.error(str(ex), ex)

if not skip_check and not lenient_check.get():
if not skip_check:
self.check_config(cfg, skip_required=skip_required)

if not (with_meta or (with_meta is None and self._default_meta)):
Expand Down Expand Up @@ -675,8 +675,10 @@ def add_subcommands(self, required: bool = True, dest: str = "subcommand", **kwa
"""
if "description" not in kwargs:
kwargs["description"] = "For more details of each subcommand, add it as an argument followed by --help."
with parser_context(parent_parser=self, lenient_check=True):
subcommands: _ActionSubCommands = super().add_subparsers(dest=dest, **kwargs) # type: ignore[assignment]
default_config_files = self.default_config_files
self.default_config_files = []
subcommands: _ActionSubCommands = super().add_subparsers(dest=dest, **kwargs) # type: ignore[assignment]
self.default_config_files = default_config_files
if required:
self.required_args.add(dest)
subcommands._required = required # type: ignore[attr-defined]
Expand Down Expand Up @@ -1069,7 +1071,7 @@ def check_values(cfg):
continue
val = cfg[key]
if action is not None:
if val is None and skip_none:
if (val is None and skip_none) or lenient_check.get():
continue
try:
self._check_value_key(action, val, key, ccfg)
Expand Down Expand Up @@ -1397,7 +1399,7 @@ def default_config_files(self) -> List[str]:
return self._default_config_files

@default_config_files.setter
def default_config_files(self, default_config_files: Optional[List[Union[str, os.PathLike]]]):
def default_config_files(self, default_config_files: Optional[Sequence[Union[str, os.PathLike]]]):
if default_config_files is None:
self._default_config_files = []
elif isinstance(default_config_files, list) and all(
Expand Down
4 changes: 4 additions & 0 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,8 @@ def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0):
kwargs = dict(sub_add_kwargs) if sub_add_kwargs else {}
if skip_args:
kwargs.setdefault("skip", set()).add(skip_args)
if is_subclass_spec(kwargs.get("default")):
kwargs["default"] = kwargs["default"].get("init_args")
parser = parent_parser.get()
parser = type(parser)(exit_on_error=False, logger=parser.logger, parser_mode=parser.parser_mode)
remove_actions(parser, (ActionConfigFile, _ActionPrintConfig))
Expand Down Expand Up @@ -1010,6 +1012,8 @@ def adapt_typehints(
if serialize:
val = load_value(parser.dump(val, **dump_kwargs.get()))
elif isinstance(val, (dict, Namespace)):
if is_subclass_spec(val) and get_import_path(typehint) == val.get("class_path"):
val = val.get("init_args")
val = parser.parse_object(val, defaults=sub_defaults.get() or list_item)
elif isinstance(val, NestedArg):
prev_val = prev_val if isinstance(prev_val, Namespace) else None
Expand Down
40 changes: 40 additions & 0 deletions jsonargparse_tests/test_dataclass_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,46 @@ def test_nested_generic_dataclass(parser):
assert "--x.y.g4 g4 (required, type: dict[str, union[float, bool]])" in help_str


# union mixture tests


@dataclasses.dataclass
class UnionData:
data_a: int = 1
data_b: Optional[str] = None


class UnionClass:
def __init__(self, prm_1: float, prm_2: bool):
self.prm_1 = prm_1
self.prm_2 = prm_2


@pytest.mark.parametrize(
"union_type",
[
Union[UnionData, UnionClass],
Union[UnionClass, UnionData],
],
)
def test_class_path_union_mixture_dataclass_and_class(parser, union_type):
parser.add_argument("--union", type=union_type, enable_path=True)

value = {"class_path": f"{__name__}.UnionData", "init_args": {"data_a": 2, "data_b": "x"}}
cfg = parser.parse_args([f"--union={json.dumps(value)}"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.union, UnionData)
assert dataclasses.asdict(init.union) == {"data_a": 2, "data_b": "x"}
assert yaml.safe_load(parser.dump(cfg))["union"] == value["init_args"]

value = {"class_path": f"{__name__}.UnionClass", "init_args": {"prm_1": 1.2, "prm_2": False}}
cfg = parser.parse_args([f"--union={json.dumps(value)}"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.union, UnionClass)
assert init.union.prm_1 == 1.2
assert yaml.safe_load(parser.dump(cfg))["union"] == value


# final classes tests


Expand Down