Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] cutlass FlashAttention bias+dropout support #587

Merged
merged 8 commits into from Jan 18, 2023

Conversation

jfc4050
Copy link
Contributor

@jfc4050 jfc4050 commented Dec 13, 2022

What does this PR do?

Adds support for attention bias, bias gradient, and dropout to CUTLASS FlashAttention.

i'm mostly new to CUDA programming and totally new to CUTLASS so please let me know if i'm doing anything that doesn't make sense, is slow, or is otherwise weird :)

one note: i noticed the CPU and pure CUDA implementations also support attention bias, but they expect bias to be same across queries. Bias here is implemented to accept different values along rows

TODOs

  • general optimization - in particular storing bias gradients is quite slow (benchmarks below). i've done some light profiling/optimization so far but hopefully room for more improvements
  • correctness + performance testing on SM80 - i only have a T4 (SM75) to work with, will try to find an A100 to do some testing with
  • [probably another PR] bias types like ALIBI probably could be implemented more efficiently without a tensor bias (compute on the fly instead of having to load a large matrix from global memory). starting with this as a more general solution

there's a bug in the backward benchmarks that prevent me from using it to test performance (doesn't seem to be caused by my PR as same thing happens in main). will see if i can find a fix.

Traceback (most recent call last):
  File "/home/ubuntu/ws/xformers/xformers/benchmarks/benchmark_mem_eff_attention.py", line 343, in <module>
    benchmark_main_helper(benchmark_backward, CASES, min_run_time=min_run_time)
  File "/home/ubuntu/ws/xformers/xformers/benchmarks/utils.py", line 436, in benchmark_main_helper
    measurement = benchmark_object.blocked_autorange(
  File "/home/ubuntu/.conda/envs/xformers/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 394, in blocked_autorange
    number = self._estimate_block_size(min_run_time)
  File "/home/ubuntu/.conda/envs/xformers/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 311, in _estimate_block_size
    time_taken = self._timeit(number)
  File "/home/ubuntu/.conda/envs/xformers/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 256, in _timeit
    return max(self._timer.timeit(number), 1e-9)
  File "/home/ubuntu/.conda/envs/xformers/lib/python3.10/timeit.py", line 178, in timeit
    timing = self.inner(it, self.timer)
  File "<timeit-src>", line 6, in inner
  File "/home/ubuntu/.conda/envs/xformers/lib/python3.10/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/ubuntu/.conda/envs/xformers/lib/python3.10/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ubuntu/.conda/envs/xformers/lib/python3.10/site-packages/torch/autograd/function.py", line 253, in apply
    return user_fn(self, *args)
  File "/home/ubuntu/.conda/envs/xformers/lib/python3.10/site-packages/torch/autograd/function.py", line 399, in wrapper
    outputs = fn(ctx, *args)
  File "/home/ubuntu/ws/xformers/xformers/ops/fmha/__init__.py", line 96, in backward
    query, key, value, out, lse, rng_state, attn_bias_tensor = ctx.saved_tensors
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [384, 1, 224]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

but here's some results from the benchmarks i wrote for myself. the labels take the form

(batch_sz, seq_len, n_heads, head_dim, dtype, bias_shape, bias_requires_grad, dropout_p)
[-------------------------------------------- attn --------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |     12.4    |     7.4
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |     15.4    |     9.0
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |     12.5    |     7.5
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |     15.5    |     9.1
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |     10.1    |     5.9
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |     12.6    |     7.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     44.0    |    28.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     55.0    |    34.7
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     44.6    |    28.9
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     55.0    |    34.8
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     36.5    |    22.2
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     45.2    |    28.5

Times are in milliseconds (ms).

[------------------------------------------ attn-bwd ------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |     19.2    |    24.0
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |     19.2    |    24.5
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |     22.2    |    28.7
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |     22.2    |    29.0
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |     19.3    |    22.7
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |     19.3    |    23.3
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     62.6    |    90.9
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     62.6    |    93.3
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     74.5    |   109.5
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     74.8    |   110.5
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     62.7    |    85.2
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     62.7    |    87.5

Before submitting

  • [Y] Did you have fun?
    • Make sure you had fun coding 馃檭
  • [Y] Did you read the contributor guideline?
  • [N] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • [N/A] Did you make sure to update the docs?
  • [Y] Did you write any new necessary tests?
  • [Y] Did you update the changelog? (if needed)

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot
Copy link
Contributor

Hi @jfc4050!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@danthe3rd
Copy link
Contributor

danthe3rd commented Dec 14, 2022

Hi @jfc4050

First, that's a pretty impressive PR and significant changes, which must have taken a lot of effort :o
Having attention bias (and in a smaller measure dropout) is something we've wanted to implement for some time but didn't have time to work on, so this contribution is very welcome!
I'll try to take a look at this in the next few days - as it's a pretty big chunk.

The main things that will matter (in order of importance) for me would be:
(1) Is there any performance regression for the users who don't use a mask/dropout (mostly worried for V100/A100 - I'll run benchmarks there)
(2) The forward part has been upstreamed to CUTLASS, and @hwu36 is working on optimising it. It would be great to avoid too much divergence between both files, we need to figure out something.
(3) Ease of use (eg compile time + binary size): as this multiplies by 2 the total number of attention kernels with the additional Dropout specialisation
(4) Compatibility with older hardware (eg not just Sm75+)

Once again, thanks a lot for putting all of this work there!

EDIT: We seem to have some build errors on Windows that would need to be addressed (see CI)

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a first pass on the code - a few questions and a few comments :)
Overall looks really clean and you did the effort to do things properly

Comment on lines 243 to 269
if (bias->dim() == 2) { // (n_queries, n_keys)
TORCH_INTERNAL_ASSERT(bias->size(0) == M);
TORCH_INTERNAL_ASSERT(bias->size(1) == N);

ASSIGN_CHECK_OVERFLOW(p.bias_strideB, 0);
ASSIGN_CHECK_OVERFLOW(p.bias_strideH, 0);
ASSIGN_CHECK_OVERFLOW(p.bias_strideM, grad_bias.stride(0));

if (bias_requires_grad) {
ASSIGN_CHECK_OVERFLOW(p.gB_strideB, 0);
ASSIGN_CHECK_OVERFLOW(p.gB_strideH, 0);
ASSIGN_CHECK_OVERFLOW(p.gB_strideM, bias->stride(0));
}
} else if (bias->dim() == 3) { // (batch_sz * n_heads, n_queries, n_keys)
TORCH_INTERNAL_ASSERT(bias->size(0) == B * nH);
TORCH_INTERNAL_ASSERT(bias->size(1) == M);
TORCH_INTERNAL_ASSERT(bias->size(2) == N);

ASSIGN_CHECK_OVERFLOW(p.bias_strideB, nH * bias->stride(0));
ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias->stride(0));
ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias->stride(1));

if (bias_requires_grad) {
ASSIGN_CHECK_OVERFLOW(p.gB_strideB, nH * grad_bias.stride(0));
ASSIGN_CHECK_OVERFLOW(p.gB_strideH, grad_bias.stride(0));
ASSIGN_CHECK_OVERFLOW(p.gB_strideM, grad_bias.stride(1));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we could assume the bias always has dimension 4 for simplicity, the user could torch.expand it if necessary (which will set strides to 0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, will do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: the tests use 3 dims. would you prefer i change the tests (not sure if other implementations can handle 4 dims), or handle cases for 3 and 4?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's only support dim=4 in the kernel, and expand properly in python before we call the C++ code.
Later we can simplify that to accept masks with or without head dimension (because we accept q/k/v in both BMHK and BMK shapes, so it makes sense to be coherent for the mask as well)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ended up being simpler imo to do in c++ layer. doing it in the python layer means we don't have easy access to the shape variables and the bias gradient ends up having wrong shape unless that gets reshaped too. also it messed with autograd a bit since the view from unsqueezing/expanding/reshaping doesn't require grad so we'd have to either add it to the autograd graph or add another flag to the function signature to indicate that it should compute grad

you can see what it looks like in 296e6fa, lmk if you'd still prefer to do it in python layer and we can revert/redo

Comment on lines 294 to 308
if (bias->dim() == 2) { // (n_queries, n_keys)
TORCH_CHECK(bias->size(0) == M);
TORCH_CHECK(bias->size(1) == N);

ASSIGN_CHECK_OVERFLOW(p.bias_strideB, 0);
ASSIGN_CHECK_OVERFLOW(p.bias_strideH, 0);
ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias->stride(0));
} else if (bias->dim() == 3) { // (batch_sz * n_heads, n_queries, n_keys)
TORCH_CHECK(bias->size(0) == B * num_heads);
TORCH_CHECK(bias->size(1) == M);
TORCH_CHECK(bias->size(2) == N);

ASSIGN_CHECK_OVERFLOW(p.bias_strideB, num_heads * bias->stride(0));
ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias->stride(0));
ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias->stride(1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, let's assume dim=4

typename DefaultGemm::Mma,
typename MatmulQK::AccumulatorSharedStorage>;

using DefaultMmaFromSmem = typename std::conditional<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should avoid using std::* in this code, as we want to upstream it to CUTLASS, where it needs to build with nvrtc. I believe there should be the same functionality in platform::conditional

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh got it. i'll check for other uses of std elsewhere

xformers/version.py Outdated Show resolved Hide resolved
@jfc4050
Copy link
Contributor Author

jfc4050 commented Dec 14, 2022

you did the effort to do things properly

thats a relief to hear 馃槄

ref = ref_attention(query, key, value, attn_bias, mask, p)
assert_allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}"

num_trials = 1000
p_val_tol = 0.0001
p_val_tol = 1e-6
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to mention this. dropped p_val_tol from 1e-4 -> 1e-6. otherwise the cutlass dropout mask fails the second binomial test.

took a look at the results it was outputting and it doesn't seem outlandish or anything. for example this test

test_dropout[cutlassF-33-32-2-32-0.7-42]

fails with one of the 2048 elements of masks ending up with 248/1000 keeps (p=0.7), resulting in a p value of 3.9172e-05.

looks like the other implementation uses a new subsequence every 4 elements which might have better independence guarantees but its unlikely to be as performant as the way the CUTLASS dropout implementation is done now

whats your take on this? maybe we can soften the constraint by using a percentage of elements have to pass the test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @fmassa who implemented this test and the smallK kernel which also supports dropout

@danthe3rd danthe3rd mentioned this pull request Dec 15, 2022
7 tasks
Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: The errors were unrelated to your code - now I can run it on f32 properly

@@ -546,7 +546,7 @@ def test_logsumexp(op_device_dtype_B_Mq_Mkv_H_K_Kv):
)

_out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad(
query, key, value
query, key, value, op=op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

@@ -88,6 +102,7 @@ struct AttentionKernel {
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(You might not know the answer, but) I'm wondering how much we could gain by having:
(1) The bias already in the right format (eg if you know you want to use this bias only with MHA, you could store it in a format easy to load from gmem directly, without having to go through shared-memory - this format would be different depending on the kernel running tho)
(2) A boolean mask or some datatype with even lower precision

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. i looked briefly into trying to load the bias directly from gmem but by using different threadmaps in the predicated tile iterator, rather than using different input format for the bias. Wasn't able to get the loaded fragment to match the accumulator fragment. I'm not too sure about having a different bias format but it could be tricky + require the user to understand internals because the elements of the accumulator tile each thread ends up and they way they are ordered in the fragment depends on architecture and MMA configuration. @hwu36 might know more

  2. i'd imagine this would be an improvement for the use cases that don't need floating point bias, less memory traffic and each thread loads 128 bits at a time and fewer bits -> fewer loads. would require more template specializations though

Copy link

@hwu36 hwu36 Dec 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please let me know more about this bias? Is it after the first gemm or the second gemm? Is it a 2D matrix or a 1D vector? If it is a vector, does every row have different values or every columns?

The accumulator layout of one 1688 tensor core is like

t0 t0 t1 t1 t2 t2 t3 t3
t4 t4 t5 t5 t6 t6 t7 t7
t8 t8 t9 t9 t10 t10 t11 t11
...
...
t29 t29 t30 t30 t31 t31
t0 t0 t1 t1 t2 t2 t3 t3
t4 t4 t5 t5 t6 t6 t7 t7
t8 t8 t9 t9 t10 t10 t11 t11
...
...
t29 t29 t30 t30 t31 t31

Then you need to add different offsets such as threadblock offset and warp offset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applied after 1st GEMM (Q @ K.T) and its 2d matrix

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it is a 2D matrix, using shared memory may not be bad. if you load directly, you can only load 2 elements a time, it is not the most efficient way for the memory BW. If you transform through the shared memory, you can load 128bit data a time to fully use the memory BW

},
[&](int accum_m) {});
}

// Mask out last if causal
if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting - we can still have causal masking + custom mask. We would need to find a way to expose that properly in a follow-up PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep. i think that for users that need both additive bias and unidirectionality, using causal masking in addition to attention bias might be faster than just using attention bias with some of the values replaced with -inf since the kernel can use the knowledge that its causal to avoid some unnecessary compute. if thats true it might be unintuitive to some users though

@danthe3rd
Copy link
Contributor

danthe3rd commented Dec 15, 2022

Benchmarks as of 0062b22:

A100 fw - maybe 5% slowdown on f32, similar perf on f16
[---------------- attention (attn_bias=<class 'NoneType'>) ----------------]                                                                                                                                                                                                                                
                                     |  pr587_0062  |    main    |   eager  
1 threads: -----------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |      125.0   |     121.3  |     851.8
      f32 B=384, M=197, H=1, K=88    |      463.3   |     447.4  |     719.5
      f16 B=384, M=197, H=1, K=80    |      116.2   |     112.2  |     742.2
      f32 B=384, M=197, H=1, K=80    |      459.6   |     443.5  |     692.4
      f16 B=384, M=197, H=1, K=64    |       87.8   |      86.7  |     679.8
      f32 B=384, M=197, H=1, K=64    |      281.1   |     268.7  |     640.5
      f16 B=1024, M=197, H=1, K=88   |      314.1   |     306.6  |    2221.3
      f32 B=1024, M=197, H=1, K=88   |     1220.4   |    1168.2  |    1765.2
      f16 B=1024, M=197, H=1, K=80   |      294.5   |     285.5  |    1927.2
      f32 B=1024, M=197, H=1, K=80   |     1212.9   |    1160.4  |    1700.6
      f16 B=1024, M=197, H=1, K=64   |      212.5   |     208.7  |    1770.5
      f32 B=1024, M=197, H=1, K=64   |      689.2   |     661.7  |    1568.6
      f16 B=512, M=197, H=1, K=80    |      153.0   |     148.2  |     979.2
      f32 B=512, M=197, H=1, K=80    |      614.9   |     592.4  |     881.7
      f16 B=32, M=197, H=16, K=80    |      153.9   |     149.2  |    1064.0
      f32 B=32, M=197, H=16, K=80    |      618.0   |     594.5  |    1051.0
      f16 B=32, M=197, H=16, K=64    |      113.4   |     111.7  |     979.9
      f32 B=32, M=197, H=16, K=64    |      355.1   |     341.4  |     949.7
      f16 B=32, M=197, H=16, K=128   |      168.2   |     162.5  |    1717.6
      f32 B=32, M=197, H=16, K=128   |      683.9   |     661.8  |    1324.7
      f16 B=256, M=197, H=1, K=88    |       87.3   |      85.0  |     577.0
      f32 B=256, M=197, H=1, K=88    |      318.7   |     306.9  |     477.8
      f16 B=16, M=197, H=16, K=88    |       89.0   |      86.2  |     628.7
      f32 B=16, M=197, H=16, K=88    |      320.7   |     308.4  |     572.9
      f16 B=16, M=197, H=16, K=64    |       65.0   |      61.1  |     506.9
      f32 B=16, M=197, H=16, K=64    |      195.9   |     186.9  |     494.9
      f16 B=16, M=197, H=16, K=128   |       90.2   |      87.0  |     877.3
      f32 B=16, M=197, H=16, K=128   |      352.8   |     339.6  |     692.8
      f16 B=1, M=4096, H=160, K=128  |    15244.1   |   14853.7  |   21510.6
      f32 B=1, M=4096, H=160, K=128  |    57973.4   |   56691.3  |   91407.5
      f16 B=2, M=4096, H=160, K=128  |    30414.8   |   29633.5  |   43562.2
      f32 B=2, M=4096, H=160, K=128  |   115750.4   |  113237.0  |          
      f16 B=1, M=8192, H=160, K=128  |    60840.4   |   59240.6  |          
      f32 B=1, M=8192, H=160, K=128  |   232306.8   |  226792.5  |          
      f16 B=2, M=8192, H=160, K=128  |   121618.1   |  118414.6  |          
      f32 B=2, M=8192, H=160, K=128  |   465145.6   |  453321.1  |          
      f16 B=1024, M=82, H=8, K=64    |      476.0   |     446.3  |    1785.2
      f32 B=1024, M=82, H=8, K=64    |     1430.9   |    1314.3  |    3753.8
      f16 B=150, M=256, H=16, K=64   |      512.3   |     503.1  |    1964.8
      f32 B=150, M=256, H=16, K=64   |     1691.3   |    1602.7  |    5294.3
      f16 B=64, M=256, H=12, K=64    |      172.2   |     169.8  |     664.2
      f32 B=64, M=256, H=12, K=64    |      559.7   |     533.0  |    1760.4
      f16 B=1, M=4096, H=16, K=40    |      857.3   |     870.6  |    1958.3
      f32 B=1, M=4096, H=16, K=40    |     2865.9   |    2765.6  |    6914.2
      f16 B=1, M=16384, H=16, K=40   |    12168.3   |   12346.3  |   30460.6
      f32 B=1, M=16384, H=16, K=40   |    41930.8   |   40533.2  |  123282.2
      f16 B=256, M=4096, H=16, K=64  |   181232.1   |  183488.9  |          
      f32 B=256, M=4096, H=16, K=64  |   665659.8   |  642960.3  |          
      f16 B=16, M=128, H=16, K=16    |       57.9   |      60.4  |     155.3
      f32 B=16, M=128, H=16, K=16    |       58.5   |      60.7  |     151.5
      f16 B=16, M=128, H=16, K=32    |       57.8   |      60.7  |     154.5
      f32 B=16, M=128, H=16, K=32    |       58.7   |      60.8  |     175.9
      f16 B=16, M=128, H=16, K=64    |       57.9   |      60.6  |     155.4
      f32 B=16, M=128, H=16, K=64    |       66.0   |      62.2  |     217.8
      f16 B=16, M=128, H=16, K=128   |       58.1   |      60.2  |     155.4
      f32 B=16, M=128, H=16, K=128   |      124.1   |     117.2  |     311.9
      f16 B=16, M=128, H=16, K=256   |       78.6   |      76.5  |     253.0
      f32 B=16, M=128, H=16, K=256   |      224.9   |     210.4  |     534.7
      f16 B=16, M=512, H=16, K=16    |      174.8   |     173.5  |     516.1
      f32 B=16, M=512, H=16, K=16    |      576.9   |     553.6  |    1629.1
      f16 B=16, M=512, H=16, K=32    |      182.5   |     180.5  |     564.7
      f32 B=16, M=512, H=16, K=32    |      586.2   |     562.4  |    1781.3
      f16 B=16, M=512, H=16, K=64    |      209.7   |     207.9  |     673.7
      f32 B=16, M=512, H=16, K=64    |      709.2   |     678.4  |    2107.5
      f16 B=16, M=512, H=16, K=128   |      363.0   |     353.8  |     856.0
      f32 B=16, M=512, H=16, K=128   |     1539.8   |    1493.4  |    2760.0
      f16 B=16, M=512, H=16, K=256   |      819.3   |     812.9  |    1230.8
      f32 B=16, M=512, H=16, K=256   |     3121.5   |    2984.7  |    4922.2
      f16 B=16, M=1024, H=16, K=16   |      667.3   |     662.9  |    1857.2
      f32 B=16, M=1024, H=16, K=16   |     2209.8   |    2130.1  |    6088.4
      f16 B=16, M=1024, H=16, K=32   |      673.6   |     669.0  |    1951.1
      f32 B=16, M=1024, H=16, K=32   |     2233.2   |    2150.0  |    6591.0
      f16 B=16, M=1024, H=16, K=64   |      762.9   |     766.7  |    2192.7
      f32 B=16, M=1024, H=16, K=64   |     2701.4   |    2597.3  |    7687.2
      f16 B=16, M=1024, H=16, K=128  |     1346.0   |    1311.5  |    2618.4
      f32 B=16, M=1024, H=16, K=128  |     5926.2   |    5772.7  |    9881.6
      f16 B=16, M=1024, H=16, K=256  |     3120.2   |    3096.7  |    3436.4
      f32 B=16, M=1024, H=16, K=256  |    12178.1   |   11659.9  |   17896.5
      f16 B=64, M=128, H=16, K=16    |       58.2   |      61.0  |     203.5
      f32 B=64, M=128, H=16, K=16    |      166.8   |     157.7  |     489.0
      f16 B=64, M=128, H=16, K=32    |       59.3   |      61.2  |     250.3
      f32 B=64, M=128, H=16, K=32    |      174.7   |     165.0  |     574.8
      f16 B=64, M=128, H=16, K=64    |       74.9   |      73.2  |     349.0
      f32 B=64, M=128, H=16, K=64    |      215.3   |     203.6  |     737.7
      f16 B=64, M=128, H=16, K=128   |      131.5   |     129.1  |     534.2
      f32 B=64, M=128, H=16, K=128   |      449.1   |     427.7  |    1070.0
      f16 B=64, M=128, H=16, K=256   |      257.1   |     253.4  |     914.4
      f32 B=64, M=128, H=16, K=256   |      833.6   |     791.8  |    1921.3
      f16 B=64, M=512, H=16, K=16    |      683.6   |     676.5  |    1893.0
      f32 B=64, M=512, H=16, K=16    |     2246.6   |    2161.4  |    6273.1
      f16 B=64, M=512, H=16, K=32    |      693.6   |     684.8  |    2096.1
      f32 B=64, M=512, H=16, K=32    |     2279.6   |    2190.8  |    6822.3
      f16 B=64, M=512, H=16, K=64    |      787.5   |     787.1  |    2495.9
      f32 B=64, M=512, H=16, K=64    |     2748.7   |    2635.4  |    8108.3
      f16 B=64, M=512, H=16, K=128   |     1416.0   |    1380.4  |    3264.1
      f32 B=64, M=512, H=16, K=128   |     6080.3   |    5895.7  |   10731.7
      f16 B=64, M=512, H=16, K=256   |     3214.8   |    3181.7  |    4760.3
      f32 B=64, M=512, H=16, K=256   |    12433.4   |   11825.6  |   19614.2
      f16 B=64, M=1024, H=16, K=16   |     2611.2   |    2593.0  |    7293.8
      f32 B=64, M=1024, H=16, K=16   |     8726.2   |    8424.4  |   24237.0
      f16 B=64, M=1024, H=16, K=32   |     2632.2   |    2612.0  |    7675.3
      f32 B=64, M=1024, H=16, K=32   |     8812.5   |    8494.5  |   26228.7
      f16 B=64, M=1024, H=16, K=64   |     2980.8   |    2991.8  |    8669.3
      f32 B=64, M=1024, H=16, K=64   |    10665.4   |   10271.7  |   30612.5
      f16 B=64, M=1024, H=16, K=128  |     5393.5   |    5266.6  |   10313.9
      f32 B=64, M=1024, H=16, K=128  |    23526.3   |   22940.7  |   39341.6
      f16 B=64, M=1024, H=16, K=256  |    12266.5   |   12354.8  |   13558.9
      f32 B=64, M=1024, H=16, K=256  |    48552.5   |   46488.2  |   71460.8

Times are in microseconds (us).

[ attention (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) ]
                                     |  pr587_0062  |    main    |   eager  
1 threads: -----------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |       95.1   |      92.0  |     918.3
      f32 B=384, M=197, H=1, K=88    |      336.6   |     321.6  |     779.4
      f16 B=384, M=197, H=1, K=80    |       90.4   |      86.8  |     812.3
      f32 B=384, M=197, H=1, K=80    |      333.1   |     318.7  |     752.1
      f16 B=384, M=197, H=1, K=64    |       67.2   |      65.3  |     749.9
      f32 B=384, M=197, H=1, K=64    |      203.3   |     191.7  |     705.0
      f16 B=1024, M=197, H=1, K=88   |      231.4   |     223.7  |    2389.0
      f32 B=1024, M=197, H=1, K=88   |      855.0   |     818.3  |    1907.8
      f16 B=1024, M=197, H=1, K=80   |      219.2   |     211.0  |    2105.1
      f32 B=1024, M=197, H=1, K=80   |      846.8   |     810.7  |    1843.8
      f16 B=1024, M=197, H=1, K=64   |      153.0   |     149.8  |    1945.4
      f32 B=1024, M=197, H=1, K=64   |      484.7   |     457.8  |    1719.9
      f16 B=512, M=197, H=1, K=80    |      116.2   |     111.6  |    1071.6
      f32 B=512, M=197, H=1, K=80    |      435.9   |     417.0  |     961.4
      f16 B=32, M=197, H=16, K=80    |      116.6   |     112.3  |    1154.2
      f32 B=32, M=197, H=16, K=80    |      437.2   |     418.4  |    1131.0
      f16 B=32, M=197, H=16, K=64    |       85.3   |      83.3  |    1068.3
      f32 B=32, M=197, H=16, K=64    |      260.1   |     245.5  |    1027.6
      f16 B=32, M=197, H=16, K=128   |      127.3   |     123.0  |    1802.6
      f32 B=32, M=197, H=16, K=128   |      490.9   |     467.4  |    1403.9
      f16 B=256, M=197, H=1, K=88    |       67.6   |      65.3  |     622.5
      f32 B=256, M=197, H=1, K=88    |      232.2   |     222.5  |     523.1
      f16 B=16, M=197, H=16, K=88    |       67.8   |      65.6  |     674.0
      f32 B=16, M=197, H=16, K=88    |      233.0   |     223.4  |     618.4
      f16 B=16, M=197, H=16, K=64    |       64.5   |      60.7  |     552.6
      f32 B=16, M=197, H=16, K=64    |      147.1   |     139.1  |     542.7
      f16 B=16, M=197, H=16, K=128   |       70.7   |      68.2  |     923.9
      f32 B=16, M=197, H=16, K=128   |      258.8   |     246.7  |     737.2
      f16 B=1, M=4096, H=160, K=128  |     7808.8   |    7595.2  |   38531.2
      f32 B=1, M=4096, H=160, K=128  |    30015.9   |   29331.8  |  109017.4
      f16 B=2, M=4096, H=160, K=128  |    15492.1   |   15094.0  |   78631.6
      f32 B=2, M=4096, H=160, K=128  |    59694.2   |   58316.8  |          
      f16 B=1, M=8192, H=160, K=128  |    30803.3   |   29983.0  |          
      f32 B=1, M=8192, H=160, K=128  |   117655.3   |  115038.2  |          
      f16 B=2, M=8192, H=160, K=128  |    61373.3   |   59734.6  |          
      f32 B=2, M=8192, H=160, K=128  |   234708.5   |  229411.1  |          
      f16 B=1024, M=82, H=8, K=64    |      381.5   |     370.1  |    1997.2
      f32 B=1024, M=82, H=8, K=64    |     1145.7   |    1063.2  |    3984.2
      f16 B=150, M=256, H=16, K=64   |      358.8   |     352.6  |    2604.2
      f32 B=150, M=256, H=16, K=64   |     1139.8   |    1071.0  |    5906.0
      f16 B=64, M=256, H=12, K=64    |      124.8   |     122.8  |     873.6
      f32 B=64, M=256, H=12, K=64    |      388.6   |     365.7  |    1963.7
      f16 B=1, M=4096, H=16, K=40    |      529.2   |     538.7  |    3768.9
      f32 B=1, M=4096, H=16, K=40    |     1699.7   |    1652.4  |    8888.6
      f16 B=1, M=16384, H=16, K=40   |     6541.3   |    6637.7  |   59019.0
      f32 B=1, M=16384, H=16, K=40   |    22413.7   |   21682.4  |          
      f16 B=256, M=4096, H=16, K=64  |    93055.1   |   94157.2  |          
      f32 B=256, M=4096, H=16, K=64  |   340536.6   |  328323.8  |          
      f16 B=16, M=128, H=16, K=16    |       57.8   |      60.9  |     161.1
      f32 B=16, M=128, H=16, K=16    |       58.0   |      61.0  |     171.2
      f16 B=16, M=128, H=16, K=32    |       57.9   |      60.7  |     160.9
      f32 B=16, M=128, H=16, K=32    |       58.7   |      60.6  |     192.6
      f16 B=16, M=128, H=16, K=64    |       57.6   |      60.8  |     161.7
      f32 B=16, M=128, H=16, K=64    |       63.6   |      61.0  |     244.1
      f16 B=16, M=128, H=16, K=128   |       57.8   |      60.4  |     166.6
      f32 B=16, M=128, H=16, K=128   |      114.1   |     108.0  |     341.2
      f16 B=16, M=128, H=16, K=256   |       70.1   |      68.1  |     283.8
      f32 B=16, M=128, H=16, K=256   |      201.6   |     190.6  |     562.6
      f16 B=16, M=512, H=16, K=16    |      117.4   |     115.7  |     824.6
      f32 B=16, M=512, H=16, K=16    |      368.9   |     352.6  |    2052.2
      f16 B=16, M=512, H=16, K=32    |      122.3   |     120.0  |     869.4
      f32 B=16, M=512, H=16, K=32    |      377.7   |     360.1  |    2129.5
      f16 B=16, M=512, H=16, K=64    |      140.4   |     139.4  |     961.3
      f32 B=16, M=512, H=16, K=64    |      454.7   |     433.3  |    2362.8
      f16 B=16, M=512, H=16, K=128   |      238.1   |     230.7  |    1131.0
      f32 B=16, M=512, H=16, K=128   |      971.1   |     936.1  |    2994.9
      f16 B=16, M=512, H=16, K=256   |      514.0   |     509.1  |    1503.4
      f32 B=16, M=512, H=16, K=256   |     1971.1   |    1875.0  |    5159.1
      f16 B=16, M=1024, H=16, K=16   |      389.8   |     385.9  |    3008.8
      f32 B=16, M=1024, H=16, K=16   |     1259.5   |    1211.4  |    8011.4
      f16 B=16, M=1024, H=16, K=32   |      394.5   |     390.2  |    3109.4
      f32 B=16, M=1024, H=16, K=32   |     1275.1   |    1224.9  |    8143.3
      f16 B=16, M=1024, H=16, K=64   |      447.3   |     448.2  |    3304.5
      f32 B=16, M=1024, H=16, K=64   |     1540.9   |    1477.3  |    8691.8
      f16 B=16, M=1024, H=16, K=128  |      788.6   |     765.3  |    3685.4
      f32 B=16, M=1024, H=16, K=128  |     3354.7   |    3257.0  |   10811.9
      f16 B=16, M=1024, H=16, K=256  |     1761.0   |    1749.4  |    4464.5
      f32 B=16, M=1024, H=16, K=256  |     6908.9   |    6593.8  |   18779.5
      f16 B=64, M=128, H=16, K=16    |       57.0   |      60.8  |     277.1
      f32 B=64, M=128, H=16, K=16    |      138.8   |     130.5  |     600.0
      f16 B=64, M=128, H=16, K=32    |       57.7   |      60.9  |     332.3
      f32 B=64, M=128, H=16, K=32    |      146.7   |     138.2  |     672.4
      f16 B=64, M=128, H=16, K=64    |       67.7   |      66.1  |     426.4
      f32 B=64, M=128, H=16, K=64    |      179.3   |     169.3  |     819.9
      f16 B=64, M=128, H=16, K=128   |      117.1   |     114.5  |     607.5
      f32 B=64, M=128, H=16, K=128   |      404.7   |     383.1  |    1139.5
      f16 B=64, M=128, H=16, K=256   |      232.0   |     228.8  |     988.8
      f32 B=64, M=128, H=16, K=256   |      734.4   |     695.7  |    1991.6
      f16 B=64, M=512, H=16, K=16    |      419.4   |     412.5  |    3068.8
      f32 B=64, M=512, H=16, K=16    |     1332.0   |    1274.7  |    7904.7
      f16 B=64, M=512, H=16, K=32    |      426.5   |     419.1  |    3256.0
      f32 B=64, M=512, H=16, K=32    |     1360.2   |    1297.8  |    8190.3
      f16 B=64, M=512, H=16, K=64    |      490.8   |     485.2  |    3628.3
      f32 B=64, M=512, H=16, K=64    |     1638.6   |    1560.6  |    9073.3
      f16 B=64, M=512, H=16, K=128   |      898.9   |     873.2  |    4331.2
      f32 B=64, M=512, H=16, K=128   |     3732.3   |    3597.2  |   11582.6
      f16 B=64, M=512, H=16, K=256   |     1968.1   |    1955.4  |    5832.9
      f32 B=64, M=512, H=16, K=256   |     7635.7   |    7274.0  |   20462.0
      f16 B=64, M=1024, H=16, K=16   |     1455.2   |    1441.0  |   11836.9
      f32 B=64, M=1024, H=16, K=16   |     4773.1   |    4594.4  |   31900.4
      f16 B=64, M=1024, H=16, K=32   |     1471.2   |    1454.8  |   12237.2
      f32 B=64, M=1024, H=16, K=32   |     4834.2   |    4644.8  |   32409.2
      f16 B=64, M=1024, H=16, K=64   |     1666.1   |    1668.8  |   13076.3
      f32 B=64, M=1024, H=16, K=64   |     5841.9   |    5601.8  |   34561.1
      f16 B=64, M=1024, H=16, K=128  |     3077.4   |    2981.2  |   14547.6
      f32 B=64, M=1024, H=16, K=128  |    13149.3   |   12773.4  |   42996.6
      f16 B=64, M=1024, H=16, K=256  |     6988.3   |    6942.8  |   17618.8
      f32 B=64, M=1024, H=16, K=256  |    27189.1   |   25945.4  |   74993.3

Times are in microseconds (us).
A100 bw - roughly equivalent, somehow faster now with `causal=True` on f16 (?)
[------------ attention backward (attn_bias=<class 'NoneType'>) ------------]                                                                                                                                                                                                                               
                                     |  pr587_0062  |     main    |  vanilla 
1 threads: ------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |      716.7   |      651.2  |    2260.9
      f32 B=384, M=197, H=1, K=88    |     2371.8   |     2330.6  |    1841.9
      f16 B=384, M=197, H=1, K=80    |      688.7   |      621.9  |    1916.9
      f32 B=384, M=197, H=1, K=80    |     2263.2   |     2229.2  |    1785.9
      f16 B=384, M=197, H=1, K=64    |      423.7   |      459.5  |    1808.0
      f32 B=384, M=197, H=1, K=64    |     1282.9   |     1262.5  |    1673.1
      f16 B=1024, M=197, H=1, K=88   |     1819.1   |     1609.8  |    5941.3
      f32 B=1024, M=197, H=1, K=88   |     6129.9   |     6051.2  |    4553.4
      f16 B=1024, M=197, H=1, K=80   |     1731.1   |     1536.3  |    5022.0
      f32 B=1024, M=197, H=1, K=80   |     5850.0   |     5778.6  |    4405.0
      f16 B=1024, M=197, H=1, K=64   |      965.3   |     1037.3  |    4732.1
      f32 B=1024, M=197, H=1, K=64   |     3345.0   |     3295.5  |    4112.7
      f16 B=512, M=197, H=1, K=80    |      876.9   |      785.1  |    2533.9
      f32 B=512, M=197, H=1, K=80    |     2899.7   |     2857.5  |    2283.9
      f16 B=32, M=197, H=16, K=80    |      875.9   |      787.2  |    2568.0
      f32 B=32, M=197, H=16, K=80    |     2895.1   |     2834.4  |    2351.5
      f16 B=32, M=197, H=16, K=64    |      496.2   |      538.7  |    2430.3
      f32 B=32, M=197, H=16, K=64    |     1821.3   |     1777.6  |    2195.1
      f16 B=32, M=197, H=16, K=128   |     1035.3   |      928.2  |    4486.7
      f32 B=32, M=197, H=16, K=128   |     3596.3   |     3544.5  |    2803.4
      f16 B=256, M=197, H=1, K=88    |      515.3   |      477.7  |    1521.7
      f32 B=256, M=197, H=1, K=88    |     1700.3   |     1675.8  |    1206.7
      f16 B=16, M=197, H=16, K=88    |      513.2   |      473.4  |    1539.3
      f32 B=16, M=197, H=16, K=88    |     1691.6   |     1664.6  |    1249.0
      f16 B=16, M=197, H=16, K=64    |      253.0   |      276.0  |    1242.9
      f32 B=16, M=197, H=16, K=64    |     1075.4   |     1060.9  |    1124.1
      f16 B=16, M=197, H=16, K=128   |      575.2   |      526.5  |    2266.7
      f32 B=16, M=197, H=16, K=128   |     1961.6   |     1935.8  |    1444.7
      f16 B=1, M=4096, H=160, K=128  |    62894.9   |    67019.0  |   46384.8
      f32 B=1, M=4096, H=160, K=128  |   237534.2   |   222376.4  |          
      f16 B=2, M=4096, H=160, K=128  |   106164.4   |   110240.7  |          
      f32 B=2, M=4096, H=160, K=128  |   374929.4   |   351572.8  |          
      f16 B=1, M=8192, H=160, K=128  |   245856.4   |   267465.8  |          
      f32 B=1, M=8192, H=160, K=128  |   942689.4   |   881885.6  |          
      f16 B=2, M=8192, H=160, K=128  |   419550.7   |   433848.1  |          
      f32 B=2, M=8192, H=160, K=128  |  1490911.6   |  1398395.3  |          
      f16 B=1024, M=82, H=8, K=64    |     2039.9   |     2111.2  |    3823.5
      f32 B=1024, M=82, H=8, K=64    |     8516.3   |     8376.5  |    8720.2
      f16 B=150, M=256, H=16, K=64   |     2341.4   |     2537.9  |    4560.5
      f32 B=150, M=256, H=16, K=64   |     6266.0   |     6269.7  |   12921.7
      f16 B=64, M=256, H=12, K=64    |      794.6   |      875.9  |    1499.4
      f32 B=64, M=256, H=12, K=64    |     2149.5   |     2153.6  |    4260.9
      f16 B=1, M=4096, H=16, K=40    |    23841.0   |    25712.6  |    4235.3
      f32 B=1, M=4096, H=16, K=40    |    73752.6   |    73180.5  |   17706.1
      f16 B=1, M=16384, H=16, K=40   |   397392.7   |   430370.9  |          
      f32 B=1, M=16384, H=16, K=40   |  1197343.6   |  1187422.5  |          
      f16 B=256, M=4096, H=16, K=64  |   742700.6   |   801632.2  |          
      f16 B=16, M=128, H=16, K=16    |      207.9   |      189.3  |     306.7
      f32 B=16, M=128, H=16, K=16    |      248.0   |      231.0  |     373.0
      f16 B=16, M=128, H=16, K=32    |      203.0   |      182.4  |     302.1
      f32 B=16, M=128, H=16, K=32    |      246.3   |      226.6  |     413.2
      f16 B=16, M=128, H=16, K=64    |      202.8   |      182.9  |     301.7
      f32 B=16, M=128, H=16, K=64    |      277.7   |      273.2  |     499.2
      f16 B=16, M=128, H=16, K=128   |      200.5   |      209.6  |     304.6
      f32 B=16, M=128, H=16, K=128   |      510.1   |      488.3  |     672.2
      f16 B=16, M=128, H=16, K=256   |      786.2   |      777.0  |     544.9
      f32 B=16, M=128, H=16, K=256   |      974.6   |      937.4  |    1162.6
      f16 B=16, M=512, H=16, K=16    |      640.5   |      713.5  |    1203.7
      f32 B=16, M=512, H=16, K=16    |     2173.5   |     2150.9  |    4409.0
      f16 B=16, M=512, H=16, K=32    |      723.6   |      805.8  |    1306.9
      f32 B=16, M=512, H=16, K=32    |     2354.7   |     2343.7  |    4633.5
      f16 B=16, M=512, H=16, K=64    |      927.5   |     1019.2  |    1544.1
      f32 B=16, M=512, H=16, K=64    |     2990.9   |     2981.1  |    5115.9
      f16 B=16, M=512, H=16, K=128   |     1842.4   |     1958.5  |    1984.9
      f32 B=16, M=512, H=16, K=128   |     6131.2   |     5800.4  |    6086.4
      f16 B=16, M=512, H=16, K=256   |     8430.5   |     8490.1  |    2902.9
      f32 B=16, M=512, H=16, K=256   |    11834.2   |    11313.2  |   10617.2
      f16 B=16, M=1024, H=16, K=16   |     2477.5   |     2809.0  |    4262.6
      f32 B=16, M=1024, H=16, K=16   |     8526.8   |     8520.4  |   16608.1
      f16 B=16, M=1024, H=16, K=32   |     2736.0   |     3086.4  |    4485.7
      f32 B=16, M=1024, H=16, K=32   |     9032.9   |     9040.3  |   17262.9
      f16 B=16, M=1024, H=16, K=64   |     3361.6   |     3721.9  |    4991.7
      f32 B=16, M=1024, H=16, K=64   |    11625.7   |    11677.5  |   18670.5
      f16 B=16, M=1024, H=16, K=128  |     6566.2   |     7003.0  |    5949.2
      f32 B=16, M=1024, H=16, K=128  |    23315.6   |    21954.0  |   21480.0
      f16 B=16, M=1024, H=16, K=256  |    31674.1   |    32062.5  |    7897.9
      f32 B=16, M=1024, H=16, K=256  |    45039.1   |    42840.9  |   37951.9
      f16 B=64, M=128, H=16, K=16    |      200.6   |      184.8  |     439.3
      f32 B=64, M=128, H=16, K=16    |      497.2   |      495.2  |    1268.7
      f16 B=64, M=128, H=16, K=32    |      262.4   |      241.0  |     545.3
      f32 B=64, M=128, H=16, K=32    |      604.9   |      603.1  |    1425.5
      f16 B=64, M=128, H=16, K=64    |      334.6   |      369.2  |     767.2
      f32 B=64, M=128, H=16, K=64    |      873.3   |      871.9  |    1743.4
      f16 B=64, M=128, H=16, K=128   |      698.0   |      723.9  |    1228.2
      f32 B=64, M=128, H=16, K=128   |     1771.2   |     1699.5  |    2383.6
      f16 B=64, M=128, H=16, K=256   |     2850.2   |     2888.7  |    2129.9
      f32 B=64, M=128, H=16, K=256   |     3415.2   |     3289.3  |    4314.9
      f16 B=64, M=512, H=16, K=16    |     2385.3   |     2629.4  |    4486.1
      f32 B=64, M=512, H=16, K=16    |     6698.9   |     6719.5  |   16963.1
      f16 B=64, M=512, H=16, K=32    |     2751.9   |     3005.5  |    4975.9
      f32 B=64, M=512, H=16, K=32    |     7491.6   |     7497.9  |   17823.2
      f16 B=64, M=512, H=16, K=64    |     3533.0   |     3876.7  |    5893.6
      f32 B=64, M=512, H=16, K=64    |     9617.1   |     9634.2  |   19731.2
      f16 B=64, M=512, H=16, K=128   |     6635.8   |     6871.8  |    7707.9
      f32 B=64, M=512, H=16, K=128   |    21317.5   |    20087.6  |   23584.0
      f16 B=64, M=512, H=16, K=256   |    31162.9   |    30844.5  |   11501.6
      f32 B=64, M=512, H=16, K=256   |    40918.1   |    38994.0  |   42386.4
      f16 B=64, M=1024, H=16, K=16   |     9388.2   |    10399.8  |   16846.5
      f32 B=64, M=1024, H=16, K=16   |    26568.8   |    26744.9  |   66205.9
      f16 B=64, M=1024, H=16, K=32   |    10683.5   |    11750.7  |   17866.1
      f32 B=64, M=1024, H=16, K=32   |    28430.2   |    28477.5  |   68832.9
      f16 B=64, M=1024, H=16, K=64   |    13117.1   |    14436.5  |   19915.5
      f32 B=64, M=1024, H=16, K=64   |    35834.8   |    35988.5  |   74463.8
      f16 B=64, M=1024, H=16, K=128  |    23610.9   |    24519.0  |   23742.3
      f32 B=64, M=1024, H=16, K=128  |    80716.8   |    75406.6  |   85733.5
      f16 B=64, M=1024, H=16, K=256  |   114888.4   |   115626.1  |   32765.2
      f32 B=64, M=1024, H=16, K=256  |   155081.1   |   147906.3  |  152428.4

Times are in microseconds (us).

[ attention backward (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) ]
                                     |  pr587_0062  |    main    |  vanilla 
1 threads: -----------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |      565.6   |     527.5  |    2261.1
      f32 B=384, M=197, H=1, K=88    |     1853.5   |    1791.2  |    1841.0
      f16 B=384, M=197, H=1, K=80    |      538.6   |     501.6  |    1915.5
      f32 B=384, M=197, H=1, K=80    |     1787.3   |    1722.7  |    1786.5
      f16 B=384, M=197, H=1, K=64    |      284.1   |     325.0  |    1810.7
      f32 B=384, M=197, H=1, K=64    |      979.8   |     974.5  |    1674.6
      f16 B=1024, M=197, H=1, K=88   |     1425.0   |    1302.1  |    5939.6
      f32 B=1024, M=197, H=1, K=88   |     4696.6   |    4595.7  |    4552.5
      f16 B=1024, M=197, H=1, K=80   |     1360.8   |    1237.0  |    5019.9
      f32 B=1024, M=197, H=1, K=80   |     4521.6   |    4435.3  |    4406.9
      f16 B=1024, M=197, H=1, K=64   |      645.9   |     725.8  |    4732.0
      f32 B=1024, M=197, H=1, K=64   |     2605.8   |    2548.6  |    4112.1
      f16 B=512, M=197, H=1, K=80    |      688.9   |     634.7  |    2535.1
      f32 B=512, M=197, H=1, K=80    |     2234.7   |    2198.3  |    2283.7
      f16 B=32, M=197, H=16, K=80    |      695.4   |     640.3  |    2570.0
      f32 B=32, M=197, H=16, K=80    |     2227.4   |    2196.0  |    2351.8
      f16 B=32, M=197, H=16, K=64    |      339.2   |     378.3  |    2428.0
      f32 B=32, M=197, H=16, K=64    |     1330.9   |    1371.3  |    2193.7
      f16 B=32, M=197, H=16, K=128   |      832.6   |     765.2  |    4489.1
      f32 B=32, M=197, H=16, K=128   |     2724.8   |    2684.6  |    2802.6
      f16 B=256, M=197, H=1, K=88    |      406.9   |     386.0  |    1526.5
      f32 B=256, M=197, H=1, K=88    |     1315.9   |    1291.3  |    1210.1
      f16 B=16, M=197, H=16, K=88    |      407.3   |     385.3  |    1539.7
      f32 B=16, M=197, H=16, K=88    |     1307.9   |    1282.6  |    1251.5
      f16 B=16, M=197, H=16, K=64    |      200.6   |     192.4  |    1243.9
      f32 B=16, M=197, H=16, K=64    |      812.8   |     809.9  |    1126.4
      f16 B=16, M=197, H=16, K=128   |      460.6   |     431.8  |    2268.8
      f32 B=16, M=197, H=16, K=128   |     1547.5   |    1484.0  |    1445.7
      f16 B=1, M=4096, H=160, K=128  |    33562.5   |   36002.6  |   46369.0
      f32 B=1, M=4096, H=160, K=128  |   123681.2   |  117500.2  |          
      f16 B=2, M=4096, H=160, K=128  |    56569.5   |   58882.2  |          
      f32 B=2, M=4096, H=160, K=128  |   196332.7   |  185992.3  |          
      f16 B=1, M=8192, H=160, K=128  |   128776.5   |  138568.8  |          
      f32 B=1, M=8192, H=160, K=128  |   482544.1   |  455362.6  |          
      f16 B=2, M=8192, H=160, K=128  |   217573.6   |  225203.4  |          
      f32 B=2, M=8192, H=160, K=128  |   763604.3   |  722918.1  |          
      f16 B=1024, M=82, H=8, K=64    |     1653.0   |    1726.6  |    3822.7
      f32 B=1024, M=82, H=8, K=64    |     7709.0   |    7623.2  |    8710.4
      f16 B=150, M=256, H=16, K=64   |     1651.4   |    1826.8  |    4561.7
      f32 B=150, M=256, H=16, K=64   |     4488.8   |    4477.1  |   12926.9
      f16 B=64, M=256, H=12, K=64    |      568.4   |     632.6  |    1500.9
      f32 B=64, M=256, H=12, K=64    |     1539.8   |    1534.2  |    4260.9
      f16 B=1, M=4096, H=16, K=40    |    11162.6   |   12100.5  |    4237.0
      f32 B=1, M=4096, H=16, K=40    |    35687.6   |   35281.9  |   17692.4
      f16 B=1, M=16384, H=16, K=40   |   198363.5   |  221542.6  |          
      f32 B=1, M=16384, H=16, K=40   |   597947.5   |  592061.1  |          
      f16 B=256, M=4096, H=16, K=64  |   389118.2   |  424073.1  |          
      f16 B=16, M=128, H=16, K=16    |      202.8   |     183.9  |     289.0
      f32 B=16, M=128, H=16, K=16    |      245.7   |     227.3  |     373.5
      f16 B=16, M=128, H=16, K=32    |      204.2   |     182.8  |     286.4
      f32 B=16, M=128, H=16, K=32    |      243.5   |     227.3  |     415.3
      f16 B=16, M=128, H=16, K=64    |      202.3   |     184.3  |     287.4
      f32 B=16, M=128, H=16, K=64    |      241.6   |     231.1  |     502.4
      f16 B=16, M=128, H=16, K=128   |      200.4   |     210.2  |     301.0
      f32 B=16, M=128, H=16, K=128   |      509.8   |     489.2  |     679.4
      f16 B=16, M=128, H=16, K=256   |      790.9   |     777.2  |     555.4
      f32 B=16, M=128, H=16, K=256   |      975.4   |     939.9  |    1163.0
      f16 B=16, M=512, H=16, K=16    |      360.4   |     413.9  |    1199.9
      f32 B=16, M=512, H=16, K=16    |     1261.1   |    1242.3  |    4408.4
      f16 B=16, M=512, H=16, K=32    |      424.1   |     484.5  |    1305.6
      f32 B=16, M=512, H=16, K=32    |     1412.7   |    1400.3  |    4633.7
      f16 B=16, M=512, H=16, K=64    |      577.0   |     641.4  |    1544.3
      f32 B=16, M=512, H=16, K=64    |     1850.4   |    1833.5  |    5117.6
      f16 B=16, M=512, H=16, K=128   |     1286.1   |    1375.8  |    1986.1
      f32 B=16, M=512, H=16, K=128   |     4045.2   |    3852.5  |    6086.9
      f16 B=16, M=512, H=16, K=256   |     5719.9   |    5757.7  |    2903.4
      f32 B=16, M=512, H=16, K=256   |     7844.9   |    7501.2  |   10619.7
      f16 B=16, M=1024, H=16, K=16   |     1317.8   |    1522.2  |    4256.2
      f32 B=16, M=1024, H=16, K=16   |     4591.7   |    4583.4  |   16612.4
      f16 B=16, M=1024, H=16, K=32   |     1486.2   |    1702.7  |    4478.4
      f32 B=16, M=1024, H=16, K=32   |     4971.2   |    4968.6  |   17261.8
      f16 B=16, M=1024, H=16, K=64   |     1914.7   |    2123.9  |    4987.3
      f32 B=16, M=1024, H=16, K=64   |     6380.9   |    6376.7  |   18674.5
      f16 B=16, M=1024, H=16, K=128  |     4028.0   |    4296.1  |    5947.0
      f32 B=16, M=1024, H=16, K=128  |    13653.7   |   12937.8  |   21481.2
      f16 B=16, M=1024, H=16, K=256  |    18675.2   |   19016.0  |    7896.4
      f32 B=16, M=1024, H=16, K=256  |    26325.6   |   25151.9  |   37929.3
      f16 B=64, M=128, H=16, K=16    |      200.7   |     184.6  |     440.2
      f32 B=64, M=128, H=16, K=16    |      405.3   |     402.4  |    1270.6
      f16 B=64, M=128, H=16, K=32    |      228.4   |     204.3  |     545.0
      f32 B=64, M=128, H=16, K=32    |      512.3   |     508.2  |    1427.3
      f16 B=64, M=128, H=16, K=64    |      288.0   |     312.9  |     773.7
      f32 B=64, M=128, H=16, K=64    |      741.6   |     737.1  |    1743.0
      f16 B=64, M=128, H=16, K=128   |      703.0   |     723.6  |    1226.1
      f32 B=64, M=128, H=16, K=128   |     1774.8   |    1703.0  |    2383.4
      f16 B=64, M=128, H=16, K=256   |     2854.0   |    2888.6  |    2129.0
      f32 B=64, M=128, H=16, K=256   |     3410.0   |    3294.6  |    4314.2
      f16 B=64, M=512, H=16, K=16    |     1315.1   |    1522.1  |    4483.1
      f32 B=64, M=512, H=16, K=16    |     3871.6   |    3864.4  |   16965.0
      f16 B=64, M=512, H=16, K=32    |     1609.5   |    1810.4  |    4972.8
      f32 B=64, M=512, H=16, K=32    |     4508.8   |    4501.9  |   17822.1
      f16 B=64, M=512, H=16, K=64    |     2225.9   |    2484.9  |    5891.3
      f32 B=64, M=512, H=16, K=64    |     5978.6   |    5975.5  |   19736.5
      f16 B=64, M=512, H=16, K=128   |     4688.2   |    4853.1  |    7704.5
      f32 B=64, M=512, H=16, K=128   |    14127.6   |   13458.0  |   23594.8
      f16 B=64, M=512, H=16, K=256   |    21160.4   |   21087.3  |   11491.5
      f32 B=64, M=512, H=16, K=256   |    27188.5   |   25985.1  |   42300.0
      f16 B=64, M=1024, H=16, K=16   |     4880.6   |    5585.3  |   16841.1
      f32 B=64, M=1024, H=16, K=16   |    14349.5   |   14389.0  |   66224.9
      f16 B=64, M=1024, H=16, K=32   |     5786.9   |    6465.5  |   17853.1
      f32 B=64, M=1024, H=16, K=32   |    15835.9   |   15835.2  |   68841.8
      f16 B=64, M=1024, H=16, K=64   |     7456.0   |    8286.4  |   19909.4
      f32 B=64, M=1024, H=16, K=64   |    20260.5   |   20341.4  |   74454.6
      f16 B=64, M=1024, H=16, K=128  |    14640.3   |   15119.5  |   23731.7
      f32 B=64, M=1024, H=16, K=128  |    47175.5   |   44702.4  |   85699.6
      f16 B=64, M=1024, H=16, K=256  |    68925.8   |   69018.2  |   32542.6
      f32 B=64, M=1024, H=16, K=256  |    91071.8   |   87053.4  |  152328.7

Times are in microseconds (us).
P100/V100 fw
[------------------- attention (attn_bias=<class 'NoneType'>) -------------------]
                                                         |     main    |   eager  
1 threads: -----------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |     1510.0  |    1397.8
                          f32 B=384, M=197, H=1, K=88    |     1434.3  |    1509.2
                          f16 B=384, M=197, H=1, K=80    |     1457.3  |    1347.1
                          f32 B=384, M=197, H=1, K=80    |     1388.8  |    1459.8
                          f16 B=384, M=197, H=1, K=64    |     1079.7  |    1245.4
                          f32 B=384, M=197, H=1, K=64    |      988.8  |    1344.6
                          f16 B=1024, M=197, H=1, K=88   |     4065.9  |    3705.0
                          f32 B=1024, M=197, H=1, K=88   |     3873.3  |    3997.2
                          f16 B=1024, M=197, H=1, K=80   |     3904.2  |    3603.3
                          f32 B=1024, M=197, H=1, K=80   |     3784.7  |    3880.4
                          f16 B=1024, M=197, H=1, K=64   |     2880.6  |    3349.5
                          f32 B=1024, M=197, H=1, K=64   |     2634.3  |    3593.6
                          f16 B=512, M=197, H=1, K=80    |     1984.0  |    1846.5
                          f32 B=512, M=197, H=1, K=80    |     1924.1  |    1975.0
                          f16 B=32, M=197, H=16, K=80    |     2028.9  |    2209.5
                          f32 B=32, M=197, H=16, K=80    |     1958.0  |    2388.4
                          f16 B=32, M=197, H=16, K=64    |     1477.0  |    1990.0
                          f32 B=32, M=197, H=16, K=64    |     1350.7  |    2163.4
                          f16 B=32, M=197, H=16, K=128   |     2530.0  |    2878.7
                          f32 B=32, M=197, H=16, K=128   |     2429.7  |    3099.2
                          f16 B=256, M=197, H=1, K=88    |     1070.5  |     990.1
                          f32 B=256, M=197, H=1, K=88    |     1024.0  |    1053.3
                          f16 B=16, M=197, H=16, K=88    |     1072.3  |    1194.3
                          f32 B=16, M=197, H=16, K=88    |     1029.1  |    1282.4
                          f16 B=16, M=197, H=16, K=64    |      764.4  |    1028.8
                          f32 B=16, M=197, H=16, K=64    |      696.7  |    1108.2
                          f16 B=16, M=197, H=16, K=128   |     1293.0  |    1470.2
                          f32 B=16, M=197, H=16, K=128   |     1230.2  |    1592.4
                          f16 B=1, M=4096, H=160, K=128  |   252185.1  |  201705.3
                          f32 B=1, M=4096, H=160, K=128  |   241164.8  |          
                          f16 B=2, M=4096, H=160, K=128  |   500526.4  |          
                          f32 B=2, M=4096, H=160, K=128  |   485637.0  |          
                          f16 B=1, M=8192, H=160, K=128  |  1015719.4  |          
                          f32 B=1, M=8192, H=160, K=128  |   996319.1  |          
                          f16 B=2, M=8192, H=160, K=128  |  2037663.8  |          
                          f32 B=2, M=8192, H=160, K=128  |  1997234.2  |          
                          f16 B=1024, M=82, H=8, K=64    |     5752.5  |    8562.6
                          f32 B=1024, M=82, H=8, K=64    |     5329.1  |    9109.9
                          f16 B=150, M=256, H=16, K=64   |     7729.5  |   11332.5
                          f32 B=150, M=256, H=16, K=64   |     7054.9  |   12679.5
                          f16 B=64, M=256, H=12, K=64    |     2511.9  |    3674.3
                          f32 B=64, M=256, H=12, K=64    |     2308.0  |    4095.5
                          f16 B=1, M=4096, H=16, K=40    |    11231.9  |   14588.4
                          f32 B=1, M=4096, H=16, K=40    |     9929.1  |   17735.1
                          f16 B=1, M=16384, H=16, K=40   |   170210.6  |          
                          f32 B=1, M=16384, H=16, K=40   |   155516.2  |          
                          f16 B=256, M=4096, H=16, K=64  |  3252023.7  |          
                          f16 B=16, M=128, H=16, K=16    |      150.9  |     265.4
                          f32 B=16, M=128, H=16, K=16    |      141.5  |     300.4
                          f16 B=16, M=128, H=16, K=32    |      179.2  |     311.5
                          f32 B=16, M=128, H=16, K=32    |      166.3  |     357.9
                          f16 B=16, M=128, H=16, K=64    |      231.8  |     414.1
                          f32 B=16, M=128, H=16, K=64    |      227.0  |     462.4
                          f16 B=16, M=128, H=16, K=128   |      437.0  |     599.3
                          f32 B=16, M=128, H=16, K=128   |      452.5  |     685.6
                          f16 B=16, M=128, H=16, K=256   |      835.5  |    1116.1
                          f32 B=16, M=128, H=16, K=256   |      899.6  |    1358.8
                          f16 B=16, M=512, H=16, K=16    |     2150.5  |    3183.3
                          f32 B=16, M=512, H=16, K=16    |     1960.3  |    3646.8
                          f16 B=16, M=512, H=16, K=32    |     2548.4  |    3576.9
                          f32 B=16, M=512, H=16, K=32    |     2256.6  |    4017.2
                          f16 B=16, M=512, H=16, K=64    |     3323.5  |    4343.4
                          f32 B=16, M=512, H=16, K=64    |     2990.8  |    4862.4
                          f16 B=16, M=512, H=16, K=128   |     6408.5  |    5901.5
                          f32 B=16, M=512, H=16, K=128   |     6033.5  |    6559.6
                          f16 B=16, M=512, H=16, K=256   |    12975.4  |   10746.3
                          f32 B=16, M=512, H=16, K=256   |    12683.6  |   11748.1
                          f16 B=16, M=1024, H=16, K=16   |     8366.1  |   12316.8
                          f32 B=16, M=1024, H=16, K=16   |     7490.7  |   13839.4
                          f16 B=16, M=1024, H=16, K=32   |     9894.1  |   13713.1
                          f32 B=16, M=1024, H=16, K=32   |     8854.9  |   15169.7
                          f16 B=16, M=1024, H=16, K=64   |    12859.7  |   16233.0
                          f32 B=16, M=1024, H=16, K=64   |    11548.4  |   18025.9
                          f16 B=16, M=1024, H=16, K=128  |    25314.9  |   21504.5
                          f32 B=16, M=1024, H=16, K=128  |    23507.9  |   23364.1
                          f16 B=16, M=1024, H=16, K=256  |    51510.7  |   38749.2
                          f32 B=16, M=1024, H=16, K=256  |    50812.9  |   41896.2
                          f16 B=64, M=128, H=16, K=16    |      565.0  |     961.0
                          f32 B=64, M=128, H=16, K=16    |      524.3  |    1091.2
                          f16 B=64, M=128, H=16, K=32    |      665.8  |    1138.3
                          f32 B=64, M=128, H=16, K=32    |      613.0  |    1300.2
                          f16 B=64, M=128, H=16, K=64    |      866.5  |    1524.1
                          f32 B=64, M=128, H=16, K=64    |      825.7  |    1723.0
                          f16 B=64, M=128, H=16, K=128   |     1682.9  |    2270.2
                          f32 B=64, M=128, H=16, K=128   |     1707.2  |    2591.8
                          f16 B=64, M=128, H=16, K=256   |     3262.1  |    4236.3
                          f32 B=64, M=128, H=16, K=256   |     3443.2  |    5390.7
                          f16 B=64, M=512, H=16, K=16    |     8419.1  |   12452.5
                          f32 B=64, M=512, H=16, K=16    |     7496.2  |   14249.4
                          f16 B=64, M=512, H=16, K=32    |     9984.7  |   14043.7
                          f32 B=64, M=512, H=16, K=32    |     8902.3  |   15898.0
                          f16 B=64, M=512, H=16, K=64    |    13064.5  |   17211.0
                          f32 B=64, M=512, H=16, K=64    |    11695.4  |   19154.0
                          f16 B=64, M=512, H=16, K=128   |    25396.8  |   23241.0
                          f32 B=64, M=512, H=16, K=128   |    24079.8  |   25745.1
                          f16 B=64, M=512, H=16, K=256   |    51471.0  |   43083.7
                          f32 B=64, M=512, H=16, K=256   |    50596.1  |   47045.4
                          f16 B=64, M=1024, H=16, K=16   |    32927.4  |   49295.4
                          f32 B=64, M=1024, H=16, K=16   |    29545.4  |   55489.4
                          f16 B=64, M=1024, H=16, K=32   |    39128.1  |   54820.3
                          f32 B=64, M=1024, H=16, K=32   |    34791.5  |   59901.5
                          f16 B=64, M=1024, H=16, K=64   |    51089.9  |   65794.8
                          f32 B=64, M=1024, H=16, K=64   |    45600.0  |   72379.6
                          f16 B=64, M=1024, H=16, K=128  |   100577.7  |   85978.0
                          f32 B=64, M=1024, H=16, K=128  |    94634.7  |   94077.1
                          f16 B=64, M=1024, H=16, K=256  |   205416.2  |  156833.3
                          f32 B=64, M=1024, H=16, K=256  |   204171.7  |          
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |      280.5  |     550.9
                          f32 B=384, M=197, H=1, K=88    |      780.4  |     888.6
                          f16 B=384, M=197, H=1, K=80    |      270.8  |     526.0
                          f32 B=384, M=197, H=1, K=80    |      755.2  |     857.3
                          f16 B=384, M=197, H=1, K=64    |      191.4  |     418.9
                          f32 B=384, M=197, H=1, K=64    |      567.8  |     714.1
                          f16 B=1024, M=197, H=1, K=88   |      742.6  |    1425.3
                          f32 B=1024, M=197, H=1, K=88   |     2075.7  |    2352.3
                          f16 B=1024, M=197, H=1, K=80   |      711.3  |    1360.5
                          f32 B=1024, M=197, H=1, K=80   |     1994.1  |    2264.5
                          f16 B=1024, M=197, H=1, K=64   |      498.9  |    1075.2
                          f32 B=1024, M=197, H=1, K=64   |     1493.6  |    1866.1
                          f16 B=512, M=197, H=1, K=80    |      359.6  |     690.8
                          f32 B=512, M=197, H=1, K=80    |     1011.7  |    1147.9
                          f16 B=32, M=197, H=16, K=80    |      374.6  |     841.7
                          f32 B=32, M=197, H=16, K=80    |     1031.9  |    1383.7
                          f16 B=32, M=197, H=16, K=64    |      255.3  |     675.1
                          f32 B=32, M=197, H=16, K=64    |      763.7  |    1133.3
                          f16 B=32, M=197, H=16, K=128   |      425.2  |    1030.2
                          f32 B=32, M=197, H=16, K=128   |     1314.9  |    1902.7
                          f16 B=256, M=197, H=1, K=88    |      193.7  |     378.2
                          f32 B=256, M=197, H=1, K=88    |      540.4  |     612.0
                          f16 B=16, M=197, H=16, K=88    |      196.0  |     461.6
                          f32 B=16, M=197, H=16, K=88    |      542.0  |     746.9
                          f16 B=16, M=197, H=16, K=64    |      171.2  |     429.8
                          f32 B=16, M=197, H=16, K=64    |      397.3  |     609.0
                          f16 B=16, M=197, H=16, K=128   |      218.1  |     536.7
                          f32 B=16, M=197, H=16, K=128   |      673.3  |     966.1
                          f16 B=1, M=4096, H=160, K=128  |    35200.8  |   44509.1
                          f32 B=1, M=4096, H=160, K=128  |   139848.7  |          
                          f16 B=2, M=4096, H=160, K=128  |    70836.1  |          
                          f32 B=2, M=4096, H=160, K=128  |   279696.9  |          
                          f16 B=1, M=8192, H=160, K=128  |   144262.0  |          
                          f32 B=1, M=8192, H=160, K=128  |   578509.8  |          
                          f16 B=2, M=8192, H=160, K=128  |   289830.3  |          
                          f32 B=2, M=8192, H=160, K=128  |  1163174.7  |          
                          f16 B=1024, M=82, H=8, K=64    |     1095.4  |    2496.5
                          f32 B=1024, M=82, H=8, K=64    |     2915.9  |    4424.4
                          f16 B=150, M=256, H=16, K=64   |     1295.8  |    3131.0
                          f32 B=150, M=256, H=16, K=64   |     4037.3  |    6739.0
                          f16 B=64, M=256, H=12, K=64    |      416.9  |    1025.5
                          f32 B=64, M=256, H=12, K=64    |     1304.1  |    2254.9
                          f16 B=1, M=4096, H=16, K=40    |     1996.3  |    4067.3
                          f32 B=1, M=4096, H=16, K=40    |     5851.3  |    8249.3
                          f16 B=1, M=16384, H=16, K=40   |    29646.1  |          
                          f32 B=1, M=16384, H=16, K=40   |    86661.6  |          
                          f16 B=256, M=4096, H=16, K=64  |   462010.1  |          
                          f16 B=16, M=128, H=16, K=16    |      148.0  |     356.6
                          f32 B=16, M=128, H=16, K=16    |      144.8  |     357.5
                          f16 B=16, M=128, H=16, K=32    |      147.6  |     345.4
                          f32 B=16, M=128, H=16, K=32    |      149.3  |     335.8
                          f16 B=16, M=128, H=16, K=64    |      143.7  |     346.2
                          f32 B=16, M=128, H=16, K=64    |      151.9  |     350.6
                          f16 B=16, M=128, H=16, K=128   |      144.9  |     349.5
                          f32 B=16, M=128, H=16, K=128   |      241.8  |     462.1
                          f16 B=16, M=128, H=16, K=256   |      167.3  |     367.1
                          f32 B=16, M=128, H=16, K=256   |      477.2  |     862.7
                          f16 B=16, M=512, H=16, K=16    |      399.8  |     845.5
                          f32 B=16, M=512, H=16, K=16    |     1088.7  |    1830.8
                          f16 B=16, M=512, H=16, K=32    |      418.4  |     937.4
                          f32 B=16, M=512, H=16, K=32    |     1264.9  |    2095.9
                          f16 B=16, M=512, H=16, K=64    |      511.8  |    1143.3
                          f32 B=16, M=512, H=16, K=64    |     1681.4  |    2501.4
                          f16 B=16, M=512, H=16, K=128   |      965.5  |    1458.1
                          f32 B=16, M=512, H=16, K=128   |     3330.1  |    4066.9
                          f16 B=16, M=512, H=16, K=256   |     2761.3  |    2316.0
                          f32 B=16, M=512, H=16, K=256   |     7204.6  |    7218.3
                          f16 B=16, M=1024, H=16, K=16   |     1537.0  |    3433.3
                          f32 B=16, M=1024, H=16, K=16   |     4195.9  |    7165.5
                          f16 B=16, M=1024, H=16, K=32   |     1589.4  |    3673.3
                          f32 B=16, M=1024, H=16, K=32   |     4958.9  |    7966.2
                          f16 B=16, M=1024, H=16, K=64   |     1945.7  |    4154.1
                          f32 B=16, M=1024, H=16, K=64   |     6528.0  |    9394.8
                          f16 B=16, M=1024, H=16, K=128  |     3621.1  |    4855.3
                          f32 B=16, M=1024, H=16, K=128  |    13031.0  |   15398.9
                          f16 B=16, M=1024, H=16, K=256  |    11071.2  |    7464.6
                          f32 B=16, M=1024, H=16, K=256  |    28171.8  |   26984.5
                          f16 B=64, M=128, H=16, K=16    |      139.2  |     364.1
                          f32 B=64, M=128, H=16, K=16    |      288.3  |     517.2
                          f16 B=64, M=128, H=16, K=32    |      172.0  |     355.8
                          f32 B=64, M=128, H=16, K=32    |      339.6  |     657.9
                          f16 B=64, M=128, H=16, K=64    |      170.1  |     492.8
                          f32 B=64, M=128, H=16, K=64    |      465.8  |     935.6
                          f16 B=64, M=128, H=16, K=128   |      323.9  |     765.4
                          f32 B=64, M=128, H=16, K=128   |      915.7  |    1600.0
                          f16 B=64, M=128, H=16, K=256   |      624.2  |    1326.2
                          f32 B=64, M=128, H=16, K=256   |     1825.8  |    2868.1
                          f16 B=64, M=512, H=16, K=16    |     1576.0  |    3237.9
                          f32 B=64, M=512, H=16, K=16    |     4193.3  |    7437.9
                          f16 B=64, M=512, H=16, K=32    |     1649.4  |    3603.9
                          f32 B=64, M=512, H=16, K=32    |     4969.9  |    8514.8
                          f16 B=64, M=512, H=16, K=64    |     2015.4  |    4469.6
                          f32 B=64, M=512, H=16, K=64    |     6611.6  |   10198.4
                          f16 B=64, M=512, H=16, K=128   |     3794.5  |    5712.6
                          f32 B=64, M=512, H=16, K=128   |    13156.5  |   16556.8
                          f16 B=64, M=512, H=16, K=256   |    10970.2  |    9144.8
                          f32 B=64, M=512, H=16, K=256   |    28069.9  |   30192.7
                          f16 B=64, M=1024, H=16, K=16   |     6039.1  |   14015.6
                          f32 B=64, M=1024, H=16, K=16   |    16379.4  |   28076.7
                          f16 B=64, M=1024, H=16, K=32   |     6222.1  |   14556.7
                          f32 B=64, M=1024, H=16, K=32   |    19341.7  |   30747.7
                          f16 B=64, M=1024, H=16, K=64   |     7476.1  |   17005.3
                          f32 B=64, M=1024, H=16, K=64   |    25863.7  |   38337.4
                          f16 B=64, M=1024, H=16, K=128  |    14205.3  |   19211.8
                          f32 B=64, M=1024, H=16, K=128  |    51596.0  |   60226.9
                          f16 B=64, M=1024, H=16, K=256  |    43848.4  |   29946.7
                          f32 B=64, M=1024, H=16, K=256  |   111474.6  |          

Times are in microseconds (us).

[- attention (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) -]
                                                         |     main    |   eager  
1 threads: -----------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |     1063.0  |    1641.4
                          f32 B=384, M=197, H=1, K=88    |     1040.8  |    1722.3
                          f16 B=384, M=197, H=1, K=80    |     1020.5  |    1588.0
                          f32 B=384, M=197, H=1, K=80    |     1005.1  |    1675.2
                          f16 B=384, M=197, H=1, K=64    |      758.3  |    1493.9
                          f32 B=384, M=197, H=1, K=64    |      683.5  |    1570.4
                          f16 B=1024, M=197, H=1, K=88   |     2817.5  |    4363.1
                          f32 B=1024, M=197, H=1, K=88   |     2779.7  |    4576.2
                          f16 B=1024, M=197, H=1, K=80   |     2722.3  |    4233.4
                          f32 B=1024, M=197, H=1, K=80   |     2694.0  |    4471.1
                          f16 B=1024, M=197, H=1, K=64   |     1997.4  |    4000.6
                          f32 B=1024, M=197, H=1, K=64   |     1808.7  |    4174.4
                          f16 B=512, M=197, H=1, K=80    |     1386.9  |    2166.3
                          f32 B=512, M=197, H=1, K=80    |     1371.1  |    2283.7
                          f16 B=32, M=197, H=16, K=80    |     1395.5  |    2543.1
                          f32 B=32, M=197, H=16, K=80    |     1383.4  |    2674.6
                          f16 B=32, M=197, H=16, K=64    |     1030.3  |    2327.5
                          f32 B=32, M=197, H=16, K=64    |      933.1  |    2465.0
                          f16 B=32, M=197, H=16, K=128   |     1760.5  |    3212.4
                          f32 B=32, M=197, H=16, K=128   |     1745.0  |    3387.1
                          f16 B=256, M=197, H=1, K=88    |      753.8  |    1151.7
                          f32 B=256, M=197, H=1, K=88    |      729.1  |    1206.3
                          f16 B=16, M=197, H=16, K=88    |      751.0  |    1361.1
                          f32 B=16, M=197, H=16, K=88    |      738.2  |    1446.1
                          f16 B=16, M=197, H=16, K=64    |      538.5  |    1198.4
                          f32 B=16, M=197, H=16, K=64    |      490.5  |    1269.8
                          f16 B=16, M=197, H=16, K=128   |      915.2  |    1650.7
                          f32 B=16, M=197, H=16, K=128   |      890.3  |    1730.6
                          f16 B=1, M=4096, H=160, K=128  |   128716.3  |          
                          f32 B=1, M=4096, H=160, K=128  |   123542.7  |          
                          f16 B=2, M=4096, H=160, K=128  |   256642.2  |          
                          f32 B=2, M=4096, H=160, K=128  |   245869.1  |          
                          f16 B=1, M=8192, H=160, K=128  |   507149.2  |          
                          f32 B=1, M=8192, H=160, K=128  |   501763.7  |          
                          f16 B=2, M=8192, H=160, K=128  |  1012807.4  |          
                          f32 B=2, M=8192, H=160, K=128  |  1008124.8  |          
                          f16 B=1024, M=82, H=8, K=64    |     4756.3  |    9552.2
                          f32 B=1024, M=82, H=8, K=64    |     4346.6  |   10218.6
                          f16 B=150, M=256, H=16, K=64   |     5001.0  |   13734.1
                          f32 B=150, M=256, H=16, K=64   |     4644.7  |   14799.1
                          f16 B=64, M=256, H=12, K=64    |     1640.0  |    4390.8
                          f32 B=64, M=256, H=12, K=64    |     1532.0  |    4743.6
                          f16 B=1, M=4096, H=16, K=40    |     5921.3  |   18371.5
                          f32 B=1, M=4096, H=16, K=40    |     5263.0  |   22564.1
                          f16 B=1, M=16384, H=16, K=40   |    88312.9  |          
                          f32 B=1, M=16384, H=16, K=40   |    81014.7  |          
                          f16 B=256, M=4096, H=16, K=64  |  1657734.4  |          
                          f16 B=16, M=128, H=16, K=16    |      126.8  |     339.5
                          f32 B=16, M=128, H=16, K=16    |      120.6  |     373.7
                          f16 B=16, M=128, H=16, K=32    |      147.7  |     381.4
                          f32 B=16, M=128, H=16, K=32    |      136.6  |     426.2
                          f16 B=16, M=128, H=16, K=64    |      189.5  |     477.6
                          f32 B=16, M=128, H=16, K=64    |      184.7  |     530.4
                          f16 B=16, M=128, H=16, K=128   |      367.7  |     667.8
                          f32 B=16, M=128, H=16, K=128   |      374.2  |     749.5
                          f16 B=16, M=128, H=16, K=256   |      701.9  |    1173.9
                          f32 B=16, M=128, H=16, K=256   |      714.6  |    1419.3
                          f16 B=16, M=512, H=16, K=16    |     1272.5  |    4180.7
                          f32 B=16, M=512, H=16, K=16    |     1148.0  |    4759.7
                          f16 B=16, M=512, H=16, K=32    |     1502.7  |    4573.0
                          f32 B=16, M=512, H=16, K=32    |     1350.9  |    5028.5
                          f16 B=16, M=512, H=16, K=64    |     1925.4  |    5352.1
                          f32 B=16, M=512, H=16, K=64    |     1770.0  |    5705.0
                          f16 B=16, M=512, H=16, K=128   |     3844.9  |    6844.9
                          f32 B=16, M=512, H=16, K=128   |     3692.4  |    7410.2
                          f16 B=16, M=512, H=16, K=256   |     7677.2  |   11763.5
                          f32 B=16, M=512, H=16, K=256   |     7520.1  |   12469.9
                          f16 B=16, M=1024, H=16, K=16   |     4571.1  |   16594.1
                          f32 B=16, M=1024, H=16, K=16   |     4116.6  |   19868.4
                          f16 B=16, M=1024, H=16, K=32   |     5390.4  |   17848.7
                          f32 B=16, M=1024, H=16, K=32   |     4856.3  |   20576.0
                          f16 B=16, M=1024, H=16, K=64   |     7004.4  |   20169.9
                          f32 B=16, M=1024, H=16, K=64   |     6332.3  |   22698.9
                          f16 B=16, M=1024, H=16, K=128  |    13939.1  |   25708.1
                          f32 B=16, M=1024, H=16, K=128  |    13245.0  |   28058.6
                          f16 B=16, M=1024, H=16, K=256  |    28277.4  |   43267.2
                          f32 B=16, M=1024, H=16, K=256  |    27680.0  |   46053.4
                          f16 B=64, M=128, H=16, K=16    |      454.9  |    1231.9
                          f32 B=64, M=128, H=16, K=16    |      418.2  |    1369.6
                          f16 B=64, M=128, H=16, K=32    |      531.9  |    1408.3
                          f32 B=64, M=128, H=16, K=32    |      494.8  |    1570.4
                          f16 B=64, M=128, H=16, K=64    |      689.3  |    1777.4
                          f32 B=64, M=128, H=16, K=64    |      649.5  |    1968.9
                          f16 B=64, M=128, H=16, K=128   |     1417.5  |    2529.1
                          f32 B=64, M=128, H=16, K=128   |     1437.2  |    2821.2
                          f16 B=64, M=128, H=16, K=256   |     2680.2  |    4523.1
                          f32 B=64, M=128, H=16, K=256   |     2752.2  |    5524.1
                          f16 B=64, M=512, H=16, K=16    |     4880.6  |   16602.6
                          f32 B=64, M=512, H=16, K=16    |     4416.0  |   18778.5
                          f16 B=64, M=512, H=16, K=32    |     5769.7  |   18066.7
                          f32 B=64, M=512, H=16, K=32    |     5192.4  |   19832.4
                          f16 B=64, M=512, H=16, K=64    |     7508.1  |   21208.7
                          f32 B=64, M=512, H=16, K=64    |     6824.1  |   22507.2
                          f16 B=64, M=512, H=16, K=128   |    15071.8  |   27390.6
                          f32 B=64, M=512, H=16, K=128   |    14547.0  |   29146.1
                          f16 B=64, M=512, H=16, K=256   |    30054.6  |   46986.3
                          f32 B=64, M=512, H=16, K=256   |    29807.5  |   50467.8
                          f16 B=64, M=1024, H=16, K=16   |    17881.9  |   67019.4
                          f32 B=64, M=1024, H=16, K=16   |    16107.3  |   79155.4
                          f16 B=64, M=1024, H=16, K=32   |    21207.7  |   72235.5
                          f32 B=64, M=1024, H=16, K=32   |    19099.1  |   82273.1
                          f16 B=64, M=1024, H=16, K=64   |    27819.1  |   83763.2
                          f32 B=64, M=1024, H=16, K=64   |    24771.3  |   91611.2
                          f16 B=64, M=1024, H=16, K=128  |    55023.8  |  104599.3
                          f32 B=64, M=1024, H=16, K=128  |    52263.7  |          
                          f16 B=64, M=1024, H=16, K=256  |   111388.9  |  175503.2
                          f32 B=64, M=1024, H=16, K=256  |   109982.0  |          
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |      210.6  |     643.6
                          f32 B=384, M=197, H=1, K=88    |      545.5  |    1029.5
                          f16 B=384, M=197, H=1, K=80    |      202.4  |     623.9
                          f32 B=384, M=197, H=1, K=80    |      526.7  |     999.5
                          f16 B=384, M=197, H=1, K=64    |      144.6  |     515.1
                          f32 B=384, M=197, H=1, K=64    |      398.2  |     857.8
                          f16 B=1024, M=197, H=1, K=88   |      552.3  |    1661.8
                          f32 B=1024, M=197, H=1, K=88   |     1438.1  |    2720.2
                          f16 B=1024, M=197, H=1, K=80   |      530.0  |    1605.3
                          f32 B=1024, M=197, H=1, K=80   |     1378.9  |    2626.9
                          f16 B=1024, M=197, H=1, K=64   |      368.1  |    1316.1
                          f32 B=1024, M=197, H=1, K=64   |     1024.9  |    2240.6
                          f16 B=512, M=197, H=1, K=80    |      270.4  |     823.3
                          f32 B=512, M=197, H=1, K=80    |      704.4  |    1332.7
                          f16 B=32, M=197, H=16, K=80    |      278.4  |     972.7
                          f32 B=32, M=197, H=16, K=80    |      720.9  |    1568.6
                          f16 B=32, M=197, H=16, K=64    |      192.9  |     801.1
                          f32 B=32, M=197, H=16, K=64    |      529.1  |    1325.7
                          f16 B=32, M=197, H=16, K=128   |      318.0  |    1146.5
                          f32 B=32, M=197, H=16, K=128   |      917.5  |    2045.7
                          f16 B=256, M=197, H=1, K=88    |      146.0  |     442.7
                          f32 B=256, M=197, H=1, K=88    |      381.0  |     706.5
                          f16 B=16, M=197, H=16, K=88    |      146.8  |     526.1
                          f32 B=16, M=197, H=16, K=88    |      383.3  |     839.9
                          f16 B=16, M=197, H=16, K=64    |      151.9  |     439.0
                          f32 B=16, M=197, H=16, K=64    |      280.0  |     707.9
                          f16 B=16, M=197, H=16, K=128   |      164.5  |     598.4
                          f32 B=16, M=197, H=16, K=128   |      472.0  |    1044.9
                          f16 B=1, M=4096, H=160, K=128  |    18344.5  |          
                          f32 B=1, M=4096, H=160, K=128  |    71344.4  |          
                          f16 B=2, M=4096, H=160, K=128  |    36563.5  |          
                          f32 B=2, M=4096, H=160, K=128  |   141996.6  |          
                          f16 B=1, M=8192, H=160, K=128  |    72900.7  |          
                          f32 B=1, M=8192, H=160, K=128  |   284758.3  |          
                          f16 B=2, M=8192, H=160, K=128  |   145710.6  |          
                          f32 B=2, M=8192, H=160, K=128  |   568170.0  |          
                          f16 B=1024, M=82, H=8, K=64    |      961.6  |    2798.5
                          f32 B=1024, M=82, H=8, K=64    |     2373.1  |    4951.5
                          f16 B=150, M=256, H=16, K=64   |      926.2  |    3821.3
                          f32 B=150, M=256, H=16, K=64   |     2624.9  |    7918.7
                          f16 B=64, M=256, H=12, K=64    |      303.3  |    1254.6
                          f32 B=64, M=256, H=12, K=64    |      858.8  |    2645.0
                          f16 B=1, M=4096, H=16, K=40    |     1072.2  |    5996.5
                          f32 B=1, M=4096, H=16, K=40    |     3053.7  |   11513.7
                          f16 B=1, M=16384, H=16, K=40   |    15403.6  |          
                          f32 B=1, M=16384, H=16, K=40   |    44929.7  |          
                          f16 B=256, M=4096, H=16, K=64  |   240596.0  |          
                          f16 B=16, M=128, H=16, K=16    |      151.9  |     351.0
                          f32 B=16, M=128, H=16, K=16    |      148.8  |     344.7
                          f16 B=16, M=128, H=16, K=32    |      177.0  |     394.4
                          f32 B=16, M=128, H=16, K=32    |      184.5  |     340.0
                          f16 B=16, M=128, H=16, K=64    |      145.4  |     351.3
                          f32 B=16, M=128, H=16, K=64    |      171.4  |     352.3
                          f16 B=16, M=128, H=16, K=128   |      171.6  |     345.0
                          f32 B=16, M=128, H=16, K=128   |      196.7  |     503.1
                          f16 B=16, M=128, H=16, K=256   |      156.8  |     388.2
                          f32 B=16, M=128, H=16, K=256   |      388.2  |     900.9
                          f16 B=16, M=512, H=16, K=16    |      261.1  |    1191.4
                          f32 B=16, M=512, H=16, K=16    |      642.7  |    2520.3
                          f16 B=16, M=512, H=16, K=32    |      268.5  |    1287.8
                          f32 B=16, M=512, H=16, K=32    |      748.9  |    2752.7
                          f16 B=16, M=512, H=16, K=64    |      327.2  |    1437.7
                          f32 B=16, M=512, H=16, K=64    |     1004.8  |    3000.1
                          f16 B=16, M=512, H=16, K=128   |      614.1  |    1757.2
                          f32 B=16, M=512, H=16, K=128   |     2001.5  |    4532.4
                          f16 B=16, M=512, H=16, K=256   |     1525.0  |    2612.2
                          f32 B=16, M=512, H=16, K=256   |     4174.8  |    7621.3
                          f16 B=16, M=1024, H=16, K=16   |      876.6  |    5043.8
                          f32 B=16, M=1024, H=16, K=16   |     2286.3  |   10962.5
                          f16 B=16, M=1024, H=16, K=32   |      908.2  |    5179.5
                          f32 B=16, M=1024, H=16, K=32   |     2674.6  |   11440.7
                          f16 B=16, M=1024, H=16, K=64   |     1100.6  |    5472.6
                          f32 B=16, M=1024, H=16, K=64   |     3600.6  |   12196.1
                          f16 B=16, M=1024, H=16, K=128  |     2062.6  |    6217.9
                          f32 B=16, M=1024, H=16, K=128  |     7193.1  |   17906.8
                          f16 B=16, M=1024, H=16, K=256  |     5683.6  |    8804.5
                          f32 B=16, M=1024, H=16, K=256  |    15334.5  |   29569.0
                          f16 B=64, M=128, H=16, K=16    |      148.6  |     381.3
                          f32 B=64, M=128, H=16, K=16    |      230.5  |     689.7
                          f16 B=64, M=128, H=16, K=32    |      144.0  |     451.7
                          f32 B=64, M=128, H=16, K=32    |      270.6  |     813.7
                          f16 B=64, M=128, H=16, K=64    |      153.4  |     587.2
                          f32 B=64, M=128, H=16, K=64    |      370.7  |    1085.4
                          f16 B=64, M=128, H=16, K=128   |      284.4  |     862.4
                          f32 B=64, M=128, H=16, K=128   |      740.5  |    1720.8
                          f16 B=64, M=128, H=16, K=256   |      529.9  |    1408.8
                          f32 B=64, M=128, H=16, K=256   |     1460.1  |    2982.0
                          f16 B=64, M=512, H=16, K=16    |      987.3  |    4603.3
                          f32 B=64, M=512, H=16, K=16    |     2408.6  |   10136.8
                          f16 B=64, M=512, H=16, K=32    |     1027.3  |    4933.2
                          f32 B=64, M=512, H=16, K=32    |     2862.2  |   11037.5
                          f16 B=64, M=512, H=16, K=64    |     1255.5  |    5622.9
                          f32 B=64, M=512, H=16, K=64    |     3861.0  |   12125.9
                          f16 B=64, M=512, H=16, K=128   |     2405.2  |    6901.4
                          f32 B=64, M=512, H=16, K=128   |     7843.0  |   18276.2
                          f16 B=64, M=512, H=16, K=256   |     5983.3  |   10359.0
                          f32 B=64, M=512, H=16, K=256   |    16371.9  |   31659.4
                          f16 B=64, M=1024, H=16, K=16   |     3408.7  |   20126.8
                          f32 B=64, M=1024, H=16, K=16   |     8819.8  |   43940.0
                          f16 B=64, M=1024, H=16, K=32   |     3531.5  |   20639.3
                          f32 B=64, M=1024, H=16, K=32   |    10498.0  |   46144.9
                          f16 B=64, M=1024, H=16, K=64   |     4277.0  |   22159.0
                          f32 B=64, M=1024, H=16, K=64   |    14147.1  |   50392.1
                          f16 B=64, M=1024, H=16, K=128  |     8174.7  |   24830.6
                          f32 B=64, M=1024, H=16, K=128  |    28411.4  |          
                          f16 B=64, M=1024, H=16, K=256  |    22481.4  |   35353.6
                          f32 B=64, M=1024, H=16, K=256  |    60469.3  |          

Times are in microseconds (us).
P100/V100 bw
[-------------- attention backward (attn_bias=<class 'NoneType'>) ---------------]
                                                         |     main    |  vanilla 
1 threads: -----------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |     6662.4  |    3591.8
                          f32 B=384, M=197, H=1, K=88    |     9584.0  |    4337.0
                          f16 B=384, M=197, H=1, K=80    |     6192.6  |    3437.6
                          f32 B=384, M=197, H=1, K=80    |     9158.2  |    4107.5
                          f16 B=384, M=197, H=1, K=64    |     3518.0  |    2927.8
                          f32 B=384, M=197, H=1, K=64    |     6136.2  |    3451.8
                          f16 B=1024, M=197, H=1, K=88   |    16310.1  |    9852.8
                          f32 B=1024, M=197, H=1, K=88   |    25756.1  |   12151.9
                          f16 B=1024, M=197, H=1, K=80   |    15522.3  |    9330.4
                          f32 B=1024, M=197, H=1, K=80   |    24578.1  |   11356.0
                          f16 B=1024, M=197, H=1, K=64   |     8935.0  |    7719.5
                          f32 B=1024, M=197, H=1, K=64   |    16599.7  |    9475.8
                          f16 B=512, M=197, H=1, K=80    |     7927.7  |    4632.1
                          f32 B=512, M=197, H=1, K=80    |    12806.3  |    5525.4
                          f16 B=32, M=197, H=16, K=80    |     8129.2  |    4891.3
                          f32 B=32, M=197, H=16, K=80    |    12774.1  |    5811.9
                          f16 B=32, M=197, H=16, K=64    |     4506.7  |    4068.1
                          f32 B=32, M=197, H=16, K=64    |     8681.2  |    4838.7
                          f16 B=32, M=197, H=16, K=128   |     9626.9  |    5991.5
                          f32 B=32, M=197, H=16, K=128   |    15544.2  |    7540.1
                          f16 B=256, M=197, H=1, K=88    |     4769.0  |    2451.1
                          f32 B=256, M=197, H=1, K=88    |     6682.9  |    2906.0
                          f16 B=16, M=197, H=16, K=88    |     4781.1  |    2549.4
                          f32 B=16, M=197, H=16, K=88    |     6629.7  |    3063.5
                          f16 B=16, M=197, H=16, K=64    |     2609.3  |    2042.1
                          f32 B=16, M=197, H=16, K=64    |     4322.7  |    2445.7
                          f16 B=16, M=197, H=16, K=128   |     5432.6  |    3014.7
                          f32 B=16, M=197, H=16, K=128   |     7794.2  |    3670.1
                          f16 B=1, M=4096, H=160, K=128  |  1033138.6  |          
                          f32 B=1, M=4096, H=160, K=128  |  1264717.2  |          
                          f16 B=2, M=4096, H=160, K=128  |  1689231.2  |          
                          f32 B=2, M=4096, H=160, K=128  |  2511754.6  |          
                          f16 B=1, M=8192, H=160, K=128  |  4110718.8  |          
                          f32 B=1, M=8192, H=160, K=128  |  5051277.9  |          
                          f16 B=2, M=8192, H=160, K=128  |  6751365.7  |          
                          f16 B=1024, M=82, H=8, K=64    |    22967.1  |   18046.4
                          f32 B=1024, M=82, H=8, K=64    |    43698.8  |   22978.7
                          f16 B=150, M=256, H=16, K=64   |    23440.9  |   24551.6
                          f32 B=150, M=256, H=16, K=64   |    37480.4  |   32205.0
                          f16 B=64, M=256, H=12, K=64    |     7491.8  |    7716.8
                          f32 B=64, M=256, H=12, K=64    |    12214.8  |    9890.6
                          f16 B=1, M=4096, H=16, K=40    |   135707.0  |   29317.2
                          f32 B=1, M=4096, H=16, K=40    |   145042.0  |   37192.7
                          f16 B=1, M=16384, H=16, K=40   |  2150814.2  |          
                          f32 B=1, M=16384, H=16, K=40   |  2295614.2  |          
                          f16 B=16, M=128, H=16, K=16    |      517.6  |     572.7
                          f32 B=16, M=128, H=16, K=16    |      652.2  |     691.7
                          f16 B=16, M=128, H=16, K=32    |      601.6  |     677.2
                          f32 B=16, M=128, H=16, K=32    |      813.6  |     828.2
                          f16 B=16, M=128, H=16, K=64    |      778.9  |     891.9
                          f32 B=16, M=128, H=16, K=64    |     1163.4  |    1088.5
                          f16 B=16, M=128, H=16, K=128   |     1607.0  |    1337.7
                          f32 B=16, M=128, H=16, K=128   |     2259.3  |    1666.7
                          f16 B=16, M=128, H=16, K=256   |     4062.0  |    2507.3
                          f32 B=16, M=128, H=16, K=256   |     4647.4  |    3356.5
                          f16 B=16, M=512, H=16, K=16    |     7866.8  |    6958.1
                          f32 B=16, M=512, H=16, K=16    |     9792.9  |    8610.8
                          f16 B=16, M=512, H=16, K=32    |     9111.2  |    7500.4
                          f32 B=16, M=512, H=16, K=32    |    11388.9  |    9295.5
                          f16 B=16, M=512, H=16, K=64    |    11402.2  |    8911.2
                          f32 B=16, M=512, H=16, K=64    |    16094.5  |   11084.9
                          f16 B=16, M=512, H=16, K=128   |    24449.4  |   12629.6
                          f32 B=16, M=512, H=16, K=128   |    32234.3  |   15264.6
                          f16 B=16, M=512, H=16, K=256   |    52619.0  |   23373.4
                          f32 B=16, M=512, H=16, K=256   |    65241.9  |   27094.9
                          f16 B=16, M=1024, H=16, K=16   |    31510.4  |   26565.4
                          f32 B=16, M=1024, H=16, K=16   |    38369.3  |   32614.4
                          f16 B=16, M=1024, H=16, K=32   |    36294.3  |   28420.7
                          f32 B=16, M=1024, H=16, K=32   |    44377.6  |   35432.3
                          f16 B=16, M=1024, H=16, K=64   |    45366.8  |   32269.4
                          f32 B=16, M=1024, H=16, K=64   |    62745.1  |   39776.9
                          f16 B=16, M=1024, H=16, K=128  |    99353.4  |   43627.4
                          f32 B=16, M=1024, H=16, K=128  |   127366.1  |   51474.4
                          f16 B=16, M=1024, H=16, K=256  |   204810.4  |   81201.6
                          f32 B=16, M=1024, H=16, K=256  |   258126.0  |   92288.2
                          f16 B=64, M=128, H=16, K=16    |     1730.2  |    2117.6
                          f32 B=64, M=128, H=16, K=16    |     2428.4  |    2576.5
                          f16 B=64, M=128, H=16, K=32    |     2070.4  |    2487.9
                          f32 B=64, M=128, H=16, K=32    |     3084.5  |    3078.0
                          f16 B=64, M=128, H=16, K=64    |     2718.5  |    3317.9
                          f32 B=64, M=128, H=16, K=64    |     4421.9  |    4237.9
                          f16 B=64, M=128, H=16, K=128   |     5646.9  |    5284.4
                          f32 B=64, M=128, H=16, K=128   |     8635.8  |    6958.5
                          f16 B=64, M=128, H=16, K=256   |    13961.0  |   10316.2
                          f32 B=64, M=128, H=16, K=256   |    17417.3  |   13584.2
                          f16 B=64, M=512, H=16, K=16    |    26936.5  |   27427.8
                          f32 B=64, M=512, H=16, K=16    |    36403.9  |   33753.3
                          f16 B=64, M=512, H=16, K=32    |    31542.1  |   30266.4
                          f32 B=64, M=512, H=16, K=32    |    42935.1  |   37398.0
                          f16 B=64, M=512, H=16, K=64    |    39718.3  |   36109.8
                          f32 B=64, M=512, H=16, K=64    |    61577.7  |   43677.3
                          f16 B=64, M=512, H=16, K=128   |    86608.8  |   51294.6
                          f32 B=64, M=512, H=16, K=128   |   123085.0  |   61843.3
                          f16 B=64, M=512, H=16, K=256   |   179902.4  |   99364.5
                          f32 B=64, M=512, H=16, K=256   |   250051.9  |  111501.9
                          f16 B=64, M=1024, H=16, K=16   |   107724.9  |  106757.7
                          f32 B=64, M=1024, H=16, K=16   |   144482.4  |          
                          f16 B=64, M=1024, H=16, K=32   |   124733.4  |  114732.4
                          f32 B=64, M=1024, H=16, K=32   |   168876.8  |          
                          f16 B=64, M=1024, H=16, K=64   |   157059.1  |  131304.1
                          f32 B=64, M=1024, H=16, K=64   |   241476.9  |          
                          f16 B=64, M=1024, H=16, K=128  |   334298.9  |  179659.1
                          f32 B=64, M=1024, H=16, K=128  |   483706.8  |          
                          f16 B=64, M=1024, H=16, K=256  |   692904.0  |          
                          f32 B=64, M=1024, H=16, K=256  |   982044.1  |          
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |     1809.0  |    1374.5
                          f32 B=384, M=197, H=1, K=88    |     4340.1  |    2247.7
                          f16 B=384, M=197, H=1, K=80    |     1732.4  |    1282.2
                          f32 B=384, M=197, H=1, K=80    |     3974.2  |    2163.4
                          f16 B=384, M=197, H=1, K=64    |     1134.9  |    1044.1
                          f32 B=384, M=197, H=1, K=64    |     2689.9  |    1741.7
                          f16 B=1024, M=197, H=1, K=88   |     4707.0  |    3724.7
                          f32 B=1024, M=197, H=1, K=88   |    10546.7  |    6061.7
                          f16 B=1024, M=197, H=1, K=80   |     4523.2  |    3330.1
                          f32 B=1024, M=197, H=1, K=80   |     9609.2  |    5719.5
                          f16 B=1024, M=197, H=1, K=64   |     2799.2  |    2675.3
                          f32 B=1024, M=197, H=1, K=64   |     6586.8  |    4507.3
                          f16 B=512, M=197, H=1, K=80    |     2380.2  |    1684.0
                          f32 B=512, M=197, H=1, K=80    |     5267.1  |    2874.5
                          f16 B=32, M=197, H=16, K=80    |     2393.9  |    1800.3
                          f32 B=32, M=197, H=16, K=80    |     5392.4  |    3029.9
                          f16 B=32, M=197, H=16, K=64    |     1558.5  |    1450.9
                          f32 B=32, M=197, H=16, K=64    |     3636.3  |    2410.5
                          f16 B=32, M=197, H=16, K=128   |     2782.4  |    2211.6
                          f32 B=32, M=197, H=16, K=128   |     6643.3  |    4061.8
                          f16 B=256, M=197, H=1, K=88    |     1357.9  |     947.7
                          f32 B=256, M=197, H=1, K=88    |     2884.8  |    1533.7
                          f16 B=16, M=197, H=16, K=88    |     1346.8  |     970.9
                          f32 B=16, M=197, H=16, K=88    |     2801.2  |    1629.0
                          f16 B=16, M=197, H=16, K=64    |      766.2  |     931.7
                          f32 B=16, M=197, H=16, K=64    |     1838.3  |    1287.4
                          f16 B=16, M=197, H=16, K=128   |     1513.8  |    1135.0
                          f32 B=16, M=197, H=16, K=128   |     3407.8  |    2034.9
                          f16 B=1, M=4096, H=160, K=128  |   169073.2  |          
                          f32 B=1, M=4096, H=160, K=128  |   550508.2  |          
                          f16 B=2, M=4096, H=160, K=128  |   340149.9  |          
                          f32 B=2, M=4096, H=160, K=128  |  1102674.8  |          
                          f16 B=1, M=8192, H=160, K=128  |   681002.9  |          
                          f32 B=1, M=8192, H=160, K=128  |  2200639.4  |          
                          f16 B=2, M=8192, H=160, K=128  |  1364914.7  |          
                          f16 B=1024, M=82, H=8, K=64    |     9059.6  |    5802.6
                          f32 B=1024, M=82, H=8, K=64    |    14694.7  |   11037.4
                          f16 B=150, M=256, H=16, K=64   |     5693.1  |    7563.8
                          f32 B=150, M=256, H=16, K=64   |    16696.3  |   16305.4
                          f16 B=64, M=256, H=12, K=64    |     1852.1  |    2386.4
                          f32 B=64, M=256, H=12, K=64    |     5462.4  |    4969.2
                          f16 B=1, M=4096, H=16, K=40    |    47164.3  |    8362.1
                          f32 B=1, M=4096, H=16, K=40    |   113058.1  |   19476.1
                          f16 B=1, M=16384, H=16, K=40   |   759023.3  |          
                          f32 B=1, M=16384, H=16, K=40   |  1804493.4  |          
                          f16 B=16, M=128, H=16, K=16    |      476.6  |     712.1
                          f32 B=16, M=128, H=16, K=16    |      619.0  |     651.7
                          f16 B=16, M=128, H=16, K=32    |      445.6  |     776.2
                          f32 B=16, M=128, H=16, K=32    |      555.9  |     662.4
                          f16 B=16, M=128, H=16, K=64    |      517.8  |     680.9
                          f32 B=16, M=128, H=16, K=64    |      601.7  |     736.3
                          f16 B=16, M=128, H=16, K=128   |      451.1  |     686.1
                          f32 B=16, M=128, H=16, K=128   |     1105.7  |    1007.0
                          f16 B=16, M=128, H=16, K=256   |     1049.9  |     888.0
                          f32 B=16, M=128, H=16, K=256   |     2192.3  |    1855.9
                          f16 B=16, M=512, H=16, K=16    |     1731.3  |    1896.6
                          f32 B=16, M=512, H=16, K=16    |     4476.5  |    4249.5
                          f16 B=16, M=512, H=16, K=32    |     1948.7  |    2095.1
                          f32 B=16, M=512, H=16, K=32    |     5679.8  |    4600.4
                          f16 B=16, M=512, H=16, K=64    |     2448.1  |    2577.1
                          f32 B=16, M=512, H=16, K=64    |     7617.6  |    5491.4
                          f16 B=16, M=512, H=16, K=128   |     4891.9  |    3380.6
                          f32 B=16, M=512, H=16, K=128   |    15084.0  |    8860.2
                          f16 B=16, M=512, H=16, K=256   |    12952.6  |    5381.5
                          f32 B=16, M=512, H=16, K=256   |    29870.0  |   16766.3
                          f16 B=16, M=1024, H=16, K=16   |     6817.0  |    6986.0
                          f32 B=16, M=1024, H=16, K=16   |    18132.2  |   16098.9
                          f16 B=16, M=1024, H=16, K=32   |     7568.5  |    7399.8
                          f32 B=16, M=1024, H=16, K=32   |    22038.8  |   17093.0
                          f16 B=16, M=1024, H=16, K=64   |     9320.6  |    8623.2
                          f32 B=16, M=1024, H=16, K=64   |    29998.6  |   20238.1
                          f16 B=16, M=1024, H=16, K=128  |    18972.4  |   10503.1
                          f32 B=16, M=1024, H=16, K=128  |    58953.5  |   33141.1
                          f16 B=16, M=1024, H=16, K=256  |    49804.3  |   17122.0
                          f32 B=16, M=1024, H=16, K=256  |   116887.9  |   60004.3
                          f16 B=64, M=128, H=16, K=16    |      509.3  |     673.2
                          f32 B=64, M=128, H=16, K=16    |     1029.9  |    1234.0
                          f16 B=64, M=128, H=16, K=32    |      546.7  |     813.7
                          f32 B=64, M=128, H=16, K=32    |     1408.3  |    1533.5
                          f16 B=64, M=128, H=16, K=64    |      745.2  |    1186.2
                          f32 B=64, M=128, H=16, K=64    |     2019.2  |    2154.9
                          f16 B=64, M=128, H=16, K=128   |     1417.3  |    1916.9
                          f32 B=64, M=128, H=16, K=128   |     3950.5  |    3779.3
                          f16 B=64, M=128, H=16, K=256   |     3808.4  |    3450.8
                          f32 B=64, M=128, H=16, K=256   |     7983.2  |    7252.1
                          f16 B=64, M=512, H=16, K=16    |     6187.6  |    7461.6
                          f32 B=64, M=512, H=16, K=16    |    16328.7  |   16558.3
                          f16 B=64, M=512, H=16, K=32    |     7026.3  |    8314.3
                          f32 B=64, M=512, H=16, K=32    |    20583.0  |   18328.0
                          f16 B=64, M=512, H=16, K=64    |     9087.4  |   10425.2
                          f32 B=64, M=512, H=16, K=64    |    27696.9  |   22791.7
                          f16 B=64, M=512, H=16, K=128   |    17574.7  |   14673.6
                          f32 B=64, M=512, H=16, K=128   |    54678.0  |   39872.5
                          f16 B=64, M=512, H=16, K=256   |    47507.7  |   26896.4
                          f32 B=64, M=512, H=16, K=256   |   109608.7  |   75908.0
                          f16 B=64, M=1024, H=16, K=16   |    24447.9  |   28512.3
                          f32 B=64, M=1024, H=16, K=16   |    65064.6  |          
                          f16 B=64, M=1024, H=16, K=32   |    27254.5  |   30504.4
                          f32 B=64, M=1024, H=16, K=32   |    80142.4  |          
                          f16 B=64, M=1024, H=16, K=64   |    34677.9  |   37021.6
                          f32 B=64, M=1024, H=16, K=64   |   108919.4  |          
                          f16 B=64, M=1024, H=16, K=128  |    68389.8  |   49203.3
                          f32 B=64, M=1024, H=16, K=128  |   214535.3  |          
                          f16 B=64, M=1024, H=16, K=256  |   183195.8  |          
                          f32 B=64, M=1024, H=16, K=256  |   425804.3  |          

Times are in microseconds (us).

[ attention backward (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) ]
                                                         |     main    |  vanilla 
1 threads: -----------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |     4252.9  |    3568.0
                          f32 B=384, M=197, H=1, K=88    |     6516.4  |    4266.8
                          f16 B=384, M=197, H=1, K=80    |     4024.4  |    3422.3
                          f32 B=384, M=197, H=1, K=80    |     6216.6  |    4078.8
                          f16 B=384, M=197, H=1, K=64    |     2367.5  |    2914.9
                          f32 B=384, M=197, H=1, K=64    |     4350.7  |    3435.6
                          f16 B=1024, M=197, H=1, K=88   |    10541.5  |    9757.7
                          f32 B=1024, M=197, H=1, K=88   |    17610.3  |   12024.4
                          f16 B=1024, M=197, H=1, K=80   |     9913.7  |    9288.2
                          f32 B=1024, M=197, H=1, K=80   |    16804.2  |   11179.8
                          f16 B=1024, M=197, H=1, K=64   |     5802.8  |    7663.7
                          f32 B=1024, M=197, H=1, K=64   |    11778.8  |    9444.2
                          f16 B=512, M=197, H=1, K=80    |     5037.0  |    4611.7
                          f32 B=512, M=197, H=1, K=80    |     8749.3  |    5465.3
                          f16 B=32, M=197, H=16, K=80    |     5118.0  |    4819.7
                          f32 B=32, M=197, H=16, K=80    |     8713.1  |    5732.0
                          f16 B=32, M=197, H=16, K=64    |     2979.9  |    4031.1
                          f32 B=32, M=197, H=16, K=64    |     6085.4  |    4790.2
                          f16 B=32, M=197, H=16, K=128   |     6053.1  |    5955.1
                          f32 B=32, M=197, H=16, K=128   |    10682.1  |    7341.6
                          f16 B=256, M=197, H=1, K=88    |     3074.8  |    2440.7
                          f32 B=256, M=197, H=1, K=88    |     4561.4  |    2860.9
                          f16 B=16, M=197, H=16, K=88    |     3082.1  |    2523.8
                          f32 B=16, M=197, H=16, K=88    |     4517.3  |    3017.4
                          f16 B=16, M=197, H=16, K=64    |     1774.8  |    2029.5
                          f32 B=16, M=197, H=16, K=64    |     3061.0  |    2429.9
                          f16 B=16, M=197, H=16, K=128   |     3489.2  |    3004.0
                          f32 B=16, M=197, H=16, K=128   |     5322.8  |    3613.6
                          f16 B=1, M=4096, H=160, K=128  |   533256.0  |          
                          f32 B=1, M=4096, H=160, K=128  |   644655.4  |          
                          f16 B=2, M=4096, H=160, K=128  |   868892.7  |          
                          f32 B=2, M=4096, H=160, K=128  |  1281898.4  |          
                          f16 B=1, M=8192, H=160, K=128  |  2085529.2  |          
                          f32 B=1, M=8192, H=160, K=128  |  2548207.5  |          
                          f16 B=2, M=8192, H=160, K=128  |  3437383.9  |          
                          f16 B=1024, M=82, H=8, K=64    |    20984.0  |   18067.6
                          f32 B=1024, M=82, H=8, K=64    |    37605.3  |   22806.3
                          f16 B=150, M=256, H=16, K=64   |    15327.7  |   24392.8
                          f32 B=150, M=256, H=16, K=64   |    24784.3  |   31835.0
                          f16 B=64, M=256, H=12, K=64    |     4922.7  |    7678.4
                          f32 B=64, M=256, H=12, K=64    |     8068.8  |    9808.9
                          f16 B=1, M=4096, H=16, K=40    |    69262.6  |   29178.9
                          f32 B=1, M=4096, H=16, K=40    |    73372.5  |   37290.0
                          f16 B=1, M=16384, H=16, K=40   |  1082724.4  |          
                          f32 B=1, M=16384, H=16, K=40   |  1156356.8  |          
                          f16 B=16, M=128, H=16, K=16    |      403.8  |     573.1
                          f32 B=16, M=128, H=16, K=16    |      514.4  |     693.1
                          f16 B=16, M=128, H=16, K=32    |      454.5  |     670.3
                          f32 B=16, M=128, H=16, K=32    |      642.2  |     821.0
                          f16 B=16, M=128, H=16, K=64    |      613.9  |     885.4
                          f32 B=16, M=128, H=16, K=64    |      922.2  |    1080.5
                          f16 B=16, M=128, H=16, K=128   |     1239.4  |    1329.3
                          f32 B=16, M=128, H=16, K=128   |     1777.7  |    1662.2
                          f16 B=16, M=128, H=16, K=256   |     3354.0  |    2500.6
                          f32 B=16, M=128, H=16, K=256   |     3651.8  |    3291.1
                          f16 B=16, M=512, H=16, K=16    |     4427.2  |    6857.0
                          f32 B=16, M=512, H=16, K=16    |     5531.2  |    8419.4
                          f16 B=16, M=512, H=16, K=32    |     5193.8  |    7481.4
                          f32 B=16, M=512, H=16, K=32    |     6608.2  |    9166.2
                          f16 B=16, M=512, H=16, K=64    |     6536.6  |    8855.1
                          f32 B=16, M=512, H=16, K=64    |     9423.3  |   10849.5
                          f16 B=16, M=512, H=16, K=128   |    13962.6  |   12345.1
                          f32 B=16, M=512, H=16, K=128   |    18679.2  |   15003.2
                          f16 B=16, M=512, H=16, K=256   |    31425.8  |   23147.7
                          f32 B=16, M=512, H=16, K=256   |    37686.5  |   26873.0
                          f16 B=16, M=1024, H=16, K=16   |    16928.6  |   26395.9
                          f32 B=16, M=1024, H=16, K=16   |    20647.1  |   32762.1
                          f16 B=16, M=1024, H=16, K=32   |    19584.6  |   28100.2
                          f32 B=16, M=1024, H=16, K=32   |    24153.9  |   35231.2
                          f16 B=16, M=1024, H=16, K=64   |    24358.8  |   31949.6
                          f32 B=16, M=1024, H=16, K=64   |    34135.4  |   39247.4
                          f16 B=16, M=1024, H=16, K=128  |    52553.6  |   42857.2
                          f32 B=16, M=1024, H=16, K=128  |    68490.7  |   50818.2
                          f16 B=16, M=1024, H=16, K=256  |   113179.1  |   79246.2
                          f32 B=16, M=1024, H=16, K=256  |   138958.7  |   90470.1
                          f16 B=64, M=128, H=16, K=16    |     1313.4  |    2093.5
                          f32 B=64, M=128, H=16, K=16    |     1912.5  |    2551.8
                          f16 B=64, M=128, H=16, K=32    |     1605.9  |    2479.4
                          f32 B=64, M=128, H=16, K=32    |     2424.6  |    3055.7
                          f16 B=64, M=128, H=16, K=64    |     2135.5  |    3306.4
                          f32 B=64, M=128, H=16, K=64    |     3512.2  |    4184.7
                          f16 B=64, M=128, H=16, K=128   |     4349.4  |    5250.9
                          f32 B=64, M=128, H=16, K=128   |     6734.9  |    6860.6
                          f16 B=64, M=128, H=16, K=256   |    11412.6  |   10225.4
                          f32 B=64, M=128, H=16, K=256   |    13715.3  |   13386.5
                          f16 B=64, M=512, H=16, K=16    |    15298.3  |   27164.6
                          f32 B=64, M=512, H=16, K=16    |    20818.0  |   33373.3
                          f16 B=64, M=512, H=16, K=32    |    18168.1  |   29831.6
                          f32 B=64, M=512, H=16, K=32    |    25124.1  |   37340.7
                          f16 B=64, M=512, H=16, K=64    |    22989.5  |   35792.1
                          f32 B=64, M=512, H=16, K=64    |    36008.3  |   43156.8
                          f16 B=64, M=512, H=16, K=128   |    48567.6  |   50699.7
                          f32 B=64, M=512, H=16, K=128   |    70832.3  |   60779.7
                          f16 B=64, M=512, H=16, K=256   |   107837.9  |   97016.9
                          f32 B=64, M=512, H=16, K=256   |   144816.6  |  109729.1
                          f16 B=64, M=1024, H=16, K=16   |    57449.7  |  105449.7
                          f32 B=64, M=1024, H=16, K=16   |    77196.8  |          
                          f16 B=64, M=1024, H=16, K=32   |    67469.2  |  113569.9
                          f32 B=64, M=1024, H=16, K=32   |    92452.9  |          
                          f16 B=64, M=1024, H=16, K=64   |    85985.5  |  129934.5
                          f32 B=64, M=1024, H=16, K=64   |   131842.4  |          
                          f16 B=64, M=1024, H=16, K=128  |   180353.2  |  176085.7
                          f32 B=64, M=1024, H=16, K=128  |   261345.7  |          
                          f16 B=64, M=1024, H=16, K=256  |   380736.0  |          
                          f32 B=64, M=1024, H=16, K=256  |   530228.5  |          
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |     1502.5  |    1373.5
                          f32 B=384, M=197, H=1, K=88    |     2885.5  |    2234.0
                          f16 B=384, M=197, H=1, K=80    |     1440.5  |    1282.0
                          f32 B=384, M=197, H=1, K=80    |     2606.4  |    2148.4
                          f16 B=384, M=197, H=1, K=64    |      824.4  |    1044.6
                          f32 B=384, M=197, H=1, K=64    |     1888.7  |    1737.6
                          f16 B=1024, M=197, H=1, K=88   |     3916.4  |    3731.4
                          f32 B=1024, M=197, H=1, K=88   |     7123.7  |    6025.3
                          f16 B=1024, M=197, H=1, K=80   |     3751.2  |    3329.3
                          f32 B=1024, M=197, H=1, K=80   |     6440.6  |    5673.2
                          f16 B=1024, M=197, H=1, K=64   |     2033.2  |    2674.8
                          f32 B=1024, M=197, H=1, K=64   |     4637.9  |    4491.2
                          f16 B=512, M=197, H=1, K=80    |     1980.8  |    1678.7
                          f32 B=512, M=197, H=1, K=80    |     3457.6  |    2856.9
                          f16 B=32, M=197, H=16, K=80    |     1972.6  |    1799.2
                          f32 B=32, M=197, H=16, K=80    |     3514.2  |    3015.5
                          f16 B=32, M=197, H=16, K=64    |     1119.0  |    1450.0
                          f32 B=32, M=197, H=16, K=64    |     2554.1  |    2415.0
                          f16 B=32, M=197, H=16, K=128   |     2290.4  |    2213.7
                          f32 B=32, M=197, H=16, K=128   |     4547.6  |    4023.2
                          f16 B=256, M=197, H=1, K=88    |     1138.3  |     941.5
                          f32 B=256, M=197, H=1, K=88    |     1926.5  |    1521.0
                          f16 B=16, M=197, H=16, K=88    |     1116.3  |     970.6
                          f32 B=16, M=197, H=16, K=88    |     1862.4  |    1606.0
                          f16 B=16, M=197, H=16, K=64    |      557.9  |     795.7
                          f32 B=16, M=197, H=16, K=64    |     1275.8  |    1285.0
                          f16 B=16, M=197, H=16, K=128   |     1245.8  |    1136.1
                          f32 B=16, M=197, H=16, K=128   |     2292.3  |    2016.9
                          f16 B=1, M=4096, H=160, K=128  |    87327.9  |          
                          f32 B=1, M=4096, H=160, K=128  |   281487.9  |          
                          f16 B=2, M=4096, H=160, K=128  |   176796.0  |          
                          f32 B=2, M=4096, H=160, K=128  |   564008.1  |          
                          f16 B=1, M=8192, H=160, K=128  |   347318.8  |          
                          f32 B=1, M=8192, H=160, K=128  |  1111934.5  |          
                          f16 B=2, M=8192, H=160, K=128  |   696927.1  |          
                          f16 B=1024, M=82, H=8, K=64    |     7899.4  |    5814.6
                          f32 B=1024, M=82, H=8, K=64    |    12697.1  |   11005.3
                          f16 B=150, M=256, H=16, K=64   |     3963.8  |    7590.3
                          f32 B=150, M=256, H=16, K=64   |    11184.3  |   16357.7
                          f16 B=64, M=256, H=12, K=64    |     1293.5  |    2385.3
                          f32 B=64, M=256, H=12, K=64    |     3633.0  |    4970.4
                          f16 B=1, M=4096, H=16, K=40    |    24253.3  |    8364.6
                          f32 B=1, M=4096, H=16, K=40    |    57122.2  |   19517.0
                          f16 B=1, M=16384, H=16, K=40   |   386768.6  |          
                          f32 B=1, M=16384, H=16, K=40   |   909207.0  |          
                          f16 B=16, M=128, H=16, K=16    |      500.4  |     633.7
                          f32 B=16, M=128, H=16, K=16    |      546.9  |     610.3
                          f16 B=16, M=128, H=16, K=32    |      575.3  |     670.2
                          f32 B=16, M=128, H=16, K=32    |      519.1  |     618.9
                          f16 B=16, M=128, H=16, K=64    |      461.2  |     648.9
                          f32 B=16, M=128, H=16, K=64    |      575.0  |     615.2
                          f16 B=16, M=128, H=16, K=128   |      515.3  |     690.1
                          f32 B=16, M=128, H=16, K=128   |      875.9  |    1006.7
                          f16 B=16, M=128, H=16, K=256   |     1052.4  |     888.7
                          f32 B=16, M=128, H=16, K=256   |     1740.9  |    1854.9
                          f16 B=16, M=512, H=16, K=16    |     1015.1  |    1918.9
                          f32 B=16, M=512, H=16, K=16    |     2540.9  |    4288.5
                          f16 B=16, M=512, H=16, K=32    |     1158.3  |    2128.9
                          f32 B=16, M=512, H=16, K=32    |     3260.8  |    4634.9
                          f16 B=16, M=512, H=16, K=64    |     1490.7  |    2560.9
                          f32 B=16, M=512, H=16, K=64    |     4449.8  |    5479.5
                          f16 B=16, M=512, H=16, K=128   |     3212.9  |    3377.7
                          f32 B=16, M=512, H=16, K=128   |     8759.7  |    8724.9
                          f16 B=16, M=512, H=16, K=256   |     8505.5  |    5348.4
                          f32 B=16, M=512, H=16, K=256   |    17494.2  |   16621.6
                          f16 B=16, M=1024, H=16, K=16   |     3717.3  |    7286.8
                          f32 B=16, M=1024, H=16, K=16   |     9676.0  |   16131.3
                          f16 B=16, M=1024, H=16, K=32   |     4170.7  |    7662.0
                          f32 B=16, M=1024, H=16, K=32   |    12001.4  |   17151.3
                          f16 B=16, M=1024, H=16, K=64   |     5203.6  |    8637.0
                          f32 B=16, M=1024, H=16, K=64   |    16305.8  |   19853.1
                          f16 B=16, M=1024, H=16, K=128  |    11030.4  |   10478.4
                          f32 B=16, M=1024, H=16, K=128  |    32050.9  |   32589.2
                          f16 B=16, M=1024, H=16, K=256  |    28891.7  |   16874.2
                          f32 B=16, M=1024, H=16, K=256  |    63284.2  |   58763.0
                          f16 B=64, M=128, H=16, K=16    |      508.8  |     651.3
                          f32 B=64, M=128, H=16, K=16    |      795.8  |    1240.6
                          f16 B=64, M=128, H=16, K=32    |      469.8  |     814.0
                          f32 B=64, M=128, H=16, K=32    |     1115.5  |    1530.4
                          f16 B=64, M=128, H=16, K=64    |      623.9  |    1185.2
                          f32 B=64, M=128, H=16, K=64    |     1612.9  |    2154.2
                          f16 B=64, M=128, H=16, K=128   |     1419.7  |    1918.1
                          f32 B=64, M=128, H=16, K=128   |     3171.7  |    3761.3
                          f16 B=64, M=128, H=16, K=256   |     3810.6  |    3445.0
                          f32 B=64, M=128, H=16, K=256   |     6416.3  |    7248.7
                          f16 B=64, M=512, H=16, K=16    |     3596.6  |    7506.6
                          f32 B=64, M=512, H=16, K=16    |     9262.7  |   16742.1
                          f16 B=64, M=512, H=16, K=32    |     4189.1  |    8365.1
                          f32 B=64, M=512, H=16, K=32    |    11914.3  |   18455.8
                          f16 B=64, M=512, H=16, K=64    |     5510.7  |   10294.1
                          f32 B=64, M=512, H=16, K=64    |    16312.6  |   22721.7
                          f16 B=64, M=512, H=16, K=128   |    11544.8  |   14667.4
                          f32 B=64, M=512, H=16, K=128   |    31957.9  |   39608.1
                          f16 B=64, M=512, H=16, K=256   |    31179.8  |   26641.1
                          f32 B=64, M=512, H=16, K=256   |    63495.1  |   74597.9
                          f16 B=64, M=1024, H=16, K=16   |    13245.5  |   29173.0
                          f32 B=64, M=1024, H=16, K=16   |    34907.6  |          
                          f16 B=64, M=1024, H=16, K=32   |    15063.4  |   31256.2
                          f32 B=64, M=1024, H=16, K=32   |    43456.3  |          
                          f16 B=64, M=1024, H=16, K=64   |    19322.4  |   37280.5
                          f32 B=64, M=1024, H=16, K=64   |    58997.7  |          
                          f16 B=64, M=1024, H=16, K=128  |    39771.6  |   49110.0
                          f32 B=64, M=1024, H=16, K=128  |   116715.2  |          
                          f16 B=64, M=1024, H=16, K=256  |   106294.7  |          
                          f32 B=64, M=1024, H=16, K=256  |   231929.0  |          

Times are in microseconds (us).

TESTS

Looks like tests are failing on A100:

A100 backward test
$ CUDA_LAUNCH_BLOCKING=1 python -m pytest /scratch/XXXX/xformers/tests/test_mem_eff_attention.py -k "test_backward" -x -s -v --pdb
=========================================================================================================================================== test session starts ============================================================================================================================================
platform linux -- Python 3.10.8, pytest-7.2.0, pluggy-1.0.0 -- /scratch/XXXX/lxformers/bin/python
cachedir: .pytest_cache
rootdir: /scratch/XXXX/xformers
plugins: mpi-0.4, timeout-1.4.2, hydra-core-1.2.0, cov-2.10.0, typeguard-2.13.3
collected 34016 items / 16368 deselected / 17648 selected                                                                                                                                                                                                                                                  

tests/test_mem_eff_attention.py::test_backward[cutlassB-cuda-torch.bfloat16-1-32-32-1-32-32-False-attn_bias_cfg0-BMK] PASSED
tests/test_mem_eff_attention.py::test_backward[cutlassB-cuda-torch.bfloat16-1-32-32-1-32-32-False-attn_bias_cfg0-BMHK] PASSED
tests/test_mem_eff_attention.py::test_backward[cutlassB-cuda-torch.bfloat16-1-32-32-1-32-32-False-attn_bias_cfg1-BMK] PASSED
tests/test_mem_eff_attention.py::test_backward[cutlassB-cuda-torch.bfloat16-1-32-32-1-32-32-False-attn_bias_cfg1-BMHK] PASSED
tests/test_mem_eff_attention.py::test_backward[cutlassB-cuda-torch.bfloat16-1-32-32-1-32-32-False-attn_bias_cfg2-BMK] FAILED
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> traceback >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>

op_device_dtype_B_Mq_Mkv_H_K_Kv = (<class 'xformers.ops.fmha.cutlass.BwOp'>, 'cuda', torch.bfloat16, 1, 32, 32, ...), grad_out_contiguous = False, attn_bias_cfg = (<class 'torch.Tensor'>, True), fmt = 'BMK'

    @pytest.mark.parametrize("fmt", ["BMK", "BMHK"])
    @pytest.mark.parametrize(
        "attn_bias_cfg",  # (type(bias), bias.requires_grad)
        [
            (None, False),
            (xformers.ops.LowerTriangularMask, False),
            (torch.Tensor, True),
            (torch.Tensor, False),
        ],
    )
    @pytest.mark.parametrize("grad_out_contiguous", [False, True])
    @pytest.mark.parametrize(
        "op_device_dtype_B_Mq_Mkv_H_K_Kv",
        _opBW_device_dtype_B_Mq_Mkv_H_K_Kv,
        ids=_opBW_device_dtype_B_Mq_Mkv_H_K_Kv_ids,
    )
    def test_backward(
        op_device_dtype_B_Mq_Mkv_H_K_Kv,
        grad_out_contiguous,
        attn_bias_cfg,
        fmt,
    ):
        attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
        (
            op_bw,
            device,
            dtype,
            batch_size,
            q_len,
            kv_len,
            h,
            k,
            kv,
        ) = op_device_dtype_B_Mq_Mkv_H_K_Kv
        query, key, value, attn_bias = create_tensors(
            *op_device_dtype_B_Mq_Mkv_H_K_Kv,
            attn_bias_type=attn_bias_type,
            fmt=fmt,
        )
        op_fw = (
            sample_random_supported_fw(
                fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias),
                seed=q_len * kv + kv_len * k,
            )
            if op_bw != fmha.cutlass.BwOp
            else fmha.cutlass.FwOp
        )
        qkv = None
    
        if (
            fmt == "BMHK"
            and query.shape[3] == value.shape[3]
            and query.shape[1] == value.shape[1]
        ):
            qkv = torch.stack([query, key, value], 2)
            qkv.requires_grad_(True)
            # bm3hk -> 3 x bmhk
            query, key, value = xformers.ops.unbind(qkv, 2)
            assert not query.is_contiguous()
    
        query.requires_grad_(True)
        key.requires_grad_(True)
        value.requires_grad_(True)
        if isinstance(attn_bias, torch.Tensor):
            attn_bias.requires_grad_(attn_bias_requires_grad)
    
        if not op_bw.supports(fmha.Inputs(query, key, value, attn_bias)):
            pytest.skip("inputs not supported")
    
        out = xformers.ops.memory_efficient_attention(
            query, key, value, attn_bias, op=(op_fw, op_bw)
        )
    
        grad_out = torch.ones_like(out)
        if grad_out_contiguous is False:
            grad_out = torch.tensor([1.0], dtype=query.dtype, device=device)[
                None, None, :
            ].expand_as(out)
    
        out.backward(grad_out)
        del out
    
        if qkv is None and op_bw == fmha.cutlass.BwOp:
            assert query.stride() == query.grad.stride()
    
        grads = []
        if qkv is None:
            grads = [query.grad, key.grad, value.grad]
            query.grad = None
            key.grad = None
            value.grad = None
        else:
            grads = [qkv.grad]
            qkv.grad = None
        if attn_bias_requires_grad:
            grads.append(attn_bias.grad)
            attn_bias.grad = None
    
        ref = ref_attention(query, key, value, attn_bias)
        ref.backward(grad_out)
        del grad_out
        del ref
    
        atol = op_bw.ERROR_ATOL[dtype]
        rtol = op_bw.ERROR_RTOL[dtype]
    
        grads_ref = []
        grads_name = []
        if qkv is None:
            assert isinstance(query.grad, torch.Tensor)
            assert isinstance(key.grad, torch.Tensor)
            assert isinstance(value.grad, torch.Tensor)
            grads_ref = [query.grad, key.grad, value.grad]
            grads_name = ["query", "key", "value"]
        else:
            assert isinstance(qkv.grad, torch.Tensor)
            grads_ref = [qkv.grad]
            grads_name = ["qkv"]
    
        if attn_bias_requires_grad:
            assert isinstance(attn_bias.grad, torch.Tensor)
            grads_ref.append(attn_bias.grad)
            grads_name.append("bias")
    
        del query
        del key
        del value
        del qkv
    
        for name, calc_grad, ref_grad in zip(grads_name, grads, grads_ref):
>           assert_allclose(
                calc_grad,
                ref_grad,
                msg=f"{op_fw.NAME}+{op_bw.NAME}:{name}",
                atol=atol,
                rtol=rtol,
            )

tests/test_mem_eff_attention.py:686: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

out = tensor([[[ 2.8519e+12, -1.1141e-11,  1.0522e+29,  ...,  1.6022e-03,
           1.8014e+16,  1.2360e-03],
         [-1....2461e-01, -2.8712e+24,  ..., -1.0156e+00,
          -4.9720e+18, -4.2969e-01]]], device='cuda:0', dtype=torch.bfloat16)
ref = tensor([[[ 4.5013e-04,  6.3324e-04,  3.0365e-03,  ..., -7.0190e-03,
          -2.4414e-03,  7.0953e-04],
         [ 1....2734e-01,  8.0078e-01,  ..., -1.9336e-01,
          -8.5938e-01,  4.5508e-01]]], device='cuda:0', dtype=torch.bfloat16), msg = 'cutlassF+cutlassB:query', atol = 0.7
rtol = 0.1

    def assert_allclose(
        out: torch.Tensor,
        ref: torch.Tensor,
        msg: str = "failed",
        atol: float = 1e-8,
        rtol: float = 1e-5,
    ) -> None:
        assert out.shape == ref.shape
        flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten()
        max_pos = flatten_diff.argmax()
        max_diff = flatten_diff[max_pos]
        num_different = torch.count_nonzero(flatten_diff > 0)
        percentage = num_different / flatten_diff.numel()
        del flatten_diff
>       assert torch.allclose(out, ref, rtol=rtol, atol=atol), (
            f"{msg}: "
            f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)"
            f"/ atol={atol}, rtol={rtol}"
            f"/ total failing elements: {num_different}, percentage={percentage}"
        )
E       AssertionError: cutlassF+cutlassB:query: out=-8.048060130728983e+35 and ref=0.67578125 (diff=8.048060130728983e+35 > 0)/ atol=0.7, rtol=0.1/ total failing elements: 684, percentage=0.66796875
E       assert False
E        +  where False = <built-in method allclose of type object at 0x7f52fa29f200>(tensor([[[ 2.8519e+12, -1.1141e-11,  1.0522e+29,  ...,  1.6022e-03,\n           1.8014e+16,  1.2360e-03],\n         [-1.9241e+12, -2.0385e-05, -5.5535e+24,  ..., -3.0884e-02,\n          -4.3580e+20,  2.0996e-02],\n         [-8.3317e+16,  6.5484e-11, -2.2849e+26,  ...,  1.2398e-04,\n          -1.9922e+22,  9.3460e-05],\n         ...,\n         [ 2.6828e+14, -1.3097e-09, -1.6712e+28,  ..., -2.5558e-04,\n          -3.6605e+19, -1.9646e-04],\n         [ 4.5425e+20, -2.4214e-07,  2.8531e+26,  ..., -6.9618e-05,\n          -3.5654e+23, -2.8729e-05],\n         [ 4.6043e+22, -2.2461e-01, -2.8712e+24,  ..., -1.0156e+00,\n          -4.9720e+18, -4.2969e-01]]], device='cuda:0', dtype=torch.bfloat16), tensor([[[ 4.5013e-04,  6.3324e-04,  3.0365e-03,  ..., -7.0190e-03,\n          -2.4414e-03,  7.0953e-04],\n         [ 1.8234e-03, -3.3417e-03,  1.3855e-02,  ..., -3.4668e-02,\n          -2.4719e-03,  3.7231e-03],\n         [-5.8984e-01, -2.3535e-01, -8.5547e-01,  ..., -7.6172e-01,\n           4.0625e-01, -2.9883e-01],\n         ...,\n         [ 6.2012e-02,  1.9238e-01,  2.0508e-01,  ..., -4.2725e-02,\n          -5.3516e-01,  8.9062e-01],\n         [ 5.0000e-01,  3.9219e+00,  4.0430e-01,  ...,  2.3594e+00,\n           9.0625e-01, -2.8750e+00],\n         [-3.8672e-01, -5.2734e-01,  8.0078e-01,  ..., -1.9336e-01,\n          -8.5938e-01,  4.5508e-01]]], device='cuda:0', dtype=torch.bfloat16), rtol=0.1, atol=0.7)
E        +    where <built-in method allclose of type object at 0x7f52fa29f200> = torch.allclose

tests/test_mem_eff_attention.py:182: AssertionError
**EDIT**: The only test configurations that fail are the ones where we require a gradient for the bias

They seem to pass on V100/P100 :)

@jfc4050
Copy link
Contributor Author

jfc4050 commented Dec 15, 2022

The only test configurations that fail are the ones where we require a gradient for the bias
They seem to pass on V100/P100 :)

cool! thats about all i could hope for on A100. i'll see if i can find an A100 today or tomorrow and start debugging

@danthe3rd
Copy link
Contributor

Thanks a lot for addressing / answering the comments! As the code looks fairly clean, I'm happy to merge once the following conditions are met:
(1) No performance regression in the previously supported cases (eg no dropout + no mask) - this seems to be the case
(2) Tests pass on A100/V100/P100 - looks like there is still some fixes needed for A100

Also heads-up as we will move all of the C++ files around (as in #579). Might create conflicts in git as you rebase, but we won't change the content of the files you are touching.

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests seem to pass on my A100! Congratulations!
I just opened draft PR with these changes to get our CI to test them:
#606

int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise

// dropout
bool use_dropout;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this additional variable? Can't we just compare "dropout_prob != 0" ?

Copy link
Contributor Author

@jfc4050 jfc4050 Dec 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll see if i can get rid of it, was done this way because i was worried about dropout_prob != 0 because its a floating point comparison

EDIT: if you're only talking about backward here can definitely get rid of it since its only used once to dispatch and we can use std::fpclassify there

Comment on lines 589 to 591
kPreloadMmas && kApplyDropout ?
cutlass::const_min(2, DefaultConfig::kStages) :
DefaultConfig::kStages, // Stages
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I last tested, it seemed that cutlass::gemm::threadblock::DefaultMma would match the Pipelined implementation on A100 when using kStages=2, instead of the Mma implem (cc @hwu36). This makes performance much worse on A100 when dropout is enabled - but let's keep it like this for now, we can address that in a later PR (I also have plans to reduce shmem usage, so this might no longer be needed in the future).

I also see that you are using cutlass::gemm::kernel::DefaultGemm now, which might not have this issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from looking at it seems like it does mean it selects the pipelined implementation. here's some benchmarks to show the impact

[---------------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------------]
                                                                            |  optimized[flshattB]  |  vanilla
1 threads: ---------------------------------------------------------------------------------------------------
      b16 B=8, M=512, H=64, K=64, p=0.0, BiasT=NoneType, BiasGrad=False     |           1.1         |     2.9
      b16 B=8, M=512, H=64, K=64, p=0.3, BiasT=NoneType, BiasGrad=False     |           1.2         |     3.4
      b16 B=8, M=512, H=64, K=128, p=0.0, BiasT=NoneType, BiasGrad=False    |           2.6         |     3.9
      b16 B=8, M=512, H=64, K=128, p=0.3, BiasT=NoneType, BiasGrad=False    |           2.7         |     4.3
      b16 B=8, M=1024, H=64, K=64, p=0.0, BiasT=NoneType, BiasGrad=False    |           4.1         |    10.0
      b16 B=8, M=1024, H=64, K=64, p=0.3, BiasT=NoneType, BiasGrad=False    |           4.4         |    12.0
      b16 B=8, M=1024, H=64, K=128, p=0.0, BiasT=NoneType, BiasGrad=False   |           9.1         |    12.0
      b16 B=8, M=1024, H=64, K=128, p=0.3, BiasT=NoneType, BiasGrad=False   |           9.4         |    13.9
      b16 B=16, M=512, H=64, K=64, p=0.0, BiasT=NoneType, BiasGrad=False    |           2.2         |     5.7
      b16 B=16, M=512, H=64, K=64, p=0.3, BiasT=NoneType, BiasGrad=False    |           2.4         |     6.7
      b16 B=16, M=512, H=64, K=128, p=0.0, BiasT=NoneType, BiasGrad=False   |           5.1         |     7.7
      b16 B=16, M=512, H=64, K=128, p=0.3, BiasT=NoneType, BiasGrad=False   |           5.3         |     8.7
      b16 B=16, M=1024, H=64, K=64, p=0.0, BiasT=NoneType, BiasGrad=False   |           8.2         |    20.1
      b16 B=16, M=1024, H=64, K=64, p=0.3, BiasT=NoneType, BiasGrad=False   |           8.7         |    23.9
      b16 B=16, M=1024, H=64, K=128, p=0.0, BiasT=NoneType, BiasGrad=False  |          18.2         |    24.0
      b16 B=16, M=1024, H=64, K=128, p=0.3, BiasT=NoneType, BiasGrad=False  |          18.8         |    27.8

Times are in milliseconds (ms).

[------------------------- attention backward (attn_bias=<class 'torch.Tensor'>) --------------------------]
                                                                          |  optimized[cutlassB]  |  vanilla
1 threads: -------------------------------------------------------------------------------------------------
      b16 B=8, M=512, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=False     |           2.1         |     2.9
      b16 B=8, M=512, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=True      |           2.1         |     2.9
      b16 B=8, M=512, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=False     |           3.1         |     3.4
      b16 B=8, M=512, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=True      |           3.1         |     3.4
      b16 B=8, M=512, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=False    |           3.6         |     3.9
      b16 B=8, M=512, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=True     |           3.6         |     3.9
      b16 B=8, M=512, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=False    |           4.5         |     4.3
      b16 B=8, M=512, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=True     |           4.5         |     4.3
      b16 B=8, M=1024, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=False    |           7.6         |    10.0
      b16 B=8, M=1024, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=True     |           7.6         |    10.0
      b16 B=8, M=1024, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=False    |          11.7         |    11.9
      b16 B=8, M=1024, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=True     |          11.7         |    11.9
      b16 B=8, M=1024, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=False   |          13.1         |    12.0
      b16 B=8, M=1024, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=True    |          13.1         |    12.0
      b16 B=8, M=1024, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=False   |          16.4         |    13.9
      b16 B=8, M=1024, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=True    |          16.4         |    13.9
      b16 B=16, M=512, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=False    |           3.9         |     5.7
      b16 B=16, M=512, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=True     |           4.0         |     5.7
      b16 B=16, M=512, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=False    |           6.0         |     6.7
      b16 B=16, M=512, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=True     |           6.0         |     6.7
      b16 B=16, M=512, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=False   |           7.3         |     7.7
      b16 B=16, M=512, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=True    |           7.3         |     7.7
      b16 B=16, M=512, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=False   |           8.9         |     8.7
      b16 B=16, M=512, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=True    |           8.9         |     8.7
      b16 B=16, M=1024, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=False   |          14.8         |    20.1
      b16 B=16, M=1024, H=64, K=64, p=0.0, BiasT=Tensor, BiasGrad=True    |          14.7         |    20.0
      b16 B=16, M=1024, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=False   |          23.1         |    23.9
      b16 B=16, M=1024, H=64, K=64, p=0.3, BiasT=Tensor, BiasGrad=True    |          23.1         |    23.9
      b16 B=16, M=1024, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=False  |          25.9         |    24.0
      b16 B=16, M=1024, H=64, K=128, p=0.0, BiasT=Tensor, BiasGrad=True   |          25.9         |    24.0
      b16 B=16, M=1024, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=False  |          32.4         |    27.8
      b16 B=16, M=1024, H=64, K=128, p=0.3, BiasT=Tensor, BiasGrad=True   |          32.4         |    27.8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, what are the plans to cut shmem usage?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which GPU did you run these benchmarks on? It's ~20-50% slower depending on the cases - might be worth investigating later if you want.

just curious, what are the plans to cut shmem usage?

We are loading Q,K,dO from global memory twice. We could reuse the shared-memory to avoid that
Some illustration of the BW pass:
(1) Q@K, dO@V, gradV matmuls
image
(2) gradQ, gradK
image

Copy link
Contributor

@danthe3rd danthe3rd Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A100 Shmem occupancy - current situation (max ~160kb)

image

A100 Shmem occupancy - after potential changes (max ~131kb)

(1) setting kNumStages=4 so the entire block of matrices fit in shared-memory
(2) loading dO/K from same shared-memory location (rather than re-loading them from global memory)
Also I expect this to be faster

image

I don't have immediate plans to work on this, but let me know if you would like to contribute :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just saw your question in NVIDIA/cutlass#744, I believe it's related to this :) Let's do that as part of a different PR if possible to make things easier to review.
Also as a heads-up for this PR, I won't be available next week but hopefully we can get this merged early january!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it was, thanks for the hints there! yes definitely we can put it in another PR, just wanted to start experimenting.

and sounds good, thanks for the quick reviews so far, happy holidays!

adds attn bias (including bias grad) and dropout support to CUTLASS
flashattn implementation

[-------------------------------------------- attn --------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |     12.7    |     7.5
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |     15.5    |     9.1
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |     12.7    |     7.6
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |     15.6    |     9.1
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |     10.1    |     6.0
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |     12.7    |     7.5
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     44.3    |    29.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     55.0    |    35.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     45.1    |    29.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     55.6    |    35.3
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     37.0    |    22.6
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     46.8    |    29.0

Times are in milliseconds (ms).

[------------------------------------------ attn-bwd ------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |     19.3    |    24.1
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |     19.4    |    24.6
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |     22.3    |    28.7
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |     22.4    |    29.0
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |     19.5    |    22.7
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |     19.5    |    23.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     62.7    |    91.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     63.4    |    93.7
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     74.8    |   109.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     75.1    |   111.1
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     63.2    |    85.5
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     64.0    |    90.1
@jfc4050
Copy link
Contributor Author

jfc4050 commented Dec 21, 2022

rebased and addressed the format/lint issues. hoping CI will pass now 馃檹 (minus windows build, will look at that next)

BEFORE

[------------------------------------------ attn-bwd ------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.0)      |      2.8    |     2.4
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.5)      |      2.8    |     3.3
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.0)       |      3.4    |     3.2
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.5)       |      3.4    |     4.2
      (8, 512, 64, 64, torch.float16, None, False, 0.0)                 |      2.8    |     2.0
      (8, 512, 64, 64, torch.float16, None, False, 0.5)                 |      2.8    |     2.9
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |      3.6    |     3.9
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |      3.6    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |      4.2    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |      4.2    |     5.6
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |      3.6    |     3.4
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |      3.6    |     4.4
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.0)   |      9.7    |     8.8
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.5)   |      9.7    |    12.6
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.0)    |     12.0    |    12.1
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.5)    |     12.1    |    16.1
      (8, 1024, 64, 64, torch.float16, None, False, 0.0)                |      9.7    |     7.4
      (8, 1024, 64, 64, torch.float16, None, False, 0.5)                |      9.7    |    10.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     11.3    |    14.0
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     11.3    |    17.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     13.6    |    17.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     13.6    |    20.9
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     11.3    |    12.1
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     11.3    |    15.8

AFTER

[------------------------------------------ attn-bwd ------------------------------------------]
                                                                        |  reference  |  cutlass
1 threads: -------------------------------------------------------------------------------------
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.0)      |      2.8    |     2.4
      (8, 512, 64, 64, torch.float16, (512, 512, 512), False, 0.5)      |      2.8    |     3.0
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.0)       |      3.4    |     3.2
      (8, 512, 64, 64, torch.float16, (512, 512, 512), True, 0.5)       |      3.4    |     3.8
      (8, 512, 64, 64, torch.float16, None, False, 0.0)                 |      2.8    |     2.0
      (8, 512, 64, 64, torch.float16, None, False, 0.5)                 |      2.8    |     2.6
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0)     |      3.6    |     3.9
      (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5)     |      3.6    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0)      |      4.2    |     4.8
      (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5)      |      4.2    |     5.6
      (8, 512, 64, 128, torch.float16, None, False, 0.0)                |      3.6    |     3.4
      (8, 512, 64, 128, torch.float16, None, False, 0.5)                |      3.6    |     4.4
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.0)   |      9.7    |     8.8
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), False, 0.5)   |      9.7    |    11.4
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.0)    |     12.0    |    12.1
      (8, 1024, 64, 64, torch.float16, (512, 1024, 1024), True, 0.5)    |     12.1    |    14.6
      (8, 1024, 64, 64, torch.float16, None, False, 0.0)                |      9.7    |     7.4
      (8, 1024, 64, 64, torch.float16, None, False, 0.5)                |      9.7    |     9.6
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0)  |     11.3    |    14.1
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5)  |     11.3    |    17.4
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0)   |     13.6    |    17.8
      (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5)   |     13.6    |    20.9
      (8, 1024, 64, 128, torch.float16, None, False, 0.0)               |     11.3    |    12.1
      (8, 1024, 64, 128, torch.float16, None, False, 0.5)               |     11.3    |    15.8
@jfc4050
Copy link
Contributor Author

jfc4050 commented Dec 30, 2022

fixed the windows build, should be good to rerun the tests when you get back.

i'm prototyping some of the shared memory changes you suggested, but can't guarantee i'll end up finishing it and opening a PR.

@jfc4050
Copy link
Contributor Author

jfc4050 commented Dec 30, 2022

not sure what to do about the "Could not find a usable config.yml, you may have revoked the CircleCI OAuth app." error.

cpu_tests_py38 seems to be failing because of a formatting issue in an unrelated file. i can fix it anyways.

edit: seems like the #include order in the swiglu file was intentional, build fails after allowing formatter to reorder them

@danthe3rd
Copy link
Contributor

The CI looks good to me! Thanks for fixing the windows build :) will merge early next week once I re-run the benchmarks on A100/v100 to ensure recent changes didn't bring any regression.
I'll also send you a slack invite when I come back so we can communicate more directly - in case you want to contribute more, also to learn more about your usecase :)

@jfc4050
Copy link
Contributor Author

jfc4050 commented Dec 30, 2022

awesome! and that sounds good to me. in general we are using this for distributed training of large language models using ZeRO-style data parallelism. in pytorch, running close to GPU memory limit really messes with performance of collectives because of the way the caching allocator works. we can definitely discuss more over slack if you are interested in the details.

@danthe3rd
Copy link
Contributor

danthe3rd commented Jan 3, 2023

Perf measurements look good. I'm just a bit worried about the stable diffusion case on V100 (f16 B=1, M=16384, H=16, K=40) - do you have any idea what might cause that?
cc @fmassa what do you think about the perf?

A100

A100 fw
[----------------- attention (attn_bias=<class 'NoneType'>) -----------------]                                                                                                                                                                                                                               
                                     |  pr587_d1b0fa  |    main    |   eager  
1 threads: -------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |       124.8    |     121.3  |     851.8
      f32 B=384, M=197, H=1, K=88    |       463.3    |     447.4  |     719.5
      f16 B=384, M=197, H=1, K=80    |       116.1    |     112.2  |     742.2
      f32 B=384, M=197, H=1, K=80    |       459.3    |     443.5  |     692.4
      f16 B=384, M=197, H=1, K=64    |        87.9    |      86.7  |     679.8
      f32 B=384, M=197, H=1, K=64    |       281.3    |     268.7  |     640.5
      f16 B=1024, M=197, H=1, K=88   |       315.1    |     306.6  |    2221.3
      f32 B=1024, M=197, H=1, K=88   |      1220.9    |    1168.2  |    1765.2
      f16 B=1024, M=197, H=1, K=80   |       294.5    |     285.5  |    1927.2
      f32 B=1024, M=197, H=1, K=80   |      1212.6    |    1160.4  |    1700.6
      f16 B=1024, M=197, H=1, K=64   |       212.8    |     208.7  |    1770.5
      f32 B=1024, M=197, H=1, K=64   |       689.1    |     661.7  |    1568.6
      f16 B=512, M=197, H=1, K=80    |       153.3    |     148.2  |     979.2
      f32 B=512, M=197, H=1, K=80    |       615.0    |     592.4  |     881.7
      f16 B=32, M=197, H=16, K=80    |       154.2    |     149.2  |    1064.0
      f32 B=32, M=197, H=16, K=80    |       617.9    |     594.5  |    1051.0
      f16 B=32, M=197, H=16, K=64    |       113.5    |     111.7  |     979.9
      f32 B=32, M=197, H=16, K=64    |       354.9    |     341.4  |     949.7
      f16 B=32, M=197, H=16, K=128   |       168.1    |     162.5  |    1717.6
      f32 B=32, M=197, H=16, K=128   |       684.3    |     661.8  |    1324.7
      f16 B=256, M=197, H=1, K=88    |        87.6    |      85.0  |     577.0
      f32 B=256, M=197, H=1, K=88    |       318.8    |     306.9  |     477.8
      f16 B=16, M=197, H=16, K=88    |        88.9    |      86.2  |     628.7
      f32 B=16, M=197, H=16, K=88    |       321.1    |     308.4  |     572.9
      f16 B=16, M=197, H=16, K=64    |        66.4    |      61.1  |     506.9
      f32 B=16, M=197, H=16, K=64    |       196.3    |     186.9  |     494.9
      f16 B=16, M=197, H=16, K=128   |        89.8    |      87.0  |     877.3
      f32 B=16, M=197, H=16, K=128   |       352.9    |     339.6  |     692.8
      f16 B=1, M=4096, H=160, K=128  |     15245.8    |   14853.7  |   21510.6
      f32 B=1, M=4096, H=160, K=128  |     57963.3    |   56691.3  |   91407.5
      f16 B=2, M=4096, H=160, K=128  |     30427.9    |   29633.5  |   43562.2
      f32 B=2, M=4096, H=160, K=128  |    115784.8    |  113237.0  |          
      f16 B=1, M=8192, H=160, K=128  |     60820.3    |   59240.6  |          
      f32 B=1, M=8192, H=160, K=128  |    232806.2    |  226792.5  |          
      f16 B=2, M=8192, H=160, K=128  |    121599.1    |  118414.6  |          
      f32 B=2, M=8192, H=160, K=128  |    466094.3    |  453321.1  |          
      f16 B=1024, M=82, H=8, K=64    |       475.4    |     446.3  |    1785.2
      f32 B=1024, M=82, H=8, K=64    |      1428.2    |    1314.3  |    3753.8
      f16 B=150, M=256, H=16, K=64   |       513.2    |     503.1  |    1964.8
      f32 B=150, M=256, H=16, K=64   |      1690.7    |    1602.7  |    5294.3
      f16 B=64, M=256, H=12, K=64    |       171.8    |     169.8  |     664.2
      f32 B=64, M=256, H=12, K=64    |       559.4    |     533.0  |    1760.4
      f16 B=1, M=4096, H=16, K=40    |       857.0    |     870.6  |    1958.3
      f32 B=1, M=4096, H=16, K=40    |      2866.1    |    2765.6  |    6914.2
      f16 B=1, M=16384, H=16, K=40   |     12159.2    |   12346.3  |   30460.6
      f32 B=1, M=16384, H=16, K=40   |     41947.4    |   40533.2  |  123282.2
      f16 B=256, M=4096, H=16, K=64  |    181204.2    |  183488.9  |          
      f32 B=256, M=4096, H=16, K=64  |    665520.1    |  642960.3  |          
      f16 B=16, M=128, H=16, K=16    |        59.4    |      60.4  |     155.3
      f32 B=16, M=128, H=16, K=16    |        60.0    |      60.7  |     151.5
      f16 B=16, M=128, H=16, K=32    |        59.4    |      60.7  |     154.5
      f32 B=16, M=128, H=16, K=32    |        60.5    |      60.8  |     175.9
      f16 B=16, M=128, H=16, K=64    |        59.3    |      60.6  |     155.4
      f32 B=16, M=128, H=16, K=64    |        66.0    |      62.2  |     217.8
      f16 B=16, M=128, H=16, K=128   |        59.4    |      60.2  |     155.4
      f32 B=16, M=128, H=16, K=128   |       124.3    |     117.2  |     311.9
      f16 B=16, M=128, H=16, K=256   |        78.3    |      76.5  |     253.0
      f32 B=16, M=128, H=16, K=256   |       224.7    |     210.4  |     534.7
      f16 B=16, M=512, H=16, K=16    |       174.8    |     173.5  |     516.1
      f32 B=16, M=512, H=16, K=16    |       576.9    |     553.6  |    1629.1
      f16 B=16, M=512, H=16, K=32    |       183.1    |     180.5  |     564.7
      f32 B=16, M=512, H=16, K=32    |       586.1    |     562.4  |    1781.3
      f16 B=16, M=512, H=16, K=64    |       209.7    |     207.9  |     673.7
      f32 B=16, M=512, H=16, K=64    |       708.4    |     678.4  |    2107.5
      f16 B=16, M=512, H=16, K=128   |       364.5    |     353.8  |     856.0
      f32 B=16, M=512, H=16, K=128   |      1540.1    |    1493.4  |    2760.0
      f16 B=16, M=512, H=16, K=256   |       817.7    |     812.9  |    1230.8
      f32 B=16, M=512, H=16, K=256   |      3124.4    |    2984.7  |    4922.2
      f16 B=16, M=1024, H=16, K=16   |       667.4    |     662.9  |    1857.2
      f32 B=16, M=1024, H=16, K=16   |      2210.5    |    2130.1  |    6088.4
      f16 B=16, M=1024, H=16, K=32   |       673.7    |     669.0  |    1951.1
      f32 B=16, M=1024, H=16, K=32   |      2234.2    |    2150.0  |    6591.0
      f16 B=16, M=1024, H=16, K=64   |       762.5    |     766.7  |    2192.7
      f32 B=16, M=1024, H=16, K=64   |      2701.8    |    2597.3  |    7687.2
      f16 B=16, M=1024, H=16, K=128  |      1345.5    |    1311.5  |    2618.4
      f32 B=16, M=1024, H=16, K=128  |      5922.5    |    5772.7  |    9881.6
      f16 B=16, M=1024, H=16, K=256  |      3132.5    |    3096.7  |    3436.4
      f32 B=16, M=1024, H=16, K=256  |     12166.5    |   11659.9  |   17896.5
      f16 B=64, M=128, H=16, K=16    |        65.8    |      61.0  |     203.5
      f32 B=64, M=128, H=16, K=16    |       166.7    |     157.7  |     489.0
      f16 B=64, M=128, H=16, K=32    |        66.0    |      61.2  |     250.3
      f32 B=64, M=128, H=16, K=32    |       174.1    |     165.0  |     574.8
      f16 B=64, M=128, H=16, K=64    |        75.0    |      73.2  |     349.0
      f32 B=64, M=128, H=16, K=64    |       215.2    |     203.6  |     737.7
      f16 B=64, M=128, H=16, K=128   |       130.9    |     129.1  |     534.2
      f32 B=64, M=128, H=16, K=128   |       449.9    |     427.7  |    1070.0
      f16 B=64, M=128, H=16, K=256   |       258.2    |     253.4  |     914.4
      f32 B=64, M=128, H=16, K=256   |       833.4    |     791.8  |    1921.3
      f16 B=64, M=512, H=16, K=16    |       683.8    |     676.5  |    1893.0
      f32 B=64, M=512, H=16, K=16    |      2247.4    |    2161.4  |    6273.1
      f16 B=64, M=512, H=16, K=32    |       693.3    |     684.8  |    2096.1
      f32 B=64, M=512, H=16, K=32    |      2280.4    |    2190.8  |    6822.3
      f16 B=64, M=512, H=16, K=64    |       794.9    |     787.1  |    2495.9
      f32 B=64, M=512, H=16, K=64    |      2748.7    |    2635.4  |    8108.3
      f16 B=64, M=512, H=16, K=128   |      1418.1    |    1380.4  |    3264.1
      f32 B=64, M=512, H=16, K=128   |      6082.3    |    5895.7  |   10731.7
      f16 B=64, M=512, H=16, K=256   |      3227.0    |    3181.7  |    4760.3
      f32 B=64, M=512, H=16, K=256   |     12406.5    |   11825.6  |   19614.2
      f16 B=64, M=1024, H=16, K=16   |      2611.3    |    2593.0  |    7293.8
      f32 B=64, M=1024, H=16, K=16   |      8722.5    |    8424.4  |   24237.0
      f16 B=64, M=1024, H=16, K=32   |      2634.5    |    2612.0  |    7675.3
      f32 B=64, M=1024, H=16, K=32   |      8811.5    |    8494.5  |   26228.7
      f16 B=64, M=1024, H=16, K=64   |      2993.2    |    2991.8  |    8669.3
      f32 B=64, M=1024, H=16, K=64   |     10667.6    |   10271.7  |   30612.5
      f16 B=64, M=1024, H=16, K=128  |      5398.9    |    5266.6  |   10313.9
      f32 B=64, M=1024, H=16, K=128  |     23517.9    |   22940.7  |   39341.6
      f16 B=64, M=1024, H=16, K=256  |     12272.2    |   12354.8  |   13558.9
      f32 B=64, M=1024, H=16, K=256  |     48531.7    |   46488.2  |   71460.8

Times are in microseconds (us).

[ attention (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) ]
                                     |  pr587_d1b0fa  |    main    |   eager  
1 threads: -------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |        94.9    |      92.0  |     918.3
      f32 B=384, M=197, H=1, K=88    |       336.1    |     321.6  |     779.4
      f16 B=384, M=197, H=1, K=80    |        90.1    |      86.8  |     812.3
      f32 B=384, M=197, H=1, K=80    |       333.0    |     318.7  |     752.1
      f16 B=384, M=197, H=1, K=64    |        66.6    |      65.3  |     749.9
      f32 B=384, M=197, H=1, K=64    |       202.6    |     191.7  |     705.0
      f16 B=1024, M=197, H=1, K=88   |       230.8    |     223.7  |    2389.0
      f32 B=1024, M=197, H=1, K=88   |       854.2    |     818.3  |    1907.8
      f16 B=1024, M=197, H=1, K=80   |       218.9    |     211.0  |    2105.1
      f32 B=1024, M=197, H=1, K=80   |       847.1    |     810.7  |    1843.8
      f16 B=1024, M=197, H=1, K=64   |       153.0    |     149.8  |    1945.4
      f32 B=1024, M=197, H=1, K=64   |       484.8    |     457.8  |    1719.9
      f16 B=512, M=197, H=1, K=80    |       115.9    |     111.6  |    1071.6
      f32 B=512, M=197, H=1, K=80    |       435.8    |     417.0  |     961.4
      f16 B=32, M=197, H=16, K=80    |       116.6    |     112.3  |    1154.2
      f32 B=32, M=197, H=16, K=80    |       437.3    |     418.4  |    1131.0
      f16 B=32, M=197, H=16, K=64    |        85.2    |      83.3  |    1068.3
      f32 B=32, M=197, H=16, K=64    |       260.0    |     245.5  |    1027.6
      f16 B=32, M=197, H=16, K=128   |       127.1    |     123.0  |    1802.6
      f32 B=32, M=197, H=16, K=128   |       490.0    |     467.4  |    1403.9
      f16 B=256, M=197, H=1, K=88    |        67.5    |      65.3  |     622.5
      f32 B=256, M=197, H=1, K=88    |       232.3    |     222.5  |     523.1
      f16 B=16, M=197, H=16, K=88    |        67.7    |      65.6  |     674.0
      f32 B=16, M=197, H=16, K=88    |       233.1    |     223.4  |     618.4
      f16 B=16, M=197, H=16, K=64    |        66.1    |      60.7  |     552.6
      f32 B=16, M=197, H=16, K=64    |       146.9    |     139.1  |     542.7
      f16 B=16, M=197, H=16, K=128   |        70.7    |      68.2  |     923.9
      f32 B=16, M=197, H=16, K=128   |       258.9    |     246.7  |     737.2
      f16 B=1, M=4096, H=160, K=128  |      7807.9    |    7595.2  |   38531.2
      f32 B=1, M=4096, H=160, K=128  |     29992.9    |   29331.8  |  109017.4
      f16 B=2, M=4096, H=160, K=128  |     15494.7    |   15094.0  |   78631.6
      f32 B=2, M=4096, H=160, K=128  |     59693.4    |   58316.8  |          
      f16 B=1, M=8192, H=160, K=128  |     30770.7    |   29983.0  |          
      f32 B=1, M=8192, H=160, K=128  |    117905.3    |  115038.2  |          
      f16 B=2, M=8192, H=160, K=128  |     61289.6    |   59734.6  |          
      f32 B=2, M=8192, H=160, K=128  |    234813.8    |  229411.1  |          
      f16 B=1024, M=82, H=8, K=64    |       381.6    |     370.1  |    1997.2
      f32 B=1024, M=82, H=8, K=64    |      1144.6    |    1063.2  |    3984.2
      f16 B=150, M=256, H=16, K=64   |       360.1    |     352.6  |    2604.2
      f32 B=150, M=256, H=16, K=64   |      1140.1    |    1071.0  |    5906.0
      f16 B=64, M=256, H=12, K=64    |       124.8    |     122.8  |     873.6
      f32 B=64, M=256, H=12, K=64    |       387.9    |     365.7  |    1963.7
      f16 B=1, M=4096, H=16, K=40    |       528.2    |     538.7  |    3768.9
      f32 B=1, M=4096, H=16, K=40    |      1706.0    |    1652.4  |    8888.6
      f16 B=1, M=16384, H=16, K=40   |      6541.9    |    6637.7  |   59019.0
      f32 B=1, M=16384, H=16, K=40   |     22448.0    |   21682.4  |          
      f16 B=256, M=4096, H=16, K=64  |     94016.8    |   94157.2  |          
      f32 B=256, M=4096, H=16, K=64  |    340504.8    |  328323.8  |          
      f16 B=16, M=128, H=16, K=16    |        59.2    |      60.9  |     161.1
      f32 B=16, M=128, H=16, K=16    |        60.1    |      61.0  |     171.2
      f16 B=16, M=128, H=16, K=32    |        58.9    |      60.7  |     160.9
      f32 B=16, M=128, H=16, K=32    |        60.6    |      60.6  |     192.6
      f16 B=16, M=128, H=16, K=64    |        59.1    |      60.8  |     161.7
      f32 B=16, M=128, H=16, K=64    |        63.6    |      61.0  |     244.1
      f16 B=16, M=128, H=16, K=128   |        59.3    |      60.4  |     166.6
      f32 B=16, M=128, H=16, K=128   |       114.2    |     108.0  |     341.2
      f16 B=16, M=128, H=16, K=256   |        69.5    |      68.1  |     283.8
      f32 B=16, M=128, H=16, K=256   |       201.7    |     190.6  |     562.6
      f16 B=16, M=512, H=16, K=16    |       117.4    |     115.7  |     824.6
      f32 B=16, M=512, H=16, K=16    |       368.9    |     352.6  |    2052.2
      f16 B=16, M=512, H=16, K=32    |       122.4    |     120.0  |     869.4
      f32 B=16, M=512, H=16, K=32    |       377.7    |     360.1  |    2129.5
      f16 B=16, M=512, H=16, K=64    |       140.3    |     139.4  |     961.3
      f32 B=16, M=512, H=16, K=64    |       454.6    |     433.3  |    2362.8
      f16 B=16, M=512, H=16, K=128   |       238.1    |     230.7  |    1131.0
      f32 B=16, M=512, H=16, K=128   |       971.3    |     936.1  |    2994.9
      f16 B=16, M=512, H=16, K=256   |       513.7    |     509.1  |    1503.4
      f32 B=16, M=512, H=16, K=256   |      1970.8    |    1875.0  |    5159.1
      f16 B=16, M=1024, H=16, K=16   |       390.0    |     385.9  |    3008.8
      f32 B=16, M=1024, H=16, K=16   |      1259.6    |    1211.4  |    8011.4
      f16 B=16, M=1024, H=16, K=32   |       394.5    |     390.2  |    3109.4
      f32 B=16, M=1024, H=16, K=32   |      1276.9    |    1224.9  |    8143.3
      f16 B=16, M=1024, H=16, K=64   |       447.4    |     448.2  |    3304.5
      f32 B=16, M=1024, H=16, K=64   |      1540.8    |    1477.3  |    8691.8
      f16 B=16, M=1024, H=16, K=128  |       790.5    |     765.3  |    3685.4
      f32 B=16, M=1024, H=16, K=128  |      3354.8    |    3257.0  |   10811.9
      f16 B=16, M=1024, H=16, K=256  |      1774.3    |    1749.4  |    4464.5
      f32 B=16, M=1024, H=16, K=256  |      6910.8    |    6593.8  |   18779.5
      f16 B=64, M=128, H=16, K=16    |        66.3    |      60.8  |     277.1
      f32 B=64, M=128, H=16, K=16    |       139.0    |     130.5  |     600.0
      f16 B=64, M=128, H=16, K=32    |        65.4    |      60.9  |     332.3
      f32 B=64, M=128, H=16, K=32    |       146.7    |     138.2  |     672.4
      f16 B=64, M=128, H=16, K=64    |        67.6    |      66.1  |     426.4
      f32 B=64, M=128, H=16, K=64    |       179.4    |     169.3  |     819.9
      f16 B=64, M=128, H=16, K=128   |       117.3    |     114.5  |     607.5
      f32 B=64, M=128, H=16, K=128   |       405.0    |     383.1  |    1139.5
      f16 B=64, M=128, H=16, K=256   |       230.7    |     228.8  |     988.8
      f32 B=64, M=128, H=16, K=256   |       733.5    |     695.7  |    1991.6
      f16 B=64, M=512, H=16, K=16    |       419.6    |     412.5  |    3068.8
      f32 B=64, M=512, H=16, K=16    |      1332.7    |    1274.7  |    7904.7
      f16 B=64, M=512, H=16, K=32    |       426.8    |     419.1  |    3256.0
      f32 B=64, M=512, H=16, K=32    |      1360.6    |    1297.8  |    8190.3
      f16 B=64, M=512, H=16, K=64    |       492.7    |     485.2  |    3628.3
      f32 B=64, M=512, H=16, K=64    |      1639.1    |    1560.6  |    9073.3
      f16 B=64, M=512, H=16, K=128   |       908.9    |     873.2  |    4331.2
      f32 B=64, M=512, H=16, K=128   |      3735.3    |    3597.2  |   11582.6
      f16 B=64, M=512, H=16, K=256   |      1967.6    |    1955.4  |    5832.9
      f32 B=64, M=512, H=16, K=256   |      7622.6    |    7274.0  |   20462.0
      f16 B=64, M=1024, H=16, K=16   |      1455.4    |    1441.0  |   11836.9
      f32 B=64, M=1024, H=16, K=16   |      4773.9    |    4594.4  |   31900.4
      f16 B=64, M=1024, H=16, K=32   |      1471.9    |    1454.8  |   12237.2
      f32 B=64, M=1024, H=16, K=32   |      4836.4    |    4644.8  |   32409.2
      f16 B=64, M=1024, H=16, K=64   |      1679.2    |    1668.8  |   13076.3
      f32 B=64, M=1024, H=16, K=64   |      5843.4    |    5601.8  |   34561.1
      f16 B=64, M=1024, H=16, K=128  |      3090.0    |    2981.2  |   14547.6
      f32 B=64, M=1024, H=16, K=128  |     13157.7    |   12773.4  |   42996.6
      f16 B=64, M=1024, H=16, K=256  |      7031.8    |    6942.8  |   17618.8
      f32 B=64, M=1024, H=16, K=256  |     27170.5    |   25945.4  |   74993.3

Times are in microseconds (us).
A100 bw
[------------- attention backward (attn_bias=<class 'NoneType'>) -------------]                                                                                                                                                                                                                              
                                     |  pr587_d1b0fa  |     main    |  vanilla 
1 threads: --------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |       717.6    |      651.2  |    2260.9
      f32 B=384, M=197, H=1, K=88    |      2362.6    |     2330.6  |    1841.9
      f16 B=384, M=197, H=1, K=80    |       686.1    |      621.9  |    1916.9
      f32 B=384, M=197, H=1, K=80    |      2265.1    |     2229.2  |    1785.9
      f16 B=384, M=197, H=1, K=64    |       422.9    |      459.5  |    1808.0
      f32 B=384, M=197, H=1, K=64    |      1280.9    |     1262.5  |    1673.1
      f16 B=1024, M=197, H=1, K=88   |      1812.5    |     1609.8  |    5941.3
      f32 B=1024, M=197, H=1, K=88   |      6124.5    |     6051.2  |    4553.4
      f16 B=1024, M=197, H=1, K=80   |      1728.0    |     1536.3  |    5022.0
      f32 B=1024, M=197, H=1, K=80   |      5863.8    |     5778.6  |    4405.0
      f16 B=1024, M=197, H=1, K=64   |       968.8    |     1037.3  |    4732.1
      f32 B=1024, M=197, H=1, K=64   |      3340.3    |     3295.5  |    4112.7
      f16 B=512, M=197, H=1, K=80    |       877.5    |      785.1  |    2533.9
      f32 B=512, M=197, H=1, K=80    |      2914.2    |     2857.5  |    2283.9
      f16 B=32, M=197, H=16, K=80    |       878.1    |      787.2  |    2568.0
      f32 B=32, M=197, H=16, K=80    |      2933.3    |     2834.4  |    2351.5
      f16 B=32, M=197, H=16, K=64    |       496.0    |      538.7  |    2430.3
      f32 B=32, M=197, H=16, K=64    |      1797.2    |     1777.6  |    2195.1
      f16 B=32, M=197, H=16, K=128   |      1035.3    |      928.2  |    4486.7
      f32 B=32, M=197, H=16, K=128   |      3594.9    |     3544.5  |    2803.4
      f16 B=256, M=197, H=1, K=88    |       515.2    |      477.7  |    1521.7
      f32 B=256, M=197, H=1, K=88    |      1700.9    |     1675.8  |    1206.7
      f16 B=16, M=197, H=16, K=88    |       514.8    |      473.4  |    1539.3
      f32 B=16, M=197, H=16, K=88    |      1689.0    |     1664.6  |    1249.0
      f16 B=16, M=197, H=16, K=64    |       253.2    |      276.0  |    1242.9
      f32 B=16, M=197, H=16, K=64    |      1070.2    |     1060.9  |    1124.1
      f16 B=16, M=197, H=16, K=128   |       575.4    |      526.5  |    2266.7
      f32 B=16, M=197, H=16, K=128   |      1960.4    |     1935.8  |    1444.7
      f16 B=1, M=4096, H=160, K=128  |     62997.4    |    67019.0  |   46384.8
      f32 B=1, M=4096, H=160, K=128  |    237718.0    |   222376.4  |          
      f16 B=2, M=4096, H=160, K=128  |    106238.3    |   110240.7  |          
      f32 B=2, M=4096, H=160, K=128  |    374967.4    |   351572.8  |          
      f16 B=1, M=8192, H=160, K=128  |    246711.2    |   267465.8  |          
      f32 B=1, M=8192, H=160, K=128  |    943042.3    |   881885.6  |          
      f16 B=2, M=8192, H=160, K=128  |    419967.4    |   433848.1  |          
      f32 B=2, M=8192, H=160, K=128  |   1490850.6    |  1398395.3  |          
      f16 B=1024, M=82, H=8, K=64    |      2009.6    |     2111.2  |    3823.5
      f32 B=1024, M=82, H=8, K=64    |      8479.3    |     8376.5  |    8720.2
      f16 B=150, M=256, H=16, K=64   |      2335.7    |     2537.9  |    4560.5
      f32 B=150, M=256, H=16, K=64   |      6255.7    |     6269.7  |   12921.7
      f16 B=64, M=256, H=12, K=64    |       789.1    |      875.9  |    1499.4
      f32 B=64, M=256, H=12, K=64    |      2149.2    |     2153.6  |    4260.9
      f16 B=1, M=4096, H=16, K=40    |     23998.0    |    25712.6  |    4235.3
      f32 B=1, M=4096, H=16, K=40    |     73606.0    |    73180.5  |   17706.1
      f16 B=1, M=16384, H=16, K=40   |    396513.0    |   430370.9  |          
      f32 B=1, M=16384, H=16, K=40   |   1195031.7    |  1187422.5  |          
      f16 B=256, M=4096, H=16, K=64  |    742711.8    |   801632.2  |          
      f16 B=16, M=128, H=16, K=16    |       242.7    |      189.3  |     306.7
      f32 B=16, M=128, H=16, K=16    |       291.9    |      231.0  |     373.0
      f16 B=16, M=128, H=16, K=32    |       244.8    |      182.4  |     302.1
      f32 B=16, M=128, H=16, K=32    |       292.2    |      226.6  |     413.2
      f16 B=16, M=128, H=16, K=64    |       243.5    |      182.9  |     301.7
      f32 B=16, M=128, H=16, K=64    |       285.9    |      273.2  |     499.2
      f16 B=16, M=128, H=16, K=128   |       241.5    |      209.6  |     304.6
      f32 B=16, M=128, H=16, K=128   |       510.0    |      488.3  |     672.2
      f16 B=16, M=128, H=16, K=256   |       752.0    |      777.0  |     544.9
      f32 B=16, M=128, H=16, K=256   |       975.4    |      937.4  |    1162.6
      f16 B=16, M=512, H=16, K=16    |       640.8    |      713.5  |    1203.7
      f32 B=16, M=512, H=16, K=16    |      2162.6    |     2150.9  |    4409.0
      f16 B=16, M=512, H=16, K=32    |       721.8    |      805.8  |    1306.9
      f32 B=16, M=512, H=16, K=32    |      2352.4    |     2343.7  |    4633.5
      f16 B=16, M=512, H=16, K=64    |       928.7    |     1019.2  |    1544.1
      f32 B=16, M=512, H=16, K=64    |      2991.9    |     2981.1  |    5115.9
      f16 B=16, M=512, H=16, K=128   |      1843.7    |     1958.5  |    1984.9
      f32 B=16, M=512, H=16, K=128   |      6134.5    |     5800.4  |    6086.4
      f16 B=16, M=512, H=16, K=256   |      8277.8    |     8490.1  |    2902.9
      f32 B=16, M=512, H=16, K=256   |     11843.6    |    11313.2  |   10617.2
      f16 B=16, M=1024, H=16, K=16   |      2476.9    |     2809.0  |    4262.6
      f32 B=16, M=1024, H=16, K=16   |      8565.4    |     8520.4  |   16608.1
      f16 B=16, M=1024, H=16, K=32   |      2722.0    |     3086.4  |    4485.7
      f32 B=16, M=1024, H=16, K=32   |      9034.9    |     9040.3  |   17262.9
      f16 B=16, M=1024, H=16, K=64   |      3371.4    |     3721.9  |    4991.7
      f32 B=16, M=1024, H=16, K=64   |     11629.5    |    11677.5  |   18670.5
      f16 B=16, M=1024, H=16, K=128  |      6575.3    |     7003.0  |    5949.2
      f32 B=16, M=1024, H=16, K=128  |     23319.6    |    21954.0  |   21480.0
      f16 B=16, M=1024, H=16, K=256  |     30759.2    |    32062.5  |    7897.9
      f32 B=16, M=1024, H=16, K=256  |     45035.5    |    42840.9  |   37951.9
      f16 B=64, M=128, H=16, K=16    |       247.4    |      184.8  |     439.3
      f32 B=64, M=128, H=16, K=16    |       496.8    |      495.2  |    1268.7
      f16 B=64, M=128, H=16, K=32    |       245.4    |      241.0  |     545.3
      f32 B=64, M=128, H=16, K=32    |       604.2    |      603.1  |    1425.5
      f16 B=64, M=128, H=16, K=64    |       336.0    |      369.2  |     767.2
      f32 B=64, M=128, H=16, K=64    |       872.9    |      871.9  |    1743.4
      f16 B=64, M=128, H=16, K=128   |       697.7    |      723.9  |    1228.2
      f32 B=64, M=128, H=16, K=128   |      1770.0    |     1699.5  |    2383.6
      f16 B=64, M=128, H=16, K=256   |      2788.9    |     2888.7  |    2129.9
      f32 B=64, M=128, H=16, K=256   |      3404.7    |     3289.3  |    4314.9
      f16 B=64, M=512, H=16, K=16    |      2381.7    |     2629.4  |    4486.1
      f32 B=64, M=512, H=16, K=16    |      6693.1    |     6719.5  |   16963.1
      f16 B=64, M=512, H=16, K=32    |      2753.7    |     3005.5  |    4975.9
      f32 B=64, M=512, H=16, K=32    |      7485.2    |     7497.9  |   17823.2
      f16 B=64, M=512, H=16, K=64    |      3541.3    |     3876.7  |    5893.6
      f32 B=64, M=512, H=16, K=64    |      9614.1    |     9634.2  |   19731.2
      f16 B=64, M=512, H=16, K=128   |      6648.7    |     6871.8  |    7707.9
      f32 B=64, M=512, H=16, K=128   |     21311.3    |    20087.6  |   23584.0
      f16 B=64, M=512, H=16, K=256   |     30093.1    |    30844.5  |   11501.6
      f32 B=64, M=512, H=16, K=256   |     40917.9    |    38994.0  |   42386.4
      f16 B=64, M=1024, H=16, K=16   |      9376.5    |    10399.8  |   16846.5
      f32 B=64, M=1024, H=16, K=16   |     26600.9    |    26744.9  |   66205.9
      f16 B=64, M=1024, H=16, K=32   |     10675.9    |    11750.7  |   17866.1
      f32 B=64, M=1024, H=16, K=32   |     28444.3    |    28477.5  |   68832.9
      f16 B=64, M=1024, H=16, K=64   |     13108.4    |    14436.5  |   19915.5
      f32 B=64, M=1024, H=16, K=64   |     35859.2    |    35988.5  |   74463.8
      f16 B=64, M=1024, H=16, K=128  |     23631.7    |    24519.0  |   23742.3
      f32 B=64, M=1024, H=16, K=128  |     80305.7    |    75406.6  |   85733.5
      f16 B=64, M=1024, H=16, K=256  |    111204.0    |   115626.1  |   32765.2
      f32 B=64, M=1024, H=16, K=256  |    154761.0    |   147906.3  |  152428.4

Times are in microseconds (us).

[ attention backward (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) ]
                                     |  pr587_d1b0fa  |    main    |  vanilla 
1 threads: -------------------------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |       561.9    |     527.5  |    2261.1
      f32 B=384, M=197, H=1, K=88    |      1846.8    |    1791.2  |    1841.0
      f16 B=384, M=197, H=1, K=80    |       534.6    |     501.6  |    1915.5
      f32 B=384, M=197, H=1, K=80    |      1785.6    |    1722.7  |    1786.5
      f16 B=384, M=197, H=1, K=64    |       283.9    |     325.0  |    1810.7
      f32 B=384, M=197, H=1, K=64    |       982.1    |     974.5  |    1674.6
      f16 B=1024, M=197, H=1, K=88   |      1409.3    |    1302.1  |    5939.6
      f32 B=1024, M=197, H=1, K=88   |      4695.6    |    4595.7  |    4552.5
      f16 B=1024, M=197, H=1, K=80   |      1341.1    |    1237.0  |    5019.9
      f32 B=1024, M=197, H=1, K=80   |      4540.5    |    4435.3  |    4406.9
      f16 B=1024, M=197, H=1, K=64   |       645.5    |     725.8  |    4732.0
      f32 B=1024, M=197, H=1, K=64   |      2583.7    |    2548.6  |    4112.1
      f16 B=512, M=197, H=1, K=80    |       682.2    |     634.7  |    2535.1
      f32 B=512, M=197, H=1, K=80    |      2244.9    |    2198.3  |    2283.7
      f16 B=32, M=197, H=16, K=80    |       691.3    |     640.3  |    2570.0
      f32 B=32, M=197, H=16, K=80    |      2231.3    |    2196.0  |    2351.8
      f16 B=32, M=197, H=16, K=64    |       341.6    |     378.3  |    2428.0
      f32 B=32, M=197, H=16, K=64    |      1359.5    |    1371.3  |    2193.7
      f16 B=32, M=197, H=16, K=128   |       829.4    |     765.2  |    4489.1
      f32 B=32, M=197, H=16, K=128   |      2727.5    |    2684.6  |    2802.6
      f16 B=256, M=197, H=1, K=88    |       407.3    |     386.0  |    1526.5
      f32 B=256, M=197, H=1, K=88    |      1316.5    |    1291.3  |    1210.1
      f16 B=16, M=197, H=16, K=88    |       407.1    |     385.3  |    1539.7
      f32 B=16, M=197, H=16, K=88    |      1310.8    |    1282.6  |    1251.5
      f16 B=16, M=197, H=16, K=64    |       247.9    |     192.4  |    1243.9
      f32 B=16, M=197, H=16, K=64    |       791.2    |     809.9  |    1126.4
      f16 B=16, M=197, H=16, K=128   |       460.9    |     431.8  |    2268.8
      f32 B=16, M=197, H=16, K=128   |      1548.2    |    1484.0  |    1445.7
      f16 B=1, M=4096, H=160, K=128  |     33584.4    |   36002.6  |   46369.0
      f32 B=1, M=4096, H=160, K=128  |    124322.3    |  117500.2  |          
      f16 B=2, M=4096, H=160, K=128  |     56632.1    |   58882.2  |          
      f32 B=2, M=4096, H=160, K=128  |    196299.6    |  185992.3  |          
      f16 B=1, M=8192, H=160, K=128  |    129289.4    |  138568.8  |          
      f32 B=1, M=8192, H=160, K=128  |    482534.0    |  455362.6  |          
      f16 B=2, M=8192, H=160, K=128  |    217759.5    |  225203.4  |          
      f32 B=2, M=8192, H=160, K=128  |    763277.0    |  722918.1  |          
      f16 B=1024, M=82, H=8, K=64    |      1648.7    |    1726.6  |    3822.7
      f32 B=1024, M=82, H=8, K=64    |      7663.7    |    7623.2  |    8710.4
      f16 B=150, M=256, H=16, K=64   |      1643.0    |    1826.8  |    4561.7
      f32 B=150, M=256, H=16, K=64   |      4484.1    |    4477.1  |   12926.9
      f16 B=64, M=256, H=12, K=64    |       566.3    |     632.6  |    1500.9
      f32 B=64, M=256, H=12, K=64    |      1536.6    |    1534.2  |    4260.9
      f16 B=1, M=4096, H=16, K=40    |     11154.9    |   12100.5  |    4237.0
      f32 B=1, M=4096, H=16, K=40    |     35628.1    |   35281.9  |   17692.4
      f16 B=1, M=16384, H=16, K=40   |    199237.0    |  221542.6  |          
      f32 B=1, M=16384, H=16, K=40   |    597761.9    |  592061.1  |          
      f16 B=256, M=4096, H=16, K=64  |    388909.8    |  424073.1  |          
      f16 B=16, M=128, H=16, K=16    |       243.5    |     183.9  |     289.0
      f32 B=16, M=128, H=16, K=16    |       287.6    |     227.3  |     373.5
      f16 B=16, M=128, H=16, K=32    |       243.0    |     182.8  |     286.4
      f32 B=16, M=128, H=16, K=32    |       292.0    |     227.3  |     415.3
      f16 B=16, M=128, H=16, K=64    |       241.0    |     184.3  |     287.4
      f32 B=16, M=128, H=16, K=64    |       287.9    |     231.1  |     502.4
      f16 B=16, M=128, H=16, K=128   |       245.0    |     210.2  |     301.0
      f32 B=16, M=128, H=16, K=128   |       510.4    |     489.2  |     679.4
      f16 B=16, M=128, H=16, K=256   |       750.4    |     777.2  |     555.4
      f32 B=16, M=128, H=16, K=256   |       975.4    |     939.9  |    1163.0
      f16 B=16, M=512, H=16, K=16    |       360.7    |     413.9  |    1199.9
      f32 B=16, M=512, H=16, K=16    |      1260.2    |    1242.3  |    4408.4
      f16 B=16, M=512, H=16, K=32    |       425.0    |     484.5  |    1305.6
      f32 B=16, M=512, H=16, K=32    |      1410.3    |    1400.3  |    4633.7
      f16 B=16, M=512, H=16, K=64    |       578.3    |     641.4  |    1544.3
      f32 B=16, M=512, H=16, K=64    |      1846.5    |    1833.5  |    5117.6
      f16 B=16, M=512, H=16, K=128   |      1289.4    |    1375.8  |    1986.1
      f32 B=16, M=512, H=16, K=128   |      4047.1    |    3852.5  |    6086.9
      f16 B=16, M=512, H=16, K=256   |      5688.5    |    5757.7  |    2903.4
      f32 B=16, M=512, H=16, K=256   |      7822.0    |    7501.2  |   10619.7
      f16 B=16, M=1024, H=16, K=16   |      1319.9    |    1522.2  |    4256.2
      f32 B=16, M=1024, H=16, K=16   |      4607.4    |    4583.4  |   16612.4
      f16 B=16, M=1024, H=16, K=32   |      1491.7    |    1702.7  |    4478.4
      f32 B=16, M=1024, H=16, K=32   |      5000.8    |    4968.6  |   17261.8
      f16 B=16, M=1024, H=16, K=64   |      1915.2    |    2123.9  |    4987.3
      f32 B=16, M=1024, H=16, K=64   |      6433.8    |    6376.7  |   18674.5
      f16 B=16, M=1024, H=16, K=128  |      4031.5    |    4296.1  |    5947.0
      f32 B=16, M=1024, H=16, K=128  |     13639.1    |   12937.8  |   21481.2
      f16 B=16, M=1024, H=16, K=256  |     18687.2    |   19016.0  |    7896.4
      f32 B=16, M=1024, H=16, K=256  |     26324.2    |   25151.9  |   37929.3
      f16 B=64, M=128, H=16, K=16    |       246.0    |     184.6  |     440.2
      f32 B=64, M=128, H=16, K=16    |       406.2    |     402.4  |    1270.6
      f16 B=64, M=128, H=16, K=32    |       238.1    |     204.3  |     545.0
      f32 B=64, M=128, H=16, K=32    |       511.8    |     508.2  |    1427.3
      f16 B=64, M=128, H=16, K=64    |       288.2    |     312.9  |     773.7
      f32 B=64, M=128, H=16, K=64    |       737.5    |     737.1  |    1743.0
      f16 B=64, M=128, H=16, K=128   |       698.3    |     723.6  |    1226.1
      f32 B=64, M=128, H=16, K=128   |      1771.7    |    1703.0  |    2383.4
      f16 B=64, M=128, H=16, K=256   |      2788.5    |    2888.6  |    2129.0
      f32 B=64, M=128, H=16, K=256   |      3410.2    |    3294.6  |    4314.2
      f16 B=64, M=512, H=16, K=16    |      1313.0    |    1522.1  |    4483.1
      f32 B=64, M=512, H=16, K=16    |      3873.0    |    3864.4  |   16965.0
      f16 B=64, M=512, H=16, K=32    |      1612.2    |    1810.4  |    4972.8
      f32 B=64, M=512, H=16, K=32    |      4512.1    |    4501.9  |   17822.1
      f16 B=64, M=512, H=16, K=64    |      2226.3    |    2484.9  |    5891.3
      f32 B=64, M=512, H=16, K=64    |      5969.3    |    5975.5  |   19736.5
      f16 B=64, M=512, H=16, K=128   |      4692.3    |    4853.1  |    7704.5
      f32 B=64, M=512, H=16, K=128   |     14180.2    |   13458.0  |   23594.8
      f16 B=64, M=512, H=16, K=256   |     20576.9    |   21087.3  |   11491.5
      f32 B=64, M=512, H=16, K=256   |     27090.0    |   25985.1  |   42300.0
      f16 B=64, M=1024, H=16, K=16   |      4906.8    |    5585.3  |   16841.1
      f32 B=64, M=1024, H=16, K=16   |     14364.6    |   14389.0  |   66224.9
      f16 B=64, M=1024, H=16, K=32   |      5754.5    |    6465.5  |   17853.1
      f32 B=64, M=1024, H=16, K=32   |     15832.9    |   15835.2  |   68841.8
      f16 B=64, M=1024, H=16, K=64   |      7421.6    |    8286.4  |   19909.4
      f32 B=64, M=1024, H=16, K=64   |     20266.3    |   20341.4  |   74454.6
      f16 B=64, M=1024, H=16, K=128  |     14653.6    |   15119.5  |   23731.7
      f32 B=64, M=1024, H=16, K=128  |     47294.1    |   44702.4  |   85699.6
      f16 B=64, M=1024, H=16, K=256  |     67373.1    |   69018.2  |   32542.6
      f32 B=64, M=1024, H=16, K=256  |     90539.4    |   87053.4  |  152328.7

Times are in microseconds (us).

V100/P100

V100/P100 fw
[---------------------------- attention (attn_bias=<class 'NoneType'>) ---------------------------]                                                                                                                                                                                                          
                                                         |  pr587_d1b0fa  |     main    |   eager  
1 threads: ----------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |      1556.8    |     1510.0  |    1397.8
                          f32 B=384, M=197, H=1, K=88    |      1505.3    |     1434.3  |    1509.2
                          f16 B=384, M=197, H=1, K=80    |      1485.5    |     1457.3  |    1347.1
                          f32 B=384, M=197, H=1, K=80    |      1437.6    |     1388.8  |    1459.8
                          f16 B=384, M=197, H=1, K=64    |      1111.2    |     1079.7  |    1245.4
                          f32 B=384, M=197, H=1, K=64    |      1025.4    |      988.8  |    1344.6
                          f16 B=1024, M=197, H=1, K=88   |      4104.5    |     4065.9  |    3705.0
                          f32 B=1024, M=197, H=1, K=88   |      4021.5    |     3873.3  |    3997.2
                          f16 B=1024, M=197, H=1, K=80   |      3950.2    |     3904.2  |    3603.3
                          f32 B=1024, M=197, H=1, K=80   |      3859.7    |     3784.7  |    3880.4
                          f16 B=1024, M=197, H=1, K=64   |      2913.2    |     2880.6  |    3349.5
                          f32 B=1024, M=197, H=1, K=64   |      2717.0    |     2634.3  |    3593.6
                          f16 B=512, M=197, H=1, K=80    |      2000.9    |     1984.0  |    1846.5
                          f32 B=512, M=197, H=1, K=80    |      1956.6    |     1924.1  |    1975.0
                          f16 B=32, M=197, H=16, K=80    |      2000.1    |     2028.9  |    2209.5
                          f32 B=32, M=197, H=16, K=80    |      1970.5    |     1958.0  |    2388.4
                          f16 B=32, M=197, H=16, K=64    |      1498.6    |     1477.0  |    1990.0
                          f32 B=32, M=197, H=16, K=64    |      1377.9    |     1350.7  |    2163.4
                          f16 B=32, M=197, H=16, K=128   |      2527.6    |     2530.0  |    2878.7
                          f32 B=32, M=197, H=16, K=128   |      2470.2    |     2429.7  |    3099.2
                          f16 B=256, M=197, H=1, K=88    |      1068.9    |     1070.5  |     990.1
                          f32 B=256, M=197, H=1, K=88    |      1037.8    |     1024.0  |    1053.3
                          f16 B=16, M=197, H=16, K=88    |      1069.7    |     1072.3  |    1194.3
                          f32 B=16, M=197, H=16, K=88    |      1046.0    |     1029.1  |    1282.4
                          f16 B=16, M=197, H=16, K=64    |       772.6    |      764.4  |    1028.8
                          f32 B=16, M=197, H=16, K=64    |       713.2    |      696.7  |    1108.2
                          f16 B=16, M=197, H=16, K=128   |      1287.8    |     1293.0  |    1470.2
                          f32 B=16, M=197, H=16, K=128   |      1251.6    |     1230.2  |    1592.4
                          f16 B=1, M=4096, H=160, K=128  |    252819.7    |   252185.1  |  201705.3
                          f32 B=1, M=4096, H=160, K=128  |    244127.7    |   241164.8  |          
                          f16 B=2, M=4096, H=160, K=128  |    501082.6    |   500526.4  |          
                          f32 B=2, M=4096, H=160, K=128  |    491881.1    |   485637.0  |          
                          f16 B=1, M=8192, H=160, K=128  |   1023464.1    |  1015719.4  |          
                          f32 B=1, M=8192, H=160, K=128  |   1014686.4    |   996319.1  |          
                          f16 B=2, M=8192, H=160, K=128  |   2047446.1    |  2037663.8  |          
                          f32 B=2, M=8192, H=160, K=128  |   2036186.5    |  1997234.2  |          
                          f16 B=1024, M=82, H=8, K=64    |      5891.7    |     5752.5  |    8562.6
                          f32 B=1024, M=82, H=8, K=64    |      5493.1    |     5329.1  |    9109.9
                          f16 B=150, M=256, H=16, K=64   |      7810.6    |     7729.5  |   11332.5
                          f32 B=150, M=256, H=16, K=64   |      7301.2    |     7054.9  |   12679.5
                          f16 B=64, M=256, H=12, K=64    |      2548.6    |     2511.9  |    3674.3
                          f32 B=64, M=256, H=12, K=64    |      2371.9    |     2308.0  |    4095.5
                          f16 B=1, M=4096, H=16, K=40    |     11293.4    |    11231.9  |   14588.4
                          f32 B=1, M=4096, H=16, K=40    |     10351.9    |     9929.1  |   17735.1
                          f16 B=1, M=16384, H=16, K=40   |    172958.1    |   170210.6  |          
                          f32 B=1, M=16384, H=16, K=40   |    159533.5    |   155516.2  |          
                          f16 B=256, M=4096, H=16, K=64  |   3272011.4    |  3252023.7  |          
                          f16 B=16, M=128, H=16, K=16    |       152.5    |      150.9  |     265.4
                          f32 B=16, M=128, H=16, K=16    |       145.2    |      141.5  |     300.4
                          f16 B=16, M=128, H=16, K=32    |       182.3    |      179.2  |     311.5
                          f32 B=16, M=128, H=16, K=32    |       168.6    |      166.3  |     357.9
                          f16 B=16, M=128, H=16, K=64    |       236.3    |      231.8  |     414.1
                          f32 B=16, M=128, H=16, K=64    |       231.4    |      227.0  |     462.4
                          f16 B=16, M=128, H=16, K=128   |       439.2    |      437.0  |     599.3
                          f32 B=16, M=128, H=16, K=128   |       459.8    |      452.5  |     685.6
                          f16 B=16, M=128, H=16, K=256   |       841.2    |      835.5  |    1116.1
                          f32 B=16, M=128, H=16, K=256   |       898.0    |      899.6  |    1358.8
                          f16 B=16, M=512, H=16, K=16    |      2168.0    |     2150.5  |    3183.3
                          f32 B=16, M=512, H=16, K=16    |      1956.2    |     1960.3  |    3646.8
                          f16 B=16, M=512, H=16, K=32    |      2562.7    |     2548.4  |    3576.9
                          f32 B=16, M=512, H=16, K=32    |      2311.2    |     2256.6  |    4017.2
                          f16 B=16, M=512, H=16, K=64    |      3355.9    |     3323.5  |    4343.4
                          f32 B=16, M=512, H=16, K=64    |      3079.5    |     2990.8  |    4862.4
                          f16 B=16, M=512, H=16, K=128   |      6508.2    |     6408.5  |    5901.5
                          f32 B=16, M=512, H=16, K=128   |      6141.6    |     6033.5  |    6559.6
                          f16 B=16, M=512, H=16, K=256   |     13089.9    |    12975.4  |   10746.3
                          f32 B=16, M=512, H=16, K=256   |     12765.7    |    12683.6  |   11748.1
                          f16 B=16, M=1024, H=16, K=16   |      8416.0    |     8366.1  |   12316.8
                          f32 B=16, M=1024, H=16, K=16   |      7545.0    |     7490.7  |   13839.4
                          f16 B=16, M=1024, H=16, K=32   |     10006.6    |     9894.1  |   13713.1
                          f32 B=16, M=1024, H=16, K=32   |      9061.8    |     8854.9  |   15169.7
                          f16 B=16, M=1024, H=16, K=64   |     13122.4    |    12859.7  |   16233.0
                          f32 B=16, M=1024, H=16, K=64   |     11880.8    |    11548.4  |   18025.9
                          f16 B=16, M=1024, H=16, K=128  |     25576.0    |    25314.9  |   21504.5
                          f32 B=16, M=1024, H=16, K=128  |     23925.9    |    23507.9  |   23364.1
                          f16 B=16, M=1024, H=16, K=256  |     51924.3    |    51510.7  |   38749.2
                          f32 B=16, M=1024, H=16, K=256  |     51078.2    |    50812.9  |   41896.2
                          f16 B=64, M=128, H=16, K=16    |       574.8    |      565.0  |     961.0
                          f32 B=64, M=128, H=16, K=16    |       534.4    |      524.3  |    1091.2
                          f16 B=64, M=128, H=16, K=32    |       674.2    |      665.8  |    1138.3
                          f32 B=64, M=128, H=16, K=32    |       631.2    |      613.0  |    1300.2
                          f16 B=64, M=128, H=16, K=64    |       887.0    |      866.5  |    1524.1
                          f32 B=64, M=128, H=16, K=64    |       843.9    |      825.7  |    1723.0
                          f16 B=64, M=128, H=16, K=128   |      1715.6    |     1682.9  |    2270.2
                          f32 B=64, M=128, H=16, K=128   |      1725.6    |     1707.2  |    2591.8
                          f16 B=64, M=128, H=16, K=256   |      3270.8    |     3262.1  |    4236.3
                          f32 B=64, M=128, H=16, K=256   |      3434.7    |     3443.2  |    5390.7
                          f16 B=64, M=512, H=16, K=16    |      8465.0    |     8419.1  |   12452.5
                          f32 B=64, M=512, H=16, K=16    |      7659.0    |     7496.2  |   14249.4
                          f16 B=64, M=512, H=16, K=32    |      9961.1    |     9984.7  |   14043.7
                          f32 B=64, M=512, H=16, K=32    |      9102.9    |     8902.3  |   15898.0
                          f16 B=64, M=512, H=16, K=64    |     13108.9    |    13064.5  |   17211.0
                          f32 B=64, M=512, H=16, K=64    |     11988.9    |    11695.4  |   19154.0
                          f16 B=64, M=512, H=16, K=128   |     25713.3    |    25396.8  |   23241.0
                          f32 B=64, M=512, H=16, K=128   |     24509.1    |    24079.8  |   25745.1
                          f16 B=64, M=512, H=16, K=256   |     51856.7    |    51471.0  |   43083.7
                          f32 B=64, M=512, H=16, K=256   |     50690.1    |    50596.1  |   47045.4
                          f16 B=64, M=1024, H=16, K=16   |     33369.6    |    32927.4  |   49295.4
                          f32 B=64, M=1024, H=16, K=16   |     30192.1    |    29545.4  |   55489.4
                          f16 B=64, M=1024, H=16, K=32   |     39164.9    |    39128.1  |   54820.3
                          f32 B=64, M=1024, H=16, K=32   |     35632.6    |    34791.5  |   59901.5
                          f16 B=64, M=1024, H=16, K=64   |     51700.2    |    51089.9  |   65794.8
                          f32 B=64, M=1024, H=16, K=64   |     46797.3    |    45600.0  |   72379.6
                          f16 B=64, M=1024, H=16, K=128  |    102438.1    |   100577.7  |   85978.0
                          f32 B=64, M=1024, H=16, K=128  |     97144.1    |    94634.7  |   94077.1
                          f16 B=64, M=1024, H=16, K=256  |    206433.0    |   205416.2  |  156833.3
                          f32 B=64, M=1024, H=16, K=256  |    204046.1    |   204171.7  |          
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |       309.1    |      280.5  |     550.9
                          f32 B=384, M=197, H=1, K=88    |       809.1    |      780.4  |     888.6
                          f16 B=384, M=197, H=1, K=80    |       293.6    |      270.8  |     526.0
                          f32 B=384, M=197, H=1, K=80    |       777.6    |      755.2  |     857.3
                          f16 B=384, M=197, H=1, K=64    |       229.4    |      191.4  |     418.9
                          f32 B=384, M=197, H=1, K=64    |       587.6    |      567.8  |     714.1
                          f16 B=1024, M=197, H=1, K=88   |       816.2    |      742.6  |    1425.3
                          f32 B=1024, M=197, H=1, K=88   |      2132.1    |     2075.7  |    2352.3
                          f16 B=1024, M=197, H=1, K=80   |       779.6    |      711.3  |    1360.5
                          f32 B=1024, M=197, H=1, K=80   |      2051.8    |     1994.1  |    2264.5
                          f16 B=1024, M=197, H=1, K=64   |       587.1    |      498.9  |    1075.2
                          f32 B=1024, M=197, H=1, K=64   |      1532.9    |     1493.6  |    1866.1
                          f16 B=512, M=197, H=1, K=80    |       391.4    |      359.6  |     690.8
                          f32 B=512, M=197, H=1, K=80    |      1037.3    |     1011.7  |    1147.9
                          f16 B=32, M=197, H=16, K=80    |       404.6    |      374.6  |     841.7
                          f32 B=32, M=197, H=16, K=80    |      1047.9    |     1031.9  |    1383.7
                          f16 B=32, M=197, H=16, K=64    |       305.8    |      255.3  |     675.1
                          f32 B=32, M=197, H=16, K=64    |       789.5    |      763.7  |    1133.3
                          f16 B=32, M=197, H=16, K=128   |       466.0    |      425.2  |    1030.2
                          f32 B=32, M=197, H=16, K=128   |      1353.7    |     1314.9  |    1902.7
                          f16 B=256, M=197, H=1, K=88    |       208.9    |      193.7  |     378.2
                          f32 B=256, M=197, H=1, K=88    |       551.6    |      540.4  |     612.0
                          f16 B=16, M=197, H=16, K=88    |       209.7    |      196.0  |     461.6
                          f32 B=16, M=197, H=16, K=88    |       553.3    |      542.0  |     746.9
                          f16 B=16, M=197, H=16, K=64    |       162.0    |      171.2  |     429.8
                          f32 B=16, M=197, H=16, K=64    |       411.5    |      397.3  |     609.0
                          f16 B=16, M=197, H=16, K=128   |       237.4    |      218.1  |     536.7
                          f32 B=16, M=197, H=16, K=128   |       687.1    |      673.3  |     966.1
                          f16 B=1, M=4096, H=160, K=128  |     37632.1    |    35200.8  |   44509.1
                          f32 B=1, M=4096, H=160, K=128  |    140941.0    |   139848.7  |          
                          f16 B=2, M=4096, H=160, K=128  |     75731.1    |    70836.1  |          
                          f32 B=2, M=4096, H=160, K=128  |    278838.3    |   279696.9  |          
                          f16 B=1, M=8192, H=160, K=128  |    153637.8    |   144262.0  |          
                          f32 B=1, M=8192, H=160, K=128  |    577397.4    |   578509.8  |          
                          f16 B=2, M=8192, H=160, K=128  |    307287.9    |   289830.3  |          
                          f32 B=2, M=8192, H=160, K=128  |   1159578.7    |  1163174.7  |          
                          f16 B=1024, M=82, H=8, K=64    |      1337.0    |     1095.4  |    2496.5
                          f32 B=1024, M=82, H=8, K=64    |      2987.2    |     2915.9  |    4424.4
                          f16 B=150, M=256, H=16, K=64   |      1541.8    |     1295.8  |    3131.0
                          f32 B=150, M=256, H=16, K=64   |      4111.0    |     4037.3  |    6739.0
                          f16 B=64, M=256, H=12, K=64    |       493.5    |      416.9  |    1025.5
                          f32 B=64, M=256, H=12, K=64    |      1333.6    |     1304.1  |    2254.9
                          f16 B=1, M=4096, H=16, K=40    |      2210.2    |     1996.3  |    4067.3
                          f32 B=1, M=4096, H=16, K=40    |      5951.6    |     5851.3  |    8249.3
                          f16 B=1, M=16384, H=16, K=40   |     37370.7    |    29646.1  |          
                          f32 B=1, M=16384, H=16, K=40   |     89387.1    |    86661.6  |          
                          f16 B=256, M=4096, H=16, K=64  |    531524.0    |   462010.1  |          
                          f16 B=16, M=128, H=16, K=16    |       129.1    |      148.0  |     356.6
                          f32 B=16, M=128, H=16, K=16    |       126.0    |      144.8  |     357.5
                          f16 B=16, M=128, H=16, K=32    |       125.4    |      147.6  |     345.4
                          f32 B=16, M=128, H=16, K=32    |       129.3    |      149.3  |     335.8
                          f16 B=16, M=128, H=16, K=64    |       129.0    |      143.7  |     346.2
                          f32 B=16, M=128, H=16, K=64    |       134.1    |      151.9  |     350.6
                          f16 B=16, M=128, H=16, K=128   |       130.2    |      144.9  |     349.5
                          f32 B=16, M=128, H=16, K=128   |       247.0    |      241.8  |     462.1
                          f16 B=16, M=128, H=16, K=256   |       177.3    |      167.3  |     367.1
                          f32 B=16, M=128, H=16, K=256   |       480.4    |      477.2  |     862.7
                          f16 B=16, M=512, H=16, K=16    |       444.3    |      399.8  |     845.5
                          f32 B=16, M=512, H=16, K=16    |      1105.1    |     1088.7  |    1830.8
                          f16 B=16, M=512, H=16, K=32    |       468.8    |      418.4  |     937.4
                          f32 B=16, M=512, H=16, K=32    |      1299.8    |     1264.9  |    2095.9
                          f16 B=16, M=512, H=16, K=64    |       588.2    |      511.8  |    1143.3
                          f32 B=16, M=512, H=16, K=64    |      1718.7    |     1681.4  |    2501.4
                          f16 B=16, M=512, H=16, K=128   |      1038.0    |      965.5  |    1458.1
                          f32 B=16, M=512, H=16, K=128   |      3376.2    |     3330.1  |    4066.9
                          f16 B=16, M=512, H=16, K=256   |      2865.5    |     2761.3  |    2316.0
                          f32 B=16, M=512, H=16, K=256   |      7149.8    |     7204.6  |    7218.3
                          f16 B=16, M=1024, H=16, K=16   |      1693.2    |     1537.0  |    3433.3
                          f32 B=16, M=1024, H=16, K=16   |      4279.4    |     4195.9  |    7165.5
                          f16 B=16, M=1024, H=16, K=32   |      1761.9    |     1589.4  |    3673.3
                          f32 B=16, M=1024, H=16, K=32   |      5043.3    |     4958.9  |    7966.2
                          f16 B=16, M=1024, H=16, K=64   |      2186.1    |     1945.7  |    4154.1
                          f32 B=16, M=1024, H=16, K=64   |      6681.1    |     6528.0  |    9394.8
                          f16 B=16, M=1024, H=16, K=128  |      3791.6    |     3621.1  |    4855.3
                          f32 B=16, M=1024, H=16, K=128  |     13205.2    |    13031.0  |   15398.9
                          f16 B=16, M=1024, H=16, K=256  |     11482.8    |    11071.2  |    7464.6
                          f32 B=16, M=1024, H=16, K=256  |     28146.0    |    28171.8  |   26984.5
                          f16 B=64, M=128, H=16, K=16    |       140.8    |      139.2  |     364.1
                          f32 B=64, M=128, H=16, K=16    |       293.9    |      288.3  |     517.2
                          f16 B=64, M=128, H=16, K=32    |       156.1    |      172.0  |     355.8
                          f32 B=64, M=128, H=16, K=32    |       349.6    |      339.6  |     657.9
                          f16 B=64, M=128, H=16, K=64    |       207.6    |      170.1  |     492.8
                          f32 B=64, M=128, H=16, K=64    |       477.3    |      465.8  |     935.6
                          f16 B=64, M=128, H=16, K=128   |       352.3    |      323.9  |     765.4
                          f32 B=64, M=128, H=16, K=128   |       936.8    |      915.7  |    1600.0
                          f16 B=64, M=128, H=16, K=256   |       656.1    |      624.2  |    1326.2
                          f32 B=64, M=128, H=16, K=256   |      1807.8    |     1825.8  |    2868.1
                          f16 B=64, M=512, H=16, K=16    |      1753.8    |     1576.0  |    3237.9
                          f32 B=64, M=512, H=16, K=16    |      4252.9    |     4193.3  |    7437.9
                          f16 B=64, M=512, H=16, K=32    |      1851.1    |     1649.4  |    3603.9
                          f32 B=64, M=512, H=16, K=32    |      5050.7    |     4969.9  |    8514.8
                          f16 B=64, M=512, H=16, K=64    |      2312.4    |     2015.4  |    4469.6
                          f32 B=64, M=512, H=16, K=64    |      6782.0    |     6611.6  |   10198.4
                          f16 B=64, M=512, H=16, K=128   |      4052.7    |     3794.5  |    5712.6
                          f32 B=64, M=512, H=16, K=128   |     13385.4    |    13156.5  |   16556.8
                          f16 B=64, M=512, H=16, K=256   |     11463.1    |    10970.2  |    9144.8
                          f32 B=64, M=512, H=16, K=256   |     28055.3    |    28069.9  |   30192.7
                          f16 B=64, M=1024, H=16, K=16   |      6624.4    |     6039.1  |   14015.6
                          f32 B=64, M=1024, H=16, K=16   |     16636.9    |    16379.4  |   28076.7
                          f16 B=64, M=1024, H=16, K=32   |      6869.0    |     6222.1  |   14556.7
                          f32 B=64, M=1024, H=16, K=32   |     19795.5    |    19341.7  |   30747.7
                          f16 B=64, M=1024, H=16, K=64   |      8579.0    |     7476.1  |   17005.3
                          f32 B=64, M=1024, H=16, K=64   |     26451.1    |    25863.7  |   38337.4
                          f16 B=64, M=1024, H=16, K=128  |     15099.4    |    14205.3  |   19211.8
                          f32 B=64, M=1024, H=16, K=128  |     52582.2    |    51596.0  |   60226.9
                          f16 B=64, M=1024, H=16, K=256  |     45330.1    |    43848.4  |   29946.7
                          f32 B=64, M=1024, H=16, K=256  |    111758.3    |   111474.6  |          

Times are in microseconds (us).

[---------- attention (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) ---------]
                                                         |  pr587_d1b0fa  |     main    |   eager  
1 threads: ----------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |      1088.5    |     1063.0  |    1641.4
                          f32 B=384, M=197, H=1, K=88    |      1081.9    |     1040.8  |    1722.3
                          f16 B=384, M=197, H=1, K=80    |      1043.8    |     1020.5  |    1588.0
                          f32 B=384, M=197, H=1, K=80    |      1034.5    |     1005.1  |    1675.2
                          f16 B=384, M=197, H=1, K=64    |       778.6    |      758.3  |    1493.9
                          f32 B=384, M=197, H=1, K=64    |       711.3    |      683.5  |    1570.4
                          f16 B=1024, M=197, H=1, K=88   |      2872.5    |     2817.5  |    4363.1
                          f32 B=1024, M=197, H=1, K=88   |      2845.6    |     2779.7  |    4576.2
                          f16 B=1024, M=197, H=1, K=80   |      2741.3    |     2722.3  |    4233.4
                          f32 B=1024, M=197, H=1, K=80   |      2732.4    |     2694.0  |    4471.1
                          f16 B=1024, M=197, H=1, K=64   |      2027.3    |     1997.4  |    4000.6
                          f32 B=1024, M=197, H=1, K=64   |      1847.7    |     1808.7  |    4174.4
                          f16 B=512, M=197, H=1, K=80    |      1394.9    |     1386.9  |    2166.3
                          f32 B=512, M=197, H=1, K=80    |      1391.1    |     1371.1  |    2283.7
                          f16 B=32, M=197, H=16, K=80    |      1401.9    |     1395.5  |    2543.1
                          f32 B=32, M=197, H=16, K=80    |      1398.9    |     1383.4  |    2674.6
                          f16 B=32, M=197, H=16, K=64    |      1032.7    |     1030.3  |    2327.5
                          f32 B=32, M=197, H=16, K=64    |       952.6    |      933.1  |    2465.0
                          f16 B=32, M=197, H=16, K=128   |      1772.0    |     1760.5  |    3212.4
                          f32 B=32, M=197, H=16, K=128   |      1759.8    |     1745.0  |    3387.1
                          f16 B=256, M=197, H=1, K=88    |       755.4    |      753.8  |    1151.7
                          f32 B=256, M=197, H=1, K=88    |       741.2    |      729.1  |    1206.3
                          f16 B=16, M=197, H=16, K=88    |       753.5    |      751.0  |    1361.1
                          f32 B=16, M=197, H=16, K=88    |       746.3    |      738.2  |    1446.1
                          f16 B=16, M=197, H=16, K=64    |       540.8    |      538.5  |    1198.4
                          f32 B=16, M=197, H=16, K=64    |       498.6    |      490.5  |    1269.8
                          f16 B=16, M=197, H=16, K=128   |       909.4    |      915.2  |    1650.7
                          f32 B=16, M=197, H=16, K=128   |       900.3    |      890.3  |    1730.6
                          f16 B=1, M=4096, H=160, K=128  |    130120.1    |   128716.3  |          
                          f32 B=1, M=4096, H=160, K=128  |    125155.8    |   123542.7  |          
                          f16 B=2, M=4096, H=160, K=128  |    259203.0    |   256642.2  |          
                          f32 B=2, M=4096, H=160, K=128  |    249488.9    |   245869.1  |          
                          f16 B=1, M=8192, H=160, K=128  |    516633.1    |   507149.2  |          
                          f32 B=1, M=8192, H=160, K=128  |    510608.0    |   501763.7  |          
                          f16 B=2, M=8192, H=160, K=128  |   1024965.8    |  1012807.4  |          
                          f32 B=2, M=8192, H=160, K=128  |   1020200.6    |  1008124.8  |          
                          f16 B=1024, M=82, H=8, K=64    |      4822.8    |     4756.3  |    9552.2
                          f32 B=1024, M=82, H=8, K=64    |      4456.2    |     4346.6  |   10218.6
                          f16 B=150, M=256, H=16, K=64   |      5085.5    |     5001.0  |   13734.1
                          f32 B=150, M=256, H=16, K=64   |      4782.9    |     4644.7  |   14799.1
                          f16 B=64, M=256, H=12, K=64    |      1668.3    |     1640.0  |    4390.8
                          f32 B=64, M=256, H=12, K=64    |      1565.1    |     1532.0  |    4743.6
                          f16 B=1, M=4096, H=16, K=40    |      6009.6    |     5921.3  |   18371.5
                          f32 B=1, M=4096, H=16, K=40    |      5391.8    |     5263.0  |   22564.1
                          f16 B=1, M=16384, H=16, K=40   |     89752.5    |    88312.9  |          
                          f32 B=1, M=16384, H=16, K=40   |     82567.0    |    81014.7  |          
                          f16 B=256, M=4096, H=16, K=64  |   1669724.7    |  1657734.4  |          
                          f16 B=16, M=128, H=16, K=16    |       129.7    |      126.8  |     339.5
                          f32 B=16, M=128, H=16, K=16    |       120.2    |      120.6  |     373.7
                          f16 B=16, M=128, H=16, K=32    |       150.7    |      147.7  |     381.4
                          f32 B=16, M=128, H=16, K=32    |       139.8    |      136.6  |     426.2
                          f16 B=16, M=128, H=16, K=64    |       193.1    |      189.5  |     477.6
                          f32 B=16, M=128, H=16, K=64    |       186.8    |      184.7  |     530.4
                          f16 B=16, M=128, H=16, K=128   |       374.4    |      367.7  |     667.8
                          f32 B=16, M=128, H=16, K=128   |       379.2    |      374.2  |     749.5
                          f16 B=16, M=128, H=16, K=256   |       698.1    |      701.9  |    1173.9
                          f32 B=16, M=128, H=16, K=256   |       714.7    |      714.6  |    1419.3
                          f16 B=16, M=512, H=16, K=16    |      1282.5    |     1272.5  |    4180.7
                          f32 B=16, M=512, H=16, K=16    |      1162.0    |     1148.0  |    4759.7
                          f16 B=16, M=512, H=16, K=32    |      1521.2    |     1502.7  |    4573.0
                          f32 B=16, M=512, H=16, K=32    |      1369.7    |     1350.9  |    5028.5
                          f16 B=16, M=512, H=16, K=64    |      1966.7    |     1925.4  |    5352.1
                          f32 B=16, M=512, H=16, K=64    |      1820.9    |     1770.0  |    5705.0
                          f16 B=16, M=512, H=16, K=128   |      3898.7    |     3844.9  |    6844.9
                          f32 B=16, M=512, H=16, K=128   |      3793.0    |     3692.4  |    7410.2
                          f16 B=16, M=512, H=16, K=256   |      7639.9    |     7677.2  |   11763.5
                          f32 B=16, M=512, H=16, K=256   |      7529.5    |     7520.1  |   12469.9
                          f16 B=16, M=1024, H=16, K=16   |      4613.0    |     4571.1  |   16594.1
                          f32 B=16, M=1024, H=16, K=16   |      4152.4    |     4116.6  |   19868.4
                          f16 B=16, M=1024, H=16, K=32   |      5469.5    |     5390.4  |   17848.7
                          f32 B=16, M=1024, H=16, K=32   |      4969.4    |     4856.3  |   20576.0
                          f16 B=16, M=1024, H=16, K=64   |      7185.3    |     7004.4  |   20169.9
                          f32 B=16, M=1024, H=16, K=64   |      6506.2    |     6332.3  |   22698.9
                          f16 B=16, M=1024, H=16, K=128  |     14201.5    |    13939.1  |   25708.1
                          f32 B=16, M=1024, H=16, K=128  |     13513.2    |    13245.0  |   28058.6
                          f16 B=16, M=1024, H=16, K=256  |     28280.6    |    28277.4  |   43267.2
                          f32 B=16, M=1024, H=16, K=256  |     27734.1    |    27680.0  |   46053.4
                          f16 B=64, M=128, H=16, K=16    |       466.0    |      454.9  |    1231.9
                          f32 B=64, M=128, H=16, K=16    |       431.1    |      418.2  |    1369.6
                          f16 B=64, M=128, H=16, K=32    |       546.5    |      531.9  |    1408.3
                          f32 B=64, M=128, H=16, K=32    |       503.7    |      494.8  |    1570.4
                          f16 B=64, M=128, H=16, K=64    |       702.4    |      689.3  |    1777.4
                          f32 B=64, M=128, H=16, K=64    |       670.4    |      649.5  |    1968.9
                          f16 B=64, M=128, H=16, K=128   |      1437.1    |     1417.5  |    2529.1
                          f32 B=64, M=128, H=16, K=128   |      1474.0    |     1437.2  |    2821.2
                          f16 B=64, M=128, H=16, K=256   |      2691.9    |     2680.2  |    4523.1
                          f32 B=64, M=128, H=16, K=256   |      2732.6    |     2752.2  |    5524.1
                          f16 B=64, M=512, H=16, K=16    |      4925.9    |     4880.6  |   16602.6
                          f32 B=64, M=512, H=16, K=16    |      4507.4    |     4416.0  |   18778.5
                          f16 B=64, M=512, H=16, K=32    |      5814.8    |     5769.7  |   18066.7
                          f32 B=64, M=512, H=16, K=32    |      5333.0    |     5192.4  |   19832.4
                          f16 B=64, M=512, H=16, K=64    |      7628.9    |     7508.1  |   21208.7
                          f32 B=64, M=512, H=16, K=64    |      7050.8    |     6824.1  |   22507.2
                          f16 B=64, M=512, H=16, K=128   |     15358.9    |    15071.8  |   27390.6
                          f32 B=64, M=512, H=16, K=128   |     14997.9    |    14547.0  |   29146.1
                          f16 B=64, M=512, H=16, K=256   |     30033.7    |    30054.6  |   46986.3
                          f32 B=64, M=512, H=16, K=256   |     29639.1    |    29807.5  |   50467.8
                          f16 B=64, M=1024, H=16, K=16   |     18163.2    |    17881.9  |   67019.4
                          f32 B=64, M=1024, H=16, K=16   |     16432.6    |    16107.3  |   79155.4
                          f16 B=64, M=1024, H=16, K=32   |     21446.6    |    21207.7  |   72235.5
                          f32 B=64, M=1024, H=16, K=32   |     19618.2    |    19099.1  |   82273.1
                          f16 B=64, M=1024, H=16, K=64   |     28035.7    |    27819.1  |   83763.2
                          f32 B=64, M=1024, H=16, K=64   |     25409.8    |    24771.3  |   91611.2
                          f16 B=64, M=1024, H=16, K=128  |     55941.3    |    55023.8  |  104599.3
                          f32 B=64, M=1024, H=16, K=128  |     53833.8    |    52263.7  |          
                          f16 B=64, M=1024, H=16, K=256  |    111941.1    |   111388.9  |  175503.2
                          f32 B=64, M=1024, H=16, K=256  |    110710.1    |   109982.0  |          
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |       232.7    |      210.6  |     643.6
                          f32 B=384, M=197, H=1, K=88    |       565.9    |      545.5  |    1029.5
                          f16 B=384, M=197, H=1, K=80    |       224.5    |      202.4  |     623.9
                          f32 B=384, M=197, H=1, K=80    |       544.6    |      526.7  |     999.5
                          f16 B=384, M=197, H=1, K=64    |       168.3    |      144.6  |     515.1
                          f32 B=384, M=197, H=1, K=64    |       411.5    |      398.2  |     857.8
                          f16 B=1024, M=197, H=1, K=88   |       613.9    |      552.3  |    1661.8
                          f32 B=1024, M=197, H=1, K=88   |      1475.2    |     1438.1  |    2720.2
                          f16 B=1024, M=197, H=1, K=80   |       588.1    |      530.0  |    1605.3
                          f32 B=1024, M=197, H=1, K=80   |      1422.2    |     1378.9  |    2626.9
                          f16 B=1024, M=197, H=1, K=64   |       428.4    |      368.1  |    1316.1
                          f32 B=1024, M=197, H=1, K=64   |      1050.0    |     1024.9  |    2240.6
                          f16 B=512, M=197, H=1, K=80    |       297.9    |      270.4  |     823.3
                          f32 B=512, M=197, H=1, K=80    |       722.2    |      704.4  |    1332.7
                          f16 B=32, M=197, H=16, K=80    |       305.2    |      278.4  |     972.7
                          f32 B=32, M=197, H=16, K=80    |       735.6    |      720.9  |    1568.6
                          f16 B=32, M=197, H=16, K=64    |       227.2    |      192.9  |     801.1
                          f32 B=32, M=197, H=16, K=64    |       542.1    |      529.1  |    1325.7
                          f16 B=32, M=197, H=16, K=128   |       351.3    |      318.0  |    1146.5
                          f32 B=32, M=197, H=16, K=128   |       939.7    |      917.5  |    2045.7
                          f16 B=256, M=197, H=1, K=88    |       159.2    |      146.0  |     442.7
                          f32 B=256, M=197, H=1, K=88    |       390.0    |      381.0  |     706.5
                          f16 B=16, M=197, H=16, K=88    |       159.6    |      146.8  |     526.1
                          f32 B=16, M=197, H=16, K=88    |       392.7    |      383.3  |     839.9
                          f16 B=16, M=197, H=16, K=64    |       130.4    |      151.9  |     439.0
                          f32 B=16, M=197, H=16, K=64    |       286.9    |      280.0  |     707.9
                          f16 B=16, M=197, H=16, K=128   |       181.3    |      164.5  |     598.4
                          f32 B=16, M=197, H=16, K=128   |       484.0    |      472.0  |    1044.9
                          f16 B=1, M=4096, H=160, K=128  |     19554.6    |    18344.5  |          
                          f32 B=1, M=4096, H=160, K=128  |     70685.5    |    71344.4  |          
                          f16 B=2, M=4096, H=160, K=128  |     39009.3    |    36563.5  |          
                          f32 B=2, M=4096, H=160, K=128  |    141880.7    |   141996.6  |          
                          f16 B=1, M=8192, H=160, K=128  |     77765.8    |    72900.7  |          
                          f32 B=1, M=8192, H=160, K=128  |    284179.1    |   284758.3  |          
                          f16 B=2, M=8192, H=160, K=128  |    155292.4    |   145710.6  |          
                          f32 B=2, M=8192, H=160, K=128  |    569032.7    |   568170.0  |          
                          f16 B=1024, M=82, H=8, K=64    |      1142.2    |      961.6  |    2798.5
                          f32 B=1024, M=82, H=8, K=64    |      2435.5    |     2373.1  |    4951.5
                          f16 B=150, M=256, H=16, K=64   |      1106.5    |      926.2  |    3821.3
                          f32 B=150, M=256, H=16, K=64   |      2665.6    |     2624.9  |    7918.7
                          f16 B=64, M=256, H=12, K=64    |       362.8    |      303.3  |    1254.6
                          f32 B=64, M=256, H=12, K=64    |       877.9    |      858.8  |    2645.0
                          f16 B=1, M=4096, H=16, K=40    |      1222.8    |     1072.2  |    5996.5
                          f32 B=1, M=4096, H=16, K=40    |      3119.6    |     3053.7  |   11513.7
                          f16 B=1, M=16384, H=16, K=40   |     19529.5    |    15403.6  |          
                          f32 B=1, M=16384, H=16, K=40   |     45840.5    |    44929.7  |          
                          f16 B=256, M=4096, H=16, K=64  |    272770.3    |   240596.0  |          
                          f16 B=16, M=128, H=16, K=16    |       130.3    |      151.9  |     351.0
                          f32 B=16, M=128, H=16, K=16    |       126.6    |      148.8  |     344.7
                          f16 B=16, M=128, H=16, K=32    |       128.7    |      177.0  |     394.4
                          f32 B=16, M=128, H=16, K=32    |       128.6    |      184.5  |     340.0
                          f16 B=16, M=128, H=16, K=64    |       126.5    |      145.4  |     351.3
                          f32 B=16, M=128, H=16, K=64    |       127.4    |      171.4  |     352.3
                          f16 B=16, M=128, H=16, K=128   |       130.2    |      171.6  |     345.0
                          f32 B=16, M=128, H=16, K=128   |       203.7    |      196.7  |     503.1
                          f16 B=16, M=128, H=16, K=256   |       149.8    |      156.8  |     388.2
                          f32 B=16, M=128, H=16, K=256   |       389.7    |      388.2  |     900.9
                          f16 B=16, M=512, H=16, K=16    |       289.8    |      261.1  |    1191.4
                          f32 B=16, M=512, H=16, K=16    |       654.5    |      642.7  |    2520.3
                          f16 B=16, M=512, H=16, K=32    |       305.2    |      268.5  |    1287.8
                          f32 B=16, M=512, H=16, K=32    |       769.2    |      748.9  |    2752.7
                          f16 B=16, M=512, H=16, K=64    |       387.4    |      327.2  |    1437.7
                          f32 B=16, M=512, H=16, K=64    |      1032.8    |     1004.8  |    3000.1
                          f16 B=16, M=512, H=16, K=128   |       670.9    |      614.1  |    1757.2
                          f32 B=16, M=512, H=16, K=128   |      2046.9    |     2001.5  |    4532.4
                          f16 B=16, M=512, H=16, K=256   |      1609.7    |     1525.0  |    2612.2
                          f32 B=16, M=512, H=16, K=256   |      4149.4    |     4174.8  |    7621.3
                          f16 B=16, M=1024, H=16, K=16   |       969.8    |      876.6  |    5043.8
                          f32 B=16, M=1024, H=16, K=16   |      2321.6    |     2286.3  |   10962.5
                          f16 B=16, M=1024, H=16, K=32   |      1012.8    |      908.2  |    5179.5
                          f32 B=16, M=1024, H=16, K=32   |      2739.2    |     2674.6  |   11440.7
                          f16 B=16, M=1024, H=16, K=64   |      1269.0    |     1100.6  |    5472.6
                          f32 B=16, M=1024, H=16, K=64   |      3686.7    |     3600.6  |   12196.1
                          f16 B=16, M=1024, H=16, K=128  |      2212.9    |     2062.6  |    6217.9
                          f32 B=16, M=1024, H=16, K=128  |      7289.8    |     7193.1  |   17906.8
                          f16 B=16, M=1024, H=16, K=256  |      5994.7    |     5683.6  |    8804.5
                          f32 B=16, M=1024, H=16, K=256  |     15270.4    |    15334.5  |   29569.0
                          f16 B=64, M=128, H=16, K=16    |       137.1    |      148.6  |     381.3
                          f32 B=64, M=128, H=16, K=16    |       236.4    |      230.5  |     689.7
                          f16 B=64, M=128, H=16, K=32    |       139.1    |      144.0  |     451.7
                          f32 B=64, M=128, H=16, K=32    |       279.3    |      270.6  |     813.7
                          f16 B=64, M=128, H=16, K=64    |       176.7    |      153.4  |     587.2
                          f32 B=64, M=128, H=16, K=64    |       375.8    |      370.7  |    1085.4
                          f16 B=64, M=128, H=16, K=128   |       307.3    |      284.4  |     862.4
                          f32 B=64, M=128, H=16, K=128   |       764.7    |      740.5  |    1720.8
                          f16 B=64, M=128, H=16, K=256   |       563.5    |      529.9  |    1408.8
                          f32 B=64, M=128, H=16, K=256   |      1463.0    |     1460.1  |    2982.0
                          f16 B=64, M=512, H=16, K=16    |      1103.8    |      987.3  |    4603.3
                          f32 B=64, M=512, H=16, K=16    |      2448.3    |     2408.6  |   10136.8
                          f16 B=64, M=512, H=16, K=32    |      1171.5    |     1027.3  |    4933.2
                          f32 B=64, M=512, H=16, K=32    |      2936.2    |     2862.2  |   11037.5
                          f16 B=64, M=512, H=16, K=64    |      1486.7    |     1255.5  |    5622.9
                          f32 B=64, M=512, H=16, K=64    |      3935.4    |     3861.0  |   12125.9
                          f16 B=64, M=512, H=16, K=128   |      2636.8    |     2405.2  |    6901.4
                          f32 B=64, M=512, H=16, K=128   |      7993.4    |     7843.0  |   18276.2
                          f16 B=64, M=512, H=16, K=256   |      6320.3    |     5983.3  |   10359.0
                          f32 B=64, M=512, H=16, K=256   |     16307.9    |    16371.9  |   31659.4
                          f16 B=64, M=1024, H=16, K=16   |      3766.5    |     3408.7  |   20126.8
                          f32 B=64, M=1024, H=16, K=16   |      8982.6    |     8819.8  |   43940.0
                          f16 B=64, M=1024, H=16, K=32   |      3957.5    |     3531.5  |   20639.3
                          f32 B=64, M=1024, H=16, K=32   |     10728.4    |    10498.0  |   46144.9
                          f16 B=64, M=1024, H=16, K=64   |      4944.6    |     4277.0  |   22159.0
                          f32 B=64, M=1024, H=16, K=64   |     14440.2    |    14147.1  |   50392.1
                          f16 B=64, M=1024, H=16, K=128  |      8797.3    |     8174.7  |   24830.6
                          f32 B=64, M=1024, H=16, K=128  |     28974.2    |    28411.4  |          
                          f16 B=64, M=1024, H=16, K=256  |     23887.4    |    22481.4  |   35353.6
                          f32 B=64, M=1024, H=16, K=256  |     60551.0    |    60469.3  |          

Times are in microseconds (us).
V100/P100 bw
[----------------------- attention backward (attn_bias=<class 'NoneType'>) -----------------------] 
                                                         |  pr587_d1b0fa  |     main    |  vanilla 
1 threads: ----------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |      6649.2    |     6662.4  |    3591.8
                          f32 B=384, M=197, H=1, K=88    |      9755.5    |     9584.0  |    4337.0
                          f16 B=384, M=197, H=1, K=80    |      6301.1    |     6192.6  |    3437.6
                          f32 B=384, M=197, H=1, K=80    |      9321.0    |     9158.2  |    4107.5
                          f16 B=384, M=197, H=1, K=64    |      3600.8    |     3518.0  |    2927.8
                          f32 B=384, M=197, H=1, K=64    |      6319.8    |     6136.2  |    3451.8
                          f16 B=1024, M=197, H=1, K=88   |     16382.6    |    16310.1  |    9852.8
                          f32 B=1024, M=197, H=1, K=88   |     26252.9    |    25756.1  |   12151.9
                          f16 B=1024, M=197, H=1, K=80   |     15587.4    |    15522.3  |    9330.4
                          f32 B=1024, M=197, H=1, K=80   |     25072.9    |    24578.1  |   11356.0
                          f16 B=1024, M=197, H=1, K=64   |      8996.6    |     8935.0  |    7719.5
                          f32 B=1024, M=197, H=1, K=64   |     17056.3    |    16599.7  |    9475.8
                          f16 B=512, M=197, H=1, K=80    |      7968.5    |     7927.7  |    4632.1
                          f32 B=512, M=197, H=1, K=80    |     13051.1    |    12806.3  |    5525.4
                          f16 B=32, M=197, H=16, K=80    |      8153.0    |     8129.2  |    4891.3
                          f32 B=32, M=197, H=16, K=80    |     13046.0    |    12774.1  |    5811.9
                          f16 B=32, M=197, H=16, K=64    |      4564.6    |     4506.7  |    4068.1
                          f32 B=32, M=197, H=16, K=64    |      8907.0    |     8681.2  |    4838.7
                          f16 B=32, M=197, H=16, K=128   |      9555.2    |     9626.9  |    5991.5
                          f32 B=32, M=197, H=16, K=128   |     15811.5    |    15544.2  |    7540.1
                          f16 B=256, M=197, H=1, K=88    |      4806.2    |     4769.0  |    2451.1
                          f32 B=256, M=197, H=1, K=88    |      6853.3    |     6682.9  |    2906.0
                          f16 B=16, M=197, H=16, K=88    |      4804.7    |     4781.1  |    2549.4
                          f32 B=16, M=197, H=16, K=88    |      6771.8    |     6629.7  |    3063.5
                          f16 B=16, M=197, H=16, K=64    |      2626.1    |     2609.3  |    2042.1
                          f32 B=16, M=197, H=16, K=64    |      4435.4    |     4322.7  |    2445.7
                          f16 B=16, M=197, H=16, K=128   |      5473.2    |     5432.6  |    3014.7
                          f32 B=16, M=197, H=16, K=128   |      8056.4    |     7794.2  |    3670.1
                          f16 B=1, M=4096, H=160, K=128  |   1040725.6    |  1033138.6  |          
                          f32 B=1, M=4096, H=160, K=128  |   1263982.9    |  1264717.2  |          
                          f16 B=2, M=4096, H=160, K=128  |   1693711.8    |  1689231.2  |          
                          f32 B=2, M=4096, H=160, K=128  |   2513370.6    |  2511754.6  |          
                          f16 B=1, M=8192, H=160, K=128  |   4147490.4    |  4110718.8  |          
                          f32 B=1, M=8192, H=160, K=128  |   5047073.2    |  5051277.9  |          
                          f16 B=2, M=8192, H=160, K=128  |   6774949.8    |  6751365.7  |          
                          f16 B=1024, M=82, H=8, K=64    |     22068.6    |    22967.1  |   18046.4
                          f32 B=1024, M=82, H=8, K=64    |     47745.6    |    43698.8  |   22978.7
                          f16 B=150, M=256, H=16, K=64   |     23965.7    |    23440.9  |   24551.6
                          f32 B=150, M=256, H=16, K=64   |     37510.4    |    37480.4  |   32205.0
                          f16 B=64, M=256, H=12, K=64    |      7592.6    |     7491.8  |    7716.8
                          f32 B=64, M=256, H=12, K=64    |     12175.9    |    12214.8  |    9890.6
                          f16 B=1, M=4096, H=16, K=40    |    137900.8    |   135707.0  |   29317.2
                          f32 B=1, M=4096, H=16, K=40    |    144544.5    |   145042.0  |   37192.7
                          f16 B=1, M=16384, H=16, K=40   |   2201518.0    |  2150814.2  |          
                          f32 B=1, M=16384, H=16, K=40   |   2314553.1    |  2295614.2  |          
                          f16 B=16, M=128, H=16, K=16    |       521.6    |      517.6  |     572.7
                          f32 B=16, M=128, H=16, K=16    |       656.5    |      652.2  |     691.7
                          f16 B=16, M=128, H=16, K=32    |       606.5    |      601.6  |     677.2
                          f32 B=16, M=128, H=16, K=32    |       816.9    |      813.6  |     828.2
                          f16 B=16, M=128, H=16, K=64    |       788.8    |      778.9  |     891.9
                          f32 B=16, M=128, H=16, K=64    |      1160.3    |     1163.4  |    1088.5
                          f16 B=16, M=128, H=16, K=128   |      1614.2    |     1607.0  |    1337.7
                          f32 B=16, M=128, H=16, K=128   |      2250.5    |     2259.3  |    1666.7
                          f16 B=16, M=128, H=16, K=256   |      4094.0    |     4062.0  |    2507.3
                          f32 B=16, M=128, H=16, K=256   |      4644.0    |     4647.4  |    3356.5
                          f16 B=16, M=512, H=16, K=16    |      7888.2    |     7866.8  |    6958.1
                          f32 B=16, M=512, H=16, K=16    |      9732.6    |     9792.9  |    8610.8
                          f16 B=16, M=512, H=16, K=32    |      9125.4    |     9111.2  |    7500.4
                          f32 B=16, M=512, H=16, K=32    |     11366.6    |    11388.9  |    9295.5
                          f16 B=16, M=512, H=16, K=64    |     11362.6    |    11402.2  |    8911.2
                          f32 B=16, M=512, H=16, K=64    |     16133.1    |    16094.5  |   11084.9
                          f16 B=16, M=512, H=16, K=128   |     24268.8    |    24449.4  |   12629.6
                          f32 B=16, M=512, H=16, K=128   |     32198.8    |    32234.3  |   15264.6
                          f16 B=16, M=512, H=16, K=256   |     52845.5    |    52619.0  |   23373.4
                          f32 B=16, M=512, H=16, K=256   |     65542.0    |    65241.9  |   27094.9
                          f16 B=16, M=1024, H=16, K=16   |     31862.8    |    31510.4  |   26565.4
                          f32 B=16, M=1024, H=16, K=16   |     38406.7    |    38369.3  |   32614.4
                          f16 B=16, M=1024, H=16, K=32   |     36278.0    |    36294.3  |   28420.7
                          f32 B=16, M=1024, H=16, K=32   |     44660.8    |    44377.6  |   35432.3
                          f16 B=16, M=1024, H=16, K=64   |     45482.3    |    45366.8  |   32269.4
                          f32 B=16, M=1024, H=16, K=64   |     62820.0    |    62745.1  |   39776.9
                          f16 B=16, M=1024, H=16, K=128  |     99333.6    |    99353.4  |   43627.4
                          f32 B=16, M=1024, H=16, K=128  |    127208.8    |   127366.1  |   51474.4
                          f16 B=16, M=1024, H=16, K=256  |    205069.4    |   204810.4  |   81201.6
                          f32 B=16, M=1024, H=16, K=256  |    257239.2    |   258126.0  |   92288.2
                          f16 B=64, M=128, H=16, K=16    |      1748.9    |     1730.2  |    2117.6
                          f32 B=64, M=128, H=16, K=16    |      2428.1    |     2428.4  |    2576.5
                          f16 B=64, M=128, H=16, K=32    |      2069.1    |     2070.4  |    2487.9
                          f32 B=64, M=128, H=16, K=32    |      3082.5    |     3084.5  |    3078.0
                          f16 B=64, M=128, H=16, K=64    |      2722.0    |     2718.5  |    3317.9
                          f32 B=64, M=128, H=16, K=64    |      4423.2    |     4421.9  |    4237.9
                          f16 B=64, M=128, H=16, K=128   |      5634.6    |     5646.9  |    5284.4
                          f32 B=64, M=128, H=16, K=128   |      8567.7    |     8635.8  |    6958.5
                          f16 B=64, M=128, H=16, K=256   |     13974.5    |    13961.0  |   10316.2
                          f32 B=64, M=128, H=16, K=256   |     17341.4    |    17417.3  |   13584.2
                          f16 B=64, M=512, H=16, K=16    |     26968.8    |    26936.5  |   27427.8
                          f32 B=64, M=512, H=16, K=16    |     36557.4    |    36403.9  |   33753.3
                          f16 B=64, M=512, H=16, K=32    |     31473.7    |    31542.1  |   30266.4
                          f32 B=64, M=512, H=16, K=32    |     43069.9    |    42935.1  |   37398.0
                          f16 B=64, M=512, H=16, K=64    |     39785.8    |    39718.3  |   36109.8
                          f32 B=64, M=512, H=16, K=64    |     61687.2    |    61577.7  |   43677.3
                          f16 B=64, M=512, H=16, K=128   |     85037.4    |    86608.8  |   51294.6
                          f32 B=64, M=512, H=16, K=128   |    123019.4    |   123085.0  |   61843.3
                          f16 B=64, M=512, H=16, K=256   |    182002.3    |   179902.4  |   99364.5
                          f32 B=64, M=512, H=16, K=256   |    250282.2    |   250051.9  |  111501.9
                          f16 B=64, M=1024, H=16, K=16   |    108675.3    |   107724.9  |  106757.7
                          f32 B=64, M=1024, H=16, K=16   |    144402.4    |   144482.4  |          
                          f16 B=64, M=1024, H=16, K=32   |    125648.2    |   124733.4  |  114732.4
                          f32 B=64, M=1024, H=16, K=32   |    168998.9    |   168876.8  |          
                          f16 B=64, M=1024, H=16, K=64   |    157730.3    |   157059.1  |  131304.1
                          f32 B=64, M=1024, H=16, K=64   |    241102.4    |   241476.9  |          
                          f16 B=64, M=1024, H=16, K=128  |    335655.4    |   334298.9  |  179659.1
                          f32 B=64, M=1024, H=16, K=128  |    480850.5    |   483706.8  |          
                          f16 B=64, M=1024, H=16, K=256  |    695390.1    |   692904.0  |          
                          f32 B=64, M=1024, H=16, K=256  |    983258.7    |   982044.1  |          
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |      1876.9    |     1809.0  |    1374.5
                          f32 B=384, M=197, H=1, K=88    |      4430.6    |     4340.1  |    2247.7
                          f16 B=384, M=197, H=1, K=80    |      1805.6    |     1732.4  |    1282.2
                          f32 B=384, M=197, H=1, K=80    |      4073.2    |     3974.2  |    2163.4
                          f16 B=384, M=197, H=1, K=64    |      1173.4    |     1134.9  |    1044.1
                          f32 B=384, M=197, H=1, K=64    |      2756.2    |     2689.9  |    1741.7
                          f16 B=1024, M=197, H=1, K=88   |      4907.4    |     4707.0  |    3724.7
                          f32 B=1024, M=197, H=1, K=88   |     10873.8    |    10546.7  |    6061.7
                          f16 B=1024, M=197, H=1, K=80   |      4716.0    |     4523.2  |    3330.1
                          f32 B=1024, M=197, H=1, K=80   |      9953.5    |     9609.2  |    5719.5
                          f16 B=1024, M=197, H=1, K=64   |      2877.5    |     2799.2  |    2675.3
                          f32 B=1024, M=197, H=1, K=64   |      6789.0    |     6586.8  |    4507.3
                          f16 B=512, M=197, H=1, K=80    |      2476.6    |     2380.2  |    1684.0
                          f32 B=512, M=197, H=1, K=80    |      5416.2    |     5267.1  |    2874.5
                          f16 B=32, M=197, H=16, K=80    |      2488.7    |     2393.9  |    1800.3
                          f32 B=32, M=197, H=16, K=80    |      5528.6    |     5392.4  |    3029.9
                          f16 B=32, M=197, H=16, K=64    |      1601.5    |     1558.5  |    1450.9
                          f32 B=32, M=197, H=16, K=64    |      3751.1    |     3636.3  |    2410.5
                          f16 B=32, M=197, H=16, K=128   |      2886.1    |     2782.4  |    2211.6
                          f32 B=32, M=197, H=16, K=128   |      6796.0    |     6643.3  |    4061.8
                          f16 B=256, M=197, H=1, K=88    |      1412.7    |     1357.9  |     947.7
                          f32 B=256, M=197, H=1, K=88    |      2938.8    |     2884.8  |    1533.7
                          f16 B=16, M=197, H=16, K=88    |      1395.3    |     1346.8  |     970.9
                          f32 B=16, M=197, H=16, K=88    |      2854.5    |     2801.2  |    1629.0
                          f16 B=16, M=197, H=16, K=64    |       791.4    |      766.2  |     931.7
                          f32 B=16, M=197, H=16, K=64    |      1887.5    |     1838.3  |    1287.4
                          f16 B=16, M=197, H=16, K=128   |      1565.2    |     1513.8  |    1135.0
                          f32 B=16, M=197, H=16, K=128   |      3469.7    |     3407.8  |    2034.9
                          f16 B=1, M=4096, H=160, K=128  |    173529.7    |   169073.2  |          
                          f32 B=1, M=4096, H=160, K=128  |    542426.7    |   550508.2  |          
                          f16 B=2, M=4096, H=160, K=128  |    345451.9    |   340149.9  |          
                          f32 B=2, M=4096, H=160, K=128  |   1093717.2    |  1102674.8  |          
                          f16 B=1, M=8192, H=160, K=128  |    693881.2    |   681002.9  |          
                          f32 B=1, M=8192, H=160, K=128  |   2180832.4    |  2200639.4  |          
                          f16 B=2, M=8192, H=160, K=128  |   1389384.8    |  1364914.7  |          
                          f16 B=1024, M=82, H=8, K=64    |      9417.7    |     9059.6  |    5802.6
                          f32 B=1024, M=82, H=8, K=64    |     14703.5    |    14694.7  |   11037.4
                          f16 B=150, M=256, H=16, K=64   |      5767.6    |     5693.1  |    7563.8
                          f32 B=150, M=256, H=16, K=64   |     16551.3    |    16696.3  |   16305.4
                          f16 B=64, M=256, H=12, K=64    |      1888.7    |     1852.1  |    2386.4
                          f32 B=64, M=256, H=12, K=64    |      5414.0    |     5462.4  |    4969.2
                          f16 B=1, M=4096, H=16, K=40    |     48364.8    |    47164.3  |    8362.1
                          f32 B=1, M=4096, H=16, K=40    |    113126.6    |   113058.1  |   19476.1
                          f16 B=1, M=16384, H=16, K=40   |    772504.0    |   759023.3  |          
                          f32 B=1, M=16384, H=16, K=40   |   1807712.6    |  1804493.4  |          
                          f16 B=16, M=128, H=16, K=16    |       422.2    |      476.6  |     712.1
                          f32 B=16, M=128, H=16, K=16    |       483.7    |      619.0  |     651.7
                          f16 B=16, M=128, H=16, K=32    |       411.0    |      445.6  |     776.2
                          f32 B=16, M=128, H=16, K=32    |       478.2    |      555.9  |     662.4
                          f16 B=16, M=128, H=16, K=64    |       411.6    |      517.8  |     680.9
                          f32 B=16, M=128, H=16, K=64    |       566.2    |      601.7  |     736.3
                          f16 B=16, M=128, H=16, K=128   |       417.1    |      451.1  |     686.1
                          f32 B=16, M=128, H=16, K=128   |      1105.3    |     1105.7  |    1007.0
                          f16 B=16, M=128, H=16, K=256   |      1054.3    |     1049.9  |     888.0
                          f32 B=16, M=128, H=16, K=256   |      2190.7    |     2192.3  |    1855.9
                          f16 B=16, M=512, H=16, K=16    |      1760.9    |     1731.3  |    1896.6
                          f32 B=16, M=512, H=16, K=16    |      4463.1    |     4476.5  |    4249.5
                          f16 B=16, M=512, H=16, K=32    |      1989.2    |     1948.7  |    2095.1
                          f32 B=16, M=512, H=16, K=32    |      5601.3    |     5679.8  |    4600.4
                          f16 B=16, M=512, H=16, K=64    |      2479.0    |     2448.1  |    2577.1
                          f32 B=16, M=512, H=16, K=64    |      7603.8    |     7617.6  |    5491.4
                          f16 B=16, M=512, H=16, K=128   |      5047.9    |     4891.9  |    3380.6
                          f32 B=16, M=512, H=16, K=128   |     14991.3    |    15084.0  |    8860.2
                          f16 B=16, M=512, H=16, K=256   |     13036.1    |    12952.6  |    5381.5
                          f32 B=16, M=512, H=16, K=256   |     29689.4    |    29870.0  |   16766.3
                          f16 B=16, M=1024, H=16, K=16   |      6913.3    |     6817.0  |    6986.0
                          f32 B=16, M=1024, H=16, K=16   |     18030.7    |    18132.2  |   16098.9
                          f16 B=16, M=1024, H=16, K=32   |      7679.2    |     7568.5  |    7399.8
                          f32 B=16, M=1024, H=16, K=32   |     21985.2    |    22038.8  |   17093.0
                          f16 B=16, M=1024, H=16, K=64   |      9503.7    |     9320.6  |    8623.2
                          f32 B=16, M=1024, H=16, K=64   |     29601.6    |    29998.6  |   20238.1
                          f16 B=16, M=1024, H=16, K=128  |     19456.8    |    18972.4  |   10503.1
                          f32 B=16, M=1024, H=16, K=128  |     58312.3    |    58953.5  |   33141.1
                          f16 B=16, M=1024, H=16, K=256  |     49957.8    |    49804.3  |   17122.0
                          f32 B=16, M=1024, H=16, K=256  |    116490.0    |   116887.9  |   60004.3
                          f16 B=64, M=128, H=16, K=16    |       445.7    |      509.3  |     673.2
                          f32 B=64, M=128, H=16, K=16    |      1037.8    |     1029.9  |    1234.0
                          f16 B=64, M=128, H=16, K=32    |       537.8    |      546.7  |     813.7
                          f32 B=64, M=128, H=16, K=32    |      1408.0    |     1408.3  |    1533.5
                          f16 B=64, M=128, H=16, K=64    |       764.5    |      745.2  |    1186.2
                          f32 B=64, M=128, H=16, K=64    |      2017.9    |     2019.2  |    2154.9
                          f16 B=64, M=128, H=16, K=128   |      1470.0    |     1417.3  |    1916.9
                          f32 B=64, M=128, H=16, K=128   |      3964.0    |     3950.5  |    3779.3
                          f16 B=64, M=128, H=16, K=256   |      3819.3    |     3808.4  |    3450.8
                          f32 B=64, M=128, H=16, K=256   |      7980.9    |     7983.2  |    7252.1
                          f16 B=64, M=512, H=16, K=16    |      6283.7    |     6187.6  |    7461.6
                          f32 B=64, M=512, H=16, K=16    |     16196.5    |    16328.7  |   16558.3
                          f16 B=64, M=512, H=16, K=32    |      7137.1    |     7026.3  |    8314.3
                          f32 B=64, M=512, H=16, K=32    |     20467.5    |    20583.0  |   18328.0
                          f16 B=64, M=512, H=16, K=64    |      9254.8    |     9087.4  |   10425.2
                          f32 B=64, M=512, H=16, K=64    |     27413.4    |    27696.9  |   22791.7
                          f16 B=64, M=512, H=16, K=128   |     18015.8    |    17574.7  |   14673.6
                          f32 B=64, M=512, H=16, K=128   |     54501.0    |    54678.0  |   39872.5
                          f16 B=64, M=512, H=16, K=256   |     47562.4    |    47507.7  |   26896.4
                          f32 B=64, M=512, H=16, K=256   |    109645.9    |   109608.7  |   75908.0
                          f16 B=64, M=1024, H=16, K=16   |     24843.0    |    24447.9  |   28512.3
                          f32 B=64, M=1024, H=16, K=16   |     65052.1    |    65064.6  |          
                          f16 B=64, M=1024, H=16, K=32   |     27623.2    |    27254.5  |   30504.4
                          f32 B=64, M=1024, H=16, K=32   |     79865.6    |    80142.4  |          
                          f16 B=64, M=1024, H=16, K=64   |     35179.7    |    34677.9  |   37021.6
                          f32 B=64, M=1024, H=16, K=64   |    108293.9    |   108919.4  |          
                          f16 B=64, M=1024, H=16, K=128  |     70080.5    |    68389.8  |   49203.3
                          f32 B=64, M=1024, H=16, K=128  |    213234.4    |   214535.3  |          
                          f16 B=64, M=1024, H=16, K=256  |    183460.3    |   183195.8  |          
                          f32 B=64, M=1024, H=16, K=256  |    423675.7    |   425804.3  |          

Times are in microseconds (us).

[----- attention backward (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) -----]
                                                         |  pr587_d1b0fa  |     main    |  vanilla 
1 threads: ----------------------------------------------------------------------------------------
  (Quadro_GP100)          f16 B=384, M=197, H=1, K=88    |      4264.6    |     4252.9  |    3568.0
                          f32 B=384, M=197, H=1, K=88    |      6613.5    |     6516.4  |    4266.8
                          f16 B=384, M=197, H=1, K=80    |      4040.2    |     4024.4  |    3422.3
                          f32 B=384, M=197, H=1, K=80    |      6284.2    |     6216.6  |    4078.8
                          f16 B=384, M=197, H=1, K=64    |      2441.7    |     2367.5  |    2914.9
                          f32 B=384, M=197, H=1, K=64    |      4439.2    |     4350.7  |    3435.6
                          f16 B=1024, M=197, H=1, K=88   |     10588.4    |    10541.5  |    9757.7
                          f32 B=1024, M=197, H=1, K=88   |     17800.8    |    17610.3  |   12024.4
                          f16 B=1024, M=197, H=1, K=80   |     10032.8    |     9913.7  |    9288.2
                          f32 B=1024, M=197, H=1, K=80   |     16952.6    |    16804.2  |   11179.8
                          f16 B=1024, M=197, H=1, K=64   |      6065.6    |     5802.8  |    7663.7
                          f32 B=1024, M=197, H=1, K=64   |     12004.7    |    11778.8  |    9444.2
                          f16 B=512, M=197, H=1, K=80    |      5083.8    |     5037.0  |    4611.7
                          f32 B=512, M=197, H=1, K=80    |      8797.7    |     8749.3  |    5465.3
                          f16 B=32, M=197, H=16, K=80    |      5156.5    |     5118.0  |    4819.7
                          f32 B=32, M=197, H=16, K=80    |      8773.3    |     8713.1  |    5732.0
                          f16 B=32, M=197, H=16, K=64    |      3089.8    |     2979.9  |    4031.1
                          f32 B=32, M=197, H=16, K=64    |      6213.4    |     6085.4  |    4790.2
                          f16 B=32, M=197, H=16, K=128   |      6068.0    |     6053.1  |    5955.1
                          f32 B=32, M=197, H=16, K=128   |     10743.0    |    10682.1  |    7341.6
                          f16 B=256, M=197, H=1, K=88    |      3038.4    |     3074.8  |    2440.7
                          f32 B=256, M=197, H=1, K=88    |      4617.1    |     4561.4  |    2860.9
                          f16 B=16, M=197, H=16, K=88    |      2990.5    |     3082.1  |    2523.8
                          f32 B=16, M=197, H=16, K=88    |      4572.3    |     4517.3  |    3017.4
                          f16 B=16, M=197, H=16, K=64    |      1782.1    |     1774.8  |    2029.5
                          f32 B=16, M=197, H=16, K=64    |      3128.9    |     3061.0  |    2429.9
                          f16 B=16, M=197, H=16, K=128   |      3495.7    |     3489.2  |    3004.0
                          f32 B=16, M=197, H=16, K=128   |      5415.1    |     5322.8  |    3613.6
                          f16 B=1, M=4096, H=160, K=128  |    535384.8    |   533256.0  |          
                          f32 B=1, M=4096, H=160, K=128  |    642492.0    |   644655.4  |          
                          f16 B=2, M=4096, H=160, K=128  |    866599.3    |   868892.7  |          
                          f32 B=2, M=4096, H=160, K=128  |   1285173.2    |  1281898.4  |          
                          f16 B=1, M=8192, H=160, K=128  |   2100152.9    |  2085529.2  |          
                          f32 B=1, M=8192, H=160, K=128  |   2558146.5    |  2548207.5  |          
                          f16 B=2, M=8192, H=160, K=128  |   3421179.4    |  3437383.9  |          
                          f16 B=1024, M=82, H=8, K=64    |     19812.3    |    20984.0  |   18067.6
                          f32 B=1024, M=82, H=8, K=64    |     40503.9    |    37605.3  |   22806.3
                          f16 B=150, M=256, H=16, K=64   |     15389.6    |    15327.7  |   24392.8
                          f32 B=150, M=256, H=16, K=64   |     24777.0    |    24784.3  |   31835.0
                          f16 B=64, M=256, H=12, K=64    |      4938.7    |     4922.7  |    7678.4
                          f32 B=64, M=256, H=12, K=64    |      8036.5    |     8068.8  |    9808.9
                          f16 B=1, M=4096, H=16, K=40    |     70662.0    |    69262.6  |   29178.9
                          f32 B=1, M=4096, H=16, K=40    |     73890.4    |    73372.5  |   37290.0
                          f16 B=1, M=16384, H=16, K=40   |   1108681.7    |  1082724.4  |          
                          f32 B=1, M=16384, H=16, K=40   |   1168963.1    |  1156356.8  |          
                          f16 B=16, M=128, H=16, K=16    |       387.4    |      403.8  |     573.1
                          f32 B=16, M=128, H=16, K=16    |       515.1    |      514.4  |     693.1
                          f16 B=16, M=128, H=16, K=32    |       458.5    |      454.5  |     670.3
                          f32 B=16, M=128, H=16, K=32    |       647.3    |      642.2  |     821.0
                          f16 B=16, M=128, H=16, K=64    |       617.7    |      613.9  |     885.4
                          f32 B=16, M=128, H=16, K=64    |       919.2    |      922.2  |    1080.5
                          f16 B=16, M=128, H=16, K=128   |      1241.0    |     1239.4  |    1329.3
                          f32 B=16, M=128, H=16, K=128   |      1770.4    |     1777.7  |    1662.2
                          f16 B=16, M=128, H=16, K=256   |      3375.3    |     3354.0  |    2500.6
                          f32 B=16, M=128, H=16, K=256   |      3634.5    |     3651.8  |    3291.1
                          f16 B=16, M=512, H=16, K=16    |      4478.1    |     4427.2  |    6857.0
                          f32 B=16, M=512, H=16, K=16    |      5532.7    |     5531.2  |    8419.4
                          f16 B=16, M=512, H=16, K=32    |      5225.1    |     5193.8  |    7481.4
                          f32 B=16, M=512, H=16, K=32    |      6641.7    |     6608.2  |    9166.2
                          f16 B=16, M=512, H=16, K=64    |      6619.1    |     6536.6  |    8855.1
                          f32 B=16, M=512, H=16, K=64    |      9451.8    |     9423.3  |   10849.5
                          f16 B=16, M=512, H=16, K=128   |     13950.4    |    13962.6  |   12345.1
                          f32 B=16, M=512, H=16, K=128   |     18661.3    |    18679.2  |   15003.2
                          f16 B=16, M=512, H=16, K=256   |     31377.2    |    31425.8  |   23147.7
                          f32 B=16, M=512, H=16, K=256   |     37663.3    |    37686.5  |   26873.0
                          f16 B=16, M=1024, H=16, K=16   |     17060.8    |    16928.6  |   26395.9
                          f32 B=16, M=1024, H=16, K=16   |     20645.6    |    20647.1  |   32762.1
                          f16 B=16, M=1024, H=16, K=32   |     19545.5    |    19584.6  |   28100.2
                          f32 B=16, M=1024, H=16, K=32   |     24112.3    |    24153.9  |   35231.2
                          f16 B=16, M=1024, H=16, K=64   |     25042.2    |    24358.8  |   31949.6
                          f32 B=16, M=1024, H=16, K=64   |     34145.3    |    34135.4  |   39247.4
                          f16 B=16, M=1024, H=16, K=128  |     52546.4    |    52553.6  |   42857.2
                          f32 B=16, M=1024, H=16, K=128  |     69014.6    |    68490.7  |   50818.2
                          f16 B=16, M=1024, H=16, K=256  |    113898.4    |   113179.1  |   79246.2
                          f32 B=16, M=1024, H=16, K=256  |    138852.1    |   138958.7  |   90470.1
                          f16 B=64, M=128, H=16, K=16    |      1317.8    |     1313.4  |    2093.5
                          f32 B=64, M=128, H=16, K=16    |      1911.1    |     1912.5  |    2551.8
                          f16 B=64, M=128, H=16, K=32    |      1604.8    |     1605.9  |    2479.4
                          f32 B=64, M=128, H=16, K=32    |      2412.4    |     2424.6  |    3055.7
                          f16 B=64, M=128, H=16, K=64    |      2129.2    |     2135.5  |    3306.4
                          f32 B=64, M=128, H=16, K=64    |      3514.7    |     3512.2  |    4184.7
                          f16 B=64, M=128, H=16, K=128   |      4342.5    |     4349.4  |    5250.9
                          f32 B=64, M=128, H=16, K=128   |      6751.9    |     6734.9  |    6860.6
                          f16 B=64, M=128, H=16, K=256   |     11425.9    |    11412.6  |   10225.4
                          f32 B=64, M=128, H=16, K=256   |     13681.5    |    13715.3  |   13386.5
                          f16 B=64, M=512, H=16, K=16    |     15392.7    |    15298.3  |   27164.6
                          f32 B=64, M=512, H=16, K=16    |     20888.0    |    20818.0  |   33373.3
                          f16 B=64, M=512, H=16, K=32    |     18209.2    |    18168.1  |   29831.6
                          f32 B=64, M=512, H=16, K=32    |     25172.4    |    25124.1  |   37340.7
                          f16 B=64, M=512, H=16, K=64    |     23142.3    |    22989.5  |   35792.1
                          f32 B=64, M=512, H=16, K=64    |     36070.6    |    36008.3  |   43156.8
                          f16 B=64, M=512, H=16, K=128   |     48709.7    |    48567.6  |   50699.7
                          f32 B=64, M=512, H=16, K=128   |     71450.2    |    70832.3  |   60779.7
                          f16 B=64, M=512, H=16, K=256   |    109410.7    |   107837.9  |   97016.9
                          f32 B=64, M=512, H=16, K=256   |    144675.3    |   144816.6  |  109729.1
                          f16 B=64, M=1024, H=16, K=16   |     58305.9    |    57449.7  |  105449.7
                          f32 B=64, M=1024, H=16, K=16   |     78111.4    |    77196.8  |          
                          f16 B=64, M=1024, H=16, K=32   |     67724.8    |    67469.2  |  113569.9
                          f32 B=64, M=1024, H=16, K=32   |     92004.9    |    92452.9  |          
                          f16 B=64, M=1024, H=16, K=64   |     85769.7    |    85985.5  |  129934.5
                          f32 B=64, M=1024, H=16, K=64   |    131474.1    |   131842.4  |          
                          f16 B=64, M=1024, H=16, K=128  |    180858.4    |   180353.2  |  176085.7
                          f32 B=64, M=1024, H=16, K=128  |    260681.5    |   261345.7  |          
                          f16 B=64, M=1024, H=16, K=256  |    383548.2    |   380736.0  |          
                          f32 B=64, M=1024, H=16, K=256  |    530545.9    |   530228.5  |          
  (Tesla_V100_SXM2_16GB)  f16 B=384, M=197, H=1, K=88    |      1556.4    |     1502.5  |    1373.5
                          f32 B=384, M=197, H=1, K=88    |      2981.9    |     2885.5  |    2234.0
                          f16 B=384, M=197, H=1, K=80    |      1492.9    |     1440.5  |    1282.0
                          f32 B=384, M=197, H=1, K=80    |      2717.2    |     2606.4  |    2148.4
                          f16 B=384, M=197, H=1, K=64    |       851.5    |      824.4  |    1044.6
                          f32 B=384, M=197, H=1, K=64    |      1956.2    |     1888.7  |    1737.6
                          f16 B=1024, M=197, H=1, K=88   |      4060.5    |     3916.4  |    3731.4
                          f32 B=1024, M=197, H=1, K=88   |      7379.3    |     7123.7  |    6025.3
                          f16 B=1024, M=197, H=1, K=80   |      3899.0    |     3751.2  |    3329.3
                          f32 B=1024, M=197, H=1, K=80   |      6679.8    |     6440.6  |    5673.2
                          f16 B=1024, M=197, H=1, K=64   |      2102.5    |     2033.2  |    2674.8
                          f32 B=1024, M=197, H=1, K=64   |      4837.4    |     4637.9  |    4491.2
                          f16 B=512, M=197, H=1, K=80    |      2051.4    |     1980.8  |    1678.7
                          f32 B=512, M=197, H=1, K=80    |      3592.4    |     3457.6  |    2856.9
                          f16 B=32, M=197, H=16, K=80    |      2040.5    |     1972.6  |    1799.2
                          f32 B=32, M=197, H=16, K=80    |      3649.2    |     3514.2  |    3015.5
                          f16 B=32, M=197, H=16, K=64    |      1164.4    |     1119.0  |    1450.0
                          f32 B=32, M=197, H=16, K=64    |      2655.3    |     2554.1  |    2415.0
                          f16 B=32, M=197, H=16, K=128   |      2367.7    |     2290.4  |    2213.7
                          f32 B=32, M=197, H=16, K=128   |      4664.7    |     4547.6  |    4023.2
                          f16 B=256, M=197, H=1, K=88    |      1174.4    |     1138.3  |     941.5
                          f32 B=256, M=197, H=1, K=88    |      1979.8    |     1926.5  |    1521.0
                          f16 B=16, M=197, H=16, K=88    |      1156.8    |     1116.3  |     970.6
                          f32 B=16, M=197, H=16, K=88    |      1894.1    |     1862.4  |    1606.0
                          f16 B=16, M=197, H=16, K=64    |       577.9    |      557.9  |     795.7
                          f32 B=16, M=197, H=16, K=64    |      1330.0    |     1275.8  |    1285.0
                          f16 B=16, M=197, H=16, K=128   |      1295.2    |     1245.8  |    1136.1
                          f32 B=16, M=197, H=16, K=128   |      2329.5    |     2292.3  |    2016.9
                          f16 B=1, M=4096, H=160, K=128  |     89527.4    |    87327.9  |          
                          f32 B=1, M=4096, H=160, K=128  |    277992.4    |   281487.9  |          
                          f16 B=2, M=4096, H=160, K=128  |    179184.1    |   176796.0  |          
                          f32 B=2, M=4096, H=160, K=128  |    560186.5    |   564008.1  |          
                          f16 B=1, M=8192, H=160, K=128  |    353299.5    |   347318.8  |          
                          f32 B=1, M=8192, H=160, K=128  |   1101264.4    |  1111934.5  |          
                          f16 B=2, M=8192, H=160, K=128  |    708640.0    |   696927.1  |          
                          f16 B=1024, M=82, H=8, K=64    |      8177.0    |     7899.4  |    5814.6
                          f32 B=1024, M=82, H=8, K=64    |     12591.5    |    12697.1  |   11005.3
                          f16 B=150, M=256, H=16, K=64   |      4030.2    |     3963.8  |    7590.3
                          f32 B=150, M=256, H=16, K=64   |     11025.0    |    11184.3  |   16357.7
                          f16 B=64, M=256, H=12, K=64    |      1317.7    |     1293.5  |    2385.3
                          f32 B=64, M=256, H=12, K=64    |      3617.1    |     3633.0  |    4970.4
                          f16 B=1, M=4096, H=16, K=40    |     24783.9    |    24253.3  |    8364.6
                          f32 B=1, M=4096, H=16, K=40    |     57289.3    |    57122.2  |   19517.0
                          f16 B=1, M=16384, H=16, K=40   |    392351.4    |   386768.6  |          
                          f32 B=1, M=16384, H=16, K=40   |    909342.2    |   909207.0  |          
                          f16 B=16, M=128, H=16, K=16    |       414.0    |      500.4  |     633.7
                          f32 B=16, M=128, H=16, K=16    |       473.3    |      546.9  |     610.3
                          f16 B=16, M=128, H=16, K=32    |       408.1    |      575.3  |     670.2
                          f32 B=16, M=128, H=16, K=32    |       479.5    |      519.1  |     618.9
                          f16 B=16, M=128, H=16, K=64    |       407.3    |      461.2  |     648.9
                          f32 B=16, M=128, H=16, K=64    |       478.4    |      575.0  |     615.2
                          f16 B=16, M=128, H=16, K=128   |       415.7    |      515.3  |     690.1
                          f32 B=16, M=128, H=16, K=128   |       875.7    |      875.9  |    1006.7
                          f16 B=16, M=128, H=16, K=256   |      1054.0    |     1052.4  |     888.7
                          f32 B=16, M=128, H=16, K=256   |      1734.7    |     1740.9  |    1854.9
                          f16 B=16, M=512, H=16, K=16    |      1036.2    |     1015.1  |    1918.9
                          f32 B=16, M=512, H=16, K=16    |      2532.6    |     2540.9  |    4288.5
                          f16 B=16, M=512, H=16, K=32    |      1187.1    |     1158.3  |    2128.9
                          f32 B=16, M=512, H=16, K=32    |      3226.6    |     3260.8  |    4634.9
                          f16 B=16, M=512, H=16, K=64    |      1513.8    |     1490.7  |    2560.9
                          f32 B=16, M=512, H=16, K=64    |      4426.3    |     4449.8  |    5479.5
                          f16 B=16, M=512, H=16, K=128   |      3317.0    |     3212.9  |    3377.7
                          f32 B=16, M=512, H=16, K=128   |      8795.9    |     8759.7  |    8724.9
                          f16 B=16, M=512, H=16, K=256   |      8523.2    |     8505.5  |    5348.4
                          f32 B=16, M=512, H=16, K=256   |     17430.3    |    17494.2  |   16621.6
                          f16 B=16, M=1024, H=16, K=16   |      3774.8    |     3717.3  |    7286.8
                          f32 B=16, M=1024, H=16, K=16   |      9734.4    |     9676.0  |   16131.3
                          f16 B=16, M=1024, H=16, K=32   |      4230.2    |     4170.7  |    7662.0
                          f32 B=16, M=1024, H=16, K=32   |     11929.4    |    12001.4  |   17151.3
                          f16 B=16, M=1024, H=16, K=64   |      5290.4    |     5203.6  |    8637.0
                          f32 B=16, M=1024, H=16, K=64   |     16046.1    |    16305.8  |   19853.1
                          f16 B=16, M=1024, H=16, K=128  |     11369.1    |    11030.4  |   10478.4
                          f32 B=16, M=1024, H=16, K=128  |     31667.7    |    32050.9  |   32589.2
                          f16 B=16, M=1024, H=16, K=256  |     28937.3    |    28891.7  |   16874.2
                          f32 B=16, M=1024, H=16, K=256  |     62718.2    |    63284.2  |   58763.0
                          f16 B=64, M=128, H=16, K=16    |       414.5    |      508.8  |     651.3
                          f32 B=64, M=128, H=16, K=16    |       799.0    |      795.8  |    1240.6
                          f16 B=64, M=128, H=16, K=32    |       432.5    |      469.8  |     814.0
                          f32 B=64, M=128, H=16, K=32    |      1117.5    |     1115.5  |    1530.4
                          f16 B=64, M=128, H=16, K=64    |       635.9    |      623.9  |    1185.2
                          f32 B=64, M=128, H=16, K=64    |      1615.4    |     1612.9  |    2154.2
                          f16 B=64, M=128, H=16, K=128   |      1470.9    |     1419.7  |    1918.1
                          f32 B=64, M=128, H=16, K=128   |      3190.9    |     3171.7  |    3761.3
                          f16 B=64, M=128, H=16, K=256   |      3823.7    |     3810.6  |    3445.0
                          f32 B=64, M=128, H=16, K=256   |      6424.0    |     6416.3  |    7248.7
                          f16 B=64, M=512, H=16, K=16    |      3671.5    |     3596.6  |    7506.6
                          f32 B=64, M=512, H=16, K=16    |      9245.7    |     9262.7  |   16742.1
                          f16 B=64, M=512, H=16, K=32    |      4266.7    |     4189.1  |    8365.1
                          f32 B=64, M=512, H=16, K=32    |     11905.8    |    11914.3  |   18455.8
                          f16 B=64, M=512, H=16, K=64    |      5614.8    |     5510.7  |   10294.1
                          f32 B=64, M=512, H=16, K=64    |     16201.4    |    16312.6  |   22721.7
                          f16 B=64, M=512, H=16, K=128   |     11895.4    |    11544.8  |   14667.4
                          f32 B=64, M=512, H=16, K=128   |     31811.5    |    31957.9  |   39608.1
                          f16 B=64, M=512, H=16, K=256   |     31259.0    |    31179.8  |   26641.1
                          f32 B=64, M=512, H=16, K=256   |     63396.4    |    63495.1  |   74597.9
                          f16 B=64, M=1024, H=16, K=16   |     13576.3    |    13245.5  |   29173.0
                          f32 B=64, M=1024, H=16, K=16   |     34830.8    |    34907.6  |          
                          f16 B=64, M=1024, H=16, K=32   |     15310.6    |    15063.4  |   31256.2
                          f32 B=64, M=1024, H=16, K=32   |     43348.1    |    43456.3  |          
                          f16 B=64, M=1024, H=16, K=64   |     19603.4    |    19322.4  |   37280.5
                          f32 B=64, M=1024, H=16, K=64   |     58298.7    |    58997.7  |          
                          f16 B=64, M=1024, H=16, K=128  |     40772.0    |    39771.6  |   49110.0
                          f32 B=64, M=1024, H=16, K=128  |    116342.5    |   116715.2  |          
                          f16 B=64, M=1024, H=16, K=256  |    106418.3    |   106294.7  |          
                          f32 B=64, M=1024, H=16, K=256  |    230815.1    |   231929.0  |          

Times are in microseconds (us).

@jfc4050
Copy link
Contributor Author

jfc4050 commented Jan 3, 2023

edit: are you talking about these forward measurements?

fwd/no-bias
                          f16 B=1, M=16384, H=16, K=40   |     37370.7    |    29646.1  |       
fwd/causal   
                          f16 B=1, M=16384, H=16, K=40   |     19529.5    |    15403.6  |          

@danthe3rd
Copy link
Contributor

perf is similar before and after

I was mentioning this for V100: 30ms before, and 37ms after

f16 B=1, M=16384, H=16, K=40   |     37370.7    |    29646.1  |

@jfc4050
Copy link
Contributor Author

jfc4050 commented Jan 3, 2023

hm, not sure. i tried to repro on T4 but no major gap. I don't have a V100 to test with but i'll go through the code and see if i can make any guesses

pr

====== {'shape': (1, 16384, 16, 40), 'num_threads': 1, 'dropout_p': 0.0, 'attn_bias_cfg': (<class 'NoneType'>, False), 'dtype': torch.float16} ======
optimized: memory used: 120.0 MB
Skipped (OOM)
optimized: memory used: inf MB
[-------------------- attention (attn_bias=<class 'NoneType'>) -------------------]
                                                                            |
1 threads: ------------------------------------------------------------------------
      f16 B=1, M=16384, H=16, K=40, p=0.0,  BiasT=NoneType, BiasGrad=False  |  88.1

Times are in milliseconds (ms).

====== {'shape': (1, 16384, 16, 40), 'num_threads': 1, 'dropout_p': 0.0, 'attn_bias_cfg': (<class 'xformers.ops.fmha.common.LowerTriangularMask'>, False), 'dtype': torch.float16} ======
optimized: memory used: 632.0 MB
Skipped (OOM)
optimized: memory used: inf MB
[------- attention (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) -------]
                                                                                       |
1 threads: -----------------------------------------------------------------------------------
      f16 B=1, M=16384, H=16, K=40, p=0.0,  BiasT=LowerTriangularMask, BiasGrad=False  |  44.4

Times are in milliseconds (ms).

main

====== {'shape': (1, 16384, 16, 40), 'num_threads': 1, 'attn_bias_type': <class 'NoneType'>, 'dtype': torch.float16} ======
optimized: memory used: 120.0 MB
Skipped (OOM)
optimized: memory used: inf MB
[ attention (attn_bias=<class 'NoneType'>) ]
                                    |
1 threads: --------------------------------
      f16 B=1, M=16384, H=16, K=40  |  87.9

Times are in milliseconds (ms).

====== {'shape': (1, 16384, 16, 40), 'num_threads': 1, 'attn_bias_type': <class 'xformers.ops.fmha.common.LowerTriangularMask'>, 'dtype': torch.float16} ======
optimized: memory used: 632.0 MB
Skipped (OOM)
optimized: memory used: inf MB
[ attention (attn_bias=<class 'xformers.ops.fmha.common.LowerTriangularMask'>) ]
                                    |
1 threads: --------------------------------
      f16 B=1, M=16384, H=16, K=40  |  44.6

Times are in milliseconds (ms).

@danthe3rd
Copy link
Contributor

Okay that's fine - let's not worry too much about it, we can create an issue and document the regression, that should be good enough.
We plan to release 0.0.16 first (hopefully this week), and merge this PR after

@danthe3rd
Copy link
Contributor

Headsup @jfc4050 - we have some conflicts because of this mostly: ac5fd49
Shouldn't be too hard to fix, basically we want to give a reason why a kernel is not supported, rather than just returning a bool

danthe3rd pushed a commit that referenced this pull request Jan 13, 2023
See also #64

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with #587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```
danthe3rd pushed a commit that referenced this pull request Jan 13, 2023
See also #64

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with #587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
# A (block-diagonal)
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
# B (lower triangular)
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
# A + B
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
# A + causal (what most ppl want)
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```
@danthe3rd danthe3rd mentioned this pull request Jan 13, 2023
danthe3rd pushed a commit that referenced this pull request Jan 13, 2023
See also #640

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with #587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
# A (block-diagonal)
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
# B (lower triangular)
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
# A + B
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
# A + causal (what most ppl want)
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```
danthe3rd pushed a commit that referenced this pull request Jan 16, 2023
See also #640

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with #587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
# A (block-diagonal)
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
# B (lower triangular)
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
# A + B
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
# A + causal (what most ppl want)
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```
danthe3rd pushed a commit that referenced this pull request Jan 17, 2023
See also #640

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with #587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
# A (block-diagonal)
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
# B (lower triangular)
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
# A + B
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
# A + causal (what most ppl want)
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```
@danthe3rd danthe3rd merged commit 814314d into facebookresearch:main Jan 18, 2023
facebook-github-bot pushed a commit that referenced this pull request Jan 19, 2023
- FW: Drop support for dropout if pytorch is not installed
- Use rng_state in Context to store seed/offset for dropout
- Add test to ensure we can't combine flash+cutlass's dropouts

ghstack-source-id: c5e05a1994b9c20fc27b071c3bfefbb4174987a2
Pull Request resolved: https://github.com/fairinternal/xformers/pull/434

__original_commit__ = fairinternal/xformers@408fefe5506c92b9c58444620c45bf5159b7fb39
facebook-github-bot pushed a commit that referenced this pull request Jan 19, 2023
See also #640

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with #587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
# A (block-diagonal)
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
# B (lower triangular)
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
# A + B
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
# A + causal (what most ppl want)
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```

ghstack-source-id: 44740f71132fa76226fd4c559cc3f09732ff139b
Pull Request resolved: https://github.com/fairinternal/xformers/pull/435

__original_commit__ = fairinternal/xformers@be55fcd21c5dd621831245c5995e1c6fb49d9b77
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants