From 3455ed7825aa44cc47bedab66892f66f911b0728 Mon Sep 17 00:00:00 2001 From: jcal-15 <13785195+jcal-15@users.noreply.github.com> Date: Mon, 9 Dec 2024 22:50:04 -0800 Subject: [PATCH] Adding option to support keeping underscores in argument names --- argparse_dataclass.py | 10 ++++++++-- tests/test_argumentparser.py | 14 ++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/argparse_dataclass.py b/argparse_dataclass.py index 9f65be8..8ac550c 100644 --- a/argparse_dataclass.py +++ b/argparse_dataclass.py @@ -334,7 +334,7 @@ def _add_dataclass_options( raise TypeError("cls must be a dataclass") for field in fields(options_class): - args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"]) + args = field.metadata.get("args", [f"--{_get_arg_name(field)}"]) positional = not args[0].startswith("-") kwargs = { "type": field.metadata.get("type", field.type), @@ -448,7 +448,7 @@ def _handle_bool_type(field: Field, args: list, kwargs: dict): if field.default is True: kwargs["action"] = "store_false" if "args" not in field.metadata: - args[0] = f"--no-{field.name.replace('_', '-')}" + args[0] = f"--no-{_get_arg_name(field)}" kwargs["dest"] = field.name elif field.metadata.get("required") is True: kwargs["action"] = BooleanOptionalAction @@ -479,6 +479,12 @@ def _handle_argument_group( group.add_argument(*args, **kwargs) +def _get_arg_name(field: Field): + if field.metadata.get("keep_underscores", False): + return field.name + return field.name.replace("_", "-") + + class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]): """Command line argument parser that derives its options from a dataclass. diff --git a/tests/test_argumentparser.py b/tests/test_argumentparser.py index 14127ec..70482fe 100644 --- a/tests/test_argumentparser.py +++ b/tests/test_argumentparser.py @@ -142,10 +142,6 @@ class Options: params = ArgumentParser(Options).parse_args(["--name", "john doe"]) self.assertEqual(params.name, "John Doe") - @unittest.skipIf( - sys.version_info[:2] == (3, 6), - "Python 3.6 does not have datetime.fromisoformat()", - ) def test_default_factory(self): @dataclass class Parameters: @@ -302,6 +298,16 @@ class Options: self.assertRaises(ValueError, lambda: ArgumentParser(Options)) + def test_keep_underscores(self): + @dataclass + class Args: + num_of_foo: int = field(metadata={"keep_underscores": True}) + is_fun: bool = field(default=True, metadata={"keep_underscores": True}) + + params = ArgumentParser(Args).parse_args(["--num_of_foo=10", "--no-is_fun"]) + self.assertEqual(10, params.num_of_foo) + self.assertFalse(params.is_fun) + if __name__ == "__main__": unittest.main()