Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions argparse_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _add_dataclass_options(
raise TypeError("cls must be a dataclass")

for field in fields(options_class):
args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"])
args = field.metadata.get("args", [f"--{_get_arg_name(field)}"])
positional = not args[0].startswith("-")
kwargs = {
"type": field.metadata.get("type", field.type),
Expand Down Expand Up @@ -448,7 +448,7 @@ def _handle_bool_type(field: Field, args: list, kwargs: dict):
if field.default is True:
kwargs["action"] = "store_false"
if "args" not in field.metadata:
args[0] = f"--no-{field.name.replace('_', '-')}"
args[0] = f"--no-{_get_arg_name(field)}"
kwargs["dest"] = field.name
elif field.metadata.get("required") is True:
kwargs["action"] = BooleanOptionalAction
Expand Down Expand Up @@ -479,6 +479,12 @@ def _handle_argument_group(
group.add_argument(*args, **kwargs)


def _get_arg_name(field: Field):
if field.metadata.get("keep_underscores", False):
return field.name
return field.name.replace("_", "-")


class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]):
"""Command line argument parser that derives its options from a dataclass.

Expand Down
14 changes: 10 additions & 4 deletions tests/test_argumentparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ class Options:
params = ArgumentParser(Options).parse_args(["--name", "john doe"])
self.assertEqual(params.name, "John Doe")

@unittest.skipIf(
sys.version_info[:2] == (3, 6),
"Python 3.6 does not have datetime.fromisoformat()",
)
def test_default_factory(self):
@dataclass
class Parameters:
Expand Down Expand Up @@ -302,6 +298,16 @@ class Options:

self.assertRaises(ValueError, lambda: ArgumentParser(Options))

def test_keep_underscores(self):
@dataclass
class Args:
num_of_foo: int = field(metadata={"keep_underscores": True})
is_fun: bool = field(default=True, metadata={"keep_underscores": True})

params = ArgumentParser(Args).parse_args(["--num_of_foo=10", "--no-is_fun"])
self.assertEqual(10, params.num_of_foo)
self.assertFalse(params.is_fun)


if __name__ == "__main__":
unittest.main()
Loading