Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient clipping by global norm #315

Merged
merged 6 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion dev/cuda/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-
$(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@

# Build all targets
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm
all: $(TARGETS)

# Individual targets: forward pass
Expand Down Expand Up @@ -48,6 +48,7 @@ matmul_backward: matmul_backward.cu

# Update kernels
adamw: adamw.cu
global_norm: global_norm.cu

# NCCL communication kernels
nccl_all_reduce: nccl_all_reduce.cu
Expand Down
181 changes: 181 additions & 0 deletions dev/cuda/global_norm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
Kernels for a global norm.
Global norm in this context means that we want to calculate a single norm cooperatively using all avalailable SMs, instead
of multiple norms that can be handled by separate blocks.

Compile example:
nvcc -O3 --use_fast_math global_norm.cu -o global_norm
*/


#include <assert.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

// turn on bf16 as default, done up here for now
#define ENABLE_BF16
#include "common.h"


float global_norm_cpu(const float* data, size_t count) {
// accumulate in double so we have an accurate numerical reference
double acc = 0.0;
for(size_t i = 0; i < count; ++i) {
acc += (double)data[i] * (double)data[i];
}
return (float)acc;
}


template<class T>
__global__ void norm_kernel1(float* out, const T* data, size_t count) {
// we want as few atomics as possible, so each block tries to do
// the maximum amount of work (so no fixed chunk, but instead iterating
// until we run out of data), and then we reduce inside the block
// and finally have just one atomic per block.
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);

__shared__ float block_result[32];

// out will be updated atomically from all thread blocks
size_t index = threadIdx.x + blockDim.x * blockIdx.x;
size_t grid_width = blockDim.x * gridDim.x;
float accumulator = 0.f;
for(size_t i = index; i < count; i += grid_width) {
accumulator += (float)data[i] * (float)data[i];
}
// warp-level reduce
float warp_result = cg::reduce(warp, accumulator, cg::plus<float>{});
block_result[warp.meta_group_rank()] = warp_result;
block.sync();
if(warp.meta_group_rank() == 0) {
float gather = warp.thread_rank() < warp.meta_group_size() ? block_result[warp.thread_rank()] : 0.f;
float block_sum = cg::reduce(warp, gather, cg::plus<float>{});
if(warp.thread_rank() == 0) {
atomicAdd(out, block_sum);
}
}
}



template<class T>
__global__ void norm_kernel2(float* out, const T* data, size_t count) {
// no shared memory; but one atomic per warp instead of per block
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);

// out will be updated atomically from all thread blocks
size_t index = threadIdx.x + blockDim.x * blockIdx.x;
size_t grid_width = blockDim.x * gridDim.x;
float accumulator = 0.f;
for(size_t i = index; i < count; i += grid_width) {
accumulator += (float)data[i] * (float)data[i];
}

// warp-level reduce
float warp_result = cg::reduce(warp, accumulator, cg::plus<float>{});
// and atomic in global buffer
if(warp.thread_rank() == 0) {
atomicAdd(out, warp_result);
}
}



template<typename T>
void global_norm1(float* out, const T* values, size_t count, int block_size) {
// launch just enough blocks to fill the grid. deliberately no DIV_CEIL.
// having one block less than possible is a tiny performance hit, having
// one block too many is catastrophic, since it only can start once all the other
// blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512
// on all gpus, so the division really is going to be exact.
const int grid_size = cuda_threads_per_SM * cuda_num_SMs / block_size;
assert(grid_size > 0); // gives a better error than letting the call below fail
norm_kernel1<<<grid_size, block_size>>>(out, values, count);
cudaCheck(cudaGetLastError());
}

template<typename T>
void global_norm2(float* out, const T* values, size_t count, int block_size) {
// ditto
const int grid_size = cuda_threads_per_SM * cuda_num_SMs / block_size;
assert(grid_size > 0); // gives a better error than letting the call below fail
norm_kernel2<<<grid_size, block_size>>>(out, values, count);
cudaCheck(cudaGetLastError());
}

void global_norm(int kernel_num, float* out, const floatX* values, size_t count, int block_size) {
switch (kernel_num) {
case 1:
return global_norm1(out, values, count, block_size);
case 2:
return global_norm2(out, values, count, block_size);
}
}

int main(int argc, const char **argv) {
setup_main();

int C = 768;
int L = 12;

size_t num_params = (size_t)(C * 4*C + C*C) * 2 * L;

// create host memory of random numbers
float* inp = make_random_float(num_params);
// scale them down
for(size_t i = 0; i < num_params; ++i) {
inp[i] *= 1e-3;
}

// read kernel_num from command line
int kernel_num = 1;
if (argc > 1) {
kernel_num = atoi(argv[1]);
}
printf("Using kernel %d\n", kernel_num);

// first check the correctness of the kernel
float out = global_norm_cpu(inp, num_params);

// move to GPU
float* d_out;
floatX* d_inp;
cudaCheck(cudaMalloc(&d_out, sizeof(float)));
cudaCheck(cudaMalloc(&d_inp, num_params * sizeof(floatX)));
cudaCheck(memcpy_convert(d_inp, inp, num_params));

int block_sizes[] = {32, 64, 128, 256, 512, 768, 1024};
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];
printf("Checking block size %d.\n", block_size);
cudaCheck(cudaMemset(d_out, 0, sizeof(float)));
global_norm(kernel_num, d_out, d_inp, num_params, block_size);
validate_result(d_out, &out, "out", 1, 1e-2f);
}

printf("All results match. Starting benchmarks.\n\n");

for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];

int repeat_times = 1000;

float elapsed_time = benchmark_kernel(repeat_times, global_norm,
kernel_num, d_out, d_inp,
num_params, block_size);
size_t memory_ops = num_params * sizeof(floatX);
float memory_bandwidth = memory_ops / elapsed_time / 1e6;

printf("block_size %4d | time %.4f ms | bandwidth %.2f GB/s\n", block_size, elapsed_time, memory_bandwidth);
}

// free memory
free(inp);
cudaCheck(cudaFree(d_out));
cudaCheck(cudaFree(d_inp));
}
2 changes: 1 addition & 1 deletion profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ int main(int argc, char *argv[]) {
gpt2_forward(&model, x, y, B, T);
gpt2_zero_grad(&model);
gpt2_backward(&model);
gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1, &multi_gpu_config);
gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1.f, 1, &multi_gpu_config);
cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings

// free
Expand Down
2 changes: 1 addition & 1 deletion profile_gpt2cu.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
# the classifier part, counts only once
pass_name = "cls"
phase = "bwd"
elif "adamw" in kernel:
elif "adamw" in kernel or "global_norm" in kernel:
# encoder layer or adam
pass_name = "opt"
# before the first optimizer run, we create weight copies.
Expand Down
22 changes: 11 additions & 11 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ int main(int argc, char *argv[]) {
allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", 2e-2f);
}

gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.01f, step+1, &multi_gpu_config);
gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.01f, 1.f, step+1, &multi_gpu_config);

// print the timing information at the end
printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000);
Expand All @@ -283,16 +283,16 @@ int main(int argc, char *argv[]) {

// expected losses are as follows, from Python
float expected_losses[10] = {
5.270007133483887,
4.059706687927246,
3.3751230239868164,
2.8007826805114746,
2.315382242202759,
1.8490285873413086,
1.3946564197540283,
0.9991465210914612,
0.6240804195404053,
0.37651097774505615
5.2700,
4.0607,
3.3166,
2.7115,
2.1702,
1.6349,
1.1419,
0.7038,
0.3769,
0.1743
};

// compare
Expand Down