Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[0.0.18] memory_efficient_attention NaNs when seqlen>32768 #719

Closed
comfyanonymous opened this issue Apr 5, 2023 · 6 comments
Closed

[0.0.18] memory_efficient_attention NaNs when seqlen>32768 #719

comfyanonymous opened this issue Apr 5, 2023 · 6 comments
Assignees
Labels
bug Something isn't working
Milestone

Comments

@comfyanonymous
Copy link

馃悰 Bug

Command

To Reproduce

Steps to reproduce the behavior:

import xformers
import xformers.ops
import torch

q = torch.zeros(([1, 33728, 512])).cuda()
k = torch.zeros(([1, 33728, 512])).cuda()
v = torch.zeros(([1, 33728, 512])).cuda()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
print(out)
print(torch.isnan(out).any())

Expected behavior

out should not contain Nan.

Environment

This test was done on the free google colab with their T4 GPU using the 0.0.18 package on pip.

Additional context

It works fine on 0.0.17 but fails on 0.0.18.

People have been reporting that my ComfyUI returns black images during the VAE decoding phase when the resolution is higher than a certain amount and I have narrowed it down to this issue.

@leppie
Copy link

leppie commented Apr 5, 2023

I hit this issue too yesterday after upgrading to 0.0.18.

@danthe3rd danthe3rd pinned this issue Apr 5, 2023
@danthe3rd danthe3rd added the bug Something isn't working label Apr 5, 2023
@danthe3rd danthe3rd changed the title On 0.0.18 xformers.ops.memory_efficient_attention returns NaN on certain input shapes [0.0.18] xformers.ops.memory_efficient_attention returns NaN on certain input shapes Apr 5, 2023
@danthe3rd
Copy link
Contributor

danthe3rd commented Apr 5, 2023

Hi,
Thanks for the minimal reproduction example. I can repro on A100 - I'll have a look.
Also pinning this issue as it's quite important

@danthe3rd
Copy link
Contributor

It seems to happen due to a cast to int16 at some point in the code, so it happens when the sequence length is larger than 32768. Updating the title according.

@danthe3rd danthe3rd changed the title [0.0.18] xformers.ops.memory_efficient_attention returns NaN on certain input shapes [0.0.18] xformers.ops.memory_efficient_attention returns NaN when seqlen>32768 Apr 5, 2023
@danthe3rd danthe3rd changed the title [0.0.18] xformers.ops.memory_efficient_attention returns NaN when seqlen>32768 [0.0.18] memory_efficient_attention NaNs when seqlen>32768 Apr 6, 2023
@danthe3rd
Copy link
Contributor

I have a tentative fix - hopefully we can land that soon and release the 0.0.19 early next week to address that

@danthe3rd
Copy link
Contributor

danthe3rd commented Apr 14, 2023

It should be fixed as of 68dce69, and will be included in the next release (0.0.19). In the meantime, you can also use a development build >=0.0.19.dev516

@patrickvonplaten
Copy link

Thanks a mille for the fix @danthe3rd - very helpful thread here!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants