diff --git a/horovod/tensorflow/xla_mpi_ops.cc b/horovod/tensorflow/xla_mpi_ops.cc index a52baa4e50..d2ea00c5d4 100644 --- a/horovod/tensorflow/xla_mpi_ops.cc +++ b/horovod/tensorflow/xla_mpi_ops.cc @@ -506,7 +506,7 @@ XLAPersistentBuffer::XLAPersistentBuffer(int device, int64_t size) #elif HAVE_ROCM HVD_GPU_CHECK(hipGetDevice(&restore_device)); HVD_GPU_CHECK(hipSetDevice(device)); - // Simply call cudaMalloc for persistent buffer. + // Simply call hipMalloc for persistent buffer. HVD_GPU_CHECK(hipMalloc((void**)&buffer_, size)); HVD_GPU_CHECK(hipSetDevice(restore_device)); #endif @@ -597,8 +597,13 @@ void CallbackHVDAllreduceDone(gpuStream_t stream, void** /*buffers*/, VLOG(2) << "hvd-allreduce-done - End"; } +#if HAVE_CUDA XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "CUDA"); +#elif HAVE_ROCM +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCM"); +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCM"); +#endif } // namespace } // namespace tensorflow