Skip to content

Add support for PEP 692: **kwargs: typing.Unpack[TypedDict] #579

@a-gardner1

Description

@a-gardner1

🚀 Feature request

Allow using TypedDict for more precise **kwargs typing as described in PEP 692.

Motivation

I want to be able to enjoy static typing guarantees through mypy for classes or functions with TypedDict-annotated **kwargs and use those classes in configurations parsed by jsonargparse .

Right now, I either have to remove the annotation from **kwargs and hope that jsonargparse is able to inspect and infer the types using its heuristics, or I have to expand the kwargs and duplicate keywords that would otherwise be represented by the TypedDict.

Pitch

I want the following script to work without errors:

import tempfile
from dataclasses import dataclass
from typing import Any, NotRequired, Required, TypeVar, TypedDict, Unpack

import jsonargparse
import yaml
from jsonargparse import ActionConfigFile, ArgumentParser, lazy_instance

if __name__ == '__main__':
    class TestDict(TypedDict):
        a: Required[int]
        """
        Test documentation.
        """
        b: NotRequired[int]

    class InnerTestClass:

        def __init__(self, **kwargs: Unpack[TestDict]) -> None:
            self.a = kwargs['a']
            self.b = kwargs.get('b')

    @dataclass
    class TestClass:

        test: InnerTestClass

    parser = ArgumentParser(exit_on_error=False)
    parser.add_argument(
        "-c",
        "--config",
        action=ActionConfigFile,
        help="Path to a configuration file in json or yaml format.")
    parser.add_class_arguments(
        TestClass,
        "test",
        fail_untyped=False,
        instantiate=True,
        sub_configs=True,
        default=lazy_instance(
            TestClass,
        ),
    )

    config = yaml.safe_dump(
        {
            "test":
                {
                    "test":
                        {
                            "class_path": f"{__name__}.InnerTestClass",
                            "init_args": {
                                "a": 2,
                            }
                        }
                }
        })

    with tempfile.NamedTemporaryFile("w", suffix=".yaml") as f:
        f.write(config)
        f.flush()
        cfg = parser.parse_args(["--config", f"{f.name}"])

    print(parser.dump(cfg, skip_link_targets=False, skip_none=False))

The script should print the following:

test:
  test:
    class_path: __main__.InnerTestClass
    init_args:
      a: 2

Partial Solution

The following diff from 7874273 partially provides a solution.
It does not account for earlier Python versions that do not support Unpack without typing_extensions, and it probably violates some conventions or expectations of the existing codebase.
It may also have some unintended side-effects.

Diff for Partial Solution
diff --git a/jsonargparse/_common.py b/jsonargparse/_common.py
index 8c12b31..eda0717 100644
--- a/jsonargparse/_common.py
+++ b/jsonargparse/_common.py
@@ -16,6 +16,7 @@ from typing import (  # type: ignore[attr-defined]
     TypeVar,
     Union,
     _GenericAlias,
+    _UnpackGenericAlias,
 )
 
 from ._namespace import Namespace
@@ -102,6 +103,10 @@ def is_generic_class(cls) -> bool:
     return isinstance(cls, _GenericAlias) and getattr(cls, "__module__", "") != "typing"
 
 
+def is_unpack_typehint(cls) -> bool:
+    return isinstance(cls, _UnpackGenericAlias)
+
+
 def get_generic_origin(cls):
     return cls.__origin__ if is_generic_class(cls) else cls
 
diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py
index 9ec653b..a216c3d 100644
--- a/jsonargparse/_core.py
+++ b/jsonargparse/_core.py
@@ -1317,7 +1317,10 @@ class ArgumentParser(ParserDeprecations, ActionsContainer, ArgumentLinking, argp
                 keys.append(action_dest)
             elif getattr(action, "jsonnet_ext_vars", False):
                 prev_cfg[action_dest] = value
-            cfg[action_dest] = value
+            if value == inspect._empty:
+                cfg.pop(action_dest, None)
+            else:
+                cfg[action_dest] = value
         return cfg[parent_key] if parent_key else cfg
 
     def merge_config(self, cfg_from: Namespace, cfg_to: Namespace) -> Namespace:
@@ -1335,6 +1338,7 @@ class ArgumentParser(ParserDeprecations, ActionsContainer, ArgumentLinking, argp
         with parser_context(parent_parser=self):
             ActionTypeHint.discard_init_args_on_class_path_change(self, cfg_to, cfg_from)
         ActionTypeHint.delete_init_args_required_none(cfg_from, cfg_to)
+        ActionTypeHint.delete_not_required_args(cfg_from, cfg_to)
         cfg_to.update(cfg_from)
         ActionTypeHint.apply_appends(self, cfg_to)
         return cfg_to
diff --git a/jsonargparse/_parameter_resolvers.py b/jsonargparse/_parameter_resolvers.py
index 8279fc7..2df199b 100644
--- a/jsonargparse/_parameter_resolvers.py
+++ b/jsonargparse/_parameter_resolvers.py
@@ -20,6 +20,7 @@ from ._common import (
     is_dataclass_like,
     is_generic_class,
     is_subclass,
+    is_unpack_typehint,
     parse_logger,
 )
 from ._optionals import get_annotated_base_type, is_annotated, is_pydantic_model, parse_docs
@@ -28,6 +29,7 @@ from ._stubs_resolver import get_stub_types
 from ._util import (
     ClassFromFunctionBase,
     get_import_path,
+    get_typehint_args,
     get_typehint_origin,
     iter_to_set_str,
     unique,
@@ -328,6 +330,38 @@ def replace_generic_type_vars(params: ParamList, parent) -> None:
             param.annotation = replace_type_vars(param.annotation)
 
 
+def unpack_typed_dict_kwargs(params: ParamList) -> bool:
+    kwargs_idx = get_arg_kind_index(params, kinds.VAR_KEYWORD)
+    if kwargs_idx >= 0:
+        kwargs = params.pop(kwargs_idx)
+        annotation = kwargs.annotation
+        if is_unpack_typehint(annotation):
+            annotation_args = get_typehint_args(annotation)
+            assert len(annotation_args) == 1, "Unpack requires a single type argument"
+            dict_annotations = annotation_args[0].__annotations__
+            new_params = []
+            for nm, annot in dict_annotations.items():
+                new_params.append(ParamData(
+                    name=nm,
+                    annotation=annot,
+                    default=inspect._empty,
+                    kind=inspect._ParameterKind.KEYWORD_ONLY,
+                    doc=None,
+                    component=kwargs.component,
+                    parent=kwargs.parent,
+                    origin=kwargs.origin
+                ))
+            # insert in-place
+            trailing_params = []  # expected to be empty
+            for _ in range(kwargs_idx, len(params)):
+                trailing_params.append(params.pop(kwargs_idx))
+            params.extend(new_params)
+            params.extend(trailing_params)
+            return True
+    return False
+
+
+
 def add_stub_types(stubs: Optional[Dict[str, Any]], params: ParamList, component) -> None:
     if not stubs:
         return
@@ -848,12 +882,16 @@ class ParametersVisitor(LoggerProperty, ast.NodeVisitor):
             self.component, self.parent, self.logger
         )
         self.replace_param_default_subclass_specs(params)
+        if unpack_typed_dict_kwargs(params):
+            kwargs_idx = -1
         if args_idx >= 0 or kwargs_idx >= 0:
             self.doc_params = doc_params
             with mro_context(self.parent):
                 args, kwargs = self.get_parameters_args_and_kwargs()
             params = replace_args_and_kwargs(params, args, kwargs)
         add_stub_types(stubs, params, self.component)
+        # in case a typed-dict kwarg typehint is inherited
+        unpack_typed_dict_kwargs(params)
         params = self.remove_ignore_parameters(params)
         return params
 
@@ -865,6 +903,8 @@ def get_parameters_by_assumptions(
 ) -> ParamList:
     component, parent, method_name = get_component_and_parent(function_or_class, method_name)
     params, args_idx, kwargs_idx, _, stubs = get_signature_parameters_and_indexes(component, parent, logger)
+    if unpack_typed_dict_kwargs(params):
+        kwargs_idx = -1
 
     if parent and (args_idx >= 0 or kwargs_idx >= 0):
         with mro_context(parent):
@@ -875,6 +915,8 @@ def get_parameters_by_assumptions(
 
     params = replace_args_and_kwargs(params, [], [])
     add_stub_types(stubs, params, component)
+    # in case a typed-dict kwarg typehint is inherited
+    unpack_typed_dict_kwargs(params)
     return params
 
 
diff --git a/jsonargparse/_signatures.py b/jsonargparse/_signatures.py
index 807a8d4..7d75e19 100644
--- a/jsonargparse/_signatures.py
+++ b/jsonargparse/_signatures.py
@@ -29,8 +29,9 @@ from ._typehints import (
     callable_instances,
     get_subclass_names,
     is_optional,
+    not_required_types,
 )
-from ._util import NoneType, get_private_kwargs, iter_to_set_str
+from ._util import NoneType, get_private_kwargs, get_typehint_origin, iter_to_set_str
 from .typing import register_pydantic_type
 
 __all__ = [
@@ -322,7 +323,7 @@ class SignatureArguments(LoggerProperty):
             default = param.default
             if default == inspect_empty and is_optional(annotation):
                 default = None
-        is_required = default == inspect_empty
+        is_required = default == inspect_empty and get_typehint_origin(annotation) not in not_required_types
         src = get_parameter_origins(param.component, param.parent)
         skip_message = f'Skipping parameter "{name}" from "{src}" because of: '
         if not fail_untyped and annotation == inspect_empty:
diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py
index 50a119b..e67c14f 100644
--- a/jsonargparse/_typehints.py
+++ b/jsonargparse/_typehints.py
@@ -439,6 +439,13 @@ class ActionTypeHint(Action):
                         if skip_key in parser.required_args:
                             del val.init_args[skip_key]
 
+    @staticmethod
+    def delete_not_required_args(cfg_from, cfg_to):
+        for key, val in list(cfg_to.items(branches=True)):
+            if val == inspect._empty and key not in cfg_from:
+                del cfg_to[key]
+
+
     @staticmethod
     @contextmanager
     def subclass_arg_context(parser):
@@ -587,6 +594,8 @@ class ActionTypeHint(Action):
                     assert ex  # needed due to ruff bug that removes " as ex"
                     if orig_val == "-" and isinstance(getattr(ex, "parent", None), PathError):
                         raise ex
+                    if get_typehint_origin(self._typehint) in not_required_types and val == inspect._empty:
+                        ex = None
                     try:
                         if isinstance(orig_val, str):
                             with change_to_path_dir(config_path):
@@ -943,6 +952,7 @@ def adapt_typehints(
     # TypedDict NotRequired and Required
     elif typehint_origin in not_required_required_types:
         assert len(subtypehints) == 1, "(Not)Required requires a single type argument"
         val = adapt_typehints(val, subtypehints[0], **adapt_kwargs)
 
     # Callable
diff --git a/jsonargparse/_util.py b/jsonargparse/_util.py
index e97ea2a..3c3d6c7 100644
--- a/jsonargparse/_util.py
+++ b/jsonargparse/_util.py
@@ -268,6 +268,10 @@ def object_path_serializer(value):
         raise ValueError(f"Only possible to serialize an importable object, given {value}: {ex}") from ex
 
 
+def get_typehint_args(typehint):
+    return getattr(typehint, "__args__", tuple())
+
+
 def get_typehint_origin(typehint):
     if not hasattr(typehint, "__origin__"):
         typehint_class = get_import_path(typehint.__class__)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions