Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SD-XL] Passing prompt_embeds/negative_prompt_embeds requires also passing pooled_prompt_embeds/negative_pooled_prompt_embeds #4043

Closed
m-a-r-v-i-n opened this issue Jul 11, 2023 · 9 comments
Labels
bug Something isn't working

Comments

@m-a-r-v-i-n
Copy link

Describe the bug

When calling the StableDiffusionXLPipeline and passing the prompts as embeddings using the prompt_embeds/negative_prompt_embeds parameters, an error is generated requiring the pooled_prompt_embeds/negative_pooled_prompt_embeds parameters also to be passed.
There seems to be no documentation as to what these parameters should be (how the embeddings should be pooled). The documentation of the encode_prompt function states the following:

pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
                If not provided, pooled text embeddings will be generated from `prompt` input argument. 

negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
                input argument.

However, the check_inputs function is flagging an error when the pooled embeddings are not passed, and I couldn't find any code that automatically generates the pooled embeddings:

if prompt_embeds is not None and pooled_prompt_embeds is None:
            raise ValueError(
                "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
            )

        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
            raise ValueError(
                "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
            )

Reproduction

from diffusers import DiffusionPipeline
from compel import Compel

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
pipe.to("cuda")
compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, truncate_long_prompts=False)

prompt = "Testing - this is just some dummy text to emulate a given prompt input which exceeds the token limit of CLIP, which is a fixed 77. For this reason I am trying to use the Compel library to circumvent the prompt truncation." # very long prompt

conditioning= compel.build_conditioning_tensor(prompt)
negative_conditioning= "" # a negative prompt is required, even if empty
negative_prompt = compel.build_conditioning_tensor(negative_prompt)
[conditioning, negative_conditioning] = compel.pad_conditioning_tensors_to_same_length([conditioning, negative_conditioning])

unrefined_img = pipe(prompt_embeds=conditioning, negative_prompt_embeds=negative_conditioning, output_type="latent").images

Logs

No response

System Info

  • diffusers version: 0.18.1
  • Platform: Linux-5.15.0-69-generic-x86_64-with-glibc2.31
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Huggingface_hub version: 0.16.4
  • Transformers version: 4.30.2
  • Accelerate version: 0.20.3
  • xFormers version: not installed
  • Using GPU in script?: Yes (2xGPUs)

Who can help?

@patrickvonplaten
@sayakpaul
@williamberman

@m-a-r-v-i-n m-a-r-v-i-n added the bug Something isn't working label Jul 11, 2023
@sayakpaul
Copy link
Member

Thanks for reporting.

The encode_prompt() utility shows how to compute the pooled embeddings.

Also, could you please try to reproduce the issue without introducing an external dependency? That would be helpful.

Cc: @patrickvonplaten

@m-a-r-v-i-n
Copy link
Author

m-a-r-v-i-n commented Jul 12, 2023

Thanks for your reply @sayakpaul.

Re the external dependency, I just added it here because it easily generates the prompt embeddings without adding lengthier code (which I thought was beyond the scope of the issue at hand). In fact you can add a dummy Tensor instead of the prompt embedding and the pooled embeds will still get requested.

Re the encode_prompt() function - you mean that I can call this function directly and pass it the prompt embeddings to generate the pooled embeddings? If so, wouldn't it make sense for check_inputs() to do that automatically when the pooled embeddings are not passed?

@sayakpaul
Copy link
Member

If so, wouldn't it make sense for check_inputs() to do that automatically when the pooled embeddings are not passed?

I think we would want to raise an error here rather than silently do that. This will allow the user to take complete control of what they want to pass as the embeddings.

@sayakpaul
Copy link
Member

Hi @m-a-r-v-i-n, I prepared a Colab Notebook here to shows how to precompute the embeddings and using them in the pipeline call.

This is what I did:

  1. Loaded the pipeline:
from diffusers import DiffusionPipeline
import torch

ckpt_id = "stabilityai/stable-diffusion-xl-base-0.9"
dtype = torch.float16
device = "cuda"

pipeline = DiffusionPipeline.from_pretrained(
    ckpt_id, torch_dtype=dtype, safety_checker=None
).to(device)
  1. Computed the embeddings with the encode_prompt() method:
prompt = "a lion posing in the style of a cat"
num_images_per_prompt = 1

(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(prompt, device, num_images_per_prompt=num_images_per_prompt)
  1. And then used them in the pipeline call:
image = pipeline(
    prompt_embeds=prompt_embeds,
    pooled_prompt_embeds=pooled_prompt_embeds,
    negative_prompt_embeds=negative_prompt_embeds,
    negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
    num_inference_steps=20,
).images[0]
image

I hope this helps.

@m-a-r-v-i-n
Copy link
Author

This works, thanks a lot @sayakpaul

@sayakpaul
Copy link
Member

Closing the issue then :-)

@Rmond
Copy link

Rmond commented Aug 4, 2023

I think this issue has not been resolved.The Compel is to solve 77 token limits. if use the encode_prompt,it's also throw an error:
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens:
like:#2136

@sayakpaul
Copy link
Member

It's not an error, it's a warning.

@sayakpaul sayakpaul reopened this Aug 4, 2023
@Rmond
Copy link

Rmond commented Aug 4, 2023

It's not an error, it's a warning.
All right,for the program this is warning,is working well,but the result was not what I wanted. so how can I avoid this problem,thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants