-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tutorial 6. fused-attention bwd very slow if D_HEAD == 128
#1975
Comments
after digging source code, Tutorial code does not working
|
@akakakakakaa Same issue here. Wonder if we can change the backward as well so we can use a smaller block size? |
@ptillet Can you provide some insights? Can we make backward block smaller? I tried and it passes the pytest. |
Environment
Problem
To check the performance difference between flash attention2 and triton, I only modified the code to enable sequence_parallel and to be compatible with flash-attention2.
When I test it, Everything is OK if I use D_HEAD = 64. But, When I use D_HEAD = 128, the Backward function showed strange results.
But, If skip storing only one of dk or dv, A significant performance improvement was observed.
It seems strange that there is such a huge difference in performance despite skipping saving to HBM only once.
Do you have any idea why this is happening?
Thanks.
Full Code
Edit
When I use BLOCK = 32 it works very fast
BLOCK = 16
The text was updated successfully, but these errors were encountered: