diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 7245b9e0..0b49968d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -22,6 +22,9 @@ Added - Allow providing a function with return type a class in ``class_path`` (`lightning#13613 `__) +- Automatic ``--print_shtab`` option when ``shtab`` is installed, providing + completions for many type hints without the need to modify code (`#528 + `__). Fixed ^^^^^ diff --git a/DOCUMENTATION.rst b/DOCUMENTATION.rst index 34687078..d6ce5ad6 100644 --- a/DOCUMENTATION.rst +++ b/DOCUMENTATION.rst @@ -2544,17 +2544,41 @@ function to provide an instance of the parser to :class:`.ActionParser`. Tab completion ============== -Tab completion is available for jsonargparse parsers by using the `argcomplete -`__ package. There is no need to -implement completer functions or to call :func:`argcomplete.autocomplete` since -this is done automatically by :py:meth:`.ArgumentParser.parse_args`. The only -requirement to enable tab completion is to install argcomplete either directly -or by installing jsonargparse with the ``argcomplete`` extras require as -explained in section :ref:`installation`. Then the tab completion can be enabled -`globally `__ for all -argcomplete compatible tools or for each `individual -`__ tool. A simple -``example.py`` tool would be: +Tab completion is available for jsonargparse parsers by using either the `shtab +`__ package or the `argcomplete +`__ package. + +shtab +----- + +For ``shtab`` to work, there is no need to set ``complete``/``choices`` to the +parser actions, and no need to call :func:`shtab.add_argument_to`. This is done +automatically by :py:meth:`.ArgumentParser.parse_args`. The only requirement is +to install shtab either directly or by installing jsonargparse with the +``shtab`` extras require as explained in section :ref:`installation`. + +.. note:: + + Automatic shtab support is currently experimental and subject to change. + +Once ``shtab`` is installed, parsers will automatically have the +``--print_shtab`` option that can be used to print the completion script for the +supported shells. For example in linux to enable bash completions for all users, +as root it would be used as: + +.. code-block:: bash + + # example.py --print_shtab=bash > /etc/bash_completion.d/example + +Without installing, completion scripts can be tested by sourcing or evaluating +them, for instance: + +.. code-block:: bash + + $ eval "$(example.py --print_shtab=bash)" + +The scripts work both to complete when there are choices, but also gives +instructions to the user for guidance. Take for example the parser: .. testsetup:: tab_completion @@ -2572,18 +2596,60 @@ argcomplete compatible tools or for each `individual parser.parse_args() -Then in a bash shell you can add the executable bit to the script, activate tab -completion and use it as follows: +The completions print the type of the argument, how many options are matched, +and afterward the list of choices matched up to that point. If only one option +matches, then the value is completed without printing guidance. For example: + +.. code-block:: bash + + $ example.py --bool + Expected type: Optional[bool]; 3/3 matched choices + true false null + $ example.py --bool f + $ example.py --bool false + +For the case of subclass types, the import class paths for known subclasses are +completed, both for the switch to select the class and for the corresponding +``--*.help`` switch. The ``init_args`` for known subclasses are also completed, +giving as guidance which of the subclasses accepts it. An example would be: + +.. code-block:: bash + + $ example.py --cls + Expected type: BaseClass; 3/3 matched choices + some.module.BaseClass other.module.SubclassA + other.module.SubclassB + $ example.py --cls other.module.SubclassA --cls. + --cls.param1 --cls.param2 + $ example.py --cls other.module.SubclassA --cls.param2 + Expected type: int; Accepted by subclasses: SubclassA + +argcomplete +----------- + +For ``argcompete`` to work, there is no need to implement completer functions or +to call :func:`argcomplete.autocomplete` since this is done automatically by +:py:meth:`.ArgumentParser.parse_args`. The only requirement to enable tab +completion is to install argcomplete either directly or by installing +jsonargparse with the ``argcomplete`` extras require as explained in section +:ref:`installation`. + +The tab completion can be enabled `globally +`__ for all +argcomplete compatible tools or for each `individual +`__ tool. + +Using the same ``bool`` example as shown above, activate tab completion and use +it as follows: .. code-block:: bash - $ chmod +x example.py $ eval "$(register-python-argcomplete example.py)" - $ ./example.py --bool + $ example.py --bool false null true - $ ./example.py --bool f - $ ./example.py --bool false + $ example.py --bool f + $ example.py --bool false .. _logging: diff --git a/jsonargparse/_actions.py b/jsonargparse/_actions.py index b0762e92..5403b071 100644 --- a/jsonargparse/_actions.py +++ b/jsonargparse/_actions.py @@ -12,7 +12,7 @@ from ._common import Action, get_class_instantiator, is_subclass, parser_context from ._loaders_dumpers import get_loader_exceptions, load_value from ._namespace import Namespace, NSKeyError, split_key, split_key_root -from ._optionals import FilesCompleterMethod, get_config_read_mode +from ._optionals import get_config_read_mode from ._type_checking import ArgumentParser from ._util import ( NoneType, @@ -145,7 +145,7 @@ def filter_default_actions(actions): return {k: a for k, a in actions.items() if not isinstance(a, default)} -class ActionConfigFile(Action, FilesCompleterMethod): +class ActionConfigFile(Action): """Action to indicate that an argument is a configuration file or a configuration string.""" def __init__(self, **kwargs): @@ -208,6 +208,12 @@ def apply_config(parser, cfg, dest, value) -> None: cfg[dest] = [] cfg[dest].append(cfg_path) + def completer(self, prefix, **kwargs): + from ._completions import get_files_completer + + files_completer = get_files_completer() + return sorted(files_completer(prefix, **kwargs)) + previous_config: ContextVar = ContextVar("previous_config", default=None) diff --git a/jsonargparse/_completions.py b/jsonargparse/_completions.py new file mode 100644 index 00000000..c693916a --- /dev/null +++ b/jsonargparse/_completions.py @@ -0,0 +1,298 @@ +import argparse +import inspect +import locale +import os +import re +import warnings +from collections import defaultdict +from contextlib import contextmanager, suppress +from contextvars import ContextVar +from enum import Enum +from importlib.util import find_spec +from subprocess import PIPE, Popen +from typing import List, Union + +from ._actions import ActionConfigFile, _ActionHelpClassPath, remove_actions +from ._parameter_resolvers import get_signature_parameters +from ._typehints import ( + ActionTypeHint, + callable_origin_types, + get_all_subclass_paths, + get_callable_return_type, + get_typehint_origin, + is_subclass, + type_to_str, +) +from ._util import NoneType, Path, import_object, unique + +shtab_shell: ContextVar = ContextVar("shtab_shell") +shtab_prog: ContextVar = ContextVar("shtab_prog") +shtab_preambles: ContextVar = ContextVar("shtab_preambles") + + +def handle_completions(parser): + if find_spec("argcomplete") and "_ARGCOMPLETE" in os.environ: + import argcomplete + + from ._common import parser_context + + with parser_context(load_value_mode=parser.parser_mode): + argcomplete.autocomplete(parser) + + if find_spec("shtab") and not getattr(parser, "parent_parser", None): + if not any(isinstance(action, ShtabAction) for action in parser._actions): + parser.add_argument("--print_shtab", action=ShtabAction) + + +# argcomplete + + +def get_files_completer(): + from argcomplete.completers import FilesCompleter + + return FilesCompleter() + + +def argcomplete_namespace(caller, parser, namespace): + if caller == "argcomplete": + namespace.__class__ = __import__("jsonargparse").Namespace + namespace = parser.merge_config(parser.get_defaults(skip_check=True), namespace).as_flat() + return namespace + + +def argcomplete_warn_redraw_prompt(prefix, message): + import argcomplete + + if prefix != "": + argcomplete.warn(message) + with suppress(Exception): + proc = Popen(f"ps -p {os.getppid()} -oppid=".split(), stdout=PIPE, stderr=PIPE) + stdout, _ = proc.communicate() + shell_pid = int(stdout.decode().strip()) + os.kill(shell_pid, 28) + _ = "_" if locale.getlocale()[1] != "UTF-8" else "\xa0" + return [_ + message.replace(" ", _), ""] + + +# shtab + + +class ShtabAction(argparse.Action): + def __init__( + self, + option_strings, + dest=argparse.SUPPRESS, + default=argparse.SUPPRESS, + ): + import shtab + + super().__init__( + option_strings=option_strings, + dest=dest, + default=default, + choices=shtab.SUPPORTED_SHELLS, + help="Print shtab shell completion script.", + ) + + def __call__(self, parser, namespace, shell, option_string=None): + import shtab + + warnings.warn("Automatic shtab support is experimental and subject to change.", UserWarning) + prog = norm_name(parser.prog) + assert prog + preambles = [] + if shell == "bash": + preambles = [bash_compgen_typehint.strip().replace("%s", prog)] + with prepare_actions_context(shell, prog, preambles): + shtab_prepare_actions(parser) + print(shtab.complete(parser, shell, preamble="\n".join(preambles))) + parser.exit(0) + + +@contextmanager +def prepare_actions_context(shell, prog, preambles): + token_shell = shtab_shell.set(shell) + token_prog = shtab_prog.set(prog) + token_preambles = shtab_preambles.set(preambles) + try: + yield + finally: + shtab_shell.reset(token_shell) + shtab_prog.reset(token_prog) + shtab_preambles.reset(token_preambles) + + +def norm_name(name: str) -> str: + return re.sub(r"\W+", "_", name) + + +def shtab_prepare_actions(parser) -> None: + remove_actions(parser, (ShtabAction,)) + if parser._subcommands_action: + for subparser in parser._subcommands_action._name_parser_map.values(): + shtab_prepare_actions(subparser) + for action in parser._actions: + shtab_prepare_action(action, parser) + + +def shtab_prepare_action(action, parser) -> None: + import shtab + + if action.choices or hasattr(action, "complete"): + return + + complete = None + if isinstance(action, ActionConfigFile): + complete = shtab.FILE + elif isinstance(action, ActionTypeHint): + typehint = action._typehint + if get_typehint_origin(typehint) == Union: + subtypes = [s for s in typehint.__args__ if s not in {NoneType, str, dict, list, tuple, bytes}] + if len(subtypes) == 1: + typehint = subtypes[0] + if is_subclass(typehint, Path): + if "f" in typehint._mode: + complete = shtab.FILE + elif "d" in typehint._mode: + complete = shtab.DIRECTORY + elif is_subclass(typehint, os.PathLike): + complete = shtab.FILE + if complete: + action.complete = complete + return + + choices = None + if isinstance(action, ActionTypeHint): + skip = getattr(action, "sub_add_kwargs", {}).get("skip", set()) + choices = get_typehint_choices(action._typehint, action.option_strings[0], parser, skip) + if shtab_shell.get() == "bash": + message = f"Expected type: {type_to_str(action._typehint)}" + add_bash_typehint_completion(parser, action, message, choices) + choices = None + elif isinstance(action, _ActionHelpClassPath): + choices = get_help_class_choices(action._baseclass) + if choices: + action.choices = choices + + +bash_compgen_typehint_name = "_jsonargparse_%s_compgen_typehint" +bash_compgen_typehint = """ +_jsonargparse_%%s_matched_choices() { + local TOTAL=$(echo "$1" | wc -w | tr -d " ") + if [ "$TOTAL" != 0 ]; then + local MATCH=$(echo "$2" | wc -w | tr -d " ") + printf "; $MATCH/$TOTAL matched choices" + fi +} +%(name)s() { + local MATCH=( $(IFS=" " compgen -W "$1" "$2") ) + if [ ${#MATCH[@]} = 0 ]; then + if [ "$COMP_TYPE" = 63 ]; then + MATCHED=$(_jsonargparse_%%s_matched_choices "$1" "${MATCH[*]}") + printf "%(b)s\\n$3$MATCHED\\n%(n)s" >&2 + kill -WINCH $$ + fi + else + IFS=" " compgen -W "$1" "$2" + if [ "$COMP_TYPE" = 63 ]; then + MATCHED=$(_jsonargparse_%%s_matched_choices "$1" "${MATCH[*]}") + printf "%(b)s\\n$3$MATCHED%(n)s" >&2 + fi + fi +} +""" % { + "name": bash_compgen_typehint_name, + "b": "$(tput setaf 5)", + "n": "$(tput sgr0)", +} + + +def add_bash_typehint_completion(parser, action, message, choices) -> None: + fn_typehint = norm_name(bash_compgen_typehint_name % shtab_prog.get()) + fn_name = parser.prog.replace(" [options] ", "_") + fn_name = norm_name(f"_jsonargparse_{fn_name}_{action.dest}_typehint") + fn = '%(fn_name)s(){ %(fn_typehint)s "%(choices)s" "$1" "%(message)s"; }' % { + "fn_name": fn_name, + "fn_typehint": fn_typehint, + "choices": " ".join(choices), + "message": message, + } + shtab_preambles.get().append(fn) + action.complete = {"bash": fn_name} + + +def get_typehint_choices(typehint, prefix, parser, skip, choices=None, added_subclasses=None) -> List[str]: + if choices is None: + choices = [] + if not added_subclasses: + added_subclasses = set() + if typehint is bool: + choices.extend(["true", "false"]) + elif typehint is type(None): + choices.append("null") + elif is_subclass(typehint, Enum): + choices.extend(list(typehint.__members__)) + else: + origin = get_typehint_origin(typehint) + if origin == Union: + for subtype in typehint.__args__: + if subtype in added_subclasses: + continue + get_typehint_choices(subtype, prefix, parser, skip, choices, added_subclasses) + elif ActionTypeHint.is_subclass_typehint(typehint): + added_subclasses.add(typehint) + choices.extend(add_subactions_and_get_subclass_choices(typehint, prefix, parser, skip, added_subclasses)) + elif origin in callable_origin_types: + return_type = get_callable_return_type(typehint) + if return_type and ActionTypeHint.is_subclass_typehint(return_type): + num_args = len(typehint.__args__) - 1 + skip.add(num_args) + choices.extend( + add_subactions_and_get_subclass_choices(return_type, prefix, parser, skip, added_subclasses) + ) + + return [] if choices == ["null"] else choices + + +def add_subactions_and_get_subclass_choices(typehint, prefix, parser, skip, added_subclasses) -> List[str]: + choices = [] + paths = get_all_subclass_paths(typehint) + init_args = defaultdict(list) + subclasses = defaultdict(list) + for path in paths: + choices.append(path) + cls = import_object(path) + params = get_signature_parameters(cls) + num_skip = next((s for s in skip if isinstance(s, int)), 0) + if num_skip > 0: + params = params[num_skip:] + for param in params: + if param.name not in skip: + init_args[param.name].append(param.annotation) + subclasses[param.name].append(path.rsplit(".", 1)[-1]) + + for name, subtypes in init_args.items(): + option_string = f"{prefix}.{name}" + if option_string not in parser._option_string_actions: + action = parser.add_argument(option_string) + for subtype in unique(subtypes): + subchoices = get_typehint_choices(subtype, option_string, parser, skip, None, added_subclasses) + if shtab_shell.get() == "bash": + message = f"Expected type: {type_to_str(subtype)}; " + message += f"Accepted by subclasses: {', '.join(subclasses[name])}" + add_bash_typehint_completion(parser, action, message, subchoices) + elif subchoices: + action.choices = subchoices + + return choices + + +def get_help_class_choices(typehint) -> List[str]: + choices = [] + if get_typehint_origin(typehint) == Union: + for subtype in typehint.__args__: + if inspect.isclass(subtype): + choices.extend(get_help_class_choices(subtype)) + else: + choices = get_all_subclass_paths(typehint) + return choices diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py index 546c69aa..cdfb3cba 100644 --- a/jsonargparse/_core.py +++ b/jsonargparse/_core.py @@ -44,6 +44,10 @@ lenient_check, parser_context, ) +from ._completions import ( + argcomplete_namespace, + handle_completions, +) from ._deprecated import ParserDeprecations from ._formatters import DefaultHelpFormatter, empty_help, get_env_var from ._jsonnet import ActionJsonnet @@ -70,8 +74,6 @@ strip_meta, ) from ._optionals import ( - argcomplete_autocomplete, - argcomplete_namespace, fsspec_support, get_config_read_mode, import_fsspec, @@ -375,7 +377,7 @@ def parse_args( # type: ignore[override] """ skip_check = get_private_kwargs(kwargs, _skip_check=False) return_parser_if_captured(self) - argcomplete_autocomplete(self) + handle_completions(self) if args is None: args = sys.argv[1:] diff --git a/jsonargparse/_deprecated.py b/jsonargparse/_deprecated.py index f1b063aa..e0c59897 100644 --- a/jsonargparse/_deprecated.py +++ b/jsonargparse/_deprecated.py @@ -13,7 +13,6 @@ from ._common import Action from ._namespace import Namespace -from ._optionals import FilesCompleterMethod from ._type_checking import ArgumentParser __all__ = [ @@ -271,7 +270,7 @@ def __call__(self, *args, **kwargs): use as type ``List[]`` with ``enable_path=True``. """ ) -class ActionPathList(Action, FilesCompleterMethod): +class ActionPathList(Action): """Action to check and store a list of file paths read from a plain text file or stream.""" def __init__(self, mode: Optional[str] = None, rel: str = "cwd", **kwargs): diff --git a/jsonargparse/_jsonschema.py b/jsonargparse/_jsonschema.py index 565e60de..c85e33e7 100644 --- a/jsonargparse/_jsonschema.py +++ b/jsonargparse/_jsonschema.py @@ -8,7 +8,6 @@ from ._loaders_dumpers import get_loader_exceptions, load_value from ._namespace import strip_meta from ._optionals import ( - argcomplete_warn_redraw_prompt, get_jsonschema_exceptions, import_jsonschema, ) @@ -117,6 +116,8 @@ def set_defaults(validator, properties, instance, schema): def completer(self, prefix, **kwargs): """Used by argcomplete, validates value and shows expected type.""" if chr(int(os.environ["COMP_TYPE"])) == "?": + from ._completions import argcomplete_warn_redraw_prompt + try: if prefix.strip() == "": raise ValueError() diff --git a/jsonargparse/_optionals.py b/jsonargparse/_optionals.py index 32f7f1fb..a4406ce9 100644 --- a/jsonargparse/_optionals.py +++ b/jsonargparse/_optionals.py @@ -1,11 +1,9 @@ """Code related to optional dependencies.""" import inspect -import locale import os -from contextlib import contextmanager, suppress +from contextlib import contextmanager from importlib.util import find_spec -from subprocess import PIPE, Popen from typing import Optional __all__ = [ @@ -21,7 +19,6 @@ jsonnet_support = find_spec("_jsonnet") is not None url_support = find_spec("requests") is not None docstring_parser_support = find_spec("docstring_parser") is not None -argcomplete_support = find_spec("argcomplete") is not None fsspec_support = find_spec("fsspec") is not None ruyaml_support = find_spec("ruyaml") is not None omegaconf_support = find_spec("omegaconf") is not None @@ -117,12 +114,6 @@ def import_docstring_parser(importer): return docstring_parser -def import_argcomplete(importer): - with missing_package_raise("argcomplete", importer): - import argcomplete - return argcomplete - - def import_fsspec(importer): with missing_package_raise("fsspec", importer): import fsspec @@ -243,49 +234,6 @@ def get_doc_short_description(function_or_class, method_name=None, logger=None): return None -def get_files_completer(): - from argcomplete.completers import FilesCompleter - - return FilesCompleter() - - -class FilesCompleterMethod: - """Completer method for Action classes that should complete files.""" - - def completer(self, prefix, **kwargs): - files_completer = get_files_completer() - return sorted(files_completer(prefix, **kwargs)) - - -def argcomplete_autocomplete(parser): - if argcomplete_support: - argcomplete = import_argcomplete("argcomplete_autocomplete") - from ._common import parser_context - - with parser_context(load_value_mode=parser.parser_mode): - argcomplete.autocomplete(parser) - - -def argcomplete_namespace(caller, parser, namespace): - if caller == "argcomplete": - namespace.__class__ = __import__("jsonargparse").Namespace - namespace = parser.merge_config(parser.get_defaults(skip_check=True), namespace).as_flat() - return namespace - - -def argcomplete_warn_redraw_prompt(prefix, message): - argcomplete = import_argcomplete("argcomplete_warn_redraw_prompt") - if prefix != "": - argcomplete.warn(message) - with suppress(Exception): - proc = Popen(f"ps -p {os.getppid()} -oppid=".split(), stdout=PIPE, stderr=PIPE) - stdout, _ = proc.communicate() - shell_pid = int(stdout.decode().strip()) - os.kill(shell_pid, 28) - _ = "_" if locale.getlocale()[1] != "UTF-8" else "\xa0" - return [_ + message.replace(" ", _), ""] - - def get_omegaconf_loader(): """Returns a yaml loader function based on OmegaConf which supports variable interpolation.""" import io diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index e8eb308d..37a6090e 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -58,9 +58,7 @@ ) from ._namespace import Namespace from ._optionals import ( - argcomplete_warn_redraw_prompt, get_alias_target, - get_files_completer, is_alias_type, is_annotated, is_annotated_validator, @@ -635,6 +633,8 @@ def extra_help(self): def completer(self, prefix, **kwargs): """Used by argcomplete, validates value and shows expected type.""" + from ._completions import argcomplete_warn_redraw_prompt, get_files_completer + if self.choices: return [str(c) for c in self.choices] elif self._typehint == bool: diff --git a/jsonargparse_tests/conftest.py b/jsonargparse_tests/conftest.py index faf07506..26ef83b7 100644 --- a/jsonargparse_tests/conftest.py +++ b/jsonargparse_tests/conftest.py @@ -1,6 +1,7 @@ import logging import os import platform +import re from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout from functools import wraps from importlib.util import find_spec @@ -150,10 +151,12 @@ def source_unavailable(): yield -def get_parser_help(parser: ArgumentParser) -> str: +def get_parser_help(parser: ArgumentParser, strip=False) -> str: out = StringIO() with patch.dict(os.environ, {"COLUMNS": "200"}): parser.print_help(out) + if strip: + return re.sub(" *", " ", out.getvalue()) return out.getvalue() diff --git a/jsonargparse_tests/test_argcomplete.py b/jsonargparse_tests/test_argcomplete.py index de88c3e8..7c3a3a53 100644 --- a/jsonargparse_tests/test_argcomplete.py +++ b/jsonargparse_tests/test_argcomplete.py @@ -58,6 +58,19 @@ def complete_line(parser, value): return out.getvalue(), err.getvalue() +def test_handle_completions(parser): + parser.add_argument("--option") + with patch("argcomplete.autocomplete") as mock, patch.dict( + os.environ, + { + "_ARGCOMPLETE": "1", + }, + ): + parser.parse_args([]) + assert mock.called + assert mock.call_args[0][0] is parser + + def test_complete_nested_one_option(parser): parser.add_argument("--group1.op") out, err = complete_line(parser, "tool.py --group1") diff --git a/jsonargparse_tests/test_optionals.py b/jsonargparse_tests/test_optionals.py index 36bf1436..9b4896b2 100644 --- a/jsonargparse_tests/test_optionals.py +++ b/jsonargparse_tests/test_optionals.py @@ -4,11 +4,9 @@ from jsonargparse import get_config_read_mode, set_config_read_mode from jsonargparse._optionals import ( - argcomplete_support, docstring_parser_support, fsspec_support, get_docstring_parse_options, - import_argcomplete, import_docstring_parser, import_fsspec, import_jsonnet, @@ -109,21 +107,6 @@ def test_docstring_parse_options(): set_docstring_parse_options(attribute_docstrings="invalid") -# argcomplete support - - -@pytest.mark.skipif(not argcomplete_support, reason="argcomplete package is required") -def test_argcomplete_support_true(): - import_argcomplete("test_argcomplete_support_true") - - -@pytest.mark.skipif(argcomplete_support, reason="argcomplete package should not be installed") -def test_argcomplete_support_false(): - with pytest.raises(ImportError) as ctx: - import_argcomplete("test_argcomplete_support_false") - ctx.match("test_argcomplete_support_false") - - # fsspec support diff --git a/jsonargparse_tests/test_shtab.py b/jsonargparse_tests/test_shtab.py new file mode 100644 index 00000000..83704e31 --- /dev/null +++ b/jsonargparse_tests/test_shtab.py @@ -0,0 +1,332 @@ +import re +import subprocess +import tempfile +from enum import Enum +from importlib.util import find_spec +from os import PathLike +from pathlib import Path +from typing import Any, Callable, Optional, Union +from unittest.mock import patch +from warnings import catch_warnings + +import pytest + +from jsonargparse import ArgumentParser +from jsonargparse._completions import norm_name +from jsonargparse._typehints import type_to_str +from jsonargparse.typing import Path_drw, Path_fr +from jsonargparse_tests.conftest import get_parse_args_stdout + + +@pytest.fixture(autouse=True) +def skip_if_no_shtab(): + if not find_spec("shtab"): + pytest.skip("shtab package is required") + + +@pytest.fixture(autouse=True) +def skip_if_wsl_message(): + popen = subprocess.Popen(["bash", "-c", "echo"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, _ = popen.communicate() + if "Windows Subsystem for Linux has no installed distributions" in out.decode().replace("\x00", ""): + pytest.skip(out.decode().replace("\x00", "")) + + +@pytest.fixture(autouse=True) +def term_env_var(): + with patch.dict("os.environ", {"TERM": "xterm-256color", "COLUMNS": "200"}): + yield + + +@pytest.fixture +def parser() -> ArgumentParser: + return ArgumentParser(exit_on_error=False, prog="tool") + + +def get_shtab_script(parser, shell): + with catch_warnings(record=True) as w: + shtab_script = get_parse_args_stdout(parser, [f"--print_shtab={shell}"]) + assert "support is experimental" in str(w[0].message) + return shtab_script + + +def assert_bash_typehint_completions(subtests, shtab_script, completions): + if isinstance(shtab_script, ArgumentParser): + shtab_script = get_shtab_script(shtab_script, "bash") + with tempfile.TemporaryDirectory() as tmpdir: + shtab_script_path = Path(tmpdir) / "comp.sh" + shtab_script_path.write_text(shtab_script) + + for dest, typehint, word, choices, extra in completions: + typehint = type_to_str(typehint) + with subtests.test(f"{word} -> {extra}"): + sh = f'source {shtab_script_path}; COMP_TYPE=63 _jsonargparse_tool_{norm_name(dest)}_typehint "{word}"' + popen = subprocess.Popen(["bash", "-c", sh], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = popen.communicate() + assert list(out.decode().split()) == choices + if extra is None: + assert f"Expected type: {typehint}" in err.decode() + elif re.match(r"^\d/\d$", extra): + assert f"Expected type: {typehint}; {extra} matched choices" in err.decode() + else: + assert f"Expected type: {typehint}; Accepted by subclasses: {extra}" in err.decode() + + +def test_bash_any(parser, subtests): + parser.add_argument("--any", type=Any) + assert_bash_typehint_completions( + subtests, + parser, + [ + ("any", Any, "", [], None), + ], + ) + + +def test_bash_bool(parser, subtests): + parser.add_argument("--bool", type=bool) + assert_bash_typehint_completions( + subtests, + parser, + [ + ("bool", bool, "", ["true", "false"], "2/2"), + ], + ) + + +def test_bash_optional_bool(parser, subtests): + parser.add_argument("--bool", type=Optional[bool]) + assert_bash_typehint_completions( + subtests, + parser, + [ + ("bool", Optional[bool], "", ["true", "false", "null"], "3/3"), + ("bool", Optional[bool], "tr", ["true"], "1/3"), + ("bool", Optional[bool], "x", [], "0/3"), + ], + ) + + +def test_bash_argument_group(parser, subtests): + group = parser.add_argument_group("Group1") + group.add_argument("--bool", type=bool) + assert_bash_typehint_completions( + subtests, + parser, + [ + ("bool", bool, "", ["true", "false"], "2/2"), + ], + ) + + +class AXEnum(Enum): + ABC = "abc" + XY = "xy" + XZ = "xz" + + +def test_bash_enum(parser, subtests): + parser.add_argument("--enum", type=AXEnum) + assert_bash_typehint_completions( + subtests, + parser, + [ + ("enum", AXEnum, "", ["ABC", "XY", "XZ"], "3/3"), + ("enum", AXEnum, "A", ["ABC"], "1/3"), + ("enum", AXEnum, "X", ["XY", "XZ"], "2/3"), + ], + ) + + +def test_bash_optional_enum(parser, subtests): + parser.add_argument("--enum", type=Optional[AXEnum]) + assert_bash_typehint_completions( + subtests, + parser, + [ + ("enum", Optional[AXEnum], "", ["ABC", "XY", "XZ", "null"], "4/4"), + ], + ) + + +def test_bash_union(parser, subtests): + typehint = Optional[Union[bool, AXEnum]] + parser.add_argument("--union", type=typehint) + assert_bash_typehint_completions( + subtests, + parser, + [ + ("union", typehint, "", ["true", "false", "ABC", "XY", "XZ", "null"], "6/6"), + ("union", typehint, "z", [], "0/6"), + ], + ) + + +def test_bash_config(parser): + parser.add_argument("--cfg", action="config") + shtab_script = get_shtab_script(parser, "bash") + assert "_cfg_COMPGEN=_shtab_compgen_files" in shtab_script + + +def test_bash_dir(parser): + parser.add_argument("--path", type=Path_drw) + shtab_script = get_shtab_script(parser, "bash") + assert "_path_COMPGEN=_shtab_compgen_dirs" in shtab_script + + +@pytest.mark.parametrize("path_type", [Path_fr, PathLike, Path, Union[PathLike, str]]) +def test_bash_file(parser, path_type): + parser.add_argument("--path", type=path_type) + shtab_script = get_shtab_script(parser, "bash") + assert "_path_COMPGEN=_shtab_compgen_files" in shtab_script + + +@pytest.mark.parametrize("path_type", [Path_fr, PathLike, Path, Union[Union[PathLike, str], dict]]) +def test_bash_optional_file(parser, path_type): + parser.add_argument("--path", type=Optional[path_type]) + shtab_script = get_shtab_script(parser, "bash") + assert "_path_COMPGEN=_shtab_compgen_files" in shtab_script + + +class Base: + def __init__(self, p1: int): + pass + + +class SubA(Base): + def __init__(self, p1: int, p2: AXEnum): + pass + + +class SubB(Base): + def __init__(self, p1: int, p3: float): + pass + + +def test_bash_subclasses_help(parser): + parser.add_argument("--cls", type=Base) + shtab_script = get_shtab_script(parser, "bash") + assert "'--cls.help' '--cls' '--cls.p1' '--cls.p2' '--cls.p3'" in shtab_script + classes = f"'{__name__}.Base' '{__name__}.SubA' '{__name__}.SubB'" + assert f"_cls_help_choices=({classes})" in shtab_script + + +def test_bash_subclasses(parser, subtests): + parser.add_argument("--cls", type=Base) + shtab_script = get_shtab_script(parser, "bash") + classes = f"{__name__}.Base {__name__}.SubA {__name__}.SubB".split() + assert_bash_typehint_completions( + subtests, + shtab_script, + [ + ("cls", Base, "", classes, "3/3"), + ("cls", Base, f"{__name__}.S", classes[1:], "2/3"), + ("cls.p1", int, "1", [], "Base, SubA, SubB"), + ("cls.p3", float, "", [], "SubB"), + ], + ) + + +class Other: + def __init__(self, o1: bool): + pass + + +def test_bash_union_subclasses(parser, subtests): + parser.add_argument("--cls", type=Union[Base, Other]) + shtab_script = get_shtab_script(parser, "bash") + assert "'--cls.help' '--cls' '--cls.p1' '--cls.p2' '--cls.p3' '--cls.o1'" in shtab_script + classes = f"'{__name__}.Base' '{__name__}.SubA' '{__name__}.SubB' '{__name__}.Other'" + assert f"_cls_help_choices=({classes})" in shtab_script + assert_bash_typehint_completions( + subtests, + shtab_script, + [ + ("cls.p1", int, "", [], "Base, SubA, SubB"), + ], + ) + + +class SupBase: + def __init__(self, s1: Base): + pass + + +class SupA(SupBase): + def __init__(self, s1: Optional[Base]): + pass + + +def test_bash_nested_subclasses(parser, subtests): + parser.add_argument("--cls", type=SupBase) + shtab_script = get_shtab_script(parser, "bash") + assert "'--cls.help' '--cls' '--cls.s1' '--cls.s1.p1' '--cls.s1.p2' '--cls.s1.p3'" in shtab_script + assert_bash_typehint_completions( + subtests, + shtab_script, + [ + ("cls.s1.p2", AXEnum, "X", ["XY", "XZ"], "SubA; 2/3 matched choices"), + ], + ) + + +def test_bash_callable_return_class(parser, subtests): + parser.add_argument("--cls", type=Callable[[int], Base]) + shtab_script = get_shtab_script(parser, "bash") + assert "_option_strings=('-h' '--help' '--cls' '--cls.p2' '--cls.p3')" in shtab_script + assert "--cls.p1" not in shtab_script + classes = f"{__name__}.Base {__name__}.SubA {__name__}.SubB".split() + assert_bash_typehint_completions( + subtests, + shtab_script, + [ + ("cls", Callable[[int], Base], "", classes, "3/3"), + ("cls.p2", AXEnum, "z", [], "SubA"), + ], + ) + + +def test_bash_subcommands(parser, subparser, subtests): + subparser.add_argument("--enum", type=AXEnum) + subparser2 = ArgumentParser() + subparser2.add_argument("--cls", type=Base) + subcommands = parser.add_subcommands() + subcommands.add_subcommand("s1", subparser) + subcommands.add_subcommand("s2", subparser2) + + help_str = get_parse_args_stdout(parser, ["--help"]) + assert "--print_shtab" in help_str + help_str = get_parse_args_stdout(parser, ["s1", "--help"]) + assert "--print_shtab" not in help_str + + shtab_script = get_shtab_script(parser, "bash") + assert "_subparsers=('s1' 's2')" in shtab_script + + assert "_s1_option_strings=('-h' '--help' '--enum')" in shtab_script + assert "_s2_option_strings=('-h' '--help' '--cls.help' '--cls' '--cls.p1' '--cls.p2' '--cls.p3')" in shtab_script + classes = f"'{__name__}.Base' '{__name__}.SubA' '{__name__}.SubB'" + assert f"_s2___cls_help_choices=({classes})" in shtab_script + + assert_bash_typehint_completions( + subtests, + shtab_script, + [ + ("s1.enum", AXEnum, "A", ["ABC"], "1/3"), + ("s2.cls.p1", int, "1", [], "Base, SubA, SubB"), + ], + ) + + +def test_zsh_script(parser): + parser.add_argument("--enum", type=Optional[AXEnum]) + parser.add_argument("--path", type=PathLike) + parser.add_argument("--cls", type=Base) + shtab_script = get_shtab_script(parser, "zsh") + assert ":enum:(ABC XY XZ null)" in shtab_script + assert ":path:_files" in shtab_script + classes = f"{__name__}.Base {__name__}.SubA {__name__}.SubB" + assert f":cls.help:({classes})" in shtab_script + assert f":cls:({classes})" in shtab_script + assert ":cls.p1:" in shtab_script + assert ":cls.p2:(ABC XY XZ)" in shtab_script + assert ":cls.p3:" in shtab_script diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index be5161fe..2ddf453e 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -263,8 +263,8 @@ def test_tuple_without_arg(parser): parser.add_argument("--tuple", type=tuple) cfg = parser.parse_args(['--tuple=[1, "a", True]']) assert (1, "a", True) == cfg.tuple - help_str = get_parser_help(parser) - assert "--tuple [ITEM,...] (type: tuple, default: null)" in help_str + help_str = get_parser_help(parser, strip=True) + assert "--tuple [ITEM,...] (type: tuple, default: null)" in help_str def test_tuples_nested(parser): @@ -296,8 +296,8 @@ def test_tuple_union(parser, tmp_cwd): pytest.raises(ArgumentError, lambda: parser.parse_args(['--tuple=[2, "a", "b", 5]'])) pytest.raises(ArgumentError, lambda: parser.parse_args(['--tuple=[2, "a"]'])) pytest.raises(ArgumentError, lambda: parser.parse_args(['--tuple={"a":1, "b":"2"}'])) - help_str = get_parser_help(parser) - assert "--tuple [ITEM,...] (type: Tuple[Union[int, EnumABC], Path_fc, NotEmptyStr], default: null)" in help_str + help_str = get_parser_help(parser, strip=True) + assert "--tuple [ITEM,...] (type: Tuple[Union[int, EnumABC], Path_fc, NotEmptyStr], default: null)" in help_str # list tests diff --git a/pyproject.toml b/pyproject.toml index 9a6e4f67..21a2259e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ all = [ "jsonargparse[jsonnet]", "jsonargparse[urls]", "jsonargparse[fsspec]", - "jsonargparse[argcomplete]", "jsonargparse[ruyaml]", "jsonargparse[omegaconf]", "jsonargparse[typing-extensions]", @@ -65,6 +64,9 @@ urls = [ fsspec = [ "fsspec>=0.8.4", ] +shtab = [ + "shtab>=1.7.1", +] argcomplete = [ "argcomplete>=2.0.0; python_version < '3.8'", "argcomplete>=3.3.0; python_version >= '3.8'", @@ -83,6 +85,8 @@ reconplogger = [ ] test = [ "jsonargparse[test-no-urls]", + "jsonargparse[shtab]", + "jsonargparse[argcomplete]", "types-PyYAML>=6.0.11", "types-requests>=2.28.9", "responses>=0.12.0",