diff --git a/horovod/spark/lightning/remote.py b/horovod/spark/lightning/remote.py index 2a1d77e951..fb60bc4813 100644 --- a/horovod/spark/lightning/remote.py +++ b/horovod/spark/lightning/remote.py @@ -121,6 +121,12 @@ def train(serialized_model): _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else 1.0 cuda_available = torch.cuda.is_available() + # We need to check all ranks have same device type for traning. + # Horovod doesn't support heterogeneous allreduce for gradients. + cuda_avail_list = hvd.allgather_object(cuda_available, name='device type') + if hvd.rank() == 0: + assert cuda_avail_list.count(cuda_available) == hvd.size(), "All ranks don't have same device type!" + if cuda_available: # Horovod: pin GPU to local rank or the assigned GPU from spark. torch.cuda.set_device(_get_assigned_gpu_or_default(default=hvd.local_rank())) diff --git a/horovod/spark/torch/remote.py b/horovod/spark/torch/remote.py index c346f32e9b..952c9db0ab 100644 --- a/horovod/spark/torch/remote.py +++ b/horovod/spark/torch/remote.py @@ -123,6 +123,12 @@ def train(serialized_model, optimizer_cls, model_opt_state_serialized, shuffle_buffer_size = user_shuffle_buffer_size cuda_available = torch.cuda.is_available() + # We need to check all ranks have same device type for traning. + # Horovod doesn't support heterogeneous allreduce for gradients. + cuda_avail_list = hvd.allgather_object(cuda_available, name='device type') + if hvd.rank() == 0: + assert cuda_avail_list.count(cuda_available) == hvd.size(), "All ranks don't have same device type!" + if cuda_available: # Horovod: pin GPU to local rank or the assigned GPU from spark. torch.cuda.set_device(_get_assigned_gpu_or_default(default=hvd.local_rank()))