-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full BF16 including layernorms by default (minimising number of BF16 atomics) #272
Conversation
I'm assuming this is not meant to merge as-is? Would it make sense to put most of this into dev/cuda and then cherry-pick the layernorm we want to use, as usual, into train_gpt2.cu? |
train_gpt2.cu
Outdated
__global__ void layernorm_backward_kernel3(Tdinp* dinp, Tparams* dweight, Tparams* dbias, | ||
const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, | ||
int B, int T, int C) { | ||
extern __shared__ float shared[]; // size = 2 * C |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(2*C+1) * 4?
train_gpt2.cu
Outdated
dbias_shared[i] = 0.0f; | ||
dweight_shared[i] = 0.0f; | ||
} | ||
uint *tmp_flag = (uint*)(shared + C*2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this actually used in this kernel?
train_gpt2.cu
Outdated
dbias_shared[i] = 0.0f; | ||
dweight_shared[i] = 0.0f; | ||
} | ||
uint *tmp_flag = (uint*)(shared + C*2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
Thanks! Fixed those bugs, and only kept Kernel6 in train_gpt2.cu, adding all the kernels to /dev/cuda/ I think this might be the first /dev/cuda change with BF16, so I added support for comparing CPU and GPU data of different types (converting GPU data to CPU type before comparison) in validate_result(). |
dev/cuda/common.h
Outdated
cudaCheck(cudaMemcpy(out_gpu, device_result, num_elements * sizeof(T), cudaMemcpyDeviceToHost)); | ||
template<class D, class T> | ||
void validate_result(D* device_result, const T* cpu_reference, const char* name, std::size_t num_elements, T tolerance=1e-4) { | ||
D* out_gpu = (D*)malloc(num_elements * sizeof(T)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sizeof(D)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, fixed!
I added 4 different new versions of layernorm_backward_kernel, performance is best for:
We probably just want to integrate Kernel 6 but might want to add all of them to /dev/cuda/ in the future. I haven't fixed the BF16 atomics in encoder_backward yet but the performance penalty of that is much much smaller, it might be worth doing it manually using atomicCAS so we can implement stochastic rounding there as well though.
Performance on my RTX 4090 is a tiny bit faster (potentially noise) than with the previous mixed FP32/BF16.