-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Core] Fix/pipeline without text encoders for SDXL #5301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
conversely, pipeline fails when using text embeds, and text_encoder (and _2) are unavailable |
Could you be a bit more specific here? Happy to try to fix it here. Edit: Ah I know what you mean. |
|
The documentation is not available anymore as the PR was closed or merged. |
|
Huh! from diffusers import StableDiffusionXLPipeline
import torch
prompt = "hey"
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=None,
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds
) = pipe.encode_prompt(prompt)
del pipe
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
call_args = dict(
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_images_per_prompt=1,
num_inference_steps=2
)
_ = pipe(**call_args)Does not work: [/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py](https://localhost:8080/#) in __call__(self, prompt, prompt_2, height, width, num_inference_steps, denoising_end, guidance_scale, negative_prompt, negative_prompt_2, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, output_type, return_dict, callback, callback_steps, cross_attention_kwargs, guidance_rescale, original_size, crops_coords_top_left, target_size, negative_original_size, negative_crops_coords_top_left, negative_target_size, clip_skip)
851 # 7. Prepare added time ids & embeddings
852 add_text_embeds = pooled_prompt_embeds
--> 853 add_time_ids = self._get_add_time_ids(
854 original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
855 )
[/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py](https://localhost:8080/#) in _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype)
544
545 passed_add_embed_dim = (
--> 546 self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
547 )
548 expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
AttributeError: 'NoneType' object has no attribute 'config'Looking into it. |
|
@patrickvonplaten let me know if the proposed solution works for you. If so, I will go ahead and propagate the changes along with adding a test. |
|
What is this needed for? I don't think one should be allowed to create a StableDiffusionPipeline instance without a unet, this doesn't make much sense. If one wants to just encode the inputs, we should probably follows @williamberman's suggestion here and try to make the Overall I'm curious though to know when a pipeline without unet would be needed? |
|
Much like how we allow the IF pipeline to be loaded with a Additionally, the example provided in #5301 (comment) should already help explain the use case. |
|
in SimpleTuner, we have to keep the text encoder "loaded" just so that the checks in the SDXL pipeline do not fail during validation, despite passing in the negative/positive embeds/conditionings this would acceptably solve the problem there, instead of consuming system memory to keep the model loaded. i'm sure there's another way i could nuke the weights from orbit, but this seems like it's in alignment with Kandinsky and DeepFloyd and other pipelines that we/I support. |
|
We don't allow loading IF with a unet. We allow loading IF without a text encoder as the text encoder is the optional component list here:
We should not allow SDXL to be loaded without a unet. The unet is the heart of SDXL and it makes no sense to load a pipeline without it (the pipeline can then not be used at all). We can allow loading SDXL without text encoders (if we believe that this is not an edge case), which would mean we should do:
Also cc @williamberman here |
From the blog post: from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0",
text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
unet=None,
device_map="auto"
)I don't follow your reasoning or the design you are proposing much. In particular, the "edge case" part. If I am missing something please elaborate a bit more about the APIs you're envisioning. Why shouldn't a user be allowed to precompute the text embeddings (both pooled and non-pooled) with the text encoders (without loading the UNet) and then reuse it while calling the SDXL pipeline without loading the text encoders? This saves memory, if I am not mistaken. The flow I have in mind for the users is exactly the one I showed in #5301 (comment). Are you essentially saying we should instead do the following? from diffusers import StableDiffusionXLPipeline
import torch
prompt = "hey"
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds
) = pipe.encode_prompt_class_method(prompt)
del pipe.text_encoder, pipe.text_encoder_2, pipe.tokenizer, pipe.tokenizer_2
call_args = dict(
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_images_per_prompt=1,
num_inference_steps=2
)
_ = pipe(**call_args)I am okay with it but this would require another library-wide refactoring / addition of a class-based |
|
@patrickvonplaten thought about a bit and here's the trade-off (developer exp. wise) I okay having. Note that the use case we're trying to target here is to allow people to load a pipeline without the text encoders for memory savings. I have updated the PR title and the description accordingly. So, the flow now becomes:
In code, it looks something like this. First, load up the text encoders along with their tokenizers: import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
torch_dtype = torch.float16
# load the text encoders and tokenizers
text_encoder = CLIPTextModel.from_pretrained(pipe_id, subfolder="text_encoder", torch_dtype=torch.float16).to("cuda")
tokenizer = CLIPTokenizer.from_pretrained(pipe_id, subfolder="tokenizer")
text_encoder_2 = CLIPTextModel.from_pretrained(pipe_id, subfolder="text_encoder_2", torch_dtype=torch.float16).to("cuda")
tokenizer_2 = CLIPTokenizer.from_pretrained(pipe_id, subfolder="tokenizer_2")Then repurpose the def encode_prompt(tokenizers, text_encoders, prompt: str, negative_prompt: str = None):
device = text_encoders[0].device
if isinstance(prompt, str):
prompt = [prompt]
batch_size = len(prompt)
prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
if negative_prompt is None:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
else:
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True)
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# for classifier-free guidance
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
# for classifier-free guidance
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embedstokenizers = [tokenizer, tokenizer_2]
text_encoders = [text_encoder, text_encoder_2]
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds
) = encode_prompt(tokenizers, text_encoders, prompt)Delete the text encoders: del text_encoder, text_encoder_2, tokenizer, tokenizer_2Then do the pipeline call: from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
pipe_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
_ = pipe(
prompt=None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds
)This experience isn't too bad IMO as we're still able to leverage parts of the library as per our needs. This also doesn't require us to maintain separate WDYT? |
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
Show resolved
Hide resolved
|
Ok for me to allow loading the unet without text encoder! Let's make sure though that in this case:
|
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
| PipelineLatentTesterMixin, | ||
| PipelineKarrasSchedulerTesterMixin, | ||
| PipelineTesterMixin, | ||
| SDXLOptionalComponentsTesterMixin, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New pipeline class to consolidate the testing of the optional components in SDXL and its derivative pipelines.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some nits, but apart from this, the PR looks good!
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
|
wow, thank you so much @sayakpaul this was a complicated one. i was going to solve it last month, but the amount of earth-moving you've done to make it happen is exactly why i couldn't get it there. good work! |
|
@patrickvonplaten I had to do 38e16f8. The nits you suggested couldn't correctly instantiate the pipeline properly. |
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
| from diffusers.utils import deprecate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to be imported. @patrickvonplaten just as an FYI.
* fix: sdxl pipeline when unet is not available. * fix moe * account for text * ifx more * don't make unet optional. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * split conditionals. * add optional components to sdxl pipeline * propagate changes to the rest of the pipelines. * add: test * add to all * fix: rest of the pipelines. * use pipeline_class variable * separate pipeline mixin * use safe_serialization * fix: test * access actual output. * add: optional test to adapter and ip2p sdxl pipeline tests/ * add optional test to controlnet sdxl. * fix tests * fix ip2p tests * fix more * fifx more. * use np output type. * fix for StableDiffusionXLMultiControlNetPipelineFastTests. * fix: SDXLOptionalComponentsTesterMixin * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix tests * Empty-Commit * revert previous * quality * fix: test --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* fix: sdxl pipeline when unet is not available. * fix moe * account for text * ifx more * don't make unet optional. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * split conditionals. * add optional components to sdxl pipeline * propagate changes to the rest of the pipelines. * add: test * add to all * fix: rest of the pipelines. * use pipeline_class variable * separate pipeline mixin * use safe_serialization * fix: test * access actual output. * add: optional test to adapter and ip2p sdxl pipeline tests/ * add optional test to controlnet sdxl. * fix tests * fix ip2p tests * fix more * fifx more. * use np output type. * fix for StableDiffusionXLMultiControlNetPipelineFastTests. * fix: SDXLOptionalComponentsTesterMixin * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix tests * Empty-Commit * revert previous * quality * fix: test --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* fix: sdxl pipeline when unet is not available. * fix moe * account for text * ifx more * don't make unet optional. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * split conditionals. * add optional components to sdxl pipeline * propagate changes to the rest of the pipelines. * add: test * add to all * fix: rest of the pipelines. * use pipeline_class variable * separate pipeline mixin * use safe_serialization * fix: test * access actual output. * add: optional test to adapter and ip2p sdxl pipeline tests/ * add optional test to controlnet sdxl. * fix tests * fix ip2p tests * fix more * fifx more. * use np output type. * fix for StableDiffusionXLMultiControlNetPipelineFastTests. * fix: SDXLOptionalComponentsTesterMixin * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix tests * Empty-Commit * revert previous * quality * fix: test --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This PR fixes the SDXL pipeline so that users can:
Please follow #5301 (comment) for a full-fledged example.