Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimised version of fused classifier + bugfixes(?) #150

Merged
merged 6 commits into from Apr 18, 2024

Conversation

ademeure
Copy link
Contributor

@ademeure ademeure commented Apr 16, 2024

This is a faster version of the cool new kernel from #117 (still /dev/cuda/ only). The biggest difference is it is optimised for doing one row per 1024-wide block rather than per 32-wide warp, which massively reduces the amount of data in-flight, and this allows us to basically always hit in the L2 cache for the 2nd read and this saves a huge amount of DRAM bandwidth. It also uses 128-bit loads which implicitly requires V to be a multiple of 4.

However, I believe the validation in classifier_fused.cu is currently not working correctly, possibly because the distribution of inputs results in all probabilities being ~zero, which is something I think ngc92 has a fix for.

Finally, I changed both old and new kernels to write to the gradient tensor, rather than modify it (read-modify-write) which costs us an extra DRAM access. I don't think accumulating gradients is required for this part of the network. All performance numbers below are with this change including Kernel1.

Performance:
Fused Kernel1 (old, block size invariant): ~6ms
Fused Kernel2 (new, 1024 threads per block): ~3.7ms
Forward Softmax Kernel7 (512 threads per block): 5.7ms (current default used in train_gpt2.cu)
Forward Softmax Kernel7 (1024 threads per block): ~3.7ms (optimal block size, got it wrong in my original PR for kernel7)

@karpathy
Copy link
Owner

oops this PR now conflicts because I merged the other one. Sounds good, agree it is ok to skip += here, but I think it should come with a comment on this subtle issue. Please add a brief one, or I'm also ok adding after.

@@ -169,6 +259,13 @@ void fused_classifier(int kernel_num, float* dlogits, float* losses,
case 1:
fused_classifier1(dlogits, losses, logits, dlosses, targets, B, T, V, block_size);
break;
case 2:
if((V % 4) != 0) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd move this inside fused_classifer2 at the top?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it after ngc92's changes that split V and P (padded V) so it's implicitly guaranteed by the function declaration.

// so even small L2s get some hits on the 2nd read of the same thread
for (int i = (V/4) + (threadIdx.x - blockDim.x); i >= 0; i -= blockDim.x) {
float4 v = x_vec4[i];
#pragma unroll
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not quite correct. While we assume V to be a multiple of 4, so that addresses are nicely aligned, the padding part still contains garbage, so we shouldn't take it into account when calculating the sum and max.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, added a simple bounds check (even if we could theoretically memset it once and make sure nothing modifies it, not worth the risk of bugs...)

@ademeure
Copy link
Contributor Author

Should be ready to be merged again now (at least to /dev/cuda/)!

I made it more flexible so it can now also work as a forward-only option by passing a pointer to "probs" and making "dlogits" NULL instead.

Combined with the bounds checking it's making the code a bit longer than I'd like, and with a bit more overhead in terms of GPU instructions, but given it's massively DRAM limited it seems maybe worth it to only have one function doing everything?

@karpathy karpathy merged commit 2d2f1df into karpathy:master Apr 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants