Skip to content

Commit

Permalink
init pr
Browse files Browse the repository at this point in the history
Signed-off-by: WeichenXu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Aug 22, 2019
1 parent 1df0938 commit 67947d0
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 56 deletions.
30 changes: 30 additions & 0 deletions bin/horovod_launch_exec_func
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python

# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import cloudpickle
import sys
from horovod.run.rendezvous.http_client import read_data_from_kvstore, put_data_into_kvstore

if __name__ == '__main__':
_, driver_addr, run_func_server_port_str = sys.argv
run_func_server_port = int(run_func_server_port_str)
pickled_func = read_data_from_kvstore(driver_addr, run_func_server_port, 'runfunc', 'func')
func = cloudpickle.loads(pickled_func)
ret_val = func()
if ret_val:
pickled_ret_val = cloudpickle.dumps(ret_val)
put_data_into_kvstore(driver_addr, run_func_server_port,
'runfunc', 'result', pickled_ret_val)
2 changes: 1 addition & 1 deletion examples/tensorflow_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def main(_):
hvd.BroadcastGlobalVariablesHook(0),

# Horovod: adjust number of steps based on number of GPUs.
tf.train.StopAtStepHook(last_step=20000 // hvd.size()),
tf.train.StopAtStepHook(last_step=200 // hvd.size()),

tf.train.LoggingTensorHook(tensors={'step': global_step, 'loss': loss},
every_n_iter=10),
Expand Down
2 changes: 2 additions & 0 deletions horovod/run/common/util/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ def __init__(self, verbose=0, ssh_port=None, key=None, timeout=None,
self.num_proc = num_proc
self.hosts = hosts
self.command = command
self.run_func_mode = False

14 changes: 2 additions & 12 deletions horovod/run/gloo_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def set_event_on_sigterm(signum, frame):
block_until_all_done=True)


def gloo_run(settings, remote_host_names, common_intfs):
def gloo_run(settings, remote_host_names, common_intfs, driver_ip):
# allocate processes into slots
host_alloc_plan = _allocate(settings.hosts, settings.num_proc)

Expand All @@ -205,17 +205,7 @@ def gloo_run(settings, remote_host_names, common_intfs):
# Start rendezvous server and get port that it is listening
global_rendezv_port = global_rendezv.start_server(host_alloc_plan)

# get the server IPv4 address
iface = list(common_intfs)[0]
server_ip = None
for addr in net_if_addrs()[iface]:
if addr.family == AF_INET:
server_ip = addr.address

if not server_ip:
raise RuntimeError(
'Cannot find an IPv4 address of the common interface.')

run_command = (
'HOROVOD_GLOO_RENDEZVOUS_ADDR={addr} '
'HOROVOD_GLOO_RENDEZVOUS_PORT={port} '
Expand All @@ -224,7 +214,7 @@ def gloo_run(settings, remote_host_names, common_intfs):
'HOROVOD_IFACE={iface} '
'NCCL_SOCKET_IFNAME={common_intfs} '
'{command}' # expect a lot of environment variables
.format(addr=server_ip,
.format(addr=driver_ip,
port=global_rendezv_port,
iface=iface, # TODO: add multiple ifaces in future
common_intfs=','.join(common_intfs),
Expand Down
6 changes: 5 additions & 1 deletion horovod/run/mpi_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,8 @@ def mpi_run(settings, common_intfs):
if settings.verbose >= 2:
print(mpirun_command)
# Execute the mpirun command.
os.execve('/bin/sh', ['/bin/sh', '-c', mpirun_command], env)
if settings.run_func_mode:
safe_shell_exec.execute(['/bin/sh', '-c', mpirun_command], env=env)
else:
os.execve('/bin/sh', ['/bin/sh', '-c', mpirun_command], env)

48 changes: 48 additions & 0 deletions horovod/run/rendezvous/http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import sys
import base64
if sys.version < '3':
from urllib2 import urlopen
from urllib2 import Request
from urllib2 import HTTPError, URLError
else:
from urllib.request import urlopen
from urllib.request import Request
from urllib.error import HTTPError, URLError


def read_data_from_kvstore(addr, port, scope, key):
try:
url = "http://{addr}:{port}/{scope}/{key}".format(
addr=addr, port=str(port), scope=scope, key=key
)
req = Request(url)
resp = urlopen(req)
return base64.b64decode(resp.read())
except (HTTPError, URLError) as e:
raise RuntimeError("Read data from KVStore server failed.", e)


def put_data_into_kvstore(addr, port, scope, key, value):
try:
url = "http://{addr}:{port}/{scope}/{key}".format(
addr=addr, port=str(port), scope=scope, key=key
)
req = Request(url, data=base64.b64encode(value))
req.get_method = lambda: "PUT" # for urllib2 compatibility
urlopen(req)
except (HTTPError, URLError) as e:
raise RuntimeError("Put data input KVStore server failed.", e)
78 changes: 64 additions & 14 deletions horovod/run/rendezvous/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
OK = 200


class RendezvousHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
class KVStoreHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
# Set timeout
timeout = SINGLE_REQUEST_TIMEOUT

Expand All @@ -39,7 +39,7 @@ def do_GET(self):
paths = self.path.split('/')
if len(paths) < 3:
print(
'Rendezvous ERROR: Invalid request path: {path}.'.format(
'KVStore ERROR: Invalid request path: {path}.'.format(
path=self.path))
self.send_status_code(BAD_REQUEST)
return
Expand All @@ -61,7 +61,7 @@ def do_PUT(self):
paths = self.path.split('/')
if len(paths) < 3:
print(
'Rendezvous ERROR: Invalid request path: {path}.'.format(
'KVStore ERROR: Invalid request path: {path}.'.format(
path=self.path))
self.send_status_code(BAD_REQUEST)
return
Expand All @@ -75,7 +75,7 @@ def do_PUT(self):
except socket.timeout:
if self.server.verbose:
print(
'Rendezvous ERROR: Timeout when receiving {content_bytes} '
'KVStore ERROR: Timeout when receiving {content_bytes} '
'bytes, aborting this incomplete request.' .format(
content_bytes=content_length))

Expand All @@ -91,6 +91,18 @@ def do_PUT(self):

self.send_status_code(OK)

def send_status_code(self, status_code):
self.send_response(status_code)
self.send_header("Content-Length", 0)
self.end_headers()

# Override this function to prevent SimpleHTTPServer printing every
# request out.
def log_message(self, format, *args):
pass


class RendezvousHandler(KVStoreHandler):
# Override DELETE handler
def do_DELETE(self):
paths = self.path.split('/')
Expand All @@ -108,16 +120,6 @@ def do_DELETE(self):

self.send_status_code(OK)

def send_status_code(self, status_code):
self.send_response(status_code)
self.send_header("Content-Length", 0)
self.end_headers()

# Override this function to prevent SimpleHTTPServer printing every
# request out.
def log_message(self, format, *args):
pass


class RendezvousHTTPServer(BaseHTTPServer.HTTPServer, object):
def __init__(self, addr, handler, verbose):
Expand Down Expand Up @@ -199,6 +201,54 @@ def listen_loop(self):
while self.httpd.should_continue():
self.httpd.handle_request()

self.httpd.server_close()

if self.verbose:
print('Rendezvous INFO: Rendezvous finishes.')
# Because this thread is daemonized, no need to join.


class KVStoreHTTPServer(BaseHTTPServer.HTTPServer, object):
def __init__(self, addr, handler, verbose):
super(KVStoreHTTPServer, self).__init__(addr, handler)

# Cache that provides the store
self.cache_lock = threading.Lock()
self.cache = {}

self.verbose = verbose


class KVStoreServer:

def __init__(self, verbose):
self.httpd = None
self.listen_thread = None
self.verbose = verbose

# KVStore server finds a available port, create http socket,
# and start listening loop to handle request
def start_server(self):
self.httpd, port = find_port(
lambda addr: KVStoreHTTPServer(
addr, KVStoreHandler, self.verbose))

self.listen_thread = threading.Thread(
target=lambda: self.httpd.serve_forever())
self.listen_thread.daemon = True
self.listen_thread.start()

if self.verbose:
print('KVStoreServer INFO: KVStore server started. Listen on port ' + str(port))

return port

def shutdown_server(self):
self.httpd.shutdown()

self.httpd.server_close()

if self.verbose:
print('KVStoreServer INFO: KVStore server finishes.')
# Because this thread is daemonized, no need to join.

Loading

0 comments on commit 67947d0

Please sign in to comment.