Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor cross attention and allow mechanism to tweak cross attention function #1639

Merged
merged 23 commits into from
Dec 20, 2022

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Dec 9, 2022

We are thinking about how to best support methods that tweak the cross attention computation, such as hyper networks (where linear layers that map k-> k' and v-> v' are trained), prompt-to-prompt, and other customized cross attention mechanisms.

Supporting such features poses a challenge in that we need to allow the user to "hack" into the cross attention module which is "buried" inside the unet.

We make two assumptions here:

  1. The cross attention layer is always the connection between the conditioning (text encoder or image encoder) and the unet, so that we can limit the scope to allow users to only hack this connection. Therefore, we can expect the scope of other cross attention method to stay within an API that gets (hidden_states, context_embeddings) and returns again (hidden_states):
cross_attention: (hidden_states, context_embeddings) -> hidden_states
  1. We also assume that such hacking at inference time only makes sense if all previously trained weights stay the same and if all previously trained weights are used. This means that we don't allow to overwrite existing weights and instead just give the user access to the existing weights:
cross_attention_fn: (hidden_states, query_weight, key_weight, value_weight, context_embeddings) -> hidden_states

Therefore a nice API that is both somewhat clean and flexible is to just let the user write "CrossAttentionProcessor" classes that are by default weights less and take (query_weight, key_weight, value_weight) as an entry which you can see in this PR.
I also took this new design to refactor the cross attention layer a bit and to make xformers, sliced attentention and normal attention different "processor" classes.

Now, let's image one would like to support prompt-to-prompt. In this case one should be able to do the following:

Note this is pseudo code:

from diffusers import CrossAttentionProcMixin

unet = # load unet

class P2PCrossAttentionProc:

    def __init__(self, head_size, upcast_attention, attn_maps_reweight):
        super().__init__(head_size=head_size, upcast_attention=upcast_attention)
        self.attn_maps_reweight = attn_maps_reweight

    def __call__(self, hidden_states, query_proj, key_proj, value_proj, encoder_hidden_states, modified_text_embeddings):
        batch_size, sequence_length, _ = hidden_states.shape
        query = query_proj(hidden_states)

        context = context if context is not None else hidden_states
        attention_probs = []
        original_text_embeddings = encoder_hidden_states
        for context in [original_text_embeddings, modified_text_embeddings]:
            key = key_proj(original_text_embeddings)
            value = self.value_proj(original_text_embeddings)
    
            query = self.head_to_batch_dim(query, self.head_size)
            key = self.head_to_batch_dim(key, self.head_size)
            value = self.head_to_batch_dim(value, self.head_size)
    
            attention_probs.append(self.get_attention_scores(query, key))
           
        merged_probs = self.attn_maps_reweight * torch.cat(attention_probs)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = self.batch_to_head_dim(hidden_states)
        return hidden_states

proc = P2PCrossAttentionProc(unet.config.head_size, unet.config.upcast_attention, 0.6)


unet.set_cross_attention_processor(proc)

unet(sample=sample, t=t, encoder_hidden_states=orig_text_embeddings, {
...

I think this design is flexible enough to allow for more complicated use cases and can also be used for training hyper-networks!

Another important point here is mutability. If I pass the same class instance to multiple layers, the class is mutable which might be desired. However it might also be desired to pass unmutable or difference proc classes to different cross attention layers. In this case, we could simple allow setting a list of processors object intsead of just one object.

@patrickvonplaten patrickvonplaten changed the title first proposal Refactor cross attention and allow mechanism to tweak cross attention function Dec 9, 2022
@patrickvonplaten
Copy link
Contributor Author

Very common processor classes such as prompt-to-prompt could also be natively added to diffusers in a new pipeline that defines a whole attention processor class.
Prompt-2-prompt would be a good example

@keturn
Copy link
Contributor

keturn commented Dec 9, 2022

Looks promising! I think separating out xformers and sliced attention like this helps readers as well as any code-tracers, as it no longer has those if-branches in the forward method.

I'll leave it to damian to comment on whether the API is sufficient for the types of applications he has in mind. But the thing I'd be wondering about is: What if I want to do my application-specific thing and also take advantage of the library's efficient implementations of attention?

Maybe that question is too vague without a concrete example. Because I know not all "application-specific things" are necessarily going to be compatible with all attention's implementation.

An illustrative example might be making visualizations of the attention maps, as we see in the Prompt to Prompt paper:
squirrel-attention-maps
or in StructureDiffusion:
structure-diffusion-attention-map

The visualization is not something that wants to control attention. It still wants to use the most efficient implementation available.

Disclosure: attention map visualizations are literally a feature InvokeAI wants to implement (as does https://github.com/JoaoLages/diffusers-interpret) but I think the interpretability it provides is of general interest. I'm not just trying to trick you in to writing InvokeAI-specific features as examples. 😉

@damian0815
Copy link
Contributor

Therefore a nice API that is both somewhat clean and flexible is to just let the user write "CrossAttentionProcessor" classes that are by default weights less and take (query_weight, key_weight, value_weight) as an entry which you can see in this PR.

This is good! However, as keturn says:

The visualization is not something that wants to control attention. It still wants to use the most efficient implementation available.

It would indeed be best if the CustomCrossAttentionProcessor could optionally call parts of the default CrossAttentionProcessor's functionality -- which would be exposed in a modular way -- to get the benefits of whatever clever optimisations HuggingFace put in there. Of course we can always just copy/paste the existing code, but this leaves us downstream users with the burden of maintaining compatibility as Diffusers code changes upstream.

@damian0815
Copy link
Contributor

damian0815 commented Dec 10, 2022

if all previously trained weights stay the same and if all previously trained weights are used

@patrickvonplaten one thing we are looking at using is LoRA, which trains "the residual" of the weights to apparently produce dreambooth-quality training in 3-4MB of shipped data:

... not all of the parameters need tuning: they found that often, Q, K, V, O (i.e., attention layer) of the transformer model is enough to tune. (This is also the reason why the end result is so small). This repo will follow the same idea.

from https://github.com/cloneofsimo/lora/blob/master/scripts/run_inference.ipynb, which monkey-patches the default Diffusers stable diffusion pipeline pipe:
Screen Shot 2022-12-10 at 13 12 18

i can't claim to understand the math well enough to know exactly what that means, just thought i should flag it as one use-case we are looking into.

@patrickvonplaten
Copy link
Contributor Author

@anton-l @pcuenca @williamberman could you give this a review?

@pcuenca
Copy link
Member

pcuenca commented Dec 11, 2022

Looks good! I think it's very compelling to model the existing xFormers and sliced attention optimizations as just instances of the new "cross-attention processor" class. I also think it would still be useful for some downstream applications to take advantage of them when possible. I've been thinking about how to achieve that, but unfortunately I currently don't see a clear path forward that is not overly complicated.

I'd start with this proposed solution to support a whole new type of applications, and then maybe we can later expand to "observer-only" callbacks that can be added independently of processors and can coexist with them. This way we could also support visualizers or loggers as an additional family of components.

In terms of the API, I like it in general and it looks clear to me. I wonder if it'd make sense to do the following instead of passing all the arguments to __call__:

  • Keep batch_to_head_dim, head_to_batch_dim in CrossAttention.
  • Implement get_attention_scores in CrossAttention too.
  • Pass the CrossAttention instance to __call__ as something like xattn, so the processor can use it like:
key = xattn.to_k(original_text_embeddings)
value = xattn.to_v(original_text_embeddings)
# etc
attention_probs.append(xattn.get_attention_scores(query, key))

This requires minimal changes to the current class, users don't have to look into another file for those implementations, and processors have information about other details they might potentially need from CrossAttention. The mixin is just doing stuff the CrossAttention module used to do.

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design wraps the existing attention variants very nicely! This way the number of set_*_attention() functions won't grow, and we won't have too much branching in the attn.forward().
The issue of mutability could be solved by overriding setattr to catch modifications to Parameters, but I'm not sure what the advanced use cases are

Comment on lines 564 to 568
def set_cross_attn_proc(self, attn_proc: CrossAttentionProcMixin):
if not isinstance(attn_proc, CrossAttentionProcMixin):
subclass = attn_proc.__bases__ if hasattr(attn_proc, "__bases__") else None
raise ValueError(f"`attn_proc` should be a subclass of {CrossAttentionProc}, but is of type {type(attn_proc)} and a subclass of {subclass}.")
self.attn_proc = attn_proc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit of a python anti-pattern here to manually check the baseclasses include the mixin, no?

Since the transformer class doesn't use any of the mixin internals, maybe just change to type signature to Callable and remove the manual type check?

@williamberman
Copy link
Contributor

Love this

The mutability point is interesting and does seem like it might run into some trouble down the road but given none of the existing use cases rely on mutable processors, I would be ok with just moving forward with this as is

@DavidePaglieri
Copy link

Hi, this change would be super useful!

I have a question though, is it possible to tweak the attention map like the Prompt-to-Prompt paper suggests, as well as using memory efficient attention at the same time?

From what I understand, prompt-to-prompt works by multiplying the attention map with the modified token weights, but memory efficient attention changes the formula and doesn't explicitly calculate the attention map in the same way. Am I understanding it wrong? Is it possible to do both things together?

@patrickvonplaten
Copy link
Contributor Author

if all previously trained weights stay the same and if all previously trained weights are used

@patrickvonplaten one thing we are looking at using is LoRA, which trains "the residual" of the weights to apparently produce dreambooth-quality training in 3-4MB of shipped data:

... not all of the parameters need tuning: they found that often, Q, K, V, O (i.e., attention layer) of the transformer model is enough to tune. (This is also the reason why the end result is so small). This repo will follow the same idea.

from https://github.com/cloneofsimo/lora/blob/master/scripts/run_inference.ipynb, which monkey-patches the default Diffusers stable diffusion pipeline pipe: Screen Shot 2022-12-10 at 13 12 18

i can't claim to understand the math well enough to know exactly what that means, just thought i should flag it as one use-case we are looking into.

That's a very good point! LoRA looks indeed very promising and it seems to adapt all linear layers of the CrossAttention module (see here: https://github.com/cloneofsimo/lora/blob/26787a09bff4ebcb08f0ad4e848b67bce4389a7a/lora_diffusion/lora.py#L177) so maybe we should allow to plug-in the whole CrossAttentionClass right away?

@patrickvonplaten
Copy link
Contributor Author

Hi, this change would be super useful!

I have a question though, is it possible to tweak the attention map like the Prompt-to-Prompt paper suggests, as well as using memory efficient attention at the same time?

From what I understand, prompt-to-prompt works by multiplying the attention map with the modified token weights, but memory efficient attention changes the formula and doesn't explicitly calculate the attention map in the same way. Am I understanding it wrong? Is it possible to do both things together?

The problem with memory efficient attention here is that the whole attention operation is done by xformers highly optimized attention function which doesn't expose the internals such as the QK^T maps.

tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
self.attn_proc = CrossAttentionProc(self.heads, self.upcast_attention)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also add this directly in the BasicTransformerBlock and so that the user has to go one less level deep.

self.attn_proc = attn_proc

def forward(self, hidden_states, context=None, cross_attention_kwargs=None):
# attn
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole forward function should probably become the attn_proc so that LORA can be nicely supported as well.


class CrossAttentionProc(CrossAttentionProcMixin):

def __call__(self, hidden_states, query_proj, key_proj, value_proj, context=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass all weigths - note in order to support LORA we need to add two more layers (proj and dropout)

Comment on lines 32 to 33
def __call__(self, hidden_states, query_proj, key_proj, value_proj, context=None):
raise NotImplementedError("Make sure this method is overwritten in the subclass.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider @abc.abstractmethod for template methods that must be overridden.

@patrickvonplaten patrickvonplaten merged commit 4125756 into main Dec 20, 2022
@@ -391,6 +391,63 @@ def check_slicable_dim_attr(module: torch.nn.Module):
for module in model.children():
check_slicable_dim_attr(module)

def test_special_attn_proc(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

example of how to use the class

@patil-suraj patil-suraj deleted the allow_special_cross_attention branch December 20, 2022 21:48
sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
… function (huggingface#1639)

* first proposal

* rename

* up

* Apply suggestions from code review

* better

* up

* finish

* up

* rename

* correct versatile

* up

* up

* up

* up

* fix

* Apply suggestions from code review

* make style

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* add error message

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
@bonlime
Copy link
Contributor

bonlime commented Dec 23, 2022

I'm pretty late to this discussion, but there is an implementation of Prompt-to-prompt which supports using xformers: https://github.com/cccntu/efficient-prompt-to-prompt
tldr: it could be achieved by passing both prompts simultaneously as key / value in CrossAttention. previously it required applying a patch to diffusers, but with new design it could be replaced by writing a custom Processor class, which is better.

Another +1 to @patrickvonplaten for good PR and design

@patil-suraj
Copy link
Contributor

Thanks a lot for sharing this @bonlime , cc @kashif

@kashif
Copy link
Contributor

kashif commented Dec 23, 2022

thanks @bonlime yes I belive then using this technique we can make a simple Processor which takes a tuple for the encoder_hidden_states and passes it to the appropriate projection... let me try that out

@hafriedlander
Copy link
Contributor

Just noticed one small backwards-incompatible API change here I think - before this change, enabling xformers would override slicing, so you could safely do both in either order and xformers would always be used.

Now, if you enable xformers and then slicing, slicing will take precedence, so the order of enabling becomes important.

@patil-suraj
Copy link
Contributor

Good catch @hafriedlander! yeah, slicing was ignored if xformers was enabled before this, but I think silent ignoring is not good. Maybe we could log/warn when the attention method is being overridden. That said, IMO the current API is good, if call slice attention then slice attention will be used rather than silently ignoring it.
cc @patrickvonplaten @pcuenca @anton-l

@pcuenca
Copy link
Member

pcuenca commented Dec 26, 2022

I agree with @patil-suraj. We could maybe clarify the behaviour in the docstrings?

@kashif
Copy link
Contributor

kashif commented Dec 26, 2022

another issue with regards to slice attention is that it is, I believe not possible currently to set a different Slice attention processor in a pipeline since that requires the slice_size and the set_attention_slice helpers to set the slice size default to setting the standard processors...

@andreaferretti
Copy link

Hi all, it's great that this PR landed in 0.12, so that we can experiment with various attention modification techniques. What is exactly the API for this feature, though? The discussion started with a proposal, but I am not sure wht the final API looks like.

Here there is an example of implementing Attend and Excite as an attention processor, but some points are still a little obscure to me, for instance:

  • what is exactly the API for these controllers? How do they get registered? I see that @evinpinar defines their own register_attention_control function
  • the example here uses model.set_attn_processor instead
  • what about the attention store that@evinpinar is using? Again, it seems that they are implementing their own, is it something provided by diffusers?

In short, even a few lines on how these classes are expected to be defined/used would go a long way!

@evinpinar
Copy link
Contributor

evinpinar commented Feb 7, 2023

Hi @andreaferretti,

Here there is an example of implementing Attend and Excite as an attention processor, but some points are still a little obscure to me, for instance:

* what is exactly the API for these controllers? How do they get registered? I see that @evinpinar defines their own `register_attention_control` function

* the example [here](https://github.com/huggingface/diffusers/pull/1639/files/9d5e5ca9b18b55d864d931a4b6199c99065c5cd7#diff-44dfc935910f5504cfa2bb02e5a4313cfdc061a6131b47f93b31b7d41422fd25) uses `model.set_attn_processor` instead

In the sample I've provided for the Attend-and-Excite paper, we need to set a processor only on specific CrossAttention layers. I also use the model.set_attn_processor , within the register_attention_control. See here

* what about the attention store that@evinpinar is using? Again, it seems that they are implementing their own, is it something provided by diffusers?

Attention store is an accumulator of the probabilities at specific crossatnn layers with specific resolutions after a forward pass on UNet. The attention probabilities then get optimized, depending on the values of each other. As far as I understand, it is not possible to do this optimization without such an accumulator/optimization control. Diffusers library enable tweaking and changing the functionality of attention and access the values, yet not sure if we can achieve the optimization without this additional store api.

In short, even a few lines on how these classes are expected to be defined/used would go a long way!

I agree, it would be very useful!

@damian0815
Copy link
Contributor

damian0815 commented Feb 7, 2023

@andreaferretti i've successfully used the AttnProcessor api in InvokeAI - see for example SlicedSwapCrossAttnProcesser which gets used like this.

@andreaferretti
Copy link

andreaferretti commented Feb 7, 2023

Hi @evinpinar, thank you for your prompt response!

So, one thing I gather from your example is that set_attn_processor accepts a dictionary mapping layer names to processors, and will use the corresponding processor on that specific layer, right? The other example here just calls

processor = AttnEasyProc(5.0)
model.set_attn_processor(processor)

which I can only assume will call the same processor on every layer. Are there any other overloads of set_attn_processor? Anyway, these two should be enough!

Another peculiarity of the API that i gather from the above example is that one can pass extra kwargs as a dictionary, like

model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample

and said kwargs will apparently be passed when calling the processor in fact the signature of __call__ there is

def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):

The last piece of the puzzle I am not sure about is to what extent one needs to reproduce the "normal" attention mechanism in processors. That is, even the simple processor in the example has the usual attention computation

def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
      batch_size, sequence_length, _ = hidden_states.shape
      attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

      query = attn.to_q(hidden_states)

      encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
      key = attn.to_k(encoder_hidden_states)
      value = attn.to_v(encoder_hidden_states)

      query = attn.head_to_batch_dim(query)
      key = attn.head_to_batch_dim(key)
      value = attn.head_to_batch_dim(value)

      attention_probs = attn.get_attention_scores(query, key, attention_mask)
      hidden_states = torch.bmm(attention_probs, value)
      hidden_states = attn.batch_to_head_dim(hidden_states)

      # linear proj
      hidden_states = attn.to_out[0](hidden_states)
      # dropout
      hidden_states = attn.to_out[1](hidden_states)

      return hidden_states

So I assume that the computation taking place in these processors will replace the default attention computation, instead of, say, augment it in some way. In other words, every processor will have to copy this first and then modify the flow to achieve whatever it needs to do, instead of getting the already computed attention maps and just having to possibly modify them.

I think this is the case, but I am just asking if there are gross misunderstandings.

Excuse me if my questions seem naive, but this monkey patching of attention maps is already delicate enough, and without some documentation on the exact API it is hard to get started.

@andreaferretti
Copy link

Hi @damian0815 , thank you, it is very useful to have another example to learn from!!

@patil-suraj
Copy link
Contributor

Hey @andreaferretti, Those are really good questions. It would be awesome if you could open an issue. We definitely want to improve the documentation for this. cc @patrickvonplaten

@evinpinar
Copy link
Contributor

Hi andreaferretti, the way I understood and use the processors aligns with your explanation. But I'm sure there can be many exciting ways to use the processors.

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
… function (huggingface#1639)

* first proposal

* rename

* up

* Apply suggestions from code review

* better

* up

* finish

* up

* rename

* correct versatile

* up

* up

* up

* up

* fix

* Apply suggestions from code review

* make style

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* add error message

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.