Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MmaFromSmem[A100]: Accept transposed operand A #540

Merged
merged 6 commits into from
Nov 29, 2022

Conversation

danthe3rd
Copy link
Contributor

@danthe3rd danthe3rd commented Nov 25, 2022

Stack from ghstack (oldest at bottom):

SUMMARY
Load tmp.transpose() directly from tmp in shared memory (transpose as we load). No longer need to store tmp+tmp.T in shared memory.
Because we use less shared-memory, this means we can fit bigger block sizes. Going from 64x128 -> 128x128 gives ~15% perf improvement (for k>64).

PERF TEST (A100)

BW A100 (f16)
[------------------------------------ attention backward (attn_bias=<class 'NoneType'>) -------------------------------------]                                                                                                                                                            
                                     |  57_tmpT_b516aec4[cutlass]  |  flash[flshatt]  |  vanilla  |  56_base_02bf6b4e[cutlass]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |              578.3          |                  |   2263.9  |              609.7        
      f16 B=384, M=197, H=1, K=80    |              547.6          |                  |   1921.9  |              572.6        
      f16 B=384, M=197, H=1, K=64    |              365.4          |        232.6     |   1808.2  |              386.0        
      f16 B=1024, M=197, H=1, K=88   |             1456.3          |                  |   5964.7  |             1539.8        
      f16 B=1024, M=197, H=1, K=80   |             1382.9          |                  |   5037.8  |             1464.9        
      f16 B=1024, M=197, H=1, K=64   |              822.1          |        576.4     |   4732.2  |              862.5        
      f16 B=512, M=197, H=1, K=80    |              695.4          |                  |   2543.6  |              730.0        
      f16 B=32, M=197, H=16, K=80    |              691.2          |                  |   2567.7  |              716.3        
      f16 B=32, M=197, H=16, K=64    |              427.5          |        296.1     |   2428.1  |              456.6        
      f16 B=32, M=197, H=16, K=128   |              858.1          |        682.6     |   4488.4  |              853.5        
      f16 B=256, M=197, H=1, K=88    |              422.0          |                  |   1528.5  |              442.7        
      f16 B=16, M=197, H=16, K=88    |              420.4          |                  |   1543.0  |              437.3        
      f16 B=16, M=197, H=16, K=64    |              217.5          |        165.2     |   1243.5  |              232.6        
      f16 B=16, M=197, H=16, K=128   |              479.8          |        385.5     |   2263.8  |              489.2        
      f16 B=1, M=4096, H=160, K=128  |            51009.8          |      54670.3     |  45924.2  |            63431.9        
      f16 B=2, M=4096, H=160, K=128  |            84491.6          |      84261.8     |           |           100393.5        
      f16 B=1, M=8192, H=160, K=128  |           201456.7          |     215540.9     |           |           251825.4        
      f16 B=2, M=8192, H=160, K=128  |           329735.0          |     330316.3     |           |           395279.2        
      f16 B=1024, M=82, H=8, K=64    |             1764.0          |       1620.9     |   3822.6  |             1857.4        
      f16 B=150, M=256, H=16, K=64   |             2021.6          |       1626.3     |   4557.3  |             2103.9        
      f16 B=64, M=256, H=12, K=64    |              699.4          |        567.8     |   1498.4  |              730.6        
      f16 B=1, M=4096, H=16, K=40    |            22788.9          |                  |   4195.6  |            23624.7        
      f16 B=1, M=16384, H=16, K=40   |           408280.7          |                  |           |           436163.8        
      f16 B=256, M=4096, H=16, K=64  |           565651.1          |     439946.4     |           |           602642.6        
      f16 B=16, M=128, H=16, K=16    |              121.9          |        139.6     |    331.3  |              121.9        
      f16 B=16, M=128, H=16, K=32    |              121.4          |        139.2     |    331.5  |              122.5        
      f16 B=16, M=128, H=16, K=64    |              121.9          |        140.0     |    369.9  |              187.9        
      f16 B=16, M=128, H=16, K=128   |              186.7          |        170.3     |    332.9  |              177.8        
      f16 B=16, M=512, H=16, K=16    |              518.4          |        322.2     |   1204.6  |              556.4        
      f16 B=16, M=512, H=16, K=32    |              602.5          |        435.1     |   1306.5  |              652.2        
      f16 B=16, M=512, H=16, K=64    |              797.0          |        704.9     |   1547.1  |              850.2        
      f16 B=16, M=512, H=16, K=128   |             1544.8          |       1584.6     |   1985.3  |             1752.3        
      f16 B=16, M=1024, H=16, K=16   |             2049.1          |       1244.7     |   4261.7  |             2239.9        
      f16 B=16, M=1024, H=16, K=32   |             2229.0          |       1620.4     |   4492.3  |             2448.0        
      f16 B=16, M=1024, H=16, K=64   |             2817.6          |       2367.6     |   4998.2  |             3041.0        
      f16 B=16, M=1024, H=16, K=128  |             5433.3          |       5638.9     |   5958.5  |             6406.4        
      f16 B=64, M=128, H=16, K=16    |              158.2          |        145.5     |    439.7  |              161.9        
      f16 B=64, M=128, H=16, K=32    |              205.2          |        212.4     |    545.2  |              206.7        
      f16 B=64, M=128, H=16, K=64    |              314.6          |        311.5     |    767.7  |              326.0        
      f16 B=64, M=128, H=16, K=128   |              651.9          |        562.8     |   1227.5  |              613.3        
      f16 B=64, M=512, H=16, K=16    |             1872.3          |       1204.0     |   4488.6  |             1985.3        
      f16 B=64, M=512, H=16, K=32    |             2185.4          |       1543.8     |   4971.7  |             2340.3        
      f16 B=64, M=512, H=16, K=64    |             2940.4          |       2421.0     |   5885.5  |             3077.9        
      f16 B=64, M=512, H=16, K=128   |             5501.3          |       5446.7     |   7711.0  |             6153.0        
      f16 B=64, M=1024, H=16, K=16   |             7318.5          |       4711.4     |  16891.1  |             7890.2        
      f16 B=64, M=1024, H=16, K=32   |             8151.5          |       5697.1     |  17885.4  |             8849.9        
      f16 B=64, M=1024, H=16, K=64   |            10477.7          |       8155.9     |  19951.2  |            11059.9        
      f16 B=64, M=1024, H=16, K=128  |            19178.4          |      19198.4     |  23794.0  |            21939.1        

Times are in microseconds (us).

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 25, 2022
danthe3rd pushed a commit that referenced this pull request Nov 25, 2022
ghstack-source-id: 66214435cf46227e4488190b7cf43dfd26de065e
Pull Request resolved: #540
danthe3rd pushed a commit that referenced this pull request Nov 25, 2022
ghstack-source-id: 258bcebead02ef44dc87eaf4fbd09fd6d9e3d3f5
Pull Request resolved: #540
danthe3rd pushed a commit that referenced this pull request Nov 28, 2022
ghstack-source-id: d532fc410b24d7a27bdfeb0071a6ab1d55c712a8
Pull Request resolved: #540
danthe3rd pushed a commit that referenced this pull request Nov 28, 2022
ghstack-source-id: d654cf5fdfaccf2b4597b3a3867863b1eaf02afa
Pull Request resolved: #540
danthe3rd pushed a commit that referenced this pull request Nov 28, 2022
ghstack-source-id: 7496be13903f8115557c2d0fe77b93bff69cafca
Pull Request resolved: #540
@danthe3rd danthe3rd marked this pull request as ready for review November 28, 2022 13:41
@codecov-commenter
Copy link

codecov-commenter commented Nov 28, 2022

Codecov Report

Base: 89.79% // Head: 89.79% // No change to project coverage 👍

Coverage data is based on head (8505b8a) compared to base (059cfdf).
Patch has no changes to coverable lines.

Additional details and impacted files
@@                  Coverage Diff                  @@
##           gh/danthe3rd/57/base     #540   +/-   ##
=====================================================
  Coverage                 89.79%   89.79%           
=====================================================
  Files                        80       80           
  Lines                      4839     4839           
=====================================================
  Hits                       4345     4345           
  Misses                      494      494           
Flag Coverage Δ
Python 89.79% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great speed-up, thanks a lot Daniel!

I just have some high-level question about what was changed in one file vs the reference implementation.

Also, I suppose our tests stress-test all the necessary configurations that are needed to validate that the new iterator works well for differently sized dimensions (which are not nice multiples)?

@@ -0,0 +1,241 @@
#pragma once
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you describe briefly what was changed in the code compared to the baseline implementation?

**SUMMARY**
Load `tmp.transpose()` directly from `tmp` in shared memory (transpose as we load). No longer need to store tmp+tmp.T in shared memory.
Because we use less shared-memory, this means we can fit bigger block sizes. Going from 64x128 -> 128x128 gives ~15% perf improvement (for k>64).

**PERF TEST (A100)**

<details>
<summary>BW A100 (f16)</summary>

```
[------------------------------------ attention backward (attn_bias=<class 'NoneType'>) -------------------------------------]                                                                                                                                                            
                                     |  57_tmpT_b516aec4[cutlass]  |  flash[flshatt]  |  vanilla  |  56_base_02bf6b4e[cutlass]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |              578.3          |                  |   2263.9  |              609.7        
      f16 B=384, M=197, H=1, K=80    |              547.6          |                  |   1921.9  |              572.6        
      f16 B=384, M=197, H=1, K=64    |              365.4          |        232.6     |   1808.2  |              386.0        
      f16 B=1024, M=197, H=1, K=88   |             1456.3          |                  |   5964.7  |             1539.8        
      f16 B=1024, M=197, H=1, K=80   |             1382.9          |                  |   5037.8  |             1464.9        
      f16 B=1024, M=197, H=1, K=64   |              822.1          |        576.4     |   4732.2  |              862.5        
      f16 B=512, M=197, H=1, K=80    |              695.4          |                  |   2543.6  |              730.0        
      f16 B=32, M=197, H=16, K=80    |              691.2          |                  |   2567.7  |              716.3        
      f16 B=32, M=197, H=16, K=64    |              427.5          |        296.1     |   2428.1  |              456.6        
      f16 B=32, M=197, H=16, K=128   |              858.1          |        682.6     |   4488.4  |              853.5        
      f16 B=256, M=197, H=1, K=88    |              422.0          |                  |   1528.5  |              442.7        
      f16 B=16, M=197, H=16, K=88    |              420.4          |                  |   1543.0  |              437.3        
      f16 B=16, M=197, H=16, K=64    |              217.5          |        165.2     |   1243.5  |              232.6        
      f16 B=16, M=197, H=16, K=128   |              479.8          |        385.5     |   2263.8  |              489.2        
      f16 B=1, M=4096, H=160, K=128  |            51009.8          |      54670.3     |  45924.2  |            63431.9        
      f16 B=2, M=4096, H=160, K=128  |            84491.6          |      84261.8     |           |           100393.5        
      f16 B=1, M=8192, H=160, K=128  |           201456.7          |     215540.9     |           |           251825.4        
      f16 B=2, M=8192, H=160, K=128  |           329735.0          |     330316.3     |           |           395279.2        
      f16 B=1024, M=82, H=8, K=64    |             1764.0          |       1620.9     |   3822.6  |             1857.4        
      f16 B=150, M=256, H=16, K=64   |             2021.6          |       1626.3     |   4557.3  |             2103.9        
      f16 B=64, M=256, H=12, K=64    |              699.4          |        567.8     |   1498.4  |              730.6        
      f16 B=1, M=4096, H=16, K=40    |            22788.9          |                  |   4195.6  |            23624.7        
      f16 B=1, M=16384, H=16, K=40   |           408280.7          |                  |           |           436163.8        
      f16 B=256, M=4096, H=16, K=64  |           565651.1          |     439946.4     |           |           602642.6        
      f16 B=16, M=128, H=16, K=16    |              121.9          |        139.6     |    331.3  |              121.9        
      f16 B=16, M=128, H=16, K=32    |              121.4          |        139.2     |    331.5  |              122.5        
      f16 B=16, M=128, H=16, K=64    |              121.9          |        140.0     |    369.9  |              187.9        
      f16 B=16, M=128, H=16, K=128   |              186.7          |        170.3     |    332.9  |              177.8        
      f16 B=16, M=512, H=16, K=16    |              518.4          |        322.2     |   1204.6  |              556.4        
      f16 B=16, M=512, H=16, K=32    |              602.5          |        435.1     |   1306.5  |              652.2        
      f16 B=16, M=512, H=16, K=64    |              797.0          |        704.9     |   1547.1  |              850.2        
      f16 B=16, M=512, H=16, K=128   |             1544.8          |       1584.6     |   1985.3  |             1752.3        
      f16 B=16, M=1024, H=16, K=16   |             2049.1          |       1244.7     |   4261.7  |             2239.9        
      f16 B=16, M=1024, H=16, K=32   |             2229.0          |       1620.4     |   4492.3  |             2448.0        
      f16 B=16, M=1024, H=16, K=64   |             2817.6          |       2367.6     |   4998.2  |             3041.0        
      f16 B=16, M=1024, H=16, K=128  |             5433.3          |       5638.9     |   5958.5  |             6406.4        
      f16 B=64, M=128, H=16, K=16    |              158.2          |        145.5     |    439.7  |              161.9        
      f16 B=64, M=128, H=16, K=32    |              205.2          |        212.4     |    545.2  |              206.7        
      f16 B=64, M=128, H=16, K=64    |              314.6          |        311.5     |    767.7  |              326.0        
      f16 B=64, M=128, H=16, K=128   |              651.9          |        562.8     |   1227.5  |              613.3        
      f16 B=64, M=512, H=16, K=16    |             1872.3          |       1204.0     |   4488.6  |             1985.3        
      f16 B=64, M=512, H=16, K=32    |             2185.4          |       1543.8     |   4971.7  |             2340.3        
      f16 B=64, M=512, H=16, K=64    |             2940.4          |       2421.0     |   5885.5  |             3077.9        
      f16 B=64, M=512, H=16, K=128   |             5501.3          |       5446.7     |   7711.0  |             6153.0        
      f16 B=64, M=1024, H=16, K=16   |             7318.5          |       4711.4     |  16891.1  |             7890.2        
      f16 B=64, M=1024, H=16, K=32   |             8151.5          |       5697.1     |  17885.4  |             8849.9        
      f16 B=64, M=1024, H=16, K=64   |            10477.7          |       8155.9     |  19951.2  |            11059.9        
      f16 B=64, M=1024, H=16, K=128  |            19178.4          |      19198.4     |  23794.0  |            21939.1        

Times are in microseconds (us).
```
</details>

[ghstack-poisoned]
@danthe3rd danthe3rd merged commit 2c3a7da into gh/danthe3rd/57/base Nov 29, 2022
danthe3rd pushed a commit that referenced this pull request Nov 29, 2022
ghstack-source-id: 1364e7de510a82b3d21b044b9cc093be101ba510
Pull Request resolved: #540
@danthe3rd danthe3rd deleted the gh/danthe3rd/57/head branch November 29, 2022 17:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants