Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flashattention #285

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Flashattention #285

wants to merge 6 commits into from

Conversation

kilianhae
Copy link

Faster Flash Attention Implementation

Added attention_forward6 to src/attention_forward: A fast flash attention forward pass to src/attention_forward written without any dependencies. We are assuming using a causal mask here.

This kernel fuses all attention operations, as described in Flash Attention 2, without realization the entire attention matrix, which reduces the memory boundedness of the attention operation.

This PR is made by @kilianhae and @simonguozirui.

Implementation

Here are a few steps we took to optimize attention

  1. Computation is first decomposed by tiling the attention computation onto the streaming processor blocks (B_r x B_c). Tiling enables us to reduce the stalls caused by GMEM loads at the cost of more SMEM stalls.
  2. To exploit data locality, we let each thread be responsible for calculating a chunk of attention values of size (TM x TN). This lowers the amount of SMEM stalls at the cost of higher register usage, which we found to be quite robust.
  3. We use sequence level parallelism over the blocks of Q as done in flash attention 2
  4. Do not compute masked attention values
  5. We dont cache over tiles of Q in order to reduce SMEM usage and increase occupancy

This blog on optimizing CUDA MatMul kernel by @siboehm was incredibly helpful for our understanding and inspired many of our techniques.

Profiling & Performance

So far we tested this attention_forward6 implementation compared to other existing implementations. It is correct for the cases we tested, but let us know if you run into issues.

On RTX 3060 (Amphere)
TODO: @Kilian if you wanna run some more numbers on 3060
On a RTX 3060, we obtain the following times:

  • B = 8, T = 1024, C = 768, NH = 12: 7.343360 ms
  • B = 6, T = 4096, C = 768, NH = 12, 61.922577 ms
  • B = 1, T = 8192, C = 768, NH = 12: 37.887924 ms

For smaller sequence lengths <= 1024 we empirically found that our kernel matches the performance of attention_forward4(), however on larger sequence lengths it outperforms it:

  • 4096: 61.922577 ms vs 84.654724 ms
  • 8192: 37.887924 ms vs 57.476559 ms

On Tesla T4 (Turing)

We ran against all existing attention_forward implementations currently in the repo, and we found our implementation slightly faster in almost all settings. We don't know whether this behavior is specific for T4 GPU or Turing architecture.
upload_c8dcec5f6a061f8492a6510babf427ae

We also tried to do some scaling test comparing attention_forward4 (current fastest) vs attention_forward6 (this PR), increasing context length until running out of memory (at seq_len = 8k). It seems that attention_forward6 might have lower scaling overhead, but we need to test up to longer sequence length to confirm.
upload_d06c48b99c34ed1df88e50a03957f16e

On A100 (Amphere) / H100 (Hopper)
We currently do not have access to such resources. We would really appreciate and be curious to see the performance on those hardware!

Compare to PyTorch

To be more fair towards non-jit compiled pytorch we provide more detailed comparisons against pytorch in our repo: https://github.com/kilianhae/FlashAttention.C.git. There we also provide a profiling script to compare with torch.

Known Issues & Next Steps

Currently we assume fixed hidden head dimension to 64 following GPT-2 architecture. We do support any multiple of 32, to do so you need to change the d constant in the flashattention kernel.

Further optimizations:

  • Reduce smem complexity: Maybe reload V tiles instead of caching fully over hidden dimension at the cost of more GMEM loads.
  • Enable larger coarsening where TN,TM > 4 to further reduce SMEM loads.
  • Autotuning different tiling size at various levels for different hardware -> also test tiles of B_r, B_c = 64, 64 on gpus with sufficient SMEM size

@karpathy
Copy link
Owner

This is very cool work!! Questions:

  • there are mallocs inside the kernel launch, I'm guessing in the actual implementation we'd treat these as buffers and make them part of the Activations (?).
  • this is only the forward pass. Are there intermediate tensors we'd want to emit for the backward pass later? Can this forward pass be combined with our current backward pass? (I think the answer is no, right?). In other words we'd have to wait for the backward() before we can push this to mainline, is that right?

@ChrisDryden
Copy link
Contributor

ChrisDryden commented Apr 29, 2024

Would it be possible to also add the commands with the params used in the profiling script to do the comparison, I have access to run it on a H100 and an A100 but when running the linked profiling command it was asking for the sets of params.

Edit: Nvm it was just there was a duplicate masking param.

Would it be possible to share the code used to generate the graphs in this PR, it would make it much faster to replicate the results

@simonguozirui
Copy link

@ChrisDryden, thanks for offering to help with the H100 and A100 performances!
I didn't have a script. I just manually edited the dimensions (B, T, C, NH) in the main function of dev/cuda/attention_forward.cu, compiled, and ran it to observe the runtime.

You can see all the settings and stats I collected in this spreadsheet. Feel free to follow up as comments on Github/spreadsheet or DM us the stats on discord.

Two tests I did as shown in the graphs

  • Use the default dimensions, run on all the current attention_forward implementations, and compare times
  • Increase the sequence length to see scaling performance.
    We could do more different experiments or comparisons, and we are super interested to see how it does on various hardware platforms.

I used a chatbot to generate the graph for now 😅 but sure I can provide a profiling script.

@kilianhae
Copy link
Author

Thanks @karpathy

  • Yes, the mallocs are just because of how the test data here is generated and formatted. In prod, I we would use these as caches for the activations.
  • With my current understanding I think this is not compatible with the current backward pass, because we never materialize S or A we need a backward kernel that relies on recomputation from shared memory, but I will double check this and report back tomorrow!

For reference: If I measure without mallocs (only time of the actual kernel) on the 3060 I get the following times:

  • B = 8, T = 1024, C = 768, NH = 12: 3.8 ms
  • B = 6, T = 4096, C = 768, NH = 12, 44 ms
  • B = 1, T = 8192, C = 768, NH = 12: 29.5 ms

@ChrisDryden Getting results on these gpu's would be awesome!
I just pushed a modification to our repo: https://github.com/kilianhae/FlashAttention.C.git, which on my GPU gives an additonal 20% speedup.

@ChrisDryden
Copy link
Contributor

ChrisDryden commented Apr 29, 2024

I was running with the A100 in the dev kernel comparing the 4th kernel to the 6th kernel with the default params.

With B=1 and T=8192, C = 768, NH = 12

Kernel 4

block_size   32 | time 2.889739 ms
block_size   64 | time 2.787523 ms
block_size  128 | time 2.758677 ms
block_size  256 | time 2.776935 ms
block_size  512 | time 2.778962 ms

Kernel 6

block_size   32 | time 4.011267 ms
block_size   64 | time 3.897487 ms
block_size  128 | time 3.869420 ms
block_size  256 | time 3.903386 ms
block_size  512 | time 3.942431 ms

With B=1 and T=8192, C = 768, NH = 12

block_size   32 | time 22.012802 ms
block_size   64 | time 21.934755 ms
block_size  128 | time 21.952288 ms
block_size  256 | time 21.965374 ms
block_size  512 | time 21.991333 ms
block_size   32 | time 14.037394 ms
block_size   64 | time 13.913477 ms
block_size  128 | time 13.851557 ms
block_size  256 | time 13.843405 ms
block_size  512 | time 13.856246 ms

Going to run the Nvidia profiler on these in the main train loop to see what the values are that were using in the default training run, seems like depending on the sizing of the params can make it slower or faster but the best improvements are when the model is large.

Edit: NVM this is all around improvement, it appears that the majority of the time difference can be explained by the mallocs, will have to profile without that

@kilianhae
Copy link
Author

Thanks alot for your profiling work @ChrisDryden , very interesting results!
I just pushed the new feature (for 3060 about +20%) into the codebase. In case you are profiling the kernels, best to pull the codebase first :)

@karpathy
Copy link
Owner

karpathy commented May 1, 2024

Hi @kilianhae & @simonguozirui , note that we merged to master the cudnn flash attention here today:

#323

so this becomes the baseline to beat!

@karpathy
Copy link
Owner

karpathy commented May 1, 2024

actually sorry let me rephrase - new baseline to approach*. The cuDNN dependency is HEAVY (custom install and the compile time becomes 1.5 minutes) and I would delete it if we can get 90% of the way there.

@simonguozirui
Copy link

@karpathy apologies for the delay. We are getting back to this

  • Re comparing with CuDNN flash attention, agree that the CuDNN dependency is heavy. For comparison of our kernel vs the CuDNN flash attention, ours is in FP32, and the CuDNN version uses Float 16. We will likely write a Float 16 version to have a fair comparison, and trying to find Ampere architecture GPUs to profile.
  • Re malloc, we will try to get rid of the malloc in kernel launch; we currently malloc to rearrange the tensor to satisfy the test format.
  • Re backward pass, @by2299 and us will be working on this. We might need to save some intermediate tensors (Q, K, V, O, l, m) for the backward pass. We will see if we can combine with the current backward pass. Will keep you posted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants