Skip to content
Permalink
Browse files

Add `horovod.run.run` to make horovod notebook friendly (#1307)

Signed-off-by: WeichenXu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 authored and tgaddair committed Oct 24, 2019
1 parent bbf09d7 commit 9fc256d2d54143451683b4876090ec659a1fdc36
@@ -109,10 +109,20 @@ run_all() {
exclude_keras_if_needed="| sed 's/[a-z_]*keras[a-z_.]*//g'"
fi

local exclude_interactiverun="| sed 's/test_interactiverun.py//g'"

# pytests have 4x GPU use cases and require a separate queue
run_test "${test}" "${pytest_queue}" \
":pytest: Run PyTests (${test})" \
"bash -c \"cd /horovod/test && (echo test_*.py ${exclude_keras_if_needed} | xargs -n 1 \\\$(cat /mpirun_command) pytest -v --capture=no)\""
"bash -c \"cd /horovod/test && (echo test_*.py ${exclude_keras_if_needed} ${exclude_interactiverun} | xargs -n 1 \\\$(cat /mpirun_command) pytest -v --capture=no)\""

# Run test_interactiverun.py
if [[ ${test} != *"mpich"* ]]; then
# TODO: support mpich
run_test "${test}" "${queue}" \
":pytest: Run PyTests test_interactiverun (${test})" \
"bash -c \"cd /horovod/test && pytest -v --capture=no test_interactiverun.py\""
fi

# Legacy TensorFlow tests
if [[ ${test} != *"tf2_"* ]]; then
@@ -180,9 +190,11 @@ run_gloo() {
exclude_spark_if_needed="| sed 's/[a-z_]*spark[a-z_.]*//g'"
fi

local exclude_interactiverun="| sed 's/test_interactiverun.py//g'"

run_test "${test}" "${pytest_queue}" \
":pytest: Run PyTests (${test})" \
"bash -c \"cd /horovod/test && (echo test_*.py ${exclude_spark_if_needed} | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no)\""
"bash -c \"cd /horovod/test && (echo test_*.py ${exclude_spark_if_needed} ${exclude_interactiverun} | xargs -n 1 horovodrun -np 2 -H localhost:2 --gloo pytest -v --capture=no)\""

run_test "${test}" "${queue}" \
":muscle: Test Keras MNIST (${test})" \
@@ -15,7 +15,7 @@
# limitations under the License.
# ==============================================================================

from horovod.run import run
from horovod.run.run import run_commandline

if __name__ == '__main__':
run.run()
run_commandline()
@@ -2,3 +2,4 @@ git+https://github.com/sphinx-doc/sphinx@2.0
sphinxcontrib-napoleon
alabaster
nbsphinx
pyyaml
@@ -0,0 +1,16 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from .run import run
@@ -14,6 +14,7 @@
# ==============================================================================

import re
import os

LOG_LEVEL_STR = ['FATAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'TRACE']

@@ -23,3 +24,22 @@

def is_exportable(v):
return not any(re.match(r, v) for r in IGNORE_REGEXES)


def get_env_rank_and_size():
rank_env = ['HOROVOD_RANK', 'OMPI_COMM_WORLD_RANK', 'PMI_RANK']
size_env = ['HOROVOD_SIZE', 'OMPI_COMM_WORLD_SIZE', 'PMI_SIZE']

for rank_var, size_var in zip(rank_env, size_env):
rank = os.environ.get(rank_var)
size = os.environ.get(size_var)
if rank is not None and size is not None:
return int(rank), int(size)
elif rank is not None or size is not None:
raise RuntimeError(
'Could not determine process rank and size: only one of {} and {} '
'found in environment'.format(rank_var, size_var))

# Default to rank zero and size one if there are no environment variables
return 0, 1

@@ -18,7 +18,7 @@ class Settings(object):

def __init__(self, verbose=0, ssh_port=None, extra_mpi_args=None, key=None, timeout=None,
num_hosts=None, num_proc=None, hosts=None, output_filename=None,
command=None):
run_func_mode=None):
"""
:param verbose: level of verbosity
:type verbose: int
@@ -39,8 +39,8 @@ def __init__(self, verbose=0, ssh_port=None, extra_mpi_args=None, key=None, time
:type hosts: string
:param output_filename: optional filename to redirect stdout / stderr by process
:try output_filename: string
:param command: number of horovod processes (-np)
:type num_proc: int
:param run_func_mode: whether it is run function mode
:type run_func_mode: boolean
"""
self.verbose = verbose
self.ssh_port = ssh_port
@@ -51,4 +51,5 @@ def __init__(self, verbose=0, ssh_port=None, extra_mpi_args=None, key=None, time
self.num_proc = num_proc
self.hosts = hosts
self.output_filename = output_filename
self.command = command
self.run_func_mode = run_func_mode

@@ -22,16 +22,13 @@
import threading
import time

from psutil import net_if_addrs
from socket import AF_INET

try:
from shlex import quote
except ImportError:
from pipes import quote

from horovod.run.common.util import env as env_util, safe_shell_exec
from horovod.run.rendezvous.http_server import RendezvousServer
from horovod.run.http.http_server import RendezvousServer
from horovod.run.util import threads


@@ -262,7 +259,7 @@ def set_event_on_sigterm(signum, frame):
.format(name=name, code=exit_code))


def gloo_run(settings, remote_host_names, common_intfs, env):
def gloo_run(settings, remote_host_names, common_intfs, env, server_ip, command):
# allocate processes into slots
host_alloc_plan = _allocate(settings.hosts, settings.num_proc)

@@ -271,16 +268,7 @@ def gloo_run(settings, remote_host_names, common_intfs, env):
# Start rendezvous server and get port that it is listening
global_rendezv_port = global_rendezv.start_server(host_alloc_plan)

# get the server IPv4 address
iface = list(common_intfs)[0]
server_ip = None
for addr in net_if_addrs()[iface]:
if addr.family == AF_INET:
server_ip = addr.address

if not server_ip:
raise RuntimeError(
'Cannot find an IPv4 address of the common interface.')

run_command = (
'HOROVOD_GLOO_RENDEZVOUS_ADDR={addr} '
@@ -294,7 +282,7 @@ def gloo_run(settings, remote_host_names, common_intfs, env):
port=global_rendezv_port,
iface=iface, # TODO: add multiple ifaces in future
common_intfs=','.join(common_intfs),
command=' '.join(quote(par) for par in settings.command)))
command=' '.join(quote(par) for par in command)))

_launch_jobs(settings, env, host_alloc_plan, remote_host_names, run_command)
return
File renamed without changes.
@@ -0,0 +1,50 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import sys
import base64
from distutils.version import LooseVersion
if LooseVersion(sys.version) < LooseVersion('3.0.0'):
from urllib2 import urlopen
from urllib2 import Request
from urllib2 import HTTPError, URLError
else:
from urllib.request import urlopen
from urllib.request import Request
from urllib.error import HTTPError, URLError


def read_data_from_kvstore(addr, port, scope, key):
try:
url = "http://{addr}:{port}/{scope}/{key}".format(
addr=addr, port=str(port), scope=scope, key=key
)
req = Request(url)
resp = urlopen(req)
# TODO: remove base64 encoding because base64 is not efficient
return base64.b64decode(resp.read())
except (HTTPError, URLError) as e:
raise RuntimeError("Read data from KVStore server failed.", e)


def put_data_into_kvstore(addr, port, scope, key, value):
try:
url = "http://{addr}:{port}/{scope}/{key}".format(
addr=addr, port=str(port), scope=scope, key=key
)
req = Request(url, data=base64.b64encode(value))
req.get_method = lambda: "PUT" # for urllib2 compatibility
urlopen(req)
except (HTTPError, URLError) as e:
raise RuntimeError("Put data input KVStore server failed.", e)
@@ -30,7 +30,7 @@
OK = 200


class RendezvousHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
class KVStoreHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
# Set timeout
timeout = SINGLE_REQUEST_TIMEOUT

@@ -39,7 +39,7 @@ def do_GET(self):
paths = self.path.split('/')
if len(paths) < 3:
print(
'Rendezvous ERROR: Invalid request path: {path}.'.format(
'KVStore ERROR: Invalid request path: {path}.'.format(
path=self.path))
self.send_status_code(BAD_REQUEST)
return
@@ -61,7 +61,7 @@ def do_PUT(self):
paths = self.path.split('/')
if len(paths) < 3:
print(
'Rendezvous ERROR: Invalid request path: {path}.'.format(
'KVStore ERROR: Invalid request path: {path}.'.format(
path=self.path))
self.send_status_code(BAD_REQUEST)
return
@@ -75,7 +75,7 @@ def do_PUT(self):
except socket.timeout:
if self.server.verbose:
print(
'Rendezvous ERROR: Timeout when receiving {content_bytes} '
'KVStore ERROR: Timeout when receiving {content_bytes} '
'bytes, aborting this incomplete request.' .format(
content_bytes=content_length))

@@ -91,6 +91,18 @@ def do_PUT(self):

self.send_status_code(OK)

def send_status_code(self, status_code):
self.send_response(status_code)
self.send_header("Content-Length", 0)
self.end_headers()

# Override this function to prevent SimpleHTTPServer printing every
# request out.
def log_message(self, format, *args):
pass


class RendezvousHandler(KVStoreHandler):
# Override DELETE handler
def do_DELETE(self):
paths = self.path.split('/')
@@ -108,16 +120,6 @@ def do_DELETE(self):

self.send_status_code(OK)

def send_status_code(self, status_code):
self.send_response(status_code)
self.send_header("Content-Length", 0)
self.end_headers()

# Override this function to prevent SimpleHTTPServer printing every
# request out.
def log_message(self, format, *args):
pass


class RendezvousHTTPServer(BaseHTTPServer.HTTPServer, object):
def __init__(self, addr, handler, verbose):
@@ -199,6 +201,52 @@ def listen_loop(self):
while self.httpd.should_continue():
self.httpd.handle_request()

self.httpd.server_close()

if self.verbose:
print('Rendezvous INFO: Rendezvous finishes.')
# Because this thread is daemonized, no need to join.


class KVStoreHTTPServer(BaseHTTPServer.HTTPServer, object):
def __init__(self, addr, handler, verbose):
super(KVStoreHTTPServer, self).__init__(addr, handler)

# Cache that provides the store
self.cache_lock = threading.Lock()
self.cache = {}

self.verbose = verbose


class KVStoreServer:
def __init__(self, verbose):
self.httpd = None
self.listen_thread = None
self.verbose = verbose

# KVStore server finds a available port, create http socket,
# and start listening loop to handle request
def start_server(self):
self.httpd, port = find_port(
lambda addr: KVStoreHTTPServer(
addr, KVStoreHandler, self.verbose))

self.listen_thread = threading.Thread(
target=lambda: self.httpd.serve_forever())
self.listen_thread.daemon = True
self.listen_thread.start()

if self.verbose:
print('KVStoreServer INFO: KVStore server started. Listen on port ' + str(port))

return port

def shutdown_server(self):
self.httpd.shutdown()

self.httpd.server_close()

if self.verbose:
print('KVStoreServer INFO: KVStore server finishes.')
# Because this thread is daemonized, no need to join.

0 comments on commit 9fc256d

Please sign in to comment.
You can’t perform that action at this time.