diff --git a/horovod/run/elastic/driver.py b/horovod/run/elastic/driver.py index 88333ce51b..7e81bf712e 100644 --- a/horovod/run/elastic/driver.py +++ b/horovod/run/elastic/driver.py @@ -105,6 +105,9 @@ def register_worker_server(self, host, slot, addresses, secret_key): self._worker_clients[(host, slot)] = WorkerNotificationClient( addresses, secret_key, self._verbose) + def get_worker_client(self, slot_info): + return self._worker_clients.get((slot_info.hostname, slot_info.local_rank)) + def record_ready(self, host, slot): self._worker_registry.record_ready(host, slot) @@ -118,11 +121,18 @@ def get_slot_info(self, host, slot): return self._host_assignments[host][slot] if self.has_rank_assignment(host, slot) \ else hosts.INVALID_SLOT_INFO + def get_coordinator_info(self): + return self._rank_assignments.get(0) + def has_rank_assignment(self, host, slot): if self._host_manager.is_blacklisted(host): return False return host in self._host_assignments and len(self._host_assignments[host]) > slot + @property + def host_assignments(self): + return self._host_assignments + def wait_for_available_slots(self, min_np, min_hosts=1): extra_message = ' An elastic job also requires that at least two hosts ' \ 'are available to resolve compatible network interfaces. If you know which interfaces ' \ @@ -176,20 +186,22 @@ def _discover_hosts(self): self._shutdown.wait(DISCOVER_HOSTS_FREQUENCY_SECS) def _notify_workers_host_changes(self, current_hosts): - # Assignments are required to be stable via contract - next_host_assignments, _ = self._get_host_assignments(current_hosts) - if next_host_assignments == self._host_assignments: + next_host_assignments = {} + if current_hosts.count_available_slots() >= self._min_np: + # Assignments are required to be stable via contract + next_host_assignments, _ = self._get_host_assignments(current_hosts) + + if next_host_assignments == self.host_assignments: # Skip notifying workers when host changes would not result in changes of host assignments logging.debug('no host assignment changes, skipping notifications') return - coordinator_slot_info = self._rank_assignments.get(0) + coordinator_slot_info = self.get_coordinator_info() if not coordinator_slot_info: logging.debug('no coordinator info, skipping notifications') return - coordinator_client = self._worker_clients.get((coordinator_slot_info.hostname, - coordinator_slot_info.local_rank)) + coordinator_client = self.get_worker_client(coordinator_slot_info) if not coordinator_client: logging.debug('no coordinator client, skipping notifications') return diff --git a/test/test_elastic_driver.py b/test/test_elastic_driver.py index a4f928365b..06054762b5 100644 --- a/test/test_elastic_driver.py +++ b/test/test_elastic_driver.py @@ -25,7 +25,7 @@ import pytest from horovod.run.util import network -from horovod.run.elastic.discovery import FixedHosts, HostDiscovery, HostManager +from horovod.run.elastic.discovery import FixedHosts, HostManager from horovod.run.elastic.driver import ElasticDriver from horovod.run.elastic.rendezvous import create_rendezvous_handler from horovod.run.elastic.worker import WorkerNotificationManager @@ -40,17 +40,11 @@ def wait_for_one(events): time.sleep(0.01) -class HostDiscoverySequence(HostDiscovery): - def __init__(self, sequence): - self.iterator = iter(sequence) - self.end = sequence[-1] - super(HostDiscoverySequence, self).__init__() - - def find_available_hosts_and_slots(self): - try: - return self.iterator.next() - except: - return self.end +def sequence(lst): + for v in lst: + yield v + while True: + yield lst[-1] class ElasticDriverTests(unittest.TestCase): @@ -173,16 +167,25 @@ def exec_command(slot_info, events): assert updated_slot_info.cross_rank == slot_info.cross_rank, rank @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) - def test_wait_for_available_slots(self): + @mock.patch('horovod.run.elastic.driver.ElasticDriver.get_coordinator_info') + @mock.patch('horovod.run.elastic.driver.ElasticDriver.get_worker_client') + def test_wait_for_available_slots(self, mock_get_worker_client, mock_get_coordinator_info): """Tests that driver blocks until the min number of slots are available.""" slots = [{'host-1': 4}, {'host-1': 4, 'host-2': 8}, {'host-1': 4, 'host-2': 8, 'host-3': 4}] - discovery = HostDiscoverySequence(slots) + mock_discovery = mock.Mock() + mock_discovery.find_available_hosts_and_slots.side_effect = sequence(slots) - driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=12) - driver.wait_for_available_slots(min_np=10) - assert driver._host_manager.current_hosts.count_available_slots() >= 10 + driver = ElasticDriver(mock.Mock(), mock_discovery, min_np=8, max_np=20) + driver.wait_for_available_slots(min_np=16) + assert driver._host_manager.current_hosts.count_available_slots() >= 16 + driver.stop() + + # Notify coordinator 2 times, as the first time we are below min_np and the existing host assignments + # are empty + assert mock_get_worker_client.call_count == 2 + assert mock_get_coordinator_info.call_count == 2 @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) def test_wait_for_min_hosts(self): @@ -190,13 +193,15 @@ def test_wait_for_min_hosts(self): slots = [{'host-1': 4}, {'host-1': 4, 'host-2': 8}, {'host-1': 4, 'host-2': 8, 'host-3': 4}] - discovery = HostDiscoverySequence(slots) + mock_discovery = mock.Mock() + mock_discovery.find_available_hosts_and_slots.side_effect = sequence(slots) - driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=12) + driver = ElasticDriver(mock.Mock(), mock_discovery, min_np=2, max_np=12) driver.wait_for_available_slots(min_np=2, min_hosts=2) # Even though we only needed 2 slots, because we also needed 2 hosts, we will at least 12 slots total assert driver._host_manager.current_hosts.count_available_slots() >= 12 + driver.stop() def test_all_workers_fail(self): """Tests that training fails when all workers fail.""" @@ -361,6 +366,30 @@ def exec_command(slot_info, events): rendezvous.stop_server() + @mock.patch('horovod.run.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', 0.01) + @mock.patch('horovod.run.elastic.driver.ElasticDriver.host_assignments') + @mock.patch('horovod.run.elastic.driver.ElasticDriver.get_coordinator_info') + @mock.patch('horovod.run.elastic.driver.ElasticDriver.get_worker_client') + def test_send_notifications_without_assignments(self, mock_get_worker_client, mock_get_coordinator_info, + mock_host_assignments): + """Tests that notifications are still sent correctly even if host assignments cannot be generated.""" + slots = [{'host-1': 8, 'host-2': 4}, + {'host-1': 8, 'host-2': 4}, + {'host-2': 4}, + {'host-2': 4}, + {'host-2': 4, 'host-3': 12}] + discovery = mock.Mock() + discovery.find_available_hosts_and_slots.side_effect = sequence(slots) + + driver = ElasticDriver(mock.Mock(), discovery, min_np=8, max_np=12) + driver.wait_for_available_slots(min_np=16) + driver.stop() + + # On the second call, we should see the number of slots dip below the minimum, but we still want to ensure + # we notify workers every time there is a change, so in total we should observe 3 calls. + assert mock_get_worker_client.call_count == 3 + assert mock_get_coordinator_info.call_count == 3 + def test_order_available_hosts(self): """Tests the order is preserved for host assignment as available hosts are updated.""" # This will be a set in practice, but use a list here to guarantee order. @@ -440,10 +469,23 @@ def test_shutdown_on_initial_discovery_failure(self): discovery = mock.Mock() discovery.find_available_hosts_and_slots.side_effect = RuntimeError() - driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) - with pytest.raises(RuntimeError): - driver.wait_for_available_slots(min_np=2) - assert driver.finished() + discover_hosts = ElasticDriver._discover_hosts + + def wrapped_discover_hosts(obj): + try: + discover_hosts(obj) + except RuntimeError: + # Suppress the error message from the background discovery thread to clean up unit tests + pass + + try: + ElasticDriver._discover_hosts = wrapped_discover_hosts + driver = ElasticDriver(mock.Mock(), discovery, min_np=2, max_np=4) + with pytest.raises(RuntimeError): + driver.wait_for_available_slots(min_np=2) + assert driver.finished() + finally: + ElasticDriver._discover_hosts = discover_hosts if __name__ == "__main__":