Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overcoming the 77 token limit in diffusers #2136

Closed
jslegers opened this issue Jan 27, 2023 · 25 comments
Closed

Overcoming the 77 token limit in diffusers #2136

jslegers opened this issue Jan 27, 2023 · 25 comments
Labels
stale Issues that haven't received updates

Comments

@jslegers
Copy link

jslegers commented Jan 27, 2023

Description of the problem

CLIP has a 77 token limit, which is much too small for many prompts.

Several GUIs have found a way to overcome this limit, but not the diffusers library.

The solution I'd like

I would like diffusers to be able to run longer prompts and overcome the 77 token limit of CLIP for any model, much like the AUTOMATIC1111/stable-diffusion-webui already does.

Alternatives I've considered

  • I tried reverse-engineering the prompt interpretation logic from one of the other GUIs out there (not sure which one), but I couldn't find the code responsible.

  • I tried running the BAAI/AltDiffusion in diffusers, which uses AltCLIP instead of CLIP. Since AltCLIP has a max_position_embeddings value of 514 for its text encoder instead of 77, I had hoped I could just replace the text encoder and tokenizer of my models with those of BAAI/AltDiffusion to overcome the 77 token limit, but I couldn't get the BAAI/AltDiffusion to work in diffusers

Additional context

This is how the AUTOMATIC1111 overcomes the token limit, according to their documentation :

Typing past standard 75 tokens that Stable Diffusion usually accepts increases prompt size limit from 75 to 150. Typing past that increases prompt size further. This is done by breaking the prompt into chunks of 75 tokens, processing each independently using CLIP's Transformers neural network, and then concatenating the result before feeding into the next component of stable diffusion, the Unet.

For example, a prompt with 120 tokens would be separated into two chunks: first with 75 tokens, second with 45. Both would be padded to 75 tokens and extended with start/end tokens to 77. After passing those two chunks though CLIP, we'll have two tensors with shape of (1, 77, 768). Concatenating those results in (1, 154, 768) tensor that is then passed to Unet without issue.

@apolinario
Copy link
Contributor

Hey @jslegers, the Long Prompt Weighting Stable Diffusion community pipeline gets rid of the 77 token limit. You can check it out here

@jslegers
Copy link
Author

@apolinario :

I have the same question / remark I made @ #2135.

Most people aren't going to figure out on their own that there is a dedicated pipeline to get rid of the 77 token limit. I sure wasn't able to find this info until you provided me a link... and I'm a dev with more than a decade of experience.

It's also not exactly user friendly to have a dedicated pipeline for what's a pretty important feature almost every Stable Diffusion user is likely to want (since it doesn't take much to surpass 77 tokens).

So why not just bake support for +77 tokens into StableDiffusionPipeline?

@patrickvonplaten
Copy link
Contributor

Hey @jslegers,

It's true that our documentation is currently lacking behind a bit. Would you be interested in contributing a doc page about long prompting?

Also note that I would suggest to just use StableDiffusionPipeline and pass the prompt_embeds manually, e.g. the following code snippet works:

from diffusers import StableDiffusionPipeline
import torch

# 1. load model
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

# 2. Forward embeddings and negative embeddings through text encoder
prompt = 25 * "a photo of an astronaut riding a horse on mars"
max_length = pipe.tokenizer.model_max_length

input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to("cuda")

negative_ids = pipe.tokenizer("", truncation=False, padding="max_length", max_length=input_ids.shape[-1], return_tensors="pt").input_ids                                                                                                     
negative_ids = negative_ids.to("cuda")

concat_embeds = []
neg_embeds = []
for i in range(0, input_ids.shape[-1], max_length):
    concat_embeds.append(pipe.text_encoder(input_ids[:, i: i + max_length])[0])
    neg_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])

prompt_embeds = torch.cat(concat_embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)

# 3. Forward
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds).images[0]
image.save("astronaut_rides_horse.png")

Could you try out whether this fits your use case? Would you be interested in adding a doc page about long-prompting maybe under: https://github.com/huggingface/diffusers/tree/main/docs/source/en/using-diffusers

@jslegers
Copy link
Author

Also note that I would suggest to just use StableDiffusionPipeline and pass the prompt_embeds manually, e.g. the following code snippet works:

[...]

Could you try out whether this fits your use case?

It's an interesting approach and definitely more in line with what I'm looking for...

I'll need to try this on my demos and test scripts before I can comment on it further, but it looks promising as an approach for at least personal use...

I'd still argue this is a bit convoluted for something that Stable Diffusion should support out of the box, but I guess that's something RunwayML and StablilityAI should fix (by replacing CLIP with an alternative that supports more tokens) and not something the diffusers library is responsible for.

Would you be interested in adding a doc page about long-prompting maybe under: https://github.com/huggingface/diffusers/tree/main/docs/source/en/using-diffusers

I'll take that into consideration, under the condition I'm allowed to post that same content on my own blog(s) as well.

I was planning to do some tutorials on how to use Stable Diffusion anyway, so I might as well make some of that content official documentation.

@patrickvonplaten
Copy link
Contributor

Feel free to use every content of diffusers in whatever way you like :-) It's MIT licensed

@jslegers
Copy link
Author

jslegers commented Feb 1, 2023

@patrickvonplaten

Feel free to use every content of diffusers in whatever way you like :-) It's MIT licensed

Good to know...

Wasn't sure that license applied to documentation as well.

I'm not a lawyer, and I prefer to make as little assumptions as possible when it involves legal matters...

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 26, 2023
@github-actions github-actions bot closed this as completed Mar 6, 2023
@romanfurman6
Copy link

Hey @jslegers,

It's true that our documentation is currently lacking behind a bit. Would you be interested in contributing a doc page about long prompting?

Also note that I would suggest to just use StableDiffusionPipeline and pass the prompt_embeds manually, e.g. the following code snippet works:

from diffusers import StableDiffusionPipeline
import torch

# 1. load model
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

# 2. Forward embeddings and negative embeddings through text encoder
prompt = 25 * "a photo of an astronaut riding a horse on mars"
max_length = pipe.tokenizer.model_max_length

input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to("cuda")

negative_ids = pipe.tokenizer("", truncation=False, padding="max_length", max_length=input_ids.shape[-1], return_tensors="pt").input_ids                                                                                                     
negative_ids = negative_ids.to("cuda")

concat_embeds = []
neg_embeds = []
for i in range(0, input_ids.shape[-1], max_length):
    concat_embeds.append(pipe.text_encoder(input_ids[:, i: i + max_length])[0])
    neg_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])

prompt_embeds = torch.cat(concat_embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)

# 3. Forward
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds).images[0]
image.save("astronaut_rides_horse.png")

Could you try out whether this fits your use case? Would you be interested in adding a doc page about long-prompting maybe under: https://github.com/huggingface/diffusers/tree/main/docs/source/en/using-diffusers

Hey, with this example, getting such error: "Token indices sequence length is longer than the specified maximum sequence length for this model (X > 77). Running this sequence through the model will result in indexing errors"
is that okay?

@patrickvonplaten
Copy link
Contributor

Hey @romanfurman6,

Could you please open a new issue for it? :-)

@andrevanzuydam
Copy link

Hey @jslegers,

It's true that our documentation is currently lacking behind a bit. Would you be interested in contributing a doc page about long prompting?

Also note that I would suggest to just use StableDiffusionPipeline and pass the prompt_embeds manually, e.g. the following code snippet works:

from diffusers import StableDiffusionPipeline
import torch

# 1. load model
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

# 2. Forward embeddings and negative embeddings through text encoder
prompt = 25 * "a photo of an astronaut riding a horse on mars"
max_length = pipe.tokenizer.model_max_length

input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to("cuda")

negative_ids = pipe.tokenizer("", truncation=False, padding="max_length", max_length=input_ids.shape[-1], return_tensors="pt").input_ids                                                                                                     
negative_ids = negative_ids.to("cuda")

concat_embeds = []
neg_embeds = []
for i in range(0, input_ids.shape[-1], max_length):
    concat_embeds.append(pipe.text_encoder(input_ids[:, i: i + max_length])[0])
    neg_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])

prompt_embeds = torch.cat(concat_embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)

# 3. Forward
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds).images[0]
image.save("astronaut_rides_horse.png")

Could you try out whether this fits your use case? Would you be interested in adding a doc page about long-prompting maybe under: https://github.com/huggingface/diffusers/tree/main/docs/source/en/using-diffusers

I'm testing your code sample as I haven't been able to get the custom pipeline lpw_stable_diffusion to work on all computers I am testing on, I'm happy to document my findings. Thanks for the code.

@andrevanzuydam
Copy link

andrevanzuydam commented Apr 19, 2023

Just for in case someone comes accross this issue and wants a solution, I built something that works for both prompts correctly of varying lengths

from diffusers import StableDiffusionPipeline
import torch
import random

# 1. load model
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")


pipe.enable_sequential_cpu_offload() # my graphics card VRAM is very low


def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
    """ Get pipeline embeds for prompts bigger than the maxlength of the pipe
    :param pipeline:
    :param prompt:
    :param negative_prompt:
    :param device:
    :return:
    """
    max_length = pipeline.tokenizer.model_max_length

    # simple way to determine length of tokens
    count_prompt = len(prompt.split(" "))
    count_negative_prompt = len(negative_prompt.split(" "))

    # create the tensor based on which prompt is longer
    if count_prompt >= count_negative_prompt:
        input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False).input_ids.to(device)
        shape_max_length = input_ids.shape[-1]
        negative_ids = pipeline.tokenizer(negative_prompt, truncation=False, padding="max_length",
                                          max_length=shape_max_length, return_tensors="pt").input_ids.to(device)

    else:
        negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
        shape_max_length = negative_ids.shape[-1]
        input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
                                       max_length=shape_max_length).input_ids.to(device)

    concat_embeds = []
    neg_embeds = []
    for i in range(0, shape_max_length, max_length):
        concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length])[0])
        neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length])[0])

    return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)


prompt = (22 + random.randint(1, 10)) * "a photo of an astronaut riding a horse on mars"
negative_prompt = (22 + random.randint(1, 10)) * "some negative texts"

print("Our inputs ", prompt, negative_prompt, len(prompt.split(" ")), len(negative_prompt.split(" ")))

prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(pipe, prompt, negative_prompt, "cuda")

image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds).images[0]

image.save("done.png")

@andrevanzuydam
Copy link

The code above does weird things with commas and special chars, just an FYI, not sure if I need to regex the prompts to sanitize, probably lends to word prioritization etc

@djj0s3
Copy link

djj0s3 commented Jun 2, 2023

Bumping this thread back up. The Long Prompt Weighting Stable Diffusion is great but it doesn't mix and match well (correct me if I'm wrong) when using other default pipelines, like ControlNet for example. I believe the spirit of Diffusers is like legos for working with diffusion models. But relying on a community pipeline for this workaround breaks that pattern a bit. I'd love some help on a standard way to add in the best parts of the expanded long prompt weighting pipeline without having to solely use that pipeline.

@patrickvonplaten
Copy link
Contributor

Hey @djj0s3,

Yes good point. Can you try whether you can solve the same use case you had by using: https://huggingface.co/docs/diffusers/main/en/using-diffusers/weighted_prompts

@djj0s3
Copy link

djj0s3 commented Jun 6, 2023

Thanks! That worked after I ran into some "I'm an idiot" errors. For anyone else that lands on here, read the Compel docs carefully - particularly this bit if you're using long tokens or you will get into mismatching tensor issues and be very sad.

compel = Compel(..., truncate_long_prompts=False)
prompt = "a cat playing with a ball++ in the forest, amazing, exquisite, stunning, masterpiece, skilled, powerful, incredible, amazing, trending on gregstation, greg, greggy, greggs greggson, greggy mcgregface, ..." # very long prompt
conditioning = compel.build_conditioning_tensor(prompt)
negative_prompt = "" # it's necessary to create an empty prompt - it can also be very long, if you want
negative_conditioning = compel.build_conditioning_tensor(negative_prompt)
[conditioning, negative_conditioning] = compel.pad_conditioning_tensors_to_same_length([conditioning, negative_conditioning])

@DevonPeroutky
Copy link

DevonPeroutky commented Jun 28, 2023

@patrickvonplaten Potentially stupid question, but if you directly pass the [neg]/prompt_embeddings into the pipeline, does that mean there's no longer an attention mask being used?

If so, could this cause issues with padding tokens (necessary to make the prompt and negative_prompt the same length), as they would not be ignored?

Thank you, and your team, for all the hard work btw.

@hckj588ku
Copy link

Hey @jslegers,

It's true that our documentation is currently lacking behind a bit. Would you be interested in contributing a doc page about long prompting?

Also note that I would suggest to just use StableDiffusionPipeline and pass the prompt_embeds manually, e.g. the following code snippet works:

from diffusers import StableDiffusionPipeline
import torch

# 1. load model
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

# 2. Forward embeddings and negative embeddings through text encoder
prompt = 25 * "a photo of an astronaut riding a horse on mars"
max_length = pipe.tokenizer.model_max_length

input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to("cuda")

negative_ids = pipe.tokenizer("", truncation=False, padding="max_length", max_length=input_ids.shape[-1], return_tensors="pt").input_ids                                                                                                     
negative_ids = negative_ids.to("cuda")

concat_embeds = []
neg_embeds = []
for i in range(0, input_ids.shape[-1], max_length):
    concat_embeds.append(pipe.text_encoder(input_ids[:, i: i + max_length])[0])
    neg_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])

prompt_embeds = torch.cat(concat_embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)

# 3. Forward
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds).images[0]
image.save("astronaut_rides_horse.png")

Could you try out whether this fits your use case? Would you be interested in adding a doc page about long-prompting maybe under: https://github.com/huggingface/diffusers/tree/main/docs/source/en/using-diffusers

How can I use this by StableDIffusionXlPipeline

@ManjuVajra
Copy link

ManjuVajra commented Jul 25, 2023

This is still an issue with diffusers use of CLIP in general. No feedback on if that code snippet currently works. I'll test it myself. Either way, it is still an issue.

@patrickvonplaten
Copy link
Contributor

I don't think this is an issue. If you want to overcome the 77 tokens limit, I highly recommend using the compel library: https://github.com/damian0815/compel#compel

@Atlas3DSS
Copy link

Thanks! That worked after I ran into some "I'm an idiot" errors. For anyone else that lands on here, read the Compel docs carefully - particularly this bit if you're using long tokens or you will get into mismatching tensor issues and be very sad.

compel = Compel(..., truncate_long_prompts=False)
prompt = "a cat playing with a ball++ in the forest, amazing, exquisite, stunning, masterpiece, skilled, powerful, incredible, amazing, trending on gregstation, greg, greggy, greggs greggson, greggy mcgregface, ..." # very long prompt
conditioning = compel.build_conditioning_tensor(prompt)
negative_prompt = "" # it's necessary to create an empty prompt - it can also be very long, if you want
negative_conditioning = compel.build_conditioning_tensor(negative_prompt)
[conditioning, negative_conditioning] = compel.pad_conditioning_tensors_to_same_length([conditioning, negative_conditioning])

My brother in christ i thank you for this! I was struggling and you saved me ;D

@o5faruk
Copy link

o5faruk commented Nov 28, 2023

self.txt2img_pipe.load_textual_inversion(
        EMBEDDING_PATHS, token=EMBEDDING_TOKENS, local_files_only=True
)

textual_inversion_manager = DiffusersTextualInversionManager(self.txt2img_pipe)


self.compel_proc = Compel(
    tokenizer=self.txt2img_pipe.tokenizer,
    text_encoder=self.txt2img_pipe.text_encoder,
    textual_inversion_manager=textual_inversion_manager,
    truncate_long_prompts=False,
)
if prompt:
    conditioning = self.compel_proc.build_conditioning_tensor(prompt)
    if not negative_prompt:
        negative_prompt = ""  # it's necessary to create an empty prompt - it can also be very long, if you want
    negative_conditioning = self.compel_proc.build_conditioning_tensor(
        negative_prompt
    )
    [
        prompt_embeds,
        negative_prompt_embeds,
    ] = self.compel_proc.pad_conditioning_tensors_to_same_length(
        [conditioning, negative_conditioning]
    )
    ...
    output = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        guidance_scale=guidance_scale,
        generator=generator,
        num_inference_steps=num_inference_steps,
        **extra_kwargs,
    )

Im having weird issues, all the relevant code is shown above, however, negative_prompt messes up my image results, almost as if negatives are getting mixed up with positives.
Also, this happens only if prompt and negative prompt length exceeds 77 tokens.
extra_kwargs does not contain prompt or negative_prompt so only embeds are passed into pipeline. The pipeline in this case is controlnet text to image

Is it possible that negatives get mixed up into positives in pad_conditioning_tensors_to_same_length function?

This is my image with long negative prompt
image

And this is same seed, same prompt, no negative
image

@lusp75
Copy link

lusp75 commented Jan 13, 2024

get_pipeline_embeds

Which file should I modify? a py. file? Is it a file to add? Thank you!

@andrevanzuydam
Copy link

get_pipeline_embeds

Which file should I modify? a py. file? Is it a file to add? Thank you!

Hi @lusp75 if you look at my example way above, I just defined a method and used it in my code, dangerous to hack maintained libraries.

@HamenderSingh
Copy link

Just for in case someone comes accross this issue and wants a solution, I built something that works for both prompts correctly of varying lengths

from diffusers import StableDiffusionPipeline
import torch
import random

# 1. load model
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")


pipe.enable_sequential_cpu_offload() # my graphics card VRAM is very low


def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
    """ Get pipeline embeds for prompts bigger than the maxlength of the pipe
    :param pipeline:
    :param prompt:
    :param negative_prompt:
    :param device:
    :return:
    """
    max_length = pipeline.tokenizer.model_max_length

    # simple way to determine length of tokens
    count_prompt = len(prompt.split(" "))
    count_negative_prompt = len(negative_prompt.split(" "))

    # create the tensor based on which prompt is longer
    if count_prompt >= count_negative_prompt:
        input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False).input_ids.to(device)
        shape_max_length = input_ids.shape[-1]
        negative_ids = pipeline.tokenizer(negative_prompt, truncation=False, padding="max_length",
                                          max_length=shape_max_length, return_tensors="pt").input_ids.to(device)

    else:
        negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
        shape_max_length = negative_ids.shape[-1]
        input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
                                       max_length=shape_max_length).input_ids.to(device)

    concat_embeds = []
    neg_embeds = []
    for i in range(0, shape_max_length, max_length):
        concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length])[0])
        neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length])[0])

    return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)


prompt = (22 + random.randint(1, 10)) * "a photo of an astronaut riding a horse on mars"
negative_prompt = (22 + random.randint(1, 10)) * "some negative texts"

print("Our inputs ", prompt, negative_prompt, len(prompt.split(" ")), len(negative_prompt.split(" ")))

prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(pipe, prompt, negative_prompt, "cuda")

image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds).images[0]

image.save("done.png")

There is a bug causing error when prompt is bigger compared to negative prompt.
I've fixed it below.

def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
    """ Get pipeline embeds for prompts bigger than the maxlength of the pipe
    :param pipeline:
    :param prompt:
    :param negative_prompt:
    :param device:
    :return:
    """
    max_length = pipeline.tokenizer.model_max_length


    # simple way to determine length of tokens
    input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False).input_ids.to(device)
    negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)

    # create the tensor based on which prompt is longer
    if input_ids.shape[-1] >= negative_ids.shape[-1]:
        shape_max_length = input_ids.shape[-1]
        negative_ids = pipeline.tokenizer(negative_prompt, truncation=False, padding="max_length",
                                          max_length=shape_max_length, return_tensors="pt").input_ids.to(device)

    else:
        shape_max_length = negative_ids.shape[-1]
        input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
                                       max_length=shape_max_length).input_ids.to(device)

    concat_embeds = []
    neg_embeds = []
    for i in range(0, shape_max_length, max_length):
        concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length])[0])
        neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length])[0])

    return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)

natsunoyuki added a commit to natsunoyuki/diffuser_tools that referenced this issue Mar 9, 2024
The original prompt embeddings will cause the pipeline to crash in the case the negative prompt was longer than the prompt. Implemented a fix suggested in huggingface/diffusers#2136.
@zhentingqi
Copy link

Hi, so is there any other solution to this problem: How can we prompt diffusion models with more than 77 tokens?

I see the following code snippet but it seems that it just split the text into chunks and encode the chunks one by one, instead of encoding the entire text sequence?

Just for in case someone comes accross this issue and wants a solution, I built something that works for both prompts correctly of varying lengths

from diffusers import StableDiffusionPipeline
import torch
import random

# 1. load model
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")


pipe.enable_sequential_cpu_offload() # my graphics card VRAM is very low


def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
    """ Get pipeline embeds for prompts bigger than the maxlength of the pipe
    :param pipeline:
    :param prompt:
    :param negative_prompt:
    :param device:
    :return:
    """
    max_length = pipeline.tokenizer.model_max_length

    # simple way to determine length of tokens
    count_prompt = len(prompt.split(" "))
    count_negative_prompt = len(negative_prompt.split(" "))

    # create the tensor based on which prompt is longer
    if count_prompt >= count_negative_prompt:
        input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False).input_ids.to(device)
        shape_max_length = input_ids.shape[-1]
        negative_ids = pipeline.tokenizer(negative_prompt, truncation=False, padding="max_length",
                                          max_length=shape_max_length, return_tensors="pt").input_ids.to(device)

    else:
        negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
        shape_max_length = negative_ids.shape[-1]
        input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
                                       max_length=shape_max_length).input_ids.to(device)

    concat_embeds = []
    neg_embeds = []
    for i in range(0, shape_max_length, max_length):
        concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length])[0])
        neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length])[0])

    return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)


prompt = (22 + random.randint(1, 10)) * "a photo of an astronaut riding a horse on mars"
negative_prompt = (22 + random.randint(1, 10)) * "some negative texts"

print("Our inputs ", prompt, negative_prompt, len(prompt.split(" ")), len(negative_prompt.split(" ")))

prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(pipe, prompt, negative_prompt, "cuda")

image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds).images[0]

image.save("done.png")

By the way, I hope someone could also help me with the same token limit problem with CLIP models. Is there any long-context image-text model? Or do I have to fine-tune a long-context CLIP-like model on my own? THanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests