-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
/
nccl_device_communicator.cu
225 lines (201 loc) · 7.07 KB
/
nccl_device_communicator.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
/*!
* Copyright 2023 XGBoost contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include "nccl_device_communicator.cuh"
namespace xgboost {
namespace collective {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync)
: device_ordinal_{device_ordinal},
needs_sync_{needs_sync},
world_size_{GetWorldSize()},
rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (world_size_ == 1) {
return;
}
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather.
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world_size_)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
}
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (world_size_ == 1) {
return;
}
if (cuda_stream_) {
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
}
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
}
}
namespace {
ncclDataType_t GetNcclDataType(DataType const &data_type) {
ncclDataType_t result{ncclInt8};
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return result;
}
bool IsBitwiseOp(Operation const &op) {
return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR ||
op == Operation::kBitwiseXOR;
}
ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result{ncclMax};
switch (op) {
case Operation::kMax:
result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
default:
LOG(FATAL) << "Unsupported reduce operation.";
}
return result;
}
template <typename Func>
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
std::size_t size, cudaStream_t stream) {
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
result = func(result, device_buffer[rank * size + idx]);
}
out_buffer[idx] = result;
});
}
} // anonymous namespace
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
auto const size = count * GetTypeSize(data_type);
dh::caching_device_vector<char> buffer(size * world_size_);
auto *device_buffer = buffer.data().get();
// First gather data from all the workers.
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, cuda_stream_));
if (needs_sync_) {
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}
// Then reduce locally.
auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) {
case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size,
cuda_stream_);
break;
case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size,
cuda_stream_);
break;
case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size,
cuda_stream_);
break;
default:
LOG(FATAL) << "Not a bitwise reduce operation.";
}
}
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
if (IsBitwiseOp(op)) {
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
} else {
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
cuda_stream_));
}
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
}
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
segments->clear();
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
size_t offset = 0;
dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_));
offset += as_bytes;
}
dh::safe_nccl(ncclGroupEnd());
}
void NcclDeviceCommunicator::Synchronize() {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}
} // namespace collective
} // namespace xgboost
#endif