Skip to content

Commit

Permalink
Fix indices in initial task-to-task registration
Browse files Browse the repository at this point in the history
Signed-off-by: Enrico Minack <github@enrico.minack.dev>
  • Loading branch information
EnricoMi committed Feb 16, 2022
1 parent 046c071 commit 5f72273
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions horovod/spark/runner.py
Expand Up @@ -167,21 +167,23 @@ def _notify_and_register_task_addresses(driver, settings, notify=True):
if settings.verbose >= 2:
logging.info('Initial Spark task registration is complete.')

def notify_and_register(index):
task_client = task_service.SparkTaskClient(index,
driver.task_addresses_for_driver(index),
task_indices = driver.task_indices()
task_pairs = zip(task_indices, task_indices[1:] + task_indices[0:1])

def notify_and_register(task_index, next_task_index):
task_client = task_service.SparkTaskClient(task_index,
driver.task_addresses_for_driver(task_index),
settings.key, settings.verbose)

if notify:
task_client.notify_initial_registration_complete()

next_task_index = (index + 1) % settings.num_proc
next_task_addresses = driver.all_task_addresses(next_task_index)
task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses)
driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses)

for index in driver.task_indices():
in_thread(notify_and_register, (index,))
for task_index, next_task_index in task_pairs:
in_thread(notify_and_register, (task_index, next_task_index))

driver.wait_for_task_to_task_address_updates(settings.start_timeout)

Expand Down

0 comments on commit 5f72273

Please sign in to comment.