Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Fixed
<https://github.com/omni-us/jsonargparse/issues/516>`__).
- ``--print_config`` failing in some cases (`#517
<https://github.com/omni-us/jsonargparse/issues/517>`__).
- Callable that returns class not using required parameter default from lambda
(`#523 <https://github.com/omni-us/jsonargparse/pull/523>`__).


v4.29.0 (2024-05-24)
Expand Down
12 changes: 7 additions & 5 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,12 +1326,14 @@ def merge_config(self, cfg_from: Namespace, cfg_to: Namespace) -> Namespace:
Returns:
A new object with the merged configuration.
"""
cfg = cfg_to.clone()
cfg_from = cfg_from.clone()
cfg_to = cfg_to.clone()
with parser_context(parent_parser=self):
ActionTypeHint.discard_init_args_on_class_path_change(self, cfg, cfg_from)
cfg.update(cfg_from)
ActionTypeHint.apply_appends(self, cfg)
return cfg
ActionTypeHint.discard_init_args_on_class_path_change(self, cfg_to, cfg_from)
ActionTypeHint.delete_init_args_required_none(cfg_from, cfg_to)
cfg_to.update(cfg_from)
ActionTypeHint.apply_appends(self, cfg_to)
return cfg_to

def _check_value_key(self, action: argparse.Action, value: Any, key: str, cfg: Optional[Namespace]) -> Any:
"""Checks the value for a given action.
Expand Down
15 changes: 15 additions & 0 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,21 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_cfg, cfg):
keys = keys[: num + 1] + [k for k in keys[num + 1 :] if not k.startswith(key + ".")]
num += 1

@staticmethod
def delete_init_args_required_none(cfg_from, cfg_to):
for key, val in cfg_from.items(branches=True):
if isinstance(val, Namespace) and val.get("class_path") and val.get("init_args"):
skip_keys = [
k
for k, v in val.init_args.__dict__.items()
if v is None and cfg_to.get(f"{key}.init_args.{k}") is not None
]
if skip_keys:
parser = ActionTypeHint.get_class_parser(val.class_path)
for skip_key in skip_keys:
if skip_key in parser.required_args:
del val.init_args[skip_key]

@staticmethod
@contextmanager
def subclass_arg_context(parser):
Expand Down
39 changes: 37 additions & 2 deletions jsonargparse_tests/test_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,9 +816,10 @@ def __init__(self, optimizer: Optimizer, last_epoch: int = -1):


class ReduceLROnPlateau:
def __init__(self, optimizer: Optimizer, monitor: str):
def __init__(self, optimizer: Optimizer, monitor: str, factor: float = 0.1):
self.optimizer = optimizer
self.monitor = monitor
self.factor = factor


def test_callable_args_return_type_union_of_classes(parser, subtests):
Expand Down Expand Up @@ -846,7 +847,7 @@ def test_callable_args_return_type_union_of_classes(parser, subtests):
}
cfg = parser.parse_args([f"--scheduler={value}"])
assert f"{__name__}.ReduceLROnPlateau" == cfg.scheduler.class_path
assert Namespace(monitor="loss") == cfg.scheduler.init_args
assert Namespace(monitor="loss", factor=0.1) == cfg.scheduler.init_args
init = parser.instantiate_classes(cfg)
scheduler = init.scheduler(optimizer)
assert isinstance(scheduler, ReduceLROnPlateau)
Expand Down Expand Up @@ -948,6 +949,40 @@ def test_callable_zero_args_return_type_class(parser):
assert activation.negative_slope == 0.05


class ModelRequiredCallableArg:
def __init__(
self,
scheduler: Callable[[Optimizer], ReduceLROnPlateau] = lambda o: ReduceLROnPlateau(o, monitor="acc"),
):
self.scheduler = scheduler


def test_callable_return_class_required_arg_from_default(parser):
parser.add_argument("--cfg", action="config")
parser.add_argument("--model", type=ModelRequiredCallableArg)

cfg = parser.parse_args(["--model=ModelRequiredCallableArg"])
assert cfg.model.init_args.scheduler.class_path == f"{__name__}.ReduceLROnPlateau"
assert cfg.model.init_args.scheduler.init_args == Namespace(monitor="acc", factor=0.1)

config = {
"model": {
"class_path": f"{__name__}.ModelRequiredCallableArg",
"init_args": {
"scheduler": {
"class_path": f"{__name__}.ReduceLROnPlateau",
"init_args": {
"factor": 0.5,
},
},
},
}
}
cfg = parser.parse_args([f"--cfg={config}"])
assert cfg.model.init_args.scheduler.class_path == f"{__name__}.ReduceLROnPlateau"
assert cfg.model.init_args.scheduler.init_args == Namespace(monitor="acc", factor=0.5)


# lazy_instance tests


Expand Down