Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] CrossAttention: (partial) fix for NaN attention_scores on PyTorch nightly #2643

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Birch-san
Copy link
Contributor

@Birch-san Birch-san commented Mar 11, 2023

Some history

when attention_mask is None, we set baddbmm's beta = 0.

If beta is 0, then input will be ignored, and nan and inf in it will not be propagated.

in other words, baddbmm_input is an unused tensor that we supply just because the API requires it. beta = 0, so the tensor isn't even used.
this is why when I originally introduced it: I knew there was no need to even initialize the memory (hence empty instead of zero).

The bug

on MPS backend: this has been fine in some PyTorch versions, but on the recent nightly I tried (2.1.0.dev20230310): allocating such a large bias tensor (40, 4096, 4096) causes baddbmm() to return NaN attention_scores (and consequently black images). even though it's unused!

so, I've changed it to do a smaller allocation, and expand the tensor to the expected size.
at this allocation size (one element): I'm unconcerned about whether initializing the memory costs anything, so I've changed it to use zeroes instead of uninitialized memory.

a (1, 1, 1) tensor would work too (baddbmm supports broadcasting), but I think the explicit shape is helpful documentation, and doesn't cost much with expand(). it might also help model conversion (such as coremltools) with expressing the baddbmm in terms of unfused operations (GEMM * scale + bias).

=====

the reason I call this a partial fix for black images (in this PyTorch nightly, on MPS) is because I found one more fix was necessary:

CrossAttention#get_attention_scores:

  attention_scores = torch.baddbmm(
      attention_bias,
      query,
      key.transpose(-1, -2),
      beta=beta,
      alpha=self.scale,
  )
+ # accessing the tensor seems to help prevent black images on MPS on PyTorch nightly
+ assert not attention_scores.isnan().any().item()

AttnProcessor2_0#__call__:

  # the output of sdp = (batch, num_heads, seq_len, head_dim)
  hidden_states = F.scaled_dot_product_attention(
      query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
  )
+ # accessing the tensor seems to help prevent black images on MPS on PyTorch nightly
+ assert not hidden_states.isnan().any().item()

but that half of the fix is a bit harder to justify upstreaming (since it's not free, and doesn't make sense). better to report that half as a bug to pytorch, but I haven't succeeded in making a minimal repro.

====

there's also a third problem causing black images on MPS/pytorch@2.1.0.dev20230310, for larger tensor sizes:
pytorch/pytorch#96602

@Birch-san Birch-san changed the title [MPS] CrossAttentionProcessor: (partial) fix for NaN attention_scores on PyTorch nightly [MPS] CrossAttention: (partial) fix for NaN attention_scores on PyTorch nightly Mar 11, 2023
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @Birch-san!

This looks good to me - @pcuenca could you also take a quick look?

@pcuenca
Copy link
Member

pcuenca commented Mar 13, 2023

Hi @Birch-san! Your code looks good and your explanations are great, but I'm unable to reproduce the problem. I've tried with the nightly you mentioned and also with the one from today, running inference on various models. Do you happen to have a reproducible code snippet I can try?

Regarding the nasty workaround to access the tensors, I'm not opposed to use it inside a check that makes sure it only runs for mps, and maybe only for certain versions of PyTorch. As you know, we've done similar things in the past. That way users can enjoy diffusers while we report upstream and the problem gets solved.

Thanks a lot for your help!

@Birch-san
Copy link
Contributor Author

@pcuenca

Regarding the nasty workaround to access the tensors, I'm not opposed to use it inside a check that makes sure it only runs for mps, and maybe only for certain versions of PyTorch

aye, that'd be fine, but I don't know which versions are affected. I tried to produce a minimal repro, but the minimal version ran without trouble. something odd's going on.

I'm on Python 3.11, PyTorch 2.1.0.dev20230310, macOS 13.3 beta. they're all new, so any one of those could be responsible for Very Weird Stuff.

@pcuenca
Copy link
Member

pcuenca commented Mar 13, 2023

I'm also on macOS 13.3 beta and tested that same daily (and today's), but I tested on Python 3.9. I'll try Python 3.11, but it looks like it's not going to be easy to reproduce given your comment :)

Copy link
Contributor

@williamberman williamberman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! @Birch-san could you move the changes into attention_processor.py given this recent refactor?

e828232

@ZXStudio
Copy link

I also encountered the same problem, but I am window10

@nipunjindal
Copy link
Contributor

Can you please illustrate what's needed to push this PR forward?

@Birch-san
Copy link
Contributor Author

@nipunjindal it just needs the changes to move to here:

baddbmm_input = torch.empty(

are you experiencing this problem too?

@nipunjindal
Copy link
Contributor

nipunjindal commented Apr 26, 2023

Yes, I do and I did apply the patch to attention_processor. I can help in case you need help to merge main to this.
Please let me know.

@patrickvonplaten
Copy link
Contributor

Yes we renamed the class to attention_processor.py so we would just have to transfer the changes to this file! Sorry about this filename change :-/

…pass in: create a smaller bias and broadcast it. this helps to prevent NaN result from baddbmm() on MPS on PyTorch 2.1.0.dev20230310.
@Birch-san
Copy link
Contributor Author

@nipunjindal I've now rebased; thanks

@patrickvonplaten
Copy link
Contributor

Does this problem also appear on torch 2.0.0 or torch 2.0.1? Torch nightly is often not in the best state

@patrickvonplaten
Copy link
Contributor

cc @pcuenca

@github-actions
Copy link

github-actions bot commented Jun 4, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Jun 4, 2023
@pcuenca pcuenca added wip and removed stale Issues that haven't received updates labels Jun 4, 2023
@wuchangping
Copy link

wuchangping commented Sep 7, 2023

diffusers 0.16.1, diffusers 0.20.0, stable-diffusion-2-1 ,dreambooth, fp16 , sd_scripts v0.6.5 reproduce the issue.

   def get_attention_scores(self, query, key, attention_mask=None):
        dtype = query.dtype
        if self.upcast_attention:
            query = query.float()
            key = key.float()

        print("0 get_attention_scores:", attention_mask)
        if attention_mask is None:
         #   baddbmm_input = torch.zeros(
         #       query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
         #   )
            batch_x_heads, query_tokens, _ = query.shape
            _, key_tokens, _ = key.shape
            # expanding dims isn't strictly necessary (baddbmm supports broadcasting bias),
            # but documents the expected shape without allocating any additional memory
            attention_bias = torch.zeros(1, 1, 1, dtype=query.dtype, device=query.device).expand(
                batch_x_heads, query_tokens, key_tokens
            )
            beta = 0
            print("0-1 attention_bias:", attention_bias)
        else:
            attention_bias = attention_mask
           # baddbmm_input = attention_mask
            beta = 1

          #  baddbmm_input,
            print("0-2 attention_bias:", attention_bias)
        print("0-3 query:", query)
        print("0-4 key:", key)
        print("0-5 beta:", beta)
        print("0-6 self.scale:", self.scale)
        attention_scores = torch.baddbmm(
            attention_bias,
            query,
            key.transpose(-1, -2),
            beta=beta,
            alpha=self.scale,
        )
        print("1 attention_scores:", attention_scores )

        if self.upcast_softmax:
            attention_scores = attention_scores.float()

        print("2 attention_scores:", attention_scores )
        attention_probs = attention_scores.softmax(dim=-1)
        print("3 attention_probs:", attention_probs)
        print("3 attention_probs dtype:", dtype)
        attention_probs = attention_probs.to(dtype)
        print("4 attention_probs:", attention_probs)

        return attention_probs

logs

0 get_attention_scores: None
0-1 attention_bias: 
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],

..........................
0-3 query: tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='xla:0',
-5 beta: 0
0-6 self.scale: 0.125
1 attention_scores: tensor([[[nan, nan, nan,  ..., nan, nan, nan],  
         [nan, nan, nan,  ..., nan, nan, nan],  
         [nan, nan, nan,  ..., nan, nan, nan],  
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],  
         [nan, nan, nan,  ..., nan, nan, nan], 
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],  
         [nan, nan, nan,  ..., nan, nan, nan],  
         [nan, nan, nan,  ..., nan, nan, nan], 
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],  
         [nan, nan, nan,  ..., nan, nan, nan], 
         [nan, nan, nan,  ..., nan, nan, nan]], 

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants