Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ Fixed
<https://github.com/omni-us/jsonargparse/pull/603>`__).
- Custom instantiators not working for nested dependency injection (`#608
<https://github.com/omni-us/jsonargparse/pull/608>`__).
- Failure when resolving forward references from dataclass parameter types
(`#611 <https://github.com/omni-us/jsonargparse/pull/611>`__).

Changed
^^^^^^^
Expand Down
18 changes: 13 additions & 5 deletions jsonargparse/_postponed_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,20 @@ def type_requires_eval(typehint):
return isinstance(typehint, (str, ForwardRef))


def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict:
def get_global_vars(obj: Any, logger: Optional[logging.Logger]) -> dict:
global_vars = vars(import_module(obj.__module__))
try:
module_source = inspect.getsource(sys.modules[obj.__module__]) if obj.__module__ in sys.modules else ""
if "TYPE_CHECKING" in module_source:
TypeCheckingVisitor().update_aliases(module_source, obj.__module__, global_vars, logger)
except Exception as ex:
if logger:
logger.debug(f"Failed to update aliases for TYPE_CHECKING blocks in {obj.__module__}", exc_info=ex)
return global_vars


def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict:
global_vars = get_global_vars(obj, logger)
try:
types = get_type_hints(obj, global_vars)
except Exception as ex1:
Expand Down Expand Up @@ -288,10 +300,6 @@ def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict:
ex = types
types = {}

module_source = inspect.getsource(sys.modules[obj.__module__]) if obj.__module__ in sys.modules else ""
if "TYPE_CHECKING" in module_source:
TypeCheckingVisitor().update_aliases(module_source, obj.__module__, aliases, logger)

if isinstance(node, ast.FunctionDef):
arg_asts = [(a.arg, a.annotation) for a in node.args.args + node.args.kwonlyargs]
else:
Expand Down
20 changes: 16 additions & 4 deletions jsonargparse_tests/test_postponed_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def function_type_checking_list(p1: List[Union["TypeCheckingClass1", TypeCheckin
def test_get_types_type_checking_list():
types = get_types(function_type_checking_list)
assert list(types.keys()) == ["p1"]
lst = "typing.List" if sys.version_info < (3, 10) else "list"
lst = "typing.List"
assert str(types["p1"]) == f"{lst}[typing.Union[{__name__}.TypeCheckingClass1, {__name__}.TypeCheckingClass2]]"


Expand All @@ -267,7 +267,7 @@ def function_type_checking_tuple(p1: Tuple[TypeCheckingClass1, "TypeCheckingClas
def test_get_types_type_checking_tuple():
types = get_types(function_type_checking_tuple)
assert list(types.keys()) == ["p1"]
tpl = "typing.Tuple" if sys.version_info < (3, 10) else "tuple"
tpl = "typing.Tuple"
assert str(types["p1"]) == f"{tpl}[{__name__}.TypeCheckingClass1, {__name__}.TypeCheckingClass2]"


Expand All @@ -278,7 +278,7 @@ def function_type_checking_type(p1: Type["TypeCheckingClass2"]):
def test_get_types_type_checking_type():
types = get_types(function_type_checking_type)
assert list(types.keys()) == ["p1"]
tpl = "typing.Type" if sys.version_info < (3, 10) else "type"
tpl = "typing.Type"
assert str(types["p1"]) == f"{tpl}[{__name__}.TypeCheckingClass2]"


Expand All @@ -289,7 +289,7 @@ def function_type_checking_dict(p1: Dict[str, Union[TypeCheckingClass1, "TypeChe
def test_get_types_type_checking_dict():
types = get_types(function_type_checking_dict)
assert list(types.keys()) == ["p1"]
dct = "typing.Dict" if sys.version_info < (3, 10) else "dict"
dct = "typing.Dict"
assert str(types["p1"]) == f"{dct}[str, typing.Union[{__name__}.TypeCheckingClass1, {__name__}.TypeCheckingClass2]]"


Expand All @@ -305,6 +305,18 @@ def test_get_types_type_checking_undefined_forward_ref(logger):
assert "NameError: Name 'Undefined' is not defined" in logs.getvalue()


@dataclasses.dataclass
class DataclassForwardRef:
p1: "int"
p2: Optional["xml.dom.Node"] = None


@pytest.mark.skipif(sys.version_info < (3, 9), reason="not working in python 3.8")
def test_get_types_type_checking_dataclass_init_forward_ref():
types = get_types(DataclassForwardRef.__init__)
assert types == {"p1": int, "p2": Optional[xml.dom.Node], "return": type(None)}


def function_source_unavailable(p1: List["TypeCheckingClass1"]):
return p1

Expand Down
Loading