-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] CrossAttention: (partial) fix for NaN attention_scores on PyTorch nightly #2643
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
89eb794
to
9238089
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR @Birch-san!
This looks good to me - @pcuenca could you also take a quick look?
Hi @Birch-san! Your code looks good and your explanations are great, but I'm unable to reproduce the problem. I've tried with the nightly you mentioned and also with the one from today, running inference on various models. Do you happen to have a reproducible code snippet I can try? Regarding the nasty workaround to access the tensors, I'm not opposed to use it inside a check that makes sure it only runs for Thanks a lot for your help! |
aye, that'd be fine, but I don't know which versions are affected. I tried to produce a minimal repro, but the minimal version ran without trouble. something odd's going on. I'm on Python 3.11, PyTorch |
I'm also on macOS 13.3 beta and tested that same daily (and today's), but I tested on Python 3.9. I'll try Python 3.11, but it looks like it's not going to be easy to reproduce given your comment :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! @Birch-san could you move the changes into attention_processor.py given this recent refactor?
I also encountered the same problem, but I am window10 |
Can you please illustrate what's needed to push this PR forward? |
@nipunjindal it just needs the changes to move to here:
are you experiencing this problem too? |
Yes, I do and I did apply the patch to |
Yes we renamed the class to |
…pass in: create a smaller bias and broadcast it. this helps to prevent NaN result from baddbmm() on MPS on PyTorch 2.1.0.dev20230310.
9238089
to
cdf071a
Compare
@nipunjindal I've now rebased; thanks |
Does this problem also appear on torch 2.0.0 or torch 2.0.1? Torch nightly is often not in the best state |
cc @pcuenca |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
diffusers 0.16.1, diffusers 0.20.0, stable-diffusion-2-1 ,dreambooth, fp16 , sd_scripts v0.6.5 reproduce the issue.
logs
|
Some history
when
attention_mask is None
, we setbaddbmm
'sbeta = 0
.in other words,
baddbmm_input
is an unused tensor that we supply just because the API requires it.beta = 0
, so the tensor isn't even used.this is why when I originally introduced it: I knew there was no need to even initialize the memory (hence
empty
instead ofzero
).The bug
on MPS backend: this has been fine in some PyTorch versions, but on the recent nightly I tried (2.1.0.dev20230310): allocating such a large bias tensor (
40, 4096, 4096
) causesbaddbmm()
to return NaN attention_scores (and consequently black images). even though it's unused!so, I've changed it to do a smaller allocation, and expand the tensor to the expected size.
at this allocation size (one element): I'm unconcerned about whether initializing the memory costs anything, so I've changed it to use zeroes instead of uninitialized memory.
a (
1, 1, 1
) tensor would work too (baddbmm
supports broadcasting), but I think the explicit shape is helpful documentation, and doesn't cost much withexpand()
. it might also help model conversion (such as coremltools) with expressing the baddbmm in terms of unfused operations (GEMM * scale + bias).=====
the reason I call this a partial fix for black images (in this PyTorch nightly, on MPS) is because I found one more fix was necessary:
CrossAttention#get_attention_scores
:AttnProcessor2_0#__call__
:but that half of the fix is a bit harder to justify upstreaming (since it's not free, and doesn't make sense). better to report that half as a bug to pytorch, but I haven't succeeded in making a minimal repro.
====
there's also a third problem causing black images on
MPS/pytorch@2.1.0.dev20230310
, for larger tensor sizes:pytorch/pytorch#96602