Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuDNN Flash Attention Forward & Backwards BF16 (+35% performance) #322

Merged
merged 13 commits into from
May 1, 2024

Conversation

ademeure
Copy link
Contributor

@ademeure ademeure commented May 1, 2024

RTX 4090 with BF16 and batch size of 24:

  • Baseline: 232.37ms (~106K tokens/s)
  • cuDNN: 170.77ms (~144K tokens/s) ==> +35% performance!
  • Compile time: Priceless(TM) (~2.7s to 48.7s - it's a big dependency and part of the reason PyTorch is so big!)

In the future we'd ideally want to implement our own version of this and hopefully avoid the requirement to include cuDNN for maximum performance, but for now it allows the GPU to go brrrrrrrrrrrrrrrrr! :)

Currently on by default with #define ENABLE_CUDNN at the top of train_gpt2.cu, this should probably become a Makefile change and become off by default. Potentially using cudnn-backend directly instead of cudnn-frontend would result in lower compile times, but that would be a lot of work and frontend is what NVIDIA recommends these days.

There are 11 "#if(n)def ENABLE_CUDNN" lines in train_gpt2.cu:

  1. 5 of them are to reduce the size of the memory allocations as much as possible (these are quite unfortunate...)
  2. 2 of them are for creating & destroying the handle and free the workspace memory if any was allocated
  3. 1 in the middle of the file for the new functions (should this be in a separate file?)
  4. 1 at the top of the file for the handle etc. (should this be in a separate file?)
  5. 1 in gpt2_forward()
  6. 1 in gpt2_backward()

Also 2 ifdefs in both test_gpt2.cu and profile_gpt2.cu just to create/destroy the handle and workspace memory.

Currently missing the /dev/cuda/attention_backward.cu implementation (only the forward is in /dev/cuda/) but it should be easy for someone else to do it if needed, and hopefully not a blocker to integrate this.

@karpathy karpathy merged commit 1147983 into karpathy:master May 1, 2024
3 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants