Skip to content

Commit

Permalink
fixed a comment string and added a preprocessor branch for ROCM
Browse files Browse the repository at this point in the history
Signed-off-by: weihanmines <wei.han3@amd.com>
  • Loading branch information
weihanmines committed May 25, 2022
1 parent 4f5a5db commit ba0c354
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion horovod/tensorflow/xla_mpi_ops.cc
Expand Up @@ -506,7 +506,7 @@ XLAPersistentBuffer::XLAPersistentBuffer(int device, int64_t size)
#elif HAVE_ROCM
HVD_GPU_CHECK(hipGetDevice(&restore_device));
HVD_GPU_CHECK(hipSetDevice(device));
// Simply call cudaMalloc for persistent buffer.
// Simply call hipMalloc for persistent buffer.
HVD_GPU_CHECK(hipMalloc((void**)&buffer_, size));
HVD_GPU_CHECK(hipSetDevice(restore_device));
#endif
Expand Down Expand Up @@ -597,8 +597,13 @@ void CallbackHVDAllreduceDone(gpuStream_t stream, void** /*buffers*/,
VLOG(2) << "hvd-allreduce-done - End";
}

#if HAVE_CUDA
XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "CUDA");
#elif HAVE_ROCM
XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCM");
XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCM");
#endif

} // namespace
} // namespace tensorflow
Expand Down

0 comments on commit ba0c354

Please sign in to comment.