Skip to content

Commit

Permalink
Simplified **kwargs arguments in CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
frthjf committed May 25, 2024
1 parent 26b89d2 commit 839163c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 57 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Unreleased

- Simplified `**kwargs` arguments in CLI
- Adds save/load_attribute helper
- Adds `@cachable` decorator utility

Expand Down
62 changes: 29 additions & 33 deletions src/machinable/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,34 @@ def parse(args: List) -> tuple:
dotlist = []
version = []

def _push(_elements, _dotlist, _version):
if len(dotlist) > 0:
_version.append(
OmegaConf.to_container(OmegaConf.from_dotlist(_dotlist))
)

if len(_version) > 0:
if len(_elements) > 0:
_elements[-1].extend(_version)
def _parse_dotlist():
if len(dotlist) == 0:
return
_ver = {}
for k, v in OmegaConf.to_container(
OmegaConf.from_dotlist(dotlist)
).items():
if k.startswith("**"):
kwargs[-1][k[2:]] = v
else:
_elements.append(_version)
_ver[k] = v
version.append(_ver)

def _push():
if len(version) == 0:
return

if len(elements) > 0:
elements[-1].extend(version)
else:
elements.append(version)

for arg in args:
if arg.startswith("~"):
# version
if len(dotlist) > 0:
# parse preceding dotlist
version.append(
OmegaConf.to_container(OmegaConf.from_dotlist(dotlist))
)
dotlist = []
_parse_dotlist()
dotlist = []
version.append(arg)
elif arg.startswith("**kwargs="):
kwargs.append(
OmegaConf.to_container(OmegaConf.from_dotlist([arg[2:]]))[
"kwargs"
]
)
elif arg.startswith("--"):
# method
methods.append((len(elements), arg[2:]))
Expand All @@ -52,23 +52,19 @@ def _push(_elements, _dotlist, _version):
dotlist.append(arg)
else:
# module
_push(elements, dotlist, version)
_parse_dotlist()
_push()
dotlist = []
version = []
# auto-complete `.project` -> `interface.project`
if arg.startswith("."):
arg = "interface" + arg
elements.append([arg])
if len(elements) - 1 > len(kwargs):
kwargs.append({})
if len(elements) - 1 != len(kwargs):
raise ValueError(f"Multiple **kwargs for {arg}")

_push(elements, dotlist, version)
if len(elements) > len(kwargs):
kwargs.append({})
if len(elements) != len(kwargs):
raise ValueError(f"Multiple **kwargs for last interface")
kwargs.append({})

_parse_dotlist()
_push()

return elements, kwargs, methods


Expand Down
37 changes: 13 additions & 24 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_cli_main(capfd, tmp_storage):
[
"get",
"machinable.execution",
"**kwargs={'resources': {'a': 1}}",
"**resources={'a': 1}",
"hello",
"name=there",
"--resources",
Expand All @@ -48,41 +48,30 @@ def test_cli_main(capfd, tmp_storage):
out, err = capfd.readouterr()
assert out == "{'a': 1}\n"

with pytest.raises(ValueError):
main(
[
"get",
"machinable.execution",
"**kwargs={'resources': {'a': 1}}",
"**kwargs={}",
"hello",
"name=there",
"**kwargs={'test': 'me'}" "--resources",
]
)
main(
[
"get",
"machinable.execution",
"**resources={'a': 1}",
"**resources={'a': 2}",
"--__model__",
]
)
out, err = capfd.readouterr()
assert "resources={'a': 2}" in out

out, err = capfd.readouterr()
main(
[
"get",
"machinable.execution",
"**kwargs={'resources': {'a': 1}}",
"**resources={'a': 1}",
"--__model__",
]
)
out, err = capfd.readouterr()
assert "resources={'a': 1}" in out

with pytest.raises(ValueError):
main(
[
"get",
"machinable.execution",
"**kwargs={'resources': {'a': 1}}",
"**kwargs={}",
]
)

# help
assert main([]) == 0
assert main(["help"]) == 0
Expand Down

0 comments on commit 839163c

Please sign in to comment.