-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flashattention #285
base: master
Are you sure you want to change the base?
Flashattention #285
Conversation
This is very cool work!! Questions:
|
Would it be possible to also add the commands with the params used in the profiling script to do the comparison, I have access to run it on a H100 and an A100 but when running the linked profiling command it was asking for the sets of params. Edit: Nvm it was just there was a duplicate masking param. Would it be possible to share the code used to generate the graphs in this PR, it would make it much faster to replicate the results |
@ChrisDryden, thanks for offering to help with the H100 and A100 performances! You can see all the settings and stats I collected in this spreadsheet. Feel free to follow up as comments on Github/spreadsheet or DM us the stats on discord. Two tests I did as shown in the graphs
I used a chatbot to generate the graph for now 😅 but sure I can provide a profiling script. |
Thanks @karpathy
For reference: If I measure without mallocs (only time of the actual kernel) on the 3060 I get the following times:
@ChrisDryden Getting results on these gpu's would be awesome! |
I was running with the A100 in the dev kernel comparing the 4th kernel to the 6th kernel with the default params. With B=1 and T=8192, C = 768, NH = 12 Kernel 4
Kernel 6
With B=1 and T=8192, C = 768, NH = 12
Going to run the Nvidia profiler on these in the main train loop to see what the values are that were using in the default training run, seems like depending on the sizing of the params can make it slower or faster but the best improvements are when the model is large. Edit: NVM this is all around improvement, it appears that the majority of the time difference can be explained by the mallocs, will have to profile without that |
Thanks alot for your profiling work @ChrisDryden , very interesting results! |
Hi @kilianhae & @simonguozirui , note that we merged to master the cudnn flash attention here today: so this becomes the baseline to beat! |
actually sorry let me rephrase - new baseline to approach*. The cuDNN dependency is HEAVY (custom install and the compile time becomes 1.5 minutes) and I would delete it if we can get 90% of the way there. |
@karpathy apologies for the delay. We are getting back to this
|
Faster Flash Attention Implementation
Added
attention_forward6
tosrc/attention_forward
: A fast flash attention forward pass tosrc/attention_forward
written without any dependencies. We are assuming using a causal mask here.This kernel fuses all attention operations, as described in Flash Attention 2, without realization the entire attention matrix, which reduces the memory boundedness of the attention operation.
This PR is made by @kilianhae and @simonguozirui.
Implementation
Here are a few steps we took to optimize attention
B_r
xB_c
). Tiling enables us to reduce the stalls caused by GMEM loads at the cost of more SMEM stalls.TM
xTN
). This lowers the amount of SMEM stalls at the cost of higher register usage, which we found to be quite robust.This blog on optimizing CUDA MatMul kernel by @siboehm was incredibly helpful for our understanding and inspired many of our techniques.
Profiling & Performance
So far we tested this
attention_forward6
implementation compared to other existing implementations. It is correct for the cases we tested, but let us know if you run into issues.On RTX 3060 (Amphere)
TODO: @Kilian if you wanna run some more numbers on 3060
On a RTX 3060, we obtain the following times:
For smaller sequence lengths <= 1024 we empirically found that our kernel matches the performance of attention_forward4(), however on larger sequence lengths it outperforms it:
On Tesla T4 (Turing)
We ran against all existing
attention_forward
implementations currently in the repo, and we found our implementation slightly faster in almost all settings. We don't know whether this behavior is specific for T4 GPU or Turing architecture.We also tried to do some scaling test comparing
attention_forward4
(current fastest) vsattention_forward6
(this PR), increasing context length until running out of memory (at seq_len = 8k). It seems thatattention_forward6
might have lower scaling overhead, but we need to test up to longer sequence length to confirm.On A100 (Amphere) / H100 (Hopper)
We currently do not have access to such resources. We would really appreciate and be curious to see the performance on those hardware!
Compare to PyTorch
To be more fair towards non-jit compiled pytorch we provide more detailed comparisons against pytorch in our repo: https://github.com/kilianhae/FlashAttention.C.git. There we also provide a profiling script to compare with torch.
Known Issues & Next Steps
Currently we assume fixed hidden head dimension to 64 following GPT-2 architecture. We do support any multiple of 32, to do so you need to change the d constant in the flashattention kernel.
Further optimizations: