Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use cutlass for memory-efficient attention #362

Merged
merged 59 commits into from Aug 25, 2022
Merged

Conversation

fmassa
Copy link
Contributor

@fmassa fmassa commented Aug 10, 2022

What does this PR do?

This (massive) PR adds a number of improvements to memory-efficient attention that have been developed over the last few months.
It contains:

  • fp16 / fp32 implementations based on CUTLASS, which supports A100 / V100 / P100 GPUs, for both forward and backward
  • attention bias and dropout for the original fp32 implementation
  • dispatches to FlashAttention for the cases supported by FlashAttention

A100

For the configurations below and the forward pass on fp16, CUTLASS-based kernels is on average 31% faster than vanilla PyTorch (10% faster on median), and 5% slower than FlashAttention on average (with median being 1% faster than FlashAttention).

For the backward pass, there is still room for improvement for the CUTLASS-based kernels, with it being 15% slower than vanilla PyTorch on average (7% slower on median), and 55% slower than FlashAttention on average and median.

The breakdown of the details can be found below

CUTLASS-based kernels
[---------- attention (attn_bias=<class 'NoneType'>) ----------]
                                        |  optimized  |  vanilla
1 threads: -----------------------------------------------------
      f16 fwd_gen B=32, M=128, K=16     |      30.8   |     63.5
      f32 fwd_gen B=32, M=128, K=16     |      30.7   |     55.8
      f16 fwd_gen B=32, M=128, K=32     |      30.8   |     63.6
      f32 fwd_gen B=32, M=128, K=32     |      30.6   |     55.1
      f16 fwd_gen B=32, M=128, K=64     |      30.9   |     63.4
      f32 fwd_gen B=32, M=128, K=64     |      30.5   |     58.5
      f16 fwd_gen B=32, M=128, K=128    |      30.9   |     63.1
      f32 fwd_gen B=32, M=128, K=128    |      31.5   |     70.6
      f16 fwd_gen B=32, M=512, K=16     |      66.0   |     70.8
      f32 fwd_gen B=32, M=512, K=16     |     183.1   |    285.5
      f16 fwd_gen B=32, M=512, K=32     |      66.7   |     71.7
      f32 fwd_gen B=32, M=512, K=32     |     185.1   |    302.7
      f16 fwd_gen B=32, M=512, K=64     |      71.6   |     76.3
      f32 fwd_gen B=32, M=512, K=64     |     209.3   |    333.3
      f16 fwd_gen B=32, M=512, K=128    |      84.6   |     88.6
      f32 fwd_gen B=32, M=512, K=128    |     256.0   |    395.2
      f16 fwd_gen B=32, M=1024, K=16    |     236.0   |    259.6
      f32 fwd_gen B=32, M=1024, K=16    |     649.6   |    886.8
      f16 fwd_gen B=32, M=1024, K=32    |     238.3   |    260.5
      f32 fwd_gen B=32, M=1024, K=32    |     654.4   |    945.1
      f16 fwd_gen B=32, M=1024, K=64    |     254.6   |    271.0
      f32 fwd_gen B=32, M=1024, K=64    |     739.2   |   1062.5
      f16 fwd_gen B=32, M=1024, K=128   |     297.5   |    302.9
      f32 fwd_gen B=32, M=1024, K=128   |     909.3   |   1289.5
      f16 fwd_gen B=256, M=128, K=16    |      37.8   |     64.0
      f32 fwd_gen B=256, M=128, K=16    |      92.3   |    126.8
      f16 fwd_gen B=256, M=128, K=32    |      38.9   |     63.6
      f32 fwd_gen B=256, M=128, K=32    |      93.9   |    136.0
      f16 fwd_gen B=256, M=128, K=64    |      42.6   |     63.7
      f32 fwd_gen B=256, M=128, K=64    |     105.4   |    157.1
      f16 fwd_gen B=256, M=128, K=128   |      53.4   |     63.5
      f32 fwd_gen B=256, M=128, K=128   |     131.5   |    215.3
      f16 fwd_gen B=256, M=512, K=16    |     423.1   |    480.8
      f32 fwd_gen B=256, M=512, K=16    |    1221.0   |   1570.4
      f16 fwd_gen B=256, M=512, K=32    |     429.2   |    501.9
      f32 fwd_gen B=256, M=512, K=32    |    1242.6   |   1685.7
      f16 fwd_gen B=256, M=512, K=64    |     474.4   |    545.6
      f32 fwd_gen B=256, M=512, K=64    |    1414.5   |   1929.8
      f16 fwd_gen B=256, M=512, K=128   |     572.1   |    622.0
      f32 fwd_gen B=256, M=512, K=128   |    1727.6   |   2423.7
      f16 fwd_gen B=256, M=1024, K=16   |    1619.3   |   1787.7
      f32 fwd_gen B=256, M=1024, K=16   |    4838.7   |   5997.2
      f16 fwd_gen B=256, M=1024, K=32   |    1641.3   |   1821.6
      f32 fwd_gen B=256, M=1024, K=32   |    4886.2   |   6414.1
      f16 fwd_gen B=256, M=1024, K=64   |    1780.2   |   1924.6
      f32 fwd_gen B=256, M=1024, K=64   |    5525.0   |   7355.2
      f16 fwd_gen B=256, M=1024, K=128  |    2121.6   |   2158.1
      f32 fwd_gen B=256, M=1024, K=128  |    6805.0   |   9216.5

Times are in microseconds (us).

[----- attention backward (attn_bias=<class 'NoneType'>) ------]
                                        |  optimized  |  vanilla
1 threads: -----------------------------------------------------
      f16 fwd_gen B=32, M=128, K=16     |     133.4   |    169.4
      f32 fwd_gen B=32, M=128, K=16     |     101.5   |    152.2
      f16 fwd_gen B=32, M=128, K=32     |     132.0   |    168.2
      f32 fwd_gen B=32, M=128, K=32     |     100.9   |    152.9
      f16 fwd_gen B=32, M=128, K=64     |     133.9   |    169.1
      f32 fwd_gen B=32, M=128, K=64     |     100.4   |    152.2
      f16 fwd_gen B=32, M=128, K=128    |     133.5   |    168.3
      f32 fwd_gen B=32, M=128, K=128    |     144.4   |    152.0
      f16 fwd_gen B=32, M=512, K=16     |     427.5   |    167.5
      f32 fwd_gen B=32, M=512, K=16     |     890.0   |    795.7
      f16 fwd_gen B=32, M=512, K=32     |     469.6   |    168.0
      f32 fwd_gen B=32, M=512, K=32     |     921.8   |    819.8
      f16 fwd_gen B=32, M=512, K=64     |     539.3   |    183.1
      f32 fwd_gen B=32, M=512, K=64     |    1110.2   |    859.2
      f16 fwd_gen B=32, M=512, K=128    |     953.4   |    225.7
      f32 fwd_gen B=32, M=512, K=128    |    2195.1   |    949.0
      f16 fwd_gen B=32, M=1024, K=16    |    1726.4   |    576.7
      f32 fwd_gen B=32, M=1024, K=16    |    3491.4   |   2450.4
      f16 fwd_gen B=32, M=1024, K=32    |    1785.4   |    594.5
      f32 fwd_gen B=32, M=1024, K=32    |    3627.2   |   2517.2
      f16 fwd_gen B=32, M=1024, K=64    |    2007.7   |    633.2
      f32 fwd_gen B=32, M=1024, K=64    |    4636.5   |   2658.0
      f16 fwd_gen B=32, M=1024, K=128   |    4227.3   |    731.3
      f32 fwd_gen B=32, M=1024, K=128   |    9188.3   |   2954.6
      f16 fwd_gen B=256, M=128, K=16    |     131.9   |    169.1
      f32 fwd_gen B=256, M=128, K=16    |     148.0   |    351.4
      f16 fwd_gen B=256, M=128, K=32    |     133.3   |    169.6
      f32 fwd_gen B=256, M=128, K=32    |     168.3   |    377.2
      f16 fwd_gen B=256, M=128, K=64    |     139.7   |    167.5
      f32 fwd_gen B=256, M=128, K=64    |     237.0   |    426.2
      f16 fwd_gen B=256, M=128, K=128   |     269.0   |    208.3
      f32 fwd_gen B=256, M=128, K=128   |     487.0   |    551.3
      f16 fwd_gen B=256, M=512, K=16    |     751.6   |   1141.8
      f32 fwd_gen B=256, M=512, K=16    |    2073.1   |   4320.2
      f16 fwd_gen B=256, M=512, K=32    |     927.7   |   1201.1
      f32 fwd_gen B=256, M=512, K=32    |    2340.4   |   4502.3
      f16 fwd_gen B=256, M=512, K=64    |    1381.8   |   1340.3
      f32 fwd_gen B=256, M=512, K=64    |    3010.8   |   4875.3
      f16 fwd_gen B=256, M=512, K=128   |    2691.6   |   1637.9
      f32 fwd_gen B=256, M=512, K=128   |    6341.2   |   5623.2
      f16 fwd_gen B=256, M=1024, K=16   |    2950.9   |   4136.5
      f32 fwd_gen B=256, M=1024, K=16   |    8616.0   |  16471.3
      f16 fwd_gen B=256, M=1024, K=32   |    3768.2   |   4289.4
      f32 fwd_gen B=256, M=1024, K=32   |    9121.6   |  17027.4
      f16 fwd_gen B=256, M=1024, K=64   |    4932.1   |   4590.7
      f32 fwd_gen B=256, M=1024, K=64   |   11903.2   |  18193.3
      f16 fwd_gen B=256, M=1024, K=128  |    9700.4   |   5259.7
      f32 fwd_gen B=256, M=1024, K=128  |   24108.2   |  20571.2

Times are in microseconds (us).
FlashAttention-based kernels
[---------- attention (attn_bias=<class 'NoneType'>) ----------]
                                        |  optimized  |  vanilla
1 threads: -----------------------------------------------------
      f16 flshatt B=32, M=128, K=16     |      31.2   |     63.3
      b16 flshatt B=32, M=128, K=16     |      31.4   |     63.2
      f16 flshatt B=32, M=128, K=32     |      31.2   |     63.4
      b16 flshatt B=32, M=128, K=32     |      30.6   |     60.7
      f16 flshatt B=32, M=128, K=64     |      30.9   |     63.2
      b16 flshatt B=32, M=128, K=64     |      31.6   |     63.2
      f16 flshatt B=32, M=128, K=128    |      31.7   |     63.1
      b16 flshatt B=32, M=128, K=128    |      31.8   |     62.9
      f16 flshatt B=32, M=512, K=16     |      77.8   |     71.1
      b16 flshatt B=32, M=512, K=16     |      77.8   |     72.9
      f16 flshatt B=32, M=512, K=32     |      71.2   |     72.3
      b16 flshatt B=32, M=512, K=32     |      71.2   |     74.7
      f16 flshatt B=32, M=512, K=64     |      90.9   |     76.2
      b16 flshatt B=32, M=512, K=64     |      90.9   |     78.5
      f16 flshatt B=32, M=512, K=128    |     172.4   |     89.3
      b16 flshatt B=32, M=512, K=128    |     172.4   |     91.1
      f16 flshatt B=32, M=1024, K=16    |     296.9   |    259.9
      b16 flshatt B=32, M=1024, K=16    |     296.9   |    265.1
      f16 flshatt B=32, M=1024, K=32    |     274.2   |    260.6
      b16 flshatt B=32, M=1024, K=32    |     274.2   |    265.5
      f16 flshatt B=32, M=1024, K=64    |     354.7   |    272.3
      b16 flshatt B=32, M=1024, K=64    |     354.7   |    277.4
      f16 flshatt B=32, M=1024, K=128   |     675.4   |    303.7
      b16 flshatt B=32, M=1024, K=128   |     675.4   |    309.2
      f16 flshatt B=256, M=128, K=16    |      31.3   |     63.6
      b16 flshatt B=256, M=128, K=16    |      31.7   |     63.1
      f16 flshatt B=256, M=128, K=32    |      31.5   |     63.0
      b16 flshatt B=256, M=128, K=32    |      31.5   |     62.6
      f16 flshatt B=256, M=128, K=64    |      31.7   |     63.4
      b16 flshatt B=256, M=128, K=64    |      31.8   |     63.3
      f16 flshatt B=256, M=128, K=128   |      39.4   |     63.2
      b16 flshatt B=256, M=128, K=128   |      39.8   |     63.2
      f16 flshatt B=256, M=512, K=16    |     112.5   |    482.2
      b16 flshatt B=256, M=512, K=16    |     112.5   |    481.0
      f16 flshatt B=256, M=512, K=32    |     122.8   |    502.9
      b16 flshatt B=256, M=512, K=32    |     122.7   |    502.0
      f16 flshatt B=256, M=512, K=64    |     234.6   |    561.3
      b16 flshatt B=256, M=512, K=64    |    4231.9   |    563.6
      f16 flshatt B=256, M=512, K=128   |     578.9   |    634.6
      b16 flshatt B=256, M=512, K=128   |     885.2   |    621.9
      f16 flshatt B=256, M=1024, K=16   |     435.9   |   1788.5
      b16 flshatt B=256, M=1024, K=16   |     472.3   |   1847.4
      f16 flshatt B=256, M=1024, K=32   |     496.1   |   1818.3
      b16 flshatt B=256, M=1024, K=32   |     479.3   |   1980.7
      f16 flshatt B=256, M=1024, K=64   |     877.3   |   1923.1
      b16 flshatt B=256, M=1024, K=64   |     879.7   |   1987.2
      f16 flshatt B=256, M=1024, K=128  |    2192.6   |   2149.6
      b16 flshatt B=256, M=1024, K=128  |    2191.4   |   2207.2

Times are in microseconds (us).

[----- attention backward (attn_bias=<class 'NoneType'>) ------]
                                        |  optimized  |  vanilla
1 threads: -----------------------------------------------------
      f16 flshatt B=32, M=128, K=16     |      80.7   |    168.2
      b16 flshatt B=32, M=128, K=16     |      80.1   |    166.8
      f16 flshatt B=32, M=128, K=32     |      81.0   |    170.6
      b16 flshatt B=32, M=128, K=32     |      81.2   |    166.8
      f16 flshatt B=32, M=128, K=64     |      80.4   |    167.4
      b16 flshatt B=32, M=128, K=64     |      80.5   |    164.9
      f16 flshatt B=32, M=128, K=128    |      82.3   |    169.5
      b16 flshatt B=32, M=128, K=128    |      82.7   |    167.9
      f16 flshatt B=32, M=512, K=16     |      95.7   |    168.5
      b16 flshatt B=32, M=512, K=16     |      96.8   |    166.0
      f16 flshatt B=32, M=512, K=32     |     121.1   |    168.1
      b16 flshatt B=32, M=512, K=32     |     122.0   |    170.6
      f16 flshatt B=32, M=512, K=64     |     191.1   |    185.0
      b16 flshatt B=32, M=512, K=64     |     192.0   |    186.9
      f16 flshatt B=32, M=512, K=128    |     413.3   |    223.5
      b16 flshatt B=32, M=512, K=128    |     416.1   |    225.7
      f16 flshatt B=32, M=1024, K=16    |     331.1   |    576.8
      b16 flshatt B=32, M=1024, K=16    |     332.7   |    585.4
      f16 flshatt B=32, M=1024, K=32    |     434.2   |    596.9
      b16 flshatt B=32, M=1024, K=32    |     433.8   |    603.4
      f16 flshatt B=32, M=1024, K=64    |     679.4   |    633.3
      b16 flshatt B=32, M=1024, K=64    |     682.4   |    639.8
      f16 flshatt B=32, M=1024, K=128   |    1709.8   |    732.5
      b16 flshatt B=32, M=1024, K=128   |    1702.0   |    739.3
      f16 flshatt B=256, M=128, K=16    |      80.8   |    167.6
      b16 flshatt B=256, M=128, K=16    |      81.1   |    167.6
      f16 flshatt B=256, M=128, K=32    |      80.5   |    167.5
      b16 flshatt B=256, M=128, K=32    |      80.5   |    167.1
      f16 flshatt B=256, M=128, K=64    |      81.2   |    167.4
      b16 flshatt B=256, M=128, K=64    |      80.5   |    165.8
      f16 flshatt B=256, M=128, K=128   |     139.0   |    208.0
      b16 flshatt B=256, M=128, K=128   |     140.7   |    210.5
      f16 flshatt B=256, M=512, K=16    |     289.4   |   1140.7
      b16 flshatt B=256, M=512, K=16    |     291.9   |   1143.4
      f16 flshatt B=256, M=512, K=32    |     381.8   |   1202.4
      b16 flshatt B=256, M=512, K=32    |     384.7   |   1205.6
      f16 flshatt B=256, M=512, K=64    |     678.3   |   1342.4
      b16 flshatt B=256, M=512, K=64    |     681.1   |   1345.0
      f16 flshatt B=256, M=512, K=128   |    1525.1   |   1637.1
      b16 flshatt B=256, M=512, K=128   |    1527.4   |   1639.6
      f16 flshatt B=256, M=1024, K=16   |    1021.7   |   4150.8
      b16 flshatt B=256, M=1024, K=16   |    1027.2   |   4178.5
      f16 flshatt B=256, M=1024, K=32   |    1487.1   |   4298.6
      b16 flshatt B=256, M=1024, K=32   |    1490.5   |   4326.1
      f16 flshatt B=256, M=1024, K=64   |    2316.2   |   4598.2
      b16 flshatt B=256, M=1024, K=64   |    2323.0   |   4629.3
      f16 flshatt B=256, M=1024, K=128  |    5559.5   |   5264.1
      b16 flshatt B=256, M=1024, K=128  |    5555.5   |   5297.1

Times are in microseconds (us).
FlashAttention and lower triangular
[ attention (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                        |  optimized  |  vanilla
1 threads: -----------------------------------------------------
      f16 flshatt B=32, M=128, K=16     |      31.6   |     70.5
      b16 flshatt B=32, M=128, K=16     |      31.2   |     70.0
      f16 flshatt B=32, M=128, K=32     |      30.6   |     68.4
      b16 flshatt B=32, M=128, K=32     |      30.1   |     67.7
      f16 flshatt B=32, M=128, K=64     |      31.6   |     70.6
      b16 flshatt B=32, M=128, K=64     |      31.6   |     70.3
      f16 flshatt B=32, M=128, K=128    |      31.1   |     70.3
      b16 flshatt B=32, M=128, K=128    |      31.5   |     69.6
      f16 flshatt B=32, M=512, K=16     |      60.3   |    106.4
      b16 flshatt B=32, M=512, K=16     |      60.3   |    109.9
      f16 flshatt B=32, M=512, K=32     |      57.6   |    108.9
      b16 flshatt B=32, M=512, K=32     |      57.6   |    112.0
      f16 flshatt B=32, M=512, K=64     |      71.8   |    114.0
      b16 flshatt B=32, M=512, K=64     |      71.7   |    117.5
      f16 flshatt B=32, M=512, K=128    |     133.1   |    132.0
      b16 flshatt B=32, M=512, K=128    |     133.3   |    135.0
      f16 flshatt B=32, M=1024, K=16    |     189.2   |    439.8
      b16 flshatt B=32, M=1024, K=16    |     189.2   |    456.7
      f16 flshatt B=32, M=1024, K=32    |     181.8   |    442.7
      b16 flshatt B=32, M=1024, K=32    |     181.8   |    459.9
      f16 flshatt B=32, M=1024, K=64    |     233.5   |    453.4
      b16 flshatt B=32, M=1024, K=64    |     233.5   |    471.3
      f16 flshatt B=32, M=1024, K=128   |     423.1   |    478.7
      b16 flshatt B=32, M=1024, K=128   |     423.1   |    497.5
      f16 flshatt B=256, M=128, K=16    |      31.7   |     69.8
      b16 flshatt B=256, M=128, K=16    |      31.5   |     69.8
      f16 flshatt B=256, M=128, K=32    |      31.5   |     70.3
      b16 flshatt B=256, M=128, K=32    |      31.5   |     69.8
      f16 flshatt B=256, M=128, K=64    |      31.6   |     73.0
      b16 flshatt B=256, M=128, K=64    |      31.7   |     74.8
      f16 flshatt B=256, M=128, K=128   |      36.3   |     90.8
      b16 flshatt B=256, M=128, K=128   |      36.6   |     93.7
      f16 flshatt B=256, M=512, K=16    |     126.8   |    841.4
      b16 flshatt B=256, M=512, K=16    |     126.9   |    858.1
      f16 flshatt B=256, M=512, K=32    |     123.7   |    854.7
      b16 flshatt B=256, M=512, K=32    |     123.6   |    871.5
      f16 flshatt B=256, M=512, K=64    |     170.2   |    926.3
      b16 flshatt B=256, M=512, K=64    |     170.0   |    923.2
      f16 flshatt B=256, M=512, K=128   |     460.3   |   2030.6
      b16 flshatt B=256, M=512, K=128   |     410.2   |    963.8
      f16 flshatt B=256, M=1024, K=16   |     406.3   |   3234.7
      b16 flshatt B=256, M=1024, K=16   |     406.3   |   3404.4
      f16 flshatt B=256, M=1024, K=32   |     418.9   |   3283.7
      b16 flshatt B=256, M=1024, K=32   |     418.9   |   3788.2
      f16 flshatt B=256, M=1024, K=64   |     571.6   |   3353.9
      b16 flshatt B=256, M=1024, K=64   |     571.6   |   3521.9
      f16 flshatt B=256, M=1024, K=128  |    1368.0   |   3514.0
      b16 flshatt B=256, M=1024, K=128  |    1368.9   |   3684.2

Times are in microseconds (us).

[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                        |  optimized  |  vanilla
1 threads: -----------------------------------------------------
      f16 flshatt B=32, M=128, K=16     |      80.3   |    154.3
      b16 flshatt B=32, M=128, K=16     |      80.8   |    152.2
      f16 flshatt B=32, M=128, K=32     |      80.5   |    152.8
      b16 flshatt B=32, M=128, K=32     |      79.3   |    151.9
      f16 flshatt B=32, M=128, K=64     |      81.0   |    155.1
      b16 flshatt B=32, M=128, K=64     |      81.5   |    152.9
      f16 flshatt B=32, M=128, K=128    |      81.6   |    153.2
      b16 flshatt B=32, M=128, K=128    |      82.1   |    152.1
      f16 flshatt B=32, M=512, K=16     |      81.3   |    181.6
      b16 flshatt B=32, M=512, K=16     |      81.1   |    165.1
      f16 flshatt B=32, M=512, K=32     |      98.7   |    168.4
      b16 flshatt B=32, M=512, K=32     |      99.4   |    171.0
      f16 flshatt B=32, M=512, K=64     |     148.3   |    184.1
      b16 flshatt B=32, M=512, K=64     |     150.0   |    187.0
      f16 flshatt B=32, M=512, K=128    |     278.8   |    227.7
      b16 flshatt B=32, M=512, K=128    |     280.7   |    227.4
      f16 flshatt B=32, M=1024, K=16    |     225.1   |    576.4
      b16 flshatt B=32, M=1024, K=16    |     227.2   |    581.5
      f16 flshatt B=32, M=1024, K=32    |     287.3   |    596.0
      b16 flshatt B=32, M=1024, K=32    |     288.0   |    600.8
      f16 flshatt B=32, M=1024, K=64    |     456.9   |    630.6
      b16 flshatt B=32, M=1024, K=64    |     458.4   |    636.0
      f16 flshatt B=32, M=1024, K=128   |    1002.1   |    730.3
      b16 flshatt B=32, M=1024, K=128   |    1004.2   |    737.1
      f16 flshatt B=256, M=128, K=16    |      80.7   |    153.2
      b16 flshatt B=256, M=128, K=16    |      80.8   |    152.2
      f16 flshatt B=256, M=128, K=32    |      95.8   |    153.7
      b16 flshatt B=256, M=128, K=32    |      80.6   |    151.8
      f16 flshatt B=256, M=128, K=64    |      80.3   |    152.5
      b16 flshatt B=256, M=128, K=64    |      80.6   |    151.5
      f16 flshatt B=256, M=128, K=128   |     142.9   |    208.8
      b16 flshatt B=256, M=128, K=128   |     143.7   |    211.8
      f16 flshatt B=256, M=512, K=16    |     238.6   |   1137.8
      b16 flshatt B=256, M=512, K=16    |     241.2   |   1139.5
      f16 flshatt B=256, M=512, K=32    |     313.1   |   1200.8
      b16 flshatt B=256, M=512, K=32    |     316.2   |   1203.2
      f16 flshatt B=256, M=512, K=64    |     525.4   |   1343.3
      b16 flshatt B=256, M=512, K=64    |     530.5   |   1347.2
      f16 flshatt B=256, M=512, K=128   |    1067.6   |   1637.4
      b16 flshatt B=256, M=512, K=128   |    1069.3   |   1640.3
      f16 flshatt B=256, M=1024, K=16   |     705.3   |   4132.8
      b16 flshatt B=256, M=1024, K=16   |     710.1   |   4164.3
      f16 flshatt B=256, M=1024, K=32   |     968.1   |   4285.2
      b16 flshatt B=256, M=1024, K=32   |     969.9   |   4314.8
      f16 flshatt B=256, M=1024, K=64   |    1580.0   |   4579.8
      b16 flshatt B=256, M=1024, K=64   |    1583.7   |   4612.7
      f16 flshatt B=256, M=1024, K=128  |    3528.0   |   5248.4
      b16 flshatt B=256, M=1024, K=128  |    3532.3   |   5276.5

Times are in microseconds (us).

V100

FlashAttention is not supported on V100, so in this case we only compare against the baseline PyTorch implementation, on both fp16 and fp32.
For the configurations below and fp16 on the forward pass, the CUTLASS-based implementation is 25% faster on average compared to vanilla implementation (5% slower on median).
For fp32, it's 13% faster on average (4% slower on median).

For the backward and fp16, CUTLASS-based implementation is 19% slower on average (15% slower on median).
For fp32, it's 27% slower on average (30% slower on median).

CUTLASS-based kernels
[---------- attention (attn_bias=<class 'NoneType'>) ----------]
                                        |  optimized  |  vanilla
1 threads: -----------------------------------------------------
      f16 fwd_gen B=32, M=128, K=16     |      44.1   |    108.9
      f32 fwd_gen B=32, M=128, K=16     |      43.6   |    100.6
      f16 fwd_gen B=32, M=128, K=32     |      42.2   |    109.0
      f32 fwd_gen B=32, M=128, K=32     |      43.1   |    100.0
      f16 fwd_gen B=32, M=128, K=64     |      42.2   |    107.2
      f32 fwd_gen B=32, M=128, K=64     |      43.6   |    100.0
      f16 fwd_gen B=32, M=128, K=128    |      42.0   |    101.4
      f32 fwd_gen B=32, M=128, K=128    |      42.1   |     99.2
      f16 fwd_gen B=32, M=512, K=16     |     139.4   |    112.3
      f32 fwd_gen B=32, M=512, K=16     |     292.7   |    237.5
      f16 fwd_gen B=32, M=512, K=32     |     142.0   |    112.8
      f32 fwd_gen B=32, M=512, K=32     |     339.1   |    251.3
      f16 fwd_gen B=32, M=512, K=64     |     156.9   |    129.3
      f32 fwd_gen B=32, M=512, K=64     |     397.0   |    310.7
      f16 fwd_gen B=32, M=512, K=128    |     193.5   |    155.7
      f32 fwd_gen B=32, M=512, K=128    |     499.3   |    487.1
      f16 fwd_gen B=32, M=1024, K=16    |     492.8   |    461.7
      f32 fwd_gen B=32, M=1024, K=16    |    1064.4   |    759.7
      f16 fwd_gen B=32, M=1024, K=32    |     501.2   |    475.8
      f32 fwd_gen B=32, M=1024, K=32    |    1203.1   |    814.6
      f16 fwd_gen B=32, M=1024, K=64    |     560.5   |    534.5
      f32 fwd_gen B=32, M=1024, K=64    |    1408.7   |   1112.4
      f16 fwd_gen B=32, M=1024, K=128   |     687.9   |    571.4
      f32 fwd_gen B=32, M=1024, K=128   |    1948.3   |   1779.3
      f16 fwd_gen B=256, M=128, K=16    |      66.1   |    111.2
      f32 fwd_gen B=256, M=128, K=16    |     137.7   |    118.3
      f16 fwd_gen B=256, M=128, K=32    |      68.6   |    103.5
      f32 fwd_gen B=256, M=128, K=32    |     152.2   |    143.0
      f16 fwd_gen B=256, M=128, K=64    |      82.6   |    110.7
      f32 fwd_gen B=256, M=128, K=64    |     182.7   |    199.1
      f16 fwd_gen B=256, M=128, K=128   |     109.3   |    134.5
      f32 fwd_gen B=256, M=128, K=128   |     248.2   |    328.4
      f16 fwd_gen B=256, M=512, K=16    |     901.6   |    798.2
      f32 fwd_gen B=256, M=512, K=16    |    1988.3   |   1782.5
      f16 fwd_gen B=256, M=512, K=32    |     923.2   |    854.7
      f32 fwd_gen B=256, M=512, K=32    |    2194.5   |   1932.7
      f16 fwd_gen B=256, M=512, K=64    |    1036.1   |    990.3
      f32 fwd_gen B=256, M=512, K=64    |    2601.0   |   2229.9
      f16 fwd_gen B=256, M=512, K=128   |    1305.1   |   1144.1
      f32 fwd_gen B=256, M=512, K=128   |    3616.3   |   3581.4
      f16 fwd_gen B=256, M=1024, K=16   |    3480.3   |   3383.2
      f32 fwd_gen B=256, M=1024, K=16   |    7666.0   |   7007.6
      f16 fwd_gen B=256, M=1024, K=32   |    3538.4   |   3490.3
      f32 fwd_gen B=256, M=1024, K=32   |    8487.8   |   7690.9
      f16 fwd_gen B=256, M=1024, K=64   |    3965.6   |   3871.7
      f32 fwd_gen B=256, M=1024, K=64   |   10166.5   |   8852.5
      f16 fwd_gen B=256, M=1024, K=128  |    5022.8   |   4271.8
      f32 fwd_gen B=256, M=1024, K=128  |   14267.9   |  14285.4

Times are in microseconds (us).

[----- attention backward (attn_bias=<class 'NoneType'>) ------]
                                        |  optimized  |  vanilla
1 threads: -----------------------------------------------------
      f16 fwd_gen B=32, M=128, K=16     |     204.4   |    268.2
      f32 fwd_gen B=32, M=128, K=16     |     154.3   |    232.3
      f16 fwd_gen B=32, M=128, K=32     |     194.4   |    266.5
      f32 fwd_gen B=32, M=128, K=32     |     154.7   |    243.6
      f16 fwd_gen B=32, M=128, K=64     |     204.6   |    279.9
      f32 fwd_gen B=32, M=128, K=64     |     176.5   |    238.3
      f16 fwd_gen B=32, M=128, K=128    |     236.0   |    281.8
      f32 fwd_gen B=32, M=128, K=128    |     270.5   |    271.2
      f16 fwd_gen B=32, M=512, K=16     |     628.8   |    262.4
      f32 fwd_gen B=32, M=512, K=16     |    1431.3   |    569.6
      f16 fwd_gen B=32, M=512, K=32     |     661.6   |    272.6
      f32 fwd_gen B=32, M=512, K=32     |    1776.3   |    593.9
      f16 fwd_gen B=32, M=512, K=64     |     798.2   |    305.4
      f32 fwd_gen B=32, M=512, K=64     |    2148.0   |    708.0
      f16 fwd_gen B=32, M=512, K=128    |    1431.7   |    370.9
      f32 fwd_gen B=32, M=512, K=128    |    4320.2   |   1091.8
      f16 fwd_gen B=32, M=1024, K=16    |    2463.0   |    955.4
      f32 fwd_gen B=32, M=1024, K=16    |    6304.7   |   1946.5
      f16 fwd_gen B=32, M=1024, K=32    |    2609.4   |    974.3
      f32 fwd_gen B=32, M=1024, K=32    |    7121.5   |   2021.5
      f16 fwd_gen B=32, M=1024, K=64    |    3033.9   |   1060.2
      f32 fwd_gen B=32, M=1024, K=64    |    8529.3   |   2519.4
      f16 fwd_gen B=32, M=1024, K=128   |    5405.4   |   1254.1
      f32 fwd_gen B=32, M=1024, K=128   |   17318.0   |   3945.0
      f16 fwd_gen B=256, M=128, K=16    |     202.5   |    256.2
      f32 fwd_gen B=256, M=128, K=16    |     284.0   |    311.9
      f16 fwd_gen B=256, M=128, K=32    |     212.3   |    294.0
      f32 fwd_gen B=256, M=128, K=32    |     370.5   |    368.3
      f16 fwd_gen B=256, M=128, K=64    |     274.1   |    343.0
      f32 fwd_gen B=256, M=128, K=64    |     534.3   |    494.0
      f16 fwd_gen B=256, M=128, K=128   |     524.6   |    361.8
      f32 fwd_gen B=256, M=128, K=128   |    1058.4   |    800.2
      f16 fwd_gen B=256, M=512, K=16    |    1683.3   |   1819.4
      f32 fwd_gen B=256, M=512, K=16    |    4390.3   |   4136.6
      f16 fwd_gen B=256, M=512, K=32    |    1887.2   |   1942.6
      f32 fwd_gen B=256, M=512, K=32    |    5303.8   |   4388.2
      f16 fwd_gen B=256, M=512, K=64    |    2572.6   |   2250.1
      f32 fwd_gen B=256, M=512, K=64    |    7328.7   |   5068.2
      f16 fwd_gen B=256, M=512, K=128   |    5291.6   |   2733.4
      f32 fwd_gen B=256, M=512, K=128   |   14765.4   |   7850.2
      f16 fwd_gen B=256, M=1024, K=16   |    6468.6   |   6828.6
      f32 fwd_gen B=256, M=1024, K=16   |   17351.1   |  15883.1
      f16 fwd_gen B=256, M=1024, K=32   |    6959.6   |   7119.4
      f32 fwd_gen B=256, M=1024, K=32   |   20910.6   |  16698.8
      f16 fwd_gen B=256, M=1024, K=64   |    9290.5   |   8016.1
      f32 fwd_gen B=256, M=1024, K=64   |   28622.6   |  19072.8
      f16 fwd_gen B=256, M=1024, K=128  |   19679.1   |   9201.1
      f32 fwd_gen B=256, M=1024, K=128  |   57831.5   |  29781.9

Times are in microseconds (us).

P100

For the configurations below, as before we only compare against a vanilla PyTorch implementation as FlashAttention doesn't support P100s.

For the forward pass, on fp16, the CUTLASS-based kernels are 18% slower on average (22% slower on median), while for fp32 they are 10% slower on average (13% slower on median).

For the backward pass, on fp16, the CUTLASS-based kernels are 40% slower on average (45% slower on median), while on fp32 they are 33% slower on average (38% slower on median)

CUTLASS-based kernels
[---------- attention (attn_bias=<class 'NoneType'>) ----------]
                                        |  optimized  |  vanilla
1 threads: -----------------------------------------------------
      f16 fwd_gen B=32, M=128, K=16     |      48.0   |     66.8
      f32 fwd_gen B=32, M=128, K=16     |      51.2   |     64.9
      f16 fwd_gen B=32, M=128, K=32     |      52.8   |     67.0
      f32 fwd_gen B=32, M=128, K=32     |      54.8   |     64.7
      f16 fwd_gen B=32, M=128, K=64     |      62.3   |     67.1
      f32 fwd_gen B=32, M=128, K=64     |      61.9   |     64.6
      f16 fwd_gen B=32, M=128, K=128    |      81.4   |     74.7
      f32 fwd_gen B=32, M=128, K=128    |      80.6   |     83.4
      f16 fwd_gen B=32, M=512, K=16     |     549.6   |    440.2
      f32 fwd_gen B=32, M=512, K=16     |     540.3   |    501.5
      f16 fwd_gen B=32, M=512, K=32     |     602.7   |    472.2
      f32 fwd_gen B=32, M=512, K=32     |     588.4   |    531.5
      f16 fwd_gen B=32, M=512, K=64     |     721.4   |    549.9
      f32 fwd_gen B=32, M=512, K=64     |     693.9   |    607.8
      f16 fwd_gen B=32, M=512, K=128    |     999.7   |    692.6
      f32 fwd_gen B=32, M=512, K=128    |     950.7   |    763.1
      f16 fwd_gen B=32, M=1024, K=16    |    2049.2   |   1645.9
      f32 fwd_gen B=32, M=1024, K=16    |    2068.0   |   1838.5
      f16 fwd_gen B=32, M=1024, K=32    |    2243.7   |   1762.3
      f32 fwd_gen B=32, M=1024, K=32    |    2268.1   |   1954.8
      f16 fwd_gen B=32, M=1024, K=64    |    2798.1   |   2041.4
      f32 fwd_gen B=32, M=1024, K=64    |    2729.5   |   2240.3
      f16 fwd_gen B=32, M=1024, K=128   |    3850.7   |   2583.4
      f32 fwd_gen B=32, M=1024, K=128   |    3787.0   |   2838.0
      f16 fwd_gen B=256, M=128, K=16    |     264.5   |    223.4
      f32 fwd_gen B=256, M=128, K=16    |     279.5   |    260.9
      f16 fwd_gen B=256, M=128, K=32    |     293.0   |    247.6
      f32 fwd_gen B=256, M=128, K=32    |     307.4   |    290.8
      f16 fwd_gen B=256, M=128, K=64    |     361.6   |    294.8
      f32 fwd_gen B=256, M=128, K=64    |     368.3   |    344.7
      f16 fwd_gen B=256, M=128, K=128   |     492.7   |    383.5
      f32 fwd_gen B=256, M=128, K=128   |     488.4   |    456.2
      f16 fwd_gen B=256, M=512, K=16    |    3881.8   |   3047.2
      f32 fwd_gen B=256, M=512, K=16    |    4007.8   |   3489.1
      f16 fwd_gen B=256, M=512, K=32    |    4311.8   |   3298.3
      f32 fwd_gen B=256, M=512, K=32    |    4393.1   |   3734.9
      f16 fwd_gen B=256, M=512, K=64    |    5352.4   |   3897.1
      f32 fwd_gen B=256, M=512, K=64    |    5259.5   |   4345.0
      f16 fwd_gen B=256, M=512, K=128   |    7360.6   |   4995.8
      f32 fwd_gen B=256, M=512, K=128   |    7223.3   |   5589.1
      f16 fwd_gen B=256, M=1024, K=16   |   15248.0   |  11920.2
      f32 fwd_gen B=256, M=1024, K=16   |   15679.0   |  13463.3
      f16 fwd_gen B=256, M=1024, K=32   |   17105.5   |  12905.6
      f32 fwd_gen B=256, M=1024, K=32   |   17286.6   |  14516.5
      f16 fwd_gen B=256, M=1024, K=64   |   21107.6   |  15209.6
      f32 fwd_gen B=256, M=1024, K=64   |   20530.7   |  16900.9
      f16 fwd_gen B=256, M=1024, K=128  |   29082.0   |  19486.8
      f32 fwd_gen B=256, M=1024, K=128  |   28934.5   |  21414.6

Times are in microseconds (us).

[----- attention backward (attn_bias=<class 'NoneType'>) ------]
                                        |  optimized  |  vanilla
1 threads: -----------------------------------------------------
      f16 fwd_gen B=32, M=128, K=16     |      180.2  |    168.5
      f32 fwd_gen B=32, M=128, K=16     |      130.3  |    161.4
      f16 fwd_gen B=32, M=128, K=32     |      175.9  |    163.2
      f32 fwd_gen B=32, M=128, K=32     |      143.8  |    158.9
      f16 fwd_gen B=32, M=128, K=64     |      200.5  |    162.8
      f32 fwd_gen B=32, M=128, K=64     |      186.3  |    158.6
      f16 fwd_gen B=32, M=128, K=128    |      377.5  |    178.6
      f32 fwd_gen B=32, M=128, K=128    |      370.6  |    213.2
      f16 fwd_gen B=32, M=512, K=16     |     2066.7  |    995.0
      f32 fwd_gen B=32, M=512, K=16     |     1859.0  |   1180.0
      f16 fwd_gen B=32, M=512, K=32     |     2257.3  |   1035.9
      f32 fwd_gen B=32, M=512, K=32     |     2203.7  |   1227.8
      f16 fwd_gen B=32, M=512, K=64     |     2837.3  |   1129.6
      f32 fwd_gen B=32, M=512, K=64     |     2695.2  |   1342.2
      f16 fwd_gen B=32, M=512, K=128    |     5419.5  |   1506.8
      f32 fwd_gen B=32, M=512, K=128    |     5630.2  |   1754.6
      f16 fwd_gen B=32, M=1024, K=16    |     8195.9  |   3618.1
      f32 fwd_gen B=32, M=1024, K=16    |     8014.2  |   4361.1
      f16 fwd_gen B=32, M=1024, K=32    |     9565.6  |   3748.1
      f32 fwd_gen B=32, M=1024, K=32    |     8757.0  |   4486.6
      f16 fwd_gen B=32, M=1024, K=64    |    11165.5  |   4091.3
      f32 fwd_gen B=32, M=1024, K=64    |    10562.8  |   4934.7
      f16 fwd_gen B=32, M=1024, K=128   |    21419.3  |   5293.5
      f32 fwd_gen B=32, M=1024, K=128   |    22485.3  |   6307.4
      f16 fwd_gen B=256, M=128, K=16    |      572.4  |    508.9
      f32 fwd_gen B=256, M=128, K=16    |      613.8  |    623.8
      f16 fwd_gen B=256, M=128, K=32    |      681.4  |    552.6
      f32 fwd_gen B=256, M=128, K=32    |      754.1  |    695.7
      f16 fwd_gen B=256, M=128, K=64    |      925.7  |    666.1
      f32 fwd_gen B=256, M=128, K=64    |     1055.3  |    845.2
      f16 fwd_gen B=256, M=128, K=128   |     1772.8  |    972.0
      f32 fwd_gen B=256, M=128, K=128   |     2131.6  |   1233.4
      f16 fwd_gen B=256, M=512, K=16    |     8528.7  |   6674.2
      f32 fwd_gen B=256, M=512, K=16    |     9735.7  |   8177.5
      f16 fwd_gen B=256, M=512, K=32    |     9684.0  |   7039.2
      f32 fwd_gen B=256, M=512, K=32    |    10968.4  |   8761.3
      f16 fwd_gen B=256, M=512, K=64    |    12595.3  |   7896.5
      f32 fwd_gen B=256, M=512, K=64    |    15158.3  |   9801.4
      f16 fwd_gen B=256, M=512, K=128   |    25114.7  |  10617.4
      f32 fwd_gen B=256, M=512, K=128   |    31365.2  |  13021.0
      f16 fwd_gen B=256, M=1024, K=16   |    34012.5  |  25756.9
      f32 fwd_gen B=256, M=1024, K=16   |    39049.4  |  31883.0
      f16 fwd_gen B=256, M=1024, K=32   |    38586.7  |  27136.1
      f32 fwd_gen B=256, M=1024, K=32   |    44221.7  |  33717.1
      f16 fwd_gen B=256, M=1024, K=64   |    50089.4  |  29991.6
      f32 fwd_gen B=256, M=1024, K=64   |    59531.6  |  37107.7
      f16 fwd_gen B=256, M=1024, K=128  |   100267.5  |  39293.2
      f32 fwd_gen B=256, M=1024, K=128  |   123949.9  |  47268.9

Times are in microseconds (us).

cc @blefaudeux @danthe3rd @tridao

fmassa and others added 30 commits August 10, 2022 08:48
* Add attention bias in memory-efficient attention

* Add gradient for attn_mask support

* Add CPU implementation

* clang-format

* Add benchmark scripts

* Add extra loop in benchmarks

* Move zeros array out of helper function

* clang-format
* Merge compute_scaling_coeffs and update_scaling_coeffs into a single function

It wasn't needed to break it in two functions to begin with

* Add CUDA implementation for dropout

* clang-format

* Make p be drop probability

* Only CUDA supports dropout

* Add benchmarks

* Remove unused variables

* Fix test

* Cleanups and comments
@@ -53,6 +53,9 @@ There are two ways you can install xFormers locally:

```bash
git clone git@github.com:facebookresearch/xformers.git
git submodule update --init --recursive
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, I was chacking that when seeing that xformers now has two submodules, perfect. Thanks

DEFAULT_ARCHS_LIST = ""
if cuda_version > 1100:
DEFAULT_ARCHS_LIST = "7.5;8.0;8.6"
elif cuda_version >= 1100:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, but cuda_version == 1100 in that case, right ?

num = 10 * int(arch[0]) + int(arch[2])
# Need at least 7.5
if num < 75:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we printout some warnings here (or in the main setup), to recap what's being built and possibly why ? I feel like there could be a lot of issues raised around that with the build process silently skipping flashattention because of an old cuda version and users not seeing it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I'll add some log messages

But in general, we need to improve on the packaging of xformers, specially now that a lot of hardware-specific kernels are being used. @bottler might look into improving this

_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


def ref_attention(q, k, v):
def assert_allclose(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, but could be moved in some utils ? it feels like this could be used in a few places already, beyond this PR

@pytest.mark.parametrize(
"attn_bias_type", [None, xformers.ops.LowerTriangularMask, torch.Tensor]
)
@pytest.mark.parametrize("k_len", [5, 6, 32, 128])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice test cascade ! that's some serious coverage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there are a lot of combinations being tested now. Tests are not instantaneous now (~1min ?), but it's not too bad I think

dtype,
op: xformers.ops.MemoryEfficientAttentionOp,
):
scale = 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my understanding, how is this scale chosen ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just to stress test a bit more the numerics of MHA. I could have left it as 1, but with larger scales it pushes it harder in the query @ key.T part so that we could hit overflows if the softmax is not done properly, for larger K dimension.


grad_out = torch.ones_like(query)
if grad_out_contiguous is False:
grad_out = torch.tensor([1.0], device=device)[None, None, :].expand_as(query)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(updated) that works indeed, not super intuitive to me but the .expand_as() call is the one which breaks the contiguity, interesting. I would have done something like .transpose(0,1).contiguous().transpose(0,1), curious about your take @fmassa, how did you think of that formulation ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expand_as appears very often when quick-testing the backward, as it's in the gradient of .sum(). So doing op(inputs).sum().backward() yields gradients which have expanded tensors, which is a particular case of non-contiguous tensor. Given that the kernel for now just calls .contiguous() in the gradients, any non-contiguous tensor is fine to exercise this codepath

mask = torch.ops.xformers._temp_dropout(mask, p)
masks.append(mask.clone().cpu())
masks = torch.stack(masks, dim=0)
p_value = binom_test(masks.sum(), masks.numel(), p=keep_prob)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, seems much better than these tests (my bad), thanks for this very thorough take


import xformers.ops

torch.backends.cuda.matmul.allow_tf32 = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not directly related, but did you get to test out the perf effect of tf32 accumulation on A100 ? asking just in case to learn a bit more

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CUTLASS-based kernels do use a trick that benefit from tf32 while having the fp32 numerics (by performing 3 matmuls in tf32 to leverage tensorcores). The trick (which was present in cutlass examples) was to decompose a fp32 tensor in fp32 = (fp32_low_bits + fp32_high_bits), where fp32_low_bits and fp32_high_bits are tf32, so that the multiplication can be approximated by 3 matmuls on tf32 (and dropping the low_bits * low_bits part).

The implementation that uses only a single tf32 instruction is not implemented yet, but we were thinking it could be exposed by reading the info from torch.backends.cuda.matmul.allow_tf32 and dispatching to different kernels.

sub_label = f"{dtype_str} {op.NAME} B={B}, M={M}, K={K}"

if True:
r = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, op=op).float()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would be curious to get numbers on self vs. non-self attention, if anything I think that benchmarking your work in the self attention case (while very relevant for vision of course) sells you short, since there's more opportunity for the GPU to have a hot cache (and the vanilla computation is IO bottlenecked, so will benefit a lot from that)

q.grad = None
del r, rr, grad

out = xformers.ops.memory_efficient_attention(q, q, q, attn_bias, p, op=op)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above (self vs. non self), maybe that the numbers end up being similar but I would be curious to see that for real. Guess is that you would slightly increase the gap vs. pytorch for non self-attention

@@ -142,6 +146,36 @@ __device__ void compute_dot(
}
}

/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: useful ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used, I was planning on doing some refactorings to make things simpler but didn't finish it and kept it there to eventually go back to it. I could just remove it though

int64_t N,
scalar_t p,
int64_t col_offset) {
// strategy: initialize the rng so that each element in the attention
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, dropout/triton is doing the same


// we will always sample 4 random floats at a time
// as it's more efficient
constexpr int kSampled = 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same in triton, I'm guessing that this is HW dependent (big random word cut into pieces ?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it turns out that curand always generates 4 floats at a time internally, even if you call curand_random instead of curand_random4, and thus getting only 1 float was much more expensive

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be much more expensive, because curand saves unused randoms and yields them on future calls, without going through generation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel indeed, but in the current setup we have to reset the state of the philox generation very often (almost for every element), and this would make up for the slowdown, as we would reset the state (expensive), generate 4 floats and only use 1.

The strategy here is to reset the state only once every 4 elements in the output

// guarantees than by doing it properly, but is much faster
curand_init(
std::get<0>(seeds), offset, std::get<1>(seeds) + delta, &state);
// curand_init(std::get<0>(seeds) + (offset << 8) + std::get<1>(seeds), 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old API I suppose, still needed around ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that should be removed actually, was a workaround to get faster generation back when it was slower (but had fewer guarantees regarding randomness).

if (index >= M)
break;

auto out_i = reinterpret_cast<vec_t*>(output[batch_idx][index].data());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would 'const vec_t *' make sense here and where this pattern appears, read only ? may not be idiomatic cuda, in c++ it's quite typical in some codebases to be a little strict around that

// if (l < end_iter) {
{
for (; l < end_iter; l += step) {
for (int jj = 0; jj < kBlockSizeQ; jj++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that I wrote the same in the previous PR related to that @fmassa, but it feels crazy that this is the best cuda has to offer to init s_delta.. I remember you wrote that was the gist of it, not really a question here, just thinking out loud


cudaStream_t stream = at::cuda::getCurrentCUDAStream();

constexpr int WARP_SIZE = 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, but this looks like it's the number of warps, and warp_size is typically used (even in this PR, see next file) to describe the number of threads in a warp, right ? not super important but for consistency's sake

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I'll clean this up in a follow-up PR, I have some changes layed out that I'll be pushing soon

@MarkusRabe
Copy link

@fmassa Yessss!! You just made my day 🙏🎉

Drinks on me next time y'all are in SF!

@MarkusRabe can come too 😆

@lucidrains Just let me know the time and place.

Copy link
Contributor

@blefaudeux blefaudeux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great, minor cosmetic or understanding questions, but that's mostly for my sake..

The cutlass part feels very different, almost like another language and hard to review for me. I do like that it seems to be one level up in terms of abstraction, but I would be hard pressed to find bugs in there.. Would it help to have other eyes on it ? (@ngimel ?)

Thanks for all this work @fmassa @danthe3rd (and @tridao of course), I hope and think that it can be super impactful, game changer for the attention mechanism

@lucidrains
Copy link

@MarkusRabe shoot me an email!

your old email at Saarland no longer works (tried to email you some time ago)

@lw921014
Copy link

We were able to train a variety of models without issues, including ViT, DINO and DETR, on both bf16 and float32 without issues.

hi, @fmassa I meet precise loss during training swin-t model. Do you test it?

@danthe3rd
Copy link
Contributor

hi, @fmassa I meet precise loss during training swin-t model. Do you test it?

Hi, can you describe more precisely your setup? (GPU, head dimension, data type, options like causality)
Depending on this, it will be dispatched to different kernels

@lw921014
Copy link

Hi, can you describe more precisely your setup? (GPU, head dimension, data type, options like causality)

GPU: A100
head dim : 32
data type: fp16
seq len: 49
head size: 6
casusality : false

* Fast again on V100

* Fix correctness - missing syncthreads

* Get rid of AttentionInfo

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
@fmassa
Copy link
Contributor Author

fmassa commented Aug 24, 2022

Hi @lw921014

For Swin-Transformer, you need to add a relative positional embedding, which requires adding an attention bias in the attention matrix, and this configuration is not exposed yet in the setup you mentioned.

If you remove the relative positional embedding, you'll indeed see a loss in accuracy

@fmassa fmassa merged commit 5958b18 into main Aug 25, 2022
@fmassa fmassa deleted the pull_internal_changes branch August 25, 2022 08:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants