diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index 7ffe00b8cd824..89f5dcb222e63 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -44,7 +44,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place, bool sync_op, bool use_calc_stream) : TaskStream(rank, comm_type, sync_op, use_calc_stream), - comm_event_(place), + comm_event_(place, platform::GenerateDeviceEventFlag()), task_place_(place) {} ProcessGroupNCCL::NCCLTask::~NCCLTask() = default; @@ -506,7 +506,9 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, auto nccl_comm_ctx = this->GetCommContext(); comm_ctx->set_nccl_comm(nccl_comm_ctx->GetNcclComm()); - place_to_calc_event_.emplace(place_key, place); + place_to_calc_event_.emplace( + place_key, + platform::DeviceEvent(place, platform::GenerateDeviceEventFlag())); place_to_calc_ctx_.emplace(place_key, calc_ctx); place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx)); @@ -592,7 +594,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask( CommType CommType, const std::vector& inputs) : TaskStream(rank, inputs, CommType), - comm_event_(places[0]), + comm_event_(places[0], platform::GenerateDeviceEventFlag()), task_place_(places[0]) {} // create NCCLManager cache for places_key @@ -636,7 +638,9 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( GroupEnd(); // TODO(sunyilun): for compatibility, will be removed later - place_to_calc_event_.emplace(places_key, places[0]); + place_to_calc_event_.emplace( + places_key, + platform::DeviceEvent(places[0], platform::GenerateDeviceEventFlag())); place_to_calc_ctx_.emplace( places_key, static_cast( diff --git a/paddle/fluid/platform/device_event_base.h b/paddle/fluid/platform/device_event_base.h index e2de1e5a9abe3..03fd7d4bb13f0 100644 --- a/paddle/fluid/platform/device_event_base.h +++ b/paddle/fluid/platform/device_event_base.h @@ -55,7 +55,7 @@ enum EventStatus { class DeviceEvent { public: - explicit DeviceEvent(const platform::Place& place, unsigned int flag = 0) + explicit DeviceEvent(const platform::Place& place, unsigned int flag) : event_(), place_(place), flag_(flag) { type_id_ = DeviceTypeToId(platform::Place2DeviceType(place)); PADDLE_ENFORCE_LT(type_id_, diff --git a/paddle/fluid/platform/device_event_test.cc b/paddle/fluid/platform/device_event_test.cc index 7dfacc66437ae..75c0e65352f52 100644 --- a/paddle/fluid/platform/device_event_test.cc +++ b/paddle/fluid/platform/device_event_test.cc @@ -37,7 +37,7 @@ TEST(DeviceEvent, CUDA) { ASSERT_NE(context, nullptr); // case 1. test for event_creator - DeviceEvent event(place); + DeviceEvent event(place, paddle::platform::GenerateDeviceEventFlag()); ASSERT_NE(event.GetEvent().get(), nullptr); bool status = event.Query(); ASSERT_EQ(status, true); @@ -85,8 +85,9 @@ TEST(DeviceEvent, CUDA) { auto* context = static_cast(pool.Get(place)); ASSERT_NE(context, nullptr); + // case 1. test for event_creator - DeviceEvent event(place); + DeviceEvent event(place, paddle::platform::GenerateDeviceEventFlag()); ASSERT_NE(event.GetEvent().get(), nullptr); bool status = event.Query(); ASSERT_EQ(status, true); @@ -127,7 +128,7 @@ TEST(DeviceEvent, CUDA) { TEST(DeviceEvent, CPU) { using paddle::platform::CPUPlace; auto place = CPUPlace(); - DeviceEvent event(place); + DeviceEvent event(place, paddle::platform::GenerateDeviceEventFlag()); auto& pool = DeviceContextPool::Instance(); auto* context = pool.Get(place);