Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MemEff: Accumulate in f32 for bw #467

Merged
merged 17 commits into from
Dec 6, 2022

Conversation

danthe3rd
Copy link
Contributor

@danthe3rd danthe3rd commented Oct 6, 2022

Stack from ghstack (oldest at bottom):

PERFORMANCE

This makes performance worse in f16 :(
But I think we need it for stability

bw P100/V100 (f32/f16)
[---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------]                                                    
                                                         |  48_accf32_6f8e2f15  |  56_base_02bf6b4e  |  vanilla   |  57_tmpT_b516aec4
1 threads: --------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6289.6       |        6936.0      |    2183.6  |        6940.0    
                          f32 B=384, M=197, H=1, K=88    |         8793.3       |        9446.6      |    2175.1  |        9429.2    
                          f16 B=384, M=197, H=1, K=80    |         5989.9       |        6596.2      |    2146.8  |        6608.6    
                          f32 B=384, M=197, H=1, K=80    |         8427.1       |        8993.1      |    2134.0  |        9030.5    
                          f16 B=384, M=197, H=1, K=64    |         3347.0       |        3527.4      |    1799.3  |        3538.3    
                          f32 B=384, M=197, H=1, K=64    |         5563.1       |        5984.2      |    1801.6  |        5980.9    
                          f16 B=1024, M=197, H=1, K=88   |        15680.5       |       17424.8      |    5671.9  |       17452.6    
                          f32 B=1024, M=197, H=1, K=88   |        23784.2       |       25542.8      |    5664.0  |       25578.6    
                          f16 B=1024, M=197, H=1, K=80   |        14935.5       |       16581.3      |    5559.5  |       16587.3    
                          f32 B=1024, M=197, H=1, K=80   |        22767.6       |       24354.1      |    5550.7  |       24362.0    
                          f16 B=1024, M=197, H=1, K=64   |         8331.1       |        8671.5      |    4644.9  |        8695.3    
                          f32 B=1024, M=197, H=1, K=64   |        15061.6       |       16148.8      |    4650.0  |       16201.4    
                          f16 B=512, M=197, H=1, K=80    |         7594.7       |        8306.6      |    2824.6  |        8336.5    
                          f32 B=512, M=197, H=1, K=80    |        11857.0       |       12660.2      |    2807.4  |       12709.2    
                          f16 B=32, M=197, H=16, K=80    |         7779.8       |        8342.0      |    2820.8  |        8393.5    
                          f32 B=32, M=197, H=16, K=80    |        11846.5       |       12487.8      |    2806.0  |       12549.1    
                          f16 B=32, M=197, H=16, K=64    |         4258.4       |        4445.6      |    2374.0  |        4461.2    
                          f32 B=32, M=197, H=16, K=64    |         7828.9       |        8434.9      |    2376.0  |        8472.2    
                          f16 B=32, M=197, H=16, K=128   |         9025.1       |        9671.9      |    3159.3  |        9707.2    
                          f32 B=32, M=197, H=16, K=128   |        14139.8       |       14920.7      |    3157.7  |       14928.9    
                          f16 B=256, M=197, H=1, K=88    |         4608.9       |        5118.8      |    1478.4  |        5119.2    
                          f32 B=256, M=197, H=1, K=88    |         6174.0       |        6644.3      |    1477.6  |        6642.7    
                          f16 B=16, M=197, H=16, K=88    |         4618.4       |        5073.4      |    1479.5  |        5062.9    
                          f32 B=16, M=197, H=16, K=88    |         6114.2       |        6471.7      |    1474.9  |        6478.7    
                          f16 B=16, M=197, H=16, K=64    |         2490.3       |        2557.5      |    1225.9  |        2550.0    
                          f32 B=16, M=197, H=16, K=64    |         3918.9       |        4208.2      |    1227.0  |        4195.7    
                          f16 B=16, M=197, H=16, K=128   |         5210.0       |        5649.3      |    1635.2  |        5648.1    
                          f32 B=16, M=197, H=16, K=128   |         7103.1       |        7451.6      |    1645.5  |        7445.1    
                          f16 B=1, M=4096, H=160, K=128  |      1014229.1       |     1106182.6      |            |     1108040.6    
                          f32 B=1, M=4096, H=160, K=128  |      1258173.2       |     1243183.8      |            |     1241548.7    
                          f16 B=2, M=4096, H=160, K=128  |      1642279.2       |     1753736.9      |            |     1771655.3    
                          f32 B=2, M=4096, H=160, K=128  |      2505435.4       |     2477353.1      |            |     2473773.4    
                          f16 B=1, M=8192, H=160, K=128  |      4050128.4       |     4415962.1      |            |     4428194.9    
                          f32 B=1, M=8192, H=160, K=128  |      5042352.6       |     4970582.0      |            |     4965069.9    
                          f16 B=2, M=8192, H=160, K=128  |      6600732.3       |     7026378.5      |            |     7068051.8    
                          f16 B=1024, M=82, H=8, K=64    |        21572.8       |       22531.6      |    9059.1  |       22418.4    
                          f32 B=1024, M=82, H=8, K=64    |        38178.4       |       45927.2      |    9070.3  |       45708.0    
                          f16 B=150, M=256, H=16, K=64   |        21436.5       |       21927.4      |   12938.5  |       22001.0    
                          f32 B=150, M=256, H=16, K=64   |        33024.2       |       33196.3      |   13249.2  |       33199.6    
                          f16 B=64, M=256, H=12, K=64    |         6869.8       |        7048.6      |    4200.6  |        7073.6    
                          f32 B=64, M=256, H=12, K=64    |        10719.9       |       10832.1      |    4271.3  |       10843.6    
                          f16 B=1, M=4096, H=16, K=40    |       134722.6       |      145429.8      |   20587.0  |      143743.8    
                          f32 B=1, M=4096, H=16, K=40    |       143015.1       |      147850.6      |   20625.8  |      148272.8    
                          f16 B=1, M=16384, H=16, K=40   |      2149850.4       |     2323732.4      |            |     2301489.4    
                          f32 B=1, M=16384, H=16, K=40   |      2286478.1       |     2369812.4      |            |     2375911.8    
                          f16 B=16, M=128, H=16, K=16    |          497.2       |         502.9      |     623.8  |         503.6    
                          f32 B=16, M=128, H=16, K=16    |          573.5       |         609.8      |     617.5  |         611.6    
                          f16 B=16, M=128, H=16, K=32    |          563.9       |         573.1      |     624.2  |         575.2    
                          f32 B=16, M=128, H=16, K=32    |          661.5       |         702.2      |     620.4  |         703.2    
                          f16 B=16, M=128, H=16, K=64    |          708.9       |         722.6      |     619.9  |         724.3    
                          f32 B=16, M=128, H=16, K=64    |          916.2       |         953.5      |     618.6  |         953.6    
                          f16 B=16, M=128, H=16, K=128   |         1465.8       |        1542.8      |     624.2  |        1545.7    
                          f32 B=16, M=128, H=16, K=128   |         1829.1       |        1872.3      |     616.2  |        1873.9    
                          f16 B=16, M=128, H=16, K=256   |         3796.3       |        4002.5      |    1010.7  |        4008.4    
                          f32 B=16, M=128, H=16, K=256   |         3838.8       |        3957.1      |    1203.6  |        3951.8    
                          f16 B=16, M=512, H=16, K=16    |         7680.4       |        7775.1      |    4848.7  |        7752.8    
                          f32 B=16, M=512, H=16, K=16    |         9402.0       |        9926.1      |    4926.7  |        9914.5    
                          f16 B=16, M=512, H=16, K=32    |         8897.5       |        8999.0      |    5055.9  |        9003.9    
                          f32 B=16, M=512, H=16, K=32    |        10762.1       |       11320.9      |    5065.4  |       11341.8    
                          f16 B=16, M=512, H=16, K=64    |        10936.9       |       11214.8      |    5484.9  |       11254.4    
                          f32 B=16, M=512, H=16, K=64    |        15091.9       |       15210.3      |    5552.0  |       15196.3    
                          f16 B=16, M=512, H=16, K=128   |        23491.0       |       25234.8      |    7317.1  |       25300.3    
                          f32 B=16, M=512, H=16, K=128   |        30524.6       |       30320.0      |    7487.0  |       30200.6    
                          f16 B=16, M=512, H=16, K=256   |        50389.7       |       54525.2      |   14015.3  |       54931.1    
                          f32 B=16, M=512, H=16, K=256   |        62155.6       |       61046.4      |   14285.5  |       61081.1    
                          f16 B=16, M=1024, H=16, K=16   |        31289.9       |       31951.9      |   18778.4  |       32044.9    
                          f32 B=16, M=1024, H=16, K=16   |        37744.4       |       39586.6      |   18929.9  |       39739.9    
                          f16 B=16, M=1024, H=16, K=32   |        35770.8       |       36651.1      |   19620.2  |       36909.0    
                          f32 B=16, M=1024, H=16, K=32   |        43211.5       |       45544.5      |   19518.4  |       45664.8    
                          f16 B=16, M=1024, H=16, K=64   |        43865.1       |       44864.8      |   21286.0  |       45345.7    
                          f32 B=16, M=1024, H=16, K=64   |        60710.5       |       60966.9      |   21634.1  |       60921.9    
                          f16 B=16, M=1024, H=16, K=128  |        94633.7       |      101796.0      |   28502.1  |      102871.6    
                          f32 B=16, M=1024, H=16, K=128  |       124093.6       |      122196.9      |   28520.3  |      122043.0    
                          f16 B=16, M=1024, H=16, K=256  |       194780.7       |      212303.0      |   55419.1  |      214643.5    
                          f32 B=16, M=1024, H=16, K=256  |       250799.1       |      245196.2      |   55634.0  |      245534.2    
                          f16 B=64, M=128, H=16, K=16    |         1658.0       |        1661.4      |    1331.0  |        1662.7    
                          f32 B=64, M=128, H=16, K=16    |         2129.2       |        2266.2      |    1371.8  |        2269.4    
                          f16 B=64, M=128, H=16, K=32    |         1904.3       |        1903.5      |    1384.2  |        1906.4    
                          f32 B=64, M=128, H=16, K=32    |         2496.5       |        2643.4      |    1445.1  |        2639.8    
                          f16 B=64, M=128, H=16, K=64    |         2393.1       |        2432.3      |    1505.2  |        2437.8    
                          f32 B=64, M=128, H=16, K=64    |         3471.1       |        3561.2      |    1590.0  |        3565.7    
                          f16 B=64, M=128, H=16, K=128   |         4969.1       |        5266.3      |    1988.4  |        5272.4    
                          f32 B=64, M=128, H=16, K=128   |         6880.3       |        7040.6      |    2121.0  |        7024.3    
                          f16 B=64, M=128, H=16, K=256   |        12635.3       |       13334.7      |    3859.5  |       13350.7    
                          f32 B=64, M=128, H=16, K=256   |        14185.7       |       14514.4      |    4553.8  |       14503.1    
                          f16 B=64, M=512, H=16, K=16    |        26278.4       |       26189.9      |   18916.0  |       26110.0    
                          f32 B=64, M=512, H=16, K=16    |        35128.2       |       37075.9      |   19191.4  |       37174.5    
                          f16 B=64, M=512, H=16, K=32    |        30414.2       |       30938.2      |   19734.7  |       31071.7    
                          f32 B=64, M=512, H=16, K=32    |        40483.2       |       42922.5      |   19843.9  |       42868.5    
                          f16 B=64, M=512, H=16, K=64    |        37248.2       |       38179.7      |   21640.1  |       38327.0    
                          f32 B=64, M=512, H=16, K=64    |        57666.8       |       57675.3      |   21909.1  |       57697.2    
                          f16 B=64, M=512, H=16, K=128   |        80113.6       |       86165.7      |   28765.5  |       86368.3    
                          f32 B=64, M=512, H=16, K=128   |       115672.5       |      115161.0      |   28910.3  |      115320.0    
                          f16 B=64, M=512, H=16, K=256   |       169250.0       |      183791.8      |   56315.7  |      183735.0    
                          f32 B=64, M=512, H=16, K=256   |       236594.6       |      233093.6      |   56853.4  |      233170.2    
                          f16 B=64, M=1024, H=16, K=16   |       106022.3       |      109588.2      |   74303.6  |      109410.4    
                          f32 B=64, M=1024, H=16, K=16   |       141241.8       |      148854.1      |            |      149651.5    
                          f16 B=64, M=1024, H=16, K=32   |       120899.1       |      125044.6      |   77828.6  |      125716.9    
                          f32 B=64, M=1024, H=16, K=32   |       162478.5       |      173906.1      |            |      173216.4    
                          f16 B=64, M=1024, H=16, K=64   |       149044.1       |      152290.6      |   85821.6  |      152748.5    
                          f32 B=64, M=1024, H=16, K=64   |       233195.9       |      231479.6      |            |      231533.9    
                          f16 B=64, M=1024, H=16, K=128  |       319761.5       |      344076.5      |  113579.4  |      345058.5    
                          f32 B=64, M=1024, H=16, K=128  |       470172.3       |      466330.7      |            |      463507.4    
                          f16 B=64, M=1024, H=16, K=256  |       658362.1       |      717070.2      |            |      723057.4    
                          f32 B=64, M=1024, H=16, K=256  |       955624.0       |      935114.3      |            |      935945.4    
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1811.3       |        1686.3      |    1375.2  |        1699.6    
                          f32 B=384, M=197, H=1, K=88    |         4315.7       |        4665.1      |    2257.1  |        4663.1    
                          f16 B=384, M=197, H=1, K=80    |         1733.0       |        1616.3      |    1281.7  |        1619.8    
                          f32 B=384, M=197, H=1, K=80    |         3965.0       |        4226.4      |    2171.7  |        4228.3    
                          f16 B=384, M=197, H=1, K=64    |         1135.2       |        1084.4      |    1043.8  |        1083.0    
                          f32 B=384, M=197, H=1, K=64    |         2673.2       |        2883.0      |    1744.5  |        2878.2    
                          f16 B=1024, M=197, H=1, K=88   |         4721.0       |        4396.6      |    3725.3  |        4404.1    
                          f32 B=1024, M=197, H=1, K=88   |        10531.3       |       11443.8      |    6106.5  |       11464.2    
                          f16 B=1024, M=197, H=1, K=80   |         4520.2       |        4216.2      |    3329.2  |        4223.7    
                          f32 B=1024, M=197, H=1, K=80   |         9573.5       |       10301.4      |    5757.6  |       10305.3    
                          f16 B=1024, M=197, H=1, K=64   |         2788.1       |        2660.1      |    2674.6  |        2663.5    
                          f32 B=1024, M=197, H=1, K=64   |         6556.7       |        7102.4      |    4516.6  |        7096.8    
                          f16 B=512, M=197, H=1, K=80    |         2377.2       |        2228.1      |    1685.0  |        2231.4    
                          f32 B=512, M=197, H=1, K=80    |         5259.3       |        5639.7      |    2887.8  |        5636.6    
                          f16 B=32, M=197, H=16, K=80    |         2403.0       |        2201.1      |    1798.2  |        2204.9    
                          f32 B=32, M=197, H=16, K=80    |         5402.3       |        5667.7      |    3046.3  |        5662.4    
                          f16 B=32, M=197, H=16, K=64    |         1552.7       |        1486.3      |    1451.6  |        1485.2    
                          f32 B=32, M=197, H=16, K=64    |         3622.5       |        3911.0      |    2427.0  |        3912.9    
                          f16 B=32, M=197, H=16, K=128   |         2776.6       |        2611.9      |    2211.3  |        2613.3    
                          f32 B=32, M=197, H=16, K=128   |         6647.9       |        7082.3      |    4088.5  |        7104.8    
                          f16 B=256, M=197, H=1, K=88    |         1357.5       |        1285.5      |     941.3  |        1287.6    
                          f32 B=256, M=197, H=1, K=88    |         2874.1       |        3085.1      |    1543.2  |        3090.1    
                          f16 B=16, M=197, H=16, K=88    |         1349.1       |        1264.5      |     964.7  |        1263.1    
                          f32 B=16, M=197, H=16, K=88    |         2803.8       |        2967.0      |    1647.4  |        2972.0    
                          f16 B=16, M=197, H=16, K=64    |          765.5       |         728.4      |     765.3  |         731.0    
                          f32 B=16, M=197, H=16, K=64    |         1834.1       |        1969.7      |    1282.4  |        1974.4    
                          f16 B=16, M=197, H=16, K=128   |         1509.1       |        1432.6      |    1139.1  |        1433.9    
                          f32 B=16, M=197, H=16, K=128   |         3406.5       |        3606.2      |    2048.6  |        3613.3    
                          f16 B=1, M=4096, H=160, K=128  |       168807.1       |      148652.9      |            |      149343.8    
                          f32 B=1, M=4096, H=160, K=128  |       549864.6       |      586699.9      |            |      585699.9    
                          f16 B=2, M=4096, H=160, K=128  |       339010.5       |      298827.7      |            |      298808.4    
                          f32 B=2, M=4096, H=160, K=128  |      1106963.8       |     1176218.3      |            |     1179173.2    
                          f16 B=1, M=8192, H=160, K=128  |       679742.4       |      594323.2      |            |      595580.1    
                          f32 B=1, M=8192, H=160, K=128  |      2195491.3       |     2340248.4      |            |     2343505.7    
                          f16 B=2, M=8192, H=160, K=128  |      1364983.7       |     1193787.8      |            |     1192596.1    
                          f16 B=1024, M=82, H=8, K=64    |         9052.8       |        8762.6      |    5804.2  |        8757.8    
                          f32 B=1024, M=82, H=8, K=64    |        14726.3       |       16270.6      |   11059.7  |       16215.9    
                          f16 B=150, M=256, H=16, K=64   |         5662.9       |        5519.4      |    7557.8  |        5524.5    
                          f32 B=150, M=256, H=16, K=64   |        16700.3       |       17612.7      |   16426.0  |       17640.4    
                          f16 B=64, M=256, H=12, K=64    |         1849.1       |        1793.1      |    2383.5  |        1798.7    
                          f32 B=64, M=256, H=12, K=64    |         5451.5       |        5766.4      |    4975.3  |        5775.8    
                          f16 B=1, M=4096, H=16, K=40    |        47263.3       |       47850.4      |    8315.6  |       47777.2    
                          f32 B=1, M=4096, H=16, K=40    |       113099.4       |      113164.5      |   19536.0  |      113930.0    
                          f16 B=1, M=16384, H=16, K=40   |       757091.8       |      770365.7      |            |      765401.3    
                          f32 B=1, M=16384, H=16, K=40   |      1806827.7       |     1816302.6      |            |     1819162.1    
                          f16 B=16, M=128, H=16, K=16    |          219.2       |         218.5      |     480.6  |         231.8    
                          f32 B=16, M=128, H=16, K=16    |          301.2       |         308.2      |     498.9  |         307.9    
                          f16 B=16, M=128, H=16, K=32    |          227.6       |         215.4      |     473.8  |         220.1    
                          f32 B=16, M=128, H=16, K=32    |          395.8       |         401.0      |     455.1  |         400.8    
                          f16 B=16, M=128, H=16, K=64    |          225.6       |         214.8      |     510.1  |         229.9    
                          f32 B=16, M=128, H=16, K=64    |          561.7       |         583.1      |     598.9  |         581.2    
                          f16 B=16, M=128, H=16, K=128   |          404.6       |         392.0      |     524.0  |         394.8    
                          f32 B=16, M=128, H=16, K=128   |         1103.0       |        1140.8      |    1015.4  |        1142.0    
                          f16 B=16, M=128, H=16, K=256   |         1045.3       |        1049.1      |     889.6  |        1047.4    
                          f32 B=16, M=128, H=16, K=256   |         2181.3       |        2270.9      |    1869.2  |        2265.7    
                          f16 B=16, M=512, H=16, K=16    |         1731.5       |        1585.7      |    1908.5  |        1586.2    
                          f32 B=16, M=512, H=16, K=16    |         4513.8       |        4696.7      |    4222.1  |        4695.6    
                          f16 B=16, M=512, H=16, K=32    |         1942.2       |        1823.4      |    2086.3  |        1809.7    
                          f32 B=16, M=512, H=16, K=32    |         5596.1       |        5819.4      |    4588.1  |        5833.6    
                          f16 B=16, M=512, H=16, K=64    |         2450.2       |        2340.8      |    2580.6  |        2353.2    
                          f32 B=16, M=512, H=16, K=64    |         7619.5       |        7853.1      |    5536.8  |        7875.1    
                          f16 B=16, M=512, H=16, K=128   |         4884.8       |        4473.5      |    3388.7  |        4487.9    
                          f32 B=16, M=512, H=16, K=128   |        15010.0       |       15557.1      |    8979.4  |       15513.1    
                          f16 B=16, M=512, H=16, K=256   |        12973.9       |       11134.7      |    5418.2  |       11106.3    
                          f32 B=16, M=512, H=16, K=256   |        29751.1       |       30979.1      |   16856.3  |       31012.8    
                          f16 B=16, M=1024, H=16, K=16   |         6802.6       |        6185.1      |    6996.0  |        6191.0    
                          f32 B=16, M=1024, H=16, K=16   |        18098.2       |       18954.5      |   16129.1  |       19166.8    
                          f16 B=16, M=1024, H=16, K=32   |         7531.9       |        7065.0      |    7436.0  |        7067.5    
                          f32 B=16, M=1024, H=16, K=32   |        21999.9       |       22901.5      |   17040.0  |       22899.4    
                          f16 B=16, M=1024, H=16, K=64   |         9312.0       |        8854.3      |    8605.6  |        8863.9    
                          f32 B=16, M=1024, H=16, K=64   |        29837.7       |       30895.1      |   20355.3  |       30878.5    
                          f16 B=16, M=1024, H=16, K=128  |        18979.0       |       16951.0      |   10561.3  |       16995.2    
                          f32 B=16, M=1024, H=16, K=128  |        58738.3       |       60861.2      |   33427.0  |       60599.8    
                          f16 B=16, M=1024, H=16, K=256  |        49681.9       |       41833.3      |   17329.5  |       41921.6    
                          f32 B=16, M=1024, H=16, K=256  |       117362.4       |      121004.8      |   60515.8  |      122046.9    
                          f16 B=64, M=128, H=16, K=16    |          432.2       |         411.1      |     642.1  |         411.3    
                          f32 B=64, M=128, H=16, K=16    |         1028.9       |        1057.9      |    1233.4  |        1056.6    
                          f16 B=64, M=128, H=16, K=32    |          522.5       |         500.7      |     813.6  |         499.7    
                          f32 B=64, M=128, H=16, K=32    |         1403.5       |        1443.4      |    1535.6  |        1443.6    
                          f16 B=64, M=128, H=16, K=64    |          750.0       |         739.8      |    1185.6  |         741.0    
                          f32 B=64, M=128, H=16, K=64    |         2013.9       |        2110.1      |    2156.8  |        2105.4    
                          f16 B=64, M=128, H=16, K=128   |         1421.2       |        1387.9      |    1915.5  |        1388.3    
                          f32 B=64, M=128, H=16, K=128   |         3946.5       |        4156.7      |    3780.1  |        4156.3    
                          f16 B=64, M=128, H=16, K=256   |         3811.5       |        3810.6      |    3448.7  |        3811.0    
                          f32 B=64, M=128, H=16, K=256   |         7983.9       |        8432.7      |    7304.0  |        8415.5    
                          f16 B=64, M=512, H=16, K=16    |         6157.4       |        5523.9      |    7461.8  |        5528.1    
                          f32 B=64, M=512, H=16, K=16    |        16118.2       |       17004.7      |   16651.7  |       16984.5    
                          f16 B=64, M=512, H=16, K=32    |         6985.1       |        6483.1      |    8278.5  |        6495.1    
                          f32 B=64, M=512, H=16, K=32    |        20471.9       |       21230.7      |   18420.5  |       21269.0    
                          f16 B=64, M=512, H=16, K=64    |         9045.8       |        8633.5      |   10337.1  |        8674.1    
                          f32 B=64, M=512, H=16, K=64    |        27675.0       |       29282.8      |   22883.7  |       29136.9    
                          f16 B=64, M=512, H=16, K=128   |        17594.2       |       15805.9      |   14700.6  |       15788.2    
                          f32 B=64, M=512, H=16, K=128   |        54612.4       |       57815.7      |   39974.3  |       57951.8    
                          f16 B=64, M=512, H=16, K=256   |        47452.6       |       40093.5      |   27087.5  |       40175.6    
                          f32 B=64, M=512, H=16, K=256   |       108880.3       |      115953.1      |   77794.2  |      115951.0    
                          f16 B=64, M=1024, H=16, K=16   |        24369.0       |       21533.0      |   28448.6  |       21556.6    
                          f32 B=64, M=1024, H=16, K=16   |        64649.7       |       68791.8      |            |       68407.4    
                          f16 B=64, M=1024, H=16, K=32   |        27143.1       |       25683.7      |   30252.7  |       25727.9    
                          f32 B=64, M=1024, H=16, K=32   |        79967.5       |       83351.5      |            |       83084.7    
                          f16 B=64, M=1024, H=16, K=64   |        34667.0       |       32592.7      |   36991.2  |       32659.8    
                          f32 B=64, M=1024, H=16, K=64   |       108282.2       |      113858.1      |            |      114286.0    
                          f16 B=64, M=1024, H=16, K=128  |        68519.5       |       59757.3      |   48834.4  |       59817.7    
                          f32 B=64, M=1024, H=16, K=128  |       215465.9       |      227335.8      |            |      227204.0    
                          f16 B=64, M=1024, H=16, K=256  |       183070.4       |      150960.1      |            |      150947.1    
                          f32 B=64, M=1024, H=16, K=256  |       425832.9       |      453717.2      |            |      453349.2    

Times are in microseconds (us).
bw A100 (f32/f16)
[----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------]                                                       
                                     |  48_accf32_69654fdb[cutlass]  |  flash[flshatt]  |  vanilla   |  56_base_02bf6b4e[cutlass]  |  57_tmpT_b516aec4[cutlass]
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |               613.4           |                  |    2264.9  |              612.8          |              578.3        
      f32 B=384, M=197, H=1, K=88    |              2335.2           |                  |    1843.0  |             2438.4          |             2425.1        
      f16 B=384, M=197, H=1, K=80    |               583.6           |                  |    1922.9  |              577.6          |              548.4        
      f32 B=384, M=197, H=1, K=80    |              2241.3           |                  |    1787.8  |             2333.2          |             2333.1        
      f16 B=384, M=197, H=1, K=64    |               405.6           |        232.5     |    1809.8  |              386.4          |              366.1        
      f32 B=384, M=197, H=1, K=64    |              1259.7           |                  |    1675.6  |             1309.6          |             1316.8        
      f16 B=1024, M=197, H=1, K=88   |              1538.2           |                  |    5964.7  |             1550.8          |             1454.4        
      f32 B=1024, M=197, H=1, K=88   |              6031.4           |                  |    4559.5  |             6325.4          |             6332.7        
      f16 B=1024, M=197, H=1, K=80   |              1458.8           |                  |    5038.2  |             1463.8          |             1379.0        
      f32 B=1024, M=197, H=1, K=80   |              5786.2           |                  |    4412.6  |             6059.0          |             6079.9        
      f16 B=1024, M=197, H=1, K=64   |               929.5           |        575.9     |    4735.6  |              862.1          |              821.6        
      f32 B=1024, M=197, H=1, K=64   |              3289.9           |                  |    4119.9  |             3434.4          |             3441.4        
      f16 B=512, M=197, H=1, K=80    |               744.5           |                  |    2544.7  |              735.3          |              695.8        
      f32 B=512, M=197, H=1, K=80    |              2889.4           |                  |    2286.1  |             3008.8          |             3029.0        
      f16 B=32, M=197, H=16, K=80    |               741.6           |                  |    2569.0  |              723.3          |              693.4        
      f32 B=32, M=197, H=16, K=80    |              2878.6           |                  |    2355.3  |             3003.0          |             3024.1        
      f16 B=32, M=197, H=16, K=64    |               478.5           |        295.7     |    2429.2  |              456.1          |              426.7        
      f32 B=32, M=197, H=16, K=64    |              1784.4           |                  |    2196.9  |             1863.7          |             1860.2        
      f16 B=32, M=197, H=16, K=128   |               887.0           |        682.3     |    4492.9  |              857.5          |              853.8        
      f32 B=32, M=197, H=16, K=128   |              3546.6           |                  |    2807.3  |             3734.9          |             3737.1        
      f16 B=256, M=197, H=1, K=88    |               445.1           |                  |    1528.9  |              445.0          |              422.2        
      f32 B=256, M=197, H=1, K=88    |              1678.8           |                  |    1207.6  |             1746.6          |             1752.8        
      f16 B=16, M=197, H=16, K=88    |               441.5           |                  |    1544.2  |              437.3          |              419.5        
      f32 B=16, M=197, H=16, K=88    |              1668.3           |                  |    1250.4  |             1742.7          |             1746.1        
      f16 B=16, M=197, H=16, K=64    |               247.4           |        165.6     |    1242.5  |              233.2          |              217.5        
      f32 B=16, M=197, H=16, K=64    |              1051.1           |                  |    1125.3  |             1099.2          |             1096.6        
      f16 B=16, M=197, H=16, K=128   |               498.4           |        386.2     |    2264.5  |              488.0          |              480.5        
      f32 B=16, M=197, H=16, K=128   |              1950.2           |                  |    1446.7  |             2039.0          |             2028.6        
      f16 B=1, M=4096, H=160, K=128  |             55915.0           |      54620.4     |   45909.5  |            63407.6          |            51298.8        
      f32 B=1, M=4096, H=160, K=128  |            238514.6           |                  |            |           232677.5          |           232672.5        
      f16 B=2, M=4096, H=160, K=128  |             93612.0           |      84238.0     |            |           100433.1          |            84858.1        
      f32 B=2, M=4096, H=160, K=128  |            375037.4           |                  |            |           364234.1          |           364407.2        
      f16 B=1, M=8192, H=160, K=128  |            223261.8           |     215499.8     |            |           251806.6          |           202133.7        
      f32 B=1, M=8192, H=160, K=128  |            946708.9           |                  |            |           924986.2          |           924988.9        
      f16 B=2, M=8192, H=160, K=128  |            367969.2           |     330092.8     |            |           395881.3          |           332000.6        
      f32 B=2, M=8192, H=160, K=128  |           1492031.4           |                  |            |          1448691.1          |          1449146.1        
      f16 B=1024, M=82, H=8, K=64    |              1890.2           |       1620.3     |    3819.7  |             1861.3          |             1764.6        
      f32 B=1024, M=82, H=8, K=64    |              8428.2           |                  |    8735.1  |             8831.3          |             8867.4        
      f16 B=150, M=256, H=16, K=64   |              2292.3           |       1625.3     |    4555.9  |             2109.8          |             2019.4        
      f32 B=150, M=256, H=16, K=64   |              6252.4           |                  |   12948.1  |             6288.9          |             6281.3        
      f16 B=64, M=256, H=12, K=64    |               782.2           |        567.4     |    1498.0  |              731.2          |              699.7        
      f32 B=64, M=256, H=12, K=64    |              2141.2           |                  |    4266.6  |             2160.6          |             2160.2        
      f16 B=1, M=4096, H=16, K=40    |             23504.2           |                  |    4196.1  |            23699.0          |            23008.9        
      f32 B=1, M=4096, H=16, K=40    |             73699.5           |                  |   17755.3  |            73261.8          |            73078.5        
      f16 B=1, M=16384, H=16, K=40   |            391408.9           |                  |            |           439777.3          |           407653.7        
      f32 B=1, M=16384, H=16, K=40   |           1196173.6           |                  |            |          1181547.9          |          1181625.1        
      f16 B=256, M=4096, H=16, K=64  |            733221.8           |     439627.5     |            |           603237.1          |           565905.8        
      f16 B=16, M=128, H=16, K=16    |               130.0           |        113.1     |     265.2  |              125.5          |              124.4        
      f32 B=16, M=128, H=16, K=16    |               161.5           |                  |     373.1  |              160.6          |              162.8        
      f16 B=16, M=128, H=16, K=32    |               125.8           |        111.5     |     263.8  |              122.1          |              125.8        
      f32 B=16, M=128, H=16, K=32    |               189.8           |                  |     412.6  |              196.3          |              196.2        
      f16 B=16, M=128, H=16, K=64    |               126.0           |        112.4     |     265.8  |              120.8          |              125.7        
      f32 B=16, M=128, H=16, K=64    |               272.1           |                  |     498.7  |              285.9          |              283.3        
      f16 B=16, M=128, H=16, K=128   |               181.3           |        158.4     |     298.5  |              178.0          |              186.5        
      f32 B=16, M=128, H=16, K=128   |               509.6           |                  |     673.9  |              521.1          |              521.5        
      f16 B=16, M=128, H=16, K=256   |               774.0           |                  |     541.4  |              775.5          |              757.0        
      f32 B=16, M=128, H=16, K=256   |               975.2           |                  |    1162.6  |              994.5          |              994.5        
      f16 B=16, M=512, H=16, K=16    |               621.0           |        322.6     |    1204.9  |              555.0          |              519.9        
      f32 B=16, M=512, H=16, K=16    |              2148.0           |                  |    4414.4  |             2178.9          |             2180.4        
      f16 B=16, M=512, H=16, K=32    |               709.9           |        435.6     |    1306.2  |              653.1          |              602.8        
      f32 B=16, M=512, H=16, K=32    |              2335.6           |                  |    4640.7  |             2336.0          |             2336.3        
      f16 B=16, M=512, H=16, K=64    |               917.8           |        702.7     |    1545.8  |              849.9          |              797.4        
      f32 B=16, M=512, H=16, K=64    |              2965.5           |                  |    5125.0  |             2986.9          |             2988.3        
      f16 B=16, M=512, H=16, K=128   |              1644.0           |       1584.1     |    1983.4  |             1757.0          |             1548.0        
      f32 B=16, M=512, H=16, K=128   |              6152.7           |                  |    6099.1  |             6067.7          |             6068.6        
      f16 B=16, M=512, H=16, K=256   |              8178.3           |                  |    2899.1  |             7895.5          |             7977.4        
      f32 B=16, M=512, H=16, K=256   |             11894.0           |                  |   10639.5  |            11635.0          |            11624.7        
      f16 B=16, M=1024, H=16, K=16   |              2420.5           |       1240.8     |    4259.2  |             2234.6          |             2048.9        
      f32 B=16, M=1024, H=16, K=16   |              8467.2           |                  |   16650.4  |             8512.2          |             8510.2        
      f16 B=16, M=1024, H=16, K=32   |              2675.7           |       1618.9     |    4491.2  |             2441.3          |             2230.0        
      f32 B=16, M=1024, H=16, K=32   |              9010.0           |                  |   17301.0  |             9012.1          |             9015.7        
      f16 B=16, M=1024, H=16, K=64   |              3328.4           |       2370.3     |    4994.5  |             3032.2          |             2820.0        
      f32 B=16, M=1024, H=16, K=64   |             11566.8           |                  |   18714.2  |            11494.1          |            11492.8        
      f16 B=16, M=1024, H=16, K=128  |              5867.9           |       5632.5     |    5952.8  |             6401.1          |             5440.8        
      f32 B=16, M=1024, H=16, K=128  |             23345.4           |                  |   21523.7  |            22859.1          |            22870.3        
      f16 B=16, M=1024, H=16, K=256  |             30619.2           |                  |    7893.1  |            29884.9          |            29060.4        
      f32 B=16, M=1024, H=16, K=256  |             45211.4           |                  |   38093.0  |            43435.3          |            43423.8        
      f16 B=64, M=128, H=16, K=16    |               159.6           |        145.2     |     439.9  |              161.2          |              167.0        
      f32 B=64, M=128, H=16, K=16    |               493.4           |                  |    1270.0  |              502.7          |              503.2        
      f16 B=64, M=128, H=16, K=32    |               208.5           |        212.1     |     545.3  |              206.2          |              204.9        
      f32 B=64, M=128, H=16, K=32    |               601.2           |                  |    1427.0  |              610.5          |              610.6        
      f16 B=64, M=128, H=16, K=64    |               329.0           |        310.8     |     766.0  |              327.2          |              314.7        
      f32 B=64, M=128, H=16, K=64    |               867.7           |                  |    1743.5  |              889.2          |              889.0        
      f16 B=64, M=128, H=16, K=128   |               635.5           |        562.1     |    1226.7  |              613.3          |              650.7        
      f32 B=64, M=128, H=16, K=128   |              1774.4           |                  |    2386.5  |             1800.0          |             1799.6        
      f16 B=64, M=128, H=16, K=256   |              2839.8           |                  |    2122.8  |             2765.2          |             2821.9        
      f32 B=64, M=128, H=16, K=256   |              3419.2           |                  |    4320.7  |             3458.8          |             3457.0        
      f16 B=64, M=512, H=16, K=16    |              2316.2           |       1202.3     |    4487.1  |             1983.9          |             1868.4        
      f32 B=64, M=512, H=16, K=16    |              6686.0           |                  |   16991.5  |             6709.9          |             6713.2        
      f16 B=64, M=512, H=16, K=32    |              2701.7           |       1541.9     |    4975.9  |             2346.4          |             2184.3        
      f32 B=64, M=512, H=16, K=32    |              7460.5           |                  |   17859.9  |             7429.6          |             7438.0        
      f16 B=64, M=512, H=16, K=64    |              3461.2           |       2418.2     |    5886.1  |             3083.1          |             2942.6        
      f32 B=64, M=512, H=16, K=64    |              9553.4           |                  |   19768.6  |             9526.5          |             9516.6        
      f16 B=64, M=512, H=16, K=128   |              5875.5           |       5443.1     |    7711.0  |             6141.3          |             5513.6        
      f32 B=64, M=512, H=16, K=128   |             21317.0           |                  |   23651.2  |            20931.9          |            20932.3        
      f16 B=64, M=512, H=16, K=256   |             31238.2           |                  |   11490.6  |            28626.8          |            28954.0        
      f32 B=64, M=512, H=16, K=256   |             41124.8           |                  |   42468.0  |            39562.8          |            39579.5        
      f16 B=64, M=1024, H=16, K=16   |              9142.8           |       4707.2     |   16882.8  |             7887.0          |             7314.5        
      f32 B=64, M=1024, H=16, K=16   |             26512.3           |                  |   66311.7  |            26497.5          |            26459.8        
      f16 B=64, M=1024, H=16, K=32   |             10420.8           |       5698.3     |   17875.0  |             8851.4          |             7974.0        
      f32 B=64, M=1024, H=16, K=32   |             28300.0           |                  |   69088.7  |            28102.1          |            28098.2        
      f16 B=64, M=1024, H=16, K=64   |             12948.5           |       8119.0     |   19944.3  |            11064.2          |            10477.6        
      f32 B=64, M=1024, H=16, K=64   |             35600.6           |                  |   74762.8  |            35316.4          |            35361.2        
      f16 B=64, M=1024, H=16, K=128  |             20820.5           |      19184.2     |   23699.3  |            21954.8          |            19220.0        
      f32 B=64, M=1024, H=16, K=128  |             80800.3           |                  |   86003.8  |            78521.9          |            78393.9        
      f16 B=64, M=1024, H=16, K=256  |            114411.1           |                  |   32958.3  |           103287.2          |           104304.2        
      f32 B=64, M=1024, H=16, K=256  |            155731.5           |                  |  153011.0  |           148071.6          |           148165.6        

Times are in microseconds (us).

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 6, 2022
danthe3rd pushed a commit that referenced this pull request Oct 6, 2022
ghstack-source-id: aff4f021abbadcf565a40f952eea873c7e5d3f09
Pull Request resolved: #467
danthe3rd pushed a commit that referenced this pull request Oct 7, 2022
ghstack-source-id: e35fdbfbadbdb88090dd7632719b6c3f071568a0
Pull Request resolved: #467
@danthe3rd danthe3rd changed the title bwaccf32: Accumulate in f32 for bw [WIP] bwaccf32: Accumulate in f32 for bw Oct 7, 2022
danthe3rd pushed a commit that referenced this pull request Nov 28, 2022
ghstack-source-id: 64b7aef3cdfba4dcaea31c651bb663a6421faec0
Pull Request resolved: #467
danthe3rd pushed a commit that referenced this pull request Nov 28, 2022
ghstack-source-id: 48369de3f8b94eb3c190ac2b0a1b3ddf6003e5ff
Pull Request resolved: #467
danthe3rd pushed a commit that referenced this pull request Nov 29, 2022
ghstack-source-id: cded45d653d00d9147d31099040a18b66403b064
Pull Request resolved: #467
@codecov-commenter
Copy link

codecov-commenter commented Nov 29, 2022

Codecov Report

Base: 89.79% // Head: 89.79% // No change to project coverage 👍

Coverage data is based on head (7538362) compared to base (5447130).
Patch has no changes to coverable lines.

Additional details and impacted files
@@                  Coverage Diff                  @@
##           gh/danthe3rd/48/base     #467   +/-   ##
=====================================================
  Coverage                 89.79%   89.79%           
=====================================================
  Files                        80       80           
  Lines                      4839     4839           
=====================================================
  Hits                       4345     4345           
  Misses                      494      494           
Flag Coverage Δ
Python 89.79% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

danthe3rd pushed a commit that referenced this pull request Nov 29, 2022
ghstack-source-id: 037be211463d6b313f7c86804f6a89799f355d9d
Pull Request resolved: #467
@danthe3rd danthe3rd changed the title [WIP] bwaccf32: Accumulate in f32 for bw MemEff: Accumulate in f32 for bw Nov 29, 2022
danthe3rd pushed a commit that referenced this pull request Nov 29, 2022
ghstack-source-id: d13ff3e407510b98ea4daf9cf38a74d966fa4d59
Pull Request resolved: #467
**PERFORMANCE**

<details>
<summary>bw P100/V100 (f32/f16)</summary>

```
[---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------]                                                    
                                                         |  48_accf32_6f8e2f15  |  56_base_02bf6b4e  |  vanilla   |  57_tmpT_b516aec4
1 threads: --------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6289.6       |        6936.0      |    2183.6  |        6940.0    
                          f32 B=384, M=197, H=1, K=88    |         8793.3       |        9446.6      |    2175.1  |        9429.2    
                          f16 B=384, M=197, H=1, K=80    |         5989.9       |        6596.2      |    2146.8  |        6608.6    
                          f32 B=384, M=197, H=1, K=80    |         8427.1       |        8993.1      |    2134.0  |        9030.5    
                          f16 B=384, M=197, H=1, K=64    |         3347.0       |        3527.4      |    1799.3  |        3538.3    
                          f32 B=384, M=197, H=1, K=64    |         5563.1       |        5984.2      |    1801.6  |        5980.9    
                          f16 B=1024, M=197, H=1, K=88   |        15680.5       |       17424.8      |    5671.9  |       17452.6    
                          f32 B=1024, M=197, H=1, K=88   |        23784.2       |       25542.8      |    5664.0  |       25578.6    
                          f16 B=1024, M=197, H=1, K=80   |        14935.5       |       16581.3      |    5559.5  |       16587.3    
                          f32 B=1024, M=197, H=1, K=80   |        22767.6       |       24354.1      |    5550.7  |       24362.0    
                          f16 B=1024, M=197, H=1, K=64   |         8331.1       |        8671.5      |    4644.9  |        8695.3    
                          f32 B=1024, M=197, H=1, K=64   |        15061.6       |       16148.8      |    4650.0  |       16201.4    
                          f16 B=512, M=197, H=1, K=80    |         7594.7       |        8306.6      |    2824.6  |        8336.5    
                          f32 B=512, M=197, H=1, K=80    |        11857.0       |       12660.2      |    2807.4  |       12709.2    
                          f16 B=32, M=197, H=16, K=80    |         7779.8       |        8342.0      |    2820.8  |        8393.5    
                          f32 B=32, M=197, H=16, K=80    |        11846.5       |       12487.8      |    2806.0  |       12549.1    
                          f16 B=32, M=197, H=16, K=64    |         4258.4       |        4445.6      |    2374.0  |        4461.2    
                          f32 B=32, M=197, H=16, K=64    |         7828.9       |        8434.9      |    2376.0  |        8472.2    
                          f16 B=32, M=197, H=16, K=128   |         9025.1       |        9671.9      |    3159.3  |        9707.2    
                          f32 B=32, M=197, H=16, K=128   |        14139.8       |       14920.7      |    3157.7  |       14928.9    
                          f16 B=256, M=197, H=1, K=88    |         4608.9       |        5118.8      |    1478.4  |        5119.2    
                          f32 B=256, M=197, H=1, K=88    |         6174.0       |        6644.3      |    1477.6  |        6642.7    
                          f16 B=16, M=197, H=16, K=88    |         4618.4       |        5073.4      |    1479.5  |        5062.9    
                          f32 B=16, M=197, H=16, K=88    |         6114.2       |        6471.7      |    1474.9  |        6478.7    
                          f16 B=16, M=197, H=16, K=64    |         2490.3       |        2557.5      |    1225.9  |        2550.0    
                          f32 B=16, M=197, H=16, K=64    |         3918.9       |        4208.2      |    1227.0  |        4195.7    
                          f16 B=16, M=197, H=16, K=128   |         5210.0       |        5649.3      |    1635.2  |        5648.1    
                          f32 B=16, M=197, H=16, K=128   |         7103.1       |        7451.6      |    1645.5  |        7445.1    
                          f16 B=1, M=4096, H=160, K=128  |      1014229.1       |     1106182.6      |            |     1108040.6    
                          f32 B=1, M=4096, H=160, K=128  |      1258173.2       |     1243183.8      |            |     1241548.7    
                          f16 B=2, M=4096, H=160, K=128  |      1642279.2       |     1753736.9      |            |     1771655.3    
                          f32 B=2, M=4096, H=160, K=128  |      2505435.4       |     2477353.1      |            |     2473773.4    
                          f16 B=1, M=8192, H=160, K=128  |      4050128.4       |     4415962.1      |            |     4428194.9    
                          f32 B=1, M=8192, H=160, K=128  |      5042352.6       |     4970582.0      |            |     4965069.9    
                          f16 B=2, M=8192, H=160, K=128  |      6600732.3       |     7026378.5      |            |     7068051.8    
                          f16 B=1024, M=82, H=8, K=64    |        21572.8       |       22531.6      |    9059.1  |       22418.4    
                          f32 B=1024, M=82, H=8, K=64    |        38178.4       |       45927.2      |    9070.3  |       45708.0    
                          f16 B=150, M=256, H=16, K=64   |        21436.5       |       21927.4      |   12938.5  |       22001.0    
                          f32 B=150, M=256, H=16, K=64   |        33024.2       |       33196.3      |   13249.2  |       33199.6    
                          f16 B=64, M=256, H=12, K=64    |         6869.8       |        7048.6      |    4200.6  |        7073.6    
                          f32 B=64, M=256, H=12, K=64    |        10719.9       |       10832.1      |    4271.3  |       10843.6    
                          f16 B=1, M=4096, H=16, K=40    |       134722.6       |      145429.8      |   20587.0  |      143743.8    
                          f32 B=1, M=4096, H=16, K=40    |       143015.1       |      147850.6      |   20625.8  |      148272.8    
                          f16 B=1, M=16384, H=16, K=40   |      2149850.4       |     2323732.4      |            |     2301489.4    
                          f32 B=1, M=16384, H=16, K=40   |      2286478.1       |     2369812.4      |            |     2375911.8    
                          f16 B=16, M=128, H=16, K=16    |          497.2       |         502.9      |     623.8  |         503.6    
                          f32 B=16, M=128, H=16, K=16    |          573.5       |         609.8      |     617.5  |         611.6    
                          f16 B=16, M=128, H=16, K=32    |          563.9       |         573.1      |     624.2  |         575.2    
                          f32 B=16, M=128, H=16, K=32    |          661.5       |         702.2      |     620.4  |         703.2    
                          f16 B=16, M=128, H=16, K=64    |          708.9       |         722.6      |     619.9  |         724.3    
                          f32 B=16, M=128, H=16, K=64    |          916.2       |         953.5      |     618.6  |         953.6    
                          f16 B=16, M=128, H=16, K=128   |         1465.8       |        1542.8      |     624.2  |        1545.7    
                          f32 B=16, M=128, H=16, K=128   |         1829.1       |        1872.3      |     616.2  |        1873.9    
                          f16 B=16, M=128, H=16, K=256   |         3796.3       |        4002.5      |    1010.7  |        4008.4    
                          f32 B=16, M=128, H=16, K=256   |         3838.8       |        3957.1      |    1203.6  |        3951.8    
                          f16 B=16, M=512, H=16, K=16    |         7680.4       |        7775.1      |    4848.7  |        7752.8    
                          f32 B=16, M=512, H=16, K=16    |         9402.0       |        9926.1      |    4926.7  |        9914.5    
                          f16 B=16, M=512, H=16, K=32    |         8897.5       |        8999.0      |    5055.9  |        9003.9    
                          f32 B=16, M=512, H=16, K=32    |        10762.1       |       11320.9      |    5065.4  |       11341.8    
                          f16 B=16, M=512, H=16, K=64    |        10936.9       |       11214.8      |    5484.9  |       11254.4    
                          f32 B=16, M=512, H=16, K=64    |        15091.9       |       15210.3      |    5552.0  |       15196.3    
                          f16 B=16, M=512, H=16, K=128   |        23491.0       |       25234.8      |    7317.1  |       25300.3    
                          f32 B=16, M=512, H=16, K=128   |        30524.6       |       30320.0      |    7487.0  |       30200.6    
                          f16 B=16, M=512, H=16, K=256   |        50389.7       |       54525.2      |   14015.3  |       54931.1    
                          f32 B=16, M=512, H=16, K=256   |        62155.6       |       61046.4      |   14285.5  |       61081.1    
                          f16 B=16, M=1024, H=16, K=16   |        31289.9       |       31951.9      |   18778.4  |       32044.9    
                          f32 B=16, M=1024, H=16, K=16   |        37744.4       |       39586.6      |   18929.9  |       39739.9    
                          f16 B=16, M=1024, H=16, K=32   |        35770.8       |       36651.1      |   19620.2  |       36909.0    
                          f32 B=16, M=1024, H=16, K=32   |        43211.5       |       45544.5      |   19518.4  |       45664.8    
                          f16 B=16, M=1024, H=16, K=64   |        43865.1       |       44864.8      |   21286.0  |       45345.7    
                          f32 B=16, M=1024, H=16, K=64   |        60710.5       |       60966.9      |   21634.1  |       60921.9    
                          f16 B=16, M=1024, H=16, K=128  |        94633.7       |      101796.0      |   28502.1  |      102871.6    
                          f32 B=16, M=1024, H=16, K=128  |       124093.6       |      122196.9      |   28520.3  |      122043.0    
                          f16 B=16, M=1024, H=16, K=256  |       194780.7       |      212303.0      |   55419.1  |      214643.5    
                          f32 B=16, M=1024, H=16, K=256  |       250799.1       |      245196.2      |   55634.0  |      245534.2    
                          f16 B=64, M=128, H=16, K=16    |         1658.0       |        1661.4      |    1331.0  |        1662.7    
                          f32 B=64, M=128, H=16, K=16    |         2129.2       |        2266.2      |    1371.8  |        2269.4    
                          f16 B=64, M=128, H=16, K=32    |         1904.3       |        1903.5      |    1384.2  |        1906.4    
                          f32 B=64, M=128, H=16, K=32    |         2496.5       |        2643.4      |    1445.1  |        2639.8    
                          f16 B=64, M=128, H=16, K=64    |         2393.1       |        2432.3      |    1505.2  |        2437.8    
                          f32 B=64, M=128, H=16, K=64    |         3471.1       |        3561.2      |    1590.0  |        3565.7    
                          f16 B=64, M=128, H=16, K=128   |         4969.1       |        5266.3      |    1988.4  |        5272.4    
                          f32 B=64, M=128, H=16, K=128   |         6880.3       |        7040.6      |    2121.0  |        7024.3    
                          f16 B=64, M=128, H=16, K=256   |        12635.3       |       13334.7      |    3859.5  |       13350.7    
                          f32 B=64, M=128, H=16, K=256   |        14185.7       |       14514.4      |    4553.8  |       14503.1    
                          f16 B=64, M=512, H=16, K=16    |        26278.4       |       26189.9      |   18916.0  |       26110.0    
                          f32 B=64, M=512, H=16, K=16    |        35128.2       |       37075.9      |   19191.4  |       37174.5    
                          f16 B=64, M=512, H=16, K=32    |        30414.2       |       30938.2      |   19734.7  |       31071.7    
                          f32 B=64, M=512, H=16, K=32    |        40483.2       |       42922.5      |   19843.9  |       42868.5    
                          f16 B=64, M=512, H=16, K=64    |        37248.2       |       38179.7      |   21640.1  |       38327.0    
                          f32 B=64, M=512, H=16, K=64    |        57666.8       |       57675.3      |   21909.1  |       57697.2    
                          f16 B=64, M=512, H=16, K=128   |        80113.6       |       86165.7      |   28765.5  |       86368.3    
                          f32 B=64, M=512, H=16, K=128   |       115672.5       |      115161.0      |   28910.3  |      115320.0    
                          f16 B=64, M=512, H=16, K=256   |       169250.0       |      183791.8      |   56315.7  |      183735.0    
                          f32 B=64, M=512, H=16, K=256   |       236594.6       |      233093.6      |   56853.4  |      233170.2    
                          f16 B=64, M=1024, H=16, K=16   |       106022.3       |      109588.2      |   74303.6  |      109410.4    
                          f32 B=64, M=1024, H=16, K=16   |       141241.8       |      148854.1      |            |      149651.5    
                          f16 B=64, M=1024, H=16, K=32   |       120899.1       |      125044.6      |   77828.6  |      125716.9    
                          f32 B=64, M=1024, H=16, K=32   |       162478.5       |      173906.1      |            |      173216.4    
                          f16 B=64, M=1024, H=16, K=64   |       149044.1       |      152290.6      |   85821.6  |      152748.5    
                          f32 B=64, M=1024, H=16, K=64   |       233195.9       |      231479.6      |            |      231533.9    
                          f16 B=64, M=1024, H=16, K=128  |       319761.5       |      344076.5      |  113579.4  |      345058.5    
                          f32 B=64, M=1024, H=16, K=128  |       470172.3       |      466330.7      |            |      463507.4    
                          f16 B=64, M=1024, H=16, K=256  |       658362.1       |      717070.2      |            |      723057.4    
                          f32 B=64, M=1024, H=16, K=256  |       955624.0       |      935114.3      |            |      935945.4    
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1811.3       |        1686.3      |    1375.2  |        1699.6    
                          f32 B=384, M=197, H=1, K=88    |         4315.7       |        4665.1      |    2257.1  |        4663.1    
                          f16 B=384, M=197, H=1, K=80    |         1733.0       |        1616.3      |    1281.7  |        1619.8    
                          f32 B=384, M=197, H=1, K=80    |         3965.0       |        4226.4      |    2171.7  |        4228.3    
                          f16 B=384, M=197, H=1, K=64    |         1135.2       |        1084.4      |    1043.8  |        1083.0    
                          f32 B=384, M=197, H=1, K=64    |         2673.2       |        2883.0      |    1744.5  |        2878.2    
                          f16 B=1024, M=197, H=1, K=88   |         4721.0       |        4396.6      |    3725.3  |        4404.1    
                          f32 B=1024, M=197, H=1, K=88   |        10531.3       |       11443.8      |    6106.5  |       11464.2    
                          f16 B=1024, M=197, H=1, K=80   |         4520.2       |        4216.2      |    3329.2  |        4223.7    
                          f32 B=1024, M=197, H=1, K=80   |         9573.5       |       10301.4      |    5757.6  |       10305.3    
                          f16 B=1024, M=197, H=1, K=64   |         2788.1       |        2660.1      |    2674.6  |        2663.5    
                          f32 B=1024, M=197, H=1, K=64   |         6556.7       |        7102.4      |    4516.6  |        7096.8    
                          f16 B=512, M=197, H=1, K=80    |         2377.2       |        2228.1      |    1685.0  |        2231.4    
                          f32 B=512, M=197, H=1, K=80    |         5259.3       |        5639.7      |    2887.8  |        5636.6    
                          f16 B=32, M=197, H=16, K=80    |         2403.0       |        2201.1      |    1798.2  |        2204.9    
                          f32 B=32, M=197, H=16, K=80    |         5402.3       |        5667.7      |    3046.3  |        5662.4    
                          f16 B=32, M=197, H=16, K=64    |         1552.7       |        1486.3      |    1451.6  |        1485.2    
                          f32 B=32, M=197, H=16, K=64    |         3622.5       |        3911.0      |    2427.0  |        3912.9    
                          f16 B=32, M=197, H=16, K=128   |         2776.6       |        2611.9      |    2211.3  |        2613.3    
                          f32 B=32, M=197, H=16, K=128   |         6647.9       |        7082.3      |    4088.5  |        7104.8    
                          f16 B=256, M=197, H=1, K=88    |         1357.5       |        1285.5      |     941.3  |        1287.6    
                          f32 B=256, M=197, H=1, K=88    |         2874.1       |        3085.1      |    1543.2  |        3090.1    
                          f16 B=16, M=197, H=16, K=88    |         1349.1       |        1264.5      |     964.7  |        1263.1    
                          f32 B=16, M=197, H=16, K=88    |         2803.8       |        2967.0      |    1647.4  |        2972.0    
                          f16 B=16, M=197, H=16, K=64    |          765.5       |         728.4      |     765.3  |         731.0    
                          f32 B=16, M=197, H=16, K=64    |         1834.1       |        1969.7      |    1282.4  |        1974.4    
                          f16 B=16, M=197, H=16, K=128   |         1509.1       |        1432.6      |    1139.1  |        1433.9    
                          f32 B=16, M=197, H=16, K=128   |         3406.5       |        3606.2      |    2048.6  |        3613.3    
                          f16 B=1, M=4096, H=160, K=128  |       168807.1       |      148652.9      |            |      149343.8    
                          f32 B=1, M=4096, H=160, K=128  |       549864.6       |      586699.9      |            |      585699.9    
                          f16 B=2, M=4096, H=160, K=128  |       339010.5       |      298827.7      |            |      298808.4    
                          f32 B=2, M=4096, H=160, K=128  |      1106963.8       |     1176218.3      |            |     1179173.2    
                          f16 B=1, M=8192, H=160, K=128  |       679742.4       |      594323.2      |            |      595580.1    
                          f32 B=1, M=8192, H=160, K=128  |      2195491.3       |     2340248.4      |            |     2343505.7    
                          f16 B=2, M=8192, H=160, K=128  |      1364983.7       |     1193787.8      |            |     1192596.1    
                          f16 B=1024, M=82, H=8, K=64    |         9052.8       |        8762.6      |    5804.2  |        8757.8    
                          f32 B=1024, M=82, H=8, K=64    |        14726.3       |       16270.6      |   11059.7  |       16215.9    
                          f16 B=150, M=256, H=16, K=64   |         5662.9       |        5519.4      |    7557.8  |        5524.5    
                          f32 B=150, M=256, H=16, K=64   |        16700.3       |       17612.7      |   16426.0  |       17640.4    
                          f16 B=64, M=256, H=12, K=64    |         1849.1       |        1793.1      |    2383.5  |        1798.7    
                          f32 B=64, M=256, H=12, K=64    |         5451.5       |        5766.4      |    4975.3  |        5775.8    
                          f16 B=1, M=4096, H=16, K=40    |        47263.3       |       47850.4      |    8315.6  |       47777.2    
                          f32 B=1, M=4096, H=16, K=40    |       113099.4       |      113164.5      |   19536.0  |      113930.0    
                          f16 B=1, M=16384, H=16, K=40   |       757091.8       |      770365.7      |            |      765401.3    
                          f32 B=1, M=16384, H=16, K=40   |      1806827.7       |     1816302.6      |            |     1819162.1    
                          f16 B=16, M=128, H=16, K=16    |          219.2       |         218.5      |     480.6  |         231.8    
                          f32 B=16, M=128, H=16, K=16    |          301.2       |         308.2      |     498.9  |         307.9    
                          f16 B=16, M=128, H=16, K=32    |          227.6       |         215.4      |     473.8  |         220.1    
                          f32 B=16, M=128, H=16, K=32    |          395.8       |         401.0      |     455.1  |         400.8    
                          f16 B=16, M=128, H=16, K=64    |          225.6       |         214.8      |     510.1  |         229.9    
                          f32 B=16, M=128, H=16, K=64    |          561.7       |         583.1      |     598.9  |         581.2    
                          f16 B=16, M=128, H=16, K=128   |          404.6       |         392.0      |     524.0  |         394.8    
                          f32 B=16, M=128, H=16, K=128   |         1103.0       |        1140.8      |    1015.4  |        1142.0    
                          f16 B=16, M=128, H=16, K=256   |         1045.3       |        1049.1      |     889.6  |        1047.4    
                          f32 B=16, M=128, H=16, K=256   |         2181.3       |        2270.9      |    1869.2  |        2265.7    
                          f16 B=16, M=512, H=16, K=16    |         1731.5       |        1585.7      |    1908.5  |        1586.2    
                          f32 B=16, M=512, H=16, K=16    |         4513.8       |        4696.7      |    4222.1  |        4695.6    
                          f16 B=16, M=512, H=16, K=32    |         1942.2       |        1823.4      |    2086.3  |        1809.7    
                          f32 B=16, M=512, H=16, K=32    |         5596.1       |        5819.4      |    4588.1  |        5833.6    
                          f16 B=16, M=512, H=16, K=64    |         2450.2       |        2340.8      |    2580.6  |        2353.2    
                          f32 B=16, M=512, H=16, K=64    |         7619.5       |        7853.1      |    5536.8  |        7875.1    
                          f16 B=16, M=512, H=16, K=128   |         4884.8       |        4473.5      |    3388.7  |        4487.9    
                          f32 B=16, M=512, H=16, K=128   |        15010.0       |       15557.1      |    8979.4  |       15513.1    
                          f16 B=16, M=512, H=16, K=256   |        12973.9       |       11134.7      |    5418.2  |       11106.3    
                          f32 B=16, M=512, H=16, K=256   |        29751.1       |       30979.1      |   16856.3  |       31012.8    
                          f16 B=16, M=1024, H=16, K=16   |         6802.6       |        6185.1      |    6996.0  |        6191.0    
                          f32 B=16, M=1024, H=16, K=16   |        18098.2       |       18954.5      |   16129.1  |       19166.8    
                          f16 B=16, M=1024, H=16, K=32   |         7531.9       |        7065.0      |    7436.0  |        7067.5    
                          f32 B=16, M=1024, H=16, K=32   |        21999.9       |       22901.5      |   17040.0  |       22899.4    
                          f16 B=16, M=1024, H=16, K=64   |         9312.0       |        8854.3      |    8605.6  |        8863.9    
                          f32 B=16, M=1024, H=16, K=64   |        29837.7       |       30895.1      |   20355.3  |       30878.5    
                          f16 B=16, M=1024, H=16, K=128  |        18979.0       |       16951.0      |   10561.3  |       16995.2    
                          f32 B=16, M=1024, H=16, K=128  |        58738.3       |       60861.2      |   33427.0  |       60599.8    
                          f16 B=16, M=1024, H=16, K=256  |        49681.9       |       41833.3      |   17329.5  |       41921.6    
                          f32 B=16, M=1024, H=16, K=256  |       117362.4       |      121004.8      |   60515.8  |      122046.9    
                          f16 B=64, M=128, H=16, K=16    |          432.2       |         411.1      |     642.1  |         411.3    
                          f32 B=64, M=128, H=16, K=16    |         1028.9       |        1057.9      |    1233.4  |        1056.6    
                          f16 B=64, M=128, H=16, K=32    |          522.5       |         500.7      |     813.6  |         499.7    
                          f32 B=64, M=128, H=16, K=32    |         1403.5       |        1443.4      |    1535.6  |        1443.6    
                          f16 B=64, M=128, H=16, K=64    |          750.0       |         739.8      |    1185.6  |         741.0    
                          f32 B=64, M=128, H=16, K=64    |         2013.9       |        2110.1      |    2156.8  |        2105.4    
                          f16 B=64, M=128, H=16, K=128   |         1421.2       |        1387.9      |    1915.5  |        1388.3    
                          f32 B=64, M=128, H=16, K=128   |         3946.5       |        4156.7      |    3780.1  |        4156.3    
                          f16 B=64, M=128, H=16, K=256   |         3811.5       |        3810.6      |    3448.7  |        3811.0    
                          f32 B=64, M=128, H=16, K=256   |         7983.9       |        8432.7      |    7304.0  |        8415.5    
                          f16 B=64, M=512, H=16, K=16    |         6157.4       |        5523.9      |    7461.8  |        5528.1    
                          f32 B=64, M=512, H=16, K=16    |        16118.2       |       17004.7      |   16651.7  |       16984.5    
                          f16 B=64, M=512, H=16, K=32    |         6985.1       |        6483.1      |    8278.5  |        6495.1    
                          f32 B=64, M=512, H=16, K=32    |        20471.9       |       21230.7      |   18420.5  |       21269.0    
                          f16 B=64, M=512, H=16, K=64    |         9045.8       |        8633.5      |   10337.1  |        8674.1    
                          f32 B=64, M=512, H=16, K=64    |        27675.0       |       29282.8      |   22883.7  |       29136.9    
                          f16 B=64, M=512, H=16, K=128   |        17594.2       |       15805.9      |   14700.6  |       15788.2    
                          f32 B=64, M=512, H=16, K=128   |        54612.4       |       57815.7      |   39974.3  |       57951.8    
                          f16 B=64, M=512, H=16, K=256   |        47452.6       |       40093.5      |   27087.5  |       40175.6    
                          f32 B=64, M=512, H=16, K=256   |       108880.3       |      115953.1      |   77794.2  |      115951.0    
                          f16 B=64, M=1024, H=16, K=16   |        24369.0       |       21533.0      |   28448.6  |       21556.6    
                          f32 B=64, M=1024, H=16, K=16   |        64649.7       |       68791.8      |            |       68407.4    
                          f16 B=64, M=1024, H=16, K=32   |        27143.1       |       25683.7      |   30252.7  |       25727.9    
                          f32 B=64, M=1024, H=16, K=32   |        79967.5       |       83351.5      |            |       83084.7    
                          f16 B=64, M=1024, H=16, K=64   |        34667.0       |       32592.7      |   36991.2  |       32659.8    
                          f32 B=64, M=1024, H=16, K=64   |       108282.2       |      113858.1      |            |      114286.0    
                          f16 B=64, M=1024, H=16, K=128  |        68519.5       |       59757.3      |   48834.4  |       59817.7    
                          f32 B=64, M=1024, H=16, K=128  |       215465.9       |      227335.8      |            |      227204.0    
                          f16 B=64, M=1024, H=16, K=256  |       183070.4       |      150960.1      |            |      150947.1    
                          f32 B=64, M=1024, H=16, K=256  |       425832.9       |      453717.2      |            |      453349.2    

Times are in microseconds (us).
```
</details>

<details>
<summary>bw A100 (f32/f16)</summary>

```
[----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------]                                                       
                                     |  48_accf32_69654fdb[cutlass]  |  flash[flshatt]  |  vanilla   |  56_base_02bf6b4e[cutlass]  |  57_tmpT_b516aec4[cutlass]
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |               613.4           |                  |    2264.9  |              612.8          |              578.3        
      f32 B=384, M=197, H=1, K=88    |              2335.2           |                  |    1843.0  |             2438.4          |             2425.1        
      f16 B=384, M=197, H=1, K=80    |               583.6           |                  |    1922.9  |              577.6          |              548.4        
      f32 B=384, M=197, H=1, K=80    |              2241.3           |                  |    1787.8  |             2333.2          |             2333.1        
      f16 B=384, M=197, H=1, K=64    |               405.6           |        232.5     |    1809.8  |              386.4          |              366.1        
      f32 B=384, M=197, H=1, K=64    |              1259.7           |                  |    1675.6  |             1309.6          |             1316.8        
      f16 B=1024, M=197, H=1, K=88   |              1538.2           |                  |    5964.7  |             1550.8          |             1454.4        
      f32 B=1024, M=197, H=1, K=88   |              6031.4           |                  |    4559.5  |             6325.4          |             6332.7        
      f16 B=1024, M=197, H=1, K=80   |              1458.8           |                  |    5038.2  |             1463.8          |             1379.0        
      f32 B=1024, M=197, H=1, K=80   |              5786.2           |                  |    4412.6  |             6059.0          |             6079.9        
      f16 B=1024, M=197, H=1, K=64   |               929.5           |        575.9     |    4735.6  |              862.1          |              821.6        
      f32 B=1024, M=197, H=1, K=64   |              3289.9           |                  |    4119.9  |             3434.4          |             3441.4        
      f16 B=512, M=197, H=1, K=80    |               744.5           |                  |    2544.7  |              735.3          |              695.8        
      f32 B=512, M=197, H=1, K=80    |              2889.4           |                  |    2286.1  |             3008.8          |             3029.0        
      f16 B=32, M=197, H=16, K=80    |               741.6           |                  |    2569.0  |              723.3          |              693.4        
      f32 B=32, M=197, H=16, K=80    |              2878.6           |                  |    2355.3  |             3003.0          |             3024.1        
      f16 B=32, M=197, H=16, K=64    |               478.5           |        295.7     |    2429.2  |              456.1          |              426.7        
      f32 B=32, M=197, H=16, K=64    |              1784.4           |                  |    2196.9  |             1863.7          |             1860.2        
      f16 B=32, M=197, H=16, K=128   |               887.0           |        682.3     |    4492.9  |              857.5          |              853.8        
      f32 B=32, M=197, H=16, K=128   |              3546.6           |                  |    2807.3  |             3734.9          |             3737.1        
      f16 B=256, M=197, H=1, K=88    |               445.1           |                  |    1528.9  |              445.0          |              422.2        
      f32 B=256, M=197, H=1, K=88    |              1678.8           |                  |    1207.6  |             1746.6          |             1752.8        
      f16 B=16, M=197, H=16, K=88    |               441.5           |                  |    1544.2  |              437.3          |              419.5        
      f32 B=16, M=197, H=16, K=88    |              1668.3           |                  |    1250.4  |             1742.7          |             1746.1        
      f16 B=16, M=197, H=16, K=64    |               247.4           |        165.6     |    1242.5  |              233.2          |              217.5        
      f32 B=16, M=197, H=16, K=64    |              1051.1           |                  |    1125.3  |             1099.2          |             1096.6        
      f16 B=16, M=197, H=16, K=128   |               498.4           |        386.2     |    2264.5  |              488.0          |              480.5        
      f32 B=16, M=197, H=16, K=128   |              1950.2           |                  |    1446.7  |             2039.0          |             2028.6        
      f16 B=1, M=4096, H=160, K=128  |             55915.0           |      54620.4     |   45909.5  |            63407.6          |            51298.8        
      f32 B=1, M=4096, H=160, K=128  |            238514.6           |                  |            |           232677.5          |           232672.5        
      f16 B=2, M=4096, H=160, K=128  |             93612.0           |      84238.0     |            |           100433.1          |            84858.1        
      f32 B=2, M=4096, H=160, K=128  |            375037.4           |                  |            |           364234.1          |           364407.2        
      f16 B=1, M=8192, H=160, K=128  |            223261.8           |     215499.8     |            |           251806.6          |           202133.7        
      f32 B=1, M=8192, H=160, K=128  |            946708.9           |                  |            |           924986.2          |           924988.9        
      f16 B=2, M=8192, H=160, K=128  |            367969.2           |     330092.8     |            |           395881.3          |           332000.6        
      f32 B=2, M=8192, H=160, K=128  |           1492031.4           |                  |            |          1448691.1          |          1449146.1        
      f16 B=1024, M=82, H=8, K=64    |              1890.2           |       1620.3     |    3819.7  |             1861.3          |             1764.6        
      f32 B=1024, M=82, H=8, K=64    |              8428.2           |                  |    8735.1  |             8831.3          |             8867.4        
      f16 B=150, M=256, H=16, K=64   |              2292.3           |       1625.3     |    4555.9  |             2109.8          |             2019.4        
      f32 B=150, M=256, H=16, K=64   |              6252.4           |                  |   12948.1  |             6288.9          |             6281.3        
      f16 B=64, M=256, H=12, K=64    |               782.2           |        567.4     |    1498.0  |              731.2          |              699.7        
      f32 B=64, M=256, H=12, K=64    |              2141.2           |                  |    4266.6  |             2160.6          |             2160.2        
      f16 B=1, M=4096, H=16, K=40    |             23504.2           |                  |    4196.1  |            23699.0          |            23008.9        
      f32 B=1, M=4096, H=16, K=40    |             73699.5           |                  |   17755.3  |            73261.8          |            73078.5        
      f16 B=1, M=16384, H=16, K=40   |            391408.9           |                  |            |           439777.3          |           407653.7        
      f32 B=1, M=16384, H=16, K=40   |           1196173.6           |                  |            |          1181547.9          |          1181625.1        
      f16 B=256, M=4096, H=16, K=64  |            733221.8           |     439627.5     |            |           603237.1          |           565905.8        
      f16 B=16, M=128, H=16, K=16    |               130.0           |        113.1     |     265.2  |              125.5          |              124.4        
      f32 B=16, M=128, H=16, K=16    |               161.5           |                  |     373.1  |              160.6          |              162.8        
      f16 B=16, M=128, H=16, K=32    |               125.8           |        111.5     |     263.8  |              122.1          |              125.8        
      f32 B=16, M=128, H=16, K=32    |               189.8           |                  |     412.6  |              196.3          |              196.2        
      f16 B=16, M=128, H=16, K=64    |               126.0           |        112.4     |     265.8  |              120.8          |              125.7        
      f32 B=16, M=128, H=16, K=64    |               272.1           |                  |     498.7  |              285.9          |              283.3        
      f16 B=16, M=128, H=16, K=128   |               181.3           |        158.4     |     298.5  |              178.0          |              186.5        
      f32 B=16, M=128, H=16, K=128   |               509.6           |                  |     673.9  |              521.1          |              521.5        
      f16 B=16, M=128, H=16, K=256   |               774.0           |                  |     541.4  |              775.5          |              757.0        
      f32 B=16, M=128, H=16, K=256   |               975.2           |                  |    1162.6  |              994.5          |              994.5        
      f16 B=16, M=512, H=16, K=16    |               621.0           |        322.6     |    1204.9  |              555.0          |              519.9        
      f32 B=16, M=512, H=16, K=16    |              2148.0           |                  |    4414.4  |             2178.9          |             2180.4        
      f16 B=16, M=512, H=16, K=32    |               709.9           |        435.6     |    1306.2  |              653.1          |              602.8        
      f32 B=16, M=512, H=16, K=32    |              2335.6           |                  |    4640.7  |             2336.0          |             2336.3        
      f16 B=16, M=512, H=16, K=64    |               917.8           |        702.7     |    1545.8  |              849.9          |              797.4        
      f32 B=16, M=512, H=16, K=64    |              2965.5           |                  |    5125.0  |             2986.9          |             2988.3        
      f16 B=16, M=512, H=16, K=128   |              1644.0           |       1584.1     |    1983.4  |             1757.0          |             1548.0        
      f32 B=16, M=512, H=16, K=128   |              6152.7           |                  |    6099.1  |             6067.7          |             6068.6        
      f16 B=16, M=512, H=16, K=256   |              8178.3           |                  |    2899.1  |             7895.5          |             7977.4        
      f32 B=16, M=512, H=16, K=256   |             11894.0           |                  |   10639.5  |            11635.0          |            11624.7        
      f16 B=16, M=1024, H=16, K=16   |              2420.5           |       1240.8     |    4259.2  |             2234.6          |             2048.9        
      f32 B=16, M=1024, H=16, K=16   |              8467.2           |                  |   16650.4  |             8512.2          |             8510.2        
      f16 B=16, M=1024, H=16, K=32   |              2675.7           |       1618.9     |    4491.2  |             2441.3          |             2230.0        
      f32 B=16, M=1024, H=16, K=32   |              9010.0           |                  |   17301.0  |             9012.1          |             9015.7        
      f16 B=16, M=1024, H=16, K=64   |              3328.4           |       2370.3     |    4994.5  |             3032.2          |             2820.0        
      f32 B=16, M=1024, H=16, K=64   |             11566.8           |                  |   18714.2  |            11494.1          |            11492.8        
      f16 B=16, M=1024, H=16, K=128  |              5867.9           |       5632.5     |    5952.8  |             6401.1          |             5440.8        
      f32 B=16, M=1024, H=16, K=128  |             23345.4           |                  |   21523.7  |            22859.1          |            22870.3        
      f16 B=16, M=1024, H=16, K=256  |             30619.2           |                  |    7893.1  |            29884.9          |            29060.4        
      f32 B=16, M=1024, H=16, K=256  |             45211.4           |                  |   38093.0  |            43435.3          |            43423.8        
      f16 B=64, M=128, H=16, K=16    |               159.6           |        145.2     |     439.9  |              161.2          |              167.0        
      f32 B=64, M=128, H=16, K=16    |               493.4           |                  |    1270.0  |              502.7          |              503.2        
      f16 B=64, M=128, H=16, K=32    |               208.5           |        212.1     |     545.3  |              206.2          |              204.9        
      f32 B=64, M=128, H=16, K=32    |               601.2           |                  |    1427.0  |              610.5          |              610.6        
      f16 B=64, M=128, H=16, K=64    |               329.0           |        310.8     |     766.0  |              327.2          |              314.7        
      f32 B=64, M=128, H=16, K=64    |               867.7           |                  |    1743.5  |              889.2          |              889.0        
      f16 B=64, M=128, H=16, K=128   |               635.5           |        562.1     |    1226.7  |              613.3          |              650.7        
      f32 B=64, M=128, H=16, K=128   |              1774.4           |                  |    2386.5  |             1800.0          |             1799.6        
      f16 B=64, M=128, H=16, K=256   |              2839.8           |                  |    2122.8  |             2765.2          |             2821.9        
      f32 B=64, M=128, H=16, K=256   |              3419.2           |                  |    4320.7  |             3458.8          |             3457.0        
      f16 B=64, M=512, H=16, K=16    |              2316.2           |       1202.3     |    4487.1  |             1983.9          |             1868.4        
      f32 B=64, M=512, H=16, K=16    |              6686.0           |                  |   16991.5  |             6709.9          |             6713.2        
      f16 B=64, M=512, H=16, K=32    |              2701.7           |       1541.9     |    4975.9  |             2346.4          |             2184.3        
      f32 B=64, M=512, H=16, K=32    |              7460.5           |                  |   17859.9  |             7429.6          |             7438.0        
      f16 B=64, M=512, H=16, K=64    |              3461.2           |       2418.2     |    5886.1  |             3083.1          |             2942.6        
      f32 B=64, M=512, H=16, K=64    |              9553.4           |                  |   19768.6  |             9526.5          |             9516.6        
      f16 B=64, M=512, H=16, K=128   |              5875.5           |       5443.1     |    7711.0  |             6141.3          |             5513.6        
      f32 B=64, M=512, H=16, K=128   |             21317.0           |                  |   23651.2  |            20931.9          |            20932.3        
      f16 B=64, M=512, H=16, K=256   |             31238.2           |                  |   11490.6  |            28626.8          |            28954.0        
      f32 B=64, M=512, H=16, K=256   |             41124.8           |                  |   42468.0  |            39562.8          |            39579.5        
      f16 B=64, M=1024, H=16, K=16   |              9142.8           |       4707.2     |   16882.8  |             7887.0          |             7314.5        
      f32 B=64, M=1024, H=16, K=16   |             26512.3           |                  |   66311.7  |            26497.5          |            26459.8        
      f16 B=64, M=1024, H=16, K=32   |             10420.8           |       5698.3     |   17875.0  |             8851.4          |             7974.0        
      f32 B=64, M=1024, H=16, K=32   |             28300.0           |                  |   69088.7  |            28102.1          |            28098.2        
      f16 B=64, M=1024, H=16, K=64   |             12948.5           |       8119.0     |   19944.3  |            11064.2          |            10477.6        
      f32 B=64, M=1024, H=16, K=64   |             35600.6           |                  |   74762.8  |            35316.4          |            35361.2        
      f16 B=64, M=1024, H=16, K=128  |             20820.5           |      19184.2     |   23699.3  |            21954.8          |            19220.0        
      f32 B=64, M=1024, H=16, K=128  |             80800.3           |                  |   86003.8  |            78521.9          |            78393.9        
      f16 B=64, M=1024, H=16, K=256  |            114411.1           |                  |   32958.3  |           103287.2          |           104304.2        
      f32 B=64, M=1024, H=16, K=256  |            155731.5           |                  |  153011.0  |           148071.6          |           148165.6        

Times are in microseconds (us).
```
</details>


[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Nov 29, 2022
ghstack-source-id: 02ddda11d414d63ce8d3693c60e6bf4430c2a5d9
Pull Request resolved: #467
@danthe3rd danthe3rd marked this pull request as ready for review November 29, 2022 13:07
**PERFORMANCE**

This makes performance worse in f16 :(
But I think we need it for stability

<details>
<summary>bw P100/V100 (f32/f16)</summary>

```
[---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------]                                                    
                                                         |  48_accf32_6f8e2f15  |  56_base_02bf6b4e  |  vanilla   |  57_tmpT_b516aec4
1 threads: --------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6289.6       |        6936.0      |    2183.6  |        6940.0    
                          f32 B=384, M=197, H=1, K=88    |         8793.3       |        9446.6      |    2175.1  |        9429.2    
                          f16 B=384, M=197, H=1, K=80    |         5989.9       |        6596.2      |    2146.8  |        6608.6    
                          f32 B=384, M=197, H=1, K=80    |         8427.1       |        8993.1      |    2134.0  |        9030.5    
                          f16 B=384, M=197, H=1, K=64    |         3347.0       |        3527.4      |    1799.3  |        3538.3    
                          f32 B=384, M=197, H=1, K=64    |         5563.1       |        5984.2      |    1801.6  |        5980.9    
                          f16 B=1024, M=197, H=1, K=88   |        15680.5       |       17424.8      |    5671.9  |       17452.6    
                          f32 B=1024, M=197, H=1, K=88   |        23784.2       |       25542.8      |    5664.0  |       25578.6    
                          f16 B=1024, M=197, H=1, K=80   |        14935.5       |       16581.3      |    5559.5  |       16587.3    
                          f32 B=1024, M=197, H=1, K=80   |        22767.6       |       24354.1      |    5550.7  |       24362.0    
                          f16 B=1024, M=197, H=1, K=64   |         8331.1       |        8671.5      |    4644.9  |        8695.3    
                          f32 B=1024, M=197, H=1, K=64   |        15061.6       |       16148.8      |    4650.0  |       16201.4    
                          f16 B=512, M=197, H=1, K=80    |         7594.7       |        8306.6      |    2824.6  |        8336.5    
                          f32 B=512, M=197, H=1, K=80    |        11857.0       |       12660.2      |    2807.4  |       12709.2    
                          f16 B=32, M=197, H=16, K=80    |         7779.8       |        8342.0      |    2820.8  |        8393.5    
                          f32 B=32, M=197, H=16, K=80    |        11846.5       |       12487.8      |    2806.0  |       12549.1    
                          f16 B=32, M=197, H=16, K=64    |         4258.4       |        4445.6      |    2374.0  |        4461.2    
                          f32 B=32, M=197, H=16, K=64    |         7828.9       |        8434.9      |    2376.0  |        8472.2    
                          f16 B=32, M=197, H=16, K=128   |         9025.1       |        9671.9      |    3159.3  |        9707.2    
                          f32 B=32, M=197, H=16, K=128   |        14139.8       |       14920.7      |    3157.7  |       14928.9    
                          f16 B=256, M=197, H=1, K=88    |         4608.9       |        5118.8      |    1478.4  |        5119.2    
                          f32 B=256, M=197, H=1, K=88    |         6174.0       |        6644.3      |    1477.6  |        6642.7    
                          f16 B=16, M=197, H=16, K=88    |         4618.4       |        5073.4      |    1479.5  |        5062.9    
                          f32 B=16, M=197, H=16, K=88    |         6114.2       |        6471.7      |    1474.9  |        6478.7    
                          f16 B=16, M=197, H=16, K=64    |         2490.3       |        2557.5      |    1225.9  |        2550.0    
                          f32 B=16, M=197, H=16, K=64    |         3918.9       |        4208.2      |    1227.0  |        4195.7    
                          f16 B=16, M=197, H=16, K=128   |         5210.0       |        5649.3      |    1635.2  |        5648.1    
                          f32 B=16, M=197, H=16, K=128   |         7103.1       |        7451.6      |    1645.5  |        7445.1    
                          f16 B=1, M=4096, H=160, K=128  |      1014229.1       |     1106182.6      |            |     1108040.6    
                          f32 B=1, M=4096, H=160, K=128  |      1258173.2       |     1243183.8      |            |     1241548.7    
                          f16 B=2, M=4096, H=160, K=128  |      1642279.2       |     1753736.9      |            |     1771655.3    
                          f32 B=2, M=4096, H=160, K=128  |      2505435.4       |     2477353.1      |            |     2473773.4    
                          f16 B=1, M=8192, H=160, K=128  |      4050128.4       |     4415962.1      |            |     4428194.9    
                          f32 B=1, M=8192, H=160, K=128  |      5042352.6       |     4970582.0      |            |     4965069.9    
                          f16 B=2, M=8192, H=160, K=128  |      6600732.3       |     7026378.5      |            |     7068051.8    
                          f16 B=1024, M=82, H=8, K=64    |        21572.8       |       22531.6      |    9059.1  |       22418.4    
                          f32 B=1024, M=82, H=8, K=64    |        38178.4       |       45927.2      |    9070.3  |       45708.0    
                          f16 B=150, M=256, H=16, K=64   |        21436.5       |       21927.4      |   12938.5  |       22001.0    
                          f32 B=150, M=256, H=16, K=64   |        33024.2       |       33196.3      |   13249.2  |       33199.6    
                          f16 B=64, M=256, H=12, K=64    |         6869.8       |        7048.6      |    4200.6  |        7073.6    
                          f32 B=64, M=256, H=12, K=64    |        10719.9       |       10832.1      |    4271.3  |       10843.6    
                          f16 B=1, M=4096, H=16, K=40    |       134722.6       |      145429.8      |   20587.0  |      143743.8    
                          f32 B=1, M=4096, H=16, K=40    |       143015.1       |      147850.6      |   20625.8  |      148272.8    
                          f16 B=1, M=16384, H=16, K=40   |      2149850.4       |     2323732.4      |            |     2301489.4    
                          f32 B=1, M=16384, H=16, K=40   |      2286478.1       |     2369812.4      |            |     2375911.8    
                          f16 B=16, M=128, H=16, K=16    |          497.2       |         502.9      |     623.8  |         503.6    
                          f32 B=16, M=128, H=16, K=16    |          573.5       |         609.8      |     617.5  |         611.6    
                          f16 B=16, M=128, H=16, K=32    |          563.9       |         573.1      |     624.2  |         575.2    
                          f32 B=16, M=128, H=16, K=32    |          661.5       |         702.2      |     620.4  |         703.2    
                          f16 B=16, M=128, H=16, K=64    |          708.9       |         722.6      |     619.9  |         724.3    
                          f32 B=16, M=128, H=16, K=64    |          916.2       |         953.5      |     618.6  |         953.6    
                          f16 B=16, M=128, H=16, K=128   |         1465.8       |        1542.8      |     624.2  |        1545.7    
                          f32 B=16, M=128, H=16, K=128   |         1829.1       |        1872.3      |     616.2  |        1873.9    
                          f16 B=16, M=128, H=16, K=256   |         3796.3       |        4002.5      |    1010.7  |        4008.4    
                          f32 B=16, M=128, H=16, K=256   |         3838.8       |        3957.1      |    1203.6  |        3951.8    
                          f16 B=16, M=512, H=16, K=16    |         7680.4       |        7775.1      |    4848.7  |        7752.8    
                          f32 B=16, M=512, H=16, K=16    |         9402.0       |        9926.1      |    4926.7  |        9914.5    
                          f16 B=16, M=512, H=16, K=32    |         8897.5       |        8999.0      |    5055.9  |        9003.9    
                          f32 B=16, M=512, H=16, K=32    |        10762.1       |       11320.9      |    5065.4  |       11341.8    
                          f16 B=16, M=512, H=16, K=64    |        10936.9       |       11214.8      |    5484.9  |       11254.4    
                          f32 B=16, M=512, H=16, K=64    |        15091.9       |       15210.3      |    5552.0  |       15196.3    
                          f16 B=16, M=512, H=16, K=128   |        23491.0       |       25234.8      |    7317.1  |       25300.3    
                          f32 B=16, M=512, H=16, K=128   |        30524.6       |       30320.0      |    7487.0  |       30200.6    
                          f16 B=16, M=512, H=16, K=256   |        50389.7       |       54525.2      |   14015.3  |       54931.1    
                          f32 B=16, M=512, H=16, K=256   |        62155.6       |       61046.4      |   14285.5  |       61081.1    
                          f16 B=16, M=1024, H=16, K=16   |        31289.9       |       31951.9      |   18778.4  |       32044.9    
                          f32 B=16, M=1024, H=16, K=16   |        37744.4       |       39586.6      |   18929.9  |       39739.9    
                          f16 B=16, M=1024, H=16, K=32   |        35770.8       |       36651.1      |   19620.2  |       36909.0    
                          f32 B=16, M=1024, H=16, K=32   |        43211.5       |       45544.5      |   19518.4  |       45664.8    
                          f16 B=16, M=1024, H=16, K=64   |        43865.1       |       44864.8      |   21286.0  |       45345.7    
                          f32 B=16, M=1024, H=16, K=64   |        60710.5       |       60966.9      |   21634.1  |       60921.9    
                          f16 B=16, M=1024, H=16, K=128  |        94633.7       |      101796.0      |   28502.1  |      102871.6    
                          f32 B=16, M=1024, H=16, K=128  |       124093.6       |      122196.9      |   28520.3  |      122043.0    
                          f16 B=16, M=1024, H=16, K=256  |       194780.7       |      212303.0      |   55419.1  |      214643.5    
                          f32 B=16, M=1024, H=16, K=256  |       250799.1       |      245196.2      |   55634.0  |      245534.2    
                          f16 B=64, M=128, H=16, K=16    |         1658.0       |        1661.4      |    1331.0  |        1662.7    
                          f32 B=64, M=128, H=16, K=16    |         2129.2       |        2266.2      |    1371.8  |        2269.4    
                          f16 B=64, M=128, H=16, K=32    |         1904.3       |        1903.5      |    1384.2  |        1906.4    
                          f32 B=64, M=128, H=16, K=32    |         2496.5       |        2643.4      |    1445.1  |        2639.8    
                          f16 B=64, M=128, H=16, K=64    |         2393.1       |        2432.3      |    1505.2  |        2437.8    
                          f32 B=64, M=128, H=16, K=64    |         3471.1       |        3561.2      |    1590.0  |        3565.7    
                          f16 B=64, M=128, H=16, K=128   |         4969.1       |        5266.3      |    1988.4  |        5272.4    
                          f32 B=64, M=128, H=16, K=128   |         6880.3       |        7040.6      |    2121.0  |        7024.3    
                          f16 B=64, M=128, H=16, K=256   |        12635.3       |       13334.7      |    3859.5  |       13350.7    
                          f32 B=64, M=128, H=16, K=256   |        14185.7       |       14514.4      |    4553.8  |       14503.1    
                          f16 B=64, M=512, H=16, K=16    |        26278.4       |       26189.9      |   18916.0  |       26110.0    
                          f32 B=64, M=512, H=16, K=16    |        35128.2       |       37075.9      |   19191.4  |       37174.5    
                          f16 B=64, M=512, H=16, K=32    |        30414.2       |       30938.2      |   19734.7  |       31071.7    
                          f32 B=64, M=512, H=16, K=32    |        40483.2       |       42922.5      |   19843.9  |       42868.5    
                          f16 B=64, M=512, H=16, K=64    |        37248.2       |       38179.7      |   21640.1  |       38327.0    
                          f32 B=64, M=512, H=16, K=64    |        57666.8       |       57675.3      |   21909.1  |       57697.2    
                          f16 B=64, M=512, H=16, K=128   |        80113.6       |       86165.7      |   28765.5  |       86368.3    
                          f32 B=64, M=512, H=16, K=128   |       115672.5       |      115161.0      |   28910.3  |      115320.0    
                          f16 B=64, M=512, H=16, K=256   |       169250.0       |      183791.8      |   56315.7  |      183735.0    
                          f32 B=64, M=512, H=16, K=256   |       236594.6       |      233093.6      |   56853.4  |      233170.2    
                          f16 B=64, M=1024, H=16, K=16   |       106022.3       |      109588.2      |   74303.6  |      109410.4    
                          f32 B=64, M=1024, H=16, K=16   |       141241.8       |      148854.1      |            |      149651.5    
                          f16 B=64, M=1024, H=16, K=32   |       120899.1       |      125044.6      |   77828.6  |      125716.9    
                          f32 B=64, M=1024, H=16, K=32   |       162478.5       |      173906.1      |            |      173216.4    
                          f16 B=64, M=1024, H=16, K=64   |       149044.1       |      152290.6      |   85821.6  |      152748.5    
                          f32 B=64, M=1024, H=16, K=64   |       233195.9       |      231479.6      |            |      231533.9    
                          f16 B=64, M=1024, H=16, K=128  |       319761.5       |      344076.5      |  113579.4  |      345058.5    
                          f32 B=64, M=1024, H=16, K=128  |       470172.3       |      466330.7      |            |      463507.4    
                          f16 B=64, M=1024, H=16, K=256  |       658362.1       |      717070.2      |            |      723057.4    
                          f32 B=64, M=1024, H=16, K=256  |       955624.0       |      935114.3      |            |      935945.4    
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1811.3       |        1686.3      |    1375.2  |        1699.6    
                          f32 B=384, M=197, H=1, K=88    |         4315.7       |        4665.1      |    2257.1  |        4663.1    
                          f16 B=384, M=197, H=1, K=80    |         1733.0       |        1616.3      |    1281.7  |        1619.8    
                          f32 B=384, M=197, H=1, K=80    |         3965.0       |        4226.4      |    2171.7  |        4228.3    
                          f16 B=384, M=197, H=1, K=64    |         1135.2       |        1084.4      |    1043.8  |        1083.0    
                          f32 B=384, M=197, H=1, K=64    |         2673.2       |        2883.0      |    1744.5  |        2878.2    
                          f16 B=1024, M=197, H=1, K=88   |         4721.0       |        4396.6      |    3725.3  |        4404.1    
                          f32 B=1024, M=197, H=1, K=88   |        10531.3       |       11443.8      |    6106.5  |       11464.2    
                          f16 B=1024, M=197, H=1, K=80   |         4520.2       |        4216.2      |    3329.2  |        4223.7    
                          f32 B=1024, M=197, H=1, K=80   |         9573.5       |       10301.4      |    5757.6  |       10305.3    
                          f16 B=1024, M=197, H=1, K=64   |         2788.1       |        2660.1      |    2674.6  |        2663.5    
                          f32 B=1024, M=197, H=1, K=64   |         6556.7       |        7102.4      |    4516.6  |        7096.8    
                          f16 B=512, M=197, H=1, K=80    |         2377.2       |        2228.1      |    1685.0  |        2231.4    
                          f32 B=512, M=197, H=1, K=80    |         5259.3       |        5639.7      |    2887.8  |        5636.6    
                          f16 B=32, M=197, H=16, K=80    |         2403.0       |        2201.1      |    1798.2  |        2204.9    
                          f32 B=32, M=197, H=16, K=80    |         5402.3       |        5667.7      |    3046.3  |        5662.4    
                          f16 B=32, M=197, H=16, K=64    |         1552.7       |        1486.3      |    1451.6  |        1485.2    
                          f32 B=32, M=197, H=16, K=64    |         3622.5       |        3911.0      |    2427.0  |        3912.9    
                          f16 B=32, M=197, H=16, K=128   |         2776.6       |        2611.9      |    2211.3  |        2613.3    
                          f32 B=32, M=197, H=16, K=128   |         6647.9       |        7082.3      |    4088.5  |        7104.8    
                          f16 B=256, M=197, H=1, K=88    |         1357.5       |        1285.5      |     941.3  |        1287.6    
                          f32 B=256, M=197, H=1, K=88    |         2874.1       |        3085.1      |    1543.2  |        3090.1    
                          f16 B=16, M=197, H=16, K=88    |         1349.1       |        1264.5      |     964.7  |        1263.1    
                          f32 B=16, M=197, H=16, K=88    |         2803.8       |        2967.0      |    1647.4  |        2972.0    
                          f16 B=16, M=197, H=16, K=64    |          765.5       |         728.4      |     765.3  |         731.0    
                          f32 B=16, M=197, H=16, K=64    |         1834.1       |        1969.7      |    1282.4  |        1974.4    
                          f16 B=16, M=197, H=16, K=128   |         1509.1       |        1432.6      |    1139.1  |        1433.9    
                          f32 B=16, M=197, H=16, K=128   |         3406.5       |        3606.2      |    2048.6  |        3613.3    
                          f16 B=1, M=4096, H=160, K=128  |       168807.1       |      148652.9      |            |      149343.8    
                          f32 B=1, M=4096, H=160, K=128  |       549864.6       |      586699.9      |            |      585699.9    
                          f16 B=2, M=4096, H=160, K=128  |       339010.5       |      298827.7      |            |      298808.4    
                          f32 B=2, M=4096, H=160, K=128  |      1106963.8       |     1176218.3      |            |     1179173.2    
                          f16 B=1, M=8192, H=160, K=128  |       679742.4       |      594323.2      |            |      595580.1    
                          f32 B=1, M=8192, H=160, K=128  |      2195491.3       |     2340248.4      |            |     2343505.7    
                          f16 B=2, M=8192, H=160, K=128  |      1364983.7       |     1193787.8      |            |     1192596.1    
                          f16 B=1024, M=82, H=8, K=64    |         9052.8       |        8762.6      |    5804.2  |        8757.8    
                          f32 B=1024, M=82, H=8, K=64    |        14726.3       |       16270.6      |   11059.7  |       16215.9    
                          f16 B=150, M=256, H=16, K=64   |         5662.9       |        5519.4      |    7557.8  |        5524.5    
                          f32 B=150, M=256, H=16, K=64   |        16700.3       |       17612.7      |   16426.0  |       17640.4    
                          f16 B=64, M=256, H=12, K=64    |         1849.1       |        1793.1      |    2383.5  |        1798.7    
                          f32 B=64, M=256, H=12, K=64    |         5451.5       |        5766.4      |    4975.3  |        5775.8    
                          f16 B=1, M=4096, H=16, K=40    |        47263.3       |       47850.4      |    8315.6  |       47777.2    
                          f32 B=1, M=4096, H=16, K=40    |       113099.4       |      113164.5      |   19536.0  |      113930.0    
                          f16 B=1, M=16384, H=16, K=40   |       757091.8       |      770365.7      |            |      765401.3    
                          f32 B=1, M=16384, H=16, K=40   |      1806827.7       |     1816302.6      |            |     1819162.1    
                          f16 B=16, M=128, H=16, K=16    |          219.2       |         218.5      |     480.6  |         231.8    
                          f32 B=16, M=128, H=16, K=16    |          301.2       |         308.2      |     498.9  |         307.9    
                          f16 B=16, M=128, H=16, K=32    |          227.6       |         215.4      |     473.8  |         220.1    
                          f32 B=16, M=128, H=16, K=32    |          395.8       |         401.0      |     455.1  |         400.8    
                          f16 B=16, M=128, H=16, K=64    |          225.6       |         214.8      |     510.1  |         229.9    
                          f32 B=16, M=128, H=16, K=64    |          561.7       |         583.1      |     598.9  |         581.2    
                          f16 B=16, M=128, H=16, K=128   |          404.6       |         392.0      |     524.0  |         394.8    
                          f32 B=16, M=128, H=16, K=128   |         1103.0       |        1140.8      |    1015.4  |        1142.0    
                          f16 B=16, M=128, H=16, K=256   |         1045.3       |        1049.1      |     889.6  |        1047.4    
                          f32 B=16, M=128, H=16, K=256   |         2181.3       |        2270.9      |    1869.2  |        2265.7    
                          f16 B=16, M=512, H=16, K=16    |         1731.5       |        1585.7      |    1908.5  |        1586.2    
                          f32 B=16, M=512, H=16, K=16    |         4513.8       |        4696.7      |    4222.1  |        4695.6    
                          f16 B=16, M=512, H=16, K=32    |         1942.2       |        1823.4      |    2086.3  |        1809.7    
                          f32 B=16, M=512, H=16, K=32    |         5596.1       |        5819.4      |    4588.1  |        5833.6    
                          f16 B=16, M=512, H=16, K=64    |         2450.2       |        2340.8      |    2580.6  |        2353.2    
                          f32 B=16, M=512, H=16, K=64    |         7619.5       |        7853.1      |    5536.8  |        7875.1    
                          f16 B=16, M=512, H=16, K=128   |         4884.8       |        4473.5      |    3388.7  |        4487.9    
                          f32 B=16, M=512, H=16, K=128   |        15010.0       |       15557.1      |    8979.4  |       15513.1    
                          f16 B=16, M=512, H=16, K=256   |        12973.9       |       11134.7      |    5418.2  |       11106.3    
                          f32 B=16, M=512, H=16, K=256   |        29751.1       |       30979.1      |   16856.3  |       31012.8    
                          f16 B=16, M=1024, H=16, K=16   |         6802.6       |        6185.1      |    6996.0  |        6191.0    
                          f32 B=16, M=1024, H=16, K=16   |        18098.2       |       18954.5      |   16129.1  |       19166.8    
                          f16 B=16, M=1024, H=16, K=32   |         7531.9       |        7065.0      |    7436.0  |        7067.5    
                          f32 B=16, M=1024, H=16, K=32   |        21999.9       |       22901.5      |   17040.0  |       22899.4    
                          f16 B=16, M=1024, H=16, K=64   |         9312.0       |        8854.3      |    8605.6  |        8863.9    
                          f32 B=16, M=1024, H=16, K=64   |        29837.7       |       30895.1      |   20355.3  |       30878.5    
                          f16 B=16, M=1024, H=16, K=128  |        18979.0       |       16951.0      |   10561.3  |       16995.2    
                          f32 B=16, M=1024, H=16, K=128  |        58738.3       |       60861.2      |   33427.0  |       60599.8    
                          f16 B=16, M=1024, H=16, K=256  |        49681.9       |       41833.3      |   17329.5  |       41921.6    
                          f32 B=16, M=1024, H=16, K=256  |       117362.4       |      121004.8      |   60515.8  |      122046.9    
                          f16 B=64, M=128, H=16, K=16    |          432.2       |         411.1      |     642.1  |         411.3    
                          f32 B=64, M=128, H=16, K=16    |         1028.9       |        1057.9      |    1233.4  |        1056.6    
                          f16 B=64, M=128, H=16, K=32    |          522.5       |         500.7      |     813.6  |         499.7    
                          f32 B=64, M=128, H=16, K=32    |         1403.5       |        1443.4      |    1535.6  |        1443.6    
                          f16 B=64, M=128, H=16, K=64    |          750.0       |         739.8      |    1185.6  |         741.0    
                          f32 B=64, M=128, H=16, K=64    |         2013.9       |        2110.1      |    2156.8  |        2105.4    
                          f16 B=64, M=128, H=16, K=128   |         1421.2       |        1387.9      |    1915.5  |        1388.3    
                          f32 B=64, M=128, H=16, K=128   |         3946.5       |        4156.7      |    3780.1  |        4156.3    
                          f16 B=64, M=128, H=16, K=256   |         3811.5       |        3810.6      |    3448.7  |        3811.0    
                          f32 B=64, M=128, H=16, K=256   |         7983.9       |        8432.7      |    7304.0  |        8415.5    
                          f16 B=64, M=512, H=16, K=16    |         6157.4       |        5523.9      |    7461.8  |        5528.1    
                          f32 B=64, M=512, H=16, K=16    |        16118.2       |       17004.7      |   16651.7  |       16984.5    
                          f16 B=64, M=512, H=16, K=32    |         6985.1       |        6483.1      |    8278.5  |        6495.1    
                          f32 B=64, M=512, H=16, K=32    |        20471.9       |       21230.7      |   18420.5  |       21269.0    
                          f16 B=64, M=512, H=16, K=64    |         9045.8       |        8633.5      |   10337.1  |        8674.1    
                          f32 B=64, M=512, H=16, K=64    |        27675.0       |       29282.8      |   22883.7  |       29136.9    
                          f16 B=64, M=512, H=16, K=128   |        17594.2       |       15805.9      |   14700.6  |       15788.2    
                          f32 B=64, M=512, H=16, K=128   |        54612.4       |       57815.7      |   39974.3  |       57951.8    
                          f16 B=64, M=512, H=16, K=256   |        47452.6       |       40093.5      |   27087.5  |       40175.6    
                          f32 B=64, M=512, H=16, K=256   |       108880.3       |      115953.1      |   77794.2  |      115951.0    
                          f16 B=64, M=1024, H=16, K=16   |        24369.0       |       21533.0      |   28448.6  |       21556.6    
                          f32 B=64, M=1024, H=16, K=16   |        64649.7       |       68791.8      |            |       68407.4    
                          f16 B=64, M=1024, H=16, K=32   |        27143.1       |       25683.7      |   30252.7  |       25727.9    
                          f32 B=64, M=1024, H=16, K=32   |        79967.5       |       83351.5      |            |       83084.7    
                          f16 B=64, M=1024, H=16, K=64   |        34667.0       |       32592.7      |   36991.2  |       32659.8    
                          f32 B=64, M=1024, H=16, K=64   |       108282.2       |      113858.1      |            |      114286.0    
                          f16 B=64, M=1024, H=16, K=128  |        68519.5       |       59757.3      |   48834.4  |       59817.7    
                          f32 B=64, M=1024, H=16, K=128  |       215465.9       |      227335.8      |            |      227204.0    
                          f16 B=64, M=1024, H=16, K=256  |       183070.4       |      150960.1      |            |      150947.1    
                          f32 B=64, M=1024, H=16, K=256  |       425832.9       |      453717.2      |            |      453349.2    

Times are in microseconds (us).
```
</details>

<details>
<summary>bw A100 (f32/f16)</summary>

```
[----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------]                                                       
                                     |  48_accf32_69654fdb[cutlass]  |  flash[flshatt]  |  vanilla   |  56_base_02bf6b4e[cutlass]  |  57_tmpT_b516aec4[cutlass]
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |               613.4           |                  |    2264.9  |              612.8          |              578.3        
      f32 B=384, M=197, H=1, K=88    |              2335.2           |                  |    1843.0  |             2438.4          |             2425.1        
      f16 B=384, M=197, H=1, K=80    |               583.6           |                  |    1922.9  |              577.6          |              548.4        
      f32 B=384, M=197, H=1, K=80    |              2241.3           |                  |    1787.8  |             2333.2          |             2333.1        
      f16 B=384, M=197, H=1, K=64    |               405.6           |        232.5     |    1809.8  |              386.4          |              366.1        
      f32 B=384, M=197, H=1, K=64    |              1259.7           |                  |    1675.6  |             1309.6          |             1316.8        
      f16 B=1024, M=197, H=1, K=88   |              1538.2           |                  |    5964.7  |             1550.8          |             1454.4        
      f32 B=1024, M=197, H=1, K=88   |              6031.4           |                  |    4559.5  |             6325.4          |             6332.7        
      f16 B=1024, M=197, H=1, K=80   |              1458.8           |                  |    5038.2  |             1463.8          |             1379.0        
      f32 B=1024, M=197, H=1, K=80   |              5786.2           |                  |    4412.6  |             6059.0          |             6079.9        
      f16 B=1024, M=197, H=1, K=64   |               929.5           |        575.9     |    4735.6  |              862.1          |              821.6        
      f32 B=1024, M=197, H=1, K=64   |              3289.9           |                  |    4119.9  |             3434.4          |             3441.4        
      f16 B=512, M=197, H=1, K=80    |               744.5           |                  |    2544.7  |              735.3          |              695.8        
      f32 B=512, M=197, H=1, K=80    |              2889.4           |                  |    2286.1  |             3008.8          |             3029.0        
      f16 B=32, M=197, H=16, K=80    |               741.6           |                  |    2569.0  |              723.3          |              693.4        
      f32 B=32, M=197, H=16, K=80    |              2878.6           |                  |    2355.3  |             3003.0          |             3024.1        
      f16 B=32, M=197, H=16, K=64    |               478.5           |        295.7     |    2429.2  |              456.1          |              426.7        
      f32 B=32, M=197, H=16, K=64    |              1784.4           |                  |    2196.9  |             1863.7          |             1860.2        
      f16 B=32, M=197, H=16, K=128   |               887.0           |        682.3     |    4492.9  |              857.5          |              853.8        
      f32 B=32, M=197, H=16, K=128   |              3546.6           |                  |    2807.3  |             3734.9          |             3737.1        
      f16 B=256, M=197, H=1, K=88    |               445.1           |                  |    1528.9  |              445.0          |              422.2        
      f32 B=256, M=197, H=1, K=88    |              1678.8           |                  |    1207.6  |             1746.6          |             1752.8        
      f16 B=16, M=197, H=16, K=88    |               441.5           |                  |    1544.2  |              437.3          |              419.5        
      f32 B=16, M=197, H=16, K=88    |              1668.3           |                  |    1250.4  |             1742.7          |             1746.1        
      f16 B=16, M=197, H=16, K=64    |               247.4           |        165.6     |    1242.5  |              233.2          |              217.5        
      f32 B=16, M=197, H=16, K=64    |              1051.1           |                  |    1125.3  |             1099.2          |             1096.6        
      f16 B=16, M=197, H=16, K=128   |               498.4           |        386.2     |    2264.5  |              488.0          |              480.5        
      f32 B=16, M=197, H=16, K=128   |              1950.2           |                  |    1446.7  |             2039.0          |             2028.6        
      f16 B=1, M=4096, H=160, K=128  |             55915.0           |      54620.4     |   45909.5  |            63407.6          |            51298.8        
      f32 B=1, M=4096, H=160, K=128  |            238514.6           |                  |            |           232677.5          |           232672.5        
      f16 B=2, M=4096, H=160, K=128  |             93612.0           |      84238.0     |            |           100433.1          |            84858.1        
      f32 B=2, M=4096, H=160, K=128  |            375037.4           |                  |            |           364234.1          |           364407.2        
      f16 B=1, M=8192, H=160, K=128  |            223261.8           |     215499.8     |            |           251806.6          |           202133.7        
      f32 B=1, M=8192, H=160, K=128  |            946708.9           |                  |            |           924986.2          |           924988.9        
      f16 B=2, M=8192, H=160, K=128  |            367969.2           |     330092.8     |            |           395881.3          |           332000.6        
      f32 B=2, M=8192, H=160, K=128  |           1492031.4           |                  |            |          1448691.1          |          1449146.1        
      f16 B=1024, M=82, H=8, K=64    |              1890.2           |       1620.3     |    3819.7  |             1861.3          |             1764.6        
      f32 B=1024, M=82, H=8, K=64    |              8428.2           |                  |    8735.1  |             8831.3          |             8867.4        
      f16 B=150, M=256, H=16, K=64   |              2292.3           |       1625.3     |    4555.9  |             2109.8          |             2019.4        
      f32 B=150, M=256, H=16, K=64   |              6252.4           |                  |   12948.1  |             6288.9          |             6281.3        
      f16 B=64, M=256, H=12, K=64    |               782.2           |        567.4     |    1498.0  |              731.2          |              699.7        
      f32 B=64, M=256, H=12, K=64    |              2141.2           |                  |    4266.6  |             2160.6          |             2160.2        
      f16 B=1, M=4096, H=16, K=40    |             23504.2           |                  |    4196.1  |            23699.0          |            23008.9        
      f32 B=1, M=4096, H=16, K=40    |             73699.5           |                  |   17755.3  |            73261.8          |            73078.5        
      f16 B=1, M=16384, H=16, K=40   |            391408.9           |                  |            |           439777.3          |           407653.7        
      f32 B=1, M=16384, H=16, K=40   |           1196173.6           |                  |            |          1181547.9          |          1181625.1        
      f16 B=256, M=4096, H=16, K=64  |            733221.8           |     439627.5     |            |           603237.1          |           565905.8        
      f16 B=16, M=128, H=16, K=16    |               130.0           |        113.1     |     265.2  |              125.5          |              124.4        
      f32 B=16, M=128, H=16, K=16    |               161.5           |                  |     373.1  |              160.6          |              162.8        
      f16 B=16, M=128, H=16, K=32    |               125.8           |        111.5     |     263.8  |              122.1          |              125.8        
      f32 B=16, M=128, H=16, K=32    |               189.8           |                  |     412.6  |              196.3          |              196.2        
      f16 B=16, M=128, H=16, K=64    |               126.0           |        112.4     |     265.8  |              120.8          |              125.7        
      f32 B=16, M=128, H=16, K=64    |               272.1           |                  |     498.7  |              285.9          |              283.3        
      f16 B=16, M=128, H=16, K=128   |               181.3           |        158.4     |     298.5  |              178.0          |              186.5        
      f32 B=16, M=128, H=16, K=128   |               509.6           |                  |     673.9  |              521.1          |              521.5        
      f16 B=16, M=128, H=16, K=256   |               774.0           |                  |     541.4  |              775.5          |              757.0        
      f32 B=16, M=128, H=16, K=256   |               975.2           |                  |    1162.6  |              994.5          |              994.5        
      f16 B=16, M=512, H=16, K=16    |               621.0           |        322.6     |    1204.9  |              555.0          |              519.9        
      f32 B=16, M=512, H=16, K=16    |              2148.0           |                  |    4414.4  |             2178.9          |             2180.4        
      f16 B=16, M=512, H=16, K=32    |               709.9           |        435.6     |    1306.2  |              653.1          |              602.8        
      f32 B=16, M=512, H=16, K=32    |              2335.6           |                  |    4640.7  |             2336.0          |             2336.3        
      f16 B=16, M=512, H=16, K=64    |               917.8           |        702.7     |    1545.8  |              849.9          |              797.4        
      f32 B=16, M=512, H=16, K=64    |              2965.5           |                  |    5125.0  |             2986.9          |             2988.3        
      f16 B=16, M=512, H=16, K=128   |              1644.0           |       1584.1     |    1983.4  |             1757.0          |             1548.0        
      f32 B=16, M=512, H=16, K=128   |              6152.7           |                  |    6099.1  |             6067.7          |             6068.6        
      f16 B=16, M=512, H=16, K=256   |              8178.3           |                  |    2899.1  |             7895.5          |             7977.4        
      f32 B=16, M=512, H=16, K=256   |             11894.0           |                  |   10639.5  |            11635.0          |            11624.7        
      f16 B=16, M=1024, H=16, K=16   |              2420.5           |       1240.8     |    4259.2  |             2234.6          |             2048.9        
      f32 B=16, M=1024, H=16, K=16   |              8467.2           |                  |   16650.4  |             8512.2          |             8510.2        
      f16 B=16, M=1024, H=16, K=32   |              2675.7           |       1618.9     |    4491.2  |             2441.3          |             2230.0        
      f32 B=16, M=1024, H=16, K=32   |              9010.0           |                  |   17301.0  |             9012.1          |             9015.7        
      f16 B=16, M=1024, H=16, K=64   |              3328.4           |       2370.3     |    4994.5  |             3032.2          |             2820.0        
      f32 B=16, M=1024, H=16, K=64   |             11566.8           |                  |   18714.2  |            11494.1          |            11492.8        
      f16 B=16, M=1024, H=16, K=128  |              5867.9           |       5632.5     |    5952.8  |             6401.1          |             5440.8        
      f32 B=16, M=1024, H=16, K=128  |             23345.4           |                  |   21523.7  |            22859.1          |            22870.3        
      f16 B=16, M=1024, H=16, K=256  |             30619.2           |                  |    7893.1  |            29884.9          |            29060.4        
      f32 B=16, M=1024, H=16, K=256  |             45211.4           |                  |   38093.0  |            43435.3          |            43423.8        
      f16 B=64, M=128, H=16, K=16    |               159.6           |        145.2     |     439.9  |              161.2          |              167.0        
      f32 B=64, M=128, H=16, K=16    |               493.4           |                  |    1270.0  |              502.7          |              503.2        
      f16 B=64, M=128, H=16, K=32    |               208.5           |        212.1     |     545.3  |              206.2          |              204.9        
      f32 B=64, M=128, H=16, K=32    |               601.2           |                  |    1427.0  |              610.5          |              610.6        
      f16 B=64, M=128, H=16, K=64    |               329.0           |        310.8     |     766.0  |              327.2          |              314.7        
      f32 B=64, M=128, H=16, K=64    |               867.7           |                  |    1743.5  |              889.2          |              889.0        
      f16 B=64, M=128, H=16, K=128   |               635.5           |        562.1     |    1226.7  |              613.3          |              650.7        
      f32 B=64, M=128, H=16, K=128   |              1774.4           |                  |    2386.5  |             1800.0          |             1799.6        
      f16 B=64, M=128, H=16, K=256   |              2839.8           |                  |    2122.8  |             2765.2          |             2821.9        
      f32 B=64, M=128, H=16, K=256   |              3419.2           |                  |    4320.7  |             3458.8          |             3457.0        
      f16 B=64, M=512, H=16, K=16    |              2316.2           |       1202.3     |    4487.1  |             1983.9          |             1868.4        
      f32 B=64, M=512, H=16, K=16    |              6686.0           |                  |   16991.5  |             6709.9          |             6713.2        
      f16 B=64, M=512, H=16, K=32    |              2701.7           |       1541.9     |    4975.9  |             2346.4          |             2184.3        
      f32 B=64, M=512, H=16, K=32    |              7460.5           |                  |   17859.9  |             7429.6          |             7438.0        
      f16 B=64, M=512, H=16, K=64    |              3461.2           |       2418.2     |    5886.1  |             3083.1          |             2942.6        
      f32 B=64, M=512, H=16, K=64    |              9553.4           |                  |   19768.6  |             9526.5          |             9516.6        
      f16 B=64, M=512, H=16, K=128   |              5875.5           |       5443.1     |    7711.0  |             6141.3          |             5513.6        
      f32 B=64, M=512, H=16, K=128   |             21317.0           |                  |   23651.2  |            20931.9          |            20932.3        
      f16 B=64, M=512, H=16, K=256   |             31238.2           |                  |   11490.6  |            28626.8          |            28954.0        
      f32 B=64, M=512, H=16, K=256   |             41124.8           |                  |   42468.0  |            39562.8          |            39579.5        
      f16 B=64, M=1024, H=16, K=16   |              9142.8           |       4707.2     |   16882.8  |             7887.0          |             7314.5        
      f32 B=64, M=1024, H=16, K=16   |             26512.3           |                  |   66311.7  |            26497.5          |            26459.8        
      f16 B=64, M=1024, H=16, K=32   |             10420.8           |       5698.3     |   17875.0  |             8851.4          |             7974.0        
      f32 B=64, M=1024, H=16, K=32   |             28300.0           |                  |   69088.7  |            28102.1          |            28098.2        
      f16 B=64, M=1024, H=16, K=64   |             12948.5           |       8119.0     |   19944.3  |            11064.2          |            10477.6        
      f32 B=64, M=1024, H=16, K=64   |             35600.6           |                  |   74762.8  |            35316.4          |            35361.2        
      f16 B=64, M=1024, H=16, K=128  |             20820.5           |      19184.2     |   23699.3  |            21954.8          |            19220.0        
      f32 B=64, M=1024, H=16, K=128  |             80800.3           |                  |   86003.8  |            78521.9          |            78393.9        
      f16 B=64, M=1024, H=16, K=256  |            114411.1           |                  |   32958.3  |           103287.2          |           104304.2        
      f32 B=64, M=1024, H=16, K=256  |            155731.5           |                  |  153011.0  |           148071.6          |           148165.6        

Times are in microseconds (us).
```
</details>


[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Nov 29, 2022
ghstack-source-id: 0dee5d90397264d6fd96a9a22b4a1788e9de425c
Pull Request resolved: #467
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this Daniel!

Could you maybe adapt the tests to show that with this change the backward has now better numerics?

Also, would it make sense to split the PR out of this stack so that we can get the other PRs merged?

**PERFORMANCE**

This makes performance worse in f16 :(
But I think we need it for stability

<details>
<summary>bw P100/V100 (f32/f16)</summary>

```
[---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------]                                                    
                                                         |  48_accf32_6f8e2f15  |  56_base_02bf6b4e  |  vanilla   |  57_tmpT_b516aec4
1 threads: --------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6289.6       |        6936.0      |    2183.6  |        6940.0    
                          f32 B=384, M=197, H=1, K=88    |         8793.3       |        9446.6      |    2175.1  |        9429.2    
                          f16 B=384, M=197, H=1, K=80    |         5989.9       |        6596.2      |    2146.8  |        6608.6    
                          f32 B=384, M=197, H=1, K=80    |         8427.1       |        8993.1      |    2134.0  |        9030.5    
                          f16 B=384, M=197, H=1, K=64    |         3347.0       |        3527.4      |    1799.3  |        3538.3    
                          f32 B=384, M=197, H=1, K=64    |         5563.1       |        5984.2      |    1801.6  |        5980.9    
                          f16 B=1024, M=197, H=1, K=88   |        15680.5       |       17424.8      |    5671.9  |       17452.6    
                          f32 B=1024, M=197, H=1, K=88   |        23784.2       |       25542.8      |    5664.0  |       25578.6    
                          f16 B=1024, M=197, H=1, K=80   |        14935.5       |       16581.3      |    5559.5  |       16587.3    
                          f32 B=1024, M=197, H=1, K=80   |        22767.6       |       24354.1      |    5550.7  |       24362.0    
                          f16 B=1024, M=197, H=1, K=64   |         8331.1       |        8671.5      |    4644.9  |        8695.3    
                          f32 B=1024, M=197, H=1, K=64   |        15061.6       |       16148.8      |    4650.0  |       16201.4    
                          f16 B=512, M=197, H=1, K=80    |         7594.7       |        8306.6      |    2824.6  |        8336.5    
                          f32 B=512, M=197, H=1, K=80    |        11857.0       |       12660.2      |    2807.4  |       12709.2    
                          f16 B=32, M=197, H=16, K=80    |         7779.8       |        8342.0      |    2820.8  |        8393.5    
                          f32 B=32, M=197, H=16, K=80    |        11846.5       |       12487.8      |    2806.0  |       12549.1    
                          f16 B=32, M=197, H=16, K=64    |         4258.4       |        4445.6      |    2374.0  |        4461.2    
                          f32 B=32, M=197, H=16, K=64    |         7828.9       |        8434.9      |    2376.0  |        8472.2    
                          f16 B=32, M=197, H=16, K=128   |         9025.1       |        9671.9      |    3159.3  |        9707.2    
                          f32 B=32, M=197, H=16, K=128   |        14139.8       |       14920.7      |    3157.7  |       14928.9    
                          f16 B=256, M=197, H=1, K=88    |         4608.9       |        5118.8      |    1478.4  |        5119.2    
                          f32 B=256, M=197, H=1, K=88    |         6174.0       |        6644.3      |    1477.6  |        6642.7    
                          f16 B=16, M=197, H=16, K=88    |         4618.4       |        5073.4      |    1479.5  |        5062.9    
                          f32 B=16, M=197, H=16, K=88    |         6114.2       |        6471.7      |    1474.9  |        6478.7    
                          f16 B=16, M=197, H=16, K=64    |         2490.3       |        2557.5      |    1225.9  |        2550.0    
                          f32 B=16, M=197, H=16, K=64    |         3918.9       |        4208.2      |    1227.0  |        4195.7    
                          f16 B=16, M=197, H=16, K=128   |         5210.0       |        5649.3      |    1635.2  |        5648.1    
                          f32 B=16, M=197, H=16, K=128   |         7103.1       |        7451.6      |    1645.5  |        7445.1    
                          f16 B=1, M=4096, H=160, K=128  |      1014229.1       |     1106182.6      |            |     1108040.6    
                          f32 B=1, M=4096, H=160, K=128  |      1258173.2       |     1243183.8      |            |     1241548.7    
                          f16 B=2, M=4096, H=160, K=128  |      1642279.2       |     1753736.9      |            |     1771655.3    
                          f32 B=2, M=4096, H=160, K=128  |      2505435.4       |     2477353.1      |            |     2473773.4    
                          f16 B=1, M=8192, H=160, K=128  |      4050128.4       |     4415962.1      |            |     4428194.9    
                          f32 B=1, M=8192, H=160, K=128  |      5042352.6       |     4970582.0      |            |     4965069.9    
                          f16 B=2, M=8192, H=160, K=128  |      6600732.3       |     7026378.5      |            |     7068051.8    
                          f16 B=1024, M=82, H=8, K=64    |        21572.8       |       22531.6      |    9059.1  |       22418.4    
                          f32 B=1024, M=82, H=8, K=64    |        38178.4       |       45927.2      |    9070.3  |       45708.0    
                          f16 B=150, M=256, H=16, K=64   |        21436.5       |       21927.4      |   12938.5  |       22001.0    
                          f32 B=150, M=256, H=16, K=64   |        33024.2       |       33196.3      |   13249.2  |       33199.6    
                          f16 B=64, M=256, H=12, K=64    |         6869.8       |        7048.6      |    4200.6  |        7073.6    
                          f32 B=64, M=256, H=12, K=64    |        10719.9       |       10832.1      |    4271.3  |       10843.6    
                          f16 B=1, M=4096, H=16, K=40    |       134722.6       |      145429.8      |   20587.0  |      143743.8    
                          f32 B=1, M=4096, H=16, K=40    |       143015.1       |      147850.6      |   20625.8  |      148272.8    
                          f16 B=1, M=16384, H=16, K=40   |      2149850.4       |     2323732.4      |            |     2301489.4    
                          f32 B=1, M=16384, H=16, K=40   |      2286478.1       |     2369812.4      |            |     2375911.8    
                          f16 B=16, M=128, H=16, K=16    |          497.2       |         502.9      |     623.8  |         503.6    
                          f32 B=16, M=128, H=16, K=16    |          573.5       |         609.8      |     617.5  |         611.6    
                          f16 B=16, M=128, H=16, K=32    |          563.9       |         573.1      |     624.2  |         575.2    
                          f32 B=16, M=128, H=16, K=32    |          661.5       |         702.2      |     620.4  |         703.2    
                          f16 B=16, M=128, H=16, K=64    |          708.9       |         722.6      |     619.9  |         724.3    
                          f32 B=16, M=128, H=16, K=64    |          916.2       |         953.5      |     618.6  |         953.6    
                          f16 B=16, M=128, H=16, K=128   |         1465.8       |        1542.8      |     624.2  |        1545.7    
                          f32 B=16, M=128, H=16, K=128   |         1829.1       |        1872.3      |     616.2  |        1873.9    
                          f16 B=16, M=128, H=16, K=256   |         3796.3       |        4002.5      |    1010.7  |        4008.4    
                          f32 B=16, M=128, H=16, K=256   |         3838.8       |        3957.1      |    1203.6  |        3951.8    
                          f16 B=16, M=512, H=16, K=16    |         7680.4       |        7775.1      |    4848.7  |        7752.8    
                          f32 B=16, M=512, H=16, K=16    |         9402.0       |        9926.1      |    4926.7  |        9914.5    
                          f16 B=16, M=512, H=16, K=32    |         8897.5       |        8999.0      |    5055.9  |        9003.9    
                          f32 B=16, M=512, H=16, K=32    |        10762.1       |       11320.9      |    5065.4  |       11341.8    
                          f16 B=16, M=512, H=16, K=64    |        10936.9       |       11214.8      |    5484.9  |       11254.4    
                          f32 B=16, M=512, H=16, K=64    |        15091.9       |       15210.3      |    5552.0  |       15196.3    
                          f16 B=16, M=512, H=16, K=128   |        23491.0       |       25234.8      |    7317.1  |       25300.3    
                          f32 B=16, M=512, H=16, K=128   |        30524.6       |       30320.0      |    7487.0  |       30200.6    
                          f16 B=16, M=512, H=16, K=256   |        50389.7       |       54525.2      |   14015.3  |       54931.1    
                          f32 B=16, M=512, H=16, K=256   |        62155.6       |       61046.4      |   14285.5  |       61081.1    
                          f16 B=16, M=1024, H=16, K=16   |        31289.9       |       31951.9      |   18778.4  |       32044.9    
                          f32 B=16, M=1024, H=16, K=16   |        37744.4       |       39586.6      |   18929.9  |       39739.9    
                          f16 B=16, M=1024, H=16, K=32   |        35770.8       |       36651.1      |   19620.2  |       36909.0    
                          f32 B=16, M=1024, H=16, K=32   |        43211.5       |       45544.5      |   19518.4  |       45664.8    
                          f16 B=16, M=1024, H=16, K=64   |        43865.1       |       44864.8      |   21286.0  |       45345.7    
                          f32 B=16, M=1024, H=16, K=64   |        60710.5       |       60966.9      |   21634.1  |       60921.9    
                          f16 B=16, M=1024, H=16, K=128  |        94633.7       |      101796.0      |   28502.1  |      102871.6    
                          f32 B=16, M=1024, H=16, K=128  |       124093.6       |      122196.9      |   28520.3  |      122043.0    
                          f16 B=16, M=1024, H=16, K=256  |       194780.7       |      212303.0      |   55419.1  |      214643.5    
                          f32 B=16, M=1024, H=16, K=256  |       250799.1       |      245196.2      |   55634.0  |      245534.2    
                          f16 B=64, M=128, H=16, K=16    |         1658.0       |        1661.4      |    1331.0  |        1662.7    
                          f32 B=64, M=128, H=16, K=16    |         2129.2       |        2266.2      |    1371.8  |        2269.4    
                          f16 B=64, M=128, H=16, K=32    |         1904.3       |        1903.5      |    1384.2  |        1906.4    
                          f32 B=64, M=128, H=16, K=32    |         2496.5       |        2643.4      |    1445.1  |        2639.8    
                          f16 B=64, M=128, H=16, K=64    |         2393.1       |        2432.3      |    1505.2  |        2437.8    
                          f32 B=64, M=128, H=16, K=64    |         3471.1       |        3561.2      |    1590.0  |        3565.7    
                          f16 B=64, M=128, H=16, K=128   |         4969.1       |        5266.3      |    1988.4  |        5272.4    
                          f32 B=64, M=128, H=16, K=128   |         6880.3       |        7040.6      |    2121.0  |        7024.3    
                          f16 B=64, M=128, H=16, K=256   |        12635.3       |       13334.7      |    3859.5  |       13350.7    
                          f32 B=64, M=128, H=16, K=256   |        14185.7       |       14514.4      |    4553.8  |       14503.1    
                          f16 B=64, M=512, H=16, K=16    |        26278.4       |       26189.9      |   18916.0  |       26110.0    
                          f32 B=64, M=512, H=16, K=16    |        35128.2       |       37075.9      |   19191.4  |       37174.5    
                          f16 B=64, M=512, H=16, K=32    |        30414.2       |       30938.2      |   19734.7  |       31071.7    
                          f32 B=64, M=512, H=16, K=32    |        40483.2       |       42922.5      |   19843.9  |       42868.5    
                          f16 B=64, M=512, H=16, K=64    |        37248.2       |       38179.7      |   21640.1  |       38327.0    
                          f32 B=64, M=512, H=16, K=64    |        57666.8       |       57675.3      |   21909.1  |       57697.2    
                          f16 B=64, M=512, H=16, K=128   |        80113.6       |       86165.7      |   28765.5  |       86368.3    
                          f32 B=64, M=512, H=16, K=128   |       115672.5       |      115161.0      |   28910.3  |      115320.0    
                          f16 B=64, M=512, H=16, K=256   |       169250.0       |      183791.8      |   56315.7  |      183735.0    
                          f32 B=64, M=512, H=16, K=256   |       236594.6       |      233093.6      |   56853.4  |      233170.2    
                          f16 B=64, M=1024, H=16, K=16   |       106022.3       |      109588.2      |   74303.6  |      109410.4    
                          f32 B=64, M=1024, H=16, K=16   |       141241.8       |      148854.1      |            |      149651.5    
                          f16 B=64, M=1024, H=16, K=32   |       120899.1       |      125044.6      |   77828.6  |      125716.9    
                          f32 B=64, M=1024, H=16, K=32   |       162478.5       |      173906.1      |            |      173216.4    
                          f16 B=64, M=1024, H=16, K=64   |       149044.1       |      152290.6      |   85821.6  |      152748.5    
                          f32 B=64, M=1024, H=16, K=64   |       233195.9       |      231479.6      |            |      231533.9    
                          f16 B=64, M=1024, H=16, K=128  |       319761.5       |      344076.5      |  113579.4  |      345058.5    
                          f32 B=64, M=1024, H=16, K=128  |       470172.3       |      466330.7      |            |      463507.4    
                          f16 B=64, M=1024, H=16, K=256  |       658362.1       |      717070.2      |            |      723057.4    
                          f32 B=64, M=1024, H=16, K=256  |       955624.0       |      935114.3      |            |      935945.4    
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1811.3       |        1686.3      |    1375.2  |        1699.6    
                          f32 B=384, M=197, H=1, K=88    |         4315.7       |        4665.1      |    2257.1  |        4663.1    
                          f16 B=384, M=197, H=1, K=80    |         1733.0       |        1616.3      |    1281.7  |        1619.8    
                          f32 B=384, M=197, H=1, K=80    |         3965.0       |        4226.4      |    2171.7  |        4228.3    
                          f16 B=384, M=197, H=1, K=64    |         1135.2       |        1084.4      |    1043.8  |        1083.0    
                          f32 B=384, M=197, H=1, K=64    |         2673.2       |        2883.0      |    1744.5  |        2878.2    
                          f16 B=1024, M=197, H=1, K=88   |         4721.0       |        4396.6      |    3725.3  |        4404.1    
                          f32 B=1024, M=197, H=1, K=88   |        10531.3       |       11443.8      |    6106.5  |       11464.2    
                          f16 B=1024, M=197, H=1, K=80   |         4520.2       |        4216.2      |    3329.2  |        4223.7    
                          f32 B=1024, M=197, H=1, K=80   |         9573.5       |       10301.4      |    5757.6  |       10305.3    
                          f16 B=1024, M=197, H=1, K=64   |         2788.1       |        2660.1      |    2674.6  |        2663.5    
                          f32 B=1024, M=197, H=1, K=64   |         6556.7       |        7102.4      |    4516.6  |        7096.8    
                          f16 B=512, M=197, H=1, K=80    |         2377.2       |        2228.1      |    1685.0  |        2231.4    
                          f32 B=512, M=197, H=1, K=80    |         5259.3       |        5639.7      |    2887.8  |        5636.6    
                          f16 B=32, M=197, H=16, K=80    |         2403.0       |        2201.1      |    1798.2  |        2204.9    
                          f32 B=32, M=197, H=16, K=80    |         5402.3       |        5667.7      |    3046.3  |        5662.4    
                          f16 B=32, M=197, H=16, K=64    |         1552.7       |        1486.3      |    1451.6  |        1485.2    
                          f32 B=32, M=197, H=16, K=64    |         3622.5       |        3911.0      |    2427.0  |        3912.9    
                          f16 B=32, M=197, H=16, K=128   |         2776.6       |        2611.9      |    2211.3  |        2613.3    
                          f32 B=32, M=197, H=16, K=128   |         6647.9       |        7082.3      |    4088.5  |        7104.8    
                          f16 B=256, M=197, H=1, K=88    |         1357.5       |        1285.5      |     941.3  |        1287.6    
                          f32 B=256, M=197, H=1, K=88    |         2874.1       |        3085.1      |    1543.2  |        3090.1    
                          f16 B=16, M=197, H=16, K=88    |         1349.1       |        1264.5      |     964.7  |        1263.1    
                          f32 B=16, M=197, H=16, K=88    |         2803.8       |        2967.0      |    1647.4  |        2972.0    
                          f16 B=16, M=197, H=16, K=64    |          765.5       |         728.4      |     765.3  |         731.0    
                          f32 B=16, M=197, H=16, K=64    |         1834.1       |        1969.7      |    1282.4  |        1974.4    
                          f16 B=16, M=197, H=16, K=128   |         1509.1       |        1432.6      |    1139.1  |        1433.9    
                          f32 B=16, M=197, H=16, K=128   |         3406.5       |        3606.2      |    2048.6  |        3613.3    
                          f16 B=1, M=4096, H=160, K=128  |       168807.1       |      148652.9      |            |      149343.8    
                          f32 B=1, M=4096, H=160, K=128  |       549864.6       |      586699.9      |            |      585699.9    
                          f16 B=2, M=4096, H=160, K=128  |       339010.5       |      298827.7      |            |      298808.4    
                          f32 B=2, M=4096, H=160, K=128  |      1106963.8       |     1176218.3      |            |     1179173.2    
                          f16 B=1, M=8192, H=160, K=128  |       679742.4       |      594323.2      |            |      595580.1    
                          f32 B=1, M=8192, H=160, K=128  |      2195491.3       |     2340248.4      |            |     2343505.7    
                          f16 B=2, M=8192, H=160, K=128  |      1364983.7       |     1193787.8      |            |     1192596.1    
                          f16 B=1024, M=82, H=8, K=64    |         9052.8       |        8762.6      |    5804.2  |        8757.8    
                          f32 B=1024, M=82, H=8, K=64    |        14726.3       |       16270.6      |   11059.7  |       16215.9    
                          f16 B=150, M=256, H=16, K=64   |         5662.9       |        5519.4      |    7557.8  |        5524.5    
                          f32 B=150, M=256, H=16, K=64   |        16700.3       |       17612.7      |   16426.0  |       17640.4    
                          f16 B=64, M=256, H=12, K=64    |         1849.1       |        1793.1      |    2383.5  |        1798.7    
                          f32 B=64, M=256, H=12, K=64    |         5451.5       |        5766.4      |    4975.3  |        5775.8    
                          f16 B=1, M=4096, H=16, K=40    |        47263.3       |       47850.4      |    8315.6  |       47777.2    
                          f32 B=1, M=4096, H=16, K=40    |       113099.4       |      113164.5      |   19536.0  |      113930.0    
                          f16 B=1, M=16384, H=16, K=40   |       757091.8       |      770365.7      |            |      765401.3    
                          f32 B=1, M=16384, H=16, K=40   |      1806827.7       |     1816302.6      |            |     1819162.1    
                          f16 B=16, M=128, H=16, K=16    |          219.2       |         218.5      |     480.6  |         231.8    
                          f32 B=16, M=128, H=16, K=16    |          301.2       |         308.2      |     498.9  |         307.9    
                          f16 B=16, M=128, H=16, K=32    |          227.6       |         215.4      |     473.8  |         220.1    
                          f32 B=16, M=128, H=16, K=32    |          395.8       |         401.0      |     455.1  |         400.8    
                          f16 B=16, M=128, H=16, K=64    |          225.6       |         214.8      |     510.1  |         229.9    
                          f32 B=16, M=128, H=16, K=64    |          561.7       |         583.1      |     598.9  |         581.2    
                          f16 B=16, M=128, H=16, K=128   |          404.6       |         392.0      |     524.0  |         394.8    
                          f32 B=16, M=128, H=16, K=128   |         1103.0       |        1140.8      |    1015.4  |        1142.0    
                          f16 B=16, M=128, H=16, K=256   |         1045.3       |        1049.1      |     889.6  |        1047.4    
                          f32 B=16, M=128, H=16, K=256   |         2181.3       |        2270.9      |    1869.2  |        2265.7    
                          f16 B=16, M=512, H=16, K=16    |         1731.5       |        1585.7      |    1908.5  |        1586.2    
                          f32 B=16, M=512, H=16, K=16    |         4513.8       |        4696.7      |    4222.1  |        4695.6    
                          f16 B=16, M=512, H=16, K=32    |         1942.2       |        1823.4      |    2086.3  |        1809.7    
                          f32 B=16, M=512, H=16, K=32    |         5596.1       |        5819.4      |    4588.1  |        5833.6    
                          f16 B=16, M=512, H=16, K=64    |         2450.2       |        2340.8      |    2580.6  |        2353.2    
                          f32 B=16, M=512, H=16, K=64    |         7619.5       |        7853.1      |    5536.8  |        7875.1    
                          f16 B=16, M=512, H=16, K=128   |         4884.8       |        4473.5      |    3388.7  |        4487.9    
                          f32 B=16, M=512, H=16, K=128   |        15010.0       |       15557.1      |    8979.4  |       15513.1    
                          f16 B=16, M=512, H=16, K=256   |        12973.9       |       11134.7      |    5418.2  |       11106.3    
                          f32 B=16, M=512, H=16, K=256   |        29751.1       |       30979.1      |   16856.3  |       31012.8    
                          f16 B=16, M=1024, H=16, K=16   |         6802.6       |        6185.1      |    6996.0  |        6191.0    
                          f32 B=16, M=1024, H=16, K=16   |        18098.2       |       18954.5      |   16129.1  |       19166.8    
                          f16 B=16, M=1024, H=16, K=32   |         7531.9       |        7065.0      |    7436.0  |        7067.5    
                          f32 B=16, M=1024, H=16, K=32   |        21999.9       |       22901.5      |   17040.0  |       22899.4    
                          f16 B=16, M=1024, H=16, K=64   |         9312.0       |        8854.3      |    8605.6  |        8863.9    
                          f32 B=16, M=1024, H=16, K=64   |        29837.7       |       30895.1      |   20355.3  |       30878.5    
                          f16 B=16, M=1024, H=16, K=128  |        18979.0       |       16951.0      |   10561.3  |       16995.2    
                          f32 B=16, M=1024, H=16, K=128  |        58738.3       |       60861.2      |   33427.0  |       60599.8    
                          f16 B=16, M=1024, H=16, K=256  |        49681.9       |       41833.3      |   17329.5  |       41921.6    
                          f32 B=16, M=1024, H=16, K=256  |       117362.4       |      121004.8      |   60515.8  |      122046.9    
                          f16 B=64, M=128, H=16, K=16    |          432.2       |         411.1      |     642.1  |         411.3    
                          f32 B=64, M=128, H=16, K=16    |         1028.9       |        1057.9      |    1233.4  |        1056.6    
                          f16 B=64, M=128, H=16, K=32    |          522.5       |         500.7      |     813.6  |         499.7    
                          f32 B=64, M=128, H=16, K=32    |         1403.5       |        1443.4      |    1535.6  |        1443.6    
                          f16 B=64, M=128, H=16, K=64    |          750.0       |         739.8      |    1185.6  |         741.0    
                          f32 B=64, M=128, H=16, K=64    |         2013.9       |        2110.1      |    2156.8  |        2105.4    
                          f16 B=64, M=128, H=16, K=128   |         1421.2       |        1387.9      |    1915.5  |        1388.3    
                          f32 B=64, M=128, H=16, K=128   |         3946.5       |        4156.7      |    3780.1  |        4156.3    
                          f16 B=64, M=128, H=16, K=256   |         3811.5       |        3810.6      |    3448.7  |        3811.0    
                          f32 B=64, M=128, H=16, K=256   |         7983.9       |        8432.7      |    7304.0  |        8415.5    
                          f16 B=64, M=512, H=16, K=16    |         6157.4       |        5523.9      |    7461.8  |        5528.1    
                          f32 B=64, M=512, H=16, K=16    |        16118.2       |       17004.7      |   16651.7  |       16984.5    
                          f16 B=64, M=512, H=16, K=32    |         6985.1       |        6483.1      |    8278.5  |        6495.1    
                          f32 B=64, M=512, H=16, K=32    |        20471.9       |       21230.7      |   18420.5  |       21269.0    
                          f16 B=64, M=512, H=16, K=64    |         9045.8       |        8633.5      |   10337.1  |        8674.1    
                          f32 B=64, M=512, H=16, K=64    |        27675.0       |       29282.8      |   22883.7  |       29136.9    
                          f16 B=64, M=512, H=16, K=128   |        17594.2       |       15805.9      |   14700.6  |       15788.2    
                          f32 B=64, M=512, H=16, K=128   |        54612.4       |       57815.7      |   39974.3  |       57951.8    
                          f16 B=64, M=512, H=16, K=256   |        47452.6       |       40093.5      |   27087.5  |       40175.6    
                          f32 B=64, M=512, H=16, K=256   |       108880.3       |      115953.1      |   77794.2  |      115951.0    
                          f16 B=64, M=1024, H=16, K=16   |        24369.0       |       21533.0      |   28448.6  |       21556.6    
                          f32 B=64, M=1024, H=16, K=16   |        64649.7       |       68791.8      |            |       68407.4    
                          f16 B=64, M=1024, H=16, K=32   |        27143.1       |       25683.7      |   30252.7  |       25727.9    
                          f32 B=64, M=1024, H=16, K=32   |        79967.5       |       83351.5      |            |       83084.7    
                          f16 B=64, M=1024, H=16, K=64   |        34667.0       |       32592.7      |   36991.2  |       32659.8    
                          f32 B=64, M=1024, H=16, K=64   |       108282.2       |      113858.1      |            |      114286.0    
                          f16 B=64, M=1024, H=16, K=128  |        68519.5       |       59757.3      |   48834.4  |       59817.7    
                          f32 B=64, M=1024, H=16, K=128  |       215465.9       |      227335.8      |            |      227204.0    
                          f16 B=64, M=1024, H=16, K=256  |       183070.4       |      150960.1      |            |      150947.1    
                          f32 B=64, M=1024, H=16, K=256  |       425832.9       |      453717.2      |            |      453349.2    

Times are in microseconds (us).
```
</details>

<details>
<summary>bw A100 (f32/f16)</summary>

```
[----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------]                                                       
                                     |  48_accf32_69654fdb[cutlass]  |  flash[flshatt]  |  vanilla   |  56_base_02bf6b4e[cutlass]  |  57_tmpT_b516aec4[cutlass]
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |               613.4           |                  |    2264.9  |              612.8          |              578.3        
      f32 B=384, M=197, H=1, K=88    |              2335.2           |                  |    1843.0  |             2438.4          |             2425.1        
      f16 B=384, M=197, H=1, K=80    |               583.6           |                  |    1922.9  |              577.6          |              548.4        
      f32 B=384, M=197, H=1, K=80    |              2241.3           |                  |    1787.8  |             2333.2          |             2333.1        
      f16 B=384, M=197, H=1, K=64    |               405.6           |        232.5     |    1809.8  |              386.4          |              366.1        
      f32 B=384, M=197, H=1, K=64    |              1259.7           |                  |    1675.6  |             1309.6          |             1316.8        
      f16 B=1024, M=197, H=1, K=88   |              1538.2           |                  |    5964.7  |             1550.8          |             1454.4        
      f32 B=1024, M=197, H=1, K=88   |              6031.4           |                  |    4559.5  |             6325.4          |             6332.7        
      f16 B=1024, M=197, H=1, K=80   |              1458.8           |                  |    5038.2  |             1463.8          |             1379.0        
      f32 B=1024, M=197, H=1, K=80   |              5786.2           |                  |    4412.6  |             6059.0          |             6079.9        
      f16 B=1024, M=197, H=1, K=64   |               929.5           |        575.9     |    4735.6  |              862.1          |              821.6        
      f32 B=1024, M=197, H=1, K=64   |              3289.9           |                  |    4119.9  |             3434.4          |             3441.4        
      f16 B=512, M=197, H=1, K=80    |               744.5           |                  |    2544.7  |              735.3          |              695.8        
      f32 B=512, M=197, H=1, K=80    |              2889.4           |                  |    2286.1  |             3008.8          |             3029.0        
      f16 B=32, M=197, H=16, K=80    |               741.6           |                  |    2569.0  |              723.3          |              693.4        
      f32 B=32, M=197, H=16, K=80    |              2878.6           |                  |    2355.3  |             3003.0          |             3024.1        
      f16 B=32, M=197, H=16, K=64    |               478.5           |        295.7     |    2429.2  |              456.1          |              426.7        
      f32 B=32, M=197, H=16, K=64    |              1784.4           |                  |    2196.9  |             1863.7          |             1860.2        
      f16 B=32, M=197, H=16, K=128   |               887.0           |        682.3     |    4492.9  |              857.5          |              853.8        
      f32 B=32, M=197, H=16, K=128   |              3546.6           |                  |    2807.3  |             3734.9          |             3737.1        
      f16 B=256, M=197, H=1, K=88    |               445.1           |                  |    1528.9  |              445.0          |              422.2        
      f32 B=256, M=197, H=1, K=88    |              1678.8           |                  |    1207.6  |             1746.6          |             1752.8        
      f16 B=16, M=197, H=16, K=88    |               441.5           |                  |    1544.2  |              437.3          |              419.5        
      f32 B=16, M=197, H=16, K=88    |              1668.3           |                  |    1250.4  |             1742.7          |             1746.1        
      f16 B=16, M=197, H=16, K=64    |               247.4           |        165.6     |    1242.5  |              233.2          |              217.5        
      f32 B=16, M=197, H=16, K=64    |              1051.1           |                  |    1125.3  |             1099.2          |             1096.6        
      f16 B=16, M=197, H=16, K=128   |               498.4           |        386.2     |    2264.5  |              488.0          |              480.5        
      f32 B=16, M=197, H=16, K=128   |              1950.2           |                  |    1446.7  |             2039.0          |             2028.6        
      f16 B=1, M=4096, H=160, K=128  |             55915.0           |      54620.4     |   45909.5  |            63407.6          |            51298.8        
      f32 B=1, M=4096, H=160, K=128  |            238514.6           |                  |            |           232677.5          |           232672.5        
      f16 B=2, M=4096, H=160, K=128  |             93612.0           |      84238.0     |            |           100433.1          |            84858.1        
      f32 B=2, M=4096, H=160, K=128  |            375037.4           |                  |            |           364234.1          |           364407.2        
      f16 B=1, M=8192, H=160, K=128  |            223261.8           |     215499.8     |            |           251806.6          |           202133.7        
      f32 B=1, M=8192, H=160, K=128  |            946708.9           |                  |            |           924986.2          |           924988.9        
      f16 B=2, M=8192, H=160, K=128  |            367969.2           |     330092.8     |            |           395881.3          |           332000.6        
      f32 B=2, M=8192, H=160, K=128  |           1492031.4           |                  |            |          1448691.1          |          1449146.1        
      f16 B=1024, M=82, H=8, K=64    |              1890.2           |       1620.3     |    3819.7  |             1861.3          |             1764.6        
      f32 B=1024, M=82, H=8, K=64    |              8428.2           |                  |    8735.1  |             8831.3          |             8867.4        
      f16 B=150, M=256, H=16, K=64   |              2292.3           |       1625.3     |    4555.9  |             2109.8          |             2019.4        
      f32 B=150, M=256, H=16, K=64   |              6252.4           |                  |   12948.1  |             6288.9          |             6281.3        
      f16 B=64, M=256, H=12, K=64    |               782.2           |        567.4     |    1498.0  |              731.2          |              699.7        
      f32 B=64, M=256, H=12, K=64    |              2141.2           |                  |    4266.6  |             2160.6          |             2160.2        
      f16 B=1, M=4096, H=16, K=40    |             23504.2           |                  |    4196.1  |            23699.0          |            23008.9        
      f32 B=1, M=4096, H=16, K=40    |             73699.5           |                  |   17755.3  |            73261.8          |            73078.5        
      f16 B=1, M=16384, H=16, K=40   |            391408.9           |                  |            |           439777.3          |           407653.7        
      f32 B=1, M=16384, H=16, K=40   |           1196173.6           |                  |            |          1181547.9          |          1181625.1        
      f16 B=256, M=4096, H=16, K=64  |            733221.8           |     439627.5     |            |           603237.1          |           565905.8        
      f16 B=16, M=128, H=16, K=16    |               130.0           |        113.1     |     265.2  |              125.5          |              124.4        
      f32 B=16, M=128, H=16, K=16    |               161.5           |                  |     373.1  |              160.6          |              162.8        
      f16 B=16, M=128, H=16, K=32    |               125.8           |        111.5     |     263.8  |              122.1          |              125.8        
      f32 B=16, M=128, H=16, K=32    |               189.8           |                  |     412.6  |              196.3          |              196.2        
      f16 B=16, M=128, H=16, K=64    |               126.0           |        112.4     |     265.8  |              120.8          |              125.7        
      f32 B=16, M=128, H=16, K=64    |               272.1           |                  |     498.7  |              285.9          |              283.3        
      f16 B=16, M=128, H=16, K=128   |               181.3           |        158.4     |     298.5  |              178.0          |              186.5        
      f32 B=16, M=128, H=16, K=128   |               509.6           |                  |     673.9  |              521.1          |              521.5        
      f16 B=16, M=128, H=16, K=256   |               774.0           |                  |     541.4  |              775.5          |              757.0        
      f32 B=16, M=128, H=16, K=256   |               975.2           |                  |    1162.6  |              994.5          |              994.5        
      f16 B=16, M=512, H=16, K=16    |               621.0           |        322.6     |    1204.9  |              555.0          |              519.9        
      f32 B=16, M=512, H=16, K=16    |              2148.0           |                  |    4414.4  |             2178.9          |             2180.4        
      f16 B=16, M=512, H=16, K=32    |               709.9           |        435.6     |    1306.2  |              653.1          |              602.8        
      f32 B=16, M=512, H=16, K=32    |              2335.6           |                  |    4640.7  |             2336.0          |             2336.3        
      f16 B=16, M=512, H=16, K=64    |               917.8           |        702.7     |    1545.8  |              849.9          |              797.4        
      f32 B=16, M=512, H=16, K=64    |              2965.5           |                  |    5125.0  |             2986.9          |             2988.3        
      f16 B=16, M=512, H=16, K=128   |              1644.0           |       1584.1     |    1983.4  |             1757.0          |             1548.0        
      f32 B=16, M=512, H=16, K=128   |              6152.7           |                  |    6099.1  |             6067.7          |             6068.6        
      f16 B=16, M=512, H=16, K=256   |              8178.3           |                  |    2899.1  |             7895.5          |             7977.4        
      f32 B=16, M=512, H=16, K=256   |             11894.0           |                  |   10639.5  |            11635.0          |            11624.7        
      f16 B=16, M=1024, H=16, K=16   |              2420.5           |       1240.8     |    4259.2  |             2234.6          |             2048.9        
      f32 B=16, M=1024, H=16, K=16   |              8467.2           |                  |   16650.4  |             8512.2          |             8510.2        
      f16 B=16, M=1024, H=16, K=32   |              2675.7           |       1618.9     |    4491.2  |             2441.3          |             2230.0        
      f32 B=16, M=1024, H=16, K=32   |              9010.0           |                  |   17301.0  |             9012.1          |             9015.7        
      f16 B=16, M=1024, H=16, K=64   |              3328.4           |       2370.3     |    4994.5  |             3032.2          |             2820.0        
      f32 B=16, M=1024, H=16, K=64   |             11566.8           |                  |   18714.2  |            11494.1          |            11492.8        
      f16 B=16, M=1024, H=16, K=128  |              5867.9           |       5632.5     |    5952.8  |             6401.1          |             5440.8        
      f32 B=16, M=1024, H=16, K=128  |             23345.4           |                  |   21523.7  |            22859.1          |            22870.3        
      f16 B=16, M=1024, H=16, K=256  |             30619.2           |                  |    7893.1  |            29884.9          |            29060.4        
      f32 B=16, M=1024, H=16, K=256  |             45211.4           |                  |   38093.0  |            43435.3          |            43423.8        
      f16 B=64, M=128, H=16, K=16    |               159.6           |        145.2     |     439.9  |              161.2          |              167.0        
      f32 B=64, M=128, H=16, K=16    |               493.4           |                  |    1270.0  |              502.7          |              503.2        
      f16 B=64, M=128, H=16, K=32    |               208.5           |        212.1     |     545.3  |              206.2          |              204.9        
      f32 B=64, M=128, H=16, K=32    |               601.2           |                  |    1427.0  |              610.5          |              610.6        
      f16 B=64, M=128, H=16, K=64    |               329.0           |        310.8     |     766.0  |              327.2          |              314.7        
      f32 B=64, M=128, H=16, K=64    |               867.7           |                  |    1743.5  |              889.2          |              889.0        
      f16 B=64, M=128, H=16, K=128   |               635.5           |        562.1     |    1226.7  |              613.3          |              650.7        
      f32 B=64, M=128, H=16, K=128   |              1774.4           |                  |    2386.5  |             1800.0          |             1799.6        
      f16 B=64, M=128, H=16, K=256   |              2839.8           |                  |    2122.8  |             2765.2          |             2821.9        
      f32 B=64, M=128, H=16, K=256   |              3419.2           |                  |    4320.7  |             3458.8          |             3457.0        
      f16 B=64, M=512, H=16, K=16    |              2316.2           |       1202.3     |    4487.1  |             1983.9          |             1868.4        
      f32 B=64, M=512, H=16, K=16    |              6686.0           |                  |   16991.5  |             6709.9          |             6713.2        
      f16 B=64, M=512, H=16, K=32    |              2701.7           |       1541.9     |    4975.9  |             2346.4          |             2184.3        
      f32 B=64, M=512, H=16, K=32    |              7460.5           |                  |   17859.9  |             7429.6          |             7438.0        
      f16 B=64, M=512, H=16, K=64    |              3461.2           |       2418.2     |    5886.1  |             3083.1          |             2942.6        
      f32 B=64, M=512, H=16, K=64    |              9553.4           |                  |   19768.6  |             9526.5          |             9516.6        
      f16 B=64, M=512, H=16, K=128   |              5875.5           |       5443.1     |    7711.0  |             6141.3          |             5513.6        
      f32 B=64, M=512, H=16, K=128   |             21317.0           |                  |   23651.2  |            20931.9          |            20932.3        
      f16 B=64, M=512, H=16, K=256   |             31238.2           |                  |   11490.6  |            28626.8          |            28954.0        
      f32 B=64, M=512, H=16, K=256   |             41124.8           |                  |   42468.0  |            39562.8          |            39579.5        
      f16 B=64, M=1024, H=16, K=16   |              9142.8           |       4707.2     |   16882.8  |             7887.0          |             7314.5        
      f32 B=64, M=1024, H=16, K=16   |             26512.3           |                  |   66311.7  |            26497.5          |            26459.8        
      f16 B=64, M=1024, H=16, K=32   |             10420.8           |       5698.3     |   17875.0  |             8851.4          |             7974.0        
      f32 B=64, M=1024, H=16, K=32   |             28300.0           |                  |   69088.7  |            28102.1          |            28098.2        
      f16 B=64, M=1024, H=16, K=64   |             12948.5           |       8119.0     |   19944.3  |            11064.2          |            10477.6        
      f32 B=64, M=1024, H=16, K=64   |             35600.6           |                  |   74762.8  |            35316.4          |            35361.2        
      f16 B=64, M=1024, H=16, K=128  |             20820.5           |      19184.2     |   23699.3  |            21954.8          |            19220.0        
      f32 B=64, M=1024, H=16, K=128  |             80800.3           |                  |   86003.8  |            78521.9          |            78393.9        
      f16 B=64, M=1024, H=16, K=256  |            114411.1           |                  |   32958.3  |           103287.2          |           104304.2        
      f32 B=64, M=1024, H=16, K=256  |            155731.5           |                  |  153011.0  |           148071.6          |           148165.6        

Times are in microseconds (us).
```
</details>


[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Nov 29, 2022
ghstack-source-id: 4759ee88b471d7e10e40c72d9b9d4a441c82aeaf
Pull Request resolved: #467
**PERFORMANCE**

This makes performance worse in f16 :(
But I think we need it for stability

<details>
<summary>bw P100/V100 (f32/f16)</summary>

```
[---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------]                                                    
                                                         |  48_accf32_6f8e2f15  |  56_base_02bf6b4e  |  vanilla   |  57_tmpT_b516aec4
1 threads: --------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6289.6       |        6936.0      |    2183.6  |        6940.0    
                          f32 B=384, M=197, H=1, K=88    |         8793.3       |        9446.6      |    2175.1  |        9429.2    
                          f16 B=384, M=197, H=1, K=80    |         5989.9       |        6596.2      |    2146.8  |        6608.6    
                          f32 B=384, M=197, H=1, K=80    |         8427.1       |        8993.1      |    2134.0  |        9030.5    
                          f16 B=384, M=197, H=1, K=64    |         3347.0       |        3527.4      |    1799.3  |        3538.3    
                          f32 B=384, M=197, H=1, K=64    |         5563.1       |        5984.2      |    1801.6  |        5980.9    
                          f16 B=1024, M=197, H=1, K=88   |        15680.5       |       17424.8      |    5671.9  |       17452.6    
                          f32 B=1024, M=197, H=1, K=88   |        23784.2       |       25542.8      |    5664.0  |       25578.6    
                          f16 B=1024, M=197, H=1, K=80   |        14935.5       |       16581.3      |    5559.5  |       16587.3    
                          f32 B=1024, M=197, H=1, K=80   |        22767.6       |       24354.1      |    5550.7  |       24362.0    
                          f16 B=1024, M=197, H=1, K=64   |         8331.1       |        8671.5      |    4644.9  |        8695.3    
                          f32 B=1024, M=197, H=1, K=64   |        15061.6       |       16148.8      |    4650.0  |       16201.4    
                          f16 B=512, M=197, H=1, K=80    |         7594.7       |        8306.6      |    2824.6  |        8336.5    
                          f32 B=512, M=197, H=1, K=80    |        11857.0       |       12660.2      |    2807.4  |       12709.2    
                          f16 B=32, M=197, H=16, K=80    |         7779.8       |        8342.0      |    2820.8  |        8393.5    
                          f32 B=32, M=197, H=16, K=80    |        11846.5       |       12487.8      |    2806.0  |       12549.1    
                          f16 B=32, M=197, H=16, K=64    |         4258.4       |        4445.6      |    2374.0  |        4461.2    
                          f32 B=32, M=197, H=16, K=64    |         7828.9       |        8434.9      |    2376.0  |        8472.2    
                          f16 B=32, M=197, H=16, K=128   |         9025.1       |        9671.9      |    3159.3  |        9707.2    
                          f32 B=32, M=197, H=16, K=128   |        14139.8       |       14920.7      |    3157.7  |       14928.9    
                          f16 B=256, M=197, H=1, K=88    |         4608.9       |        5118.8      |    1478.4  |        5119.2    
                          f32 B=256, M=197, H=1, K=88    |         6174.0       |        6644.3      |    1477.6  |        6642.7    
                          f16 B=16, M=197, H=16, K=88    |         4618.4       |        5073.4      |    1479.5  |        5062.9    
                          f32 B=16, M=197, H=16, K=88    |         6114.2       |        6471.7      |    1474.9  |        6478.7    
                          f16 B=16, M=197, H=16, K=64    |         2490.3       |        2557.5      |    1225.9  |        2550.0    
                          f32 B=16, M=197, H=16, K=64    |         3918.9       |        4208.2      |    1227.0  |        4195.7    
                          f16 B=16, M=197, H=16, K=128   |         5210.0       |        5649.3      |    1635.2  |        5648.1    
                          f32 B=16, M=197, H=16, K=128   |         7103.1       |        7451.6      |    1645.5  |        7445.1    
                          f16 B=1, M=4096, H=160, K=128  |      1014229.1       |     1106182.6      |            |     1108040.6    
                          f32 B=1, M=4096, H=160, K=128  |      1258173.2       |     1243183.8      |            |     1241548.7    
                          f16 B=2, M=4096, H=160, K=128  |      1642279.2       |     1753736.9      |            |     1771655.3    
                          f32 B=2, M=4096, H=160, K=128  |      2505435.4       |     2477353.1      |            |     2473773.4    
                          f16 B=1, M=8192, H=160, K=128  |      4050128.4       |     4415962.1      |            |     4428194.9    
                          f32 B=1, M=8192, H=160, K=128  |      5042352.6       |     4970582.0      |            |     4965069.9    
                          f16 B=2, M=8192, H=160, K=128  |      6600732.3       |     7026378.5      |            |     7068051.8    
                          f16 B=1024, M=82, H=8, K=64    |        21572.8       |       22531.6      |    9059.1  |       22418.4    
                          f32 B=1024, M=82, H=8, K=64    |        38178.4       |       45927.2      |    9070.3  |       45708.0    
                          f16 B=150, M=256, H=16, K=64   |        21436.5       |       21927.4      |   12938.5  |       22001.0    
                          f32 B=150, M=256, H=16, K=64   |        33024.2       |       33196.3      |   13249.2  |       33199.6    
                          f16 B=64, M=256, H=12, K=64    |         6869.8       |        7048.6      |    4200.6  |        7073.6    
                          f32 B=64, M=256, H=12, K=64    |        10719.9       |       10832.1      |    4271.3  |       10843.6    
                          f16 B=1, M=4096, H=16, K=40    |       134722.6       |      145429.8      |   20587.0  |      143743.8    
                          f32 B=1, M=4096, H=16, K=40    |       143015.1       |      147850.6      |   20625.8  |      148272.8    
                          f16 B=1, M=16384, H=16, K=40   |      2149850.4       |     2323732.4      |            |     2301489.4    
                          f32 B=1, M=16384, H=16, K=40   |      2286478.1       |     2369812.4      |            |     2375911.8    
                          f16 B=16, M=128, H=16, K=16    |          497.2       |         502.9      |     623.8  |         503.6    
                          f32 B=16, M=128, H=16, K=16    |          573.5       |         609.8      |     617.5  |         611.6    
                          f16 B=16, M=128, H=16, K=32    |          563.9       |         573.1      |     624.2  |         575.2    
                          f32 B=16, M=128, H=16, K=32    |          661.5       |         702.2      |     620.4  |         703.2    
                          f16 B=16, M=128, H=16, K=64    |          708.9       |         722.6      |     619.9  |         724.3    
                          f32 B=16, M=128, H=16, K=64    |          916.2       |         953.5      |     618.6  |         953.6    
                          f16 B=16, M=128, H=16, K=128   |         1465.8       |        1542.8      |     624.2  |        1545.7    
                          f32 B=16, M=128, H=16, K=128   |         1829.1       |        1872.3      |     616.2  |        1873.9    
                          f16 B=16, M=128, H=16, K=256   |         3796.3       |        4002.5      |    1010.7  |        4008.4    
                          f32 B=16, M=128, H=16, K=256   |         3838.8       |        3957.1      |    1203.6  |        3951.8    
                          f16 B=16, M=512, H=16, K=16    |         7680.4       |        7775.1      |    4848.7  |        7752.8    
                          f32 B=16, M=512, H=16, K=16    |         9402.0       |        9926.1      |    4926.7  |        9914.5    
                          f16 B=16, M=512, H=16, K=32    |         8897.5       |        8999.0      |    5055.9  |        9003.9    
                          f32 B=16, M=512, H=16, K=32    |        10762.1       |       11320.9      |    5065.4  |       11341.8    
                          f16 B=16, M=512, H=16, K=64    |        10936.9       |       11214.8      |    5484.9  |       11254.4    
                          f32 B=16, M=512, H=16, K=64    |        15091.9       |       15210.3      |    5552.0  |       15196.3    
                          f16 B=16, M=512, H=16, K=128   |        23491.0       |       25234.8      |    7317.1  |       25300.3    
                          f32 B=16, M=512, H=16, K=128   |        30524.6       |       30320.0      |    7487.0  |       30200.6    
                          f16 B=16, M=512, H=16, K=256   |        50389.7       |       54525.2      |   14015.3  |       54931.1    
                          f32 B=16, M=512, H=16, K=256   |        62155.6       |       61046.4      |   14285.5  |       61081.1    
                          f16 B=16, M=1024, H=16, K=16   |        31289.9       |       31951.9      |   18778.4  |       32044.9    
                          f32 B=16, M=1024, H=16, K=16   |        37744.4       |       39586.6      |   18929.9  |       39739.9    
                          f16 B=16, M=1024, H=16, K=32   |        35770.8       |       36651.1      |   19620.2  |       36909.0    
                          f32 B=16, M=1024, H=16, K=32   |        43211.5       |       45544.5      |   19518.4  |       45664.8    
                          f16 B=16, M=1024, H=16, K=64   |        43865.1       |       44864.8      |   21286.0  |       45345.7    
                          f32 B=16, M=1024, H=16, K=64   |        60710.5       |       60966.9      |   21634.1  |       60921.9    
                          f16 B=16, M=1024, H=16, K=128  |        94633.7       |      101796.0      |   28502.1  |      102871.6    
                          f32 B=16, M=1024, H=16, K=128  |       124093.6       |      122196.9      |   28520.3  |      122043.0    
                          f16 B=16, M=1024, H=16, K=256  |       194780.7       |      212303.0      |   55419.1  |      214643.5    
                          f32 B=16, M=1024, H=16, K=256  |       250799.1       |      245196.2      |   55634.0  |      245534.2    
                          f16 B=64, M=128, H=16, K=16    |         1658.0       |        1661.4      |    1331.0  |        1662.7    
                          f32 B=64, M=128, H=16, K=16    |         2129.2       |        2266.2      |    1371.8  |        2269.4    
                          f16 B=64, M=128, H=16, K=32    |         1904.3       |        1903.5      |    1384.2  |        1906.4    
                          f32 B=64, M=128, H=16, K=32    |         2496.5       |        2643.4      |    1445.1  |        2639.8    
                          f16 B=64, M=128, H=16, K=64    |         2393.1       |        2432.3      |    1505.2  |        2437.8    
                          f32 B=64, M=128, H=16, K=64    |         3471.1       |        3561.2      |    1590.0  |        3565.7    
                          f16 B=64, M=128, H=16, K=128   |         4969.1       |        5266.3      |    1988.4  |        5272.4    
                          f32 B=64, M=128, H=16, K=128   |         6880.3       |        7040.6      |    2121.0  |        7024.3    
                          f16 B=64, M=128, H=16, K=256   |        12635.3       |       13334.7      |    3859.5  |       13350.7    
                          f32 B=64, M=128, H=16, K=256   |        14185.7       |       14514.4      |    4553.8  |       14503.1    
                          f16 B=64, M=512, H=16, K=16    |        26278.4       |       26189.9      |   18916.0  |       26110.0    
                          f32 B=64, M=512, H=16, K=16    |        35128.2       |       37075.9      |   19191.4  |       37174.5    
                          f16 B=64, M=512, H=16, K=32    |        30414.2       |       30938.2      |   19734.7  |       31071.7    
                          f32 B=64, M=512, H=16, K=32    |        40483.2       |       42922.5      |   19843.9  |       42868.5    
                          f16 B=64, M=512, H=16, K=64    |        37248.2       |       38179.7      |   21640.1  |       38327.0    
                          f32 B=64, M=512, H=16, K=64    |        57666.8       |       57675.3      |   21909.1  |       57697.2    
                          f16 B=64, M=512, H=16, K=128   |        80113.6       |       86165.7      |   28765.5  |       86368.3    
                          f32 B=64, M=512, H=16, K=128   |       115672.5       |      115161.0      |   28910.3  |      115320.0    
                          f16 B=64, M=512, H=16, K=256   |       169250.0       |      183791.8      |   56315.7  |      183735.0    
                          f32 B=64, M=512, H=16, K=256   |       236594.6       |      233093.6      |   56853.4  |      233170.2    
                          f16 B=64, M=1024, H=16, K=16   |       106022.3       |      109588.2      |   74303.6  |      109410.4    
                          f32 B=64, M=1024, H=16, K=16   |       141241.8       |      148854.1      |            |      149651.5    
                          f16 B=64, M=1024, H=16, K=32   |       120899.1       |      125044.6      |   77828.6  |      125716.9    
                          f32 B=64, M=1024, H=16, K=32   |       162478.5       |      173906.1      |            |      173216.4    
                          f16 B=64, M=1024, H=16, K=64   |       149044.1       |      152290.6      |   85821.6  |      152748.5    
                          f32 B=64, M=1024, H=16, K=64   |       233195.9       |      231479.6      |            |      231533.9    
                          f16 B=64, M=1024, H=16, K=128  |       319761.5       |      344076.5      |  113579.4  |      345058.5    
                          f32 B=64, M=1024, H=16, K=128  |       470172.3       |      466330.7      |            |      463507.4    
                          f16 B=64, M=1024, H=16, K=256  |       658362.1       |      717070.2      |            |      723057.4    
                          f32 B=64, M=1024, H=16, K=256  |       955624.0       |      935114.3      |            |      935945.4    
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1811.3       |        1686.3      |    1375.2  |        1699.6    
                          f32 B=384, M=197, H=1, K=88    |         4315.7       |        4665.1      |    2257.1  |        4663.1    
                          f16 B=384, M=197, H=1, K=80    |         1733.0       |        1616.3      |    1281.7  |        1619.8    
                          f32 B=384, M=197, H=1, K=80    |         3965.0       |        4226.4      |    2171.7  |        4228.3    
                          f16 B=384, M=197, H=1, K=64    |         1135.2       |        1084.4      |    1043.8  |        1083.0    
                          f32 B=384, M=197, H=1, K=64    |         2673.2       |        2883.0      |    1744.5  |        2878.2    
                          f16 B=1024, M=197, H=1, K=88   |         4721.0       |        4396.6      |    3725.3  |        4404.1    
                          f32 B=1024, M=197, H=1, K=88   |        10531.3       |       11443.8      |    6106.5  |       11464.2    
                          f16 B=1024, M=197, H=1, K=80   |         4520.2       |        4216.2      |    3329.2  |        4223.7    
                          f32 B=1024, M=197, H=1, K=80   |         9573.5       |       10301.4      |    5757.6  |       10305.3    
                          f16 B=1024, M=197, H=1, K=64   |         2788.1       |        2660.1      |    2674.6  |        2663.5    
                          f32 B=1024, M=197, H=1, K=64   |         6556.7       |        7102.4      |    4516.6  |        7096.8    
                          f16 B=512, M=197, H=1, K=80    |         2377.2       |        2228.1      |    1685.0  |        2231.4    
                          f32 B=512, M=197, H=1, K=80    |         5259.3       |        5639.7      |    2887.8  |        5636.6    
                          f16 B=32, M=197, H=16, K=80    |         2403.0       |        2201.1      |    1798.2  |        2204.9    
                          f32 B=32, M=197, H=16, K=80    |         5402.3       |        5667.7      |    3046.3  |        5662.4    
                          f16 B=32, M=197, H=16, K=64    |         1552.7       |        1486.3      |    1451.6  |        1485.2    
                          f32 B=32, M=197, H=16, K=64    |         3622.5       |        3911.0      |    2427.0  |        3912.9    
                          f16 B=32, M=197, H=16, K=128   |         2776.6       |        2611.9      |    2211.3  |        2613.3    
                          f32 B=32, M=197, H=16, K=128   |         6647.9       |        7082.3      |    4088.5  |        7104.8    
                          f16 B=256, M=197, H=1, K=88    |         1357.5       |        1285.5      |     941.3  |        1287.6    
                          f32 B=256, M=197, H=1, K=88    |         2874.1       |        3085.1      |    1543.2  |        3090.1    
                          f16 B=16, M=197, H=16, K=88    |         1349.1       |        1264.5      |     964.7  |        1263.1    
                          f32 B=16, M=197, H=16, K=88    |         2803.8       |        2967.0      |    1647.4  |        2972.0    
                          f16 B=16, M=197, H=16, K=64    |          765.5       |         728.4      |     765.3  |         731.0    
                          f32 B=16, M=197, H=16, K=64    |         1834.1       |        1969.7      |    1282.4  |        1974.4    
                          f16 B=16, M=197, H=16, K=128   |         1509.1       |        1432.6      |    1139.1  |        1433.9    
                          f32 B=16, M=197, H=16, K=128   |         3406.5       |        3606.2      |    2048.6  |        3613.3    
                          f16 B=1, M=4096, H=160, K=128  |       168807.1       |      148652.9      |            |      149343.8    
                          f32 B=1, M=4096, H=160, K=128  |       549864.6       |      586699.9      |            |      585699.9    
                          f16 B=2, M=4096, H=160, K=128  |       339010.5       |      298827.7      |            |      298808.4    
                          f32 B=2, M=4096, H=160, K=128  |      1106963.8       |     1176218.3      |            |     1179173.2    
                          f16 B=1, M=8192, H=160, K=128  |       679742.4       |      594323.2      |            |      595580.1    
                          f32 B=1, M=8192, H=160, K=128  |      2195491.3       |     2340248.4      |            |     2343505.7    
                          f16 B=2, M=8192, H=160, K=128  |      1364983.7       |     1193787.8      |            |     1192596.1    
                          f16 B=1024, M=82, H=8, K=64    |         9052.8       |        8762.6      |    5804.2  |        8757.8    
                          f32 B=1024, M=82, H=8, K=64    |        14726.3       |       16270.6      |   11059.7  |       16215.9    
                          f16 B=150, M=256, H=16, K=64   |         5662.9       |        5519.4      |    7557.8  |        5524.5    
                          f32 B=150, M=256, H=16, K=64   |        16700.3       |       17612.7      |   16426.0  |       17640.4    
                          f16 B=64, M=256, H=12, K=64    |         1849.1       |        1793.1      |    2383.5  |        1798.7    
                          f32 B=64, M=256, H=12, K=64    |         5451.5       |        5766.4      |    4975.3  |        5775.8    
                          f16 B=1, M=4096, H=16, K=40    |        47263.3       |       47850.4      |    8315.6  |       47777.2    
                          f32 B=1, M=4096, H=16, K=40    |       113099.4       |      113164.5      |   19536.0  |      113930.0    
                          f16 B=1, M=16384, H=16, K=40   |       757091.8       |      770365.7      |            |      765401.3    
                          f32 B=1, M=16384, H=16, K=40   |      1806827.7       |     1816302.6      |            |     1819162.1    
                          f16 B=16, M=128, H=16, K=16    |          219.2       |         218.5      |     480.6  |         231.8    
                          f32 B=16, M=128, H=16, K=16    |          301.2       |         308.2      |     498.9  |         307.9    
                          f16 B=16, M=128, H=16, K=32    |          227.6       |         215.4      |     473.8  |         220.1    
                          f32 B=16, M=128, H=16, K=32    |          395.8       |         401.0      |     455.1  |         400.8    
                          f16 B=16, M=128, H=16, K=64    |          225.6       |         214.8      |     510.1  |         229.9    
                          f32 B=16, M=128, H=16, K=64    |          561.7       |         583.1      |     598.9  |         581.2    
                          f16 B=16, M=128, H=16, K=128   |          404.6       |         392.0      |     524.0  |         394.8    
                          f32 B=16, M=128, H=16, K=128   |         1103.0       |        1140.8      |    1015.4  |        1142.0    
                          f16 B=16, M=128, H=16, K=256   |         1045.3       |        1049.1      |     889.6  |        1047.4    
                          f32 B=16, M=128, H=16, K=256   |         2181.3       |        2270.9      |    1869.2  |        2265.7    
                          f16 B=16, M=512, H=16, K=16    |         1731.5       |        1585.7      |    1908.5  |        1586.2    
                          f32 B=16, M=512, H=16, K=16    |         4513.8       |        4696.7      |    4222.1  |        4695.6    
                          f16 B=16, M=512, H=16, K=32    |         1942.2       |        1823.4      |    2086.3  |        1809.7    
                          f32 B=16, M=512, H=16, K=32    |         5596.1       |        5819.4      |    4588.1  |        5833.6    
                          f16 B=16, M=512, H=16, K=64    |         2450.2       |        2340.8      |    2580.6  |        2353.2    
                          f32 B=16, M=512, H=16, K=64    |         7619.5       |        7853.1      |    5536.8  |        7875.1    
                          f16 B=16, M=512, H=16, K=128   |         4884.8       |        4473.5      |    3388.7  |        4487.9    
                          f32 B=16, M=512, H=16, K=128   |        15010.0       |       15557.1      |    8979.4  |       15513.1    
                          f16 B=16, M=512, H=16, K=256   |        12973.9       |       11134.7      |    5418.2  |       11106.3    
                          f32 B=16, M=512, H=16, K=256   |        29751.1       |       30979.1      |   16856.3  |       31012.8    
                          f16 B=16, M=1024, H=16, K=16   |         6802.6       |        6185.1      |    6996.0  |        6191.0    
                          f32 B=16, M=1024, H=16, K=16   |        18098.2       |       18954.5      |   16129.1  |       19166.8    
                          f16 B=16, M=1024, H=16, K=32   |         7531.9       |        7065.0      |    7436.0  |        7067.5    
                          f32 B=16, M=1024, H=16, K=32   |        21999.9       |       22901.5      |   17040.0  |       22899.4    
                          f16 B=16, M=1024, H=16, K=64   |         9312.0       |        8854.3      |    8605.6  |        8863.9    
                          f32 B=16, M=1024, H=16, K=64   |        29837.7       |       30895.1      |   20355.3  |       30878.5    
                          f16 B=16, M=1024, H=16, K=128  |        18979.0       |       16951.0      |   10561.3  |       16995.2    
                          f32 B=16, M=1024, H=16, K=128  |        58738.3       |       60861.2      |   33427.0  |       60599.8    
                          f16 B=16, M=1024, H=16, K=256  |        49681.9       |       41833.3      |   17329.5  |       41921.6    
                          f32 B=16, M=1024, H=16, K=256  |       117362.4       |      121004.8      |   60515.8  |      122046.9    
                          f16 B=64, M=128, H=16, K=16    |          432.2       |         411.1      |     642.1  |         411.3    
                          f32 B=64, M=128, H=16, K=16    |         1028.9       |        1057.9      |    1233.4  |        1056.6    
                          f16 B=64, M=128, H=16, K=32    |          522.5       |         500.7      |     813.6  |         499.7    
                          f32 B=64, M=128, H=16, K=32    |         1403.5       |        1443.4      |    1535.6  |        1443.6    
                          f16 B=64, M=128, H=16, K=64    |          750.0       |         739.8      |    1185.6  |         741.0    
                          f32 B=64, M=128, H=16, K=64    |         2013.9       |        2110.1      |    2156.8  |        2105.4    
                          f16 B=64, M=128, H=16, K=128   |         1421.2       |        1387.9      |    1915.5  |        1388.3    
                          f32 B=64, M=128, H=16, K=128   |         3946.5       |        4156.7      |    3780.1  |        4156.3    
                          f16 B=64, M=128, H=16, K=256   |         3811.5       |        3810.6      |    3448.7  |        3811.0    
                          f32 B=64, M=128, H=16, K=256   |         7983.9       |        8432.7      |    7304.0  |        8415.5    
                          f16 B=64, M=512, H=16, K=16    |         6157.4       |        5523.9      |    7461.8  |        5528.1    
                          f32 B=64, M=512, H=16, K=16    |        16118.2       |       17004.7      |   16651.7  |       16984.5    
                          f16 B=64, M=512, H=16, K=32    |         6985.1       |        6483.1      |    8278.5  |        6495.1    
                          f32 B=64, M=512, H=16, K=32    |        20471.9       |       21230.7      |   18420.5  |       21269.0    
                          f16 B=64, M=512, H=16, K=64    |         9045.8       |        8633.5      |   10337.1  |        8674.1    
                          f32 B=64, M=512, H=16, K=64    |        27675.0       |       29282.8      |   22883.7  |       29136.9    
                          f16 B=64, M=512, H=16, K=128   |        17594.2       |       15805.9      |   14700.6  |       15788.2    
                          f32 B=64, M=512, H=16, K=128   |        54612.4       |       57815.7      |   39974.3  |       57951.8    
                          f16 B=64, M=512, H=16, K=256   |        47452.6       |       40093.5      |   27087.5  |       40175.6    
                          f32 B=64, M=512, H=16, K=256   |       108880.3       |      115953.1      |   77794.2  |      115951.0    
                          f16 B=64, M=1024, H=16, K=16   |        24369.0       |       21533.0      |   28448.6  |       21556.6    
                          f32 B=64, M=1024, H=16, K=16   |        64649.7       |       68791.8      |            |       68407.4    
                          f16 B=64, M=1024, H=16, K=32   |        27143.1       |       25683.7      |   30252.7  |       25727.9    
                          f32 B=64, M=1024, H=16, K=32   |        79967.5       |       83351.5      |            |       83084.7    
                          f16 B=64, M=1024, H=16, K=64   |        34667.0       |       32592.7      |   36991.2  |       32659.8    
                          f32 B=64, M=1024, H=16, K=64   |       108282.2       |      113858.1      |            |      114286.0    
                          f16 B=64, M=1024, H=16, K=128  |        68519.5       |       59757.3      |   48834.4  |       59817.7    
                          f32 B=64, M=1024, H=16, K=128  |       215465.9       |      227335.8      |            |      227204.0    
                          f16 B=64, M=1024, H=16, K=256  |       183070.4       |      150960.1      |            |      150947.1    
                          f32 B=64, M=1024, H=16, K=256  |       425832.9       |      453717.2      |            |      453349.2    

Times are in microseconds (us).
```
</details>

<details>
<summary>bw A100 (f32/f16)</summary>

```
[----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------]                                                       
                                     |  48_accf32_69654fdb[cutlass]  |  flash[flshatt]  |  vanilla   |  56_base_02bf6b4e[cutlass]  |  57_tmpT_b516aec4[cutlass]
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |               613.4           |                  |    2264.9  |              612.8          |              578.3        
      f32 B=384, M=197, H=1, K=88    |              2335.2           |                  |    1843.0  |             2438.4          |             2425.1        
      f16 B=384, M=197, H=1, K=80    |               583.6           |                  |    1922.9  |              577.6          |              548.4        
      f32 B=384, M=197, H=1, K=80    |              2241.3           |                  |    1787.8  |             2333.2          |             2333.1        
      f16 B=384, M=197, H=1, K=64    |               405.6           |        232.5     |    1809.8  |              386.4          |              366.1        
      f32 B=384, M=197, H=1, K=64    |              1259.7           |                  |    1675.6  |             1309.6          |             1316.8        
      f16 B=1024, M=197, H=1, K=88   |              1538.2           |                  |    5964.7  |             1550.8          |             1454.4        
      f32 B=1024, M=197, H=1, K=88   |              6031.4           |                  |    4559.5  |             6325.4          |             6332.7        
      f16 B=1024, M=197, H=1, K=80   |              1458.8           |                  |    5038.2  |             1463.8          |             1379.0        
      f32 B=1024, M=197, H=1, K=80   |              5786.2           |                  |    4412.6  |             6059.0          |             6079.9        
      f16 B=1024, M=197, H=1, K=64   |               929.5           |        575.9     |    4735.6  |              862.1          |              821.6        
      f32 B=1024, M=197, H=1, K=64   |              3289.9           |                  |    4119.9  |             3434.4          |             3441.4        
      f16 B=512, M=197, H=1, K=80    |               744.5           |                  |    2544.7  |              735.3          |              695.8        
      f32 B=512, M=197, H=1, K=80    |              2889.4           |                  |    2286.1  |             3008.8          |             3029.0        
      f16 B=32, M=197, H=16, K=80    |               741.6           |                  |    2569.0  |              723.3          |              693.4        
      f32 B=32, M=197, H=16, K=80    |              2878.6           |                  |    2355.3  |             3003.0          |             3024.1        
      f16 B=32, M=197, H=16, K=64    |               478.5           |        295.7     |    2429.2  |              456.1          |              426.7        
      f32 B=32, M=197, H=16, K=64    |              1784.4           |                  |    2196.9  |             1863.7          |             1860.2        
      f16 B=32, M=197, H=16, K=128   |               887.0           |        682.3     |    4492.9  |              857.5          |              853.8        
      f32 B=32, M=197, H=16, K=128   |              3546.6           |                  |    2807.3  |             3734.9          |             3737.1        
      f16 B=256, M=197, H=1, K=88    |               445.1           |                  |    1528.9  |              445.0          |              422.2        
      f32 B=256, M=197, H=1, K=88    |              1678.8           |                  |    1207.6  |             1746.6          |             1752.8        
      f16 B=16, M=197, H=16, K=88    |               441.5           |                  |    1544.2  |              437.3          |              419.5        
      f32 B=16, M=197, H=16, K=88    |              1668.3           |                  |    1250.4  |             1742.7          |             1746.1        
      f16 B=16, M=197, H=16, K=64    |               247.4           |        165.6     |    1242.5  |              233.2          |              217.5        
      f32 B=16, M=197, H=16, K=64    |              1051.1           |                  |    1125.3  |             1099.2          |             1096.6        
      f16 B=16, M=197, H=16, K=128   |               498.4           |        386.2     |    2264.5  |              488.0          |              480.5        
      f32 B=16, M=197, H=16, K=128   |              1950.2           |                  |    1446.7  |             2039.0          |             2028.6        
      f16 B=1, M=4096, H=160, K=128  |             55915.0           |      54620.4     |   45909.5  |            63407.6          |            51298.8        
      f32 B=1, M=4096, H=160, K=128  |            238514.6           |                  |            |           232677.5          |           232672.5        
      f16 B=2, M=4096, H=160, K=128  |             93612.0           |      84238.0     |            |           100433.1          |            84858.1        
      f32 B=2, M=4096, H=160, K=128  |            375037.4           |                  |            |           364234.1          |           364407.2        
      f16 B=1, M=8192, H=160, K=128  |            223261.8           |     215499.8     |            |           251806.6          |           202133.7        
      f32 B=1, M=8192, H=160, K=128  |            946708.9           |                  |            |           924986.2          |           924988.9        
      f16 B=2, M=8192, H=160, K=128  |            367969.2           |     330092.8     |            |           395881.3          |           332000.6        
      f32 B=2, M=8192, H=160, K=128  |           1492031.4           |                  |            |          1448691.1          |          1449146.1        
      f16 B=1024, M=82, H=8, K=64    |              1890.2           |       1620.3     |    3819.7  |             1861.3          |             1764.6        
      f32 B=1024, M=82, H=8, K=64    |              8428.2           |                  |    8735.1  |             8831.3          |             8867.4        
      f16 B=150, M=256, H=16, K=64   |              2292.3           |       1625.3     |    4555.9  |             2109.8          |             2019.4        
      f32 B=150, M=256, H=16, K=64   |              6252.4           |                  |   12948.1  |             6288.9          |             6281.3        
      f16 B=64, M=256, H=12, K=64    |               782.2           |        567.4     |    1498.0  |              731.2          |              699.7        
      f32 B=64, M=256, H=12, K=64    |              2141.2           |                  |    4266.6  |             2160.6          |             2160.2        
      f16 B=1, M=4096, H=16, K=40    |             23504.2           |                  |    4196.1  |            23699.0          |            23008.9        
      f32 B=1, M=4096, H=16, K=40    |             73699.5           |                  |   17755.3  |            73261.8          |            73078.5        
      f16 B=1, M=16384, H=16, K=40   |            391408.9           |                  |            |           439777.3          |           407653.7        
      f32 B=1, M=16384, H=16, K=40   |           1196173.6           |                  |            |          1181547.9          |          1181625.1        
      f16 B=256, M=4096, H=16, K=64  |            733221.8           |     439627.5     |            |           603237.1          |           565905.8        
      f16 B=16, M=128, H=16, K=16    |               130.0           |        113.1     |     265.2  |              125.5          |              124.4        
      f32 B=16, M=128, H=16, K=16    |               161.5           |                  |     373.1  |              160.6          |              162.8        
      f16 B=16, M=128, H=16, K=32    |               125.8           |        111.5     |     263.8  |              122.1          |              125.8        
      f32 B=16, M=128, H=16, K=32    |               189.8           |                  |     412.6  |              196.3          |              196.2        
      f16 B=16, M=128, H=16, K=64    |               126.0           |        112.4     |     265.8  |              120.8          |              125.7        
      f32 B=16, M=128, H=16, K=64    |               272.1           |                  |     498.7  |              285.9          |              283.3        
      f16 B=16, M=128, H=16, K=128   |               181.3           |        158.4     |     298.5  |              178.0          |              186.5        
      f32 B=16, M=128, H=16, K=128   |               509.6           |                  |     673.9  |              521.1          |              521.5        
      f16 B=16, M=128, H=16, K=256   |               774.0           |                  |     541.4  |              775.5          |              757.0        
      f32 B=16, M=128, H=16, K=256   |               975.2           |                  |    1162.6  |              994.5          |              994.5        
      f16 B=16, M=512, H=16, K=16    |               621.0           |        322.6     |    1204.9  |              555.0          |              519.9        
      f32 B=16, M=512, H=16, K=16    |              2148.0           |                  |    4414.4  |             2178.9          |             2180.4        
      f16 B=16, M=512, H=16, K=32    |               709.9           |        435.6     |    1306.2  |              653.1          |              602.8        
      f32 B=16, M=512, H=16, K=32    |              2335.6           |                  |    4640.7  |             2336.0          |             2336.3        
      f16 B=16, M=512, H=16, K=64    |               917.8           |        702.7     |    1545.8  |              849.9          |              797.4        
      f32 B=16, M=512, H=16, K=64    |              2965.5           |                  |    5125.0  |             2986.9          |             2988.3        
      f16 B=16, M=512, H=16, K=128   |              1644.0           |       1584.1     |    1983.4  |             1757.0          |             1548.0        
      f32 B=16, M=512, H=16, K=128   |              6152.7           |                  |    6099.1  |             6067.7          |             6068.6        
      f16 B=16, M=512, H=16, K=256   |              8178.3           |                  |    2899.1  |             7895.5          |             7977.4        
      f32 B=16, M=512, H=16, K=256   |             11894.0           |                  |   10639.5  |            11635.0          |            11624.7        
      f16 B=16, M=1024, H=16, K=16   |              2420.5           |       1240.8     |    4259.2  |             2234.6          |             2048.9        
      f32 B=16, M=1024, H=16, K=16   |              8467.2           |                  |   16650.4  |             8512.2          |             8510.2        
      f16 B=16, M=1024, H=16, K=32   |              2675.7           |       1618.9     |    4491.2  |             2441.3          |             2230.0        
      f32 B=16, M=1024, H=16, K=32   |              9010.0           |                  |   17301.0  |             9012.1          |             9015.7        
      f16 B=16, M=1024, H=16, K=64   |              3328.4           |       2370.3     |    4994.5  |             3032.2          |             2820.0        
      f32 B=16, M=1024, H=16, K=64   |             11566.8           |                  |   18714.2  |            11494.1          |            11492.8        
      f16 B=16, M=1024, H=16, K=128  |              5867.9           |       5632.5     |    5952.8  |             6401.1          |             5440.8        
      f32 B=16, M=1024, H=16, K=128  |             23345.4           |                  |   21523.7  |            22859.1          |            22870.3        
      f16 B=16, M=1024, H=16, K=256  |             30619.2           |                  |    7893.1  |            29884.9          |            29060.4        
      f32 B=16, M=1024, H=16, K=256  |             45211.4           |                  |   38093.0  |            43435.3          |            43423.8        
      f16 B=64, M=128, H=16, K=16    |               159.6           |        145.2     |     439.9  |              161.2          |              167.0        
      f32 B=64, M=128, H=16, K=16    |               493.4           |                  |    1270.0  |              502.7          |              503.2        
      f16 B=64, M=128, H=16, K=32    |               208.5           |        212.1     |     545.3  |              206.2          |              204.9        
      f32 B=64, M=128, H=16, K=32    |               601.2           |                  |    1427.0  |              610.5          |              610.6        
      f16 B=64, M=128, H=16, K=64    |               329.0           |        310.8     |     766.0  |              327.2          |              314.7        
      f32 B=64, M=128, H=16, K=64    |               867.7           |                  |    1743.5  |              889.2          |              889.0        
      f16 B=64, M=128, H=16, K=128   |               635.5           |        562.1     |    1226.7  |              613.3          |              650.7        
      f32 B=64, M=128, H=16, K=128   |              1774.4           |                  |    2386.5  |             1800.0          |             1799.6        
      f16 B=64, M=128, H=16, K=256   |              2839.8           |                  |    2122.8  |             2765.2          |             2821.9        
      f32 B=64, M=128, H=16, K=256   |              3419.2           |                  |    4320.7  |             3458.8          |             3457.0        
      f16 B=64, M=512, H=16, K=16    |              2316.2           |       1202.3     |    4487.1  |             1983.9          |             1868.4        
      f32 B=64, M=512, H=16, K=16    |              6686.0           |                  |   16991.5  |             6709.9          |             6713.2        
      f16 B=64, M=512, H=16, K=32    |              2701.7           |       1541.9     |    4975.9  |             2346.4          |             2184.3        
      f32 B=64, M=512, H=16, K=32    |              7460.5           |                  |   17859.9  |             7429.6          |             7438.0        
      f16 B=64, M=512, H=16, K=64    |              3461.2           |       2418.2     |    5886.1  |             3083.1          |             2942.6        
      f32 B=64, M=512, H=16, K=64    |              9553.4           |                  |   19768.6  |             9526.5          |             9516.6        
      f16 B=64, M=512, H=16, K=128   |              5875.5           |       5443.1     |    7711.0  |             6141.3          |             5513.6        
      f32 B=64, M=512, H=16, K=128   |             21317.0           |                  |   23651.2  |            20931.9          |            20932.3        
      f16 B=64, M=512, H=16, K=256   |             31238.2           |                  |   11490.6  |            28626.8          |            28954.0        
      f32 B=64, M=512, H=16, K=256   |             41124.8           |                  |   42468.0  |            39562.8          |            39579.5        
      f16 B=64, M=1024, H=16, K=16   |              9142.8           |       4707.2     |   16882.8  |             7887.0          |             7314.5        
      f32 B=64, M=1024, H=16, K=16   |             26512.3           |                  |   66311.7  |            26497.5          |            26459.8        
      f16 B=64, M=1024, H=16, K=32   |             10420.8           |       5698.3     |   17875.0  |             8851.4          |             7974.0        
      f32 B=64, M=1024, H=16, K=32   |             28300.0           |                  |   69088.7  |            28102.1          |            28098.2        
      f16 B=64, M=1024, H=16, K=64   |             12948.5           |       8119.0     |   19944.3  |            11064.2          |            10477.6        
      f32 B=64, M=1024, H=16, K=64   |             35600.6           |                  |   74762.8  |            35316.4          |            35361.2        
      f16 B=64, M=1024, H=16, K=128  |             20820.5           |      19184.2     |   23699.3  |            21954.8          |            19220.0        
      f32 B=64, M=1024, H=16, K=128  |             80800.3           |                  |   86003.8  |            78521.9          |            78393.9        
      f16 B=64, M=1024, H=16, K=256  |            114411.1           |                  |   32958.3  |           103287.2          |           104304.2        
      f32 B=64, M=1024, H=16, K=256  |            155731.5           |                  |  153011.0  |           148071.6          |           148165.6        

Times are in microseconds (us).
```
</details>


[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Nov 29, 2022
ghstack-source-id: 29c8918b52a893018f19cebeb5367bdd6202405b
Pull Request resolved: #467
**PERFORMANCE**

This makes performance worse in f16 :(
But I think we need it for stability

<details>
<summary>bw P100/V100 (f32/f16)</summary>

```
[---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------]                                                    
                                                         |  48_accf32_6f8e2f15  |  56_base_02bf6b4e  |  vanilla   |  57_tmpT_b516aec4
1 threads: --------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6289.6       |        6936.0      |    2183.6  |        6940.0    
                          f32 B=384, M=197, H=1, K=88    |         8793.3       |        9446.6      |    2175.1  |        9429.2    
                          f16 B=384, M=197, H=1, K=80    |         5989.9       |        6596.2      |    2146.8  |        6608.6    
                          f32 B=384, M=197, H=1, K=80    |         8427.1       |        8993.1      |    2134.0  |        9030.5    
                          f16 B=384, M=197, H=1, K=64    |         3347.0       |        3527.4      |    1799.3  |        3538.3    
                          f32 B=384, M=197, H=1, K=64    |         5563.1       |        5984.2      |    1801.6  |        5980.9    
                          f16 B=1024, M=197, H=1, K=88   |        15680.5       |       17424.8      |    5671.9  |       17452.6    
                          f32 B=1024, M=197, H=1, K=88   |        23784.2       |       25542.8      |    5664.0  |       25578.6    
                          f16 B=1024, M=197, H=1, K=80   |        14935.5       |       16581.3      |    5559.5  |       16587.3    
                          f32 B=1024, M=197, H=1, K=80   |        22767.6       |       24354.1      |    5550.7  |       24362.0    
                          f16 B=1024, M=197, H=1, K=64   |         8331.1       |        8671.5      |    4644.9  |        8695.3    
                          f32 B=1024, M=197, H=1, K=64   |        15061.6       |       16148.8      |    4650.0  |       16201.4    
                          f16 B=512, M=197, H=1, K=80    |         7594.7       |        8306.6      |    2824.6  |        8336.5    
                          f32 B=512, M=197, H=1, K=80    |        11857.0       |       12660.2      |    2807.4  |       12709.2    
                          f16 B=32, M=197, H=16, K=80    |         7779.8       |        8342.0      |    2820.8  |        8393.5    
                          f32 B=32, M=197, H=16, K=80    |        11846.5       |       12487.8      |    2806.0  |       12549.1    
                          f16 B=32, M=197, H=16, K=64    |         4258.4       |        4445.6      |    2374.0  |        4461.2    
                          f32 B=32, M=197, H=16, K=64    |         7828.9       |        8434.9      |    2376.0  |        8472.2    
                          f16 B=32, M=197, H=16, K=128   |         9025.1       |        9671.9      |    3159.3  |        9707.2    
                          f32 B=32, M=197, H=16, K=128   |        14139.8       |       14920.7      |    3157.7  |       14928.9    
                          f16 B=256, M=197, H=1, K=88    |         4608.9       |        5118.8      |    1478.4  |        5119.2    
                          f32 B=256, M=197, H=1, K=88    |         6174.0       |        6644.3      |    1477.6  |        6642.7    
                          f16 B=16, M=197, H=16, K=88    |         4618.4       |        5073.4      |    1479.5  |        5062.9    
                          f32 B=16, M=197, H=16, K=88    |         6114.2       |        6471.7      |    1474.9  |        6478.7    
                          f16 B=16, M=197, H=16, K=64    |         2490.3       |        2557.5      |    1225.9  |        2550.0    
                          f32 B=16, M=197, H=16, K=64    |         3918.9       |        4208.2      |    1227.0  |        4195.7    
                          f16 B=16, M=197, H=16, K=128   |         5210.0       |        5649.3      |    1635.2  |        5648.1    
                          f32 B=16, M=197, H=16, K=128   |         7103.1       |        7451.6      |    1645.5  |        7445.1    
                          f16 B=1, M=4096, H=160, K=128  |      1014229.1       |     1106182.6      |            |     1108040.6    
                          f32 B=1, M=4096, H=160, K=128  |      1258173.2       |     1243183.8      |            |     1241548.7    
                          f16 B=2, M=4096, H=160, K=128  |      1642279.2       |     1753736.9      |            |     1771655.3    
                          f32 B=2, M=4096, H=160, K=128  |      2505435.4       |     2477353.1      |            |     2473773.4    
                          f16 B=1, M=8192, H=160, K=128  |      4050128.4       |     4415962.1      |            |     4428194.9    
                          f32 B=1, M=8192, H=160, K=128  |      5042352.6       |     4970582.0      |            |     4965069.9    
                          f16 B=2, M=8192, H=160, K=128  |      6600732.3       |     7026378.5      |            |     7068051.8    
                          f16 B=1024, M=82, H=8, K=64    |        21572.8       |       22531.6      |    9059.1  |       22418.4    
                          f32 B=1024, M=82, H=8, K=64    |        38178.4       |       45927.2      |    9070.3  |       45708.0    
                          f16 B=150, M=256, H=16, K=64   |        21436.5       |       21927.4      |   12938.5  |       22001.0    
                          f32 B=150, M=256, H=16, K=64   |        33024.2       |       33196.3      |   13249.2  |       33199.6    
                          f16 B=64, M=256, H=12, K=64    |         6869.8       |        7048.6      |    4200.6  |        7073.6    
                          f32 B=64, M=256, H=12, K=64    |        10719.9       |       10832.1      |    4271.3  |       10843.6    
                          f16 B=1, M=4096, H=16, K=40    |       134722.6       |      145429.8      |   20587.0  |      143743.8    
                          f32 B=1, M=4096, H=16, K=40    |       143015.1       |      147850.6      |   20625.8  |      148272.8    
                          f16 B=1, M=16384, H=16, K=40   |      2149850.4       |     2323732.4      |            |     2301489.4    
                          f32 B=1, M=16384, H=16, K=40   |      2286478.1       |     2369812.4      |            |     2375911.8    
                          f16 B=16, M=128, H=16, K=16    |          497.2       |         502.9      |     623.8  |         503.6    
                          f32 B=16, M=128, H=16, K=16    |          573.5       |         609.8      |     617.5  |         611.6    
                          f16 B=16, M=128, H=16, K=32    |          563.9       |         573.1      |     624.2  |         575.2    
                          f32 B=16, M=128, H=16, K=32    |          661.5       |         702.2      |     620.4  |         703.2    
                          f16 B=16, M=128, H=16, K=64    |          708.9       |         722.6      |     619.9  |         724.3    
                          f32 B=16, M=128, H=16, K=64    |          916.2       |         953.5      |     618.6  |         953.6    
                          f16 B=16, M=128, H=16, K=128   |         1465.8       |        1542.8      |     624.2  |        1545.7    
                          f32 B=16, M=128, H=16, K=128   |         1829.1       |        1872.3      |     616.2  |        1873.9    
                          f16 B=16, M=128, H=16, K=256   |         3796.3       |        4002.5      |    1010.7  |        4008.4    
                          f32 B=16, M=128, H=16, K=256   |         3838.8       |        3957.1      |    1203.6  |        3951.8    
                          f16 B=16, M=512, H=16, K=16    |         7680.4       |        7775.1      |    4848.7  |        7752.8    
                          f32 B=16, M=512, H=16, K=16    |         9402.0       |        9926.1      |    4926.7  |        9914.5    
                          f16 B=16, M=512, H=16, K=32    |         8897.5       |        8999.0      |    5055.9  |        9003.9    
                          f32 B=16, M=512, H=16, K=32    |        10762.1       |       11320.9      |    5065.4  |       11341.8    
                          f16 B=16, M=512, H=16, K=64    |        10936.9       |       11214.8      |    5484.9  |       11254.4    
                          f32 B=16, M=512, H=16, K=64    |        15091.9       |       15210.3      |    5552.0  |       15196.3    
                          f16 B=16, M=512, H=16, K=128   |        23491.0       |       25234.8      |    7317.1  |       25300.3    
                          f32 B=16, M=512, H=16, K=128   |        30524.6       |       30320.0      |    7487.0  |       30200.6    
                          f16 B=16, M=512, H=16, K=256   |        50389.7       |       54525.2      |   14015.3  |       54931.1    
                          f32 B=16, M=512, H=16, K=256   |        62155.6       |       61046.4      |   14285.5  |       61081.1    
                          f16 B=16, M=1024, H=16, K=16   |        31289.9       |       31951.9      |   18778.4  |       32044.9    
                          f32 B=16, M=1024, H=16, K=16   |        37744.4       |       39586.6      |   18929.9  |       39739.9    
                          f16 B=16, M=1024, H=16, K=32   |        35770.8       |       36651.1      |   19620.2  |       36909.0    
                          f32 B=16, M=1024, H=16, K=32   |        43211.5       |       45544.5      |   19518.4  |       45664.8    
                          f16 B=16, M=1024, H=16, K=64   |        43865.1       |       44864.8      |   21286.0  |       45345.7    
                          f32 B=16, M=1024, H=16, K=64   |        60710.5       |       60966.9      |   21634.1  |       60921.9    
                          f16 B=16, M=1024, H=16, K=128  |        94633.7       |      101796.0      |   28502.1  |      102871.6    
                          f32 B=16, M=1024, H=16, K=128  |       124093.6       |      122196.9      |   28520.3  |      122043.0    
                          f16 B=16, M=1024, H=16, K=256  |       194780.7       |      212303.0      |   55419.1  |      214643.5    
                          f32 B=16, M=1024, H=16, K=256  |       250799.1       |      245196.2      |   55634.0  |      245534.2    
                          f16 B=64, M=128, H=16, K=16    |         1658.0       |        1661.4      |    1331.0  |        1662.7    
                          f32 B=64, M=128, H=16, K=16    |         2129.2       |        2266.2      |    1371.8  |        2269.4    
                          f16 B=64, M=128, H=16, K=32    |         1904.3       |        1903.5      |    1384.2  |        1906.4    
                          f32 B=64, M=128, H=16, K=32    |         2496.5       |        2643.4      |    1445.1  |        2639.8    
                          f16 B=64, M=128, H=16, K=64    |         2393.1       |        2432.3      |    1505.2  |        2437.8    
                          f32 B=64, M=128, H=16, K=64    |         3471.1       |        3561.2      |    1590.0  |        3565.7    
                          f16 B=64, M=128, H=16, K=128   |         4969.1       |        5266.3      |    1988.4  |        5272.4    
                          f32 B=64, M=128, H=16, K=128   |         6880.3       |        7040.6      |    2121.0  |        7024.3    
                          f16 B=64, M=128, H=16, K=256   |        12635.3       |       13334.7      |    3859.5  |       13350.7    
                          f32 B=64, M=128, H=16, K=256   |        14185.7       |       14514.4      |    4553.8  |       14503.1    
                          f16 B=64, M=512, H=16, K=16    |        26278.4       |       26189.9      |   18916.0  |       26110.0    
                          f32 B=64, M=512, H=16, K=16    |        35128.2       |       37075.9      |   19191.4  |       37174.5    
                          f16 B=64, M=512, H=16, K=32    |        30414.2       |       30938.2      |   19734.7  |       31071.7    
                          f32 B=64, M=512, H=16, K=32    |        40483.2       |       42922.5      |   19843.9  |       42868.5    
                          f16 B=64, M=512, H=16, K=64    |        37248.2       |       38179.7      |   21640.1  |       38327.0    
                          f32 B=64, M=512, H=16, K=64    |        57666.8       |       57675.3      |   21909.1  |       57697.2    
                          f16 B=64, M=512, H=16, K=128   |        80113.6       |       86165.7      |   28765.5  |       86368.3    
                          f32 B=64, M=512, H=16, K=128   |       115672.5       |      115161.0      |   28910.3  |      115320.0    
                          f16 B=64, M=512, H=16, K=256   |       169250.0       |      183791.8      |   56315.7  |      183735.0    
                          f32 B=64, M=512, H=16, K=256   |       236594.6       |      233093.6      |   56853.4  |      233170.2    
                          f16 B=64, M=1024, H=16, K=16   |       106022.3       |      109588.2      |   74303.6  |      109410.4    
                          f32 B=64, M=1024, H=16, K=16   |       141241.8       |      148854.1      |            |      149651.5    
                          f16 B=64, M=1024, H=16, K=32   |       120899.1       |      125044.6      |   77828.6  |      125716.9    
                          f32 B=64, M=1024, H=16, K=32   |       162478.5       |      173906.1      |            |      173216.4    
                          f16 B=64, M=1024, H=16, K=64   |       149044.1       |      152290.6      |   85821.6  |      152748.5    
                          f32 B=64, M=1024, H=16, K=64   |       233195.9       |      231479.6      |            |      231533.9    
                          f16 B=64, M=1024, H=16, K=128  |       319761.5       |      344076.5      |  113579.4  |      345058.5    
                          f32 B=64, M=1024, H=16, K=128  |       470172.3       |      466330.7      |            |      463507.4    
                          f16 B=64, M=1024, H=16, K=256  |       658362.1       |      717070.2      |            |      723057.4    
                          f32 B=64, M=1024, H=16, K=256  |       955624.0       |      935114.3      |            |      935945.4    
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1811.3       |        1686.3      |    1375.2  |        1699.6    
                          f32 B=384, M=197, H=1, K=88    |         4315.7       |        4665.1      |    2257.1  |        4663.1    
                          f16 B=384, M=197, H=1, K=80    |         1733.0       |        1616.3      |    1281.7  |        1619.8    
                          f32 B=384, M=197, H=1, K=80    |         3965.0       |        4226.4      |    2171.7  |        4228.3    
                          f16 B=384, M=197, H=1, K=64    |         1135.2       |        1084.4      |    1043.8  |        1083.0    
                          f32 B=384, M=197, H=1, K=64    |         2673.2       |        2883.0      |    1744.5  |        2878.2    
                          f16 B=1024, M=197, H=1, K=88   |         4721.0       |        4396.6      |    3725.3  |        4404.1    
                          f32 B=1024, M=197, H=1, K=88   |        10531.3       |       11443.8      |    6106.5  |       11464.2    
                          f16 B=1024, M=197, H=1, K=80   |         4520.2       |        4216.2      |    3329.2  |        4223.7    
                          f32 B=1024, M=197, H=1, K=80   |         9573.5       |       10301.4      |    5757.6  |       10305.3    
                          f16 B=1024, M=197, H=1, K=64   |         2788.1       |        2660.1      |    2674.6  |        2663.5    
                          f32 B=1024, M=197, H=1, K=64   |         6556.7       |        7102.4      |    4516.6  |        7096.8    
                          f16 B=512, M=197, H=1, K=80    |         2377.2       |        2228.1      |    1685.0  |        2231.4    
                          f32 B=512, M=197, H=1, K=80    |         5259.3       |        5639.7      |    2887.8  |        5636.6    
                          f16 B=32, M=197, H=16, K=80    |         2403.0       |        2201.1      |    1798.2  |        2204.9    
                          f32 B=32, M=197, H=16, K=80    |         5402.3       |        5667.7      |    3046.3  |        5662.4    
                          f16 B=32, M=197, H=16, K=64    |         1552.7       |        1486.3      |    1451.6  |        1485.2    
                          f32 B=32, M=197, H=16, K=64    |         3622.5       |        3911.0      |    2427.0  |        3912.9    
                          f16 B=32, M=197, H=16, K=128   |         2776.6       |        2611.9      |    2211.3  |        2613.3    
                          f32 B=32, M=197, H=16, K=128   |         6647.9       |        7082.3      |    4088.5  |        7104.8    
                          f16 B=256, M=197, H=1, K=88    |         1357.5       |        1285.5      |     941.3  |        1287.6    
                          f32 B=256, M=197, H=1, K=88    |         2874.1       |        3085.1      |    1543.2  |        3090.1    
                          f16 B=16, M=197, H=16, K=88    |         1349.1       |        1264.5      |     964.7  |        1263.1    
                          f32 B=16, M=197, H=16, K=88    |         2803.8       |        2967.0      |    1647.4  |        2972.0    
                          f16 B=16, M=197, H=16, K=64    |          765.5       |         728.4      |     765.3  |         731.0    
                          f32 B=16, M=197, H=16, K=64    |         1834.1       |        1969.7      |    1282.4  |        1974.4    
                          f16 B=16, M=197, H=16, K=128   |         1509.1       |        1432.6      |    1139.1  |        1433.9    
                          f32 B=16, M=197, H=16, K=128   |         3406.5       |        3606.2      |    2048.6  |        3613.3    
                          f16 B=1, M=4096, H=160, K=128  |       168807.1       |      148652.9      |            |      149343.8    
                          f32 B=1, M=4096, H=160, K=128  |       549864.6       |      586699.9      |            |      585699.9    
                          f16 B=2, M=4096, H=160, K=128  |       339010.5       |      298827.7      |            |      298808.4    
                          f32 B=2, M=4096, H=160, K=128  |      1106963.8       |     1176218.3      |            |     1179173.2    
                          f16 B=1, M=8192, H=160, K=128  |       679742.4       |      594323.2      |            |      595580.1    
                          f32 B=1, M=8192, H=160, K=128  |      2195491.3       |     2340248.4      |            |     2343505.7    
                          f16 B=2, M=8192, H=160, K=128  |      1364983.7       |     1193787.8      |            |     1192596.1    
                          f16 B=1024, M=82, H=8, K=64    |         9052.8       |        8762.6      |    5804.2  |        8757.8    
                          f32 B=1024, M=82, H=8, K=64    |        14726.3       |       16270.6      |   11059.7  |       16215.9    
                          f16 B=150, M=256, H=16, K=64   |         5662.9       |        5519.4      |    7557.8  |        5524.5    
                          f32 B=150, M=256, H=16, K=64   |        16700.3       |       17612.7      |   16426.0  |       17640.4    
                          f16 B=64, M=256, H=12, K=64    |         1849.1       |        1793.1      |    2383.5  |        1798.7    
                          f32 B=64, M=256, H=12, K=64    |         5451.5       |        5766.4      |    4975.3  |        5775.8    
                          f16 B=1, M=4096, H=16, K=40    |        47263.3       |       47850.4      |    8315.6  |       47777.2    
                          f32 B=1, M=4096, H=16, K=40    |       113099.4       |      113164.5      |   19536.0  |      113930.0    
                          f16 B=1, M=16384, H=16, K=40   |       757091.8       |      770365.7      |            |      765401.3    
                          f32 B=1, M=16384, H=16, K=40   |      1806827.7       |     1816302.6      |            |     1819162.1    
                          f16 B=16, M=128, H=16, K=16    |          219.2       |         218.5      |     480.6  |         231.8    
                          f32 B=16, M=128, H=16, K=16    |          301.2       |         308.2      |     498.9  |         307.9    
                          f16 B=16, M=128, H=16, K=32    |          227.6       |         215.4      |     473.8  |         220.1    
                          f32 B=16, M=128, H=16, K=32    |          395.8       |         401.0      |     455.1  |         400.8    
                          f16 B=16, M=128, H=16, K=64    |          225.6       |         214.8      |     510.1  |         229.9    
                          f32 B=16, M=128, H=16, K=64    |          561.7       |         583.1      |     598.9  |         581.2    
                          f16 B=16, M=128, H=16, K=128   |          404.6       |         392.0      |     524.0  |         394.8    
                          f32 B=16, M=128, H=16, K=128   |         1103.0       |        1140.8      |    1015.4  |        1142.0    
                          f16 B=16, M=128, H=16, K=256   |         1045.3       |        1049.1      |     889.6  |        1047.4    
                          f32 B=16, M=128, H=16, K=256   |         2181.3       |        2270.9      |    1869.2  |        2265.7    
                          f16 B=16, M=512, H=16, K=16    |         1731.5       |        1585.7      |    1908.5  |        1586.2    
                          f32 B=16, M=512, H=16, K=16    |         4513.8       |        4696.7      |    4222.1  |        4695.6    
                          f16 B=16, M=512, H=16, K=32    |         1942.2       |        1823.4      |    2086.3  |        1809.7    
                          f32 B=16, M=512, H=16, K=32    |         5596.1       |        5819.4      |    4588.1  |        5833.6    
                          f16 B=16, M=512, H=16, K=64    |         2450.2       |        2340.8      |    2580.6  |        2353.2    
                          f32 B=16, M=512, H=16, K=64    |         7619.5       |        7853.1      |    5536.8  |        7875.1    
                          f16 B=16, M=512, H=16, K=128   |         4884.8       |        4473.5      |    3388.7  |        4487.9    
                          f32 B=16, M=512, H=16, K=128   |        15010.0       |       15557.1      |    8979.4  |       15513.1    
                          f16 B=16, M=512, H=16, K=256   |        12973.9       |       11134.7      |    5418.2  |       11106.3    
                          f32 B=16, M=512, H=16, K=256   |        29751.1       |       30979.1      |   16856.3  |       31012.8    
                          f16 B=16, M=1024, H=16, K=16   |         6802.6       |        6185.1      |    6996.0  |        6191.0    
                          f32 B=16, M=1024, H=16, K=16   |        18098.2       |       18954.5      |   16129.1  |       19166.8    
                          f16 B=16, M=1024, H=16, K=32   |         7531.9       |        7065.0      |    7436.0  |        7067.5    
                          f32 B=16, M=1024, H=16, K=32   |        21999.9       |       22901.5      |   17040.0  |       22899.4    
                          f16 B=16, M=1024, H=16, K=64   |         9312.0       |        8854.3      |    8605.6  |        8863.9    
                          f32 B=16, M=1024, H=16, K=64   |        29837.7       |       30895.1      |   20355.3  |       30878.5    
                          f16 B=16, M=1024, H=16, K=128  |        18979.0       |       16951.0      |   10561.3  |       16995.2    
                          f32 B=16, M=1024, H=16, K=128  |        58738.3       |       60861.2      |   33427.0  |       60599.8    
                          f16 B=16, M=1024, H=16, K=256  |        49681.9       |       41833.3      |   17329.5  |       41921.6    
                          f32 B=16, M=1024, H=16, K=256  |       117362.4       |      121004.8      |   60515.8  |      122046.9    
                          f16 B=64, M=128, H=16, K=16    |          432.2       |         411.1      |     642.1  |         411.3    
                          f32 B=64, M=128, H=16, K=16    |         1028.9       |        1057.9      |    1233.4  |        1056.6    
                          f16 B=64, M=128, H=16, K=32    |          522.5       |         500.7      |     813.6  |         499.7    
                          f32 B=64, M=128, H=16, K=32    |         1403.5       |        1443.4      |    1535.6  |        1443.6    
                          f16 B=64, M=128, H=16, K=64    |          750.0       |         739.8      |    1185.6  |         741.0    
                          f32 B=64, M=128, H=16, K=64    |         2013.9       |        2110.1      |    2156.8  |        2105.4    
                          f16 B=64, M=128, H=16, K=128   |         1421.2       |        1387.9      |    1915.5  |        1388.3    
                          f32 B=64, M=128, H=16, K=128   |         3946.5       |        4156.7      |    3780.1  |        4156.3    
                          f16 B=64, M=128, H=16, K=256   |         3811.5       |        3810.6      |    3448.7  |        3811.0    
                          f32 B=64, M=128, H=16, K=256   |         7983.9       |        8432.7      |    7304.0  |        8415.5    
                          f16 B=64, M=512, H=16, K=16    |         6157.4       |        5523.9      |    7461.8  |        5528.1    
                          f32 B=64, M=512, H=16, K=16    |        16118.2       |       17004.7      |   16651.7  |       16984.5    
                          f16 B=64, M=512, H=16, K=32    |         6985.1       |        6483.1      |    8278.5  |        6495.1    
                          f32 B=64, M=512, H=16, K=32    |        20471.9       |       21230.7      |   18420.5  |       21269.0    
                          f16 B=64, M=512, H=16, K=64    |         9045.8       |        8633.5      |   10337.1  |        8674.1    
                          f32 B=64, M=512, H=16, K=64    |        27675.0       |       29282.8      |   22883.7  |       29136.9    
                          f16 B=64, M=512, H=16, K=128   |        17594.2       |       15805.9      |   14700.6  |       15788.2    
                          f32 B=64, M=512, H=16, K=128   |        54612.4       |       57815.7      |   39974.3  |       57951.8    
                          f16 B=64, M=512, H=16, K=256   |        47452.6       |       40093.5      |   27087.5  |       40175.6    
                          f32 B=64, M=512, H=16, K=256   |       108880.3       |      115953.1      |   77794.2  |      115951.0    
                          f16 B=64, M=1024, H=16, K=16   |        24369.0       |       21533.0      |   28448.6  |       21556.6    
                          f32 B=64, M=1024, H=16, K=16   |        64649.7       |       68791.8      |            |       68407.4    
                          f16 B=64, M=1024, H=16, K=32   |        27143.1       |       25683.7      |   30252.7  |       25727.9    
                          f32 B=64, M=1024, H=16, K=32   |        79967.5       |       83351.5      |            |       83084.7    
                          f16 B=64, M=1024, H=16, K=64   |        34667.0       |       32592.7      |   36991.2  |       32659.8    
                          f32 B=64, M=1024, H=16, K=64   |       108282.2       |      113858.1      |            |      114286.0    
                          f16 B=64, M=1024, H=16, K=128  |        68519.5       |       59757.3      |   48834.4  |       59817.7    
                          f32 B=64, M=1024, H=16, K=128  |       215465.9       |      227335.8      |            |      227204.0    
                          f16 B=64, M=1024, H=16, K=256  |       183070.4       |      150960.1      |            |      150947.1    
                          f32 B=64, M=1024, H=16, K=256  |       425832.9       |      453717.2      |            |      453349.2    

Times are in microseconds (us).
```
</details>

<details>
<summary>bw A100 (f32/f16)</summary>

```
[----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------]                                                       
                                     |  48_accf32_69654fdb[cutlass]  |  flash[flshatt]  |  vanilla   |  56_base_02bf6b4e[cutlass]  |  57_tmpT_b516aec4[cutlass]
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |               613.4           |                  |    2264.9  |              612.8          |              578.3        
      f32 B=384, M=197, H=1, K=88    |              2335.2           |                  |    1843.0  |             2438.4          |             2425.1        
      f16 B=384, M=197, H=1, K=80    |               583.6           |                  |    1922.9  |              577.6          |              548.4        
      f32 B=384, M=197, H=1, K=80    |              2241.3           |                  |    1787.8  |             2333.2          |             2333.1        
      f16 B=384, M=197, H=1, K=64    |               405.6           |        232.5     |    1809.8  |              386.4          |              366.1        
      f32 B=384, M=197, H=1, K=64    |              1259.7           |                  |    1675.6  |             1309.6          |             1316.8        
      f16 B=1024, M=197, H=1, K=88   |              1538.2           |                  |    5964.7  |             1550.8          |             1454.4        
      f32 B=1024, M=197, H=1, K=88   |              6031.4           |                  |    4559.5  |             6325.4          |             6332.7        
      f16 B=1024, M=197, H=1, K=80   |              1458.8           |                  |    5038.2  |             1463.8          |             1379.0        
      f32 B=1024, M=197, H=1, K=80   |              5786.2           |                  |    4412.6  |             6059.0          |             6079.9        
      f16 B=1024, M=197, H=1, K=64   |               929.5           |        575.9     |    4735.6  |              862.1          |              821.6        
      f32 B=1024, M=197, H=1, K=64   |              3289.9           |                  |    4119.9  |             3434.4          |             3441.4        
      f16 B=512, M=197, H=1, K=80    |               744.5           |                  |    2544.7  |              735.3          |              695.8        
      f32 B=512, M=197, H=1, K=80    |              2889.4           |                  |    2286.1  |             3008.8          |             3029.0        
      f16 B=32, M=197, H=16, K=80    |               741.6           |                  |    2569.0  |              723.3          |              693.4        
      f32 B=32, M=197, H=16, K=80    |              2878.6           |                  |    2355.3  |             3003.0          |             3024.1        
      f16 B=32, M=197, H=16, K=64    |               478.5           |        295.7     |    2429.2  |              456.1          |              426.7        
      f32 B=32, M=197, H=16, K=64    |              1784.4           |                  |    2196.9  |             1863.7          |             1860.2        
      f16 B=32, M=197, H=16, K=128   |               887.0           |        682.3     |    4492.9  |              857.5          |              853.8        
      f32 B=32, M=197, H=16, K=128   |              3546.6           |                  |    2807.3  |             3734.9          |             3737.1        
      f16 B=256, M=197, H=1, K=88    |               445.1           |                  |    1528.9  |              445.0          |              422.2        
      f32 B=256, M=197, H=1, K=88    |              1678.8           |                  |    1207.6  |             1746.6          |             1752.8        
      f16 B=16, M=197, H=16, K=88    |               441.5           |                  |    1544.2  |              437.3          |              419.5        
      f32 B=16, M=197, H=16, K=88    |              1668.3           |                  |    1250.4  |             1742.7          |             1746.1        
      f16 B=16, M=197, H=16, K=64    |               247.4           |        165.6     |    1242.5  |              233.2          |              217.5        
      f32 B=16, M=197, H=16, K=64    |              1051.1           |                  |    1125.3  |             1099.2          |             1096.6        
      f16 B=16, M=197, H=16, K=128   |               498.4           |        386.2     |    2264.5  |              488.0          |              480.5        
      f32 B=16, M=197, H=16, K=128   |              1950.2           |                  |    1446.7  |             2039.0          |             2028.6        
      f16 B=1, M=4096, H=160, K=128  |             55915.0           |      54620.4     |   45909.5  |            63407.6          |            51298.8        
      f32 B=1, M=4096, H=160, K=128  |            238514.6           |                  |            |           232677.5          |           232672.5        
      f16 B=2, M=4096, H=160, K=128  |             93612.0           |      84238.0     |            |           100433.1          |            84858.1        
      f32 B=2, M=4096, H=160, K=128  |            375037.4           |                  |            |           364234.1          |           364407.2        
      f16 B=1, M=8192, H=160, K=128  |            223261.8           |     215499.8     |            |           251806.6          |           202133.7        
      f32 B=1, M=8192, H=160, K=128  |            946708.9           |                  |            |           924986.2          |           924988.9        
      f16 B=2, M=8192, H=160, K=128  |            367969.2           |     330092.8     |            |           395881.3          |           332000.6        
      f32 B=2, M=8192, H=160, K=128  |           1492031.4           |                  |            |          1448691.1          |          1449146.1        
      f16 B=1024, M=82, H=8, K=64    |              1890.2           |       1620.3     |    3819.7  |             1861.3          |             1764.6        
      f32 B=1024, M=82, H=8, K=64    |              8428.2           |                  |    8735.1  |             8831.3          |             8867.4        
      f16 B=150, M=256, H=16, K=64   |              2292.3           |       1625.3     |    4555.9  |             2109.8          |             2019.4        
      f32 B=150, M=256, H=16, K=64   |              6252.4           |                  |   12948.1  |             6288.9          |             6281.3        
      f16 B=64, M=256, H=12, K=64    |               782.2           |        567.4     |    1498.0  |              731.2          |              699.7        
      f32 B=64, M=256, H=12, K=64    |              2141.2           |                  |    4266.6  |             2160.6          |             2160.2        
      f16 B=1, M=4096, H=16, K=40    |             23504.2           |                  |    4196.1  |            23699.0          |            23008.9        
      f32 B=1, M=4096, H=16, K=40    |             73699.5           |                  |   17755.3  |            73261.8          |            73078.5        
      f16 B=1, M=16384, H=16, K=40   |            391408.9           |                  |            |           439777.3          |           407653.7        
      f32 B=1, M=16384, H=16, K=40   |           1196173.6           |                  |            |          1181547.9          |          1181625.1        
      f16 B=256, M=4096, H=16, K=64  |            733221.8           |     439627.5     |            |           603237.1          |           565905.8        
      f16 B=16, M=128, H=16, K=16    |               130.0           |        113.1     |     265.2  |              125.5          |              124.4        
      f32 B=16, M=128, H=16, K=16    |               161.5           |                  |     373.1  |              160.6          |              162.8        
      f16 B=16, M=128, H=16, K=32    |               125.8           |        111.5     |     263.8  |              122.1          |              125.8        
      f32 B=16, M=128, H=16, K=32    |               189.8           |                  |     412.6  |              196.3          |              196.2        
      f16 B=16, M=128, H=16, K=64    |               126.0           |        112.4     |     265.8  |              120.8          |              125.7        
      f32 B=16, M=128, H=16, K=64    |               272.1           |                  |     498.7  |              285.9          |              283.3        
      f16 B=16, M=128, H=16, K=128   |               181.3           |        158.4     |     298.5  |              178.0          |              186.5        
      f32 B=16, M=128, H=16, K=128   |               509.6           |                  |     673.9  |              521.1          |              521.5        
      f16 B=16, M=128, H=16, K=256   |               774.0           |                  |     541.4  |              775.5          |              757.0        
      f32 B=16, M=128, H=16, K=256   |               975.2           |                  |    1162.6  |              994.5          |              994.5        
      f16 B=16, M=512, H=16, K=16    |               621.0           |        322.6     |    1204.9  |              555.0          |              519.9        
      f32 B=16, M=512, H=16, K=16    |              2148.0           |                  |    4414.4  |             2178.9          |             2180.4        
      f16 B=16, M=512, H=16, K=32    |               709.9           |        435.6     |    1306.2  |              653.1          |              602.8        
      f32 B=16, M=512, H=16, K=32    |              2335.6           |                  |    4640.7  |             2336.0          |             2336.3        
      f16 B=16, M=512, H=16, K=64    |               917.8           |        702.7     |    1545.8  |              849.9          |              797.4        
      f32 B=16, M=512, H=16, K=64    |              2965.5           |                  |    5125.0  |             2986.9          |             2988.3        
      f16 B=16, M=512, H=16, K=128   |              1644.0           |       1584.1     |    1983.4  |             1757.0          |             1548.0        
      f32 B=16, M=512, H=16, K=128   |              6152.7           |                  |    6099.1  |             6067.7          |             6068.6        
      f16 B=16, M=512, H=16, K=256   |              8178.3           |                  |    2899.1  |             7895.5          |             7977.4        
      f32 B=16, M=512, H=16, K=256   |             11894.0           |                  |   10639.5  |            11635.0          |            11624.7        
      f16 B=16, M=1024, H=16, K=16   |              2420.5           |       1240.8     |    4259.2  |             2234.6          |             2048.9        
      f32 B=16, M=1024, H=16, K=16   |              8467.2           |                  |   16650.4  |             8512.2          |             8510.2        
      f16 B=16, M=1024, H=16, K=32   |              2675.7           |       1618.9     |    4491.2  |             2441.3          |             2230.0        
      f32 B=16, M=1024, H=16, K=32   |              9010.0           |                  |   17301.0  |             9012.1          |             9015.7        
      f16 B=16, M=1024, H=16, K=64   |              3328.4           |       2370.3     |    4994.5  |             3032.2          |             2820.0        
      f32 B=16, M=1024, H=16, K=64   |             11566.8           |                  |   18714.2  |            11494.1          |            11492.8        
      f16 B=16, M=1024, H=16, K=128  |              5867.9           |       5632.5     |    5952.8  |             6401.1          |             5440.8        
      f32 B=16, M=1024, H=16, K=128  |             23345.4           |                  |   21523.7  |            22859.1          |            22870.3        
      f16 B=16, M=1024, H=16, K=256  |             30619.2           |                  |    7893.1  |            29884.9          |            29060.4        
      f32 B=16, M=1024, H=16, K=256  |             45211.4           |                  |   38093.0  |            43435.3          |            43423.8        
      f16 B=64, M=128, H=16, K=16    |               159.6           |        145.2     |     439.9  |              161.2          |              167.0        
      f32 B=64, M=128, H=16, K=16    |               493.4           |                  |    1270.0  |              502.7          |              503.2        
      f16 B=64, M=128, H=16, K=32    |               208.5           |        212.1     |     545.3  |              206.2          |              204.9        
      f32 B=64, M=128, H=16, K=32    |               601.2           |                  |    1427.0  |              610.5          |              610.6        
      f16 B=64, M=128, H=16, K=64    |               329.0           |        310.8     |     766.0  |              327.2          |              314.7        
      f32 B=64, M=128, H=16, K=64    |               867.7           |                  |    1743.5  |              889.2          |              889.0        
      f16 B=64, M=128, H=16, K=128   |               635.5           |        562.1     |    1226.7  |              613.3          |              650.7        
      f32 B=64, M=128, H=16, K=128   |              1774.4           |                  |    2386.5  |             1800.0          |             1799.6        
      f16 B=64, M=128, H=16, K=256   |              2839.8           |                  |    2122.8  |             2765.2          |             2821.9        
      f32 B=64, M=128, H=16, K=256   |              3419.2           |                  |    4320.7  |             3458.8          |             3457.0        
      f16 B=64, M=512, H=16, K=16    |              2316.2           |       1202.3     |    4487.1  |             1983.9          |             1868.4        
      f32 B=64, M=512, H=16, K=16    |              6686.0           |                  |   16991.5  |             6709.9          |             6713.2        
      f16 B=64, M=512, H=16, K=32    |              2701.7           |       1541.9     |    4975.9  |             2346.4          |             2184.3        
      f32 B=64, M=512, H=16, K=32    |              7460.5           |                  |   17859.9  |             7429.6          |             7438.0        
      f16 B=64, M=512, H=16, K=64    |              3461.2           |       2418.2     |    5886.1  |             3083.1          |             2942.6        
      f32 B=64, M=512, H=16, K=64    |              9553.4           |                  |   19768.6  |             9526.5          |             9516.6        
      f16 B=64, M=512, H=16, K=128   |              5875.5           |       5443.1     |    7711.0  |             6141.3          |             5513.6        
      f32 B=64, M=512, H=16, K=128   |             21317.0           |                  |   23651.2  |            20931.9          |            20932.3        
      f16 B=64, M=512, H=16, K=256   |             31238.2           |                  |   11490.6  |            28626.8          |            28954.0        
      f32 B=64, M=512, H=16, K=256   |             41124.8           |                  |   42468.0  |            39562.8          |            39579.5        
      f16 B=64, M=1024, H=16, K=16   |              9142.8           |       4707.2     |   16882.8  |             7887.0          |             7314.5        
      f32 B=64, M=1024, H=16, K=16   |             26512.3           |                  |   66311.7  |            26497.5          |            26459.8        
      f16 B=64, M=1024, H=16, K=32   |             10420.8           |       5698.3     |   17875.0  |             8851.4          |             7974.0        
      f32 B=64, M=1024, H=16, K=32   |             28300.0           |                  |   69088.7  |            28102.1          |            28098.2        
      f16 B=64, M=1024, H=16, K=64   |             12948.5           |       8119.0     |   19944.3  |            11064.2          |            10477.6        
      f32 B=64, M=1024, H=16, K=64   |             35600.6           |                  |   74762.8  |            35316.4          |            35361.2        
      f16 B=64, M=1024, H=16, K=128  |             20820.5           |      19184.2     |   23699.3  |            21954.8          |            19220.0        
      f32 B=64, M=1024, H=16, K=128  |             80800.3           |                  |   86003.8  |            78521.9          |            78393.9        
      f16 B=64, M=1024, H=16, K=256  |            114411.1           |                  |   32958.3  |           103287.2          |           104304.2        
      f32 B=64, M=1024, H=16, K=256  |            155731.5           |                  |  153011.0  |           148071.6          |           148165.6        

Times are in microseconds (us).
```
</details>


[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Nov 29, 2022
ghstack-source-id: 18c8ba5fd2ce14141a4a39a1e6900707d17c19f0
Pull Request resolved: #467
**PERFORMANCE**

This makes performance worse in f16 :(
But I think we need it for stability

<details>
<summary>bw P100/V100 (f32/f16)</summary>

```
[---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------]                                                    
                                                         |  48_accf32_6f8e2f15  |  56_base_02bf6b4e  |  vanilla   |  57_tmpT_b516aec4
1 threads: --------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6289.6       |        6936.0      |    2183.6  |        6940.0    
                          f32 B=384, M=197, H=1, K=88    |         8793.3       |        9446.6      |    2175.1  |        9429.2    
                          f16 B=384, M=197, H=1, K=80    |         5989.9       |        6596.2      |    2146.8  |        6608.6    
                          f32 B=384, M=197, H=1, K=80    |         8427.1       |        8993.1      |    2134.0  |        9030.5    
                          f16 B=384, M=197, H=1, K=64    |         3347.0       |        3527.4      |    1799.3  |        3538.3    
                          f32 B=384, M=197, H=1, K=64    |         5563.1       |        5984.2      |    1801.6  |        5980.9    
                          f16 B=1024, M=197, H=1, K=88   |        15680.5       |       17424.8      |    5671.9  |       17452.6    
                          f32 B=1024, M=197, H=1, K=88   |        23784.2       |       25542.8      |    5664.0  |       25578.6    
                          f16 B=1024, M=197, H=1, K=80   |        14935.5       |       16581.3      |    5559.5  |       16587.3    
                          f32 B=1024, M=197, H=1, K=80   |        22767.6       |       24354.1      |    5550.7  |       24362.0    
                          f16 B=1024, M=197, H=1, K=64   |         8331.1       |        8671.5      |    4644.9  |        8695.3    
                          f32 B=1024, M=197, H=1, K=64   |        15061.6       |       16148.8      |    4650.0  |       16201.4    
                          f16 B=512, M=197, H=1, K=80    |         7594.7       |        8306.6      |    2824.6  |        8336.5    
                          f32 B=512, M=197, H=1, K=80    |        11857.0       |       12660.2      |    2807.4  |       12709.2    
                          f16 B=32, M=197, H=16, K=80    |         7779.8       |        8342.0      |    2820.8  |        8393.5    
                          f32 B=32, M=197, H=16, K=80    |        11846.5       |       12487.8      |    2806.0  |       12549.1    
                          f16 B=32, M=197, H=16, K=64    |         4258.4       |        4445.6      |    2374.0  |        4461.2    
                          f32 B=32, M=197, H=16, K=64    |         7828.9       |        8434.9      |    2376.0  |        8472.2    
                          f16 B=32, M=197, H=16, K=128   |         9025.1       |        9671.9      |    3159.3  |        9707.2    
                          f32 B=32, M=197, H=16, K=128   |        14139.8       |       14920.7      |    3157.7  |       14928.9    
                          f16 B=256, M=197, H=1, K=88    |         4608.9       |        5118.8      |    1478.4  |        5119.2    
                          f32 B=256, M=197, H=1, K=88    |         6174.0       |        6644.3      |    1477.6  |        6642.7    
                          f16 B=16, M=197, H=16, K=88    |         4618.4       |        5073.4      |    1479.5  |        5062.9    
                          f32 B=16, M=197, H=16, K=88    |         6114.2       |        6471.7      |    1474.9  |        6478.7    
                          f16 B=16, M=197, H=16, K=64    |         2490.3       |        2557.5      |    1225.9  |        2550.0    
                          f32 B=16, M=197, H=16, K=64    |         3918.9       |        4208.2      |    1227.0  |        4195.7    
                          f16 B=16, M=197, H=16, K=128   |         5210.0       |        5649.3      |    1635.2  |        5648.1    
                          f32 B=16, M=197, H=16, K=128   |         7103.1       |        7451.6      |    1645.5  |        7445.1    
                          f16 B=1, M=4096, H=160, K=128  |      1014229.1       |     1106182.6      |            |     1108040.6    
                          f32 B=1, M=4096, H=160, K=128  |      1258173.2       |     1243183.8      |            |     1241548.7    
                          f16 B=2, M=4096, H=160, K=128  |      1642279.2       |     1753736.9      |            |     1771655.3    
                          f32 B=2, M=4096, H=160, K=128  |      2505435.4       |     2477353.1      |            |     2473773.4    
                          f16 B=1, M=8192, H=160, K=128  |      4050128.4       |     4415962.1      |            |     4428194.9    
                          f32 B=1, M=8192, H=160, K=128  |      5042352.6       |     4970582.0      |            |     4965069.9    
                          f16 B=2, M=8192, H=160, K=128  |      6600732.3       |     7026378.5      |            |     7068051.8    
                          f16 B=1024, M=82, H=8, K=64    |        21572.8       |       22531.6      |    9059.1  |       22418.4    
                          f32 B=1024, M=82, H=8, K=64    |        38178.4       |       45927.2      |    9070.3  |       45708.0    
                          f16 B=150, M=256, H=16, K=64   |        21436.5       |       21927.4      |   12938.5  |       22001.0    
                          f32 B=150, M=256, H=16, K=64   |        33024.2       |       33196.3      |   13249.2  |       33199.6    
                          f16 B=64, M=256, H=12, K=64    |         6869.8       |        7048.6      |    4200.6  |        7073.6    
                          f32 B=64, M=256, H=12, K=64    |        10719.9       |       10832.1      |    4271.3  |       10843.6    
                          f16 B=1, M=4096, H=16, K=40    |       134722.6       |      145429.8      |   20587.0  |      143743.8    
                          f32 B=1, M=4096, H=16, K=40    |       143015.1       |      147850.6      |   20625.8  |      148272.8    
                          f16 B=1, M=16384, H=16, K=40   |      2149850.4       |     2323732.4      |            |     2301489.4    
                          f32 B=1, M=16384, H=16, K=40   |      2286478.1       |     2369812.4      |            |     2375911.8    
                          f16 B=16, M=128, H=16, K=16    |          497.2       |         502.9      |     623.8  |         503.6    
                          f32 B=16, M=128, H=16, K=16    |          573.5       |         609.8      |     617.5  |         611.6    
                          f16 B=16, M=128, H=16, K=32    |          563.9       |         573.1      |     624.2  |         575.2    
                          f32 B=16, M=128, H=16, K=32    |          661.5       |         702.2      |     620.4  |         703.2    
                          f16 B=16, M=128, H=16, K=64    |          708.9       |         722.6      |     619.9  |         724.3    
                          f32 B=16, M=128, H=16, K=64    |          916.2       |         953.5      |     618.6  |         953.6    
                          f16 B=16, M=128, H=16, K=128   |         1465.8       |        1542.8      |     624.2  |        1545.7    
                          f32 B=16, M=128, H=16, K=128   |         1829.1       |        1872.3      |     616.2  |        1873.9    
                          f16 B=16, M=128, H=16, K=256   |         3796.3       |        4002.5      |    1010.7  |        4008.4    
                          f32 B=16, M=128, H=16, K=256   |         3838.8       |        3957.1      |    1203.6  |        3951.8    
                          f16 B=16, M=512, H=16, K=16    |         7680.4       |        7775.1      |    4848.7  |        7752.8    
                          f32 B=16, M=512, H=16, K=16    |         9402.0       |        9926.1      |    4926.7  |        9914.5    
                          f16 B=16, M=512, H=16, K=32    |         8897.5       |        8999.0      |    5055.9  |        9003.9    
                          f32 B=16, M=512, H=16, K=32    |        10762.1       |       11320.9      |    5065.4  |       11341.8    
                          f16 B=16, M=512, H=16, K=64    |        10936.9       |       11214.8      |    5484.9  |       11254.4    
                          f32 B=16, M=512, H=16, K=64    |        15091.9       |       15210.3      |    5552.0  |       15196.3    
                          f16 B=16, M=512, H=16, K=128   |        23491.0       |       25234.8      |    7317.1  |       25300.3    
                          f32 B=16, M=512, H=16, K=128   |        30524.6       |       30320.0      |    7487.0  |       30200.6    
                          f16 B=16, M=512, H=16, K=256   |        50389.7       |       54525.2      |   14015.3  |       54931.1    
                          f32 B=16, M=512, H=16, K=256   |        62155.6       |       61046.4      |   14285.5  |       61081.1    
                          f16 B=16, M=1024, H=16, K=16   |        31289.9       |       31951.9      |   18778.4  |       32044.9    
                          f32 B=16, M=1024, H=16, K=16   |        37744.4       |       39586.6      |   18929.9  |       39739.9    
                          f16 B=16, M=1024, H=16, K=32   |        35770.8       |       36651.1      |   19620.2  |       36909.0    
                          f32 B=16, M=1024, H=16, K=32   |        43211.5       |       45544.5      |   19518.4  |       45664.8    
                          f16 B=16, M=1024, H=16, K=64   |        43865.1       |       44864.8      |   21286.0  |       45345.7    
                          f32 B=16, M=1024, H=16, K=64   |        60710.5       |       60966.9      |   21634.1  |       60921.9    
                          f16 B=16, M=1024, H=16, K=128  |        94633.7       |      101796.0      |   28502.1  |      102871.6    
                          f32 B=16, M=1024, H=16, K=128  |       124093.6       |      122196.9      |   28520.3  |      122043.0    
                          f16 B=16, M=1024, H=16, K=256  |       194780.7       |      212303.0      |   55419.1  |      214643.5    
                          f32 B=16, M=1024, H=16, K=256  |       250799.1       |      245196.2      |   55634.0  |      245534.2    
                          f16 B=64, M=128, H=16, K=16    |         1658.0       |        1661.4      |    1331.0  |        1662.7    
                          f32 B=64, M=128, H=16, K=16    |         2129.2       |        2266.2      |    1371.8  |        2269.4    
                          f16 B=64, M=128, H=16, K=32    |         1904.3       |        1903.5      |    1384.2  |        1906.4    
                          f32 B=64, M=128, H=16, K=32    |         2496.5       |        2643.4      |    1445.1  |        2639.8    
                          f16 B=64, M=128, H=16, K=64    |         2393.1       |        2432.3      |    1505.2  |        2437.8    
                          f32 B=64, M=128, H=16, K=64    |         3471.1       |        3561.2      |    1590.0  |        3565.7    
                          f16 B=64, M=128, H=16, K=128   |         4969.1       |        5266.3      |    1988.4  |        5272.4    
                          f32 B=64, M=128, H=16, K=128   |         6880.3       |        7040.6      |    2121.0  |        7024.3    
                          f16 B=64, M=128, H=16, K=256   |        12635.3       |       13334.7      |    3859.5  |       13350.7    
                          f32 B=64, M=128, H=16, K=256   |        14185.7       |       14514.4      |    4553.8  |       14503.1    
                          f16 B=64, M=512, H=16, K=16    |        26278.4       |       26189.9      |   18916.0  |       26110.0    
                          f32 B=64, M=512, H=16, K=16    |        35128.2       |       37075.9      |   19191.4  |       37174.5    
                          f16 B=64, M=512, H=16, K=32    |        30414.2       |       30938.2      |   19734.7  |       31071.7    
                          f32 B=64, M=512, H=16, K=32    |        40483.2       |       42922.5      |   19843.9  |       42868.5    
                          f16 B=64, M=512, H=16, K=64    |        37248.2       |       38179.7      |   21640.1  |       38327.0    
                          f32 B=64, M=512, H=16, K=64    |        57666.8       |       57675.3      |   21909.1  |       57697.2    
                          f16 B=64, M=512, H=16, K=128   |        80113.6       |       86165.7      |   28765.5  |       86368.3    
                          f32 B=64, M=512, H=16, K=128   |       115672.5       |      115161.0      |   28910.3  |      115320.0    
                          f16 B=64, M=512, H=16, K=256   |       169250.0       |      183791.8      |   56315.7  |      183735.0    
                          f32 B=64, M=512, H=16, K=256   |       236594.6       |      233093.6      |   56853.4  |      233170.2    
                          f16 B=64, M=1024, H=16, K=16   |       106022.3       |      109588.2      |   74303.6  |      109410.4    
                          f32 B=64, M=1024, H=16, K=16   |       141241.8       |      148854.1      |            |      149651.5    
                          f16 B=64, M=1024, H=16, K=32   |       120899.1       |      125044.6      |   77828.6  |      125716.9    
                          f32 B=64, M=1024, H=16, K=32   |       162478.5       |      173906.1      |            |      173216.4    
                          f16 B=64, M=1024, H=16, K=64   |       149044.1       |      152290.6      |   85821.6  |      152748.5    
                          f32 B=64, M=1024, H=16, K=64   |       233195.9       |      231479.6      |            |      231533.9    
                          f16 B=64, M=1024, H=16, K=128  |       319761.5       |      344076.5      |  113579.4  |      345058.5    
                          f32 B=64, M=1024, H=16, K=128  |       470172.3       |      466330.7      |            |      463507.4    
                          f16 B=64, M=1024, H=16, K=256  |       658362.1       |      717070.2      |            |      723057.4    
                          f32 B=64, M=1024, H=16, K=256  |       955624.0       |      935114.3      |            |      935945.4    
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1811.3       |        1686.3      |    1375.2  |        1699.6    
                          f32 B=384, M=197, H=1, K=88    |         4315.7       |        4665.1      |    2257.1  |        4663.1    
                          f16 B=384, M=197, H=1, K=80    |         1733.0       |        1616.3      |    1281.7  |        1619.8    
                          f32 B=384, M=197, H=1, K=80    |         3965.0       |        4226.4      |    2171.7  |        4228.3    
                          f16 B=384, M=197, H=1, K=64    |         1135.2       |        1084.4      |    1043.8  |        1083.0    
                          f32 B=384, M=197, H=1, K=64    |         2673.2       |        2883.0      |    1744.5  |        2878.2    
                          f16 B=1024, M=197, H=1, K=88   |         4721.0       |        4396.6      |    3725.3  |        4404.1    
                          f32 B=1024, M=197, H=1, K=88   |        10531.3       |       11443.8      |    6106.5  |       11464.2    
                          f16 B=1024, M=197, H=1, K=80   |         4520.2       |        4216.2      |    3329.2  |        4223.7    
                          f32 B=1024, M=197, H=1, K=80   |         9573.5       |       10301.4      |    5757.6  |       10305.3    
                          f16 B=1024, M=197, H=1, K=64   |         2788.1       |        2660.1      |    2674.6  |        2663.5    
                          f32 B=1024, M=197, H=1, K=64   |         6556.7       |        7102.4      |    4516.6  |        7096.8    
                          f16 B=512, M=197, H=1, K=80    |         2377.2       |        2228.1      |    1685.0  |        2231.4    
                          f32 B=512, M=197, H=1, K=80    |         5259.3       |        5639.7      |    2887.8  |        5636.6    
                          f16 B=32, M=197, H=16, K=80    |         2403.0       |        2201.1      |    1798.2  |        2204.9    
                          f32 B=32, M=197, H=16, K=80    |         5402.3       |        5667.7      |    3046.3  |        5662.4    
                          f16 B=32, M=197, H=16, K=64    |         1552.7       |        1486.3      |    1451.6  |        1485.2    
                          f32 B=32, M=197, H=16, K=64    |         3622.5       |        3911.0      |    2427.0  |        3912.9    
                          f16 B=32, M=197, H=16, K=128   |         2776.6       |        2611.9      |    2211.3  |        2613.3    
                          f32 B=32, M=197, H=16, K=128   |         6647.9       |        7082.3      |    4088.5  |        7104.8    
                          f16 B=256, M=197, H=1, K=88    |         1357.5       |        1285.5      |     941.3  |        1287.6    
                          f32 B=256, M=197, H=1, K=88    |         2874.1       |        3085.1      |    1543.2  |        3090.1    
                          f16 B=16, M=197, H=16, K=88    |         1349.1       |        1264.5      |     964.7  |        1263.1    
                          f32 B=16, M=197, H=16, K=88    |         2803.8       |        2967.0      |    1647.4  |        2972.0    
                          f16 B=16, M=197, H=16, K=64    |          765.5       |         728.4      |     765.3  |         731.0    
                          f32 B=16, M=197, H=16, K=64    |         1834.1       |        1969.7      |    1282.4  |        1974.4    
                          f16 B=16, M=197, H=16, K=128   |         1509.1       |        1432.6      |    1139.1  |        1433.9    
                          f32 B=16, M=197, H=16, K=128   |         3406.5       |        3606.2      |    2048.6  |        3613.3    
                          f16 B=1, M=4096, H=160, K=128  |       168807.1       |      148652.9      |            |      149343.8    
                          f32 B=1, M=4096, H=160, K=128  |       549864.6       |      586699.9      |            |      585699.9    
                          f16 B=2, M=4096, H=160, K=128  |       339010.5       |      298827.7      |            |      298808.4    
                          f32 B=2, M=4096, H=160, K=128  |      1106963.8       |     1176218.3      |            |     1179173.2    
                          f16 B=1, M=8192, H=160, K=128  |       679742.4       |      594323.2      |            |      595580.1    
                          f32 B=1, M=8192, H=160, K=128  |      2195491.3       |     2340248.4      |            |     2343505.7    
                          f16 B=2, M=8192, H=160, K=128  |      1364983.7       |     1193787.8      |            |     1192596.1    
                          f16 B=1024, M=82, H=8, K=64    |         9052.8       |        8762.6      |    5804.2  |        8757.8    
                          f32 B=1024, M=82, H=8, K=64    |        14726.3       |       16270.6      |   11059.7  |       16215.9    
                          f16 B=150, M=256, H=16, K=64   |         5662.9       |        5519.4      |    7557.8  |        5524.5    
                          f32 B=150, M=256, H=16, K=64   |        16700.3       |       17612.7      |   16426.0  |       17640.4    
                          f16 B=64, M=256, H=12, K=64    |         1849.1       |        1793.1      |    2383.5  |        1798.7    
                          f32 B=64, M=256, H=12, K=64    |         5451.5       |        5766.4      |    4975.3  |        5775.8    
                          f16 B=1, M=4096, H=16, K=40    |        47263.3       |       47850.4      |    8315.6  |       47777.2    
                          f32 B=1, M=4096, H=16, K=40    |       113099.4       |      113164.5      |   19536.0  |      113930.0    
                          f16 B=1, M=16384, H=16, K=40   |       757091.8       |      770365.7      |            |      765401.3    
                          f32 B=1, M=16384, H=16, K=40   |      1806827.7       |     1816302.6      |            |     1819162.1    
                          f16 B=16, M=128, H=16, K=16    |          219.2       |         218.5      |     480.6  |         231.8    
                          f32 B=16, M=128, H=16, K=16    |          301.2       |         308.2      |     498.9  |         307.9    
                          f16 B=16, M=128, H=16, K=32    |          227.6       |         215.4      |     473.8  |         220.1    
                          f32 B=16, M=128, H=16, K=32    |          395.8       |         401.0      |     455.1  |         400.8    
                          f16 B=16, M=128, H=16, K=64    |          225.6       |         214.8      |     510.1  |         229.9    
                          f32 B=16, M=128, H=16, K=64    |          561.7       |         583.1      |     598.9  |         581.2    
                          f16 B=16, M=128, H=16, K=128   |          404.6       |         392.0      |     524.0  |         394.8    
                          f32 B=16, M=128, H=16, K=128   |         1103.0       |        1140.8      |    1015.4  |        1142.0    
                          f16 B=16, M=128, H=16, K=256   |         1045.3       |        1049.1      |     889.6  |        1047.4    
                          f32 B=16, M=128, H=16, K=256   |         2181.3       |        2270.9      |    1869.2  |        2265.7    
                          f16 B=16, M=512, H=16, K=16    |         1731.5       |        1585.7      |    1908.5  |        1586.2    
                          f32 B=16, M=512, H=16, K=16    |         4513.8       |        4696.7      |    4222.1  |        4695.6    
                          f16 B=16, M=512, H=16, K=32    |         1942.2       |        1823.4      |    2086.3  |        1809.7    
                          f32 B=16, M=512, H=16, K=32    |         5596.1       |        5819.4      |    4588.1  |        5833.6    
                          f16 B=16, M=512, H=16, K=64    |         2450.2       |        2340.8      |    2580.6  |        2353.2    
                          f32 B=16, M=512, H=16, K=64    |         7619.5       |        7853.1      |    5536.8  |        7875.1    
                          f16 B=16, M=512, H=16, K=128   |         4884.8       |        4473.5      |    3388.7  |        4487.9    
                          f32 B=16, M=512, H=16, K=128   |        15010.0       |       15557.1      |    8979.4  |       15513.1    
                          f16 B=16, M=512, H=16, K=256   |        12973.9       |       11134.7      |    5418.2  |       11106.3    
                          f32 B=16, M=512, H=16, K=256   |        29751.1       |       30979.1      |   16856.3  |       31012.8    
                          f16 B=16, M=1024, H=16, K=16   |         6802.6       |        6185.1      |    6996.0  |        6191.0    
                          f32 B=16, M=1024, H=16, K=16   |        18098.2       |       18954.5      |   16129.1  |       19166.8    
                          f16 B=16, M=1024, H=16, K=32   |         7531.9       |        7065.0      |    7436.0  |        7067.5    
                          f32 B=16, M=1024, H=16, K=32   |        21999.9       |       22901.5      |   17040.0  |       22899.4    
                          f16 B=16, M=1024, H=16, K=64   |         9312.0       |        8854.3      |    8605.6  |        8863.9    
                          f32 B=16, M=1024, H=16, K=64   |        29837.7       |       30895.1      |   20355.3  |       30878.5    
                          f16 B=16, M=1024, H=16, K=128  |        18979.0       |       16951.0      |   10561.3  |       16995.2    
                          f32 B=16, M=1024, H=16, K=128  |        58738.3       |       60861.2      |   33427.0  |       60599.8    
                          f16 B=16, M=1024, H=16, K=256  |        49681.9       |       41833.3      |   17329.5  |       41921.6    
                          f32 B=16, M=1024, H=16, K=256  |       117362.4       |      121004.8      |   60515.8  |      122046.9    
                          f16 B=64, M=128, H=16, K=16    |          432.2       |         411.1      |     642.1  |         411.3    
                          f32 B=64, M=128, H=16, K=16    |         1028.9       |        1057.9      |    1233.4  |        1056.6    
                          f16 B=64, M=128, H=16, K=32    |          522.5       |         500.7      |     813.6  |         499.7    
                          f32 B=64, M=128, H=16, K=32    |         1403.5       |        1443.4      |    1535.6  |        1443.6    
                          f16 B=64, M=128, H=16, K=64    |          750.0       |         739.8      |    1185.6  |         741.0    
                          f32 B=64, M=128, H=16, K=64    |         2013.9       |        2110.1      |    2156.8  |        2105.4    
                          f16 B=64, M=128, H=16, K=128   |         1421.2       |        1387.9      |    1915.5  |        1388.3    
                          f32 B=64, M=128, H=16, K=128   |         3946.5       |        4156.7      |    3780.1  |        4156.3    
                          f16 B=64, M=128, H=16, K=256   |         3811.5       |        3810.6      |    3448.7  |        3811.0    
                          f32 B=64, M=128, H=16, K=256   |         7983.9       |        8432.7      |    7304.0  |        8415.5    
                          f16 B=64, M=512, H=16, K=16    |         6157.4       |        5523.9      |    7461.8  |        5528.1    
                          f32 B=64, M=512, H=16, K=16    |        16118.2       |       17004.7      |   16651.7  |       16984.5    
                          f16 B=64, M=512, H=16, K=32    |         6985.1       |        6483.1      |    8278.5  |        6495.1    
                          f32 B=64, M=512, H=16, K=32    |        20471.9       |       21230.7      |   18420.5  |       21269.0    
                          f16 B=64, M=512, H=16, K=64    |         9045.8       |        8633.5      |   10337.1  |        8674.1    
                          f32 B=64, M=512, H=16, K=64    |        27675.0       |       29282.8      |   22883.7  |       29136.9    
                          f16 B=64, M=512, H=16, K=128   |        17594.2       |       15805.9      |   14700.6  |       15788.2    
                          f32 B=64, M=512, H=16, K=128   |        54612.4       |       57815.7      |   39974.3  |       57951.8    
                          f16 B=64, M=512, H=16, K=256   |        47452.6       |       40093.5      |   27087.5  |       40175.6    
                          f32 B=64, M=512, H=16, K=256   |       108880.3       |      115953.1      |   77794.2  |      115951.0    
                          f16 B=64, M=1024, H=16, K=16   |        24369.0       |       21533.0      |   28448.6  |       21556.6    
                          f32 B=64, M=1024, H=16, K=16   |        64649.7       |       68791.8      |            |       68407.4    
                          f16 B=64, M=1024, H=16, K=32   |        27143.1       |       25683.7      |   30252.7  |       25727.9    
                          f32 B=64, M=1024, H=16, K=32   |        79967.5       |       83351.5      |            |       83084.7    
                          f16 B=64, M=1024, H=16, K=64   |        34667.0       |       32592.7      |   36991.2  |       32659.8    
                          f32 B=64, M=1024, H=16, K=64   |       108282.2       |      113858.1      |            |      114286.0    
                          f16 B=64, M=1024, H=16, K=128  |        68519.5       |       59757.3      |   48834.4  |       59817.7    
                          f32 B=64, M=1024, H=16, K=128  |       215465.9       |      227335.8      |            |      227204.0    
                          f16 B=64, M=1024, H=16, K=256  |       183070.4       |      150960.1      |            |      150947.1    
                          f32 B=64, M=1024, H=16, K=256  |       425832.9       |      453717.2      |            |      453349.2    

Times are in microseconds (us).
```
</details>

<details>
<summary>bw A100 (f32/f16)</summary>

```
[----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------]                                                       
                                     |  48_accf32_69654fdb[cutlass]  |  flash[flshatt]  |  vanilla   |  56_base_02bf6b4e[cutlass]  |  57_tmpT_b516aec4[cutlass]
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |               613.4           |                  |    2264.9  |              612.8          |              578.3        
      f32 B=384, M=197, H=1, K=88    |              2335.2           |                  |    1843.0  |             2438.4          |             2425.1        
      f16 B=384, M=197, H=1, K=80    |               583.6           |                  |    1922.9  |              577.6          |              548.4        
      f32 B=384, M=197, H=1, K=80    |              2241.3           |                  |    1787.8  |             2333.2          |             2333.1        
      f16 B=384, M=197, H=1, K=64    |               405.6           |        232.5     |    1809.8  |              386.4          |              366.1        
      f32 B=384, M=197, H=1, K=64    |              1259.7           |                  |    1675.6  |             1309.6          |             1316.8        
      f16 B=1024, M=197, H=1, K=88   |              1538.2           |                  |    5964.7  |             1550.8          |             1454.4        
      f32 B=1024, M=197, H=1, K=88   |              6031.4           |                  |    4559.5  |             6325.4          |             6332.7        
      f16 B=1024, M=197, H=1, K=80   |              1458.8           |                  |    5038.2  |             1463.8          |             1379.0        
      f32 B=1024, M=197, H=1, K=80   |              5786.2           |                  |    4412.6  |             6059.0          |             6079.9        
      f16 B=1024, M=197, H=1, K=64   |               929.5           |        575.9     |    4735.6  |              862.1          |              821.6        
      f32 B=1024, M=197, H=1, K=64   |              3289.9           |                  |    4119.9  |             3434.4          |             3441.4        
      f16 B=512, M=197, H=1, K=80    |               744.5           |                  |    2544.7  |              735.3          |              695.8        
      f32 B=512, M=197, H=1, K=80    |              2889.4           |                  |    2286.1  |             3008.8          |             3029.0        
      f16 B=32, M=197, H=16, K=80    |               741.6           |                  |    2569.0  |              723.3          |              693.4        
      f32 B=32, M=197, H=16, K=80    |              2878.6           |                  |    2355.3  |             3003.0          |             3024.1        
      f16 B=32, M=197, H=16, K=64    |               478.5           |        295.7     |    2429.2  |              456.1          |              426.7        
      f32 B=32, M=197, H=16, K=64    |              1784.4           |                  |    2196.9  |             1863.7          |             1860.2        
      f16 B=32, M=197, H=16, K=128   |               887.0           |        682.3     |    4492.9  |              857.5          |              853.8        
      f32 B=32, M=197, H=16, K=128   |              3546.6           |                  |    2807.3  |             3734.9          |             3737.1        
      f16 B=256, M=197, H=1, K=88    |               445.1           |                  |    1528.9  |              445.0          |              422.2        
      f32 B=256, M=197, H=1, K=88    |              1678.8           |                  |    1207.6  |             1746.6          |             1752.8        
      f16 B=16, M=197, H=16, K=88    |               441.5           |                  |    1544.2  |              437.3          |              419.5        
      f32 B=16, M=197, H=16, K=88    |              1668.3           |                  |    1250.4  |             1742.7          |             1746.1        
      f16 B=16, M=197, H=16, K=64    |               247.4           |        165.6     |    1242.5  |              233.2          |              217.5        
      f32 B=16, M=197, H=16, K=64    |              1051.1           |                  |    1125.3  |             1099.2          |             1096.6        
      f16 B=16, M=197, H=16, K=128   |               498.4           |        386.2     |    2264.5  |              488.0          |              480.5        
      f32 B=16, M=197, H=16, K=128   |              1950.2           |                  |    1446.7  |             2039.0          |             2028.6        
      f16 B=1, M=4096, H=160, K=128  |             55915.0           |      54620.4     |   45909.5  |            63407.6          |            51298.8        
      f32 B=1, M=4096, H=160, K=128  |            238514.6           |                  |            |           232677.5          |           232672.5        
      f16 B=2, M=4096, H=160, K=128  |             93612.0           |      84238.0     |            |           100433.1          |            84858.1        
      f32 B=2, M=4096, H=160, K=128  |            375037.4           |                  |            |           364234.1          |           364407.2        
      f16 B=1, M=8192, H=160, K=128  |            223261.8           |     215499.8     |            |           251806.6          |           202133.7        
      f32 B=1, M=8192, H=160, K=128  |            946708.9           |                  |            |           924986.2          |           924988.9        
      f16 B=2, M=8192, H=160, K=128  |            367969.2           |     330092.8     |            |           395881.3          |           332000.6        
      f32 B=2, M=8192, H=160, K=128  |           1492031.4           |                  |            |          1448691.1          |          1449146.1        
      f16 B=1024, M=82, H=8, K=64    |              1890.2           |       1620.3     |    3819.7  |             1861.3          |             1764.6        
      f32 B=1024, M=82, H=8, K=64    |              8428.2           |                  |    8735.1  |             8831.3          |             8867.4        
      f16 B=150, M=256, H=16, K=64   |              2292.3           |       1625.3     |    4555.9  |             2109.8          |             2019.4        
      f32 B=150, M=256, H=16, K=64   |              6252.4           |                  |   12948.1  |             6288.9          |             6281.3        
      f16 B=64, M=256, H=12, K=64    |               782.2           |        567.4     |    1498.0  |              731.2          |              699.7        
      f32 B=64, M=256, H=12, K=64    |              2141.2           |                  |    4266.6  |             2160.6          |             2160.2        
      f16 B=1, M=4096, H=16, K=40    |             23504.2           |                  |    4196.1  |            23699.0          |            23008.9        
      f32 B=1, M=4096, H=16, K=40    |             73699.5           |                  |   17755.3  |            73261.8          |            73078.5        
      f16 B=1, M=16384, H=16, K=40   |            391408.9           |                  |            |           439777.3          |           407653.7        
      f32 B=1, M=16384, H=16, K=40   |           1196173.6           |                  |            |          1181547.9          |          1181625.1        
      f16 B=256, M=4096, H=16, K=64  |            733221.8           |     439627.5     |            |           603237.1          |           565905.8        
      f16 B=16, M=128, H=16, K=16    |               130.0           |        113.1     |     265.2  |              125.5          |              124.4        
      f32 B=16, M=128, H=16, K=16    |               161.5           |                  |     373.1  |              160.6          |              162.8        
      f16 B=16, M=128, H=16, K=32    |               125.8           |        111.5     |     263.8  |              122.1          |              125.8        
      f32 B=16, M=128, H=16, K=32    |               189.8           |                  |     412.6  |              196.3          |              196.2        
      f16 B=16, M=128, H=16, K=64    |               126.0           |        112.4     |     265.8  |              120.8          |              125.7        
      f32 B=16, M=128, H=16, K=64    |               272.1           |                  |     498.7  |              285.9          |              283.3        
      f16 B=16, M=128, H=16, K=128   |               181.3           |        158.4     |     298.5  |              178.0          |              186.5        
      f32 B=16, M=128, H=16, K=128   |               509.6           |                  |     673.9  |              521.1          |              521.5        
      f16 B=16, M=128, H=16, K=256   |               774.0           |                  |     541.4  |              775.5          |              757.0        
      f32 B=16, M=128, H=16, K=256   |               975.2           |                  |    1162.6  |              994.5          |              994.5        
      f16 B=16, M=512, H=16, K=16    |               621.0           |        322.6     |    1204.9  |              555.0          |              519.9        
      f32 B=16, M=512, H=16, K=16    |              2148.0           |                  |    4414.4  |             2178.9          |             2180.4        
      f16 B=16, M=512, H=16, K=32    |               709.9           |        435.6     |    1306.2  |              653.1          |              602.8        
      f32 B=16, M=512, H=16, K=32    |              2335.6           |                  |    4640.7  |             2336.0          |             2336.3        
      f16 B=16, M=512, H=16, K=64    |               917.8           |        702.7     |    1545.8  |              849.9          |              797.4        
      f32 B=16, M=512, H=16, K=64    |              2965.5           |                  |    5125.0  |             2986.9          |             2988.3        
      f16 B=16, M=512, H=16, K=128   |              1644.0           |       1584.1     |    1983.4  |             1757.0          |             1548.0        
      f32 B=16, M=512, H=16, K=128   |              6152.7           |                  |    6099.1  |             6067.7          |             6068.6        
      f16 B=16, M=512, H=16, K=256   |              8178.3           |                  |    2899.1  |             7895.5          |             7977.4        
      f32 B=16, M=512, H=16, K=256   |             11894.0           |                  |   10639.5  |            11635.0          |            11624.7        
      f16 B=16, M=1024, H=16, K=16   |              2420.5           |       1240.8     |    4259.2  |             2234.6          |             2048.9        
      f32 B=16, M=1024, H=16, K=16   |              8467.2           |                  |   16650.4  |             8512.2          |             8510.2        
      f16 B=16, M=1024, H=16, K=32   |              2675.7           |       1618.9     |    4491.2  |             2441.3          |             2230.0        
      f32 B=16, M=1024, H=16, K=32   |              9010.0           |                  |   17301.0  |             9012.1          |             9015.7        
      f16 B=16, M=1024, H=16, K=64   |              3328.4           |       2370.3     |    4994.5  |             3032.2          |             2820.0        
      f32 B=16, M=1024, H=16, K=64   |             11566.8           |                  |   18714.2  |            11494.1          |            11492.8        
      f16 B=16, M=1024, H=16, K=128  |              5867.9           |       5632.5     |    5952.8  |             6401.1          |             5440.8        
      f32 B=16, M=1024, H=16, K=128  |             23345.4           |                  |   21523.7  |            22859.1          |            22870.3        
      f16 B=16, M=1024, H=16, K=256  |             30619.2           |                  |    7893.1  |            29884.9          |            29060.4        
      f32 B=16, M=1024, H=16, K=256  |             45211.4           |                  |   38093.0  |            43435.3          |            43423.8        
      f16 B=64, M=128, H=16, K=16    |               159.6           |        145.2     |     439.9  |              161.2          |              167.0        
      f32 B=64, M=128, H=16, K=16    |               493.4           |                  |    1270.0  |              502.7          |              503.2        
      f16 B=64, M=128, H=16, K=32    |               208.5           |        212.1     |     545.3  |              206.2          |              204.9        
      f32 B=64, M=128, H=16, K=32    |               601.2           |                  |    1427.0  |              610.5          |              610.6        
      f16 B=64, M=128, H=16, K=64    |               329.0           |        310.8     |     766.0  |              327.2          |              314.7        
      f32 B=64, M=128, H=16, K=64    |               867.7           |                  |    1743.5  |              889.2          |              889.0        
      f16 B=64, M=128, H=16, K=128   |               635.5           |        562.1     |    1226.7  |              613.3          |              650.7        
      f32 B=64, M=128, H=16, K=128   |              1774.4           |                  |    2386.5  |             1800.0          |             1799.6        
      f16 B=64, M=128, H=16, K=256   |              2839.8           |                  |    2122.8  |             2765.2          |             2821.9        
      f32 B=64, M=128, H=16, K=256   |              3419.2           |                  |    4320.7  |             3458.8          |             3457.0        
      f16 B=64, M=512, H=16, K=16    |              2316.2           |       1202.3     |    4487.1  |             1983.9          |             1868.4        
      f32 B=64, M=512, H=16, K=16    |              6686.0           |                  |   16991.5  |             6709.9          |             6713.2        
      f16 B=64, M=512, H=16, K=32    |              2701.7           |       1541.9     |    4975.9  |             2346.4          |             2184.3        
      f32 B=64, M=512, H=16, K=32    |              7460.5           |                  |   17859.9  |             7429.6          |             7438.0        
      f16 B=64, M=512, H=16, K=64    |              3461.2           |       2418.2     |    5886.1  |             3083.1          |             2942.6        
      f32 B=64, M=512, H=16, K=64    |              9553.4           |                  |   19768.6  |             9526.5          |             9516.6        
      f16 B=64, M=512, H=16, K=128   |              5875.5           |       5443.1     |    7711.0  |             6141.3          |             5513.6        
      f32 B=64, M=512, H=16, K=128   |             21317.0           |                  |   23651.2  |            20931.9          |            20932.3        
      f16 B=64, M=512, H=16, K=256   |             31238.2           |                  |   11490.6  |            28626.8          |            28954.0        
      f32 B=64, M=512, H=16, K=256   |             41124.8           |                  |   42468.0  |            39562.8          |            39579.5        
      f16 B=64, M=1024, H=16, K=16   |              9142.8           |       4707.2     |   16882.8  |             7887.0          |             7314.5        
      f32 B=64, M=1024, H=16, K=16   |             26512.3           |                  |   66311.7  |            26497.5          |            26459.8        
      f16 B=64, M=1024, H=16, K=32   |             10420.8           |       5698.3     |   17875.0  |             8851.4          |             7974.0        
      f32 B=64, M=1024, H=16, K=32   |             28300.0           |                  |   69088.7  |            28102.1          |            28098.2        
      f16 B=64, M=1024, H=16, K=64   |             12948.5           |       8119.0     |   19944.3  |            11064.2          |            10477.6        
      f32 B=64, M=1024, H=16, K=64   |             35600.6           |                  |   74762.8  |            35316.4          |            35361.2        
      f16 B=64, M=1024, H=16, K=128  |             20820.5           |      19184.2     |   23699.3  |            21954.8          |            19220.0        
      f32 B=64, M=1024, H=16, K=128  |             80800.3           |                  |   86003.8  |            78521.9          |            78393.9        
      f16 B=64, M=1024, H=16, K=256  |            114411.1           |                  |   32958.3  |           103287.2          |           104304.2        
      f32 B=64, M=1024, H=16, K=256  |            155731.5           |                  |  153011.0  |           148071.6          |           148165.6        

Times are in microseconds (us).
```
</details>


[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Nov 29, 2022
ghstack-source-id: 5af56eef32bfba5defa0673145be8c293b78c76f
Pull Request resolved: #467
**PERFORMANCE**

This makes performance worse in f16 :(
But I think we need it for stability

<details>
<summary>bw P100/V100 (f32/f16)</summary>

```
[---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------]                                                    
                                                         |  48_accf32_6f8e2f15  |  56_base_02bf6b4e  |  vanilla   |  57_tmpT_b516aec4
1 threads: --------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6289.6       |        6936.0      |    2183.6  |        6940.0    
                          f32 B=384, M=197, H=1, K=88    |         8793.3       |        9446.6      |    2175.1  |        9429.2    
                          f16 B=384, M=197, H=1, K=80    |         5989.9       |        6596.2      |    2146.8  |        6608.6    
                          f32 B=384, M=197, H=1, K=80    |         8427.1       |        8993.1      |    2134.0  |        9030.5    
                          f16 B=384, M=197, H=1, K=64    |         3347.0       |        3527.4      |    1799.3  |        3538.3    
                          f32 B=384, M=197, H=1, K=64    |         5563.1       |        5984.2      |    1801.6  |        5980.9    
                          f16 B=1024, M=197, H=1, K=88   |        15680.5       |       17424.8      |    5671.9  |       17452.6    
                          f32 B=1024, M=197, H=1, K=88   |        23784.2       |       25542.8      |    5664.0  |       25578.6    
                          f16 B=1024, M=197, H=1, K=80   |        14935.5       |       16581.3      |    5559.5  |       16587.3    
                          f32 B=1024, M=197, H=1, K=80   |        22767.6       |       24354.1      |    5550.7  |       24362.0    
                          f16 B=1024, M=197, H=1, K=64   |         8331.1       |        8671.5      |    4644.9  |        8695.3    
                          f32 B=1024, M=197, H=1, K=64   |        15061.6       |       16148.8      |    4650.0  |       16201.4    
                          f16 B=512, M=197, H=1, K=80    |         7594.7       |        8306.6      |    2824.6  |        8336.5    
                          f32 B=512, M=197, H=1, K=80    |        11857.0       |       12660.2      |    2807.4  |       12709.2    
                          f16 B=32, M=197, H=16, K=80    |         7779.8       |        8342.0      |    2820.8  |        8393.5    
                          f32 B=32, M=197, H=16, K=80    |        11846.5       |       12487.8      |    2806.0  |       12549.1    
                          f16 B=32, M=197, H=16, K=64    |         4258.4       |        4445.6      |    2374.0  |        4461.2    
                          f32 B=32, M=197, H=16, K=64    |         7828.9       |        8434.9      |    2376.0  |        8472.2    
                          f16 B=32, M=197, H=16, K=128   |         9025.1       |        9671.9      |    3159.3  |        9707.2    
                          f32 B=32, M=197, H=16, K=128   |        14139.8       |       14920.7      |    3157.7  |       14928.9    
                          f16 B=256, M=197, H=1, K=88    |         4608.9       |        5118.8      |    1478.4  |        5119.2    
                          f32 B=256, M=197, H=1, K=88    |         6174.0       |        6644.3      |    1477.6  |        6642.7    
                          f16 B=16, M=197, H=16, K=88    |         4618.4       |        5073.4      |    1479.5  |        5062.9    
                          f32 B=16, M=197, H=16, K=88    |         6114.2       |        6471.7      |    1474.9  |        6478.7    
                          f16 B=16, M=197, H=16, K=64    |         2490.3       |        2557.5      |    1225.9  |        2550.0    
                          f32 B=16, M=197, H=16, K=64    |         3918.9       |        4208.2      |    1227.0  |        4195.7    
                          f16 B=16, M=197, H=16, K=128   |         5210.0       |        5649.3      |    1635.2  |        5648.1    
                          f32 B=16, M=197, H=16, K=128   |         7103.1       |        7451.6      |    1645.5  |        7445.1    
                          f16 B=1, M=4096, H=160, K=128  |      1014229.1       |     1106182.6      |            |     1108040.6    
                          f32 B=1, M=4096, H=160, K=128  |      1258173.2       |     1243183.8      |            |     1241548.7    
                          f16 B=2, M=4096, H=160, K=128  |      1642279.2       |     1753736.9      |            |     1771655.3    
                          f32 B=2, M=4096, H=160, K=128  |      2505435.4       |     2477353.1      |            |     2473773.4    
                          f16 B=1, M=8192, H=160, K=128  |      4050128.4       |     4415962.1      |            |     4428194.9    
                          f32 B=1, M=8192, H=160, K=128  |      5042352.6       |     4970582.0      |            |     4965069.9    
                          f16 B=2, M=8192, H=160, K=128  |      6600732.3       |     7026378.5      |            |     7068051.8    
                          f16 B=1024, M=82, H=8, K=64    |        21572.8       |       22531.6      |    9059.1  |       22418.4    
                          f32 B=1024, M=82, H=8, K=64    |        38178.4       |       45927.2      |    9070.3  |       45708.0    
                          f16 B=150, M=256, H=16, K=64   |        21436.5       |       21927.4      |   12938.5  |       22001.0    
                          f32 B=150, M=256, H=16, K=64   |        33024.2       |       33196.3      |   13249.2  |       33199.6    
                          f16 B=64, M=256, H=12, K=64    |         6869.8       |        7048.6      |    4200.6  |        7073.6    
                          f32 B=64, M=256, H=12, K=64    |        10719.9       |       10832.1      |    4271.3  |       10843.6    
                          f16 B=1, M=4096, H=16, K=40    |       134722.6       |      145429.8      |   20587.0  |      143743.8    
                          f32 B=1, M=4096, H=16, K=40    |       143015.1       |      147850.6      |   20625.8  |      148272.8    
                          f16 B=1, M=16384, H=16, K=40   |      2149850.4       |     2323732.4      |            |     2301489.4    
                          f32 B=1, M=16384, H=16, K=40   |      2286478.1       |     2369812.4      |            |     2375911.8    
                          f16 B=16, M=128, H=16, K=16    |          497.2       |         502.9      |     623.8  |         503.6    
                          f32 B=16, M=128, H=16, K=16    |          573.5       |         609.8      |     617.5  |         611.6    
                          f16 B=16, M=128, H=16, K=32    |          563.9       |         573.1      |     624.2  |         575.2    
                          f32 B=16, M=128, H=16, K=32    |          661.5       |         702.2      |     620.4  |         703.2    
                          f16 B=16, M=128, H=16, K=64    |          708.9       |         722.6      |     619.9  |         724.3    
                          f32 B=16, M=128, H=16, K=64    |          916.2       |         953.5      |     618.6  |         953.6    
                          f16 B=16, M=128, H=16, K=128   |         1465.8       |        1542.8      |     624.2  |        1545.7    
                          f32 B=16, M=128, H=16, K=128   |         1829.1       |        1872.3      |     616.2  |        1873.9    
                          f16 B=16, M=128, H=16, K=256   |         3796.3       |        4002.5      |    1010.7  |        4008.4    
                          f32 B=16, M=128, H=16, K=256   |         3838.8       |        3957.1      |    1203.6  |        3951.8    
                          f16 B=16, M=512, H=16, K=16    |         7680.4       |        7775.1      |    4848.7  |        7752.8    
                          f32 B=16, M=512, H=16, K=16    |         9402.0       |        9926.1      |    4926.7  |        9914.5    
                          f16 B=16, M=512, H=16, K=32    |         8897.5       |        8999.0      |    5055.9  |        9003.9    
                          f32 B=16, M=512, H=16, K=32    |        10762.1       |       11320.9      |    5065.4  |       11341.8    
                          f16 B=16, M=512, H=16, K=64    |        10936.9       |       11214.8      |    5484.9  |       11254.4    
                          f32 B=16, M=512, H=16, K=64    |        15091.9       |       15210.3      |    5552.0  |       15196.3    
                          f16 B=16, M=512, H=16, K=128   |        23491.0       |       25234.8      |    7317.1  |       25300.3    
                          f32 B=16, M=512, H=16, K=128   |        30524.6       |       30320.0      |    7487.0  |       30200.6    
                          f16 B=16, M=512, H=16, K=256   |        50389.7       |       54525.2      |   14015.3  |       54931.1    
                          f32 B=16, M=512, H=16, K=256   |        62155.6       |       61046.4      |   14285.5  |       61081.1    
                          f16 B=16, M=1024, H=16, K=16   |        31289.9       |       31951.9      |   18778.4  |       32044.9    
                          f32 B=16, M=1024, H=16, K=16   |        37744.4       |       39586.6      |   18929.9  |       39739.9    
                          f16 B=16, M=1024, H=16, K=32   |        35770.8       |       36651.1      |   19620.2  |       36909.0    
                          f32 B=16, M=1024, H=16, K=32   |        43211.5       |       45544.5      |   19518.4  |       45664.8    
                          f16 B=16, M=1024, H=16, K=64   |        43865.1       |       44864.8      |   21286.0  |       45345.7    
                          f32 B=16, M=1024, H=16, K=64   |        60710.5       |       60966.9      |   21634.1  |       60921.9    
                          f16 B=16, M=1024, H=16, K=128  |        94633.7       |      101796.0      |   28502.1  |      102871.6    
                          f32 B=16, M=1024, H=16, K=128  |       124093.6       |      122196.9      |   28520.3  |      122043.0    
                          f16 B=16, M=1024, H=16, K=256  |       194780.7       |      212303.0      |   55419.1  |      214643.5    
                          f32 B=16, M=1024, H=16, K=256  |       250799.1       |      245196.2      |   55634.0  |      245534.2    
                          f16 B=64, M=128, H=16, K=16    |         1658.0       |        1661.4      |    1331.0  |        1662.7    
                          f32 B=64, M=128, H=16, K=16    |         2129.2       |        2266.2      |    1371.8  |        2269.4    
                          f16 B=64, M=128, H=16, K=32    |         1904.3       |        1903.5      |    1384.2  |        1906.4    
                          f32 B=64, M=128, H=16, K=32    |         2496.5       |        2643.4      |    1445.1  |        2639.8    
                          f16 B=64, M=128, H=16, K=64    |         2393.1       |        2432.3      |    1505.2  |        2437.8    
                          f32 B=64, M=128, H=16, K=64    |         3471.1       |        3561.2      |    1590.0  |        3565.7    
                          f16 B=64, M=128, H=16, K=128   |         4969.1       |        5266.3      |    1988.4  |        5272.4    
                          f32 B=64, M=128, H=16, K=128   |         6880.3       |        7040.6      |    2121.0  |        7024.3    
                          f16 B=64, M=128, H=16, K=256   |        12635.3       |       13334.7      |    3859.5  |       13350.7    
                          f32 B=64, M=128, H=16, K=256   |        14185.7       |       14514.4      |    4553.8  |       14503.1    
                          f16 B=64, M=512, H=16, K=16    |        26278.4       |       26189.9      |   18916.0  |       26110.0    
                          f32 B=64, M=512, H=16, K=16    |        35128.2       |       37075.9      |   19191.4  |       37174.5    
                          f16 B=64, M=512, H=16, K=32    |        30414.2       |       30938.2      |   19734.7  |       31071.7    
                          f32 B=64, M=512, H=16, K=32    |        40483.2       |       42922.5      |   19843.9  |       42868.5    
                          f16 B=64, M=512, H=16, K=64    |        37248.2       |       38179.7      |   21640.1  |       38327.0    
                          f32 B=64, M=512, H=16, K=64    |        57666.8       |       57675.3      |   21909.1  |       57697.2    
                          f16 B=64, M=512, H=16, K=128   |        80113.6       |       86165.7      |   28765.5  |       86368.3    
                          f32 B=64, M=512, H=16, K=128   |       115672.5       |      115161.0      |   28910.3  |      115320.0    
                          f16 B=64, M=512, H=16, K=256   |       169250.0       |      183791.8      |   56315.7  |      183735.0    
                          f32 B=64, M=512, H=16, K=256   |       236594.6       |      233093.6      |   56853.4  |      233170.2    
                          f16 B=64, M=1024, H=16, K=16   |       106022.3       |      109588.2      |   74303.6  |      109410.4    
                          f32 B=64, M=1024, H=16, K=16   |       141241.8       |      148854.1      |            |      149651.5    
                          f16 B=64, M=1024, H=16, K=32   |       120899.1       |      125044.6      |   77828.6  |      125716.9    
                          f32 B=64, M=1024, H=16, K=32   |       162478.5       |      173906.1      |            |      173216.4    
                          f16 B=64, M=1024, H=16, K=64   |       149044.1       |      152290.6      |   85821.6  |      152748.5    
                          f32 B=64, M=1024, H=16, K=64   |       233195.9       |      231479.6      |            |      231533.9    
                          f16 B=64, M=1024, H=16, K=128  |       319761.5       |      344076.5      |  113579.4  |      345058.5    
                          f32 B=64, M=1024, H=16, K=128  |       470172.3       |      466330.7      |            |      463507.4    
                          f16 B=64, M=1024, H=16, K=256  |       658362.1       |      717070.2      |            |      723057.4    
                          f32 B=64, M=1024, H=16, K=256  |       955624.0       |      935114.3      |            |      935945.4    
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1811.3       |        1686.3      |    1375.2  |        1699.6    
                          f32 B=384, M=197, H=1, K=88    |         4315.7       |        4665.1      |    2257.1  |        4663.1    
                          f16 B=384, M=197, H=1, K=80    |         1733.0       |        1616.3      |    1281.7  |        1619.8    
                          f32 B=384, M=197, H=1, K=80    |         3965.0       |        4226.4      |    2171.7  |        4228.3    
                          f16 B=384, M=197, H=1, K=64    |         1135.2       |        1084.4      |    1043.8  |        1083.0    
                          f32 B=384, M=197, H=1, K=64    |         2673.2       |        2883.0      |    1744.5  |        2878.2    
                          f16 B=1024, M=197, H=1, K=88   |         4721.0       |        4396.6      |    3725.3  |        4404.1    
                          f32 B=1024, M=197, H=1, K=88   |        10531.3       |       11443.8      |    6106.5  |       11464.2    
                          f16 B=1024, M=197, H=1, K=80   |         4520.2       |        4216.2      |    3329.2  |        4223.7    
                          f32 B=1024, M=197, H=1, K=80   |         9573.5       |       10301.4      |    5757.6  |       10305.3    
                          f16 B=1024, M=197, H=1, K=64   |         2788.1       |        2660.1      |    2674.6  |        2663.5    
                          f32 B=1024, M=197, H=1, K=64   |         6556.7       |        7102.4      |    4516.6  |        7096.8    
                          f16 B=512, M=197, H=1, K=80    |         2377.2       |        2228.1      |    1685.0  |        2231.4    
                          f32 B=512, M=197, H=1, K=80    |         5259.3       |        5639.7      |    2887.8  |        5636.6    
                          f16 B=32, M=197, H=16, K=80    |         2403.0       |        2201.1      |    1798.2  |        2204.9    
                          f32 B=32, M=197, H=16, K=80    |         5402.3       |        5667.7      |    3046.3  |        5662.4    
                          f16 B=32, M=197, H=16, K=64    |         1552.7       |        1486.3      |    1451.6  |        1485.2    
                          f32 B=32, M=197, H=16, K=64    |         3622.5       |        3911.0      |    2427.0  |        3912.9    
                          f16 B=32, M=197, H=16, K=128   |         2776.6       |        2611.9      |    2211.3  |        2613.3    
                          f32 B=32, M=197, H=16, K=128   |         6647.9       |        7082.3      |    4088.5  |        7104.8    
                          f16 B=256, M=197, H=1, K=88    |         1357.5       |        1285.5      |     941.3  |        1287.6    
                          f32 B=256, M=197, H=1, K=88    |         2874.1       |        3085.1      |    1543.2  |        3090.1    
                          f16 B=16, M=197, H=16, K=88    |         1349.1       |        1264.5      |     964.7  |        1263.1    
                          f32 B=16, M=197, H=16, K=88    |         2803.8       |        2967.0      |    1647.4  |        2972.0    
                          f16 B=16, M=197, H=16, K=64    |          765.5       |         728.4      |     765.3  |         731.0    
                          f32 B=16, M=197, H=16, K=64    |         1834.1       |        1969.7      |    1282.4  |        1974.4    
                          f16 B=16, M=197, H=16, K=128   |         1509.1       |        1432.6      |    1139.1  |        1433.9    
                          f32 B=16, M=197, H=16, K=128   |         3406.5       |        3606.2      |    2048.6  |        3613.3    
                          f16 B=1, M=4096, H=160, K=128  |       168807.1       |      148652.9      |            |      149343.8    
                          f32 B=1, M=4096, H=160, K=128  |       549864.6       |      586699.9      |            |      585699.9    
                          f16 B=2, M=4096, H=160, K=128  |       339010.5       |      298827.7      |            |      298808.4    
                          f32 B=2, M=4096, H=160, K=128  |      1106963.8       |     1176218.3      |            |     1179173.2    
                          f16 B=1, M=8192, H=160, K=128  |       679742.4       |      594323.2      |            |      595580.1    
                          f32 B=1, M=8192, H=160, K=128  |      2195491.3       |     2340248.4      |            |     2343505.7    
                          f16 B=2, M=8192, H=160, K=128  |      1364983.7       |     1193787.8      |            |     1192596.1    
                          f16 B=1024, M=82, H=8, K=64    |         9052.8       |        8762.6      |    5804.2  |        8757.8    
                          f32 B=1024, M=82, H=8, K=64    |        14726.3       |       16270.6      |   11059.7  |       16215.9    
                          f16 B=150, M=256, H=16, K=64   |         5662.9       |        5519.4      |    7557.8  |        5524.5    
                          f32 B=150, M=256, H=16, K=64   |        16700.3       |       17612.7      |   16426.0  |       17640.4    
                          f16 B=64, M=256, H=12, K=64    |         1849.1       |        1793.1      |    2383.5  |        1798.7    
                          f32 B=64, M=256, H=12, K=64    |         5451.5       |        5766.4      |    4975.3  |        5775.8    
                          f16 B=1, M=4096, H=16, K=40    |        47263.3       |       47850.4      |    8315.6  |       47777.2    
                          f32 B=1, M=4096, H=16, K=40    |       113099.4       |      113164.5      |   19536.0  |      113930.0    
                          f16 B=1, M=16384, H=16, K=40   |       757091.8       |      770365.7      |            |      765401.3    
                          f32 B=1, M=16384, H=16, K=40   |      1806827.7       |     1816302.6      |            |     1819162.1    
                          f16 B=16, M=128, H=16, K=16    |          219.2       |         218.5      |     480.6  |         231.8    
                          f32 B=16, M=128, H=16, K=16    |          301.2       |         308.2      |     498.9  |         307.9    
                          f16 B=16, M=128, H=16, K=32    |          227.6       |         215.4      |     473.8  |         220.1    
                          f32 B=16, M=128, H=16, K=32    |          395.8       |         401.0      |     455.1  |         400.8    
                          f16 B=16, M=128, H=16, K=64    |          225.6       |         214.8      |     510.1  |         229.9    
                          f32 B=16, M=128, H=16, K=64    |          561.7       |         583.1      |     598.9  |         581.2    
                          f16 B=16, M=128, H=16, K=128   |          404.6       |         392.0      |     524.0  |         394.8    
                          f32 B=16, M=128, H=16, K=128   |         1103.0       |        1140.8      |    1015.4  |        1142.0    
                          f16 B=16, M=128, H=16, K=256   |         1045.3       |        1049.1      |     889.6  |        1047.4    
                          f32 B=16, M=128, H=16, K=256   |         2181.3       |        2270.9      |    1869.2  |        2265.7    
                          f16 B=16, M=512, H=16, K=16    |         1731.5       |        1585.7      |    1908.5  |        1586.2    
                          f32 B=16, M=512, H=16, K=16    |         4513.8       |        4696.7      |    4222.1  |        4695.6    
                          f16 B=16, M=512, H=16, K=32    |         1942.2       |        1823.4      |    2086.3  |        1809.7    
                          f32 B=16, M=512, H=16, K=32    |         5596.1       |        5819.4      |    4588.1  |        5833.6    
                          f16 B=16, M=512, H=16, K=64    |         2450.2       |        2340.8      |    2580.6  |        2353.2    
                          f32 B=16, M=512, H=16, K=64    |         7619.5       |        7853.1      |    5536.8  |        7875.1    
                          f16 B=16, M=512, H=16, K=128   |         4884.8       |        4473.5      |    3388.7  |        4487.9    
                          f32 B=16, M=512, H=16, K=128   |        15010.0       |       15557.1      |    8979.4  |       15513.1    
                          f16 B=16, M=512, H=16, K=256   |        12973.9       |       11134.7      |    5418.2  |       11106.3    
                          f32 B=16, M=512, H=16, K=256   |        29751.1       |       30979.1      |   16856.3  |       31012.8    
                          f16 B=16, M=1024, H=16, K=16   |         6802.6       |        6185.1      |    6996.0  |        6191.0    
                          f32 B=16, M=1024, H=16, K=16   |        18098.2       |       18954.5      |   16129.1  |       19166.8    
                          f16 B=16, M=1024, H=16, K=32   |         7531.9       |        7065.0      |    7436.0  |        7067.5    
                          f32 B=16, M=1024, H=16, K=32   |        21999.9       |       22901.5      |   17040.0  |       22899.4    
                          f16 B=16, M=1024, H=16, K=64   |         9312.0       |        8854.3      |    8605.6  |        8863.9    
                          f32 B=16, M=1024, H=16, K=64   |        29837.7       |       30895.1      |   20355.3  |       30878.5    
                          f16 B=16, M=1024, H=16, K=128  |        18979.0       |       16951.0      |   10561.3  |       16995.2    
                          f32 B=16, M=1024, H=16, K=128  |        58738.3       |       60861.2      |   33427.0  |       60599.8    
                          f16 B=16, M=1024, H=16, K=256  |        49681.9       |       41833.3      |   17329.5  |       41921.6    
                          f32 B=16, M=1024, H=16, K=256  |       117362.4       |      121004.8      |   60515.8  |      122046.9    
                          f16 B=64, M=128, H=16, K=16    |          432.2       |         411.1      |     642.1  |         411.3    
                          f32 B=64, M=128, H=16, K=16    |         1028.9       |        1057.9      |    1233.4  |        1056.6    
                          f16 B=64, M=128, H=16, K=32    |          522.5       |         500.7      |     813.6  |         499.7    
                          f32 B=64, M=128, H=16, K=32    |         1403.5       |        1443.4      |    1535.6  |        1443.6    
                          f16 B=64, M=128, H=16, K=64    |          750.0       |         739.8      |    1185.6  |         741.0    
                          f32 B=64, M=128, H=16, K=64    |         2013.9       |        2110.1      |    2156.8  |        2105.4    
                          f16 B=64, M=128, H=16, K=128   |         1421.2       |        1387.9      |    1915.5  |        1388.3    
                          f32 B=64, M=128, H=16, K=128   |         3946.5       |        4156.7      |    3780.1  |        4156.3    
                          f16 B=64, M=128, H=16, K=256   |         3811.5       |        3810.6      |    3448.7  |        3811.0    
                          f32 B=64, M=128, H=16, K=256   |         7983.9       |        8432.7      |    7304.0  |        8415.5    
                          f16 B=64, M=512, H=16, K=16    |         6157.4       |        5523.9      |    7461.8  |        5528.1    
                          f32 B=64, M=512, H=16, K=16    |        16118.2       |       17004.7      |   16651.7  |       16984.5    
                          f16 B=64, M=512, H=16, K=32    |         6985.1       |        6483.1      |    8278.5  |        6495.1    
                          f32 B=64, M=512, H=16, K=32    |        20471.9       |       21230.7      |   18420.5  |       21269.0    
                          f16 B=64, M=512, H=16, K=64    |         9045.8       |        8633.5      |   10337.1  |        8674.1    
                          f32 B=64, M=512, H=16, K=64    |        27675.0       |       29282.8      |   22883.7  |       29136.9    
                          f16 B=64, M=512, H=16, K=128   |        17594.2       |       15805.9      |   14700.6  |       15788.2    
                          f32 B=64, M=512, H=16, K=128   |        54612.4       |       57815.7      |   39974.3  |       57951.8    
                          f16 B=64, M=512, H=16, K=256   |        47452.6       |       40093.5      |   27087.5  |       40175.6    
                          f32 B=64, M=512, H=16, K=256   |       108880.3       |      115953.1      |   77794.2  |      115951.0    
                          f16 B=64, M=1024, H=16, K=16   |        24369.0       |       21533.0      |   28448.6  |       21556.6    
                          f32 B=64, M=1024, H=16, K=16   |        64649.7       |       68791.8      |            |       68407.4    
                          f16 B=64, M=1024, H=16, K=32   |        27143.1       |       25683.7      |   30252.7  |       25727.9    
                          f32 B=64, M=1024, H=16, K=32   |        79967.5       |       83351.5      |            |       83084.7    
                          f16 B=64, M=1024, H=16, K=64   |        34667.0       |       32592.7      |   36991.2  |       32659.8    
                          f32 B=64, M=1024, H=16, K=64   |       108282.2       |      113858.1      |            |      114286.0    
                          f16 B=64, M=1024, H=16, K=128  |        68519.5       |       59757.3      |   48834.4  |       59817.7    
                          f32 B=64, M=1024, H=16, K=128  |       215465.9       |      227335.8      |            |      227204.0    
                          f16 B=64, M=1024, H=16, K=256  |       183070.4       |      150960.1      |            |      150947.1    
                          f32 B=64, M=1024, H=16, K=256  |       425832.9       |      453717.2      |            |      453349.2    

Times are in microseconds (us).
```
</details>

<details>
<summary>bw A100 (f32/f16)</summary>

```
[----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------]                                                       
                                     |  48_accf32_69654fdb[cutlass]  |  flash[flshatt]  |  vanilla   |  56_base_02bf6b4e[cutlass]  |  57_tmpT_b516aec4[cutlass]
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |               613.4           |                  |    2264.9  |              612.8          |              578.3        
      f32 B=384, M=197, H=1, K=88    |              2335.2           |                  |    1843.0  |             2438.4          |             2425.1        
      f16 B=384, M=197, H=1, K=80    |               583.6           |                  |    1922.9  |              577.6          |              548.4        
      f32 B=384, M=197, H=1, K=80    |              2241.3           |                  |    1787.8  |             2333.2          |             2333.1        
      f16 B=384, M=197, H=1, K=64    |               405.6           |        232.5     |    1809.8  |              386.4          |              366.1        
      f32 B=384, M=197, H=1, K=64    |              1259.7           |                  |    1675.6  |             1309.6          |             1316.8        
      f16 B=1024, M=197, H=1, K=88   |              1538.2           |                  |    5964.7  |             1550.8          |             1454.4        
      f32 B=1024, M=197, H=1, K=88   |              6031.4           |                  |    4559.5  |             6325.4          |             6332.7        
      f16 B=1024, M=197, H=1, K=80   |              1458.8           |                  |    5038.2  |             1463.8          |             1379.0        
      f32 B=1024, M=197, H=1, K=80   |              5786.2           |                  |    4412.6  |             6059.0          |             6079.9        
      f16 B=1024, M=197, H=1, K=64   |               929.5           |        575.9     |    4735.6  |              862.1          |              821.6        
      f32 B=1024, M=197, H=1, K=64   |              3289.9           |                  |    4119.9  |             3434.4          |             3441.4        
      f16 B=512, M=197, H=1, K=80    |               744.5           |                  |    2544.7  |              735.3          |              695.8        
      f32 B=512, M=197, H=1, K=80    |              2889.4           |                  |    2286.1  |             3008.8          |             3029.0        
      f16 B=32, M=197, H=16, K=80    |               741.6           |                  |    2569.0  |              723.3          |              693.4        
      f32 B=32, M=197, H=16, K=80    |              2878.6           |                  |    2355.3  |             3003.0          |             3024.1        
      f16 B=32, M=197, H=16, K=64    |               478.5           |        295.7     |    2429.2  |              456.1          |              426.7        
      f32 B=32, M=197, H=16, K=64    |              1784.4           |                  |    2196.9  |             1863.7          |             1860.2        
      f16 B=32, M=197, H=16, K=128   |               887.0           |        682.3     |    4492.9  |              857.5          |              853.8        
      f32 B=32, M=197, H=16, K=128   |              3546.6           |                  |    2807.3  |             3734.9          |             3737.1        
      f16 B=256, M=197, H=1, K=88    |               445.1           |                  |    1528.9  |              445.0          |              422.2        
      f32 B=256, M=197, H=1, K=88    |              1678.8           |                  |    1207.6  |             1746.6          |             1752.8        
      f16 B=16, M=197, H=16, K=88    |               441.5           |                  |    1544.2  |              437.3          |              419.5        
      f32 B=16, M=197, H=16, K=88    |              1668.3           |                  |    1250.4  |             1742.7          |             1746.1        
      f16 B=16, M=197, H=16, K=64    |               247.4           |        165.6     |    1242.5  |              233.2          |              217.5        
      f32 B=16, M=197, H=16, K=64    |              1051.1           |                  |    1125.3  |             1099.2          |             1096.6        
      f16 B=16, M=197, H=16, K=128   |               498.4           |        386.2     |    2264.5  |              488.0          |              480.5        
      f32 B=16, M=197, H=16, K=128   |              1950.2           |                  |    1446.7  |             2039.0          |             2028.6        
      f16 B=1, M=4096, H=160, K=128  |             55915.0           |      54620.4     |   45909.5  |            63407.6          |            51298.8        
      f32 B=1, M=4096, H=160, K=128  |            238514.6           |                  |            |           232677.5          |           232672.5        
      f16 B=2, M=4096, H=160, K=128  |             93612.0           |      84238.0     |            |           100433.1          |            84858.1        
      f32 B=2, M=4096, H=160, K=128  |            375037.4           |                  |            |           364234.1          |           364407.2        
      f16 B=1, M=8192, H=160, K=128  |            223261.8           |     215499.8     |            |           251806.6          |           202133.7        
      f32 B=1, M=8192, H=160, K=128  |            946708.9           |                  |            |           924986.2          |           924988.9        
      f16 B=2, M=8192, H=160, K=128  |            367969.2           |     330092.8     |            |           395881.3          |           332000.6        
      f32 B=2, M=8192, H=160, K=128  |           1492031.4           |                  |            |          1448691.1          |          1449146.1        
      f16 B=1024, M=82, H=8, K=64    |              1890.2           |       1620.3     |    3819.7  |             1861.3          |             1764.6        
      f32 B=1024, M=82, H=8, K=64    |              8428.2           |                  |    8735.1  |             8831.3          |             8867.4        
      f16 B=150, M=256, H=16, K=64   |              2292.3           |       1625.3     |    4555.9  |             2109.8          |             2019.4        
      f32 B=150, M=256, H=16, K=64   |              6252.4           |                  |   12948.1  |             6288.9          |             6281.3        
      f16 B=64, M=256, H=12, K=64    |               782.2           |        567.4     |    1498.0  |              731.2          |              699.7        
      f32 B=64, M=256, H=12, K=64    |              2141.2           |                  |    4266.6  |             2160.6          |             2160.2        
      f16 B=1, M=4096, H=16, K=40    |             23504.2           |                  |    4196.1  |            23699.0          |            23008.9        
      f32 B=1, M=4096, H=16, K=40    |             73699.5           |                  |   17755.3  |            73261.8          |            73078.5        
      f16 B=1, M=16384, H=16, K=40   |            391408.9           |                  |            |           439777.3          |           407653.7        
      f32 B=1, M=16384, H=16, K=40   |           1196173.6           |                  |            |          1181547.9          |          1181625.1        
      f16 B=256, M=4096, H=16, K=64  |            733221.8           |     439627.5     |            |           603237.1          |           565905.8        
      f16 B=16, M=128, H=16, K=16    |               130.0           |        113.1     |     265.2  |              125.5          |              124.4        
      f32 B=16, M=128, H=16, K=16    |               161.5           |                  |     373.1  |              160.6          |              162.8        
      f16 B=16, M=128, H=16, K=32    |               125.8           |        111.5     |     263.8  |              122.1          |              125.8        
      f32 B=16, M=128, H=16, K=32    |               189.8           |                  |     412.6  |              196.3          |              196.2        
      f16 B=16, M=128, H=16, K=64    |               126.0           |        112.4     |     265.8  |              120.8          |              125.7        
      f32 B=16, M=128, H=16, K=64    |               272.1           |                  |     498.7  |              285.9          |              283.3        
      f16 B=16, M=128, H=16, K=128   |               181.3           |        158.4     |     298.5  |              178.0          |              186.5        
      f32 B=16, M=128, H=16, K=128   |               509.6           |                  |     673.9  |              521.1          |              521.5        
      f16 B=16, M=128, H=16, K=256   |               774.0           |                  |     541.4  |              775.5          |              757.0        
      f32 B=16, M=128, H=16, K=256   |               975.2           |                  |    1162.6  |              994.5          |              994.5        
      f16 B=16, M=512, H=16, K=16    |               621.0           |        322.6     |    1204.9  |              555.0          |              519.9        
      f32 B=16, M=512, H=16, K=16    |              2148.0           |                  |    4414.4  |             2178.9          |             2180.4        
      f16 B=16, M=512, H=16, K=32    |               709.9           |        435.6     |    1306.2  |              653.1          |              602.8        
      f32 B=16, M=512, H=16, K=32    |              2335.6           |                  |    4640.7  |             2336.0          |             2336.3        
      f16 B=16, M=512, H=16, K=64    |               917.8           |        702.7     |    1545.8  |              849.9          |              797.4        
      f32 B=16, M=512, H=16, K=64    |              2965.5           |                  |    5125.0  |             2986.9          |             2988.3        
      f16 B=16, M=512, H=16, K=128   |              1644.0           |       1584.1     |    1983.4  |             1757.0          |             1548.0        
      f32 B=16, M=512, H=16, K=128   |              6152.7           |                  |    6099.1  |             6067.7          |             6068.6        
      f16 B=16, M=512, H=16, K=256   |              8178.3           |                  |    2899.1  |             7895.5          |             7977.4        
      f32 B=16, M=512, H=16, K=256   |             11894.0           |                  |   10639.5  |            11635.0          |            11624.7        
      f16 B=16, M=1024, H=16, K=16   |              2420.5           |       1240.8     |    4259.2  |             2234.6          |             2048.9        
      f32 B=16, M=1024, H=16, K=16   |              8467.2           |                  |   16650.4  |             8512.2          |             8510.2        
      f16 B=16, M=1024, H=16, K=32   |              2675.7           |       1618.9     |    4491.2  |             2441.3          |             2230.0        
      f32 B=16, M=1024, H=16, K=32   |              9010.0           |                  |   17301.0  |             9012.1          |             9015.7        
      f16 B=16, M=1024, H=16, K=64   |              3328.4           |       2370.3     |    4994.5  |             3032.2          |             2820.0        
      f32 B=16, M=1024, H=16, K=64   |             11566.8           |                  |   18714.2  |            11494.1          |            11492.8        
      f16 B=16, M=1024, H=16, K=128  |              5867.9           |       5632.5     |    5952.8  |             6401.1          |             5440.8        
      f32 B=16, M=1024, H=16, K=128  |             23345.4           |                  |   21523.7  |            22859.1          |            22870.3        
      f16 B=16, M=1024, H=16, K=256  |             30619.2           |                  |    7893.1  |            29884.9          |            29060.4        
      f32 B=16, M=1024, H=16, K=256  |             45211.4           |                  |   38093.0  |            43435.3          |            43423.8        
      f16 B=64, M=128, H=16, K=16    |               159.6           |        145.2     |     439.9  |              161.2          |              167.0        
      f32 B=64, M=128, H=16, K=16    |               493.4           |                  |    1270.0  |              502.7          |              503.2        
      f16 B=64, M=128, H=16, K=32    |               208.5           |        212.1     |     545.3  |              206.2          |              204.9        
      f32 B=64, M=128, H=16, K=32    |               601.2           |                  |    1427.0  |              610.5          |              610.6        
      f16 B=64, M=128, H=16, K=64    |               329.0           |        310.8     |     766.0  |              327.2          |              314.7        
      f32 B=64, M=128, H=16, K=64    |               867.7           |                  |    1743.5  |              889.2          |              889.0        
      f16 B=64, M=128, H=16, K=128   |               635.5           |        562.1     |    1226.7  |              613.3          |              650.7        
      f32 B=64, M=128, H=16, K=128   |              1774.4           |                  |    2386.5  |             1800.0          |             1799.6        
      f16 B=64, M=128, H=16, K=256   |              2839.8           |                  |    2122.8  |             2765.2          |             2821.9        
      f32 B=64, M=128, H=16, K=256   |              3419.2           |                  |    4320.7  |             3458.8          |             3457.0        
      f16 B=64, M=512, H=16, K=16    |              2316.2           |       1202.3     |    4487.1  |             1983.9          |             1868.4        
      f32 B=64, M=512, H=16, K=16    |              6686.0           |                  |   16991.5  |             6709.9          |             6713.2        
      f16 B=64, M=512, H=16, K=32    |              2701.7           |       1541.9     |    4975.9  |             2346.4          |             2184.3        
      f32 B=64, M=512, H=16, K=32    |              7460.5           |                  |   17859.9  |             7429.6          |             7438.0        
      f16 B=64, M=512, H=16, K=64    |              3461.2           |       2418.2     |    5886.1  |             3083.1          |             2942.6        
      f32 B=64, M=512, H=16, K=64    |              9553.4           |                  |   19768.6  |             9526.5          |             9516.6        
      f16 B=64, M=512, H=16, K=128   |              5875.5           |       5443.1     |    7711.0  |             6141.3          |             5513.6        
      f32 B=64, M=512, H=16, K=128   |             21317.0           |                  |   23651.2  |            20931.9          |            20932.3        
      f16 B=64, M=512, H=16, K=256   |             31238.2           |                  |   11490.6  |            28626.8          |            28954.0        
      f32 B=64, M=512, H=16, K=256   |             41124.8           |                  |   42468.0  |            39562.8          |            39579.5        
      f16 B=64, M=1024, H=16, K=16   |              9142.8           |       4707.2     |   16882.8  |             7887.0          |             7314.5        
      f32 B=64, M=1024, H=16, K=16   |             26512.3           |                  |   66311.7  |            26497.5          |            26459.8        
      f16 B=64, M=1024, H=16, K=32   |             10420.8           |       5698.3     |   17875.0  |             8851.4          |             7974.0        
      f32 B=64, M=1024, H=16, K=32   |             28300.0           |                  |   69088.7  |            28102.1          |            28098.2        
      f16 B=64, M=1024, H=16, K=64   |             12948.5           |       8119.0     |   19944.3  |            11064.2          |            10477.6        
      f32 B=64, M=1024, H=16, K=64   |             35600.6           |                  |   74762.8  |            35316.4          |            35361.2        
      f16 B=64, M=1024, H=16, K=128  |             20820.5           |      19184.2     |   23699.3  |            21954.8          |            19220.0        
      f32 B=64, M=1024, H=16, K=128  |             80800.3           |                  |   86003.8  |            78521.9          |            78393.9        
      f16 B=64, M=1024, H=16, K=256  |            114411.1           |                  |   32958.3  |           103287.2          |           104304.2        
      f32 B=64, M=1024, H=16, K=256  |            155731.5           |                  |  153011.0  |           148071.6          |           148165.6        

Times are in microseconds (us).
```
</details>


[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Dec 6, 2022
ghstack-source-id: b7b4fe67e89e46e691615a42676b04c3bc0779a6
Pull Request resolved: #467
**PERFORMANCE**

This makes performance worse in f16 :(
But I think we need it for stability

<details>
<summary>bw P100/V100 (f32/f16)</summary>

```
[---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------]                                                    
                                                         |  48_accf32_6f8e2f15  |  56_base_02bf6b4e  |  vanilla   |  57_tmpT_b516aec4
1 threads: --------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6289.6       |        6936.0      |    2183.6  |        6940.0    
                          f32 B=384, M=197, H=1, K=88    |         8793.3       |        9446.6      |    2175.1  |        9429.2    
                          f16 B=384, M=197, H=1, K=80    |         5989.9       |        6596.2      |    2146.8  |        6608.6    
                          f32 B=384, M=197, H=1, K=80    |         8427.1       |        8993.1      |    2134.0  |        9030.5    
                          f16 B=384, M=197, H=1, K=64    |         3347.0       |        3527.4      |    1799.3  |        3538.3    
                          f32 B=384, M=197, H=1, K=64    |         5563.1       |        5984.2      |    1801.6  |        5980.9    
                          f16 B=1024, M=197, H=1, K=88   |        15680.5       |       17424.8      |    5671.9  |       17452.6    
                          f32 B=1024, M=197, H=1, K=88   |        23784.2       |       25542.8      |    5664.0  |       25578.6    
                          f16 B=1024, M=197, H=1, K=80   |        14935.5       |       16581.3      |    5559.5  |       16587.3    
                          f32 B=1024, M=197, H=1, K=80   |        22767.6       |       24354.1      |    5550.7  |       24362.0    
                          f16 B=1024, M=197, H=1, K=64   |         8331.1       |        8671.5      |    4644.9  |        8695.3    
                          f32 B=1024, M=197, H=1, K=64   |        15061.6       |       16148.8      |    4650.0  |       16201.4    
                          f16 B=512, M=197, H=1, K=80    |         7594.7       |        8306.6      |    2824.6  |        8336.5    
                          f32 B=512, M=197, H=1, K=80    |        11857.0       |       12660.2      |    2807.4  |       12709.2    
                          f16 B=32, M=197, H=16, K=80    |         7779.8       |        8342.0      |    2820.8  |        8393.5    
                          f32 B=32, M=197, H=16, K=80    |        11846.5       |       12487.8      |    2806.0  |       12549.1    
                          f16 B=32, M=197, H=16, K=64    |         4258.4       |        4445.6      |    2374.0  |        4461.2    
                          f32 B=32, M=197, H=16, K=64    |         7828.9       |        8434.9      |    2376.0  |        8472.2    
                          f16 B=32, M=197, H=16, K=128   |         9025.1       |        9671.9      |    3159.3  |        9707.2    
                          f32 B=32, M=197, H=16, K=128   |        14139.8       |       14920.7      |    3157.7  |       14928.9    
                          f16 B=256, M=197, H=1, K=88    |         4608.9       |        5118.8      |    1478.4  |        5119.2    
                          f32 B=256, M=197, H=1, K=88    |         6174.0       |        6644.3      |    1477.6  |        6642.7    
                          f16 B=16, M=197, H=16, K=88    |         4618.4       |        5073.4      |    1479.5  |        5062.9    
                          f32 B=16, M=197, H=16, K=88    |         6114.2       |        6471.7      |    1474.9  |        6478.7    
                          f16 B=16, M=197, H=16, K=64    |         2490.3       |        2557.5      |    1225.9  |        2550.0    
                          f32 B=16, M=197, H=16, K=64    |         3918.9       |        4208.2      |    1227.0  |        4195.7    
                          f16 B=16, M=197, H=16, K=128   |         5210.0       |        5649.3      |    1635.2  |        5648.1    
                          f32 B=16, M=197, H=16, K=128   |         7103.1       |        7451.6      |    1645.5  |        7445.1    
                          f16 B=1, M=4096, H=160, K=128  |      1014229.1       |     1106182.6      |            |     1108040.6    
                          f32 B=1, M=4096, H=160, K=128  |      1258173.2       |     1243183.8      |            |     1241548.7    
                          f16 B=2, M=4096, H=160, K=128  |      1642279.2       |     1753736.9      |            |     1771655.3    
                          f32 B=2, M=4096, H=160, K=128  |      2505435.4       |     2477353.1      |            |     2473773.4    
                          f16 B=1, M=8192, H=160, K=128  |      4050128.4       |     4415962.1      |            |     4428194.9    
                          f32 B=1, M=8192, H=160, K=128  |      5042352.6       |     4970582.0      |            |     4965069.9    
                          f16 B=2, M=8192, H=160, K=128  |      6600732.3       |     7026378.5      |            |     7068051.8    
                          f16 B=1024, M=82, H=8, K=64    |        21572.8       |       22531.6      |    9059.1  |       22418.4    
                          f32 B=1024, M=82, H=8, K=64    |        38178.4       |       45927.2      |    9070.3  |       45708.0    
                          f16 B=150, M=256, H=16, K=64   |        21436.5       |       21927.4      |   12938.5  |       22001.0    
                          f32 B=150, M=256, H=16, K=64   |        33024.2       |       33196.3      |   13249.2  |       33199.6    
                          f16 B=64, M=256, H=12, K=64    |         6869.8       |        7048.6      |    4200.6  |        7073.6    
                          f32 B=64, M=256, H=12, K=64    |        10719.9       |       10832.1      |    4271.3  |       10843.6    
                          f16 B=1, M=4096, H=16, K=40    |       134722.6       |      145429.8      |   20587.0  |      143743.8    
                          f32 B=1, M=4096, H=16, K=40    |       143015.1       |      147850.6      |   20625.8  |      148272.8    
                          f16 B=1, M=16384, H=16, K=40   |      2149850.4       |     2323732.4      |            |     2301489.4    
                          f32 B=1, M=16384, H=16, K=40   |      2286478.1       |     2369812.4      |            |     2375911.8    
                          f16 B=16, M=128, H=16, K=16    |          497.2       |         502.9      |     623.8  |         503.6    
                          f32 B=16, M=128, H=16, K=16    |          573.5       |         609.8      |     617.5  |         611.6    
                          f16 B=16, M=128, H=16, K=32    |          563.9       |         573.1      |     624.2  |         575.2    
                          f32 B=16, M=128, H=16, K=32    |          661.5       |         702.2      |     620.4  |         703.2    
                          f16 B=16, M=128, H=16, K=64    |          708.9       |         722.6      |     619.9  |         724.3    
                          f32 B=16, M=128, H=16, K=64    |          916.2       |         953.5      |     618.6  |         953.6    
                          f16 B=16, M=128, H=16, K=128   |         1465.8       |        1542.8      |     624.2  |        1545.7    
                          f32 B=16, M=128, H=16, K=128   |         1829.1       |        1872.3      |     616.2  |        1873.9    
                          f16 B=16, M=128, H=16, K=256   |         3796.3       |        4002.5      |    1010.7  |        4008.4    
                          f32 B=16, M=128, H=16, K=256   |         3838.8       |        3957.1      |    1203.6  |        3951.8    
                          f16 B=16, M=512, H=16, K=16    |         7680.4       |        7775.1      |    4848.7  |        7752.8    
                          f32 B=16, M=512, H=16, K=16    |         9402.0       |        9926.1      |    4926.7  |        9914.5    
                          f16 B=16, M=512, H=16, K=32    |         8897.5       |        8999.0      |    5055.9  |        9003.9    
                          f32 B=16, M=512, H=16, K=32    |        10762.1       |       11320.9      |    5065.4  |       11341.8    
                          f16 B=16, M=512, H=16, K=64    |        10936.9       |       11214.8      |    5484.9  |       11254.4    
                          f32 B=16, M=512, H=16, K=64    |        15091.9       |       15210.3      |    5552.0  |       15196.3    
                          f16 B=16, M=512, H=16, K=128   |        23491.0       |       25234.8      |    7317.1  |       25300.3    
                          f32 B=16, M=512, H=16, K=128   |        30524.6       |       30320.0      |    7487.0  |       30200.6    
                          f16 B=16, M=512, H=16, K=256   |        50389.7       |       54525.2      |   14015.3  |       54931.1    
                          f32 B=16, M=512, H=16, K=256   |        62155.6       |       61046.4      |   14285.5  |       61081.1    
                          f16 B=16, M=1024, H=16, K=16   |        31289.9       |       31951.9      |   18778.4  |       32044.9    
                          f32 B=16, M=1024, H=16, K=16   |        37744.4       |       39586.6      |   18929.9  |       39739.9    
                          f16 B=16, M=1024, H=16, K=32   |        35770.8       |       36651.1      |   19620.2  |       36909.0    
                          f32 B=16, M=1024, H=16, K=32   |        43211.5       |       45544.5      |   19518.4  |       45664.8    
                          f16 B=16, M=1024, H=16, K=64   |        43865.1       |       44864.8      |   21286.0  |       45345.7    
                          f32 B=16, M=1024, H=16, K=64   |        60710.5       |       60966.9      |   21634.1  |       60921.9    
                          f16 B=16, M=1024, H=16, K=128  |        94633.7       |      101796.0      |   28502.1  |      102871.6    
                          f32 B=16, M=1024, H=16, K=128  |       124093.6       |      122196.9      |   28520.3  |      122043.0    
                          f16 B=16, M=1024, H=16, K=256  |       194780.7       |      212303.0      |   55419.1  |      214643.5    
                          f32 B=16, M=1024, H=16, K=256  |       250799.1       |      245196.2      |   55634.0  |      245534.2    
                          f16 B=64, M=128, H=16, K=16    |         1658.0       |        1661.4      |    1331.0  |        1662.7    
                          f32 B=64, M=128, H=16, K=16    |         2129.2       |        2266.2      |    1371.8  |        2269.4    
                          f16 B=64, M=128, H=16, K=32    |         1904.3       |        1903.5      |    1384.2  |        1906.4    
                          f32 B=64, M=128, H=16, K=32    |         2496.5       |        2643.4      |    1445.1  |        2639.8    
                          f16 B=64, M=128, H=16, K=64    |         2393.1       |        2432.3      |    1505.2  |        2437.8    
                          f32 B=64, M=128, H=16, K=64    |         3471.1       |        3561.2      |    1590.0  |        3565.7    
                          f16 B=64, M=128, H=16, K=128   |         4969.1       |        5266.3      |    1988.4  |        5272.4    
                          f32 B=64, M=128, H=16, K=128   |         6880.3       |        7040.6      |    2121.0  |        7024.3    
                          f16 B=64, M=128, H=16, K=256   |        12635.3       |       13334.7      |    3859.5  |       13350.7    
                          f32 B=64, M=128, H=16, K=256   |        14185.7       |       14514.4      |    4553.8  |       14503.1    
                          f16 B=64, M=512, H=16, K=16    |        26278.4       |       26189.9      |   18916.0  |       26110.0    
                          f32 B=64, M=512, H=16, K=16    |        35128.2       |       37075.9      |   19191.4  |       37174.5    
                          f16 B=64, M=512, H=16, K=32    |        30414.2       |       30938.2      |   19734.7  |       31071.7    
                          f32 B=64, M=512, H=16, K=32    |        40483.2       |       42922.5      |   19843.9  |       42868.5    
                          f16 B=64, M=512, H=16, K=64    |        37248.2       |       38179.7      |   21640.1  |       38327.0    
                          f32 B=64, M=512, H=16, K=64    |        57666.8       |       57675.3      |   21909.1  |       57697.2    
                          f16 B=64, M=512, H=16, K=128   |        80113.6       |       86165.7      |   28765.5  |       86368.3    
                          f32 B=64, M=512, H=16, K=128   |       115672.5       |      115161.0      |   28910.3  |      115320.0    
                          f16 B=64, M=512, H=16, K=256   |       169250.0       |      183791.8      |   56315.7  |      183735.0    
                          f32 B=64, M=512, H=16, K=256   |       236594.6       |      233093.6      |   56853.4  |      233170.2    
                          f16 B=64, M=1024, H=16, K=16   |       106022.3       |      109588.2      |   74303.6  |      109410.4    
                          f32 B=64, M=1024, H=16, K=16   |       141241.8       |      148854.1      |            |      149651.5    
                          f16 B=64, M=1024, H=16, K=32   |       120899.1       |      125044.6      |   77828.6  |      125716.9    
                          f32 B=64, M=1024, H=16, K=32   |       162478.5       |      173906.1      |            |      173216.4    
                          f16 B=64, M=1024, H=16, K=64   |       149044.1       |      152290.6      |   85821.6  |      152748.5    
                          f32 B=64, M=1024, H=16, K=64   |       233195.9       |      231479.6      |            |      231533.9    
                          f16 B=64, M=1024, H=16, K=128  |       319761.5       |      344076.5      |  113579.4  |      345058.5    
                          f32 B=64, M=1024, H=16, K=128  |       470172.3       |      466330.7      |            |      463507.4    
                          f16 B=64, M=1024, H=16, K=256  |       658362.1       |      717070.2      |            |      723057.4    
                          f32 B=64, M=1024, H=16, K=256  |       955624.0       |      935114.3      |            |      935945.4    
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1811.3       |        1686.3      |    1375.2  |        1699.6    
                          f32 B=384, M=197, H=1, K=88    |         4315.7       |        4665.1      |    2257.1  |        4663.1    
                          f16 B=384, M=197, H=1, K=80    |         1733.0       |        1616.3      |    1281.7  |        1619.8    
                          f32 B=384, M=197, H=1, K=80    |         3965.0       |        4226.4      |    2171.7  |        4228.3    
                          f16 B=384, M=197, H=1, K=64    |         1135.2       |        1084.4      |    1043.8  |        1083.0    
                          f32 B=384, M=197, H=1, K=64    |         2673.2       |        2883.0      |    1744.5  |        2878.2    
                          f16 B=1024, M=197, H=1, K=88   |         4721.0       |        4396.6      |    3725.3  |        4404.1    
                          f32 B=1024, M=197, H=1, K=88   |        10531.3       |       11443.8      |    6106.5  |       11464.2    
                          f16 B=1024, M=197, H=1, K=80   |         4520.2       |        4216.2      |    3329.2  |        4223.7    
                          f32 B=1024, M=197, H=1, K=80   |         9573.5       |       10301.4      |    5757.6  |       10305.3    
                          f16 B=1024, M=197, H=1, K=64   |         2788.1       |        2660.1      |    2674.6  |        2663.5    
                          f32 B=1024, M=197, H=1, K=64   |         6556.7       |        7102.4      |    4516.6  |        7096.8    
                          f16 B=512, M=197, H=1, K=80    |         2377.2       |        2228.1      |    1685.0  |        2231.4    
                          f32 B=512, M=197, H=1, K=80    |         5259.3       |        5639.7      |    2887.8  |        5636.6    
                          f16 B=32, M=197, H=16, K=80    |         2403.0       |        2201.1      |    1798.2  |        2204.9    
                          f32 B=32, M=197, H=16, K=80    |         5402.3       |        5667.7      |    3046.3  |        5662.4    
                          f16 B=32, M=197, H=16, K=64    |         1552.7       |        1486.3      |    1451.6  |        1485.2    
                          f32 B=32, M=197, H=16, K=64    |         3622.5       |        3911.0      |    2427.0  |        3912.9    
                          f16 B=32, M=197, H=16, K=128   |         2776.6       |        2611.9      |    2211.3  |        2613.3    
                          f32 B=32, M=197, H=16, K=128   |         6647.9       |        7082.3      |    4088.5  |        7104.8    
                          f16 B=256, M=197, H=1, K=88    |         1357.5       |        1285.5      |     941.3  |        1287.6    
                          f32 B=256, M=197, H=1, K=88    |         2874.1       |        3085.1      |    1543.2  |        3090.1    
                          f16 B=16, M=197, H=16, K=88    |         1349.1       |        1264.5      |     964.7  |        1263.1    
                          f32 B=16, M=197, H=16, K=88    |         2803.8       |        2967.0      |    1647.4  |        2972.0    
                          f16 B=16, M=197, H=16, K=64    |          765.5       |         728.4      |     765.3  |         731.0    
                          f32 B=16, M=197, H=16, K=64    |         1834.1       |        1969.7      |    1282.4  |        1974.4    
                          f16 B=16, M=197, H=16, K=128   |         1509.1       |        1432.6      |    1139.1  |        1433.9    
                          f32 B=16, M=197, H=16, K=128   |         3406.5       |        3606.2      |    2048.6  |        3613.3    
                          f16 B=1, M=4096, H=160, K=128  |       168807.1       |      148652.9      |            |      149343.8    
                          f32 B=1, M=4096, H=160, K=128  |       549864.6       |      586699.9      |            |      585699.9    
                          f16 B=2, M=4096, H=160, K=128  |       339010.5       |      298827.7      |            |      298808.4    
                          f32 B=2, M=4096, H=160, K=128  |      1106963.8       |     1176218.3      |            |     1179173.2    
                          f16 B=1, M=8192, H=160, K=128  |       679742.4       |      594323.2      |            |      595580.1    
                          f32 B=1, M=8192, H=160, K=128  |      2195491.3       |     2340248.4      |            |     2343505.7    
                          f16 B=2, M=8192, H=160, K=128  |      1364983.7       |     1193787.8      |            |     1192596.1    
                          f16 B=1024, M=82, H=8, K=64    |         9052.8       |        8762.6      |    5804.2  |        8757.8    
                          f32 B=1024, M=82, H=8, K=64    |        14726.3       |       16270.6      |   11059.7  |       16215.9    
                          f16 B=150, M=256, H=16, K=64   |         5662.9       |        5519.4      |    7557.8  |        5524.5    
                          f32 B=150, M=256, H=16, K=64   |        16700.3       |       17612.7      |   16426.0  |       17640.4    
                          f16 B=64, M=256, H=12, K=64    |         1849.1       |        1793.1      |    2383.5  |        1798.7    
                          f32 B=64, M=256, H=12, K=64    |         5451.5       |        5766.4      |    4975.3  |        5775.8    
                          f16 B=1, M=4096, H=16, K=40    |        47263.3       |       47850.4      |    8315.6  |       47777.2    
                          f32 B=1, M=4096, H=16, K=40    |       113099.4       |      113164.5      |   19536.0  |      113930.0    
                          f16 B=1, M=16384, H=16, K=40   |       757091.8       |      770365.7      |            |      765401.3    
                          f32 B=1, M=16384, H=16, K=40   |      1806827.7       |     1816302.6      |            |     1819162.1    
                          f16 B=16, M=128, H=16, K=16    |          219.2       |         218.5      |     480.6  |         231.8    
                          f32 B=16, M=128, H=16, K=16    |          301.2       |         308.2      |     498.9  |         307.9    
                          f16 B=16, M=128, H=16, K=32    |          227.6       |         215.4      |     473.8  |         220.1    
                          f32 B=16, M=128, H=16, K=32    |          395.8       |         401.0      |     455.1  |         400.8    
                          f16 B=16, M=128, H=16, K=64    |          225.6       |         214.8      |     510.1  |         229.9    
                          f32 B=16, M=128, H=16, K=64    |          561.7       |         583.1      |     598.9  |         581.2    
                          f16 B=16, M=128, H=16, K=128   |          404.6       |         392.0      |     524.0  |         394.8    
                          f32 B=16, M=128, H=16, K=128   |         1103.0       |        1140.8      |    1015.4  |        1142.0    
                          f16 B=16, M=128, H=16, K=256   |         1045.3       |        1049.1      |     889.6  |        1047.4    
                          f32 B=16, M=128, H=16, K=256   |         2181.3       |        2270.9      |    1869.2  |        2265.7    
                          f16 B=16, M=512, H=16, K=16    |         1731.5       |        1585.7      |    1908.5  |        1586.2    
                          f32 B=16, M=512, H=16, K=16    |         4513.8       |        4696.7      |    4222.1  |        4695.6    
                          f16 B=16, M=512, H=16, K=32    |         1942.2       |        1823.4      |    2086.3  |        1809.7    
                          f32 B=16, M=512, H=16, K=32    |         5596.1       |        5819.4      |    4588.1  |        5833.6    
                          f16 B=16, M=512, H=16, K=64    |         2450.2       |        2340.8      |    2580.6  |        2353.2    
                          f32 B=16, M=512, H=16, K=64    |         7619.5       |        7853.1      |    5536.8  |        7875.1    
                          f16 B=16, M=512, H=16, K=128   |         4884.8       |        4473.5      |    3388.7  |        4487.9    
                          f32 B=16, M=512, H=16, K=128   |        15010.0       |       15557.1      |    8979.4  |       15513.1    
                          f16 B=16, M=512, H=16, K=256   |        12973.9       |       11134.7      |    5418.2  |       11106.3    
                          f32 B=16, M=512, H=16, K=256   |        29751.1       |       30979.1      |   16856.3  |       31012.8    
                          f16 B=16, M=1024, H=16, K=16   |         6802.6       |        6185.1      |    6996.0  |        6191.0    
                          f32 B=16, M=1024, H=16, K=16   |        18098.2       |       18954.5      |   16129.1  |       19166.8    
                          f16 B=16, M=1024, H=16, K=32   |         7531.9       |        7065.0      |    7436.0  |        7067.5    
                          f32 B=16, M=1024, H=16, K=32   |        21999.9       |       22901.5      |   17040.0  |       22899.4    
                          f16 B=16, M=1024, H=16, K=64   |         9312.0       |        8854.3      |    8605.6  |        8863.9    
                          f32 B=16, M=1024, H=16, K=64   |        29837.7       |       30895.1      |   20355.3  |       30878.5    
                          f16 B=16, M=1024, H=16, K=128  |        18979.0       |       16951.0      |   10561.3  |       16995.2    
                          f32 B=16, M=1024, H=16, K=128  |        58738.3       |       60861.2      |   33427.0  |       60599.8    
                          f16 B=16, M=1024, H=16, K=256  |        49681.9       |       41833.3      |   17329.5  |       41921.6    
                          f32 B=16, M=1024, H=16, K=256  |       117362.4       |      121004.8      |   60515.8  |      122046.9    
                          f16 B=64, M=128, H=16, K=16    |          432.2       |         411.1      |     642.1  |         411.3    
                          f32 B=64, M=128, H=16, K=16    |         1028.9       |        1057.9      |    1233.4  |        1056.6    
                          f16 B=64, M=128, H=16, K=32    |          522.5       |         500.7      |     813.6  |         499.7    
                          f32 B=64, M=128, H=16, K=32    |         1403.5       |        1443.4      |    1535.6  |        1443.6    
                          f16 B=64, M=128, H=16, K=64    |          750.0       |         739.8      |    1185.6  |         741.0    
                          f32 B=64, M=128, H=16, K=64    |         2013.9       |        2110.1      |    2156.8  |        2105.4    
                          f16 B=64, M=128, H=16, K=128   |         1421.2       |        1387.9      |    1915.5  |        1388.3    
                          f32 B=64, M=128, H=16, K=128   |         3946.5       |        4156.7      |    3780.1  |        4156.3    
                          f16 B=64, M=128, H=16, K=256   |         3811.5       |        3810.6      |    3448.7  |        3811.0    
                          f32 B=64, M=128, H=16, K=256   |         7983.9       |        8432.7      |    7304.0  |        8415.5    
                          f16 B=64, M=512, H=16, K=16    |         6157.4       |        5523.9      |    7461.8  |        5528.1    
                          f32 B=64, M=512, H=16, K=16    |        16118.2       |       17004.7      |   16651.7  |       16984.5    
                          f16 B=64, M=512, H=16, K=32    |         6985.1       |        6483.1      |    8278.5  |        6495.1    
                          f32 B=64, M=512, H=16, K=32    |        20471.9       |       21230.7      |   18420.5  |       21269.0    
                          f16 B=64, M=512, H=16, K=64    |         9045.8       |        8633.5      |   10337.1  |        8674.1    
                          f32 B=64, M=512, H=16, K=64    |        27675.0       |       29282.8      |   22883.7  |       29136.9    
                          f16 B=64, M=512, H=16, K=128   |        17594.2       |       15805.9      |   14700.6  |       15788.2    
                          f32 B=64, M=512, H=16, K=128   |        54612.4       |       57815.7      |   39974.3  |       57951.8    
                          f16 B=64, M=512, H=16, K=256   |        47452.6       |       40093.5      |   27087.5  |       40175.6    
                          f32 B=64, M=512, H=16, K=256   |       108880.3       |      115953.1      |   77794.2  |      115951.0    
                          f16 B=64, M=1024, H=16, K=16   |        24369.0       |       21533.0      |   28448.6  |       21556.6    
                          f32 B=64, M=1024, H=16, K=16   |        64649.7       |       68791.8      |            |       68407.4    
                          f16 B=64, M=1024, H=16, K=32   |        27143.1       |       25683.7      |   30252.7  |       25727.9    
                          f32 B=64, M=1024, H=16, K=32   |        79967.5       |       83351.5      |            |       83084.7    
                          f16 B=64, M=1024, H=16, K=64   |        34667.0       |       32592.7      |   36991.2  |       32659.8    
                          f32 B=64, M=1024, H=16, K=64   |       108282.2       |      113858.1      |            |      114286.0    
                          f16 B=64, M=1024, H=16, K=128  |        68519.5       |       59757.3      |   48834.4  |       59817.7    
                          f32 B=64, M=1024, H=16, K=128  |       215465.9       |      227335.8      |            |      227204.0    
                          f16 B=64, M=1024, H=16, K=256  |       183070.4       |      150960.1      |            |      150947.1    
                          f32 B=64, M=1024, H=16, K=256  |       425832.9       |      453717.2      |            |      453349.2    

Times are in microseconds (us).
```
</details>

<details>
<summary>bw A100 (f32/f16)</summary>

```
[----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------]                                                       
                                     |  48_accf32_69654fdb[cutlass]  |  flash[flshatt]  |  vanilla   |  56_base_02bf6b4e[cutlass]  |  57_tmpT_b516aec4[cutlass]
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |               613.4           |                  |    2264.9  |              612.8          |              578.3        
      f32 B=384, M=197, H=1, K=88    |              2335.2           |                  |    1843.0  |             2438.4          |             2425.1        
      f16 B=384, M=197, H=1, K=80    |               583.6           |                  |    1922.9  |              577.6          |              548.4        
      f32 B=384, M=197, H=1, K=80    |              2241.3           |                  |    1787.8  |             2333.2          |             2333.1        
      f16 B=384, M=197, H=1, K=64    |               405.6           |        232.5     |    1809.8  |              386.4          |              366.1        
      f32 B=384, M=197, H=1, K=64    |              1259.7           |                  |    1675.6  |             1309.6          |             1316.8        
      f16 B=1024, M=197, H=1, K=88   |              1538.2           |                  |    5964.7  |             1550.8          |             1454.4        
      f32 B=1024, M=197, H=1, K=88   |              6031.4           |                  |    4559.5  |             6325.4          |             6332.7        
      f16 B=1024, M=197, H=1, K=80   |              1458.8           |                  |    5038.2  |             1463.8          |             1379.0        
      f32 B=1024, M=197, H=1, K=80   |              5786.2           |                  |    4412.6  |             6059.0          |             6079.9        
      f16 B=1024, M=197, H=1, K=64   |               929.5           |        575.9     |    4735.6  |              862.1          |              821.6        
      f32 B=1024, M=197, H=1, K=64   |              3289.9           |                  |    4119.9  |             3434.4          |             3441.4        
      f16 B=512, M=197, H=1, K=80    |               744.5           |                  |    2544.7  |              735.3          |              695.8        
      f32 B=512, M=197, H=1, K=80    |              2889.4           |                  |    2286.1  |             3008.8          |             3029.0        
      f16 B=32, M=197, H=16, K=80    |               741.6           |                  |    2569.0  |              723.3          |              693.4        
      f32 B=32, M=197, H=16, K=80    |              2878.6           |                  |    2355.3  |             3003.0          |             3024.1        
      f16 B=32, M=197, H=16, K=64    |               478.5           |        295.7     |    2429.2  |              456.1          |              426.7        
      f32 B=32, M=197, H=16, K=64    |              1784.4           |                  |    2196.9  |             1863.7          |             1860.2        
      f16 B=32, M=197, H=16, K=128   |               887.0           |        682.3     |    4492.9  |              857.5          |              853.8        
      f32 B=32, M=197, H=16, K=128   |              3546.6           |                  |    2807.3  |             3734.9          |             3737.1        
      f16 B=256, M=197, H=1, K=88    |               445.1           |                  |    1528.9  |              445.0          |              422.2        
      f32 B=256, M=197, H=1, K=88    |              1678.8           |                  |    1207.6  |             1746.6          |             1752.8        
      f16 B=16, M=197, H=16, K=88    |               441.5           |                  |    1544.2  |              437.3          |              419.5        
      f32 B=16, M=197, H=16, K=88    |              1668.3           |                  |    1250.4  |             1742.7          |             1746.1        
      f16 B=16, M=197, H=16, K=64    |               247.4           |        165.6     |    1242.5  |              233.2          |              217.5        
      f32 B=16, M=197, H=16, K=64    |              1051.1           |                  |    1125.3  |             1099.2          |             1096.6        
      f16 B=16, M=197, H=16, K=128   |               498.4           |        386.2     |    2264.5  |              488.0          |              480.5        
      f32 B=16, M=197, H=16, K=128   |              1950.2           |                  |    1446.7  |             2039.0          |             2028.6        
      f16 B=1, M=4096, H=160, K=128  |             55915.0           |      54620.4     |   45909.5  |            63407.6          |            51298.8        
      f32 B=1, M=4096, H=160, K=128  |            238514.6           |                  |            |           232677.5          |           232672.5        
      f16 B=2, M=4096, H=160, K=128  |             93612.0           |      84238.0     |            |           100433.1          |            84858.1        
      f32 B=2, M=4096, H=160, K=128  |            375037.4           |                  |            |           364234.1          |           364407.2        
      f16 B=1, M=8192, H=160, K=128  |            223261.8           |     215499.8     |            |           251806.6          |           202133.7        
      f32 B=1, M=8192, H=160, K=128  |            946708.9           |                  |            |           924986.2          |           924988.9        
      f16 B=2, M=8192, H=160, K=128  |            367969.2           |     330092.8     |            |           395881.3          |           332000.6        
      f32 B=2, M=8192, H=160, K=128  |           1492031.4           |                  |            |          1448691.1          |          1449146.1        
      f16 B=1024, M=82, H=8, K=64    |              1890.2           |       1620.3     |    3819.7  |             1861.3          |             1764.6        
      f32 B=1024, M=82, H=8, K=64    |              8428.2           |                  |    8735.1  |             8831.3          |             8867.4        
      f16 B=150, M=256, H=16, K=64   |              2292.3           |       1625.3     |    4555.9  |             2109.8          |             2019.4        
      f32 B=150, M=256, H=16, K=64   |              6252.4           |                  |   12948.1  |             6288.9          |             6281.3        
      f16 B=64, M=256, H=12, K=64    |               782.2           |        567.4     |    1498.0  |              731.2          |              699.7        
      f32 B=64, M=256, H=12, K=64    |              2141.2           |                  |    4266.6  |             2160.6          |             2160.2        
      f16 B=1, M=4096, H=16, K=40    |             23504.2           |                  |    4196.1  |            23699.0          |            23008.9        
      f32 B=1, M=4096, H=16, K=40    |             73699.5           |                  |   17755.3  |            73261.8          |            73078.5        
      f16 B=1, M=16384, H=16, K=40   |            391408.9           |                  |            |           439777.3          |           407653.7        
      f32 B=1, M=16384, H=16, K=40   |           1196173.6           |                  |            |          1181547.9          |          1181625.1        
      f16 B=256, M=4096, H=16, K=64  |            733221.8           |     439627.5     |            |           603237.1          |           565905.8        
      f16 B=16, M=128, H=16, K=16    |               130.0           |        113.1     |     265.2  |              125.5          |              124.4        
      f32 B=16, M=128, H=16, K=16    |               161.5           |                  |     373.1  |              160.6          |              162.8        
      f16 B=16, M=128, H=16, K=32    |               125.8           |        111.5     |     263.8  |              122.1          |              125.8        
      f32 B=16, M=128, H=16, K=32    |               189.8           |                  |     412.6  |              196.3          |              196.2        
      f16 B=16, M=128, H=16, K=64    |               126.0           |        112.4     |     265.8  |              120.8          |              125.7        
      f32 B=16, M=128, H=16, K=64    |               272.1           |                  |     498.7  |              285.9          |              283.3        
      f16 B=16, M=128, H=16, K=128   |               181.3           |        158.4     |     298.5  |              178.0          |              186.5        
      f32 B=16, M=128, H=16, K=128   |               509.6           |                  |     673.9  |              521.1          |              521.5        
      f16 B=16, M=128, H=16, K=256   |               774.0           |                  |     541.4  |              775.5          |              757.0        
      f32 B=16, M=128, H=16, K=256   |               975.2           |                  |    1162.6  |              994.5          |              994.5        
      f16 B=16, M=512, H=16, K=16    |               621.0           |        322.6     |    1204.9  |              555.0          |              519.9        
      f32 B=16, M=512, H=16, K=16    |              2148.0           |                  |    4414.4  |             2178.9          |             2180.4        
      f16 B=16, M=512, H=16, K=32    |               709.9           |        435.6     |    1306.2  |              653.1          |              602.8        
      f32 B=16, M=512, H=16, K=32    |              2335.6           |                  |    4640.7  |             2336.0          |             2336.3        
      f16 B=16, M=512, H=16, K=64    |               917.8           |        702.7     |    1545.8  |              849.9          |              797.4        
      f32 B=16, M=512, H=16, K=64    |              2965.5           |                  |    5125.0  |             2986.9          |             2988.3        
      f16 B=16, M=512, H=16, K=128   |              1644.0           |       1584.1     |    1983.4  |             1757.0          |             1548.0        
      f32 B=16, M=512, H=16, K=128   |              6152.7           |                  |    6099.1  |             6067.7          |             6068.6        
      f16 B=16, M=512, H=16, K=256   |              8178.3           |                  |    2899.1  |             7895.5          |             7977.4        
      f32 B=16, M=512, H=16, K=256   |             11894.0           |                  |   10639.5  |            11635.0          |            11624.7        
      f16 B=16, M=1024, H=16, K=16   |              2420.5           |       1240.8     |    4259.2  |             2234.6          |             2048.9        
      f32 B=16, M=1024, H=16, K=16   |              8467.2           |                  |   16650.4  |             8512.2          |             8510.2        
      f16 B=16, M=1024, H=16, K=32   |              2675.7           |       1618.9     |    4491.2  |             2441.3          |             2230.0        
      f32 B=16, M=1024, H=16, K=32   |              9010.0           |                  |   17301.0  |             9012.1          |             9015.7        
      f16 B=16, M=1024, H=16, K=64   |              3328.4           |       2370.3     |    4994.5  |             3032.2          |             2820.0        
      f32 B=16, M=1024, H=16, K=64   |             11566.8           |                  |   18714.2  |            11494.1          |            11492.8        
      f16 B=16, M=1024, H=16, K=128  |              5867.9           |       5632.5     |    5952.8  |             6401.1          |             5440.8        
      f32 B=16, M=1024, H=16, K=128  |             23345.4           |                  |   21523.7  |            22859.1          |            22870.3        
      f16 B=16, M=1024, H=16, K=256  |             30619.2           |                  |    7893.1  |            29884.9          |            29060.4        
      f32 B=16, M=1024, H=16, K=256  |             45211.4           |                  |   38093.0  |            43435.3          |            43423.8        
      f16 B=64, M=128, H=16, K=16    |               159.6           |        145.2     |     439.9  |              161.2          |              167.0        
      f32 B=64, M=128, H=16, K=16    |               493.4           |                  |    1270.0  |              502.7          |              503.2        
      f16 B=64, M=128, H=16, K=32    |               208.5           |        212.1     |     545.3  |              206.2          |              204.9        
      f32 B=64, M=128, H=16, K=32    |               601.2           |                  |    1427.0  |              610.5          |              610.6        
      f16 B=64, M=128, H=16, K=64    |               329.0           |        310.8     |     766.0  |              327.2          |              314.7        
      f32 B=64, M=128, H=16, K=64    |               867.7           |                  |    1743.5  |              889.2          |              889.0        
      f16 B=64, M=128, H=16, K=128   |               635.5           |        562.1     |    1226.7  |              613.3          |              650.7        
      f32 B=64, M=128, H=16, K=128   |              1774.4           |                  |    2386.5  |             1800.0          |             1799.6        
      f16 B=64, M=128, H=16, K=256   |              2839.8           |                  |    2122.8  |             2765.2          |             2821.9        
      f32 B=64, M=128, H=16, K=256   |              3419.2           |                  |    4320.7  |             3458.8          |             3457.0        
      f16 B=64, M=512, H=16, K=16    |              2316.2           |       1202.3     |    4487.1  |             1983.9          |             1868.4        
      f32 B=64, M=512, H=16, K=16    |              6686.0           |                  |   16991.5  |             6709.9          |             6713.2        
      f16 B=64, M=512, H=16, K=32    |              2701.7           |       1541.9     |    4975.9  |             2346.4          |             2184.3        
      f32 B=64, M=512, H=16, K=32    |              7460.5           |                  |   17859.9  |             7429.6          |             7438.0        
      f16 B=64, M=512, H=16, K=64    |              3461.2           |       2418.2     |    5886.1  |             3083.1          |             2942.6        
      f32 B=64, M=512, H=16, K=64    |              9553.4           |                  |   19768.6  |             9526.5          |             9516.6        
      f16 B=64, M=512, H=16, K=128   |              5875.5           |       5443.1     |    7711.0  |             6141.3          |             5513.6        
      f32 B=64, M=512, H=16, K=128   |             21317.0           |                  |   23651.2  |            20931.9          |            20932.3        
      f16 B=64, M=512, H=16, K=256   |             31238.2           |                  |   11490.6  |            28626.8          |            28954.0        
      f32 B=64, M=512, H=16, K=256   |             41124.8           |                  |   42468.0  |            39562.8          |            39579.5        
      f16 B=64, M=1024, H=16, K=16   |              9142.8           |       4707.2     |   16882.8  |             7887.0          |             7314.5        
      f32 B=64, M=1024, H=16, K=16   |             26512.3           |                  |   66311.7  |            26497.5          |            26459.8        
      f16 B=64, M=1024, H=16, K=32   |             10420.8           |       5698.3     |   17875.0  |             8851.4          |             7974.0        
      f32 B=64, M=1024, H=16, K=32   |             28300.0           |                  |   69088.7  |            28102.1          |            28098.2        
      f16 B=64, M=1024, H=16, K=64   |             12948.5           |       8119.0     |   19944.3  |            11064.2          |            10477.6        
      f32 B=64, M=1024, H=16, K=64   |             35600.6           |                  |   74762.8  |            35316.4          |            35361.2        
      f16 B=64, M=1024, H=16, K=128  |             20820.5           |      19184.2     |   23699.3  |            21954.8          |            19220.0        
      f32 B=64, M=1024, H=16, K=128  |             80800.3           |                  |   86003.8  |            78521.9          |            78393.9        
      f16 B=64, M=1024, H=16, K=256  |            114411.1           |                  |   32958.3  |           103287.2          |           104304.2        
      f32 B=64, M=1024, H=16, K=256  |            155731.5           |                  |  153011.0  |           148071.6          |           148165.6        

Times are in microseconds (us).
```
</details>


[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Dec 6, 2022
ghstack-source-id: 165c9a375a81b8c14ee7490c6f590d2f68b23df0
Pull Request resolved: #467
**PERFORMANCE**

This makes performance worse in f16 :(
But I think we need it for stability

<details>
<summary>bw P100/V100 (f32/f16)</summary>

```
[---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------]                                                    
                                                         |  48_accf32_6f8e2f15  |  56_base_02bf6b4e  |  vanilla   |  57_tmpT_b516aec4
1 threads: --------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6289.6       |        6936.0      |    2183.6  |        6940.0    
                          f32 B=384, M=197, H=1, K=88    |         8793.3       |        9446.6      |    2175.1  |        9429.2    
                          f16 B=384, M=197, H=1, K=80    |         5989.9       |        6596.2      |    2146.8  |        6608.6    
                          f32 B=384, M=197, H=1, K=80    |         8427.1       |        8993.1      |    2134.0  |        9030.5    
                          f16 B=384, M=197, H=1, K=64    |         3347.0       |        3527.4      |    1799.3  |        3538.3    
                          f32 B=384, M=197, H=1, K=64    |         5563.1       |        5984.2      |    1801.6  |        5980.9    
                          f16 B=1024, M=197, H=1, K=88   |        15680.5       |       17424.8      |    5671.9  |       17452.6    
                          f32 B=1024, M=197, H=1, K=88   |        23784.2       |       25542.8      |    5664.0  |       25578.6    
                          f16 B=1024, M=197, H=1, K=80   |        14935.5       |       16581.3      |    5559.5  |       16587.3    
                          f32 B=1024, M=197, H=1, K=80   |        22767.6       |       24354.1      |    5550.7  |       24362.0    
                          f16 B=1024, M=197, H=1, K=64   |         8331.1       |        8671.5      |    4644.9  |        8695.3    
                          f32 B=1024, M=197, H=1, K=64   |        15061.6       |       16148.8      |    4650.0  |       16201.4    
                          f16 B=512, M=197, H=1, K=80    |         7594.7       |        8306.6      |    2824.6  |        8336.5    
                          f32 B=512, M=197, H=1, K=80    |        11857.0       |       12660.2      |    2807.4  |       12709.2    
                          f16 B=32, M=197, H=16, K=80    |         7779.8       |        8342.0      |    2820.8  |        8393.5    
                          f32 B=32, M=197, H=16, K=80    |        11846.5       |       12487.8      |    2806.0  |       12549.1    
                          f16 B=32, M=197, H=16, K=64    |         4258.4       |        4445.6      |    2374.0  |        4461.2    
                          f32 B=32, M=197, H=16, K=64    |         7828.9       |        8434.9      |    2376.0  |        8472.2    
                          f16 B=32, M=197, H=16, K=128   |         9025.1       |        9671.9      |    3159.3  |        9707.2    
                          f32 B=32, M=197, H=16, K=128   |        14139.8       |       14920.7      |    3157.7  |       14928.9    
                          f16 B=256, M=197, H=1, K=88    |         4608.9       |        5118.8      |    1478.4  |        5119.2    
                          f32 B=256, M=197, H=1, K=88    |         6174.0       |        6644.3      |    1477.6  |        6642.7    
                          f16 B=16, M=197, H=16, K=88    |         4618.4       |        5073.4      |    1479.5  |        5062.9    
                          f32 B=16, M=197, H=16, K=88    |         6114.2       |        6471.7      |    1474.9  |        6478.7    
                          f16 B=16, M=197, H=16, K=64    |         2490.3       |        2557.5      |    1225.9  |        2550.0    
                          f32 B=16, M=197, H=16, K=64    |         3918.9       |        4208.2      |    1227.0  |        4195.7    
                          f16 B=16, M=197, H=16, K=128   |         5210.0       |        5649.3      |    1635.2  |        5648.1    
                          f32 B=16, M=197, H=16, K=128   |         7103.1       |        7451.6      |    1645.5  |        7445.1    
                          f16 B=1, M=4096, H=160, K=128  |      1014229.1       |     1106182.6      |            |     1108040.6    
                          f32 B=1, M=4096, H=160, K=128  |      1258173.2       |     1243183.8      |            |     1241548.7    
                          f16 B=2, M=4096, H=160, K=128  |      1642279.2       |     1753736.9      |            |     1771655.3    
                          f32 B=2, M=4096, H=160, K=128  |      2505435.4       |     2477353.1      |            |     2473773.4    
                          f16 B=1, M=8192, H=160, K=128  |      4050128.4       |     4415962.1      |            |     4428194.9    
                          f32 B=1, M=8192, H=160, K=128  |      5042352.6       |     4970582.0      |            |     4965069.9    
                          f16 B=2, M=8192, H=160, K=128  |      6600732.3       |     7026378.5      |            |     7068051.8    
                          f16 B=1024, M=82, H=8, K=64    |        21572.8       |       22531.6      |    9059.1  |       22418.4    
                          f32 B=1024, M=82, H=8, K=64    |        38178.4       |       45927.2      |    9070.3  |       45708.0    
                          f16 B=150, M=256, H=16, K=64   |        21436.5       |       21927.4      |   12938.5  |       22001.0    
                          f32 B=150, M=256, H=16, K=64   |        33024.2       |       33196.3      |   13249.2  |       33199.6    
                          f16 B=64, M=256, H=12, K=64    |         6869.8       |        7048.6      |    4200.6  |        7073.6    
                          f32 B=64, M=256, H=12, K=64    |        10719.9       |       10832.1      |    4271.3  |       10843.6    
                          f16 B=1, M=4096, H=16, K=40    |       134722.6       |      145429.8      |   20587.0  |      143743.8    
                          f32 B=1, M=4096, H=16, K=40    |       143015.1       |      147850.6      |   20625.8  |      148272.8    
                          f16 B=1, M=16384, H=16, K=40   |      2149850.4       |     2323732.4      |            |     2301489.4    
                          f32 B=1, M=16384, H=16, K=40   |      2286478.1       |     2369812.4      |            |     2375911.8    
                          f16 B=16, M=128, H=16, K=16    |          497.2       |         502.9      |     623.8  |         503.6    
                          f32 B=16, M=128, H=16, K=16    |          573.5       |         609.8      |     617.5  |         611.6    
                          f16 B=16, M=128, H=16, K=32    |          563.9       |         573.1      |     624.2  |         575.2    
                          f32 B=16, M=128, H=16, K=32    |          661.5       |         702.2      |     620.4  |         703.2    
                          f16 B=16, M=128, H=16, K=64    |          708.9       |         722.6      |     619.9  |         724.3    
                          f32 B=16, M=128, H=16, K=64    |          916.2       |         953.5      |     618.6  |         953.6    
                          f16 B=16, M=128, H=16, K=128   |         1465.8       |        1542.8      |     624.2  |        1545.7    
                          f32 B=16, M=128, H=16, K=128   |         1829.1       |        1872.3      |     616.2  |        1873.9    
                          f16 B=16, M=128, H=16, K=256   |         3796.3       |        4002.5      |    1010.7  |        4008.4    
                          f32 B=16, M=128, H=16, K=256   |         3838.8       |        3957.1      |    1203.6  |        3951.8    
                          f16 B=16, M=512, H=16, K=16    |         7680.4       |        7775.1      |    4848.7  |        7752.8    
                          f32 B=16, M=512, H=16, K=16    |         9402.0       |        9926.1      |    4926.7  |        9914.5    
                          f16 B=16, M=512, H=16, K=32    |         8897.5       |        8999.0      |    5055.9  |        9003.9    
                          f32 B=16, M=512, H=16, K=32    |        10762.1       |       11320.9      |    5065.4  |       11341.8    
                          f16 B=16, M=512, H=16, K=64    |        10936.9       |       11214.8      |    5484.9  |       11254.4    
                          f32 B=16, M=512, H=16, K=64    |        15091.9       |       15210.3      |    5552.0  |       15196.3    
                          f16 B=16, M=512, H=16, K=128   |        23491.0       |       25234.8      |    7317.1  |       25300.3    
                          f32 B=16, M=512, H=16, K=128   |        30524.6       |       30320.0      |    7487.0  |       30200.6    
                          f16 B=16, M=512, H=16, K=256   |        50389.7       |       54525.2      |   14015.3  |       54931.1    
                          f32 B=16, M=512, H=16, K=256   |        62155.6       |       61046.4      |   14285.5  |       61081.1    
                          f16 B=16, M=1024, H=16, K=16   |        31289.9       |       31951.9      |   18778.4  |       32044.9    
                          f32 B=16, M=1024, H=16, K=16   |        37744.4       |       39586.6      |   18929.9  |       39739.9    
                          f16 B=16, M=1024, H=16, K=32   |        35770.8       |       36651.1      |   19620.2  |       36909.0    
                          f32 B=16, M=1024, H=16, K=32   |        43211.5       |       45544.5      |   19518.4  |       45664.8    
                          f16 B=16, M=1024, H=16, K=64   |        43865.1       |       44864.8      |   21286.0  |       45345.7    
                          f32 B=16, M=1024, H=16, K=64   |        60710.5       |       60966.9      |   21634.1  |       60921.9    
                          f16 B=16, M=1024, H=16, K=128  |        94633.7       |      101796.0      |   28502.1  |      102871.6    
                          f32 B=16, M=1024, H=16, K=128  |       124093.6       |      122196.9      |   28520.3  |      122043.0    
                          f16 B=16, M=1024, H=16, K=256  |       194780.7       |      212303.0      |   55419.1  |      214643.5    
                          f32 B=16, M=1024, H=16, K=256  |       250799.1       |      245196.2      |   55634.0  |      245534.2    
                          f16 B=64, M=128, H=16, K=16    |         1658.0       |        1661.4      |    1331.0  |        1662.7    
                          f32 B=64, M=128, H=16, K=16    |         2129.2       |        2266.2      |    1371.8  |        2269.4    
                          f16 B=64, M=128, H=16, K=32    |         1904.3       |        1903.5      |    1384.2  |        1906.4    
                          f32 B=64, M=128, H=16, K=32    |         2496.5       |        2643.4      |    1445.1  |        2639.8    
                          f16 B=64, M=128, H=16, K=64    |         2393.1       |        2432.3      |    1505.2  |        2437.8    
                          f32 B=64, M=128, H=16, K=64    |         3471.1       |        3561.2      |    1590.0  |        3565.7    
                          f16 B=64, M=128, H=16, K=128   |         4969.1       |        5266.3      |    1988.4  |        5272.4    
                          f32 B=64, M=128, H=16, K=128   |         6880.3       |        7040.6      |    2121.0  |        7024.3    
                          f16 B=64, M=128, H=16, K=256   |        12635.3       |       13334.7      |    3859.5  |       13350.7    
                          f32 B=64, M=128, H=16, K=256   |        14185.7       |       14514.4      |    4553.8  |       14503.1    
                          f16 B=64, M=512, H=16, K=16    |        26278.4       |       26189.9      |   18916.0  |       26110.0    
                          f32 B=64, M=512, H=16, K=16    |        35128.2       |       37075.9      |   19191.4  |       37174.5    
                          f16 B=64, M=512, H=16, K=32    |        30414.2       |       30938.2      |   19734.7  |       31071.7    
                          f32 B=64, M=512, H=16, K=32    |        40483.2       |       42922.5      |   19843.9  |       42868.5    
                          f16 B=64, M=512, H=16, K=64    |        37248.2       |       38179.7      |   21640.1  |       38327.0    
                          f32 B=64, M=512, H=16, K=64    |        57666.8       |       57675.3      |   21909.1  |       57697.2    
                          f16 B=64, M=512, H=16, K=128   |        80113.6       |       86165.7      |   28765.5  |       86368.3    
                          f32 B=64, M=512, H=16, K=128   |       115672.5       |      115161.0      |   28910.3  |      115320.0    
                          f16 B=64, M=512, H=16, K=256   |       169250.0       |      183791.8      |   56315.7  |      183735.0    
                          f32 B=64, M=512, H=16, K=256   |       236594.6       |      233093.6      |   56853.4  |      233170.2    
                          f16 B=64, M=1024, H=16, K=16   |       106022.3       |      109588.2      |   74303.6  |      109410.4    
                          f32 B=64, M=1024, H=16, K=16   |       141241.8       |      148854.1      |            |      149651.5    
                          f16 B=64, M=1024, H=16, K=32   |       120899.1       |      125044.6      |   77828.6  |      125716.9    
                          f32 B=64, M=1024, H=16, K=32   |       162478.5       |      173906.1      |            |      173216.4    
                          f16 B=64, M=1024, H=16, K=64   |       149044.1       |      152290.6      |   85821.6  |      152748.5    
                          f32 B=64, M=1024, H=16, K=64   |       233195.9       |      231479.6      |            |      231533.9    
                          f16 B=64, M=1024, H=16, K=128  |       319761.5       |      344076.5      |  113579.4  |      345058.5    
                          f32 B=64, M=1024, H=16, K=128  |       470172.3       |      466330.7      |            |      463507.4    
                          f16 B=64, M=1024, H=16, K=256  |       658362.1       |      717070.2      |            |      723057.4    
                          f32 B=64, M=1024, H=16, K=256  |       955624.0       |      935114.3      |            |      935945.4    
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1811.3       |        1686.3      |    1375.2  |        1699.6    
                          f32 B=384, M=197, H=1, K=88    |         4315.7       |        4665.1      |    2257.1  |        4663.1    
                          f16 B=384, M=197, H=1, K=80    |         1733.0       |        1616.3      |    1281.7  |        1619.8    
                          f32 B=384, M=197, H=1, K=80    |         3965.0       |        4226.4      |    2171.7  |        4228.3    
                          f16 B=384, M=197, H=1, K=64    |         1135.2       |        1084.4      |    1043.8  |        1083.0    
                          f32 B=384, M=197, H=1, K=64    |         2673.2       |        2883.0      |    1744.5  |        2878.2    
                          f16 B=1024, M=197, H=1, K=88   |         4721.0       |        4396.6      |    3725.3  |        4404.1    
                          f32 B=1024, M=197, H=1, K=88   |        10531.3       |       11443.8      |    6106.5  |       11464.2    
                          f16 B=1024, M=197, H=1, K=80   |         4520.2       |        4216.2      |    3329.2  |        4223.7    
                          f32 B=1024, M=197, H=1, K=80   |         9573.5       |       10301.4      |    5757.6  |       10305.3    
                          f16 B=1024, M=197, H=1, K=64   |         2788.1       |        2660.1      |    2674.6  |        2663.5    
                          f32 B=1024, M=197, H=1, K=64   |         6556.7       |        7102.4      |    4516.6  |        7096.8    
                          f16 B=512, M=197, H=1, K=80    |         2377.2       |        2228.1      |    1685.0  |        2231.4    
                          f32 B=512, M=197, H=1, K=80    |         5259.3       |        5639.7      |    2887.8  |        5636.6    
                          f16 B=32, M=197, H=16, K=80    |         2403.0       |        2201.1      |    1798.2  |        2204.9    
                          f32 B=32, M=197, H=16, K=80    |         5402.3       |        5667.7      |    3046.3  |        5662.4    
                          f16 B=32, M=197, H=16, K=64    |         1552.7       |        1486.3      |    1451.6  |        1485.2    
                          f32 B=32, M=197, H=16, K=64    |         3622.5       |        3911.0      |    2427.0  |        3912.9    
                          f16 B=32, M=197, H=16, K=128   |         2776.6       |        2611.9      |    2211.3  |        2613.3    
                          f32 B=32, M=197, H=16, K=128   |         6647.9       |        7082.3      |    4088.5  |        7104.8    
                          f16 B=256, M=197, H=1, K=88    |         1357.5       |        1285.5      |     941.3  |        1287.6    
                          f32 B=256, M=197, H=1, K=88    |         2874.1       |        3085.1      |    1543.2  |        3090.1    
                          f16 B=16, M=197, H=16, K=88    |         1349.1       |        1264.5      |     964.7  |        1263.1    
                          f32 B=16, M=197, H=16, K=88    |         2803.8       |        2967.0      |    1647.4  |        2972.0    
                          f16 B=16, M=197, H=16, K=64    |          765.5       |         728.4      |     765.3  |         731.0    
                          f32 B=16, M=197, H=16, K=64    |         1834.1       |        1969.7      |    1282.4  |        1974.4    
                          f16 B=16, M=197, H=16, K=128   |         1509.1       |        1432.6      |    1139.1  |        1433.9    
                          f32 B=16, M=197, H=16, K=128   |         3406.5       |        3606.2      |    2048.6  |        3613.3    
                          f16 B=1, M=4096, H=160, K=128  |       168807.1       |      148652.9      |            |      149343.8    
                          f32 B=1, M=4096, H=160, K=128  |       549864.6       |      586699.9      |            |      585699.9    
                          f16 B=2, M=4096, H=160, K=128  |       339010.5       |      298827.7      |            |      298808.4    
                          f32 B=2, M=4096, H=160, K=128  |      1106963.8       |     1176218.3      |            |     1179173.2    
                          f16 B=1, M=8192, H=160, K=128  |       679742.4       |      594323.2      |            |      595580.1    
                          f32 B=1, M=8192, H=160, K=128  |      2195491.3       |     2340248.4      |            |     2343505.7    
                          f16 B=2, M=8192, H=160, K=128  |      1364983.7       |     1193787.8      |            |     1192596.1    
                          f16 B=1024, M=82, H=8, K=64    |         9052.8       |        8762.6      |    5804.2  |        8757.8    
                          f32 B=1024, M=82, H=8, K=64    |        14726.3       |       16270.6      |   11059.7  |       16215.9    
                          f16 B=150, M=256, H=16, K=64   |         5662.9       |        5519.4      |    7557.8  |        5524.5    
                          f32 B=150, M=256, H=16, K=64   |        16700.3       |       17612.7      |   16426.0  |       17640.4    
                          f16 B=64, M=256, H=12, K=64    |         1849.1       |        1793.1      |    2383.5  |        1798.7    
                          f32 B=64, M=256, H=12, K=64    |         5451.5       |        5766.4      |    4975.3  |        5775.8    
                          f16 B=1, M=4096, H=16, K=40    |        47263.3       |       47850.4      |    8315.6  |       47777.2    
                          f32 B=1, M=4096, H=16, K=40    |       113099.4       |      113164.5      |   19536.0  |      113930.0    
                          f16 B=1, M=16384, H=16, K=40   |       757091.8       |      770365.7      |            |      765401.3    
                          f32 B=1, M=16384, H=16, K=40   |      1806827.7       |     1816302.6      |            |     1819162.1    
                          f16 B=16, M=128, H=16, K=16    |          219.2       |         218.5      |     480.6  |         231.8    
                          f32 B=16, M=128, H=16, K=16    |          301.2       |         308.2      |     498.9  |         307.9    
                          f16 B=16, M=128, H=16, K=32    |          227.6       |         215.4      |     473.8  |         220.1    
                          f32 B=16, M=128, H=16, K=32    |          395.8       |         401.0      |     455.1  |         400.8    
                          f16 B=16, M=128, H=16, K=64    |          225.6       |         214.8      |     510.1  |         229.9    
                          f32 B=16, M=128, H=16, K=64    |          561.7       |         583.1      |     598.9  |         581.2    
                          f16 B=16, M=128, H=16, K=128   |          404.6       |         392.0      |     524.0  |         394.8    
                          f32 B=16, M=128, H=16, K=128   |         1103.0       |        1140.8      |    1015.4  |        1142.0    
                          f16 B=16, M=128, H=16, K=256   |         1045.3       |        1049.1      |     889.6  |        1047.4    
                          f32 B=16, M=128, H=16, K=256   |         2181.3       |        2270.9      |    1869.2  |        2265.7    
                          f16 B=16, M=512, H=16, K=16    |         1731.5       |        1585.7      |    1908.5  |        1586.2    
                          f32 B=16, M=512, H=16, K=16    |         4513.8       |        4696.7      |    4222.1  |        4695.6    
                          f16 B=16, M=512, H=16, K=32    |         1942.2       |        1823.4      |    2086.3  |        1809.7    
                          f32 B=16, M=512, H=16, K=32    |         5596.1       |        5819.4      |    4588.1  |        5833.6    
                          f16 B=16, M=512, H=16, K=64    |         2450.2       |        2340.8      |    2580.6  |        2353.2    
                          f32 B=16, M=512, H=16, K=64    |         7619.5       |        7853.1      |    5536.8  |        7875.1    
                          f16 B=16, M=512, H=16, K=128   |         4884.8       |        4473.5      |    3388.7  |        4487.9    
                          f32 B=16, M=512, H=16, K=128   |        15010.0       |       15557.1      |    8979.4  |       15513.1    
                          f16 B=16, M=512, H=16, K=256   |        12973.9       |       11134.7      |    5418.2  |       11106.3    
                          f32 B=16, M=512, H=16, K=256   |        29751.1       |       30979.1      |   16856.3  |       31012.8    
                          f16 B=16, M=1024, H=16, K=16   |         6802.6       |        6185.1      |    6996.0  |        6191.0    
                          f32 B=16, M=1024, H=16, K=16   |        18098.2       |       18954.5      |   16129.1  |       19166.8    
                          f16 B=16, M=1024, H=16, K=32   |         7531.9       |        7065.0      |    7436.0  |        7067.5    
                          f32 B=16, M=1024, H=16, K=32   |        21999.9       |       22901.5      |   17040.0  |       22899.4    
                          f16 B=16, M=1024, H=16, K=64   |         9312.0       |        8854.3      |    8605.6  |        8863.9    
                          f32 B=16, M=1024, H=16, K=64   |        29837.7       |       30895.1      |   20355.3  |       30878.5    
                          f16 B=16, M=1024, H=16, K=128  |        18979.0       |       16951.0      |   10561.3  |       16995.2    
                          f32 B=16, M=1024, H=16, K=128  |        58738.3       |       60861.2      |   33427.0  |       60599.8    
                          f16 B=16, M=1024, H=16, K=256  |        49681.9       |       41833.3      |   17329.5  |       41921.6    
                          f32 B=16, M=1024, H=16, K=256  |       117362.4       |      121004.8      |   60515.8  |      122046.9    
                          f16 B=64, M=128, H=16, K=16    |          432.2       |         411.1      |     642.1  |         411.3    
                          f32 B=64, M=128, H=16, K=16    |         1028.9       |        1057.9      |    1233.4  |        1056.6    
                          f16 B=64, M=128, H=16, K=32    |          522.5       |         500.7      |     813.6  |         499.7    
                          f32 B=64, M=128, H=16, K=32    |         1403.5       |        1443.4      |    1535.6  |        1443.6    
                          f16 B=64, M=128, H=16, K=64    |          750.0       |         739.8      |    1185.6  |         741.0    
                          f32 B=64, M=128, H=16, K=64    |         2013.9       |        2110.1      |    2156.8  |        2105.4    
                          f16 B=64, M=128, H=16, K=128   |         1421.2       |        1387.9      |    1915.5  |        1388.3    
                          f32 B=64, M=128, H=16, K=128   |         3946.5       |        4156.7      |    3780.1  |        4156.3    
                          f16 B=64, M=128, H=16, K=256   |         3811.5       |        3810.6      |    3448.7  |        3811.0    
                          f32 B=64, M=128, H=16, K=256   |         7983.9       |        8432.7      |    7304.0  |        8415.5    
                          f16 B=64, M=512, H=16, K=16    |         6157.4       |        5523.9      |    7461.8  |        5528.1    
                          f32 B=64, M=512, H=16, K=16    |        16118.2       |       17004.7      |   16651.7  |       16984.5    
                          f16 B=64, M=512, H=16, K=32    |         6985.1       |        6483.1      |    8278.5  |        6495.1    
                          f32 B=64, M=512, H=16, K=32    |        20471.9       |       21230.7      |   18420.5  |       21269.0    
                          f16 B=64, M=512, H=16, K=64    |         9045.8       |        8633.5      |   10337.1  |        8674.1    
                          f32 B=64, M=512, H=16, K=64    |        27675.0       |       29282.8      |   22883.7  |       29136.9    
                          f16 B=64, M=512, H=16, K=128   |        17594.2       |       15805.9      |   14700.6  |       15788.2    
                          f32 B=64, M=512, H=16, K=128   |        54612.4       |       57815.7      |   39974.3  |       57951.8    
                          f16 B=64, M=512, H=16, K=256   |        47452.6       |       40093.5      |   27087.5  |       40175.6    
                          f32 B=64, M=512, H=16, K=256   |       108880.3       |      115953.1      |   77794.2  |      115951.0    
                          f16 B=64, M=1024, H=16, K=16   |        24369.0       |       21533.0      |   28448.6  |       21556.6    
                          f32 B=64, M=1024, H=16, K=16   |        64649.7       |       68791.8      |            |       68407.4    
                          f16 B=64, M=1024, H=16, K=32   |        27143.1       |       25683.7      |   30252.7  |       25727.9    
                          f32 B=64, M=1024, H=16, K=32   |        79967.5       |       83351.5      |            |       83084.7    
                          f16 B=64, M=1024, H=16, K=64   |        34667.0       |       32592.7      |   36991.2  |       32659.8    
                          f32 B=64, M=1024, H=16, K=64   |       108282.2       |      113858.1      |            |      114286.0    
                          f16 B=64, M=1024, H=16, K=128  |        68519.5       |       59757.3      |   48834.4  |       59817.7    
                          f32 B=64, M=1024, H=16, K=128  |       215465.9       |      227335.8      |            |      227204.0    
                          f16 B=64, M=1024, H=16, K=256  |       183070.4       |      150960.1      |            |      150947.1    
                          f32 B=64, M=1024, H=16, K=256  |       425832.9       |      453717.2      |            |      453349.2    

Times are in microseconds (us).
```
</details>

<details>
<summary>bw A100 (f32/f16)</summary>

```
[----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------]                                                       
                                     |  48_accf32_69654fdb[cutlass]  |  flash[flshatt]  |  vanilla   |  56_base_02bf6b4e[cutlass]  |  57_tmpT_b516aec4[cutlass]
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |               613.4           |                  |    2264.9  |              612.8          |              578.3        
      f32 B=384, M=197, H=1, K=88    |              2335.2           |                  |    1843.0  |             2438.4          |             2425.1        
      f16 B=384, M=197, H=1, K=80    |               583.6           |                  |    1922.9  |              577.6          |              548.4        
      f32 B=384, M=197, H=1, K=80    |              2241.3           |                  |    1787.8  |             2333.2          |             2333.1        
      f16 B=384, M=197, H=1, K=64    |               405.6           |        232.5     |    1809.8  |              386.4          |              366.1        
      f32 B=384, M=197, H=1, K=64    |              1259.7           |                  |    1675.6  |             1309.6          |             1316.8        
      f16 B=1024, M=197, H=1, K=88   |              1538.2           |                  |    5964.7  |             1550.8          |             1454.4        
      f32 B=1024, M=197, H=1, K=88   |              6031.4           |                  |    4559.5  |             6325.4          |             6332.7        
      f16 B=1024, M=197, H=1, K=80   |              1458.8           |                  |    5038.2  |             1463.8          |             1379.0        
      f32 B=1024, M=197, H=1, K=80   |              5786.2           |                  |    4412.6  |             6059.0          |             6079.9        
      f16 B=1024, M=197, H=1, K=64   |               929.5           |        575.9     |    4735.6  |              862.1          |              821.6        
      f32 B=1024, M=197, H=1, K=64   |              3289.9           |                  |    4119.9  |             3434.4          |             3441.4        
      f16 B=512, M=197, H=1, K=80    |               744.5           |                  |    2544.7  |              735.3          |              695.8        
      f32 B=512, M=197, H=1, K=80    |              2889.4           |                  |    2286.1  |             3008.8          |             3029.0        
      f16 B=32, M=197, H=16, K=80    |               741.6           |                  |    2569.0  |              723.3          |              693.4        
      f32 B=32, M=197, H=16, K=80    |              2878.6           |                  |    2355.3  |             3003.0          |             3024.1        
      f16 B=32, M=197, H=16, K=64    |               478.5           |        295.7     |    2429.2  |              456.1          |              426.7        
      f32 B=32, M=197, H=16, K=64    |              1784.4           |                  |    2196.9  |             1863.7          |             1860.2        
      f16 B=32, M=197, H=16, K=128   |               887.0           |        682.3     |    4492.9  |              857.5          |              853.8        
      f32 B=32, M=197, H=16, K=128   |              3546.6           |                  |    2807.3  |             3734.9          |             3737.1        
      f16 B=256, M=197, H=1, K=88    |               445.1           |                  |    1528.9  |              445.0          |              422.2        
      f32 B=256, M=197, H=1, K=88    |              1678.8           |                  |    1207.6  |             1746.6          |             1752.8        
      f16 B=16, M=197, H=16, K=88    |               441.5           |                  |    1544.2  |              437.3          |              419.5        
      f32 B=16, M=197, H=16, K=88    |              1668.3           |                  |    1250.4  |             1742.7          |             1746.1        
      f16 B=16, M=197, H=16, K=64    |               247.4           |        165.6     |    1242.5  |              233.2          |              217.5        
      f32 B=16, M=197, H=16, K=64    |              1051.1           |                  |    1125.3  |             1099.2          |             1096.6        
      f16 B=16, M=197, H=16, K=128   |               498.4           |        386.2     |    2264.5  |              488.0          |              480.5        
      f32 B=16, M=197, H=16, K=128   |              1950.2           |                  |    1446.7  |             2039.0          |             2028.6        
      f16 B=1, M=4096, H=160, K=128  |             55915.0           |      54620.4     |   45909.5  |            63407.6          |            51298.8        
      f32 B=1, M=4096, H=160, K=128  |            238514.6           |                  |            |           232677.5          |           232672.5        
      f16 B=2, M=4096, H=160, K=128  |             93612.0           |      84238.0     |            |           100433.1          |            84858.1        
      f32 B=2, M=4096, H=160, K=128  |            375037.4           |                  |            |           364234.1          |           364407.2        
      f16 B=1, M=8192, H=160, K=128  |            223261.8           |     215499.8     |            |           251806.6          |           202133.7        
      f32 B=1, M=8192, H=160, K=128  |            946708.9           |                  |            |           924986.2          |           924988.9        
      f16 B=2, M=8192, H=160, K=128  |            367969.2           |     330092.8     |            |           395881.3          |           332000.6        
      f32 B=2, M=8192, H=160, K=128  |           1492031.4           |                  |            |          1448691.1          |          1449146.1        
      f16 B=1024, M=82, H=8, K=64    |              1890.2           |       1620.3     |    3819.7  |             1861.3          |             1764.6        
      f32 B=1024, M=82, H=8, K=64    |              8428.2           |                  |    8735.1  |             8831.3          |             8867.4        
      f16 B=150, M=256, H=16, K=64   |              2292.3           |       1625.3     |    4555.9  |             2109.8          |             2019.4        
      f32 B=150, M=256, H=16, K=64   |              6252.4           |                  |   12948.1  |             6288.9          |             6281.3        
      f16 B=64, M=256, H=12, K=64    |               782.2           |        567.4     |    1498.0  |              731.2          |              699.7        
      f32 B=64, M=256, H=12, K=64    |              2141.2           |                  |    4266.6  |             2160.6          |             2160.2        
      f16 B=1, M=4096, H=16, K=40    |             23504.2           |                  |    4196.1  |            23699.0          |            23008.9        
      f32 B=1, M=4096, H=16, K=40    |             73699.5           |                  |   17755.3  |            73261.8          |            73078.5        
      f16 B=1, M=16384, H=16, K=40   |            391408.9           |                  |            |           439777.3          |           407653.7        
      f32 B=1, M=16384, H=16, K=40   |           1196173.6           |                  |            |          1181547.9          |          1181625.1        
      f16 B=256, M=4096, H=16, K=64  |            733221.8           |     439627.5     |            |           603237.1          |           565905.8        
      f16 B=16, M=128, H=16, K=16    |               130.0           |        113.1     |     265.2  |              125.5          |              124.4        
      f32 B=16, M=128, H=16, K=16    |               161.5           |                  |     373.1  |              160.6          |              162.8        
      f16 B=16, M=128, H=16, K=32    |               125.8           |        111.5     |     263.8  |              122.1          |              125.8        
      f32 B=16, M=128, H=16, K=32    |               189.8           |                  |     412.6  |              196.3          |              196.2        
      f16 B=16, M=128, H=16, K=64    |               126.0           |        112.4     |     265.8  |              120.8          |              125.7        
      f32 B=16, M=128, H=16, K=64    |               272.1           |                  |     498.7  |              285.9          |              283.3        
      f16 B=16, M=128, H=16, K=128   |               181.3           |        158.4     |     298.5  |              178.0          |              186.5        
      f32 B=16, M=128, H=16, K=128   |               509.6           |                  |     673.9  |              521.1          |              521.5        
      f16 B=16, M=128, H=16, K=256   |               774.0           |                  |     541.4  |              775.5          |              757.0        
      f32 B=16, M=128, H=16, K=256   |               975.2           |                  |    1162.6  |              994.5          |              994.5        
      f16 B=16, M=512, H=16, K=16    |               621.0           |        322.6     |    1204.9  |              555.0          |              519.9        
      f32 B=16, M=512, H=16, K=16    |              2148.0           |                  |    4414.4  |             2178.9          |             2180.4        
      f16 B=16, M=512, H=16, K=32    |               709.9           |        435.6     |    1306.2  |              653.1          |              602.8        
      f32 B=16, M=512, H=16, K=32    |              2335.6           |                  |    4640.7  |             2336.0          |             2336.3        
      f16 B=16, M=512, H=16, K=64    |               917.8           |        702.7     |    1545.8  |              849.9          |              797.4        
      f32 B=16, M=512, H=16, K=64    |              2965.5           |                  |    5125.0  |             2986.9          |             2988.3        
      f16 B=16, M=512, H=16, K=128   |              1644.0           |       1584.1     |    1983.4  |             1757.0          |             1548.0        
      f32 B=16, M=512, H=16, K=128   |              6152.7           |                  |    6099.1  |             6067.7          |             6068.6        
      f16 B=16, M=512, H=16, K=256   |              8178.3           |                  |    2899.1  |             7895.5          |             7977.4        
      f32 B=16, M=512, H=16, K=256   |             11894.0           |                  |   10639.5  |            11635.0          |            11624.7        
      f16 B=16, M=1024, H=16, K=16   |              2420.5           |       1240.8     |    4259.2  |             2234.6          |             2048.9        
      f32 B=16, M=1024, H=16, K=16   |              8467.2           |                  |   16650.4  |             8512.2          |             8510.2        
      f16 B=16, M=1024, H=16, K=32   |              2675.7           |       1618.9     |    4491.2  |             2441.3          |             2230.0        
      f32 B=16, M=1024, H=16, K=32   |              9010.0           |                  |   17301.0  |             9012.1          |             9015.7        
      f16 B=16, M=1024, H=16, K=64   |              3328.4           |       2370.3     |    4994.5  |             3032.2          |             2820.0        
      f32 B=16, M=1024, H=16, K=64   |             11566.8           |                  |   18714.2  |            11494.1          |            11492.8        
      f16 B=16, M=1024, H=16, K=128  |              5867.9           |       5632.5     |    5952.8  |             6401.1          |             5440.8        
      f32 B=16, M=1024, H=16, K=128  |             23345.4           |                  |   21523.7  |            22859.1          |            22870.3        
      f16 B=16, M=1024, H=16, K=256  |             30619.2           |                  |    7893.1  |            29884.9          |            29060.4        
      f32 B=16, M=1024, H=16, K=256  |             45211.4           |                  |   38093.0  |            43435.3          |            43423.8        
      f16 B=64, M=128, H=16, K=16    |               159.6           |        145.2     |     439.9  |              161.2          |              167.0        
      f32 B=64, M=128, H=16, K=16    |               493.4           |                  |    1270.0  |              502.7          |              503.2        
      f16 B=64, M=128, H=16, K=32    |               208.5           |        212.1     |     545.3  |              206.2          |              204.9        
      f32 B=64, M=128, H=16, K=32    |               601.2           |                  |    1427.0  |              610.5          |              610.6        
      f16 B=64, M=128, H=16, K=64    |               329.0           |        310.8     |     766.0  |              327.2          |              314.7        
      f32 B=64, M=128, H=16, K=64    |               867.7           |                  |    1743.5  |              889.2          |              889.0        
      f16 B=64, M=128, H=16, K=128   |               635.5           |        562.1     |    1226.7  |              613.3          |              650.7        
      f32 B=64, M=128, H=16, K=128   |              1774.4           |                  |    2386.5  |             1800.0          |             1799.6        
      f16 B=64, M=128, H=16, K=256   |              2839.8           |                  |    2122.8  |             2765.2          |             2821.9        
      f32 B=64, M=128, H=16, K=256   |              3419.2           |                  |    4320.7  |             3458.8          |             3457.0        
      f16 B=64, M=512, H=16, K=16    |              2316.2           |       1202.3     |    4487.1  |             1983.9          |             1868.4        
      f32 B=64, M=512, H=16, K=16    |              6686.0           |                  |   16991.5  |             6709.9          |             6713.2        
      f16 B=64, M=512, H=16, K=32    |              2701.7           |       1541.9     |    4975.9  |             2346.4          |             2184.3        
      f32 B=64, M=512, H=16, K=32    |              7460.5           |                  |   17859.9  |             7429.6          |             7438.0        
      f16 B=64, M=512, H=16, K=64    |              3461.2           |       2418.2     |    5886.1  |             3083.1          |             2942.6        
      f32 B=64, M=512, H=16, K=64    |              9553.4           |                  |   19768.6  |             9526.5          |             9516.6        
      f16 B=64, M=512, H=16, K=128   |              5875.5           |       5443.1     |    7711.0  |             6141.3          |             5513.6        
      f32 B=64, M=512, H=16, K=128   |             21317.0           |                  |   23651.2  |            20931.9          |            20932.3        
      f16 B=64, M=512, H=16, K=256   |             31238.2           |                  |   11490.6  |            28626.8          |            28954.0        
      f32 B=64, M=512, H=16, K=256   |             41124.8           |                  |   42468.0  |            39562.8          |            39579.5        
      f16 B=64, M=1024, H=16, K=16   |              9142.8           |       4707.2     |   16882.8  |             7887.0          |             7314.5        
      f32 B=64, M=1024, H=16, K=16   |             26512.3           |                  |   66311.7  |            26497.5          |            26459.8        
      f16 B=64, M=1024, H=16, K=32   |             10420.8           |       5698.3     |   17875.0  |             8851.4          |             7974.0        
      f32 B=64, M=1024, H=16, K=32   |             28300.0           |                  |   69088.7  |            28102.1          |            28098.2        
      f16 B=64, M=1024, H=16, K=64   |             12948.5           |       8119.0     |   19944.3  |            11064.2          |            10477.6        
      f32 B=64, M=1024, H=16, K=64   |             35600.6           |                  |   74762.8  |            35316.4          |            35361.2        
      f16 B=64, M=1024, H=16, K=128  |             20820.5           |      19184.2     |   23699.3  |            21954.8          |            19220.0        
      f32 B=64, M=1024, H=16, K=128  |             80800.3           |                  |   86003.8  |            78521.9          |            78393.9        
      f16 B=64, M=1024, H=16, K=256  |            114411.1           |                  |   32958.3  |           103287.2          |           104304.2        
      f32 B=64, M=1024, H=16, K=256  |            155731.5           |                  |  153011.0  |           148071.6          |           148165.6        

Times are in microseconds (us).
```
</details>


[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Dec 6, 2022
ghstack-source-id: a7de549161fc563736f803189c103c8e5c7545e7
Pull Request resolved: #467
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks and sorry for the delay!

Let's get this merged!

**PERFORMANCE**

This makes performance worse in f16 :(
But I think we need it for stability

<details>
<summary>bw P100/V100 (f32/f16)</summary>

```
[---------------------------------------- attention backward (attn_bias=<class 'NoneType'>) ----------------------------------------]                                                    
                                                         |  48_accf32_6f8e2f15  |  56_base_02bf6b4e  |  vanilla   |  57_tmpT_b516aec4
1 threads: --------------------------------------------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |         6289.6       |        6936.0      |    2183.6  |        6940.0    
                          f32 B=384, M=197, H=1, K=88    |         8793.3       |        9446.6      |    2175.1  |        9429.2    
                          f16 B=384, M=197, H=1, K=80    |         5989.9       |        6596.2      |    2146.8  |        6608.6    
                          f32 B=384, M=197, H=1, K=80    |         8427.1       |        8993.1      |    2134.0  |        9030.5    
                          f16 B=384, M=197, H=1, K=64    |         3347.0       |        3527.4      |    1799.3  |        3538.3    
                          f32 B=384, M=197, H=1, K=64    |         5563.1       |        5984.2      |    1801.6  |        5980.9    
                          f16 B=1024, M=197, H=1, K=88   |        15680.5       |       17424.8      |    5671.9  |       17452.6    
                          f32 B=1024, M=197, H=1, K=88   |        23784.2       |       25542.8      |    5664.0  |       25578.6    
                          f16 B=1024, M=197, H=1, K=80   |        14935.5       |       16581.3      |    5559.5  |       16587.3    
                          f32 B=1024, M=197, H=1, K=80   |        22767.6       |       24354.1      |    5550.7  |       24362.0    
                          f16 B=1024, M=197, H=1, K=64   |         8331.1       |        8671.5      |    4644.9  |        8695.3    
                          f32 B=1024, M=197, H=1, K=64   |        15061.6       |       16148.8      |    4650.0  |       16201.4    
                          f16 B=512, M=197, H=1, K=80    |         7594.7       |        8306.6      |    2824.6  |        8336.5    
                          f32 B=512, M=197, H=1, K=80    |        11857.0       |       12660.2      |    2807.4  |       12709.2    
                          f16 B=32, M=197, H=16, K=80    |         7779.8       |        8342.0      |    2820.8  |        8393.5    
                          f32 B=32, M=197, H=16, K=80    |        11846.5       |       12487.8      |    2806.0  |       12549.1    
                          f16 B=32, M=197, H=16, K=64    |         4258.4       |        4445.6      |    2374.0  |        4461.2    
                          f32 B=32, M=197, H=16, K=64    |         7828.9       |        8434.9      |    2376.0  |        8472.2    
                          f16 B=32, M=197, H=16, K=128   |         9025.1       |        9671.9      |    3159.3  |        9707.2    
                          f32 B=32, M=197, H=16, K=128   |        14139.8       |       14920.7      |    3157.7  |       14928.9    
                          f16 B=256, M=197, H=1, K=88    |         4608.9       |        5118.8      |    1478.4  |        5119.2    
                          f32 B=256, M=197, H=1, K=88    |         6174.0       |        6644.3      |    1477.6  |        6642.7    
                          f16 B=16, M=197, H=16, K=88    |         4618.4       |        5073.4      |    1479.5  |        5062.9    
                          f32 B=16, M=197, H=16, K=88    |         6114.2       |        6471.7      |    1474.9  |        6478.7    
                          f16 B=16, M=197, H=16, K=64    |         2490.3       |        2557.5      |    1225.9  |        2550.0    
                          f32 B=16, M=197, H=16, K=64    |         3918.9       |        4208.2      |    1227.0  |        4195.7    
                          f16 B=16, M=197, H=16, K=128   |         5210.0       |        5649.3      |    1635.2  |        5648.1    
                          f32 B=16, M=197, H=16, K=128   |         7103.1       |        7451.6      |    1645.5  |        7445.1    
                          f16 B=1, M=4096, H=160, K=128  |      1014229.1       |     1106182.6      |            |     1108040.6    
                          f32 B=1, M=4096, H=160, K=128  |      1258173.2       |     1243183.8      |            |     1241548.7    
                          f16 B=2, M=4096, H=160, K=128  |      1642279.2       |     1753736.9      |            |     1771655.3    
                          f32 B=2, M=4096, H=160, K=128  |      2505435.4       |     2477353.1      |            |     2473773.4    
                          f16 B=1, M=8192, H=160, K=128  |      4050128.4       |     4415962.1      |            |     4428194.9    
                          f32 B=1, M=8192, H=160, K=128  |      5042352.6       |     4970582.0      |            |     4965069.9    
                          f16 B=2, M=8192, H=160, K=128  |      6600732.3       |     7026378.5      |            |     7068051.8    
                          f16 B=1024, M=82, H=8, K=64    |        21572.8       |       22531.6      |    9059.1  |       22418.4    
                          f32 B=1024, M=82, H=8, K=64    |        38178.4       |       45927.2      |    9070.3  |       45708.0    
                          f16 B=150, M=256, H=16, K=64   |        21436.5       |       21927.4      |   12938.5  |       22001.0    
                          f32 B=150, M=256, H=16, K=64   |        33024.2       |       33196.3      |   13249.2  |       33199.6    
                          f16 B=64, M=256, H=12, K=64    |         6869.8       |        7048.6      |    4200.6  |        7073.6    
                          f32 B=64, M=256, H=12, K=64    |        10719.9       |       10832.1      |    4271.3  |       10843.6    
                          f16 B=1, M=4096, H=16, K=40    |       134722.6       |      145429.8      |   20587.0  |      143743.8    
                          f32 B=1, M=4096, H=16, K=40    |       143015.1       |      147850.6      |   20625.8  |      148272.8    
                          f16 B=1, M=16384, H=16, K=40   |      2149850.4       |     2323732.4      |            |     2301489.4    
                          f32 B=1, M=16384, H=16, K=40   |      2286478.1       |     2369812.4      |            |     2375911.8    
                          f16 B=16, M=128, H=16, K=16    |          497.2       |         502.9      |     623.8  |         503.6    
                          f32 B=16, M=128, H=16, K=16    |          573.5       |         609.8      |     617.5  |         611.6    
                          f16 B=16, M=128, H=16, K=32    |          563.9       |         573.1      |     624.2  |         575.2    
                          f32 B=16, M=128, H=16, K=32    |          661.5       |         702.2      |     620.4  |         703.2    
                          f16 B=16, M=128, H=16, K=64    |          708.9       |         722.6      |     619.9  |         724.3    
                          f32 B=16, M=128, H=16, K=64    |          916.2       |         953.5      |     618.6  |         953.6    
                          f16 B=16, M=128, H=16, K=128   |         1465.8       |        1542.8      |     624.2  |        1545.7    
                          f32 B=16, M=128, H=16, K=128   |         1829.1       |        1872.3      |     616.2  |        1873.9    
                          f16 B=16, M=128, H=16, K=256   |         3796.3       |        4002.5      |    1010.7  |        4008.4    
                          f32 B=16, M=128, H=16, K=256   |         3838.8       |        3957.1      |    1203.6  |        3951.8    
                          f16 B=16, M=512, H=16, K=16    |         7680.4       |        7775.1      |    4848.7  |        7752.8    
                          f32 B=16, M=512, H=16, K=16    |         9402.0       |        9926.1      |    4926.7  |        9914.5    
                          f16 B=16, M=512, H=16, K=32    |         8897.5       |        8999.0      |    5055.9  |        9003.9    
                          f32 B=16, M=512, H=16, K=32    |        10762.1       |       11320.9      |    5065.4  |       11341.8    
                          f16 B=16, M=512, H=16, K=64    |        10936.9       |       11214.8      |    5484.9  |       11254.4    
                          f32 B=16, M=512, H=16, K=64    |        15091.9       |       15210.3      |    5552.0  |       15196.3    
                          f16 B=16, M=512, H=16, K=128   |        23491.0       |       25234.8      |    7317.1  |       25300.3    
                          f32 B=16, M=512, H=16, K=128   |        30524.6       |       30320.0      |    7487.0  |       30200.6    
                          f16 B=16, M=512, H=16, K=256   |        50389.7       |       54525.2      |   14015.3  |       54931.1    
                          f32 B=16, M=512, H=16, K=256   |        62155.6       |       61046.4      |   14285.5  |       61081.1    
                          f16 B=16, M=1024, H=16, K=16   |        31289.9       |       31951.9      |   18778.4  |       32044.9    
                          f32 B=16, M=1024, H=16, K=16   |        37744.4       |       39586.6      |   18929.9  |       39739.9    
                          f16 B=16, M=1024, H=16, K=32   |        35770.8       |       36651.1      |   19620.2  |       36909.0    
                          f32 B=16, M=1024, H=16, K=32   |        43211.5       |       45544.5      |   19518.4  |       45664.8    
                          f16 B=16, M=1024, H=16, K=64   |        43865.1       |       44864.8      |   21286.0  |       45345.7    
                          f32 B=16, M=1024, H=16, K=64   |        60710.5       |       60966.9      |   21634.1  |       60921.9    
                          f16 B=16, M=1024, H=16, K=128  |        94633.7       |      101796.0      |   28502.1  |      102871.6    
                          f32 B=16, M=1024, H=16, K=128  |       124093.6       |      122196.9      |   28520.3  |      122043.0    
                          f16 B=16, M=1024, H=16, K=256  |       194780.7       |      212303.0      |   55419.1  |      214643.5    
                          f32 B=16, M=1024, H=16, K=256  |       250799.1       |      245196.2      |   55634.0  |      245534.2    
                          f16 B=64, M=128, H=16, K=16    |         1658.0       |        1661.4      |    1331.0  |        1662.7    
                          f32 B=64, M=128, H=16, K=16    |         2129.2       |        2266.2      |    1371.8  |        2269.4    
                          f16 B=64, M=128, H=16, K=32    |         1904.3       |        1903.5      |    1384.2  |        1906.4    
                          f32 B=64, M=128, H=16, K=32    |         2496.5       |        2643.4      |    1445.1  |        2639.8    
                          f16 B=64, M=128, H=16, K=64    |         2393.1       |        2432.3      |    1505.2  |        2437.8    
                          f32 B=64, M=128, H=16, K=64    |         3471.1       |        3561.2      |    1590.0  |        3565.7    
                          f16 B=64, M=128, H=16, K=128   |         4969.1       |        5266.3      |    1988.4  |        5272.4    
                          f32 B=64, M=128, H=16, K=128   |         6880.3       |        7040.6      |    2121.0  |        7024.3    
                          f16 B=64, M=128, H=16, K=256   |        12635.3       |       13334.7      |    3859.5  |       13350.7    
                          f32 B=64, M=128, H=16, K=256   |        14185.7       |       14514.4      |    4553.8  |       14503.1    
                          f16 B=64, M=512, H=16, K=16    |        26278.4       |       26189.9      |   18916.0  |       26110.0    
                          f32 B=64, M=512, H=16, K=16    |        35128.2       |       37075.9      |   19191.4  |       37174.5    
                          f16 B=64, M=512, H=16, K=32    |        30414.2       |       30938.2      |   19734.7  |       31071.7    
                          f32 B=64, M=512, H=16, K=32    |        40483.2       |       42922.5      |   19843.9  |       42868.5    
                          f16 B=64, M=512, H=16, K=64    |        37248.2       |       38179.7      |   21640.1  |       38327.0    
                          f32 B=64, M=512, H=16, K=64    |        57666.8       |       57675.3      |   21909.1  |       57697.2    
                          f16 B=64, M=512, H=16, K=128   |        80113.6       |       86165.7      |   28765.5  |       86368.3    
                          f32 B=64, M=512, H=16, K=128   |       115672.5       |      115161.0      |   28910.3  |      115320.0    
                          f16 B=64, M=512, H=16, K=256   |       169250.0       |      183791.8      |   56315.7  |      183735.0    
                          f32 B=64, M=512, H=16, K=256   |       236594.6       |      233093.6      |   56853.4  |      233170.2    
                          f16 B=64, M=1024, H=16, K=16   |       106022.3       |      109588.2      |   74303.6  |      109410.4    
                          f32 B=64, M=1024, H=16, K=16   |       141241.8       |      148854.1      |            |      149651.5    
                          f16 B=64, M=1024, H=16, K=32   |       120899.1       |      125044.6      |   77828.6  |      125716.9    
                          f32 B=64, M=1024, H=16, K=32   |       162478.5       |      173906.1      |            |      173216.4    
                          f16 B=64, M=1024, H=16, K=64   |       149044.1       |      152290.6      |   85821.6  |      152748.5    
                          f32 B=64, M=1024, H=16, K=64   |       233195.9       |      231479.6      |            |      231533.9    
                          f16 B=64, M=1024, H=16, K=128  |       319761.5       |      344076.5      |  113579.4  |      345058.5    
                          f32 B=64, M=1024, H=16, K=128  |       470172.3       |      466330.7      |            |      463507.4    
                          f16 B=64, M=1024, H=16, K=256  |       658362.1       |      717070.2      |            |      723057.4    
                          f32 B=64, M=1024, H=16, K=256  |       955624.0       |      935114.3      |            |      935945.4    
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |         1811.3       |        1686.3      |    1375.2  |        1699.6    
                          f32 B=384, M=197, H=1, K=88    |         4315.7       |        4665.1      |    2257.1  |        4663.1    
                          f16 B=384, M=197, H=1, K=80    |         1733.0       |        1616.3      |    1281.7  |        1619.8    
                          f32 B=384, M=197, H=1, K=80    |         3965.0       |        4226.4      |    2171.7  |        4228.3    
                          f16 B=384, M=197, H=1, K=64    |         1135.2       |        1084.4      |    1043.8  |        1083.0    
                          f32 B=384, M=197, H=1, K=64    |         2673.2       |        2883.0      |    1744.5  |        2878.2    
                          f16 B=1024, M=197, H=1, K=88   |         4721.0       |        4396.6      |    3725.3  |        4404.1    
                          f32 B=1024, M=197, H=1, K=88   |        10531.3       |       11443.8      |    6106.5  |       11464.2    
                          f16 B=1024, M=197, H=1, K=80   |         4520.2       |        4216.2      |    3329.2  |        4223.7    
                          f32 B=1024, M=197, H=1, K=80   |         9573.5       |       10301.4      |    5757.6  |       10305.3    
                          f16 B=1024, M=197, H=1, K=64   |         2788.1       |        2660.1      |    2674.6  |        2663.5    
                          f32 B=1024, M=197, H=1, K=64   |         6556.7       |        7102.4      |    4516.6  |        7096.8    
                          f16 B=512, M=197, H=1, K=80    |         2377.2       |        2228.1      |    1685.0  |        2231.4    
                          f32 B=512, M=197, H=1, K=80    |         5259.3       |        5639.7      |    2887.8  |        5636.6    
                          f16 B=32, M=197, H=16, K=80    |         2403.0       |        2201.1      |    1798.2  |        2204.9    
                          f32 B=32, M=197, H=16, K=80    |         5402.3       |        5667.7      |    3046.3  |        5662.4    
                          f16 B=32, M=197, H=16, K=64    |         1552.7       |        1486.3      |    1451.6  |        1485.2    
                          f32 B=32, M=197, H=16, K=64    |         3622.5       |        3911.0      |    2427.0  |        3912.9    
                          f16 B=32, M=197, H=16, K=128   |         2776.6       |        2611.9      |    2211.3  |        2613.3    
                          f32 B=32, M=197, H=16, K=128   |         6647.9       |        7082.3      |    4088.5  |        7104.8    
                          f16 B=256, M=197, H=1, K=88    |         1357.5       |        1285.5      |     941.3  |        1287.6    
                          f32 B=256, M=197, H=1, K=88    |         2874.1       |        3085.1      |    1543.2  |        3090.1    
                          f16 B=16, M=197, H=16, K=88    |         1349.1       |        1264.5      |     964.7  |        1263.1    
                          f32 B=16, M=197, H=16, K=88    |         2803.8       |        2967.0      |    1647.4  |        2972.0    
                          f16 B=16, M=197, H=16, K=64    |          765.5       |         728.4      |     765.3  |         731.0    
                          f32 B=16, M=197, H=16, K=64    |         1834.1       |        1969.7      |    1282.4  |        1974.4    
                          f16 B=16, M=197, H=16, K=128   |         1509.1       |        1432.6      |    1139.1  |        1433.9    
                          f32 B=16, M=197, H=16, K=128   |         3406.5       |        3606.2      |    2048.6  |        3613.3    
                          f16 B=1, M=4096, H=160, K=128  |       168807.1       |      148652.9      |            |      149343.8    
                          f32 B=1, M=4096, H=160, K=128  |       549864.6       |      586699.9      |            |      585699.9    
                          f16 B=2, M=4096, H=160, K=128  |       339010.5       |      298827.7      |            |      298808.4    
                          f32 B=2, M=4096, H=160, K=128  |      1106963.8       |     1176218.3      |            |     1179173.2    
                          f16 B=1, M=8192, H=160, K=128  |       679742.4       |      594323.2      |            |      595580.1    
                          f32 B=1, M=8192, H=160, K=128  |      2195491.3       |     2340248.4      |            |     2343505.7    
                          f16 B=2, M=8192, H=160, K=128  |      1364983.7       |     1193787.8      |            |     1192596.1    
                          f16 B=1024, M=82, H=8, K=64    |         9052.8       |        8762.6      |    5804.2  |        8757.8    
                          f32 B=1024, M=82, H=8, K=64    |        14726.3       |       16270.6      |   11059.7  |       16215.9    
                          f16 B=150, M=256, H=16, K=64   |         5662.9       |        5519.4      |    7557.8  |        5524.5    
                          f32 B=150, M=256, H=16, K=64   |        16700.3       |       17612.7      |   16426.0  |       17640.4    
                          f16 B=64, M=256, H=12, K=64    |         1849.1       |        1793.1      |    2383.5  |        1798.7    
                          f32 B=64, M=256, H=12, K=64    |         5451.5       |        5766.4      |    4975.3  |        5775.8    
                          f16 B=1, M=4096, H=16, K=40    |        47263.3       |       47850.4      |    8315.6  |       47777.2    
                          f32 B=1, M=4096, H=16, K=40    |       113099.4       |      113164.5      |   19536.0  |      113930.0    
                          f16 B=1, M=16384, H=16, K=40   |       757091.8       |      770365.7      |            |      765401.3    
                          f32 B=1, M=16384, H=16, K=40   |      1806827.7       |     1816302.6      |            |     1819162.1    
                          f16 B=16, M=128, H=16, K=16    |          219.2       |         218.5      |     480.6  |         231.8    
                          f32 B=16, M=128, H=16, K=16    |          301.2       |         308.2      |     498.9  |         307.9    
                          f16 B=16, M=128, H=16, K=32    |          227.6       |         215.4      |     473.8  |         220.1    
                          f32 B=16, M=128, H=16, K=32    |          395.8       |         401.0      |     455.1  |         400.8    
                          f16 B=16, M=128, H=16, K=64    |          225.6       |         214.8      |     510.1  |         229.9    
                          f32 B=16, M=128, H=16, K=64    |          561.7       |         583.1      |     598.9  |         581.2    
                          f16 B=16, M=128, H=16, K=128   |          404.6       |         392.0      |     524.0  |         394.8    
                          f32 B=16, M=128, H=16, K=128   |         1103.0       |        1140.8      |    1015.4  |        1142.0    
                          f16 B=16, M=128, H=16, K=256   |         1045.3       |        1049.1      |     889.6  |        1047.4    
                          f32 B=16, M=128, H=16, K=256   |         2181.3       |        2270.9      |    1869.2  |        2265.7    
                          f16 B=16, M=512, H=16, K=16    |         1731.5       |        1585.7      |    1908.5  |        1586.2    
                          f32 B=16, M=512, H=16, K=16    |         4513.8       |        4696.7      |    4222.1  |        4695.6    
                          f16 B=16, M=512, H=16, K=32    |         1942.2       |        1823.4      |    2086.3  |        1809.7    
                          f32 B=16, M=512, H=16, K=32    |         5596.1       |        5819.4      |    4588.1  |        5833.6    
                          f16 B=16, M=512, H=16, K=64    |         2450.2       |        2340.8      |    2580.6  |        2353.2    
                          f32 B=16, M=512, H=16, K=64    |         7619.5       |        7853.1      |    5536.8  |        7875.1    
                          f16 B=16, M=512, H=16, K=128   |         4884.8       |        4473.5      |    3388.7  |        4487.9    
                          f32 B=16, M=512, H=16, K=128   |        15010.0       |       15557.1      |    8979.4  |       15513.1    
                          f16 B=16, M=512, H=16, K=256   |        12973.9       |       11134.7      |    5418.2  |       11106.3    
                          f32 B=16, M=512, H=16, K=256   |        29751.1       |       30979.1      |   16856.3  |       31012.8    
                          f16 B=16, M=1024, H=16, K=16   |         6802.6       |        6185.1      |    6996.0  |        6191.0    
                          f32 B=16, M=1024, H=16, K=16   |        18098.2       |       18954.5      |   16129.1  |       19166.8    
                          f16 B=16, M=1024, H=16, K=32   |         7531.9       |        7065.0      |    7436.0  |        7067.5    
                          f32 B=16, M=1024, H=16, K=32   |        21999.9       |       22901.5      |   17040.0  |       22899.4    
                          f16 B=16, M=1024, H=16, K=64   |         9312.0       |        8854.3      |    8605.6  |        8863.9    
                          f32 B=16, M=1024, H=16, K=64   |        29837.7       |       30895.1      |   20355.3  |       30878.5    
                          f16 B=16, M=1024, H=16, K=128  |        18979.0       |       16951.0      |   10561.3  |       16995.2    
                          f32 B=16, M=1024, H=16, K=128  |        58738.3       |       60861.2      |   33427.0  |       60599.8    
                          f16 B=16, M=1024, H=16, K=256  |        49681.9       |       41833.3      |   17329.5  |       41921.6    
                          f32 B=16, M=1024, H=16, K=256  |       117362.4       |      121004.8      |   60515.8  |      122046.9    
                          f16 B=64, M=128, H=16, K=16    |          432.2       |         411.1      |     642.1  |         411.3    
                          f32 B=64, M=128, H=16, K=16    |         1028.9       |        1057.9      |    1233.4  |        1056.6    
                          f16 B=64, M=128, H=16, K=32    |          522.5       |         500.7      |     813.6  |         499.7    
                          f32 B=64, M=128, H=16, K=32    |         1403.5       |        1443.4      |    1535.6  |        1443.6    
                          f16 B=64, M=128, H=16, K=64    |          750.0       |         739.8      |    1185.6  |         741.0    
                          f32 B=64, M=128, H=16, K=64    |         2013.9       |        2110.1      |    2156.8  |        2105.4    
                          f16 B=64, M=128, H=16, K=128   |         1421.2       |        1387.9      |    1915.5  |        1388.3    
                          f32 B=64, M=128, H=16, K=128   |         3946.5       |        4156.7      |    3780.1  |        4156.3    
                          f16 B=64, M=128, H=16, K=256   |         3811.5       |        3810.6      |    3448.7  |        3811.0    
                          f32 B=64, M=128, H=16, K=256   |         7983.9       |        8432.7      |    7304.0  |        8415.5    
                          f16 B=64, M=512, H=16, K=16    |         6157.4       |        5523.9      |    7461.8  |        5528.1    
                          f32 B=64, M=512, H=16, K=16    |        16118.2       |       17004.7      |   16651.7  |       16984.5    
                          f16 B=64, M=512, H=16, K=32    |         6985.1       |        6483.1      |    8278.5  |        6495.1    
                          f32 B=64, M=512, H=16, K=32    |        20471.9       |       21230.7      |   18420.5  |       21269.0    
                          f16 B=64, M=512, H=16, K=64    |         9045.8       |        8633.5      |   10337.1  |        8674.1    
                          f32 B=64, M=512, H=16, K=64    |        27675.0       |       29282.8      |   22883.7  |       29136.9    
                          f16 B=64, M=512, H=16, K=128   |        17594.2       |       15805.9      |   14700.6  |       15788.2    
                          f32 B=64, M=512, H=16, K=128   |        54612.4       |       57815.7      |   39974.3  |       57951.8    
                          f16 B=64, M=512, H=16, K=256   |        47452.6       |       40093.5      |   27087.5  |       40175.6    
                          f32 B=64, M=512, H=16, K=256   |       108880.3       |      115953.1      |   77794.2  |      115951.0    
                          f16 B=64, M=1024, H=16, K=16   |        24369.0       |       21533.0      |   28448.6  |       21556.6    
                          f32 B=64, M=1024, H=16, K=16   |        64649.7       |       68791.8      |            |       68407.4    
                          f16 B=64, M=1024, H=16, K=32   |        27143.1       |       25683.7      |   30252.7  |       25727.9    
                          f32 B=64, M=1024, H=16, K=32   |        79967.5       |       83351.5      |            |       83084.7    
                          f16 B=64, M=1024, H=16, K=64   |        34667.0       |       32592.7      |   36991.2  |       32659.8    
                          f32 B=64, M=1024, H=16, K=64   |       108282.2       |      113858.1      |            |      114286.0    
                          f16 B=64, M=1024, H=16, K=128  |        68519.5       |       59757.3      |   48834.4  |       59817.7    
                          f32 B=64, M=1024, H=16, K=128  |       215465.9       |      227335.8      |            |      227204.0    
                          f16 B=64, M=1024, H=16, K=256  |       183070.4       |      150960.1      |            |      150947.1    
                          f32 B=64, M=1024, H=16, K=256  |       425832.9       |      453717.2      |            |      453349.2    

Times are in microseconds (us).
```
</details>

<details>
<summary>bw A100 (f32/f16)</summary>

```
[----------------------------------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------------------------------]                                                       
                                     |  48_accf32_69654fdb[cutlass]  |  flash[flshatt]  |  vanilla   |  56_base_02bf6b4e[cutlass]  |  57_tmpT_b516aec4[cutlass]
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |               613.4           |                  |    2264.9  |              612.8          |              578.3        
      f32 B=384, M=197, H=1, K=88    |              2335.2           |                  |    1843.0  |             2438.4          |             2425.1        
      f16 B=384, M=197, H=1, K=80    |               583.6           |                  |    1922.9  |              577.6          |              548.4        
      f32 B=384, M=197, H=1, K=80    |              2241.3           |                  |    1787.8  |             2333.2          |             2333.1        
      f16 B=384, M=197, H=1, K=64    |               405.6           |        232.5     |    1809.8  |              386.4          |              366.1        
      f32 B=384, M=197, H=1, K=64    |              1259.7           |                  |    1675.6  |             1309.6          |             1316.8        
      f16 B=1024, M=197, H=1, K=88   |              1538.2           |                  |    5964.7  |             1550.8          |             1454.4        
      f32 B=1024, M=197, H=1, K=88   |              6031.4           |                  |    4559.5  |             6325.4          |             6332.7        
      f16 B=1024, M=197, H=1, K=80   |              1458.8           |                  |    5038.2  |             1463.8          |             1379.0        
      f32 B=1024, M=197, H=1, K=80   |              5786.2           |                  |    4412.6  |             6059.0          |             6079.9        
      f16 B=1024, M=197, H=1, K=64   |               929.5           |        575.9     |    4735.6  |              862.1          |              821.6        
      f32 B=1024, M=197, H=1, K=64   |              3289.9           |                  |    4119.9  |             3434.4          |             3441.4        
      f16 B=512, M=197, H=1, K=80    |               744.5           |                  |    2544.7  |              735.3          |              695.8        
      f32 B=512, M=197, H=1, K=80    |              2889.4           |                  |    2286.1  |             3008.8          |             3029.0        
      f16 B=32, M=197, H=16, K=80    |               741.6           |                  |    2569.0  |              723.3          |              693.4        
      f32 B=32, M=197, H=16, K=80    |              2878.6           |                  |    2355.3  |             3003.0          |             3024.1        
      f16 B=32, M=197, H=16, K=64    |               478.5           |        295.7     |    2429.2  |              456.1          |              426.7        
      f32 B=32, M=197, H=16, K=64    |              1784.4           |                  |    2196.9  |             1863.7          |             1860.2        
      f16 B=32, M=197, H=16, K=128   |               887.0           |        682.3     |    4492.9  |              857.5          |              853.8        
      f32 B=32, M=197, H=16, K=128   |              3546.6           |                  |    2807.3  |             3734.9          |             3737.1        
      f16 B=256, M=197, H=1, K=88    |               445.1           |                  |    1528.9  |              445.0          |              422.2        
      f32 B=256, M=197, H=1, K=88    |              1678.8           |                  |    1207.6  |             1746.6          |             1752.8        
      f16 B=16, M=197, H=16, K=88    |               441.5           |                  |    1544.2  |              437.3          |              419.5        
      f32 B=16, M=197, H=16, K=88    |              1668.3           |                  |    1250.4  |             1742.7          |             1746.1        
      f16 B=16, M=197, H=16, K=64    |               247.4           |        165.6     |    1242.5  |              233.2          |              217.5        
      f32 B=16, M=197, H=16, K=64    |              1051.1           |                  |    1125.3  |             1099.2          |             1096.6        
      f16 B=16, M=197, H=16, K=128   |               498.4           |        386.2     |    2264.5  |              488.0          |              480.5        
      f32 B=16, M=197, H=16, K=128   |              1950.2           |                  |    1446.7  |             2039.0          |             2028.6        
      f16 B=1, M=4096, H=160, K=128  |             55915.0           |      54620.4     |   45909.5  |            63407.6          |            51298.8        
      f32 B=1, M=4096, H=160, K=128  |            238514.6           |                  |            |           232677.5          |           232672.5        
      f16 B=2, M=4096, H=160, K=128  |             93612.0           |      84238.0     |            |           100433.1          |            84858.1        
      f32 B=2, M=4096, H=160, K=128  |            375037.4           |                  |            |           364234.1          |           364407.2        
      f16 B=1, M=8192, H=160, K=128  |            223261.8           |     215499.8     |            |           251806.6          |           202133.7        
      f32 B=1, M=8192, H=160, K=128  |            946708.9           |                  |            |           924986.2          |           924988.9        
      f16 B=2, M=8192, H=160, K=128  |            367969.2           |     330092.8     |            |           395881.3          |           332000.6        
      f32 B=2, M=8192, H=160, K=128  |           1492031.4           |                  |            |          1448691.1          |          1449146.1        
      f16 B=1024, M=82, H=8, K=64    |              1890.2           |       1620.3     |    3819.7  |             1861.3          |             1764.6        
      f32 B=1024, M=82, H=8, K=64    |              8428.2           |                  |    8735.1  |             8831.3          |             8867.4        
      f16 B=150, M=256, H=16, K=64   |              2292.3           |       1625.3     |    4555.9  |             2109.8          |             2019.4        
      f32 B=150, M=256, H=16, K=64   |              6252.4           |                  |   12948.1  |             6288.9          |             6281.3        
      f16 B=64, M=256, H=12, K=64    |               782.2           |        567.4     |    1498.0  |              731.2          |              699.7        
      f32 B=64, M=256, H=12, K=64    |              2141.2           |                  |    4266.6  |             2160.6          |             2160.2        
      f16 B=1, M=4096, H=16, K=40    |             23504.2           |                  |    4196.1  |            23699.0          |            23008.9        
      f32 B=1, M=4096, H=16, K=40    |             73699.5           |                  |   17755.3  |            73261.8          |            73078.5        
      f16 B=1, M=16384, H=16, K=40   |            391408.9           |                  |            |           439777.3          |           407653.7        
      f32 B=1, M=16384, H=16, K=40   |           1196173.6           |                  |            |          1181547.9          |          1181625.1        
      f16 B=256, M=4096, H=16, K=64  |            733221.8           |     439627.5     |            |           603237.1          |           565905.8        
      f16 B=16, M=128, H=16, K=16    |               130.0           |        113.1     |     265.2  |              125.5          |              124.4        
      f32 B=16, M=128, H=16, K=16    |               161.5           |                  |     373.1  |              160.6          |              162.8        
      f16 B=16, M=128, H=16, K=32    |               125.8           |        111.5     |     263.8  |              122.1          |              125.8        
      f32 B=16, M=128, H=16, K=32    |               189.8           |                  |     412.6  |              196.3          |              196.2        
      f16 B=16, M=128, H=16, K=64    |               126.0           |        112.4     |     265.8  |              120.8          |              125.7        
      f32 B=16, M=128, H=16, K=64    |               272.1           |                  |     498.7  |              285.9          |              283.3        
      f16 B=16, M=128, H=16, K=128   |               181.3           |        158.4     |     298.5  |              178.0          |              186.5        
      f32 B=16, M=128, H=16, K=128   |               509.6           |                  |     673.9  |              521.1          |              521.5        
      f16 B=16, M=128, H=16, K=256   |               774.0           |                  |     541.4  |              775.5          |              757.0        
      f32 B=16, M=128, H=16, K=256   |               975.2           |                  |    1162.6  |              994.5          |              994.5        
      f16 B=16, M=512, H=16, K=16    |               621.0           |        322.6     |    1204.9  |              555.0          |              519.9        
      f32 B=16, M=512, H=16, K=16    |              2148.0           |                  |    4414.4  |             2178.9          |             2180.4        
      f16 B=16, M=512, H=16, K=32    |               709.9           |        435.6     |    1306.2  |              653.1          |              602.8        
      f32 B=16, M=512, H=16, K=32    |              2335.6           |                  |    4640.7  |             2336.0          |             2336.3        
      f16 B=16, M=512, H=16, K=64    |               917.8           |        702.7     |    1545.8  |              849.9          |              797.4        
      f32 B=16, M=512, H=16, K=64    |              2965.5           |                  |    5125.0  |             2986.9          |             2988.3        
      f16 B=16, M=512, H=16, K=128   |              1644.0           |       1584.1     |    1983.4  |             1757.0          |             1548.0        
      f32 B=16, M=512, H=16, K=128   |              6152.7           |                  |    6099.1  |             6067.7          |             6068.6        
      f16 B=16, M=512, H=16, K=256   |              8178.3           |                  |    2899.1  |             7895.5          |             7977.4        
      f32 B=16, M=512, H=16, K=256   |             11894.0           |                  |   10639.5  |            11635.0          |            11624.7        
      f16 B=16, M=1024, H=16, K=16   |              2420.5           |       1240.8     |    4259.2  |             2234.6          |             2048.9        
      f32 B=16, M=1024, H=16, K=16   |              8467.2           |                  |   16650.4  |             8512.2          |             8510.2        
      f16 B=16, M=1024, H=16, K=32   |              2675.7           |       1618.9     |    4491.2  |             2441.3          |             2230.0        
      f32 B=16, M=1024, H=16, K=32   |              9010.0           |                  |   17301.0  |             9012.1          |             9015.7        
      f16 B=16, M=1024, H=16, K=64   |              3328.4           |       2370.3     |    4994.5  |             3032.2          |             2820.0        
      f32 B=16, M=1024, H=16, K=64   |             11566.8           |                  |   18714.2  |            11494.1          |            11492.8        
      f16 B=16, M=1024, H=16, K=128  |              5867.9           |       5632.5     |    5952.8  |             6401.1          |             5440.8        
      f32 B=16, M=1024, H=16, K=128  |             23345.4           |                  |   21523.7  |            22859.1          |            22870.3        
      f16 B=16, M=1024, H=16, K=256  |             30619.2           |                  |    7893.1  |            29884.9          |            29060.4        
      f32 B=16, M=1024, H=16, K=256  |             45211.4           |                  |   38093.0  |            43435.3          |            43423.8        
      f16 B=64, M=128, H=16, K=16    |               159.6           |        145.2     |     439.9  |              161.2          |              167.0        
      f32 B=64, M=128, H=16, K=16    |               493.4           |                  |    1270.0  |              502.7          |              503.2        
      f16 B=64, M=128, H=16, K=32    |               208.5           |        212.1     |     545.3  |              206.2          |              204.9        
      f32 B=64, M=128, H=16, K=32    |               601.2           |                  |    1427.0  |              610.5          |              610.6        
      f16 B=64, M=128, H=16, K=64    |               329.0           |        310.8     |     766.0  |              327.2          |              314.7        
      f32 B=64, M=128, H=16, K=64    |               867.7           |                  |    1743.5  |              889.2          |              889.0        
      f16 B=64, M=128, H=16, K=128   |               635.5           |        562.1     |    1226.7  |              613.3          |              650.7        
      f32 B=64, M=128, H=16, K=128   |              1774.4           |                  |    2386.5  |             1800.0          |             1799.6        
      f16 B=64, M=128, H=16, K=256   |              2839.8           |                  |    2122.8  |             2765.2          |             2821.9        
      f32 B=64, M=128, H=16, K=256   |              3419.2           |                  |    4320.7  |             3458.8          |             3457.0        
      f16 B=64, M=512, H=16, K=16    |              2316.2           |       1202.3     |    4487.1  |             1983.9          |             1868.4        
      f32 B=64, M=512, H=16, K=16    |              6686.0           |                  |   16991.5  |             6709.9          |             6713.2        
      f16 B=64, M=512, H=16, K=32    |              2701.7           |       1541.9     |    4975.9  |             2346.4          |             2184.3        
      f32 B=64, M=512, H=16, K=32    |              7460.5           |                  |   17859.9  |             7429.6          |             7438.0        
      f16 B=64, M=512, H=16, K=64    |              3461.2           |       2418.2     |    5886.1  |             3083.1          |             2942.6        
      f32 B=64, M=512, H=16, K=64    |              9553.4           |                  |   19768.6  |             9526.5          |             9516.6        
      f16 B=64, M=512, H=16, K=128   |              5875.5           |       5443.1     |    7711.0  |             6141.3          |             5513.6        
      f32 B=64, M=512, H=16, K=128   |             21317.0           |                  |   23651.2  |            20931.9          |            20932.3        
      f16 B=64, M=512, H=16, K=256   |             31238.2           |                  |   11490.6  |            28626.8          |            28954.0        
      f32 B=64, M=512, H=16, K=256   |             41124.8           |                  |   42468.0  |            39562.8          |            39579.5        
      f16 B=64, M=1024, H=16, K=16   |              9142.8           |       4707.2     |   16882.8  |             7887.0          |             7314.5        
      f32 B=64, M=1024, H=16, K=16   |             26512.3           |                  |   66311.7  |            26497.5          |            26459.8        
      f16 B=64, M=1024, H=16, K=32   |             10420.8           |       5698.3     |   17875.0  |             8851.4          |             7974.0        
      f32 B=64, M=1024, H=16, K=32   |             28300.0           |                  |   69088.7  |            28102.1          |            28098.2        
      f16 B=64, M=1024, H=16, K=64   |             12948.5           |       8119.0     |   19944.3  |            11064.2          |            10477.6        
      f32 B=64, M=1024, H=16, K=64   |             35600.6           |                  |   74762.8  |            35316.4          |            35361.2        
      f16 B=64, M=1024, H=16, K=128  |             20820.5           |      19184.2     |   23699.3  |            21954.8          |            19220.0        
      f32 B=64, M=1024, H=16, K=128  |             80800.3           |                  |   86003.8  |            78521.9          |            78393.9        
      f16 B=64, M=1024, H=16, K=256  |            114411.1           |                  |   32958.3  |           103287.2          |           104304.2        
      f32 B=64, M=1024, H=16, K=256  |            155731.5           |                  |  153011.0  |           148071.6          |           148165.6        

Times are in microseconds (us).
```
</details>


[ghstack-poisoned]
danthe3rd pushed a commit that referenced this pull request Dec 6, 2022
ghstack-source-id: f713589d43273c6785ba6e3ae92e0974ef8ccfba
Pull Request resolved: #467
@danthe3rd danthe3rd merged commit e331e5b into gh/danthe3rd/48/base Dec 6, 2022
@danthe3rd danthe3rd deleted the gh/danthe3rd/48/head branch December 6, 2022 14:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants