-
Notifications
You must be signed in to change notification settings - Fork 27.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[examples] Generate argparsers from type hints on dataclasses (#3669)
* [examples] Generate argparsers from type hints on dataclasses * [HfArgumentParser] way simpler API * Restore run_language_modeling.py for easier diff * [HfArgumentParser] final tweaks from code review
- Loading branch information
Showing
5 changed files
with
357 additions
and
137 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import dataclasses | ||
from argparse import ArgumentParser | ||
from enum import Enum | ||
from typing import Any, Iterable, NewType, Tuple, Union | ||
|
||
|
||
DataClass = NewType("DataClass", Any) | ||
DataClassType = NewType("DataClassType", Any) | ||
|
||
|
||
class HfArgumentParser(ArgumentParser): | ||
""" | ||
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses | ||
to generate arguments. | ||
The class is designed to play well with the native argparse. In particular, | ||
you can add more (non-dataclass backed) arguments to the parser after initialization | ||
and you'll get the output back after parsing as an additional namespace. | ||
""" | ||
|
||
dataclass_types: Iterable[DataClassType] | ||
|
||
def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs): | ||
""" | ||
Args: | ||
dataclass_types: | ||
Dataclass type, or list of dataclass types for which we will "fill" instances | ||
with the parsed args. | ||
kwargs: | ||
(Optional) Passed to `argparse.ArgumentParser()` in the regular way. | ||
""" | ||
super().__init__(**kwargs) | ||
if dataclasses.is_dataclass(dataclass_types): | ||
dataclass_types = [dataclass_types] | ||
self.dataclass_types = dataclass_types | ||
for dtype in self.dataclass_types: | ||
self._add_dataclass_arguments(dtype) | ||
|
||
def _add_dataclass_arguments(self, dtype: DataClassType): | ||
for field in dataclasses.fields(dtype): | ||
field_name = f"--{field.name}" | ||
kwargs = field.metadata.copy() | ||
# field.metadata is not used at all by Data Classes, | ||
# it is provided as a third-party extension mechanism. | ||
if isinstance(field.type, str): | ||
raise ImportError( | ||
"This implementation is not compatible with Postponed Evaluation of Annotations (PEP 563)," | ||
"which can be opted in from Python 3.7 with `from __future__ import annotations`." | ||
"We will add compatibility when Python 3.9 is released." | ||
) | ||
typestring = str(field.type) | ||
for x in (int, float, str): | ||
if typestring == f"typing.Union[{x.__name__}, NoneType]": | ||
field.type = x | ||
if isinstance(field.type, type) and issubclass(field.type, Enum): | ||
kwargs["choices"] = list(field.type) | ||
kwargs["type"] = field.type | ||
if field.default is not dataclasses.MISSING: | ||
kwargs["default"] = field.default | ||
elif field.type is bool: | ||
kwargs["action"] = "store_false" if field.default is True else "store_true" | ||
if field.default is True: | ||
field_name = f"--no-{field.name}" | ||
kwargs["dest"] = field.name | ||
else: | ||
kwargs["type"] = field.type | ||
if field.default is not dataclasses.MISSING: | ||
kwargs["default"] = field.default | ||
else: | ||
kwargs["required"] = True | ||
self.add_argument(field_name, **kwargs) | ||
|
||
def parse_args_into_dataclasses(self, args=None, return_remaining_strings=False) -> Tuple[DataClass, ...]: | ||
""" | ||
Parse command-line args into instances of the specified dataclass types. | ||
This relies on argparse's `ArgumentParser.parse_known_args`. | ||
See the doc at: | ||
docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args | ||
Args: | ||
args: | ||
List of strings to parse. The default is taken from sys.argv. | ||
(same as argparse.ArgumentParser) | ||
return_remaining_strings: | ||
If true, also return a list of remaining argument strings. | ||
Returns: | ||
Tuple consisting of: | ||
- the dataclass instances in the same order as they | ||
were passed to the initializer.abspath | ||
- if applicable, an additional namespace for more | ||
(non-dataclass backed) arguments added to the parser | ||
after initialization. | ||
- The potential list of remaining argument strings. | ||
(same as argparse.ArgumentParser.parse_known_args) | ||
""" | ||
namespace, remaining_args = self.parse_known_args(args=args) | ||
outputs = [] | ||
for dtype in self.dataclass_types: | ||
keys = {f.name for f in dataclasses.fields(dtype)} | ||
inputs = {k: v for k, v in vars(namespace).items() if k in keys} | ||
for k in keys: | ||
delattr(namespace, k) | ||
obj = dtype(**inputs) | ||
outputs.append(obj) | ||
if len(namespace.__dict__) > 0: | ||
# additional namespace. | ||
outputs.append(namespace) | ||
if return_remaining_strings: | ||
return (*outputs, remaining_args) | ||
else: | ||
return (*outputs,) |
Oops, something went wrong.