Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensor masks failing with memory_efficient_attention #683

Closed
Infrared1029 opened this issue Mar 6, 2023 · 11 comments
Closed

tensor masks failing with memory_efficient_attention #683

Infrared1029 opened this issue Mar 6, 2023 · 11 comments

Comments

@Infrared1029
Copy link

❓ Questions and Help

I'm playing around with xformers and this is probably a noobish question, but why is this code snippet failing?

q, k, v = torch.randn(32, 10, 16), torch.randn(32, 10, 16), torch.randn(32, 10, 16)
q, k, v = q.cuda(), k.cuda(), v.cuda()
mask = torch.zeros(32, 10, 10).bool()
xops.memory_efficient_attention(q, k, v, attn_bias=mask)

Error: `RuntimeError: Expected attn_bias.stride(1) == 0 to be true, but got false.`
@danthe3rd
Copy link
Contributor

Hi @Infrared1029
Can you try with inputs in the BMHK format? (that is: [batch, seqlen, num_heads, embed_per_head])
Also I'm not sure if we support binary masking, I would recommend first trying with a float tensor, with -inf for the values you want to discard.

Taking a step back, adding a torch.Tensor bias will impact the performance of the kernel. If you want to use a standard mask (like causal masking for instance), you can use one of the attention bias classes in xformers. There is an example for causal masking for instance.

@Infrared1029
Copy link
Author

Hi @danthe3rd
That's what I'm getting:

q, k, v = torch.randn(32, 10, 1, 16), torch.randn(32, 10, 1, 16), torch.randn(32, 10, 1, 16)
q, k, v = q.cuda(), k.cuda(), v.cuda()
mask = torch.zeros((32, 10, 10)).cuda()
xops.memory_efficient_attention(q, k, v, attn_bias=mask)

Error: RuntimeError: Expected attn_bias.stride(1) == 0 to be true, but got false.

Also thanks for the attention bias classes suggestion, I'm aware of that, just testing for more custom cases.

@danthe3rd
Copy link
Contributor

The most recent development version should give you more information on what's going on:

NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
     query       : shape=(32, 10, 1, 16) (torch.float32)
     key         : shape=(32, 10, 1, 16) (torch.float32)
     value       : shape=(32, 10, 1, 16) (torch.float32)
     attn_bias   : <class 'torch.Tensor'>
     p           : 0.0
`flshattF` is not supported because:
    dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})
    attn_bias type is <class 'torch.Tensor'>
`tritonflashattF` is not supported because:
    dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})
    attn_bias type is <class 'torch.Tensor'>
`cutlassF` is not supported because:
    attn_bias.shape[-1] % 4 != 0
`smallkF` is not supported because:
    bias with non-zero stride not supported

attn_bias.shape[-1] % 4 != 0

I believe the same script should work if your sequence length is dividable by 4 (f32) or 8 (f16/bf16).
Try using 16 for seqlen instead of 10. This is a limitation to have correct memory alignment for performance.

In theory, this could work with seqlen 10 if your attn_bias is padded, but this is not implemented in xFormers at the moment... Is your sequence length going to be 10?

cc @jfc4050

@Infrared1029
Copy link
Author

hmm, same issue still

q, k, v = torch.randn(32, 16, 1, 16), torch.randn(32, 16, 1, 16), torch.randn(32, 16, 1, 16)
q, k, v = q.cuda(), k.cuda(), v.cuda()
mask = torch.zeros((32, 16, 16)).cuda()
xops.memory_efficient_attention(q, k, v, attn_bias=mask)

yields the same error: RuntimeError: Expected attn_bias.stride(1) == 0 to be true, but got false.

for reference, this is the output of python -m xformers.info

xFormers 0.0.16
memory_efficient_attention.cutlassF:               available
memory_efficient_attention.cutlassB:               available
memory_efficient_attention.flshattF:               available
memory_efficient_attention.flshattB:               available
memory_efficient_attention.smallkF:                available
memory_efficient_attention.smallkB:                available
memory_efficient_attention.tritonflashattF:        available
memory_efficient_attention.tritonflashattB:        available
swiglu.fused.p.cpp:                                available
is_triton_available:                               True
is_functorch_available:                            False
pytorch.version:                                   1.13.1+cu116
pytorch.cuda:                                      available
gpu.compute_capability:                            7.0
gpu.name:                                          Tesla V100-PCIE-16GB
build.info:                                        available
build.cuda_version:                                1107
build.python_version:                              3.8.16
build.torch_version:                               1.13.1+cu117
build.env.TORCH_CUDA_ARCH_LIST:                    5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
build.env.XFORMERS_BUILD_TYPE:                     Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS:        None
build.env.NVCC_FLAGS:                              None
build.env.XFORMERS_PACKAGE_FROM:                   wheel-v0.0.16
source.privacy:                                    open source

@Infrared1029
Copy link
Author

Infrared1029 commented Mar 7, 2023

The most recent development version should give you more information on what's going on:

NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
     query       : shape=(32, 10, 1, 16) (torch.float32)
     key         : shape=(32, 10, 1, 16) (torch.float32)
     value       : shape=(32, 10, 1, 16) (torch.float32)
     attn_bias   : <class 'torch.Tensor'>
     p           : 0.0
`flshattF` is not supported because:
    dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})
    attn_bias type is <class 'torch.Tensor'>
`tritonflashattF` is not supported because:
    dtype=torch.float32 (supported: {torch.bfloat16, torch.float16})
    attn_bias type is <class 'torch.Tensor'>
`cutlassF` is not supported because:
    attn_bias.shape[-1] % 4 != 0
`smallkF` is not supported because:
    bias with non-zero stride not supported

attn_bias.shape[-1] % 4 != 0

I believe the same script should work if your sequence length is dividable by 4 (f32) or 8 (f16/bf16). Try using 16 for seqlen instead of 10. This is a limitation to have correct memory alignment for performance.

In theory, this could work with seqlen 10 if your attn_bias is padded, but this is not implemented in xFormers at the moment... Is your sequence length going to be 10?

cc @jfc4050

Actually, after trying out the dev version pip install --pre -U xformers, it does run with the suggested sequence length of 16, it fails with xFormers 0.0.16 as mentioned tho.

@danthe3rd
Copy link
Contributor

Oh right, the bias support was merged after we cut branch for 0.0.16.
This should allow you to experiment already, we can add proper support for (seqlen % 8) != 0 later if you need it

@Infrared1029
Copy link
Author

Thanks a lot @danthe3rd, I was just writing a custom scaled_dot_product_attention that uses xFormers if it finds an installation, otherwise falls back to a normal non-efficient implementation. So just to make sure, xFormers does NOT support torch.Tensor in general for attn_bias pre 0.0.17 right?

@danthe3rd
Copy link
Contributor

So just to make sure, xFormers does NOT support torch.Tensor in general for attn_bias pre 0.0.17 right

Pre-0.0.17: Supported under very strict conditions by a deprecated kernel.
torch.Tensor bias is only supported for limited cases on the "smallK" kernel - which is not efficient in general, only works with K<=32, f32. Further more, it only supports a 1d bias repeated to be 2d (hence the attn_bias.stride(1) == 0 error you had initially).

0.0.17+: Attention bias is supported only if seqlen is a multiple of 4 or 8.

@Infrared1029
Copy link
Author

Awesome, thanks a lot @danthe3rd for the help, also thanks for the quick replies, saved me lots of time:)

@raman-r-4978
Copy link

0.0.17+: Attention bias is supported only if seqlen is a multiple of 4 or 8

Do we have the support for seqlen that is not a multiple of 4 or 8 ?

I am working on a speech recognition problem, where seqlen is generally very high and might not always be divisible of 4 or 8

@danthe3rd
Copy link
Contributor

This has been added in v0.0.18, however with a twist. If your sequence length is M, you need to create your attention bias with shape [num_batches, num_heads, M, 8 * ((M + 7) // 8)] (basically rounding to the next multiple of 8), and then take a slice of it with the size you want (eg attn_bias[:, :, :, :M])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants