Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does attention masking actually work? #1890

Closed
Birch-san opened this issue Jan 3, 2023 · 13 comments · Fixed by #2634
Closed

Does attention masking actually work? #1890

Birch-san opened this issue Jan 3, 2023 · 13 comments · Fixed by #2634
Labels
stale Issues that haven't received updates

Comments

@Birch-san
Copy link
Contributor

Birch-san commented Jan 3, 2023

I tried passing in an attention_mask, for use in a stable-diffusion Unet but it doesn't actually get passed down as deep as CrossAttention#forward.

I tried fixing it to pass the param down, but it blows up on tensor size mismatch, because self-attention and cross-attention have different masking requirements.

I made my own implementation of cross-attention masking a few weeks ago (before the refactor) but never upstreamed it. mainly because I didn't understand whether I'd done it correctly (I re-used the lucidrains implementation that CompVis used):
cbb4c02
EDIT: rebased implementation to show how it would fit in with the existing attention masking and the refactored attention:
Birch-san@e3a93e9

I explicitly named the parameter as a cross attention mask, because a self-attention mask has entirely different requirements.

in terms of wider API design, I wonder whether it should be an attention map (i.e. so you can use it to increase/decrease attention scores for certain token embeds). but for now I'm mostly interested in the mask aspect. because waifu-diffusion makes use of "multiple CLIP embeddings stitched together", so attention masking is useful to avoid attending to padding token embeddings, which would be biased towards conveying high-level semantic of the final CLIP segment only.

@patrickvonplaten

@patrickvonplaten
Copy link
Contributor

That's a good point! Also cc @williamberman and @patil-suraj here.

Stable Diffusion never used a attention mask when training, however some other models that make use of UNet2DCondition use it (such as UnCLIP). With more and more people fine-tuning stable diffusion we could actually allow attention_mask to also be used in Stable Diffusion.

@patil-suraj @williamberman wdyt?

@Birch-san
Copy link
Contributor Author

I see two masking use-cases:

  • self-attention mask (if you're training on a batch of images with mixed aspect ratios: you can tell it not to attend to padding pixels)
  • cross-attention mask (don't attend to PAD token embeddings in CLIP text condition)

waifu-diffusion already wants to use cross-attention masks for fine-tuning stable-diffusion (in fact will begin a big training run in a few days).

but, what should the API be? attention_mask could refer to self-attention or cross-attention. but it'd be catastrophic to pass it to both.
is cross_attention_kwargs a useful piece in this puzzle? are we supposed to do something like cross_attention_kwargs = { 'attention_mask': my_cool_mask }?

side-note: why is the class named CrossAttention? it's super confusing, given that it implements self-attention too. would MultiHeadAttention be a better name?
side-note 2: is there a reason why you didn't use PyTorch's native torch.nn.MultiheadAttention? you can swap diffusers' CrossAttention class for the native implementation pretty easily.

@patrickvonplaten
Copy link
Contributor

Hey @Birch-san,

Thanks for your thoughtful answer here. I see the need to pass customized attention_masks . For this IMO the user should set up a customized attention processor class as was merged in: #1639 . Now since we need different attention processors for the self- and cross-attention layers we need to leverage some new functionality that will be added as part of this PR (hope to get it merged this week).

As you can see in the comment, it allows to set layers depending on the weight name which should make it to set attention processors only to the self- or cross-attention processors. Does this make sense?

@patrickvonplaten
Copy link
Contributor

Regarding the naming, yes I think you're right here - we should give it a better name. Would you like to open a new seperate issue for this as this new issue would be quite actionable? :-)

@Lime-Cakes
Copy link
Contributor

On the topic of masking when training with padded pixels? Conv doesn't allow mask. Would the padded value be zero and passed into conv, then masked at attention? Would the model still function right when conv might learn the letterbox edges, while attention ignores them?

@patrickvonplaten
Copy link
Contributor

Good point @Lime-Cakes!

Actually for me the attention mask only really makes sense for cross attention as well

@Lime-Cakes
Copy link
Contributor

I see two masking use-cases:

* self-attention mask (if you're training on a batch of images with mixed aspect ratios: you can tell it not to attend to padding pixels)

* cross-attention mask (don't attend to PAD token embeddings in CLIP text condition)

waifu-diffusion already wants to use cross-attention masks for fine-tuning stable-diffusion (in fact will begin a big training run in a few days).

but, what should the API be? attention_mask could refer to self-attention or cross-attention. but it'd be catastrophic to pass it to both. is cross_attention_kwargs a useful piece in this puzzle? are we supposed to do something like cross_attention_kwargs = { 'attention_mask': my_cool_mask }?

side-note: why is the class named CrossAttention? it's super confusing, given that it implements self-attention too. would MultiHeadAttention be a better name? side-note 2: is there a reason why you didn't use PyTorch's native torch.nn.MultiheadAttention? you can swap diffusers' CrossAttention class for the native implementation pretty easily.

Do you have any update for the result of the training using cross attention mask? Does it work?

@bonlime
Copy link
Contributor

bonlime commented Feb 7, 2023

want to give +1 to importance of having cross-attention masks. for SD having too much padding tokens affects the image generation (this is probably one of the reason people found longer prompts to work better, to avoid too much attention drawn to the PAD token). I don't see why padding/not padding should change the output image even a tiny bit, so proper support for masks is required

upd. it seems that currently the proper masks can be supported by writing a custom CrossAttnProcessor and passing cross_attention_kwargs = { 'attention_mask_': my_cool_mask }. using the default name attention_mask raises errors

@patil-suraj
Copy link
Contributor

It indeed makes sense to pass an attention mask to cross attention.But the stable diffusion model is trained without the masks and to keep the implem 1:1 with the original implem we don't pass attention mask. But it makes sense to allow passing masks optionally.

@Lime-Cakes
Copy link
Contributor

Having the option to use it would be great! It should by off be default, to produce same training/inference result as the original implementation.

A lot people are experimenting with fine tuning with more tokens, meaning more padding are used. A few test runs suggest having too much padded tokens fed into unet for training causes lower quality. Personally, I think masking at unet cross attention should be a solution to this issue.

@bonlime
Copy link
Contributor

bonlime commented Feb 8, 2023

I ended up using such implementation which works nicely. maybe this could make it to the upstream?
Currently this could only be implemented for default CrossAttention, because xformers do not support custom attention bias to be passed (but I think they are working on it).

class FixedCrossAttnProcessor:
    """Copy-paste from HF diffusers, but with support for passing `encoder_attention_mask` which avoids giving 
    attention to padded tokens"""
    def __call__(
        self,
        attn: CrossAttention,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        encoder_attention_mask=None,
    ):
        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
        query = attn.to_q(hidden_states)

        if encoder_attention_mask is not None and encoder_hidden_states is not None:
            # B x 77 -> B x 4096 x 77
            attention_mask = encoder_attention_mask.unsqueeze(1).repeat(1, hidden_states.size(1), 1)
            attention_mask = attention_mask.repeat_interleave(attn.heads, dim=0)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.cross_attention_norm:
            encoder_hidden_states = attn.norm_cross(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states

and then call like this:

noise_pred = unet(samle, timestep, encoder_hidden_states, cross_attention_kwargs=dict(encoder_attention_mask=mask))

@github-actions
Copy link

github-actions bot commented Mar 4, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@lijuncheng16
Copy link

@Birch-san
image
In our experiment, we find attention mask results in inferior generation compared to unmasked version, and here's our potential examplanation: https://github.com/audiojourney/audiojourney.github.io/blob/main/neurIPS_2023_appendix_v1.3.pdf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants