-
-
Notifications
You must be signed in to change notification settings - Fork 59
Description
🚀 Feature request
Allow using TypedDict
for more precise **kwargs
typing as described in PEP 692.
Motivation
I want to be able to enjoy static typing guarantees through mypy
for classes or functions with TypedDict
-annotated **kwargs
and use those classes in configurations parsed by jsonargparse
.
Right now, I either have to remove the annotation from **kwargs
and hope that jsonargparse
is able to inspect and infer the types using its heuristics, or I have to expand the kwargs
and duplicate keywords that would otherwise be represented by the TypedDict
.
Pitch
I want the following script to work without errors:
import tempfile
from dataclasses import dataclass
from typing import Any, NotRequired, Required, TypeVar, TypedDict, Unpack
import jsonargparse
import yaml
from jsonargparse import ActionConfigFile, ArgumentParser, lazy_instance
if __name__ == '__main__':
class TestDict(TypedDict):
a: Required[int]
"""
Test documentation.
"""
b: NotRequired[int]
class InnerTestClass:
def __init__(self, **kwargs: Unpack[TestDict]) -> None:
self.a = kwargs['a']
self.b = kwargs.get('b')
@dataclass
class TestClass:
test: InnerTestClass
parser = ArgumentParser(exit_on_error=False)
parser.add_argument(
"-c",
"--config",
action=ActionConfigFile,
help="Path to a configuration file in json or yaml format.")
parser.add_class_arguments(
TestClass,
"test",
fail_untyped=False,
instantiate=True,
sub_configs=True,
default=lazy_instance(
TestClass,
),
)
config = yaml.safe_dump(
{
"test":
{
"test":
{
"class_path": f"{__name__}.InnerTestClass",
"init_args": {
"a": 2,
}
}
}
})
with tempfile.NamedTemporaryFile("w", suffix=".yaml") as f:
f.write(config)
f.flush()
cfg = parser.parse_args(["--config", f"{f.name}"])
print(parser.dump(cfg, skip_link_targets=False, skip_none=False))
The script should print the following:
test:
test:
class_path: __main__.InnerTestClass
init_args:
a: 2
Partial Solution
The following diff from 7874273 partially provides a solution.
It does not account for earlier Python versions that do not support Unpack
without typing_extensions
, and it probably violates some conventions or expectations of the existing codebase.
It may also have some unintended side-effects.
Diff for Partial Solution
diff --git a/jsonargparse/_common.py b/jsonargparse/_common.py
index 8c12b31..eda0717 100644
--- a/jsonargparse/_common.py
+++ b/jsonargparse/_common.py
@@ -16,6 +16,7 @@ from typing import ( # type: ignore[attr-defined]
TypeVar,
Union,
_GenericAlias,
+ _UnpackGenericAlias,
)
from ._namespace import Namespace
@@ -102,6 +103,10 @@ def is_generic_class(cls) -> bool:
return isinstance(cls, _GenericAlias) and getattr(cls, "__module__", "") != "typing"
+def is_unpack_typehint(cls) -> bool:
+ return isinstance(cls, _UnpackGenericAlias)
+
+
def get_generic_origin(cls):
return cls.__origin__ if is_generic_class(cls) else cls
diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py
index 9ec653b..a216c3d 100644
--- a/jsonargparse/_core.py
+++ b/jsonargparse/_core.py
@@ -1317,7 +1317,10 @@ class ArgumentParser(ParserDeprecations, ActionsContainer, ArgumentLinking, argp
keys.append(action_dest)
elif getattr(action, "jsonnet_ext_vars", False):
prev_cfg[action_dest] = value
- cfg[action_dest] = value
+ if value == inspect._empty:
+ cfg.pop(action_dest, None)
+ else:
+ cfg[action_dest] = value
return cfg[parent_key] if parent_key else cfg
def merge_config(self, cfg_from: Namespace, cfg_to: Namespace) -> Namespace:
@@ -1335,6 +1338,7 @@ class ArgumentParser(ParserDeprecations, ActionsContainer, ArgumentLinking, argp
with parser_context(parent_parser=self):
ActionTypeHint.discard_init_args_on_class_path_change(self, cfg_to, cfg_from)
ActionTypeHint.delete_init_args_required_none(cfg_from, cfg_to)
+ ActionTypeHint.delete_not_required_args(cfg_from, cfg_to)
cfg_to.update(cfg_from)
ActionTypeHint.apply_appends(self, cfg_to)
return cfg_to
diff --git a/jsonargparse/_parameter_resolvers.py b/jsonargparse/_parameter_resolvers.py
index 8279fc7..2df199b 100644
--- a/jsonargparse/_parameter_resolvers.py
+++ b/jsonargparse/_parameter_resolvers.py
@@ -20,6 +20,7 @@ from ._common import (
is_dataclass_like,
is_generic_class,
is_subclass,
+ is_unpack_typehint,
parse_logger,
)
from ._optionals import get_annotated_base_type, is_annotated, is_pydantic_model, parse_docs
@@ -28,6 +29,7 @@ from ._stubs_resolver import get_stub_types
from ._util import (
ClassFromFunctionBase,
get_import_path,
+ get_typehint_args,
get_typehint_origin,
iter_to_set_str,
unique,
@@ -328,6 +330,38 @@ def replace_generic_type_vars(params: ParamList, parent) -> None:
param.annotation = replace_type_vars(param.annotation)
+def unpack_typed_dict_kwargs(params: ParamList) -> bool:
+ kwargs_idx = get_arg_kind_index(params, kinds.VAR_KEYWORD)
+ if kwargs_idx >= 0:
+ kwargs = params.pop(kwargs_idx)
+ annotation = kwargs.annotation
+ if is_unpack_typehint(annotation):
+ annotation_args = get_typehint_args(annotation)
+ assert len(annotation_args) == 1, "Unpack requires a single type argument"
+ dict_annotations = annotation_args[0].__annotations__
+ new_params = []
+ for nm, annot in dict_annotations.items():
+ new_params.append(ParamData(
+ name=nm,
+ annotation=annot,
+ default=inspect._empty,
+ kind=inspect._ParameterKind.KEYWORD_ONLY,
+ doc=None,
+ component=kwargs.component,
+ parent=kwargs.parent,
+ origin=kwargs.origin
+ ))
+ # insert in-place
+ trailing_params = [] # expected to be empty
+ for _ in range(kwargs_idx, len(params)):
+ trailing_params.append(params.pop(kwargs_idx))
+ params.extend(new_params)
+ params.extend(trailing_params)
+ return True
+ return False
+
+
+
def add_stub_types(stubs: Optional[Dict[str, Any]], params: ParamList, component) -> None:
if not stubs:
return
@@ -848,12 +882,16 @@ class ParametersVisitor(LoggerProperty, ast.NodeVisitor):
self.component, self.parent, self.logger
)
self.replace_param_default_subclass_specs(params)
+ if unpack_typed_dict_kwargs(params):
+ kwargs_idx = -1
if args_idx >= 0 or kwargs_idx >= 0:
self.doc_params = doc_params
with mro_context(self.parent):
args, kwargs = self.get_parameters_args_and_kwargs()
params = replace_args_and_kwargs(params, args, kwargs)
add_stub_types(stubs, params, self.component)
+ # in case a typed-dict kwarg typehint is inherited
+ unpack_typed_dict_kwargs(params)
params = self.remove_ignore_parameters(params)
return params
@@ -865,6 +903,8 @@ def get_parameters_by_assumptions(
) -> ParamList:
component, parent, method_name = get_component_and_parent(function_or_class, method_name)
params, args_idx, kwargs_idx, _, stubs = get_signature_parameters_and_indexes(component, parent, logger)
+ if unpack_typed_dict_kwargs(params):
+ kwargs_idx = -1
if parent and (args_idx >= 0 or kwargs_idx >= 0):
with mro_context(parent):
@@ -875,6 +915,8 @@ def get_parameters_by_assumptions(
params = replace_args_and_kwargs(params, [], [])
add_stub_types(stubs, params, component)
+ # in case a typed-dict kwarg typehint is inherited
+ unpack_typed_dict_kwargs(params)
return params
diff --git a/jsonargparse/_signatures.py b/jsonargparse/_signatures.py
index 807a8d4..7d75e19 100644
--- a/jsonargparse/_signatures.py
+++ b/jsonargparse/_signatures.py
@@ -29,8 +29,9 @@ from ._typehints import (
callable_instances,
get_subclass_names,
is_optional,
+ not_required_types,
)
-from ._util import NoneType, get_private_kwargs, iter_to_set_str
+from ._util import NoneType, get_private_kwargs, get_typehint_origin, iter_to_set_str
from .typing import register_pydantic_type
__all__ = [
@@ -322,7 +323,7 @@ class SignatureArguments(LoggerProperty):
default = param.default
if default == inspect_empty and is_optional(annotation):
default = None
- is_required = default == inspect_empty
+ is_required = default == inspect_empty and get_typehint_origin(annotation) not in not_required_types
src = get_parameter_origins(param.component, param.parent)
skip_message = f'Skipping parameter "{name}" from "{src}" because of: '
if not fail_untyped and annotation == inspect_empty:
diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py
index 50a119b..e67c14f 100644
--- a/jsonargparse/_typehints.py
+++ b/jsonargparse/_typehints.py
@@ -439,6 +439,13 @@ class ActionTypeHint(Action):
if skip_key in parser.required_args:
del val.init_args[skip_key]
+ @staticmethod
+ def delete_not_required_args(cfg_from, cfg_to):
+ for key, val in list(cfg_to.items(branches=True)):
+ if val == inspect._empty and key not in cfg_from:
+ del cfg_to[key]
+
+
@staticmethod
@contextmanager
def subclass_arg_context(parser):
@@ -587,6 +594,8 @@ class ActionTypeHint(Action):
assert ex # needed due to ruff bug that removes " as ex"
if orig_val == "-" and isinstance(getattr(ex, "parent", None), PathError):
raise ex
+ if get_typehint_origin(self._typehint) in not_required_types and val == inspect._empty:
+ ex = None
try:
if isinstance(orig_val, str):
with change_to_path_dir(config_path):
@@ -943,6 +952,7 @@ def adapt_typehints(
# TypedDict NotRequired and Required
elif typehint_origin in not_required_required_types:
assert len(subtypehints) == 1, "(Not)Required requires a single type argument"
val = adapt_typehints(val, subtypehints[0], **adapt_kwargs)
# Callable
diff --git a/jsonargparse/_util.py b/jsonargparse/_util.py
index e97ea2a..3c3d6c7 100644
--- a/jsonargparse/_util.py
+++ b/jsonargparse/_util.py
@@ -268,6 +268,10 @@ def object_path_serializer(value):
raise ValueError(f"Only possible to serialize an importable object, given {value}: {ex}") from ex
+def get_typehint_args(typehint):
+ return getattr(typehint, "__args__", tuple())
+
+
def get_typehint_origin(typehint):
if not hasattr(typehint, "__origin__"):
typehint_class = get_import_path(typehint.__class__)