Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Test jfc4050's pr587 #606

Closed
wants to merge 8 commits into from
Closed

[CI] Test jfc4050's pr587 #606

wants to merge 8 commits into from

Conversation

danthe3rd
Copy link
Contributor

@danthe3rd danthe3rd commented Dec 19, 2022

To run CI on #587

git fetch jfc4050 && git push oss jfc4050/attn-bias-pr-2:test_jfc4050

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 19, 2022
adds attn bias (including bias grad) and dropout support to CUTLASS
flashattn implementation

[-------------------------------------------- attn --------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |     12.7    |     7.5
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |     15.5    |     9.1
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |     12.7    |     7.6
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |     15.6    |     9.1
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |     10.1    |     6.0
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |     12.7    |     7.5
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     44.3    |    29.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     55.0    |    35.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     45.1    |    29.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     55.6    |    35.3
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     37.0    |    22.6
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     46.8    |    29.0

Times are in milliseconds (ms).

[------------------------------------------ attn-bwd ------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |     19.3    |    24.1
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |     19.4    |    24.6
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |     22.3    |    28.7
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |     22.4    |    29.0
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |     19.5    |    22.7
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |     19.5    |    23.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     62.7    |    91.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     63.4    |    93.7
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     74.8    |   109.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     75.1    |   111.1
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     63.2    |    85.5
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     64.0    |    90.1
BEFORE

[------------------------------------------ attn-bwd ------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.0)      |      2.8    |     2.4
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.5)      |      2.8    |     3.3
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.0)       |      3.4    |     3.2
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.5)       |      3.4    |     4.2
      (8, 512, 64, 64, torch.float16, None, False, 0.0)                 |      2.8    |     2.0
      (8, 512, 64, 64, torch.float16, None, False, 0.5)                 |      2.8    |     2.9
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |      3.6    |     3.9
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |      3.6    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |      4.2    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |      4.2    |     5.6
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |      3.6    |     3.4
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |      3.6    |     4.4
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.0)   |      9.7    |     8.8
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.5)   |      9.7    |    12.6
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.0)    |     12.0    |    12.1
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.5)    |     12.1    |    16.1
      (8, 1024, 64, 64, torch.float16, None, False, 0.0)                |      9.7    |     7.4
      (8, 1024, 64, 64, torch.float16, None, False, 0.5)                |      9.7    |    10.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     11.3    |    14.0
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     11.3    |    17.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     13.6    |    17.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     13.6    |    20.9
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     11.3    |    12.1
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     11.3    |    15.8

AFTER

[------------------------------------------ attn-bwd ------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.0)      |      2.8    |     2.4
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.5)      |      2.8    |     3.0
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.0)       |      3.4    |     3.2
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.5)       |      3.4    |     3.8
      (8, 512, 64, 64, torch.float16, None, False, 0.0)                 |      2.8    |     2.0
      (8, 512, 64, 64, torch.float16, None, False, 0.5)                 |      2.8    |     2.6
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |      3.6    |     3.9
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |      3.6    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |      4.2    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |      4.2    |     5.6
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |      3.6    |     3.4
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |      3.6    |     4.4
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.0)   |      9.7    |     8.8
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.5)   |      9.7    |    11.4
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.0)    |     12.0    |    12.1
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.5)    |     12.1    |    14.6
      (8, 1024, 64, 64, torch.float16, None, False, 0.0)                |      9.7    |     7.4
      (8, 1024, 64, 64, torch.float16, None, False, 0.5)                |      9.7    |     9.6
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     11.3    |    14.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     11.3    |    17.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     13.6    |    17.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     13.6    |    20.9
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     11.3    |    12.1
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     11.3    |    15.8
@danthe3rd danthe3rd closed this Jan 18, 2023
@danthe3rd danthe3rd deleted the test_jfc4050 branch February 13, 2023 13:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants