diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1c6c82b7..9089ceee 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,6 +21,8 @@ Fixed `__). - Linking entire dataclasses on instantiation not working (`#746 `__). +- Introspection of postponed annotations from jax not working (`#749 + `__). v4.40.1 (2025-07-24) @@ -441,9 +443,8 @@ Added - Allow adding config argument with ``action="config"`` avoiding need to import action class (`#512 `__). -- Allow providing a function with return type a class in ``class_path`` - (`lightning#13613 - `__). +- Allow providing a function with return type a class in ``class_path`` (`#513 + `__). - Automatic ``--print_shtab`` option when ``shtab`` is installed, providing completions for many type hints without the need to modify code (`#528 `__). diff --git a/jsonargparse/_postponed_annotations.py b/jsonargparse/_postponed_annotations.py index 48136e5a..646a4821 100644 --- a/jsonargparse/_postponed_annotations.py +++ b/jsonargparse/_postponed_annotations.py @@ -260,7 +260,10 @@ def type_requires_eval(typehint): def get_global_vars(obj: Any, logger: Optional[logging.Logger]) -> dict: - global_vars = vars(import_module(obj.__module__)) + global_vars = obj.__globals__.copy() if hasattr(obj, "__globals__") else {} + for key, value in vars(import_module(obj.__module__)).items(): # needed for pydantic-v1 + if key not in global_vars: + global_vars[key] = value try: module_source = inspect.getsource(sys.modules[obj.__module__]) if obj.__module__ in sys.modules else "" if "TYPE_CHECKING" in module_source: @@ -349,7 +352,7 @@ def evaluate_postponed_annotations(params, component, parent, logger): def get_return_type(component, logger=None): return_type = inspect.signature(component).return_annotation if type_requires_eval(return_type): - global_vars = vars(import_module(component.__module__)) + global_vars = get_global_vars(component, logger) try: return_type = get_type_hints(component, global_vars)["return"] if isinstance(return_type, ForwardRef): diff --git a/jsonargparse_tests/test_postponed_annotations.py b/jsonargparse_tests/test_postponed_annotations.py index c7af175a..4eab551f 100644 --- a/jsonargparse_tests/test_postponed_annotations.py +++ b/jsonargparse_tests/test_postponed_annotations.py @@ -313,6 +313,8 @@ class DataclassForwardRef: @pytest.mark.skipif(sys.version_info < (3, 9), reason="not working in python 3.8") def test_get_types_type_checking_dataclass_init_forward_ref(): + import xml.dom + types = get_types(DataclassForwardRef.__init__) assert types == {"p1": int, "p2": Optional[xml.dom.Node], "return": type(None)}