From 3309ea4b1e5ecc1c7c40f45d342eda8ea9d918da Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Tue, 21 May 2024 07:29:04 +0200 Subject: [PATCH] Directly providing a dict with parameters or a single parameter to a subclass or callable with class return now implicitly tries using the base class as class_path if not abstract. --- CHANGELOG.rst | 3 +++ DOCUMENTATION.rst | 10 ++++++++++ jsonargparse/_core.py | 5 ++++- jsonargparse/_typehints.py | 8 ++++++++ jsonargparse_tests/test_link_arguments.py | 8 ++++++-- jsonargparse_tests/test_subclasses.py | 24 ++++++++++++++++------- jsonargparse_tests/test_typehints.py | 12 +++++++++++- 7 files changed, 59 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6635a646..8e616714 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -19,6 +19,9 @@ Added ^^^^^ - Support for ``TypedDict`` (`#457 `__). +- Directly providing a dict with parameters or a single parameter to a subclass + or callable with class return now implicitly tries using the base class as + ``class_path`` if not abstract. Fixed ^^^^^ diff --git a/DOCUMENTATION.rst b/DOCUMENTATION.rst index f4027450..d11af8d7 100644 --- a/DOCUMENTATION.rst +++ b/DOCUMENTATION.rst @@ -2013,6 +2013,16 @@ been imported before parsing. Abstract classes and private classes (module or name starting with ``'_'``) are not considered. All the subclasses resolvable by its name can be seen in the general help ``python tool.py --help``. +When the base class is not abstract, the ``class_path`` can be omitted, by +giving directly ``init_args``, for example: + +.. code-block:: bash + + python tool.py --calendar.firstweekday 2 + +would implicitly use ``calendar.Calendar`` as the class path. + + Default values -------------- diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py index 703a724f..0e845b48 100644 --- a/jsonargparse/_core.py +++ b/jsonargparse/_core.py @@ -313,7 +313,10 @@ def _parse_common( _ActionPrintConfig.print_config_if_requested(self, cfg) with parser_context(parent_parser=self): - ActionLink.apply_parsing_links(self, cfg) + try: + ActionLink.apply_parsing_links(self, cfg) + except Exception as ex: + self.error(str(ex), ex) if not skip_check and not lenient_check.get(): self.check_config(cfg, skip_required=skip_required) diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index c2ad0ba9..f90c44e4 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -865,6 +865,11 @@ def adapt_typehints( else: raise ImportError(f"Unexpected import object {val_obj}") if isinstance(val, (dict, Namespace, NestedArg)): + if prev_val is None: + return_type = get_callable_return_type(typehint) + if return_type and not inspect.isabstract(return_type): + with suppress(ValueError): + prev_val = Namespace(class_path=get_import_path(return_type)) val = subclass_spec_as_namespace(val, prev_val) if not is_subclass_spec(val): raise ImportError( @@ -920,6 +925,9 @@ def adapt_typehints( return val val_input = val + if prev_val is None and not inspect.isabstract(typehint): + with suppress(ValueError): + prev_val = Namespace(class_path=get_import_path(typehint)) val = subclass_spec_as_namespace(val, prev_val) if not is_subclass_spec(val): raise_unexpected_value( diff --git a/jsonargparse_tests/test_link_arguments.py b/jsonargparse_tests/test_link_arguments.py index b62ef119..54f27445 100644 --- a/jsonargparse_tests/test_link_arguments.py +++ b/jsonargparse_tests/test_link_arguments.py @@ -2,6 +2,7 @@ from calendar import Calendar, TextCalendar from dataclasses import dataclass +from importlib.util import find_spec from typing import Any, Callable, List, Mapping, Optional, Union import pytest @@ -51,7 +52,7 @@ def to_str(value): subcommands.add_subcommand("sub", subparser) with subtests.test("parse_args"): - with pytest.raises(ValueError) as ctx: + with pytest.raises(ArgumentError) as ctx: parser.parse_args(["sub"]) ctx.match("Call to compute_fn of link 'to_str.*failed: value is empty") @@ -111,7 +112,10 @@ def test_on_parse_compute_fn_subclass_spec(parser, subtests): parser.set_defaults(cal1=None) with pytest.raises(ArgumentError) as ctx: parser.parse_args(["--cal1.firstweekday=-"]) - ctx.match('Parser key "cal1"') + if find_spec("typeshed_client"): + ctx.match('Parser key "cal1"') + else: + ctx.match("Call to compute_fn of link") class ClassA: diff --git a/jsonargparse_tests/test_subclasses.py b/jsonargparse_tests/test_subclasses.py index 3f2f6d4e..fb2275bf 100644 --- a/jsonargparse_tests/test_subclasses.py +++ b/jsonargparse_tests/test_subclasses.py @@ -246,13 +246,6 @@ def test_subclass_init_args_without_class_path(parser): assert cfg.cal3.init_args == Namespace(firstweekday=5) -def test_subclass_init_args_without_class_path_error(parser): - parser.add_subclass_arguments(Calendar, "cal1") - with pytest.raises(ArgumentError) as ctx: - parser.parse_args(["--cal1.init_args.firstweekday=4"]) - ctx.match("class path given previously") - - def test_subclass_init_args_without_class_path_dict(parser): parser.add_argument("--cfg", action=ActionConfigFile) parser.add_argument("--cal", type=Calendar) @@ -1500,6 +1493,23 @@ def test_subclass_help_not_subclass(parser): ctx.match("is not a subclass of") +class Implicit: + def __init__(self, a: int = 1, b: str = ""): + pass + + +def test_subclass_implicit_class_path(parser): + parser.add_argument("--implicit", type=Implicit) + cfg = parser.parse_args(['--implicit={"a": 2, "b": "x"}']) + assert cfg.implicit.class_path == f"{__name__}.Implicit" + assert cfg.implicit.init_args == Namespace(a=2, b="x") + cfg = parser.parse_args(["--implicit.a=3"]) + assert cfg.implicit.init_args == Namespace(a=3, b="") + with pytest.raises(ArgumentError) as ctx: + parser.parse_args(['--implicit={"c": null}']) + ctx.match('No action for key "c" to check its value') + + # error messages tests diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index 5e6ce9cb..e4de2ec7 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -754,6 +754,16 @@ def test_callable_args_return_type_class(parser, subtests): assert f"{__name__}.{name}" in help_str +def test_callable_return_type_class_implicit_class_path(parser): + parser.add_argument("--optimizer", type=Callable[[List[float]], Optimizer]) + cfg = parser.parse_args(['--optimizer={"lr": 0.5}']) + assert cfg.optimizer.class_path == f"{__name__}.Optimizer" + assert cfg.optimizer.init_args == Namespace(lr=0.5, momentum=0.0) + cfg = parser.parse_args(["--optimizer.momentum=0.2"]) + assert cfg.optimizer.class_path == f"{__name__}.Optimizer" + assert cfg.optimizer.init_args == Namespace(lr=0.001, momentum=0.2) + + def test_callable_multiple_args_return_type_class(parser, subtests): parser.add_argument("--optimizer", type=Callable[[List[float], float], Optimizer], default=SGD) @@ -924,7 +934,7 @@ def __init__( self.activation = activation -def test_callable_zero_args_return_type_class(parser): # , subtests): +def test_callable_zero_args_return_type_class(parser): parser.add_class_arguments(Model, "model") cfg = parser.parse_args([]) assert cfg.model.activation == Namespace(