Skip to content
Permalink
Browse files

Added horovodrun. (#869)

* Added horovodrun

Signed-off-by: fardin <fardin@uber.com>
  • Loading branch information...
abditag2 committed Mar 18, 2019
1 parent 3598452 commit f3b118629ff2770cb1485209e8fa46daf9eb6c9a
@@ -124,6 +124,13 @@ script:
# run TensorFlow MNIST example
- docker exec ${CONTAINER} /bin/sh -c "${MPIRUN} python /horovod/examples/tensorflow_mnist.py"

# run TensorFlow MNIST example with horovodrun. For now, horovdrun does not
# support MPICH.
- |
if [[ ${MPI} != "MPICH" ]]; then
docker exec ${CONTAINER} /bin/sh -c "horovodrun -np 2 -H localhost:2 python /horovod/examples/tensorflow_mnist.py"
fi
# hack TensorFlow Eager MNIST example to be smaller
- docker exec ${CONTAINER} /bin/sh -c "sed -i \"s/dataset.take(20000/dataset.take(100/\" /horovod/examples/tensorflow_mnist_eager.py"

@@ -0,0 +1,21 @@
#!/usr/bin/env python

# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from horovod.run import run

if __name__ == '__main__':
run.run()
File renamed without changes.
No changes.
@@ -0,0 +1,152 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import threading

from horovod.run.common.util import network


class RegisterTaskToTaskAddressesRequest(object):
def __init__(self, index, task_addresses):
self.index = index
"""Task index."""

self.task_addresses = task_addresses
"""Map of interface to list of (ip, port) pairs."""


class AllTaskAddressesRequest(object):
"""Request all task addresses for a given index."""

def __init__(self, index):
self.index = index


class AllTaskAddressesResponse(object):
def __init__(self, all_task_addresses):
self.all_task_addresses = all_task_addresses
"""Map of interface to list of (ip, port) pairs."""


class BasicDriverService(network.BasicService):
def __init__(self, num_proc, name, key):
super(BasicDriverService, self).__init__(name, key)
self._num_proc = num_proc
self._all_task_addresses = {}
self._task_addresses_for_driver = {}
self._task_addresses_for_tasks = {}
self._task_host_hash_indices = {}
self._wait_cond = threading.Condition()

def _handle(self, req, client_address):
if isinstance(req, RegisterTaskRequest):
self._wait_cond.acquire()
try:
assert 0 <= req.index < self._num_proc
self._all_task_addresses[req.index] = req.task_addresses
# Just use source address for service for fast probing.
self._task_addresses_for_driver[req.index] = \
self._filter_by_ip(req.task_addresses, client_address[0])
# Make host hash -> indices map.
if req.host_hash not in self._task_host_hash_indices:
self._task_host_hash_indices[req.host_hash] = []
self._task_host_hash_indices[req.host_hash].append(req.index)
self._task_host_hash_indices[req.host_hash].sort()
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
return network.AckResponse()

if isinstance(req, RegisterTaskToTaskAddressesRequest):
self._wait_cond.acquire()
try:
assert 0 <= req.index < self._num_proc
self._task_addresses_for_tasks[req.index] = req.task_addresses
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
return network.AckResponse()

if isinstance(req, AllTaskAddressesRequest):
return AllTaskAddressesResponse(self._all_task_addresses[req.index])

return super(BasicDriverService, self)._handle(req, client_address)

def _filter_by_ip(self, addresses, target_ip):
for intf, intf_addresses in addresses.items():
for ip, port in intf_addresses:
if ip == target_ip:
return {intf: [(ip, port)]}

def task_addresses_for_driver(self, index):
return self._task_addresses_for_driver[index]

def task_addresses_for_tasks(self, index):
return self._task_addresses_for_tasks[index]

def task_host_hash_indices(self):
return self._task_host_hash_indices

def wait_for_initial_registration(self, timeout):
self._wait_cond.acquire()
try:
while len(self._all_task_addresses) < self._num_proc:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('tasks to start')
finally:
self._wait_cond.release()

def wait_for_task_to_task_address_updates(self, timeout):
self._wait_cond.acquire()
try:
while len(self._task_addresses_for_tasks) < self._num_proc:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for(
'tasks to update task-to-task addresses')
finally:
self._wait_cond.release()


class RegisterTaskRequest(object):
def __init__(self, index, task_addresses, host_hash):
self.index = index
"""Task index."""

self.task_addresses = task_addresses
"""Map of interface to list of (ip, port) pairs."""

self.host_hash = host_hash
"""
Hash of the host that helps to determine which tasks
have shared memory access to each other.
"""


class BasicDriverClient(network.BasicClient):
def __init__(self, name, driver_addresses, key, match_intf=False):
super(BasicDriverClient, self).__init__(name,
driver_addresses,
key,
match_intf=match_intf)

def register_task(self, index, task_addresses, host_hash):
self._send(RegisterTaskRequest(index, task_addresses, host_hash))

def all_task_addresses(self, index):
resp = self._send(AllTaskAddressesRequest(index))
return resp.all_task_addresses

def register_task_to_task_addresses(self, index, task_addresses):
self._send(RegisterTaskToTaskAddressesRequest(index, task_addresses))
@@ -0,0 +1,155 @@
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import threading
import time

from horovod.run.common.util import network
from horovod.run.common.util import safe_shell_exec


class RunCommandRequest(object):
def __init__(self, command, env):
self.command = command
"""Command to run."""
self.env = env
"""Environment to use."""


class CommandTerminatedRequest(object):
"""Is command execution finished?"""
pass


class CommandTerminatedResponse(object):
def __init__(self, flag):
self.flag = flag
"""Yes/no"""


class NotifyInitialRegistrationCompleteRequest(object):
"""Notification that initial task registration has completed."""
pass


class RegisterCodeResultRequest(object):
"""Register code execution results with task."""

def __init__(self, result):
self.result = result


class BasicTaskService(network.BasicService):
def __init__(self, name, key):
super(BasicTaskService, self).__init__(name, key)
self._initial_registration_complete = False
self._wait_cond = threading.Condition()
self._command_thread = None
self._fn_result = None

def _handle(self, req, client_address):
if isinstance(req, RunCommandRequest):
self._wait_cond.acquire()
try:
if self._command_thread is None:
# We only permit executing exactly one command, so this is idempotent.
self._command_thread = threading.Thread(
target=safe_shell_exec.execute,
args=(req.command, req.env))
self._command_thread.daemon = True
self._command_thread.start()
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
return network.AckResponse()

if isinstance(req, NotifyInitialRegistrationCompleteRequest):
self._wait_cond.acquire()
try:
self._initial_registration_complete = True
finally:
self._wait_cond.notify_all()
self._wait_cond.release()
return network.AckResponse()

if isinstance(req, CommandTerminatedRequest):
self._wait_cond.acquire()
try:
terminated = (self._command_thread is not None and
not self._command_thread.is_alive())
finally:
self._wait_cond.release()
return CommandTerminatedResponse(terminated)

if isinstance(req, RegisterCodeResultRequest):
self._fn_result = req.result
return network.AckResponse()

return super(BasicTaskService, self)._handle(req, client_address)

def fn_result(self):
return self._fn_result

def wait_for_initial_registration(self, timeout):
self._wait_cond.acquire()
try:
while not self._initial_registration_complete:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('Spark tasks to start')
finally:
self._wait_cond.release()

def wait_for_command_start(self, timeout):
self._wait_cond.acquire()
try:
while self._command_thread is None:
self._wait_cond.wait(timeout.remaining())
timeout.check_time_out_for('command to run')
finally:
self._wait_cond.release()

def wait_for_command_termination(self):
self._command_thread.join()


class BasicTaskClient(network.BasicClient):
def __init__(self, service_name, task_addresses, key, match_intf=False,
retries=3):
super(BasicTaskClient, self).__init__(service_name,
task_addresses, key,
match_intf=match_intf,
retries=retries)

def run_command(self, command, env):
self._send(RunCommandRequest(command, env))

def notify_initial_registration_complete(self):
self._send(NotifyInitialRegistrationCompleteRequest())

def command_terminated(self):
resp = self._send(CommandTerminatedRequest())
return resp.flag

def register_code_result(self, result):
self._send(RegisterCodeResultRequest(result))

def wait_for_command_termination(self, delay=1):
try:
while True:
if self.command_terminated():
break
time.sleep(delay)
except:
pass
@@ -1,4 +1,4 @@
# Copyright 2018 Uber Technologies, Inc. All Rights Reserved.
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
# Copyright 2018 Uber Technologies, Inc. All Rights Reserved.
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
import os
import socket


NAMESPACE_PATH = '/proc/self/ns'


Oops, something went wrong.

0 comments on commit f3b1186

Please sign in to comment.
You can’t perform that action at this time.