-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
cuda_utils.hu
293 lines (241 loc) · 9.46 KB
/
cuda_utils.hu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
/*!
* Copyright (c) 2020-2021 IBM Corporation, Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_CUDA_CUDA_UTILS_H_
#define LIGHTGBM_CUDA_CUDA_UTILS_H_
#ifdef USE_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <nccl.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/meta.h>
#include <algorithm>
#include <vector>
#include <cmath>
namespace LightGBM {
typedef unsigned long long atomic_add_long_t;
#define CUDASUCCESS_OR_FATAL(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) {
if (code != cudaSuccess) {
LightGBM::Log::Fatal("[CUDA] %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}
#define CUDASUCCESS_OR_FATAL_OUTER(ans) { gpuAssert((ans), file, line); }
#define NCCLCHECK(cmd) do { \
ncclResult_t r = cmd; \
if (r!= ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", \
__FILE__,__LINE__,ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while(0)
void SetCUDADevice(int gpu_device_id, const char* file, int line);
int GetCUDADevice(const char* file, int line);
template <typename T>
void AllocateCUDAMemory(T** out_ptr, size_t size, const char* file, const int line) {
void* tmp_ptr = nullptr;
CUDASUCCESS_OR_FATAL_OUTER(cudaMalloc(&tmp_ptr, size * sizeof(T)));
*out_ptr = reinterpret_cast<T*>(tmp_ptr);
}
template <typename T>
void CopyFromHostToCUDADevice(T* dst_ptr, const T* src_ptr, size_t size, const char* file, const int line) {
void* void_dst_ptr = reinterpret_cast<void*>(dst_ptr);
const void* void_src_ptr = reinterpret_cast<const void*>(src_ptr);
size_t size_in_bytes = size * sizeof(T);
CUDASUCCESS_OR_FATAL_OUTER(cudaMemcpy(void_dst_ptr, void_src_ptr, size_in_bytes, cudaMemcpyHostToDevice));
}
template <typename T>
void InitCUDAMemoryFromHostMemory(T** dst_ptr, const T* src_ptr, size_t size, const char* file, const int line) {
AllocateCUDAMemory<T>(dst_ptr, size, file, line);
CopyFromHostToCUDADevice<T>(*dst_ptr, src_ptr, size, file, line);
}
template <typename T>
void CopyFromCUDADeviceToHost(T* dst_ptr, const T* src_ptr, size_t size, const char* file, const int line) {
void* void_dst_ptr = reinterpret_cast<void*>(dst_ptr);
const void* void_src_ptr = reinterpret_cast<const void*>(src_ptr);
size_t size_in_bytes = size * sizeof(T);
CUDASUCCESS_OR_FATAL_OUTER(cudaMemcpy(void_dst_ptr, void_src_ptr, size_in_bytes, cudaMemcpyDeviceToHost));
}
template <typename T>
void CopyFromCUDADeviceToHostAsync(T* dst_ptr, const T* src_ptr, size_t size, cudaStream_t stream, const char* file, const int line) {
void* void_dst_ptr = reinterpret_cast<void*>(dst_ptr);
const void* void_src_ptr = reinterpret_cast<const void*>(src_ptr);
size_t size_in_bytes = size * sizeof(T);
CUDASUCCESS_OR_FATAL_OUTER(cudaMemcpyAsync(void_dst_ptr, void_src_ptr, size_in_bytes, cudaMemcpyDeviceToHost, stream));
}
template <typename T>
void CopyFromCUDADeviceToCUDADevice(T* dst_ptr, const T* src_ptr, size_t size, const char* file, const int line) {
void* void_dst_ptr = reinterpret_cast<void*>(dst_ptr);
const void* void_src_ptr = reinterpret_cast<const void*>(src_ptr);
size_t size_in_bytes = size * sizeof(T);
CUDASUCCESS_OR_FATAL_OUTER(cudaMemcpy(void_dst_ptr, void_src_ptr, size_in_bytes, cudaMemcpyDeviceToDevice));
}
template <typename T>
void CopyFromCUDADeviceToCUDADeviceAsync(T* dst_ptr, const T* src_ptr, size_t size, const char* file, const int line) {
void* void_dst_ptr = reinterpret_cast<void*>(dst_ptr);
const void* void_src_ptr = reinterpret_cast<const void*>(src_ptr);
size_t size_in_bytes = size * sizeof(T);
CUDASUCCESS_OR_FATAL_OUTER(cudaMemcpyAsync(void_dst_ptr, void_src_ptr, size_in_bytes, cudaMemcpyDeviceToDevice));
}
void SynchronizeCUDADevice(const char* file, const int line);
void SynchronizeCUDAStream(cudaStream_t cuda_stream, const char* file, const int line);
template <typename T>
void SetCUDAMemory(T* dst_ptr, int value, size_t size, const char* file, const int line) {
CUDASUCCESS_OR_FATAL_OUTER(cudaMemset(reinterpret_cast<void*>(dst_ptr), value, size * sizeof(T)));
SynchronizeCUDADevice(file, line);
}
template <typename T>
void DeallocateCUDAMemory(T** ptr, const char* file, const int line) {
if (*ptr != nullptr) {
CUDASUCCESS_OR_FATAL_OUTER(cudaFree(reinterpret_cast<void*>(*ptr)));
*ptr = nullptr;
}
}
void PrintLastCUDAError();
template <typename T>
class CUDAVector {
public:
CUDAVector() {
size_ = 0;
data_ = nullptr;
}
explicit CUDAVector(size_t size) {
size_ = size;
AllocateCUDAMemory<T>(&data_, size_, __FILE__, __LINE__);
}
void Resize(size_t size) {
if (size == size_) {
return;
}
if (size == 0) {
Clear();
return;
}
T* new_data = nullptr;
AllocateCUDAMemory<T>(&new_data, size, __FILE__, __LINE__);
if (size_ > 0 && data_ != nullptr) {
const size_t size_for_old_content = std::min<size_t>(size_, size);
CopyFromCUDADeviceToCUDADevice<T>(new_data, data_, size_for_old_content, __FILE__, __LINE__);
}
DeallocateCUDAMemory<T>(&data_, __FILE__, __LINE__);
data_ = new_data;
size_ = size;
}
void InitFromHostVector(const std::vector<T>& host_vector) {
Resize(host_vector.size());
CopyFromHostToCUDADevice(data_, host_vector.data(), host_vector.size(), __FILE__, __LINE__);
}
void Clear() {
if (size_ > 0 && data_ != nullptr) {
DeallocateCUDAMemory<T>(&data_, __FILE__, __LINE__);
}
size_ = 0;
}
void PushBack(const T* values, size_t len) {
T* new_data = nullptr;
AllocateCUDAMemory<T>(&new_data, size_ + len, __FILE__, __LINE__);
if (size_ > 0 && data_ != nullptr) {
CopyFromCUDADeviceToCUDADevice<T>(new_data, data_, size_, __FILE__, __LINE__);
}
CopyFromCUDADeviceToCUDADevice<T>(new_data + size_, values, len, __FILE__, __LINE__);
DeallocateCUDAMemory<T>(&data_, __FILE__, __LINE__);
size_ += len;
data_ = new_data;
}
size_t Size() {
return size_;
}
~CUDAVector() {
DeallocateCUDAMemory<T>(&data_, __FILE__, __LINE__);
}
std::vector<T> ToHost() {
std::vector<T> host_vector(size_);
if (size_ > 0 && data_ != nullptr) {
CopyFromCUDADeviceToHost(host_vector.data(), data_, size_, __FILE__, __LINE__);
}
return host_vector;
}
T* RawData() const {
return data_;
}
void SetValue(int value) {
SetCUDAMemory<T>(data_, value, size_, __FILE__, __LINE__);
}
const T* RawDataReadOnly() const {
return data_;
}
private:
T* data_;
size_t size_;
};
template <typename T>
static __device__ T SafeLog(T x) {
if (x > 0) {
return std::log(x);
} else {
return -INFINITY;
}
}
class NCCLInfo {
public:
NCCLInfo() {}
virtual void SetNCCLInfo(
ncclComm_t nccl_communicator,
int nccl_gpu_rank,
int local_gpu_rank,
int gpu_device_id,
data_size_t global_num_data) {
nccl_communicator_ = nccl_communicator;
nccl_gpu_rank_ = nccl_gpu_rank;
local_gpu_rank_ = local_gpu_rank;
gpu_device_id_ = gpu_device_id;
global_num_data_ = global_num_data;
}
protected:
ncclComm_t nccl_communicator_ = nullptr;
int nccl_gpu_rank_ = -1;
int local_gpu_rank_ = -1;
int gpu_device_id_ = -1;
int num_gpu_in_node_ = 0;
data_size_t global_num_data_ = 0;
};
cudaStream_t CUDAStreamCreate();
void CUDAStreamDestroy(cudaStream_t cuda_stream);
void NCCLGroupStart();
void NCCLGroupEnd();
template <typename T>
void NCCLAllReduce(const T* send_buffer, T* recv_buffer, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
NCCLCHECK(ncclAllReduce(reinterpret_cast<const void*>(send_buffer), reinterpret_cast<void*>(recv_buffer), count, datatype, op, comm, stream));
}
template <typename T>
void NCCLAllReduce(const T* send_buffer, T* recv_buffer, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm) {
cudaStream_t nccl_stream;
CUDASUCCESS_OR_FATAL(cudaStreamCreate(&nccl_stream));
NCCLCHECK(ncclAllReduce(reinterpret_cast<const void*>(send_buffer), reinterpret_cast<void*>(recv_buffer), count, datatype, op, comm, nccl_stream));
CUDASUCCESS_OR_FATAL(cudaStreamSynchronize(nccl_stream));
CUDASUCCESS_OR_FATAL(cudaStreamDestroy(nccl_stream));
}
template <typename T>
T NCCLAllReduce(T send_value, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
CUDAVector<T> send_buffer(1);
CopyFromHostToCUDADevice<T>(send_buffer.RawData(), &send_value, 1, __FILE__, __LINE__);
NCCLAllReduce<T>(send_buffer.RawDataReadOnly(), send_buffer.RawData(), 1, datatype, op, comm, stream);
T recv_value = 0;
CopyFromCUDADeviceToHost<T>(&recv_value, send_buffer.RawDataReadOnly(), 1, __FILE__, __LINE__);
return recv_value;
}
template <typename T>
T NCCLAllReduce(T send_value, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm) {
CUDAVector<T> send_buffer(1);
CopyFromHostToCUDADevice<T>(send_buffer.RawData(), &send_value, 1, __FILE__, __LINE__);
NCCLAllReduce<T>(send_buffer.RawDataReadOnly(), send_buffer.RawData(), 1, datatype, op, comm);
T recv_value = 0;
CopyFromCUDADeviceToHost<T>(&recv_value, send_buffer.RawDataReadOnly(), 1, __FILE__, __LINE__);
return recv_value;
}
} // namespace LightGBM
#endif // USE_CUDA
#endif // LIGHTGBM_CUDA_CUDA_UTILS_H_