Skip to content

Commit

Permalink
platform string changed in XLA backend
Browse files Browse the repository at this point in the history
  • Loading branch information
weihanmines committed Apr 27, 2022
1 parent 19e9b48 commit 8502607
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions horovod/tensorflow/xla_mpi_ops.cc
Expand Up @@ -19,6 +19,7 @@
#include <thread>
#include <unordered_map>


#if TENSORFLOW_VERSION >= 2006000000

#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
Expand Down Expand Up @@ -570,7 +571,6 @@ XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "CUDA");
} // namespace tensorflow
} // namespace horovod

#endif // TENSORFLOW_VERSION >= 2006000000
#endif // HAVE_CUDA
#if HAVE_ROCM

Expand Down Expand Up @@ -1095,11 +1095,12 @@ void CallbackHVDAllreduceDone(hipStream_t stream, void** /*buffers*/,
VLOG(2) << "hvd-allreduce-done - End";
}

XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCm");
XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCm");
XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCM");
XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCM");

} // namespace
} // namespace tensorflow
} // namespace horovod
#endif //HAVE_ROCM
#endif // HAVE_GPU
#endif // TENSORFLOW_VERSION >= 2006000000

0 comments on commit 8502607

Please sign in to comment.