Skip to content

Commit

Permalink
refactor: update TypeAction handling VarArg
Browse files Browse the repository at this point in the history
  • Loading branch information
jnoortheen committed Nov 3, 2020
1 parent 545cb80 commit 351979f
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 80 deletions.
55 changes: 34 additions & 21 deletions arger/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,29 @@ def __init__(self, func: tp.Optional[tp.Callable]):
for param in sign.parameters.values():
param_doc = docstr.params.get(param.name)
if param.default is param.empty:
# todo: handle VarArg
# if param.kind == inspect.Parameter.VAR_POSITIONAL:
# annot = tp_utils.VarArg(annot)
# elif param.kind == inspect.Parameter.VAR_KEYWORD:
# annot = tp_utils.VarKw(annot)
self.args[param.name] = create_argument(param, param_doc)
else:
self.args[param.name] = create_option(
param, param_doc, option_generator
)

def dispatch(self, ns: argparse.Namespace) -> tp.Any:
if self.fn:
kwargs = {}
args = []
sign = inspect.signature(self.fn)
for arg_name in self.args:
val = getattr(ns, arg_name)
param = sign.parameters[arg_name]
if param.kind == param.POSITIONAL_ONLY:
args.append(val)
elif param.kind == param.VAR_POSITIONAL:
args.extend(val)
else:
kwargs[arg_name] = val
return self.fn(*args, **kwargs)
return None


def create_option(
param: inspect.Parameter,
Expand All @@ -185,19 +197,19 @@ def create_option(


def create_argument(param: inspect.Parameter, doc: tp.Optional[ParamDocTp]) -> Argument:
arg = Argument(help=doc.doc if doc else "")
kwargs = dict(help=doc.doc if doc else "")
if param.kind == inspect.Parameter.VAR_POSITIONAL:
kwargs.setdefault("nargs", "*")
if param.annotation is not _EMPTY:
kwargs.setdefault("type", param.annotation)
arg = Argument(**kwargs)
arg.set_dest(param.name, param.annotation)
return arg


def get_nargs(typ: Any) -> Tuple[Any, Union[int, str]]:
inner = tp_utils.unpack_type(typ)
if (
tp_utils.is_tuple(typ)
and typ != tuple
and not isinstance(typ, (tp_utils.VarKw, tp_utils.VarArg))
and tp_utils.get_inner_args(typ)
):
if tp_utils.is_tuple(typ) and typ != tuple and tp_utils.get_inner_args(typ):
args = tp_utils.get_inner_args(typ)
inner = inner if len(set(args)) == 1 else str
return inner, '+' if (... in args) else len(args)
Expand All @@ -210,31 +222,32 @@ class TypeAction(argparse.Action):
def __init__(self, *args, **kwargs):
typ = kwargs.pop("type", _EMPTY)
self.orig_type = typ
self.is_iterable = tp_utils.is_seq_container(typ)
self.is_enum = False
if typ is not _EMPTY:
origin = tp_utils.get_origin(typ)
if tp_utils.is_iterable(origin):
if self.is_iterable:
origin, kwargs["nargs"] = get_nargs(typ)

if tp_utils.is_enum(origin):
kwargs.setdefault("choices", [e.name for e in origin])
origin = str
self.is_enum = True

kwargs["type"] = origin
super().__init__(*args, **kwargs)

def set_attr(self, namespace, vals):
if self.orig_type is _EMPTY:
val = vals
else:
val = tp_utils.cast(self.orig_type, vals)
setattr(namespace, self.dest, val)
def cast_value(self, vals):
if self.is_iterable or self.is_enum:
return tp_utils.cast(self.orig_type, vals)
return vals

def __call__(self, parser, namespace, values, option_string=None):
if tp_utils.is_iterable(self.orig_type):
if self.is_iterable:
items = getattr(namespace, self.dest, ()) or ()
items = list(items)
items.extend(values)
vals = items
else:
vals = values
self.set_attr(namespace, vals)
setattr(namespace, self.dest, self.cast_value(vals))
18 changes: 1 addition & 17 deletions arger/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from arger.funcs import ParsedFunc

from .typing_utils import VarArg

CMD_TITLE = "commands"
LEVEL = '__level__'
FUNC_PREFIX = '__func_'
Expand Down Expand Up @@ -53,7 +51,7 @@ def __init__(
self._add_args(_level)

def _add_args(self, level: int):
self.set_defaults(**{f'{FUNC_PREFIX}{level}': self.dispatch, LEVEL: level})
self.set_defaults(**{f'{FUNC_PREFIX}{level}': self.func.dispatch, LEVEL: level})

for arg_name, arg in self.func.args.items():
if arg_name.startswith(
Expand All @@ -62,20 +60,6 @@ def _add_args(self, level: int):
continue
arg.add(self)

def dispatch(self, ns: ap.Namespace) -> tp.Any:
if self.func.fn:
kwargs = {}
args = []
for arg_name, arg_type in self.func.args.items():
val = getattr(ns, arg_name)
if isinstance(arg_type.kwargs.get('type'), VarArg):
args = val
else:
kwargs[arg_name] = val
# todo: use inspect.signature.bind to bind kwargs and args to respective names
return self.func.fn(*args, **kwargs)
return None

def run(self, *args: str, capture_sys=True) -> ap.Namespace:
"""Parse cli and dispatch functions.
Expand Down
40 changes: 7 additions & 33 deletions arger/typing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def define_old_types():


def get_origin(tp):
"""Return the python class for the GenericAlias. Dict->dict, List->list..."""
origin = _get_origin(tp)

if not NEW_TYPING and hasattr(tp, '__name__'):
Expand Down Expand Up @@ -68,14 +69,14 @@ def unpack_type(tp, default=str) -> Any:
Returns:
type inside the container type
"""
if get_inner_args(tp):
inner_tp = getattr(tp, ARGS)
if inner_tp and str(inner_tp[0]) not in {'~T', 'typing.Any'}:
return inner_tp[0]
args = get_inner_args(tp)
if args:
if str(args[0]) not in {'~T', 'typing.Any'}:
return args[0]
return default


def is_iterable(tp):
def is_seq_container(tp):
origin = get_origin(tp)
return origin in {list, tuple, set, frozenset}

Expand All @@ -94,7 +95,7 @@ def cast(tp, val) -> Any:
if is_enum(origin):
return origin[val]

if is_iterable(origin):
if is_seq_container(origin):
val = origin(val)
args = get_inner_args(tp)
if (
Expand All @@ -112,30 +113,3 @@ def cast(tp, val) -> Any:


T = TypeVar('T')


class VarArg:
"""Represent variadic arguent."""

__origin__: Any = tuple
__args__: Any = ()

def __init__(self, tp):
self.__args__ = (tp, ...)

def __repr__(self):
tp = self.__args__[0]
tp = getattr(tp, "__name__", tp)
return f"{self.__class__.__name__}[{tp}]"

def __eq__(self, other):
return repr(self) == repr(other)

def __hash__(self):
return hash(repr(self))


class VarKw(VarArg):
"""Represent variadic keyword argument."""

__original__ = dict
5 changes: 3 additions & 2 deletions tests/test_args_opts/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# pylint: disable = redefined-outer-name

import inspect

import pytest
Expand All @@ -13,12 +15,11 @@ def param_doc(hlp=''):


@pytest.fixture
def parameter(name, tp=_EMPTY, default=_EMPTY):
def parameter(name, tp):
return inspect.Parameter(
name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=tp,
default=default,
)


Expand Down
3 changes: 0 additions & 3 deletions tests/test_args_opts/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import pytest

from arger.typing_utils import VarArg


@pytest.fixture
def parser(add_arger, argument):
Expand All @@ -29,7 +27,6 @@ class Num(Enum):
('enum', Num, 'one', Num.one),
('enum', Num, 'two', Num.two),
# container types
('vargs', VarArg, '1 2 3', ('1', '2', '3')),
('a_tuple', tuple, '1 2 3', ('1', '2', '3')),
('a_tuple', Tuple[int, ...], '1 2 3', (1, 2, 3)),
('a_tuple', Tuple[str, ...], '1 2 3', ('1', '2', '3')),
Expand Down
10 changes: 10 additions & 0 deletions tests/test_args_opts/test_options.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import inspect
from decimal import Decimal

import pytest


@pytest.fixture
def parameter(name, default):
return inspect.Parameter(
name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=default,
)


@pytest.fixture
def parser(add_arger, option):
return add_arger(option)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_parser_funcs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from arger.funcs import ParsedFunc, TypeAction
from arger.typing_utils import VarArg

from .utils import _reprint

Expand All @@ -18,15 +17,15 @@ def test_parse_function():
assert (
docs.description == "Example function with types documented in the docstring."
)
exp_args = ["param1", "param2", "args", "kw1", "kw2"]
exp_args = ["param1", "param2", "kw1", "kw2", "args"]
assert list(docs.args) == exp_args

exp_flags = [
("param1",),
("param2",),
("args",),
("-k", "--kw1"),
("-w", "--kw2"),
("args",),
]
exp_kwargs = [
{
Expand All @@ -39,7 +38,6 @@ def test_parse_function():
"help": "The second parameter.",
"type": str,
},
{"action": TypeAction, "help": "", "type": VarArg(int)},
{
"action": TypeAction,
"default": None,
Expand All @@ -52,6 +50,7 @@ def test_parse_function():
"dest": "kw2",
"help": "",
},
{"action": TypeAction, "help": "", "type": int, "nargs": "*"},
]

for idx, arg in enumerate(docs.args.values()):
Expand Down

0 comments on commit 351979f

Please sign in to comment.