-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUDA] Add SparseAttention operator for Phi-3-small #20216
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tianleiwu
force-pushed
the
tlwu/sparse_attention
branch
from
April 5, 2024 23:32
b929f67
to
daa1787
Compare
tianleiwu
force-pushed
the
tlwu/sparse_attention
branch
from
April 24, 2024 23:33
5a1d115
to
bb5796d
Compare
onnxruntime/contrib_ops/cuda/sparse/sparse_attention_trition/sparse_attention_triton.py
Fixed
Show fixed
Hide fixed
tianleiwu
requested review from
aciddelgado,
wangyems,
kunal-vaishnavi,
yufenglee and
apsonawane
April 29, 2024 04:05
wangyems
reviewed
Apr 29, 2024
wangyems
reviewed
Apr 29, 2024
wangyems
reviewed
Apr 30, 2024
...ntrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_bf16_m16_0_n64_0_d64_2_sm80.cc
Show resolved
Hide resolved
tianleiwu
force-pushed
the
tlwu/sparse_attention
branch
from
April 30, 2024 05:51
50b6eee
to
348781d
Compare
wangyems
approved these changes
Apr 30, 2024
7 tasks
tianleiwu
added a commit
that referenced
this pull request
May 2, 2024
### Description Follow up of #20216 to add kernel for sm=75 (GPU like T4, Geforce RTX 2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc) - [x] Add kernel for sm=75 - [x] Update dispatch code to use sm to call different kernel. - [x] Update compile script to use num_stages=2 instead of 3 for sm=75 - [x] Refactor test script and add tests for bfloat16. - [x] Fix performance test of token generation (previously we did not concatenate past_key) - [x] Fix debug build - [x] Run performance test and update numbers. For sm=70, the v1 kernel can be compiled but there is error in compiling v2 kernel. So it is skipped in this pull request. Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with dense causal. Note that GQA with dense need more computation since no sparsity is used. The TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be faster). ``` prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.184173 2.994347 0.089064 2 64.0 0.303300 3.023986 0.107418 3 128.0 0.887795 3.073728 0.174213 4 256.0 2.797654 3.246899 0.357869 5 512.0 10.055048 3.814039 0.893903 6 1024.0 37.849937 5.818439 2.658720 7 2048.0 148.641785 13.638480 7.202690 8 4096.0 OOM 43.556847 17.680954 9 8192.0 OOM 161.628540 44.336670 token-sm75-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.110353 2.996305 0.137509 2 64.0 0.145088 3.006860 0.165424 3 128.0 0.219500 3.036448 0.192001 4 256.0 0.347496 3.071341 0.249125 5 512.0 0.595842 3.135225 0.398726 6 1024.0 1.081216 3.261110 0.612744 7 2048.0 2.060307 3.515578 0.685670 8 4096.0 OOM 4.022986 0.819707 9 8191.0 OOM 5.024528 1.072912 ``` ### Motivation and Context To inference Phi-3-small in T4 GPU
4 tasks
yihonglyu
pushed a commit
that referenced
this pull request
May 4, 2024
### Description Add CUDA implementation for block sparse attention for Phi-3-small. Block sparse attention was proposed in [Sparse Transformers](https://arxiv.org/pdf/1904.10509) by OpenAI, and also adopted in [BigBird](https://arxiv.org/pdf/2007.14062) with different sparse layout. In Phi-3-small, the sparse layout is static, and works with unidirectional (causal) attention. Compared to dense attention, the benefit of block sparse is to speed up both training and inference. It could save memory thus support longer context length. - [x] Add operator spec and shape inference - [x] Symbolic shape inference - [x] Refactor GroupQueryAttention to expose common kernels for kv cache concatenation, q/k/v transpose etc. - [x] Add cuda kernel to convert block mask to CSR format - [x] Add cuda kernel to generate position ids - [x] Add compile script and template files to convert triton kernel to cubin and dispatcher. - [x] Add triton kernel v1 for prompt - [x] Add triton kernel v2 for token generation and support padding - [x] Update IO Binding Helper to allow buffer sharing. - [x] Test relevance - [x] Test performance ### Performance Test in A100-SXM4-80GB with `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with local attention windows size 1024, or GQA with dense causal. Average latency in milliseconds (for fused attention kernel used in prompt prefilling): seq_len | GQA-Dense | GQA-Local | SparseAttention -- | -- | -- | -- 64 | 0.0465 | 0.0722 | 0.0641 128 | 0.0618 | 0.0787 | 0.0672 256 | 0.1086 | 0.1076 | 0.0943 512 | 0.2535 | 0.2487 | 0.1676 1024 | 0.7042 | 0.7050 | 0.3800 2048 | 2.4125 | 1.9316 | 0.8966 4096 | 8.9346 | 4.5699 | 2.1129 8192 | 40.5401 | 10.3508 | 5.1748 Average latency in milliseconds (for fused attention kernel used in token generation: past_seq_len | GQA-Dense | GQA-Local | SparseAttention -- | -- | -- | -- 64 | 0.0186 | 0.0186 | 0.0870 128 | 0.0408 | 0.0466 | 0.1165 256 | 0.0530 | 0.0592 | 0.0988 512 | 0.0445| 0.0447 | 0.1150 1024 | 0.0634 | 0.0640 | 0.1454 2048 | 0.1027 | 0.0637 | 0.1589 4096 | 0.1789 | 0.0631 | 0.1806 8192 | 0.3288 | 0.0655 | 0.2146 We can see that the kernel for token generation still have room to improve. #### Limitations Only support right-side padding and unidirectional attention. The following are not supported in the first version: (1) Packed mode like PackedMultiHeadAttention where input has been removed padding. (2) paged attention. (3) bidirectional attention. (4) GPU compute capacity that is not 8.0, 8.6 and 8.9. (5) Left side padding. Some of these limitations will be removed in the future (may be in a new operator).
yihonglyu
pushed a commit
that referenced
this pull request
May 4, 2024
### Description Follow up of #20216 to add kernel for sm=75 (GPU like T4, Geforce RTX 2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc) - [x] Add kernel for sm=75 - [x] Update dispatch code to use sm to call different kernel. - [x] Update compile script to use num_stages=2 instead of 3 for sm=75 - [x] Refactor test script and add tests for bfloat16. - [x] Fix performance test of token generation (previously we did not concatenate past_key) - [x] Fix debug build - [x] Run performance test and update numbers. For sm=70, the v1 kernel can be compiled but there is error in compiling v2 kernel. So it is skipped in this pull request. Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with dense causal. Note that GQA with dense need more computation since no sparsity is used. The TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be faster). ``` prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.184173 2.994347 0.089064 2 64.0 0.303300 3.023986 0.107418 3 128.0 0.887795 3.073728 0.174213 4 256.0 2.797654 3.246899 0.357869 5 512.0 10.055048 3.814039 0.893903 6 1024.0 37.849937 5.818439 2.658720 7 2048.0 148.641785 13.638480 7.202690 8 4096.0 OOM 43.556847 17.680954 9 8192.0 OOM 161.628540 44.336670 token-sm75-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.110353 2.996305 0.137509 2 64.0 0.145088 3.006860 0.165424 3 128.0 0.219500 3.036448 0.192001 4 256.0 0.347496 3.071341 0.249125 5 512.0 0.595842 3.135225 0.398726 6 1024.0 1.081216 3.261110 0.612744 7 2048.0 2.060307 3.515578 0.685670 8 4096.0 OOM 4.022986 0.819707 9 8191.0 OOM 5.024528 1.072912 ``` ### Motivation and Context To inference Phi-3-small in T4 GPU
tianleiwu
added a commit
that referenced
this pull request
May 5, 2024
### Description Follow up of #20216 to add sparse attention kernel compiled by Triton for H100 (sm90). - [x] Refine sparse attention v1 kernel compilation (remove some combinations) - [x] compile kernels for v1 kernels - [x] compile kernels for H100 - [x] run performance tests ### Performane Test setting `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with local attention windows size 1024, or GQA with dense causal. Note that ORT-GQA-Dense has more computation than ORT-SparseAtt, while ORT-GQA-Local has less computation (no vertial strides) than ORT-SparseAtt. They are added for reference. It is not fair comparison, but could show the benefit of sparsity vs dense. Example results in Azure Standard_ND96isr_H100_v5 VM with NVIDIA H100-80GB-HBM3 GPU (sm=90): ``` prompt-sm90-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-GQA-Local ORT-SparseAtt 0 16.0 0.079877 0.006362 0.006403 0.042758 1 32.0 0.086920 0.016404 0.016686 0.044183 2 64.0 0.090727 0.020429 0.020409 0.045343 3 128.0 0.128148 0.032009 0.031984 0.051516 4 256.0 0.323933 0.074110 0.073920 0.068308 5 512.0 1.021856 0.162167 0.161951 0.109226 6 1024.0 3.596002 0.452629 0.452780 0.231653 7 2048.0 13.865088 1.499534 1.195749 0.515488 8 4096.0 0.000000 5.454785 2.669682 1.163233 9 8192.0 0.000000 22.068159 6.018604 2.772873 token-sm90-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-GQA-Local ORT-SparseAtt 0 16.0 0.104460 0.012652 0.012661 0.069549 1 32.0 0.113866 0.012776 0.012765 0.069024 2 64.0 0.124600 0.016791 0.012672 0.069397 3 128.0 0.108658 0.017900 0.018294 0.074844 4 256.0 0.115463 0.029409 0.029608 0.078911 5 512.0 0.149824 0.033968 0.033701 0.092998 6 1024.0 0.234050 0.042930 0.042951 0.116920 7 2048.0 0.390695 0.061462 0.043008 0.121555 8 4096.0 0.000000 0.097505 0.042948 0.134757 9 8191.0 0.000000 0.165861 0.043542 0.158796 ``` The following might be able to help performance on short sequence length. Need update operator spec: Fall back to flash attention when total_sequence length < local_blocks * block_size ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
TedThemistokleous
pushed a commit
to TedThemistokleous/onnxruntime
that referenced
this pull request
May 7, 2024
### Description Add CUDA implementation for block sparse attention for Phi-3-small. Block sparse attention was proposed in [Sparse Transformers](https://arxiv.org/pdf/1904.10509) by OpenAI, and also adopted in [BigBird](https://arxiv.org/pdf/2007.14062) with different sparse layout. In Phi-3-small, the sparse layout is static, and works with unidirectional (causal) attention. Compared to dense attention, the benefit of block sparse is to speed up both training and inference. It could save memory thus support longer context length. - [x] Add operator spec and shape inference - [x] Symbolic shape inference - [x] Refactor GroupQueryAttention to expose common kernels for kv cache concatenation, q/k/v transpose etc. - [x] Add cuda kernel to convert block mask to CSR format - [x] Add cuda kernel to generate position ids - [x] Add compile script and template files to convert triton kernel to cubin and dispatcher. - [x] Add triton kernel v1 for prompt - [x] Add triton kernel v2 for token generation and support padding - [x] Update IO Binding Helper to allow buffer sharing. - [x] Test relevance - [x] Test performance ### Performance Test in A100-SXM4-80GB with `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with local attention windows size 1024, or GQA with dense causal. Average latency in milliseconds (for fused attention kernel used in prompt prefilling): seq_len | GQA-Dense | GQA-Local | SparseAttention -- | -- | -- | -- 64 | 0.0465 | 0.0722 | 0.0641 128 | 0.0618 | 0.0787 | 0.0672 256 | 0.1086 | 0.1076 | 0.0943 512 | 0.2535 | 0.2487 | 0.1676 1024 | 0.7042 | 0.7050 | 0.3800 2048 | 2.4125 | 1.9316 | 0.8966 4096 | 8.9346 | 4.5699 | 2.1129 8192 | 40.5401 | 10.3508 | 5.1748 Average latency in milliseconds (for fused attention kernel used in token generation: past_seq_len | GQA-Dense | GQA-Local | SparseAttention -- | -- | -- | -- 64 | 0.0186 | 0.0186 | 0.0870 128 | 0.0408 | 0.0466 | 0.1165 256 | 0.0530 | 0.0592 | 0.0988 512 | 0.0445| 0.0447 | 0.1150 1024 | 0.0634 | 0.0640 | 0.1454 2048 | 0.1027 | 0.0637 | 0.1589 4096 | 0.1789 | 0.0631 | 0.1806 8192 | 0.3288 | 0.0655 | 0.2146 We can see that the kernel for token generation still have room to improve. #### Limitations Only support right-side padding and unidirectional attention. The following are not supported in the first version: (1) Packed mode like PackedMultiHeadAttention where input has been removed padding. (2) paged attention. (3) bidirectional attention. (4) GPU compute capacity that is not 8.0, 8.6 and 8.9. (5) Left side padding. Some of these limitations will be removed in the future (may be in a new operator).
TedThemistokleous
pushed a commit
to TedThemistokleous/onnxruntime
that referenced
this pull request
May 7, 2024
### Description Follow up of microsoft#20216 to add kernel for sm=75 (GPU like T4, Geforce RTX 2080, GeForce GTX 1650 Ti, NVIDIA TITAN RTX, RTX 4000 etc) - [x] Add kernel for sm=75 - [x] Update dispatch code to use sm to call different kernel. - [x] Update compile script to use num_stages=2 instead of 3 for sm=75 - [x] Refactor test script and add tests for bfloat16. - [x] Fix performance test of token generation (previously we did not concatenate past_key) - [x] Fix debug build - [x] Run performance test and update numbers. For sm=70, the v1 kernel can be compiled but there is error in compiling v2 kernel. So it is skipped in this pull request. Performance Test on T4 GPU (using Standard_NC4as_T4_v3 Azure VM) with `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with dense causal. Note that GQA with dense need more computation since no sparsity is used. The TORCH-GQA use naive implementation (using cuSPARSE Block-SpMM could be faster). ``` prompt-sm75-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.184173 2.994347 0.089064 2 64.0 0.303300 3.023986 0.107418 3 128.0 0.887795 3.073728 0.174213 4 256.0 2.797654 3.246899 0.357869 5 512.0 10.055048 3.814039 0.893903 6 1024.0 37.849937 5.818439 2.658720 7 2048.0 148.641785 13.638480 7.202690 8 4096.0 OOM 43.556847 17.680954 9 8192.0 OOM 161.628540 44.336670 token-sm75-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-SparseAtt 1 32.0 0.110353 2.996305 0.137509 2 64.0 0.145088 3.006860 0.165424 3 128.0 0.219500 3.036448 0.192001 4 256.0 0.347496 3.071341 0.249125 5 512.0 0.595842 3.135225 0.398726 6 1024.0 1.081216 3.261110 0.612744 7 2048.0 2.060307 3.515578 0.685670 8 4096.0 OOM 4.022986 0.819707 9 8191.0 OOM 5.024528 1.072912 ``` ### Motivation and Context To inference Phi-3-small in T4 GPU
TedThemistokleous
pushed a commit
to TedThemistokleous/onnxruntime
that referenced
this pull request
May 7, 2024
### Description Follow up of microsoft#20216 to add sparse attention kernel compiled by Triton for H100 (sm90). - [x] Refine sparse attention v1 kernel compilation (remove some combinations) - [x] compile kernels for v1 kernels - [x] compile kernels for H100 - [x] run performance tests ### Performane Test setting `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with local attention windows size 1024, or GQA with dense causal. Note that ORT-GQA-Dense has more computation than ORT-SparseAtt, while ORT-GQA-Local has less computation (no vertial strides) than ORT-SparseAtt. They are added for reference. It is not fair comparison, but could show the benefit of sparsity vs dense. Example results in Azure Standard_ND96isr_H100_v5 VM with NVIDIA H100-80GB-HBM3 GPU (sm=90): ``` prompt-sm90-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-GQA-Local ORT-SparseAtt 0 16.0 0.079877 0.006362 0.006403 0.042758 1 32.0 0.086920 0.016404 0.016686 0.044183 2 64.0 0.090727 0.020429 0.020409 0.045343 3 128.0 0.128148 0.032009 0.031984 0.051516 4 256.0 0.323933 0.074110 0.073920 0.068308 5 512.0 1.021856 0.162167 0.161951 0.109226 6 1024.0 3.596002 0.452629 0.452780 0.231653 7 2048.0 13.865088 1.499534 1.195749 0.515488 8 4096.0 0.000000 5.454785 2.669682 1.163233 9 8192.0 0.000000 22.068159 6.018604 2.772873 token-sm90-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-GQA-Local ORT-SparseAtt 0 16.0 0.104460 0.012652 0.012661 0.069549 1 32.0 0.113866 0.012776 0.012765 0.069024 2 64.0 0.124600 0.016791 0.012672 0.069397 3 128.0 0.108658 0.017900 0.018294 0.074844 4 256.0 0.115463 0.029409 0.029608 0.078911 5 512.0 0.149824 0.033968 0.033701 0.092998 6 1024.0 0.234050 0.042930 0.042951 0.116920 7 2048.0 0.390695 0.061462 0.043008 0.121555 8 4096.0 0.000000 0.097505 0.042948 0.134757 9 8191.0 0.000000 0.165861 0.043542 0.158796 ``` The following might be able to help performance on short sequence length. Need update operator spec: Fall back to flash attention when total_sequence length < local_blocks * block_size ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
yihonglyu
pushed a commit
that referenced
this pull request
May 9, 2024
### Description Follow up of #20216 to add sparse attention kernel compiled by Triton for H100 (sm90). - [x] Refine sparse attention v1 kernel compilation (remove some combinations) - [x] compile kernels for v1 kernels - [x] compile kernels for H100 - [x] run performance tests ### Performane Test setting `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with local attention windows size 1024, or GQA with dense causal. Note that ORT-GQA-Dense has more computation than ORT-SparseAtt, while ORT-GQA-Local has less computation (no vertial strides) than ORT-SparseAtt. They are added for reference. It is not fair comparison, but could show the benefit of sparsity vs dense. Example results in Azure Standard_ND96isr_H100_v5 VM with NVIDIA H100-80GB-HBM3 GPU (sm=90): ``` prompt-sm90-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-GQA-Local ORT-SparseAtt 0 16.0 0.079877 0.006362 0.006403 0.042758 1 32.0 0.086920 0.016404 0.016686 0.044183 2 64.0 0.090727 0.020429 0.020409 0.045343 3 128.0 0.128148 0.032009 0.031984 0.051516 4 256.0 0.323933 0.074110 0.073920 0.068308 5 512.0 1.021856 0.162167 0.161951 0.109226 6 1024.0 3.596002 0.452629 0.452780 0.231653 7 2048.0 13.865088 1.499534 1.195749 0.515488 8 4096.0 0.000000 5.454785 2.669682 1.163233 9 8192.0 0.000000 22.068159 6.018604 2.772873 token-sm90-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-GQA-Local ORT-SparseAtt 0 16.0 0.104460 0.012652 0.012661 0.069549 1 32.0 0.113866 0.012776 0.012765 0.069024 2 64.0 0.124600 0.016791 0.012672 0.069397 3 128.0 0.108658 0.017900 0.018294 0.074844 4 256.0 0.115463 0.029409 0.029608 0.078911 5 512.0 0.149824 0.033968 0.033701 0.092998 6 1024.0 0.234050 0.042930 0.042951 0.116920 7 2048.0 0.390695 0.061462 0.043008 0.121555 8 4096.0 0.000000 0.097505 0.042948 0.134757 9 8191.0 0.000000 0.165861 0.043542 0.158796 ``` The following might be able to help performance on short sequence length. Need update operator spec: Fall back to flash attention when total_sequence length < local_blocks * block_size ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
cherry-picked
Cherry-picked for a cherrypicks branch
rel-merged
Cherrypicks merged into release
release:1.18.0
triage:approved
Approved for cherrypicks for release
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Add CUDA implementation for block sparse attention for Phi-3-small.
Block sparse attention was proposed in Sparse Transformers by OpenAI, and also adopted in BigBird with different sparse layout.
In Phi-3-small, the sparse layout is static, and works with unidirectional (causal) attention.
Compared to dense attention, the benefit of block sparse is to speed up both training and inference. It could save memory thus support longer context length.
Tasks to be done in next pull request:
Performance
Test in A100-SXM4-80GB with
batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8
We compare sparse attention to corresponding GQA with local attention windows size 1024, or GQA with dense causal. Note that
ORT-GQA-Dense
has more computation thanORT-SparseAtt
, whileORT-GQA-Local
has less computation (no vertial strides) thanORT-SparseAtt
. They are added for reference. It is not fair comparison, but could show the benefit of sparsity vs dense.Average latency in milliseconds (for fused attention kernel used in prompt prefilling):
We can see that the kernel for token generation still have room to improve.
Limitations
Only support right-side padding and unidirectional attention.
The following are not supported in the first version:
(1) Packed mode like PackedMultiHeadAttention where input has been removed padding.
(2) paged attention.
(3) bidirectional attention.
(4) GPU compute capacity that is not 7.5, 8.0, 8.6 and 8.9.
(5) Left side padding.
Some of these limitations will be removed in the future (may be in a new operator).
Motivation and Context