Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero 2 #593

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
19 changes: 18 additions & 1 deletion llmc/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ __global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n, ptrdiff_t
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// need to try grid stride looping for more perf later
if (idx < n) {
dst[idx + stride_dst * blockIdx.y] = cast_value<Td, Ts>(src[idx + stride_src * blockIdx.y]);
dst[idx + stride_dst * blockIdx.y] = cast_value<Td>(src[idx + stride_src * blockIdx.y]);
ngc92 marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down Expand Up @@ -260,4 +260,21 @@ __device__ __forceinline__ void stochastic_rounding(float in, float *out, unsign
*out = in; // dummy function for when floatX is float (FP32 mode)
}

// Add two (potentially low-precision) vectors of size `n` together using stochastic rounding
template<class T>
__global__ void vector_add(T* dst, const T* src, size_t n, unsigned seed) {
using t128 = Packed128<T>;
assert(n % t128::size == 0);
ptrdiff_t idx = ((ptrdiff_t)blockIdx.x * blockDim.x + threadIdx.x) * t128::size;
if (idx < n) {
t128 src_v = load128cs(src + idx);
t128 dst_v = load128cs(dst + idx);
for(int k = 0; k < t128::size; ++k) {
float sum = (float)dst_v[k] + (float)src_v[k];
stochastic_rounding(sum, &dst_v[k], seed + idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what benefit do we have from stochastic rounding here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you do lots of gradient accumulation, you will incur more and more error because you end up adding small new gradients to the buffer of large accumulated gradients. With stochastic rounding, we at least stay correct in expectation, and will not systematically ignore small changes.

}
store128cs(dst + idx, dst_v);
}
}

#endif
54 changes: 41 additions & 13 deletions llmc/zero.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -506,27 +506,40 @@ ShardInfo multi_gpu_get_shard_offset(size_t elements, const MultiGpuConfig* conf
}
}

// Block NCCL stream until computations on compute_stream are done, then aggregate multiple pointers in an NCCL group.
void nccl_wait_on_compute(MultiGpuConfig* config, cudaStream_t compute_stream) {
// mark an event on the compute stream, and immediately wait on this in the nccl stream
// this means that the nccl stream won't start executing before all compute kernels that
// have been submitted before this point have finished.
// by using an event instead of cudaSyncStream, we avoid having to synchronize the host, and
// can enqueue new work to the GPU right away.
#ifdef MULTI_GPU
cudaCheck(cudaEventRecord(config->compute_nccl_sync, compute_stream));
cudaCheck(cudaStreamWaitEvent(config->nccl_stream, config->compute_nccl_sync));
#endif
}

void compute_wait_on_nccl(MultiGpuConfig* config, cudaStream_t compute_stream) {
// mark an event on the nccl stream, and immediately wait on this in the compute stream
#ifdef MULTI_GPU
cudaCheck(cudaEventRecord(config->compute_nccl_sync, config->nccl_stream));
cudaCheck(cudaStreamWaitEvent(compute_stream, config->compute_nccl_sync));
#endif
}

// Aggregate multiple pointers in an NCCL group.
// This can work either as an all-reduce (i.e., no ZeRo), or a reduce-scatter (ZeRO 1).
// The awkward `(&pointers)[N]` syntax ensures we are capturing the parameters as sized arrays, so that it becomes impossible
// to call this function if pointers and pointers_sizes do not match.
template<int N>
void multi_gpu_async_reduce_gradient(
floatX* const (&pointers)[N], const size_t (&pointers_sizes)[N],
MultiGpuConfig* config, cudaStream_t compute_stream) {
MultiGpuConfig* config) {
if (config->num_processes == 1) {
return; // no multi-GPU, just exit.
}

#ifdef MULTI_GPU
NVTX_RANGE_FN();
// mark an event on the compute stream, and immediately wait on this in the nccl stream
// this means that the nccl stream won't start executing before all compute kernels that
// have been submitted before this point have finished.
// by using an event instead of cudaSyncStream, we avoid having to synchronize the host, and
// can enqueue new work to the GPU right away.
cudaCheck(cudaEventRecord(config->compute_nccl_sync, compute_stream));
cudaCheck(cudaStreamWaitEvent(config->nccl_stream, config->compute_nccl_sync));
ncclCheck(ncclGroupStart()); // NCCL group: aggregate all pointers in a single NCCL GPU kernel.
for (int i = 0; i < N; ++i) {
if(config->zero_stage == 0) {
Expand All @@ -536,7 +549,7 @@ void multi_gpu_async_reduce_gradient(
ncclFloatX, ncclAvg,
config->nccl_comm, config->nccl_stream
));
} else if(config->zero_stage == 1) {
} else if(config->zero_stage == 1 || config->zero_stage == 2) {
assert(pointers_sizes[i] % config->num_processes == 0);
size_t shard_size = pointers_sizes[i] / config->num_processes;
ptrdiff_t shard_offset = (ptrdiff_t)shard_size * config->process_rank;
Expand All @@ -562,18 +575,18 @@ void set_zero_configs(MultiGpuConfig* config, int zero_stage, size_t total_param
if (zero_stage == 0) {
printf0("| Zero Optimization is disabled |\n");
}
else if (zero_stage == 1) {
else if (zero_stage == 1 || zero_stage == 2) {
if (total_parameters % config->num_processes != 0) {
printf0("| Zero Optimization is disabled, Can't equally partition parameters |\n");
config->zero_stage = 0;
}
else {
config->zero_stage = 1;
config->zero_stage = zero_stage;
config->shard_num_parameters = total_parameters / config->num_processes;
}
}
else{
printf0("| Disabling Zero Optimization, Zero Stage2 and Stage3 are not yet supported |\n");
printf0("| Disabling Zero Optimization, Zero Stage3 is not yet supported |\n");
config->zero_stage = 0;
}
}
Expand All @@ -593,5 +606,20 @@ float multi_gpu_cpu_float_sum(float value, MultiGpuConfig* config) {
#endif
}

template<int N>
void zero2_accumulate_grad(floatX* const (&dst)[N], floatX* const (&src)[N], const size_t (&nelem)[N], int layer, unsigned seed, MultiGpuConfig* config) {
#ifdef MULTI_GPU
cudaStream_t stream = config->nccl_stream;
for(int i = 0; i < N; ++i) {
size_t n = nelem[i] / multi_gpu_config.num_processes;
vector_add<<<CEIL_DIV(n, 512), 512, 0, stream>>>(dst[i] + layer * n,
src[i] + multi_gpu_config.process_rank * n,
n, seed + i);
cudaCheck(cudaGetLastError());
cudaCheck(cudaMemsetAsync(src[i], 0, nelem[i] * sizeof(floatX), stream));
}
#endif
}

#endif

Loading