Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ Fixed
- Required and optional ``TypedDict`` keys are now correctly inferred when one
inherits one ``TypedDict`` from another with different totality (`#597
<https://github.com/omni-us/jsonargparse/pull/597>`__).
- Callables that return class not considering previous values (`#603
<https://github.com/omni-us/jsonargparse/pull/603>`__).

Changed
^^^^^^^
- Callables that return class with class default now normalizes the default to
a subclass spec with ``class_path`` (`#603
<https://github.com/omni-us/jsonargparse/pull/603>`__).


v4.33.2 (2024-10-07)
Expand Down
1 change: 0 additions & 1 deletion jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,6 @@ def merge_config(self, cfg_from: Namespace, cfg_to: Namespace) -> Namespace:
cfg_to = cfg_to.clone()
with parser_context(parent_parser=self):
ActionTypeHint.discard_init_args_on_class_path_change(self, cfg_to, cfg_from)
ActionTypeHint.delete_init_args_required_none(cfg_from, cfg_to)
ActionTypeHint.delete_not_required_args(cfg_from, cfg_to)
cfg_to.update(cfg_from)
ActionTypeHint.apply_appends(self, cfg_to)
Expand Down
20 changes: 3 additions & 17 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def normalize_default(self, default):
default_type = type(default)
if not is_subclass(default_type, UnknownDefault) and self.is_subclass_typehint(default_type):
raise ValueError("Subclass types require as default either a dict with class_path or a lazy instance.")
elif ActionTypeHint.is_return_subclass_typehint(self._typehint) and inspect.isclass(default):
default = {"class_path": get_import_path(default)}
return default

@staticmethod
Expand Down Expand Up @@ -430,21 +432,6 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_cfg, cfg):
keys = keys[: num + 1] + [k for k in keys[num + 1 :] if not k.startswith(key + ".")]
num += 1

@staticmethod
def delete_init_args_required_none(cfg_from, cfg_to):
for key, val in cfg_from.items(branches=True):
if isinstance(val, Namespace) and val.get("class_path") and val.get("init_args"):
skip_keys = [
k
for k, v in val.init_args.__dict__.items()
if v is None and cfg_to.get(f"{key}.init_args.{k}") is not None
]
if skip_keys:
parser = ActionTypeHint.get_class_parser(val.class_path)
for skip_key in skip_keys:
if skip_key in parser.required_args:
del val.init_args[skip_key]

@staticmethod
def delete_not_required_args(cfg_from, cfg_to):
for key, val in list(cfg_to.items(branches=True)):
Expand Down Expand Up @@ -1021,6 +1008,7 @@ def adapt_typehints(
sub_add_kwargs,
skip_args=num_partial_args,
partial_classes=partial_classes,
prev_val=prev_val,
)
except (ImportError, AttributeError, ArgumentError) as ex:
raise_unexpected_value(f"Type {typehint} expects a function or a callable class: {ex}", val, ex)
Expand Down Expand Up @@ -1168,8 +1156,6 @@ def subclass_spec_as_namespace(val, prev_val=None):
val = Namespace({root_key: val})
if isinstance(prev_val, str):
prev_val = Namespace(class_path=prev_val)
elif inspect.isclass(prev_val):
prev_val = Namespace(class_path=get_import_path(prev_val))
if isinstance(val, dict):
val = Namespace(val)
if "init_args" in val and isinstance(val["init_args"], dict):
Expand Down
37 changes: 37 additions & 0 deletions jsonargparse_tests/test_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,7 @@ def test_callable_args_return_type_class(parser, subtests):

with subtests.test("default"):
cfg = parser.get_defaults()
assert cfg.optimizer.class_path == f"{__name__}.SGD"
init = parser.instantiate_classes(cfg)
optimizer = init.optimizer([0.1, 2, 3])
assert isinstance(optimizer, SGD)
Expand Down Expand Up @@ -1226,6 +1227,42 @@ def test_callable_return_class_required_arg_from_default(parser):
assert cfg.model.init_args.scheduler.init_args == Namespace(monitor="acc", factor=0.5)


class ModelListCallableReturnClass:
def __init__(
self,
schedulers: List[Callable[[Optimizer], Union[StepLR, ReduceLROnPlateau]]] = [],
):
self.schedulers = schedulers


def test_list_callable_return_class(parser):
parser.add_argument("--cfg", action="config")
parser.add_argument("--model", type=ModelListCallableReturnClass)

config = {
"model": {
"class_path": f"{__name__}.ModelListCallableReturnClass",
"init_args": {
"schedulers": [
{
"class_path": f"{__name__}.StepLR",
},
{
"class_path": f"{__name__}.ReduceLROnPlateau",
"init_args": {
"factor": 0.5,
},
},
],
},
},
}

cfg = parser.parse_args([f"--cfg={config}", "--model.schedulers.monitor=val/mAP50"])
assert cfg.model.init_args.schedulers[1].class_path == f"{__name__}.ReduceLROnPlateau"
assert cfg.model.init_args.schedulers[1].init_args == Namespace(monitor="val/mAP50", factor=0.5)


# lazy_instance tests


Expand Down
Loading