Skip to content

Commit

Permalink
Do not call _get_host_assignments when we have insufficient slots
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Addair <taddair@uber.com>
  • Loading branch information
tgaddair committed May 13, 2020
1 parent 4506a1d commit f4d7c0b
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 29 deletions.
24 changes: 18 additions & 6 deletions horovod/run/elastic/driver.py
Expand Up @@ -105,6 +105,9 @@ def register_worker_server(self, host, slot, addresses, secret_key):
self._worker_clients[(host, slot)] = WorkerNotificationClient(
addresses, secret_key, self._verbose)

def get_worker_client(self, slot_info):
return self._worker_clients.get((slot_info.hostname, slot_info.local_rank))

def record_ready(self, host, slot):
self._worker_registry.record_ready(host, slot)

Expand All @@ -118,11 +121,18 @@ def get_slot_info(self, host, slot):
return self._host_assignments[host][slot] if self.has_rank_assignment(host, slot) \
else hosts.INVALID_SLOT_INFO

def get_coordinator_info(self):
return self._rank_assignments.get(0)

def has_rank_assignment(self, host, slot):
if self._host_manager.is_blacklisted(host):
return False
return host in self._host_assignments and len(self._host_assignments[host]) > slot

@property
def host_assignments(self):
return self._host_assignments

def wait_for_available_slots(self, min_np, min_hosts=1):
extra_message = ' An elastic job also requires that at least two hosts ' \
'are available to resolve compatible network interfaces. If you know which interfaces ' \
Expand Down Expand Up @@ -176,20 +186,22 @@ def _discover_hosts(self):
self._shutdown.wait(DISCOVER_HOSTS_FREQUENCY_SECS)

def _notify_workers_host_changes(self, current_hosts):
# Assignments are required to be stable via contract
next_host_assignments, _ = self._get_host_assignments(current_hosts)
if next_host_assignments == self._host_assignments:
next_host_assignments = {}
if current_hosts.count_available_slots() >= self._min_np:
# Assignments are required to be stable via contract
next_host_assignments, _ = self._get_host_assignments(current_hosts)

if next_host_assignments == self.host_assignments:
# Skip notifying workers when host changes would not result in changes of host assignments
logging.debug('no host assignment changes, skipping notifications')
return

coordinator_slot_info = self._rank_assignments.get(0)
coordinator_slot_info = self.get_coordinator_info()
if not coordinator_slot_info:
logging.debug('no coordinator info, skipping notifications')
return

coordinator_client = self._worker_clients.get((coordinator_slot_info.hostname,
coordinator_slot_info.local_rank))
coordinator_client = self.get_worker_client(coordinator_slot_info)
if not coordinator_client:
logging.debug('no coordinator client, skipping notifications')
return
Expand Down
88 changes: 65 additions & 23 deletions test/test_elastic_driver.py
Expand Up @@ -25,7 +25,7 @@
import pytest

from horovod.run.util import network
from horovod.run.elastic.discovery import FixedHosts, HostDiscovery, HostManager
from horovod.run.elastic.discovery import FixedHosts, HostManager
from horovod.run.elastic.driver import ElasticDriver
from horovod.run.elastic.rendezvous import create_rendezvous_handler
from horovod.run.elastic.worker import WorkerNotificationManager
Expand All @@ -40,17 +40,11 @@ def wait_for_one(events):
time.sleep(0.01)


class HostDiscoverySequence(HostDiscovery):
def __init__(self, sequence):
self.iterator = iter(sequence)
self.end = sequence[-1]
super(HostDiscoverySequence, self).__init__()

def find_available_hosts_and_slots(self):
try:
return self.iterator.next()
except:
return self.end
def sequence(lst):
for v in lst:
yield v
while True:
yield lst[-1]


class ElasticDriverTests(unittest.TestCase):
Expand Down Expand Up @@ -173,30 +167,41 @@ def exec_command(slot_info, events):
assert updated_slot_info.cross_rank == slot_info.cross_rank, rank

@mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01)
def test_wait_for_available_slots(self):
@mock.patch('horovod.run.elastic.driver.ElasticDriver.get_coordinator_info')
@mock.patch('horovod.run.elastic.driver.ElasticDriver.get_worker_client')
def test_wait_for_available_slots(self, mock_get_worker_client, mock_get_coordinator_info):
"""Tests that driver blocks until the min number of slots are available."""
slots = [{'host-1': 4},
{'host-1': 4, 'host-2': 8},
{'host-1': 4, 'host-2': 8, 'host-3': 4}]
discovery = HostDiscoverySequence(slots)
mock_discovery = mock.Mock()
mock_discovery.find_available_hosts_and_slots.side_effect = sequence(slots)

driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=12)
driver.wait_for_available_slots(min_np=10)
assert driver._host_manager.current_hosts.count_available_slots() >= 10
driver = ElasticDriver(mock.Mock(), mock_discovery, min_np=8, max_np=20)
driver.wait_for_available_slots(min_np=16)
assert driver._host_manager.current_hosts.count_available_slots() >= 16
driver.stop()

# Notify coordinator 2 times, as the first time we are below min_np and the existing host assignments
# are empty
assert mock_get_worker_client.call_count == 2
assert mock_get_coordinator_info.call_count == 2

@mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01)
def test_wait_for_min_hosts(self):
"""Tests that driver blocks until the min number of hosts and slots are available."""
slots = [{'host-1': 4},
{'host-1': 4, 'host-2': 8},
{'host-1': 4, 'host-2': 8, 'host-3': 4}]
discovery = HostDiscoverySequence(slots)
mock_discovery = mock.Mock()
mock_discovery.find_available_hosts_and_slots.side_effect = sequence(slots)

driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=12)
driver = ElasticDriver(mock.Mock(), mock_discovery, min_np=2, max_np=12)
driver.wait_for_available_slots(min_np=2, min_hosts=2)

# Even though we only needed 2 slots, because we also needed 2 hosts, we will at least 12 slots total
assert driver._host_manager.current_hosts.count_available_slots() >= 12
driver.stop()

def test_all_workers_fail(self):
"""Tests that training fails when all workers fail."""
Expand Down Expand Up @@ -361,6 +366,30 @@ def exec_command(slot_info, events):

rendezvous.stop_server()

@mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01)
@mock.patch('horovod.run.elastic.driver.ElasticDriver.host_assignments')
@mock.patch('horovod.run.elastic.driver.ElasticDriver.get_coordinator_info')
@mock.patch('horovod.run.elastic.driver.ElasticDriver.get_worker_client')
def test_send_notifications_without_assignments(self, mock_get_worker_client, mock_get_coordinator_info,
mock_host_assignments):
"""Tests that notifications are still sent correctly even if host assignments cannot be generated."""
slots = [{'host-1': 8, 'host-2': 4},
{'host-1': 8, 'host-2': 4},
{'host-2': 4},
{'host-2': 4},
{'host-2': 4, 'host-3': 12}]
discovery = mock.Mock()
discovery.find_available_hosts_and_slots.side_effect = sequence(slots)

driver = ElasticDriver(mock.Mock(), discovery, min_np=8, max_np=12)
driver.wait_for_available_slots(min_np=16)
driver.stop()

# On the second call, we should see the number of slots dip below the minimum, but we still want to ensure
# we notify workers every time there is a change, so in total we should observe 3 calls.
assert mock_get_worker_client.call_count == 3
assert mock_get_coordinator_info.call_count == 3

def test_order_available_hosts(self):
"""Tests the order is preserved for host assignment as available hosts are updated."""
# This will be a set in practice, but use a list here to guarantee order.
Expand Down Expand Up @@ -440,10 +469,23 @@ def test_shutdown_on_initial_discovery_failure(self):
discovery = mock.Mock()
discovery.find_available_hosts_and_slots.side_effect = RuntimeError()

driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4)
with pytest.raises(RuntimeError):
driver.wait_for_available_slots(min_np=2)
assert driver.finished()
discover_hosts = ElasticDriver._discover_hosts

def wrapped_discover_hosts(obj):
try:
discover_hosts(obj)
except RuntimeError:
# Suppress the error message from the background discovery thread to clean up unit tests
pass

try:
ElasticDriver._discover_hosts = wrapped_discover_hosts
driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4)
with pytest.raises(RuntimeError):
driver.wait_for_available_slots(min_np=2)
assert driver.finished()
finally:
ElasticDriver._discover_hosts = discover_hosts


if __name__ == "__main__":
Expand Down

0 comments on commit f4d7c0b

Please sign in to comment.