From 1058347da4bb6e6e98052b6e44c9e8fa8f76372d Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Thu, 24 Oct 2024 21:50:39 +0200 Subject: [PATCH 1/4] Fix missing use of prev vals for callables returning class and normalize class default as spec. --- CHANGELOG.rst | 8 ++++++++ jsonargparse/_typehints.py | 5 ++++- jsonargparse_tests/test_typehints.py | 1 + 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9d94c570..aa2738a1 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,6 +27,14 @@ Fixed - Required and optional ``TypedDict`` keys are now correctly inferred when one inherits one ``TypedDict`` from another with different totality (`#597 `__). +- Callables that return class not considering previous values (`#??? + `__). + +Changed +^^^^^^^ +- Callables that return class with class default now normalizes the default to + a subclass spec with ``class_path`` (`#??? + `__). v4.33.2 (2024-10-07) diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 14f095cb..9d19226d 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -274,6 +274,8 @@ def normalize_default(self, default): default_type = type(default) if not is_subclass(default_type, UnknownDefault) and self.is_subclass_typehint(default_type): raise ValueError("Subclass types require as default either a dict with class_path or a lazy instance.") + elif ActionTypeHint.is_return_subclass_typehint(self._typehint) and inspect.isclass(default): + default = {"class_path": get_import_path(default)} return default @staticmethod @@ -407,7 +409,7 @@ def parse_argv_item(arg_string): @staticmethod def discard_init_args_on_class_path_change(parser_or_action, prev_cfg, cfg): - if isinstance(prev_cfg, dict): + if isinstance(prev_cfg, dict) or cfg is None: return keys = list(prev_cfg.keys(branches=True)) num = 0 @@ -1021,6 +1023,7 @@ def adapt_typehints( sub_add_kwargs, skip_args=num_partial_args, partial_classes=partial_classes, + prev_val=prev_val, ) except (ImportError, AttributeError, ArgumentError) as ex: raise_unexpected_value(f"Type {typehint} expects a function or a callable class: {ex}", val, ex) diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index b5dfb063..825e77be 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -940,6 +940,7 @@ def test_callable_args_return_type_class(parser, subtests): with subtests.test("default"): cfg = parser.get_defaults() + assert cfg.optimizer.class_path == f"{__name__}.SGD" init = parser.instantiate_classes(cfg) optimizer = init.optimizer([0.1, 2, 3]) assert isinstance(optimizer, SGD) From 2d15ea632680f4fa7ec9f31120bf9966658eef05 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Thu, 24 Oct 2024 22:32:35 +0200 Subject: [PATCH 2/4] Update changelog --- CHANGELOG.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index aa2738a1..c5db0c5f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,14 +27,14 @@ Fixed - Required and optional ``TypedDict`` keys are now correctly inferred when one inherits one ``TypedDict`` from another with different totality (`#597 `__). -- Callables that return class not considering previous values (`#??? - `__). +- Callables that return class not considering previous values (`#603 + `__). Changed ^^^^^^^ - Callables that return class with class default now normalizes the default to - a subclass spec with ``class_path`` (`#??? - `__). + a subclass spec with ``class_path`` (`#603 + `__). v4.33.2 (2024-10-07) From 4f295df94101418e6efc728f069769cb6e085d08 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Fri, 25 Oct 2024 11:25:10 +0200 Subject: [PATCH 3/4] Remove no longer used code and add unit test. --- jsonargparse/_core.py | 1 - jsonargparse/_typehints.py | 17 ------------- jsonargparse_tests/test_typehints.py | 36 ++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py index 063b623f..ce89003e 100644 --- a/jsonargparse/_core.py +++ b/jsonargparse/_core.py @@ -1336,7 +1336,6 @@ def merge_config(self, cfg_from: Namespace, cfg_to: Namespace) -> Namespace: cfg_to = cfg_to.clone() with parser_context(parent_parser=self): ActionTypeHint.discard_init_args_on_class_path_change(self, cfg_to, cfg_from) - ActionTypeHint.delete_init_args_required_none(cfg_from, cfg_to) ActionTypeHint.delete_not_required_args(cfg_from, cfg_to) cfg_to.update(cfg_from) ActionTypeHint.apply_appends(self, cfg_to) diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 9d19226d..e645ffc8 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -432,21 +432,6 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_cfg, cfg): keys = keys[: num + 1] + [k for k in keys[num + 1 :] if not k.startswith(key + ".")] num += 1 - @staticmethod - def delete_init_args_required_none(cfg_from, cfg_to): - for key, val in cfg_from.items(branches=True): - if isinstance(val, Namespace) and val.get("class_path") and val.get("init_args"): - skip_keys = [ - k - for k, v in val.init_args.__dict__.items() - if v is None and cfg_to.get(f"{key}.init_args.{k}") is not None - ] - if skip_keys: - parser = ActionTypeHint.get_class_parser(val.class_path) - for skip_key in skip_keys: - if skip_key in parser.required_args: - del val.init_args[skip_key] - @staticmethod def delete_not_required_args(cfg_from, cfg_to): for key, val in list(cfg_to.items(branches=True)): @@ -1171,8 +1156,6 @@ def subclass_spec_as_namespace(val, prev_val=None): val = Namespace({root_key: val}) if isinstance(prev_val, str): prev_val = Namespace(class_path=prev_val) - elif inspect.isclass(prev_val): - prev_val = Namespace(class_path=get_import_path(prev_val)) if isinstance(val, dict): val = Namespace(val) if "init_args" in val and isinstance(val["init_args"], dict): diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index 825e77be..51354d1d 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -1227,6 +1227,42 @@ def test_callable_return_class_required_arg_from_default(parser): assert cfg.model.init_args.scheduler.init_args == Namespace(monitor="acc", factor=0.5) +class ModelListCallableReturnClass: + def __init__( + self, + schedulers: List[Callable[[Optimizer], Union[StepLR, ReduceLROnPlateau]]] = [], + ): + self.schedulers = schedulers + + +def test_list_callable_return_class(parser): + parser.add_argument("--cfg", action="config") + parser.add_argument("--model", type=ModelListCallableReturnClass) + + config = { + "model": { + "class_path": f"{__name__}.ModelListCallableReturnClass", + "init_args": { + "schedulers": [ + { + "class_path": f"{__name__}.StepLR", + }, + { + "class_path": f"{__name__}.ReduceLROnPlateau", + "init_args": { + "factor": 0.5, + }, + }, + ], + }, + }, + } + + cfg = parser.parse_args([f"--cfg={config}", "--model.schedulers.monitor=val/mAP50"]) + assert cfg.model.init_args.schedulers[1].class_path == f"{__name__}.ReduceLROnPlateau" + assert cfg.model.init_args.schedulers[1].init_args == Namespace(monitor="val/mAP50", factor=0.5) + + # lazy_instance tests From c05edc97299f504c2c50ff218ae9461bc94422c7 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Fri, 25 Oct 2024 22:26:57 +0200 Subject: [PATCH 4/4] Revert unnecessary change --- jsonargparse/_typehints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index e645ffc8..effbeae3 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -409,7 +409,7 @@ def parse_argv_item(arg_string): @staticmethod def discard_init_args_on_class_path_change(parser_or_action, prev_cfg, cfg): - if isinstance(prev_cfg, dict) or cfg is None: + if isinstance(prev_cfg, dict): return keys = list(prev_cfg.keys(branches=True)) num = 0