From bf43453af470b085a97efb6074985f949ab6fbca Mon Sep 17 00:00:00 2001 From: Aki Sasaki Date: Mon, 19 Aug 2019 17:53:09 -0700 Subject: [PATCH] allow for overriding work_dir and artifact_dir via commandline options --- scriptworker/client.py | 64 ++++++++++++++++++------ scriptworker/test/test_client.py | 83 +++++++++++++++----------------- 2 files changed, 88 insertions(+), 59 deletions(-) diff --git a/scriptworker/client.py b/scriptworker/client.py index debf976d..b7a9716d 100644 --- a/scriptworker/client.py +++ b/scriptworker/client.py @@ -10,6 +10,7 @@ """ import aiohttp +import argparse import asyncio import jsonschema import logging @@ -134,13 +135,13 @@ def callback(match): def sync_main(async_main, config_path=None, default_config=None, + parser=None, parser_desc=None, commandline_args=None, should_validate_task=True, loop_function=asyncio.get_event_loop): """Entry point for scripts using scriptworker. This function sets up the basic needs for a script to run. More specifically: * it creates the scriptworker context and initializes it with the provided config - * the path to the config file is either taken from `config_path` or from `sys.argv[1]`. - * it verifies `sys.argv` doesn't have more arguments than the config path. + * config options are either taken from `commandline_args` or from `sys.argv[1:]` * it creates the asyncio event loop so that `async_main` can run Args: @@ -155,8 +156,20 @@ def sync_main(async_main, config_path=None, default_config=None, event loop; here for testing purposes. Defaults to ``asyncio.get_event_loop``. + Raises: + ScriptWorkerException: if the deprecated `config_path` option is used + with `parser`, `parser_desc`, or `commandline_args`. + """ - context = _init_context(config_path, default_config) + if config_path: + if parser or parser_desc or commandline_args: + raise ScriptWorkerException("Deprecated `config_path` is incompatible with `parser`, `parser_desc`, and `commandline_args`!") + log.warning("`sync_main` `config_path` usage is deprecated!") + commandline_args = [config_path] + parser = parser or get_parser(parser_desc) + commandline_args = commandline_args or sys.argv[1:] + parsed_args = parser.parse_args(commandline_args) + context = _init_context(parsed_args, default_config) _init_logging(context) if should_validate_task: validate_task_schema(context) @@ -164,27 +177,48 @@ def sync_main(async_main, config_path=None, default_config=None, loop.run_until_complete(_handle_asyncio_loop(async_main, context)) -def _init_context(config_path=None, default_config=None): - context = ScriptContext() +def get_parser(desc=None): + """Create a default *script argparse parser. + + Args: + desc (str, optional): the description for the parser. - if config_path is None: - if len(sys.argv) != 2: - _usage() - config_path = sys.argv[1] + Returns: + argparse.Namespace: the parsed args. + + """ + parser = argparse.ArgumentParser(description=desc) + parser.add_argument( + "--work-dir", + type=str, + required=False, + help="The path to use as the script working directory", + ) + parser.add_argument( + "--artifact-dir", + type=str, + required=False, + help="The path to use as the artifact upload directory", + ) + parser.add_argument("config_path", type=str) + return parser + + +def _init_context(parsed_args, default_config=None): + context = ScriptContext() context.config = {} if default_config is None else default_config - context.config.update(load_json_or_yaml(config_path, is_path=True)) + context.config.update(load_json_or_yaml(parsed_args.config_path, is_path=True)) + for var in ("work_dir", "artifact_dir"): + path = getattr(parsed_args, var, None) + if path: + context.config[var] = path context.task = get_task(context.config) return context -def _usage(): - print('Usage: {} CONFIG_FILE'.format(sys.argv[0]), file=sys.stderr) - sys.exit(1) - - def _init_logging(context): logging.basicConfig( format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', diff --git a/scriptworker/test/test_client.py b/scriptworker/test/test_client.py index e2fec0bc..4be2056f 100644 --- a/scriptworker/test/test_client.py +++ b/scriptworker/test/test_client.py @@ -4,6 +4,7 @@ """ import aiohttp +import argparse import asyncio import json import logging @@ -23,6 +24,7 @@ from scriptworker.constants import DEFAULT_CONFIG from scriptworker.context import ScriptContext from scriptworker.exceptions import ScriptWorkerException, ScriptWorkerTaskException, TaskVerificationError +from scriptworker.utils import makedirs from . import noop_sync @@ -246,66 +248,59 @@ def loop_function(): assert len(async_main_calls) == 1 # async_main was called once -@pytest.mark.parametrize('does_use_argv, default_config', ( - (True, None), - (True, {'some_param_only_in_default': 'default_value', 'worker_type': 'default_value'}), - (False, None), - (True, {'some_param_only_in_default': 'default_value', 'worker_type': 'default_value'}), +def test_get_parser(): + """`get_parser` stores `config_path` and the optional `work_dir` and + `artifact_dir`, but raises `SystemExit` on unknown options. + """ + parser = client.get_parser() + parsed = parser.parse_args(["foo"]) + assert parsed.config_path == "foo" + assert parsed.work_dir is None + assert parsed.artifact_dir is None + parsed = parser.parse_args( + ["--work-dir", "work", "--artifact-dir", "artifact", "bar"] + ) + assert parsed.config_path == "bar" + assert parsed.work_dir == "work" + assert parsed.artifact_dir == "artifact" + for args in ([], ["--illegal", "foo"]): + with pytest.raises(SystemExit): + parser.parse_args(args) + + +@pytest.mark.parametrize('default_config, dirs_in_args', ( + (None, False), + (None, True), + ({'some_param_only_in_default': 'default_value', 'worker_type': 'default_value'}, False), + ({'some_param_only_in_default': 'default_value', 'worker_type': 'default_value'}, True), )) -def test_init_context(config, monkeypatch, mocker, does_use_argv, default_config): - copyfile(BASIC_TASK, os.path.join(config['work_dir'], "task.json")) +def test_init_context(config, monkeypatch, mocker, dirs_in_args, default_config): + expected_config = deepcopy(config) with tempfile.NamedTemporaryFile('w+') as f: json.dump(config, f) f.seek(0) kwargs = {'default_config': default_config} - - if does_use_argv: - monkeypatch.setattr(sys, 'argv', ['some_binary_name', f.name]) - else: - kwargs['config_path'] = f.name - - context = client._init_context(**kwargs) + parsed_args = argparse.Namespace() + parsed_args.config_path = f.name + if dirs_in_args: + expected_config["work_dir"] = "{}2".format(config["work_dir"]) + parsed_args.work_dir = "{}2".format(config["work_dir"]) + makedirs(expected_config["work_dir"]) + expected_config["artifact_dir"] = "{}2".format(config["artifact_dir"]) + parsed_args.artifact_dir = "{}2".format(config["artifact_dir"]) + copyfile(BASIC_TASK, os.path.join(expected_config['work_dir'], "task.json")) + context = client._init_context(parsed_args, **kwargs) assert isinstance(context, ScriptContext) assert context.task['this_is_a_task'] is True - expected_config = deepcopy(config) if default_config: expected_config['some_param_only_in_default'] = 'default_value' assert context.config == expected_config assert context.config['worker_type'] != 'default_value' - mock_open = mocker.patch('builtins.open') - mock_open.assert_not_called() - - -def test_fail_init_context(capsys, monkeypatch): - for i in range(1, 10): - if i == 2: - # expected working case - continue - - argv = ['argv{}'.format(j) for j in range(i)] - monkeypatch.setattr(sys, 'argv', argv) - with pytest.raises(SystemExit): - context = client._init_context() - - # XXX This prevents usage from being printed out when the test is passing. Assertions are - # done in test_usage - capsys.readouterr() - - -def test_usage(capsys, monkeypatch): - monkeypatch.setattr(sys, 'argv', ['my_binary']) - with pytest.raises(SystemExit): - client._usage() - - captured = capsys.readouterr() - assert captured.out == '' - assert captured.err == 'Usage: my_binary CONFIG_FILE\n' - @pytest.mark.parametrize('is_verbose, log_level', ( (True, logging.DEBUG),