diff --git a/horovod/tensorflow/xla_mpi_ops.cc b/horovod/tensorflow/xla_mpi_ops.cc index 7c06232584..5894999d03 100644 --- a/horovod/tensorflow/xla_mpi_ops.cc +++ b/horovod/tensorflow/xla_mpi_ops.cc @@ -19,6 +19,7 @@ #include #include + #if TENSORFLOW_VERSION >= 2006000000 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -570,7 +571,6 @@ XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "CUDA"); } // namespace tensorflow } // namespace horovod -#endif // TENSORFLOW_VERSION >= 2006000000 #endif // HAVE_CUDA #if HAVE_ROCM @@ -1095,11 +1095,12 @@ void CallbackHVDAllreduceDone(hipStream_t stream, void** /*buffers*/, VLOG(2) << "hvd-allreduce-done - End"; } -XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCm"); -XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCm"); +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduce, "ROCM"); +XLA_REGISTER_CUSTOM_CALL_TARGET(CallbackHVDAllreduceDone, "ROCM"); } // namespace } // namespace tensorflow } // namespace horovod #endif //HAVE_ROCM #endif // HAVE_GPU +#endif // TENSORFLOW_VERSION >= 2006000000