From dcb2ce58d244aa187abfa83d794c577a771a1425 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Thu, 31 Oct 2024 07:12:02 +0100 Subject: [PATCH 1/2] Fix failures when resolving forward references from dataclass parameter types (#609). --- CHANGELOG.rst | 2 ++ jsonargparse/_postponed_annotations.py | 16 ++++++++++----- .../test_postponed_annotations.py | 20 +++++++++++++++---- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 489125ef..29769c74 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -31,6 +31,8 @@ Fixed `__). - Custom instantiators not working for nested dependency injection (`#608 `__). +- Failure when resolving forward references from dataclass parameter types + (`#??? `__). Changed ^^^^^^^ diff --git a/jsonargparse/_postponed_annotations.py b/jsonargparse/_postponed_annotations.py index 5dbdffd4..bd4df689 100644 --- a/jsonargparse/_postponed_annotations.py +++ b/jsonargparse/_postponed_annotations.py @@ -4,6 +4,7 @@ import sys import textwrap from collections import namedtuple +from contextlib import suppress from copy import deepcopy from dataclasses import is_dataclass from importlib import import_module @@ -259,8 +260,17 @@ def type_requires_eval(typehint): return isinstance(typehint, (str, ForwardRef)) -def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict: +def get_global_vars(obj: Any, logger: Optional[logging.Logger]) -> dict: global_vars = vars(import_module(obj.__module__)) + with suppress(Exception): + module_source = inspect.getsource(sys.modules[obj.__module__]) if obj.__module__ in sys.modules else "" + if "TYPE_CHECKING" in module_source: + TypeCheckingVisitor().update_aliases(module_source, obj.__module__, global_vars, logger) + return global_vars + + +def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict: + global_vars = get_global_vars(obj, logger) try: types = get_type_hints(obj, global_vars) except Exception as ex1: @@ -288,10 +298,6 @@ def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict: ex = types types = {} - module_source = inspect.getsource(sys.modules[obj.__module__]) if obj.__module__ in sys.modules else "" - if "TYPE_CHECKING" in module_source: - TypeCheckingVisitor().update_aliases(module_source, obj.__module__, aliases, logger) - if isinstance(node, ast.FunctionDef): arg_asts = [(a.arg, a.annotation) for a in node.args.args + node.args.kwonlyargs] else: diff --git a/jsonargparse_tests/test_postponed_annotations.py b/jsonargparse_tests/test_postponed_annotations.py index 05d7d907..afbdccb4 100644 --- a/jsonargparse_tests/test_postponed_annotations.py +++ b/jsonargparse_tests/test_postponed_annotations.py @@ -256,7 +256,7 @@ def function_type_checking_list(p1: List[Union["TypeCheckingClass1", TypeCheckin def test_get_types_type_checking_list(): types = get_types(function_type_checking_list) assert list(types.keys()) == ["p1"] - lst = "typing.List" if sys.version_info < (3, 10) else "list" + lst = "typing.List" assert str(types["p1"]) == f"{lst}[typing.Union[{__name__}.TypeCheckingClass1, {__name__}.TypeCheckingClass2]]" @@ -267,7 +267,7 @@ def function_type_checking_tuple(p1: Tuple[TypeCheckingClass1, "TypeCheckingClas def test_get_types_type_checking_tuple(): types = get_types(function_type_checking_tuple) assert list(types.keys()) == ["p1"] - tpl = "typing.Tuple" if sys.version_info < (3, 10) else "tuple" + tpl = "typing.Tuple" assert str(types["p1"]) == f"{tpl}[{__name__}.TypeCheckingClass1, {__name__}.TypeCheckingClass2]" @@ -278,7 +278,7 @@ def function_type_checking_type(p1: Type["TypeCheckingClass2"]): def test_get_types_type_checking_type(): types = get_types(function_type_checking_type) assert list(types.keys()) == ["p1"] - tpl = "typing.Type" if sys.version_info < (3, 10) else "type" + tpl = "typing.Type" assert str(types["p1"]) == f"{tpl}[{__name__}.TypeCheckingClass2]" @@ -289,7 +289,7 @@ def function_type_checking_dict(p1: Dict[str, Union[TypeCheckingClass1, "TypeChe def test_get_types_type_checking_dict(): types = get_types(function_type_checking_dict) assert list(types.keys()) == ["p1"] - dct = "typing.Dict" if sys.version_info < (3, 10) else "dict" + dct = "typing.Dict" assert str(types["p1"]) == f"{dct}[str, typing.Union[{__name__}.TypeCheckingClass1, {__name__}.TypeCheckingClass2]]" @@ -305,6 +305,18 @@ def test_get_types_type_checking_undefined_forward_ref(logger): assert "NameError: Name 'Undefined' is not defined" in logs.getvalue() +@dataclasses.dataclass +class DataclassForwardRef: + p1: "int" + p2: Optional["xml.dom.Node"] = None + + +@pytest.mark.skipif(sys.version_info < (3, 9), reason="not working in python 3.8") +def test_get_types_type_checking_dataclass_init_forward_ref(): + types = get_types(DataclassForwardRef.__init__) + assert types == {"p1": int, "p2": Optional[xml.dom.Node], "return": type(None)} + + def function_source_unavailable(p1: List["TypeCheckingClass1"]): return p1 From 25a8ccce82c84338b0b39bac479588a1f789c56e Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Thu, 31 Oct 2024 07:19:09 +0100 Subject: [PATCH 2/2] Update changelog and add logging --- CHANGELOG.rst | 2 +- jsonargparse/_postponed_annotations.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 29769c74..84bb8c6b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -32,7 +32,7 @@ Fixed - Custom instantiators not working for nested dependency injection (`#608 `__). - Failure when resolving forward references from dataclass parameter types - (`#??? `__). + (`#611 `__). Changed ^^^^^^^ diff --git a/jsonargparse/_postponed_annotations.py b/jsonargparse/_postponed_annotations.py index bd4df689..48136e5a 100644 --- a/jsonargparse/_postponed_annotations.py +++ b/jsonargparse/_postponed_annotations.py @@ -4,7 +4,6 @@ import sys import textwrap from collections import namedtuple -from contextlib import suppress from copy import deepcopy from dataclasses import is_dataclass from importlib import import_module @@ -262,10 +261,13 @@ def type_requires_eval(typehint): def get_global_vars(obj: Any, logger: Optional[logging.Logger]) -> dict: global_vars = vars(import_module(obj.__module__)) - with suppress(Exception): + try: module_source = inspect.getsource(sys.modules[obj.__module__]) if obj.__module__ in sys.modules else "" if "TYPE_CHECKING" in module_source: TypeCheckingVisitor().update_aliases(module_source, obj.__module__, global_vars, logger) + except Exception as ex: + if logger: + logger.debug(f"Failed to update aliases for TYPE_CHECKING blocks in {obj.__module__}", exc_info=ex) return global_vars