From 5f72273f1e3cf7bd7ce873f5b6217c9018bc6937 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 16 Feb 2022 22:25:53 +0100 Subject: [PATCH] Fix indices in initial task-to-task registration Signed-off-by: Enrico Minack --- horovod/spark/runner.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/horovod/spark/runner.py b/horovod/spark/runner.py index 06232a728c..c74abb484c 100644 --- a/horovod/spark/runner.py +++ b/horovod/spark/runner.py @@ -167,21 +167,23 @@ def _notify_and_register_task_addresses(driver, settings, notify=True): if settings.verbose >= 2: logging.info('Initial Spark task registration is complete.') - def notify_and_register(index): - task_client = task_service.SparkTaskClient(index, - driver.task_addresses_for_driver(index), + task_indices = driver.task_indices() + task_pairs = zip(task_indices, task_indices[1:] + task_indices[0:1]) + + def notify_and_register(task_index, next_task_index): + task_client = task_service.SparkTaskClient(task_index, + driver.task_addresses_for_driver(task_index), settings.key, settings.verbose) if notify: task_client.notify_initial_registration_complete() - next_task_index = (index + 1) % settings.num_proc next_task_addresses = driver.all_task_addresses(next_task_index) task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses) driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses) - for index in driver.task_indices(): - in_thread(notify_and_register, (index,)) + for task_index, next_task_index in task_pairs: + in_thread(notify_and_register, (task_index, next_task_index)) driver.wait_for_task_to_task_address_updates(settings.start_timeout)