-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add clip skip for diffusion pipeline #3212
Comments
Hey @NormXU, you could also just do the following no: # we skip one layer of the encoder
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder", num_hidden_layers=11, torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(checkpoint, torch_dtype=torch.float16) By loading the text encoder only with 11 layers you are skipping the final layer. |
Hey @patrickvonplaten, thank you for your quick reply. Setting num_hidden_layers while initializing the text encoder is a good choice. However, I think it will be easier to use if we can set @torch.no_grad()
def __call__(
self,
clip_skip: Optional[int] = 1,
): Or, we need to initialize a text encoder every time we want to change the value of |
@NormXU, can you explain a bit more when people would want to change |
@patrickvonplaten I usually change the But you are right, in my use cases, it seems that it is not necessary to introduce a new parameter into the runtime. 😂 CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder", num_hidden_layers=11, torch_dtype=torch.float16) is good and convenient enough to use. |
clip_skip 2 is used generally with models like Anything. using clip_skip 2 has shown positive results during image generation compared to clip_skip 1 for anime-related models. However, most users tend to use either clip_skip 1 and 2 to ideate, and using the num_hidden_layers=11 should suffice. |
Any guidance on how to use this when loading models from .safetensor files:
the clip_model=clip_model argument is being ignored |
Where do you see |
Oh yeah, that seems to have worked. I've got a follow up question on this though. I'm running into this error when combining this with compel, I've disabled the truncation of the prompts but am running into this error on this line: |
Hello. Given the recent introduction of new functionality, our existing approach may no longer be viable. Consider the implementation of class LoraLoaderMixin:
....
def load_lora_weights(...)
.... Should there be a mismatch between the keys of LoRA weights and the ...
File "/opt/conda/envs/backend-new/lib/python3.7/site-packages/diffusers/loaders.py", line 882, in load_lora_weights
lora_scale=self.lora_scale,
File "/opt/conda/envs/backend-new/lib/python3.7/site-packages/diffusers/loaders.py", line 1148, in load_lora_into_text_encoder
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
ValueError: failed to load text encoder state dict, unexpected keys: ['text_model.encoder.layers.11.mlp.fc1.lora_linear_layer.down.weight', ... Given this scenario, I propose two potential solutions:
|
How can i define CLIPTextModel? |
Hmm is clip_skip really being that important as a feature? I still haven't seen a use case where a model produces better results with clip_skip different to what it was trained on in inference. Would love to see some concrete examples. Also cc @sayakpaul and @yiyixuxu FYI |
Almost every SD1.5 anime model is based off the NAI model, which was trained for a huge number of steps at CLIP Skip 2. If you run them at CLIP Skip 1 (which AIUI is a bit of a misnomer since it doesn't skip any layers, but whatever), you'll often get weird glitches in fine details or strange fractals in the background; it's been the source of quite some annoyance with some merges/finetunes. A couple of not particularly great examples: The second one is a better example; at CLIP skip 1 the model failed to pick up on the |
Is this specific to anime models only? How does this approach generalize to other kinds of models? The solution shown in the OP seems like the best approach as we cannot really change things at the end of If we can see some diverse and varied examples where using CLIP Skip has been truly crucial, I think the original solution proposed in the OP could be directly incorporated. |
Do you have any reproducible code @neggles ? |
Sorry for the slow reply, have a lot of irons in the fire 😅
It's most effective for anime models, since they were largely trained using the penultimate layer states, but the same approach works just fine on base SD1.5 (albeit with rather mixed results) as well as on models like OpenJourney v4:
(This was generated using A1111, but the Diffusers outputs should be similar) It's also worth noting that SD2.1 was trained entirely using CLIP penultimate layer states (on an unfortunately undertrained text encoder, but that's beside the point).
My implementation of the solution OP proposed is used in my AnimateDiff fork - seems to work fine. It's almost a straight copy, but I tried to stick a little closer to the existing Diffusers code style. I just instantiate the pipeline by loading the tl;dr yeah, it's primarily an anime model thing, but it just comes down to how the model was trained, and an awful lot of models - even realism-focused ones - are based off anime models / trained with CLIP penultimate layer states. It's really just another tweakable knob 🤷 |
@neggles what about the other inference time parameters?
etc.? |
very interested in this - I'm struggling currently to get this to work with single_files checkpoints. |
while this seem to work, the problem of this approach is that one has to guess how many layers there are, and as far my knowledge of python goes (not that much) couldn't find a proper way to do so. |
@sayakpaul see last line of code block - |
This feature does seem to be requested quite a bit now - design-wise we have three options:
def set_clip_output_layer(output_layer_idx: int):
pass which would allow the user to do: pipe = DiffusionPipeline.from_pretrained("...")
pipe.set_clip_output_layer(...) => This would be a relatively simply PR where we only have to add the function to one SD pipeline and can then copy it to all other pipelines. Wdyt? @neggles @sayakpaul ? |
@patrickvonplaten I appreciate the approach with the setter method, but I have a couple of suggestions that might further optimize this process: I suggest we target Secondly, it's crucial to ensure the robustness of the |
Good points! Note however that the |
@patrickvonplaten I think this feature is fun for play and valuable for research. For playing purposes: @neggles has demonstrated how For research purposes: SDXL has incorporated this trick into its architecture:
It looks like they also noticed the problem with the text encoder and solved it by concatenating two penultimate text encoder outputs along the channel axis. This is also a 'clip_skip' trick. SDXL paper link Notably, recently VLM(Visual-Language Model), such as LLaVa, BLIVA, also use this trick to align the penultimate image features with LLM, which they claim can give better results.
-- from BLIVA paper |
Certainly, while |
This seems like a reasonable approach to me; keeps from having to fiddle with Transformers directly and is straightforward to use/implement. It would be nicer to have it as an argument for Come to think of it, this could also be quite handy for training use cases (wrapping the pipeline in a trainer class is a common way to approach that) depending on how it's implemented. I had a whole thing I was going to drop in here summarizing the argument in favour from this end, but based on #4834 it seems like I don't need to bother 😆 |
I met the similar issue and my solution is something like this: from safetensors.torch import load_file
ckpt_path = '/path/to/ckpt.safetensors'
state_dict_lora = load_file(ckpt_path, device='cpu')
new_state_dict_lora = {}
for k_, v_ in state_dict_lora.items():
invalid_key = any(f'text_model_encoder_layers_{11 - layer_idx_}_' in k_ for layer_idx_ in range(clip_skip))
if not invalid_key:
new_state_dict_lora[k_] = v_
pipe.load_lora_weights(new_state_dict_lora) It works well for the method |
@NormXU you can already do this like this:
|
PR has been merged: #3870 , please try it out |
Example with clipskip 1 vs 2. |
Introduction
clip skip is a trick to feed the early-stopped features encoded by
CLIPTextModel
into the cross-attention. Ifclip_skip = 2
, it means that we want to use the features from the layer before the last of the clip text encoder to guide our image generation. And our current diffusion pipeline can be regarded asclip_skip = 1
, which means that we just use the feature from the last layer of clip text encoder.Here is a brief introduction to clip skip webui-wiki and related discussion link
A dominant majority of models need
clip_skip=2
to reach a more aesthetic generation. I think adding this feature can give people more choices to optimize their generation.Implementation
Adding clip_skip into diffusers is both simple and difficult. The main idea of clip_skip is simple, however, since our text encoder is imported from transformers, it is not easy to hack the
CLIPTexModel
in diffusers.To do so, we need to overwrite
CLIPTextModel
andCLIPTextTransformer
. Here is my implementation:We can then use the overwritten clip_text_encoder in any
_encode_prompt
function of the diffuser pipeline. For example,in
pipeline_stable_diffusion.py
Implementing
CLIPTextTransformer
andCLIPTextModel
to support clip_skip cleanly and nicely is difficult for me. I'd like to leave this issue to the diffusers team.The text was updated successfully, but these errors were encountered: