diff --git a/horovod/run/gloo_run.py b/horovod/run/gloo_run.py index 25def0e176..2c19158885 100644 --- a/horovod/run/gloo_run.py +++ b/horovod/run/gloo_run.py @@ -141,10 +141,7 @@ def _exec_command(_command, _index, event_): 'message: {message}'.format(message=e)) return 0 - if settings.ssh_port: - ssh_port_arg = "-p {ssh_port}".format(ssh_port=settings.ssh_port) - else: - ssh_port_arg = "" + ssh_port_arg = '-p {ssh_port}'.format(ssh_port=settings.ssh_port) if settings.ssh_port else '' # Create a event for communication between threads event = threading.Event() @@ -171,7 +168,7 @@ def set_event_on_sigterm(signum, frame): local_command = '{horovod_env} {env} {run_command}' .format( horovod_env=horovod_rendez_env, env=' '.join(['%s=%s' % (key, quote(value)) for key, value in env.items() - if env_util.is_exportable(key)]), + if env_util.is_exportable(key)]), run_command=_run_command) if host_name not in remote_host_names: @@ -181,7 +178,8 @@ def set_event_on_sigterm(signum, frame): '{local_command}'.format( host=host_name, ssh_port_arg=ssh_port_arg, - local_command=quote(local_command) + local_command=quote('cd {pwd} >& /dev/null ; {local_command}' + .format(pwd=os.getcwd(), local_command=local_command)) ) args_list.append([command, alloc_info.rank, event])