diff --git a/CHANGELOG.md b/CHANGELOG.md index e6f0f2fc..d3143aba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ # Unreleased +- Simplified `**kwargs` arguments in CLI - Adds save/load_attribute helper - Adds `@cachable` decorator utility diff --git a/src/machinable/cli.py b/src/machinable/cli.py index 4faa002e..777cc1c1 100644 --- a/src/machinable/cli.py +++ b/src/machinable/cli.py @@ -16,34 +16,34 @@ def parse(args: List) -> tuple: dotlist = [] version = [] - def _push(_elements, _dotlist, _version): - if len(dotlist) > 0: - _version.append( - OmegaConf.to_container(OmegaConf.from_dotlist(_dotlist)) - ) - - if len(_version) > 0: - if len(_elements) > 0: - _elements[-1].extend(_version) + def _parse_dotlist(): + if len(dotlist) == 0: + return + _ver = {} + for k, v in OmegaConf.to_container( + OmegaConf.from_dotlist(dotlist) + ).items(): + if k.startswith("**"): + kwargs[-1][k[2:]] = v else: - _elements.append(_version) + _ver[k] = v + version.append(_ver) + + def _push(): + if len(version) == 0: + return + + if len(elements) > 0: + elements[-1].extend(version) + else: + elements.append(version) for arg in args: if arg.startswith("~"): # version - if len(dotlist) > 0: - # parse preceding dotlist - version.append( - OmegaConf.to_container(OmegaConf.from_dotlist(dotlist)) - ) - dotlist = [] + _parse_dotlist() + dotlist = [] version.append(arg) - elif arg.startswith("**kwargs="): - kwargs.append( - OmegaConf.to_container(OmegaConf.from_dotlist([arg[2:]]))[ - "kwargs" - ] - ) elif arg.startswith("--"): # method methods.append((len(elements), arg[2:])) @@ -52,23 +52,19 @@ def _push(_elements, _dotlist, _version): dotlist.append(arg) else: # module - _push(elements, dotlist, version) + _parse_dotlist() + _push() dotlist = [] version = [] # auto-complete `.project` -> `interface.project` if arg.startswith("."): arg = "interface" + arg elements.append([arg]) - if len(elements) - 1 > len(kwargs): - kwargs.append({}) - if len(elements) - 1 != len(kwargs): - raise ValueError(f"Multiple **kwargs for {arg}") - - _push(elements, dotlist, version) - if len(elements) > len(kwargs): - kwargs.append({}) - if len(elements) != len(kwargs): - raise ValueError(f"Multiple **kwargs for last interface") + kwargs.append({}) + + _parse_dotlist() + _push() + return elements, kwargs, methods diff --git a/tests/test_cli.py b/tests/test_cli.py index 6ad5463f..9af4f200 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -39,7 +39,7 @@ def test_cli_main(capfd, tmp_storage): [ "get", "machinable.execution", - "**kwargs={'resources': {'a': 1}}", + "**resources={'a': 1}", "hello", "name=there", "--resources", @@ -48,41 +48,30 @@ def test_cli_main(capfd, tmp_storage): out, err = capfd.readouterr() assert out == "{'a': 1}\n" - with pytest.raises(ValueError): - main( - [ - "get", - "machinable.execution", - "**kwargs={'resources': {'a': 1}}", - "**kwargs={}", - "hello", - "name=there", - "**kwargs={'test': 'me'}" "--resources", - ] - ) + main( + [ + "get", + "machinable.execution", + "**resources={'a': 1}", + "**resources={'a': 2}", + "--__model__", + ] + ) + out, err = capfd.readouterr() + assert "resources={'a': 2}" in out out, err = capfd.readouterr() main( [ "get", "machinable.execution", - "**kwargs={'resources': {'a': 1}}", + "**resources={'a': 1}", "--__model__", ] ) out, err = capfd.readouterr() assert "resources={'a': 1}" in out - with pytest.raises(ValueError): - main( - [ - "get", - "machinable.execution", - "**kwargs={'resources': {'a': 1}}", - "**kwargs={}", - ] - ) - # help assert main([]) == 0 assert main(["help"]) == 0