Skip to content

Commit

Permalink
Allow forwarding arguments to submitted script (#36)
Browse files Browse the repository at this point in the history
Forward all remaining arguments to `dask-yarn submit ...` to the
submitted script.
  • Loading branch information
jcrist committed Nov 2, 2018
1 parent 404d571 commit ba9e07c
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 17 deletions.
18 changes: 15 additions & 3 deletions dask_yarn/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ def _action_max_length(self):
def _action_max_length(self, value):
pass

def _format_args(self, action, default_metavar):
"""Format remainder arguments nicer"""
get_metavar = self._metavar_formatter(action, default_metavar)
if action.nargs == argparse.REMAINDER:
return '[%s...]' % get_metavar(1)
return super(_Formatter, self)._format_args(action, default_metavar)


class _VersionAction(argparse.Action):
def __init__(self, option_strings, version=None, dest=argparse.SUPPRESS,
Expand Down Expand Up @@ -130,6 +137,8 @@ def _parse_submit_kwargs(**kwargs):
@subcommand(entry_subs,
'submit', 'Submit a Dask application to a YARN cluster',
arg("script", help="Path to a python script to run on the client"),
arg("args", nargs=argparse.REMAINDER,
help="Any additional arguments to forward to `script`"),
arg("--name", help="The application name"),
arg("--queue", help="The queue to deploy to"),
arg("--tags",
Expand Down Expand Up @@ -307,15 +316,18 @@ def run():

@subcommand(services.subs,
'client', 'Start a Dask client process',
arg("script", help="Path to a Python script to run."))
def client(script):
arg("script", help="Path to a Python script to run."),
arg("args", nargs=argparse.REMAINDER,
help="Any additional arguments to forward to `script`"))
def client(script, args=None):
app = skein.ApplicationClient.from_current()
args = args or []

if not os.path.exists(script):
raise ValueError("%r doesn't exist" % script)

try:
subprocess.check_call([sys.executable, script])
subprocess.check_call([sys.executable, script] + args)
succeeded = True
retcode = 0
except subprocess.CalledProcessError as exc:
Expand Down
32 changes: 18 additions & 14 deletions dask_yarn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _make_specification(**kwargs):
return spec


def _make_submit_specification(script, **kwargs):
def _make_submit_specification(script, args=(), **kwargs):
client_vcores = lookup(kwargs, 'client_vcores', 'yarn.client.vcores')
client_memory = lookup(kwargs, 'client_memory', 'yarn.client.memory')
client_env = lookup(kwargs, 'client_env', 'yarn.client.env')
Expand All @@ -120,19 +120,23 @@ def _make_submit_specification(script, **kwargs):
environment = spec.services['dask.worker'].files['environment']

script_name = os.path.basename(script)
client = skein.Service(instances=1,
resources=skein.Resources(
vcores=client_vcores,
memory=client_memory
),
max_restarts=0,
depends=['dask.scheduler'],
files={'environment': environment,
script_name: script},
env=client_env,
commands=['source environment/bin/activate',
'dask-yarn services client %s' % script_name])
spec.services['dask.client'] = client

spec.services['dask.client'] = skein.Service(
instances=1,
resources=skein.Resources(
vcores=client_vcores,
memory=client_memory
),
max_restarts=0,
depends=['dask.scheduler'],
files={'environment': environment,
script_name: script},
env=client_env,
commands=[
'source environment/bin/activate',
'dask-yarn services client %s %s' % (script_name, ' '.join(args))
]
)
return spec


Expand Down
34 changes: 34 additions & 0 deletions dask_yarn/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,40 @@ def test_cli_submit_and_status(script_kind, final_status, searchtxt,
assert searchtxt in logs


def test_cli_submit_with_args(tmpdir, conda_env, skein_client, capfd):
script = ('import sys\n'
'args = sys.argv[1:]\n'
'assert args == ["a", "b", "c"]\n'
'print("Done!")')

script_path = str(tmpdir.join('script.py'))
with open(script_path, 'w') as fil:
fil.write(script)

run_command('submit '
'--name test-cli-submit-and-status '
'--environment %s '
'--worker-count 0 '
'--scheduler-memory 256MiB '
'--scheduler-vcores 1 '
'--client-memory 128MiB '
'--client-vcores 1 '
'%s a b c' % (conda_env, script_path))
out, err = capfd.readouterr()
# Logs go to err
assert 'INFO' in err
app_id = out.strip()
assert '\n' not in app_id

with ensure_shutdown(skein_client, app_id, status='SUCCEEDED'):
# Wait for app to start
skein_client.connect(app_id)
wait_for_completion(skein_client, app_id, timeout=60)

logs = get_logs(app_id)
assert 'Done!' in logs


def test_cli_kill(tmpdir, conda_env, skein_client, capfd):
script_path = str(tmpdir.join('script.py'))
with open(script_path, 'w') as fil:
Expand Down

0 comments on commit ba9e07c

Please sign in to comment.