New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimised version of fused classifier + bugfixes(?) #150
Optimised version of fused classifier + bugfixes(?) #150
Conversation
oops this PR now conflicts because I merged the other one. Sounds good, agree it is ok to skip += here, but I think it should come with a comment on this subtle issue. Please add a brief one, or I'm also ok adding after. |
dev/cuda/classifier_fused.cu
Outdated
@@ -169,6 +259,13 @@ void fused_classifier(int kernel_num, float* dlogits, float* losses, | |||
case 1: | |||
fused_classifier1(dlogits, losses, logits, dlosses, targets, B, T, V, block_size); | |||
break; | |||
case 2: | |||
if((V % 4) != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd move this inside fused_classifer2
at the top?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed it after ngc92's changes that split V and P (padded V) so it's implicitly guaranteed by the function declaration.
1cb5cf8
to
b225501
Compare
// so even small L2s get some hits on the 2nd read of the same thread | ||
for (int i = (V/4) + (threadIdx.x - blockDim.x); i >= 0; i -= blockDim.x) { | ||
float4 v = x_vec4[i]; | ||
#pragma unroll |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not quite correct. While we assume V to be a multiple of 4, so that addresses are nicely aligned, the padding part still contains garbage, so we shouldn't take it into account when calculating the sum and max.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, added a simple bounds check (even if we could theoretically memset it once and make sure nothing modifies it, not worth the risk of bugs...)
Should be ready to be merged again now (at least to /dev/cuda/)! I made it more flexible so it can now also work as a forward-only option by passing a pointer to "probs" and making "dlogits" NULL instead. Combined with the bounds checking it's making the code a bit longer than I'd like, and with a bit more overhead in terms of GPU instructions, but given it's massively DRAM limited it seems maybe worth it to only have one function doing everything? |
This is a faster version of the cool new kernel from #117 (still /dev/cuda/ only). The biggest difference is it is optimised for doing one row per 1024-wide block rather than per 32-wide warp, which massively reduces the amount of data in-flight, and this allows us to basically always hit in the L2 cache for the 2nd read and this saves a huge amount of DRAM bandwidth. It also uses 128-bit loads which implicitly requires V to be a multiple of 4.
However, I believe the validation in classifier_fused.cu is currently not working correctly, possibly because the distribution of inputs results in all probabilities being ~zero, which is something I think ngc92 has a fix for.
Finally, I changed both old and new kernels to write to the gradient tensor, rather than modify it (read-modify-write) which costs us an extra DRAM access. I don't think accumulating gradients is required for this part of the network. All performance numbers below are with this change including Kernel1.
Performance:
Fused Kernel1 (old, block size invariant): ~6ms
Fused Kernel2 (new, 1024 threads per block): ~3.7ms
Forward Softmax Kernel7 (512 threads per block): 5.7ms (current default used in train_gpt2.cu)
Forward Softmax Kernel7 (1024 threads per block): ~3.7ms (optimal block size, got it wrong in my original PR for kernel7)