Skip to content

Commit

Permalink
Merge pull request #9 from adsharma/master
Browse files Browse the repository at this point in the history
dataclass decorator
  • Loading branch information
mivade committed Mar 14, 2021
2 parents 6e46d35 + 4916002 commit 84d3918
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 5 deletions.
49 changes: 44 additions & 5 deletions argparse_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,26 @@
Examples
--------
Using dataclass decorator
.. code-block:: pycon
from argparse_dataclass import dataclass
@dataclass
class Opt:
x: int = 42
y: bool = False
def main():
params = Opt.parse_args()
print(params)
if __name__ == "__main__":
main()
A simple parser with flags:
Expand Down Expand Up @@ -120,7 +140,7 @@

import argparse
from contextlib import suppress
from dataclasses import is_dataclass, MISSING
from dataclasses import is_dataclass, MISSING, dataclass as real_dataclass
from typing import TypeVar

__version__ = "0.1.0"
Expand Down Expand Up @@ -152,10 +172,7 @@ def _add_dataclass_options(self) -> None:
for name, field in getattr(self._options_type, "__dataclass_fields__").items():
args = field.metadata.get("args", [f"--{name.replace('_', '-')}"])
positional = not args[0].startswith("-")
kwargs = {
"type": field.type,
"help": field.metadata.get("help", None),
}
kwargs = {"type": field.type, "help": field.metadata.get("help", None)}

if field.metadata.get("args") and not positional:
# We want to ensure that we store the argument based on the
Expand Down Expand Up @@ -186,3 +203,25 @@ def parse_args(self, *args, **kwargs) -> OptionsType:
"""Parse arguments and return as the dataclass type."""
namespace = super().parse_args(*args, **kwargs)
return self._options_type(**vars(namespace))


def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
unsafe_hash=False, frozen=False):
class Inner:
parser = ArgumentParser(
real_dataclass(
cls,
init=int,
repr=repr,
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
frozen=frozen,
)
)

@staticmethod
def parse_args(args=None):
return Inner.parser.parse_args(args)

return Inner
18 changes: 18 additions & 0 deletions tests/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/usr/bin/env python3

from argparse_dataclass import dataclass


@dataclass
class Opt:
x: int = 42
y: bool = False


def main():
params = Opt.parse_args()
print(params)


if __name__ == "__main__":
main()
22 changes: 22 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import unittest
from argparse_dataclass import dataclass


@dataclass
class Opt:
x: int = 42
y: bool = False


class ArgParseTests(unittest.TestCase):
def test_basic(self):
params = Opt.parse_args([])
self.assertEqual(42, params.x)
self.assertEqual(False, params.y)
params = Opt.parse_args(["--x=10", "--y"])
self.assertEqual(10, params.x)
self.assertEqual(True, params.y)


if __name__ == "__main__":
unittest.main()

0 comments on commit 84d3918

Please sign in to comment.