Skip to content

Commit

Permalink
allow for overriding work_dir and artifact_dir via commandline options
Browse files Browse the repository at this point in the history
  • Loading branch information
escapewindow committed Aug 20, 2019
1 parent a120168 commit bf43453
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 59 deletions.
64 changes: 49 additions & 15 deletions scriptworker/client.py
Expand Up @@ -10,6 +10,7 @@
"""
import aiohttp
import argparse
import asyncio
import jsonschema
import logging
Expand Down Expand Up @@ -134,13 +135,13 @@ def callback(match):


def sync_main(async_main, config_path=None, default_config=None,
parser=None, parser_desc=None, commandline_args=None,
should_validate_task=True, loop_function=asyncio.get_event_loop):
"""Entry point for scripts using scriptworker.
This function sets up the basic needs for a script to run. More specifically:
* it creates the scriptworker context and initializes it with the provided config
* the path to the config file is either taken from `config_path` or from `sys.argv[1]`.
* it verifies `sys.argv` doesn't have more arguments than the config path.
* config options are either taken from `commandline_args` or from `sys.argv[1:]`
* it creates the asyncio event loop so that `async_main` can run
Args:
Expand All @@ -155,36 +156,69 @@ def sync_main(async_main, config_path=None, default_config=None,
event loop; here for testing purposes. Defaults to
``asyncio.get_event_loop``.
Raises:
ScriptWorkerException: if the deprecated `config_path` option is used
with `parser`, `parser_desc`, or `commandline_args`.
"""
context = _init_context(config_path, default_config)
if config_path:
if parser or parser_desc or commandline_args:
raise ScriptWorkerException("Deprecated `config_path` is incompatible with `parser`, `parser_desc`, and `commandline_args`!")
log.warning("`sync_main` `config_path` usage is deprecated!")
commandline_args = [config_path]
parser = parser or get_parser(parser_desc)
commandline_args = commandline_args or sys.argv[1:]
parsed_args = parser.parse_args(commandline_args)
context = _init_context(parsed_args, default_config)
_init_logging(context)
if should_validate_task:
validate_task_schema(context)
loop = loop_function()
loop.run_until_complete(_handle_asyncio_loop(async_main, context))


def _init_context(config_path=None, default_config=None):
context = ScriptContext()
def get_parser(desc=None):
"""Create a default *script argparse parser.
Args:
desc (str, optional): the description for the parser.
if config_path is None:
if len(sys.argv) != 2:
_usage()
config_path = sys.argv[1]
Returns:
argparse.Namespace: the parsed args.
"""
parser = argparse.ArgumentParser(description=desc)
parser.add_argument(
"--work-dir",
type=str,
required=False,
help="The path to use as the script working directory",
)
parser.add_argument(
"--artifact-dir",
type=str,
required=False,
help="The path to use as the artifact upload directory",
)
parser.add_argument("config_path", type=str)
return parser


def _init_context(parsed_args, default_config=None):
context = ScriptContext()

context.config = {} if default_config is None else default_config
context.config.update(load_json_or_yaml(config_path, is_path=True))
context.config.update(load_json_or_yaml(parsed_args.config_path, is_path=True))
for var in ("work_dir", "artifact_dir"):
path = getattr(parsed_args, var, None)
if path:
context.config[var] = path

context.task = get_task(context.config)

return context


def _usage():
print('Usage: {} CONFIG_FILE'.format(sys.argv[0]), file=sys.stderr)
sys.exit(1)


def _init_logging(context):
logging.basicConfig(
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
Expand Down
83 changes: 39 additions & 44 deletions scriptworker/test/test_client.py
Expand Up @@ -4,6 +4,7 @@
"""

import aiohttp
import argparse
import asyncio
import json
import logging
Expand All @@ -23,6 +24,7 @@
from scriptworker.constants import DEFAULT_CONFIG
from scriptworker.context import ScriptContext
from scriptworker.exceptions import ScriptWorkerException, ScriptWorkerTaskException, TaskVerificationError
from scriptworker.utils import makedirs

from . import noop_sync

Expand Down Expand Up @@ -246,66 +248,59 @@ def loop_function():
assert len(async_main_calls) == 1 # async_main was called once


@pytest.mark.parametrize('does_use_argv, default_config', (
(True, None),
(True, {'some_param_only_in_default': 'default_value', 'worker_type': 'default_value'}),
(False, None),
(True, {'some_param_only_in_default': 'default_value', 'worker_type': 'default_value'}),
def test_get_parser():
"""`get_parser` stores `config_path` and the optional `work_dir` and
`artifact_dir`, but raises `SystemExit` on unknown options.
"""
parser = client.get_parser()
parsed = parser.parse_args(["foo"])
assert parsed.config_path == "foo"
assert parsed.work_dir is None
assert parsed.artifact_dir is None
parsed = parser.parse_args(
["--work-dir", "work", "--artifact-dir", "artifact", "bar"]
)
assert parsed.config_path == "bar"
assert parsed.work_dir == "work"
assert parsed.artifact_dir == "artifact"
for args in ([], ["--illegal", "foo"]):
with pytest.raises(SystemExit):
parser.parse_args(args)


@pytest.mark.parametrize('default_config, dirs_in_args', (
(None, False),
(None, True),
({'some_param_only_in_default': 'default_value', 'worker_type': 'default_value'}, False),
({'some_param_only_in_default': 'default_value', 'worker_type': 'default_value'}, True),
))
def test_init_context(config, monkeypatch, mocker, does_use_argv, default_config):
copyfile(BASIC_TASK, os.path.join(config['work_dir'], "task.json"))
def test_init_context(config, monkeypatch, mocker, dirs_in_args, default_config):
expected_config = deepcopy(config)
with tempfile.NamedTemporaryFile('w+') as f:
json.dump(config, f)
f.seek(0)

kwargs = {'default_config': default_config}

if does_use_argv:
monkeypatch.setattr(sys, 'argv', ['some_binary_name', f.name])
else:
kwargs['config_path'] = f.name

context = client._init_context(**kwargs)
parsed_args = argparse.Namespace()
parsed_args.config_path = f.name
if dirs_in_args:
expected_config["work_dir"] = "{}2".format(config["work_dir"])
parsed_args.work_dir = "{}2".format(config["work_dir"])
makedirs(expected_config["work_dir"])
expected_config["artifact_dir"] = "{}2".format(config["artifact_dir"])
parsed_args.artifact_dir = "{}2".format(config["artifact_dir"])
copyfile(BASIC_TASK, os.path.join(expected_config['work_dir'], "task.json"))
context = client._init_context(parsed_args, **kwargs)

assert isinstance(context, ScriptContext)
assert context.task['this_is_a_task'] is True

expected_config = deepcopy(config)
if default_config:
expected_config['some_param_only_in_default'] = 'default_value'

assert context.config == expected_config
assert context.config['worker_type'] != 'default_value'

mock_open = mocker.patch('builtins.open')
mock_open.assert_not_called()


def test_fail_init_context(capsys, monkeypatch):
for i in range(1, 10):
if i == 2:
# expected working case
continue

argv = ['argv{}'.format(j) for j in range(i)]
monkeypatch.setattr(sys, 'argv', argv)
with pytest.raises(SystemExit):
context = client._init_context()

# XXX This prevents usage from being printed out when the test is passing. Assertions are
# done in test_usage
capsys.readouterr()


def test_usage(capsys, monkeypatch):
monkeypatch.setattr(sys, 'argv', ['my_binary'])
with pytest.raises(SystemExit):
client._usage()

captured = capsys.readouterr()
assert captured.out == ''
assert captured.err == 'Usage: my_binary CONFIG_FILE\n'


@pytest.mark.parametrize('is_verbose, log_level', (
(True, logging.DEBUG),
Expand Down

0 comments on commit bf43453

Please sign in to comment.