-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor cross attention and allow mechanism to tweak cross attention function #1639
Conversation
Very common processor classes such as prompt-to-prompt could also be natively added to |
Looks promising! I think separating out xformers and sliced attention like this helps readers as well as any code-tracers, as it no longer has those if-branches in the I'll leave it to damian to comment on whether the API is sufficient for the types of applications he has in mind. But the thing I'd be wondering about is: What if I want to do my application-specific thing and also take advantage of the library's efficient implementations of attention? Maybe that question is too vague without a concrete example. Because I know not all "application-specific things" are necessarily going to be compatible with all attention's implementation. An illustrative example might be making visualizations of the attention maps, as we see in the Prompt to Prompt paper: The visualization is not something that wants to control attention. It still wants to use the most efficient implementation available. Disclosure: attention map visualizations are literally a feature InvokeAI wants to implement (as does https://github.com/JoaoLages/diffusers-interpret) but I think the interpretability it provides is of general interest. I'm not just trying to trick you in to writing InvokeAI-specific features as examples. 😉 |
This is good! However, as keturn says:
It would indeed be best if the CustomCrossAttentionProcessor could optionally call parts of the default CrossAttentionProcessor's functionality -- which would be exposed in a modular way -- to get the benefits of whatever clever optimisations HuggingFace put in there. Of course we can always just copy/paste the existing code, but this leaves us downstream users with the burden of maintaining compatibility as Diffusers code changes upstream. |
@patrickvonplaten one thing we are looking at using is LoRA, which trains "the residual" of the weights to apparently produce dreambooth-quality training in 3-4MB of shipped data:
from i can't claim to understand the math well enough to know exactly what that means, just thought i should flag it as one use-case we are looking into. |
@anton-l @pcuenca @williamberman could you give this a review? |
Looks good! I think it's very compelling to model the existing xFormers and sliced attention optimizations as just instances of the new "cross-attention processor" class. I also think it would still be useful for some downstream applications to take advantage of them when possible. I've been thinking about how to achieve that, but unfortunately I currently don't see a clear path forward that is not overly complicated. I'd start with this proposed solution to support a whole new type of applications, and then maybe we can later expand to "observer-only" callbacks that can be added independently of processors and can coexist with them. This way we could also support visualizers or loggers as an additional family of components. In terms of the API, I like it in general and it looks clear to me. I wonder if it'd make sense to do the following instead of passing all the arguments to
key = xattn.to_k(original_text_embeddings)
value = xattn.to_v(original_text_embeddings)
# etc
attention_probs.append(xattn.get_attention_scores(query, key)) This requires minimal changes to the current class, users don't have to look into another file for those implementations, and processors have information about other details they might potentially need from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The design wraps the existing attention variants very nicely! This way the number of set_*_attention()
functions won't grow, and we won't have too much branching in the attn.forward()
.
The issue of mutability could be solved by overriding setattr
to catch modifications to Parameter
s, but I'm not sure what the advanced use cases are
src/diffusers/models/attention.py
Outdated
def set_cross_attn_proc(self, attn_proc: CrossAttentionProcMixin): | ||
if not isinstance(attn_proc, CrossAttentionProcMixin): | ||
subclass = attn_proc.__bases__ if hasattr(attn_proc, "__bases__") else None | ||
raise ValueError(f"`attn_proc` should be a subclass of {CrossAttentionProc}, but is of type {type(attn_proc)} and a subclass of {subclass}.") | ||
self.attn_proc = attn_proc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bit of a python anti-pattern here to manually check the baseclasses include the mixin, no?
Since the transformer class doesn't use any of the mixin internals, maybe just change to type signature to Callable and remove the manual type check?
Love this The mutability point is interesting and does seem like it might run into some trouble down the road but given none of the existing use cases rely on mutable processors, I would be ok with just moving forward with this as is |
Hi, this change would be super useful! I have a question though, is it possible to tweak the attention map like the Prompt-to-Prompt paper suggests, as well as using memory efficient attention at the same time? From what I understand, prompt-to-prompt works by multiplying the attention map with the modified token weights, but memory efficient attention changes the formula and doesn't explicitly calculate the attention map in the same way. Am I understanding it wrong? Is it possible to do both things together? |
That's a very good point! LoRA looks indeed very promising and it seems to adapt all linear layers of the CrossAttention module (see here: https://github.com/cloneofsimo/lora/blob/26787a09bff4ebcb08f0ad4e848b67bce4389a7a/lora_diffusion/lora.py#L177) so maybe we should allow to plug-in the whole |
The problem with memory efficient attention here is that the whole attention operation is done by xformers highly optimized attention function which doesn't expose the internals such as the QK^T maps. |
src/diffusers/models/attention.py
Outdated
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) | ||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) | ||
return tensor | ||
self.attn_proc = CrossAttentionProc(self.heads, self.upcast_attention) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could also add this directly in the BasicTransformerBlock
and so that the user has to go one less level deep.
src/diffusers/models/attention.py
Outdated
self.attn_proc = attn_proc | ||
|
||
def forward(self, hidden_states, context=None, cross_attention_kwargs=None): | ||
# attn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the whole forward function should probably become the attn_proc
so that LORA can be nicely supported as well.
|
||
class CrossAttentionProc(CrossAttentionProcMixin): | ||
|
||
def __call__(self, hidden_states, query_proj, key_proj, value_proj, context=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass all weigths - note in order to support LORA we need to add two more layers (proj and dropout)
def __call__(self, hidden_states, query_proj, key_proj, value_proj, context=None): | ||
raise NotImplementedError("Make sure this method is overwritten in the subclass.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider @abc.abstractmethod
for template methods that must be overridden.
@@ -391,6 +391,63 @@ def check_slicable_dim_attr(module: torch.nn.Module): | |||
for module in model.children(): | |||
check_slicable_dim_attr(module) | |||
|
|||
def test_special_attn_proc(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
example of how to use the class
… function (huggingface#1639) * first proposal * rename * up * Apply suggestions from code review * better * up * finish * up * rename * correct versatile * up * up * up * up * fix * Apply suggestions from code review * make style * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * add error message Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
I'm pretty late to this discussion, but there is an implementation of Prompt-to-prompt which supports using xformers: https://github.com/cccntu/efficient-prompt-to-prompt Another +1 to @patrickvonplaten for good PR and design |
thanks @bonlime yes I belive then using this technique we can make a simple Processor which takes a tuple for the |
Just noticed one small backwards-incompatible API change here I think - before this change, enabling xformers would override slicing, so you could safely do both in either order and xformers would always be used. Now, if you enable xformers and then slicing, slicing will take precedence, so the order of enabling becomes important. |
Good catch @hafriedlander! yeah, slicing was ignored if |
I agree with @patil-suraj. We could maybe clarify the behaviour in the docstrings? |
another issue with regards to slice attention is that it is, I believe not possible currently to set a different Slice attention processor in a pipeline since that requires the |
Hi all, it's great that this PR landed in 0.12, so that we can experiment with various attention modification techniques. What is exactly the API for this feature, though? The discussion started with a proposal, but I am not sure wht the final API looks like. Here there is an example of implementing Attend and Excite as an attention processor, but some points are still a little obscure to me, for instance:
In short, even a few lines on how these classes are expected to be defined/used would go a long way! |
Hi @andreaferretti,
In the sample I've provided for the Attend-and-Excite paper, we need to set a processor only on specific CrossAttention layers. I also use the
Attention store is an accumulator of the probabilities at specific crossatnn layers with specific resolutions after a forward pass on UNet. The attention probabilities then get optimized, depending on the values of each other. As far as I understand, it is not possible to do this optimization without such an accumulator/optimization control. Diffusers library enable tweaking and changing the functionality of attention and access the values, yet not sure if we can achieve the optimization without this additional store api.
I agree, it would be very useful! |
@andreaferretti i've successfully used the AttnProcessor api in InvokeAI - see for example SlicedSwapCrossAttnProcesser which gets used like this. |
Hi @evinpinar, thank you for your prompt response! So, one thing I gather from your example is that processor = AttnEasyProc(5.0)
model.set_attn_processor(processor) which I can only assume will call the same processor on every layer. Are there any other overloads of Another peculiarity of the API that i gather from the above example is that one can pass extra kwargs as a dictionary, like model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample and said kwargs will apparently be passed when calling the processor in fact the signature of def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None): The last piece of the puzzle I am not sure about is to what extent one needs to reproduce the "normal" attention mechanism in processors. That is, even the simple processor in the example has the usual attention computation def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states So I assume that the computation taking place in these processors will replace the default attention computation, instead of, say, augment it in some way. In other words, every processor will have to copy this first and then modify the flow to achieve whatever it needs to do, instead of getting the already computed attention maps and just having to possibly modify them. I think this is the case, but I am just asking if there are gross misunderstandings. Excuse me if my questions seem naive, but this monkey patching of attention maps is already delicate enough, and without some documentation on the exact API it is hard to get started. |
Hi @damian0815 , thank you, it is very useful to have another example to learn from!! |
Hey @andreaferretti, Those are really good questions. It would be awesome if you could open an issue. We definitely want to improve the documentation for this. cc @patrickvonplaten |
Hi andreaferretti, the way I understood and use the processors aligns with your explanation. But I'm sure there can be many exciting ways to use the processors. |
… function (huggingface#1639) * first proposal * rename * up * Apply suggestions from code review * better * up * finish * up * rename * correct versatile * up * up * up * up * fix * Apply suggestions from code review * make style * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * add error message Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
We are thinking about how to best support methods that tweak the cross attention computation, such as hyper networks (where linear layers that map k-> k' and v-> v' are trained), prompt-to-prompt, and other customized cross attention mechanisms.
Supporting such features poses a challenge in that we need to allow the user to "hack" into the cross attention module which is "buried" inside the unet.
We make two assumptions here:
Therefore a nice API that is both somewhat clean and flexible is to just let the user write "CrossAttentionProcessor" classes that are by default weights less and take
(query_weight, key_weight, value_weight)
as an entry which you can see in this PR.I also took this new design to refactor the cross attention layer a bit and to make xformers, sliced attentention and normal attention different "processor" classes.
Now, let's image one would like to support prompt-to-prompt. In this case one should be able to do the following:
Note this is pseudo code:
I think this design is flexible enough to allow for more complicated use cases and can also be used for training hyper-networks!
Another important point here is mutability. If I pass the same class instance to multiple layers, the class is mutable which might be desired. However it might also be desired to pass unmutable or difference proc classes to different cross attention layers. In this case, we could simple allow setting a list of processors object intsead of just one object.