Skip to content

How to handle SDXL long prompt #4716

@elcolie

Description

@elcolie

Describe the bug

I am unable to use embeds prompt in order to handle prompt that is longer than 77 tokens.

Reproduction

import itertools
import os.path
import random
import string
import time
import typing as typ

import torch
from diffusers import StableDiffusionXLPipeline
from tqdm import tqdm

import bb
from web_sdxl import seed_everything

seed_everything(42)


def generate_random_string(length):
    letters = string.ascii_letters
    result = ''.join(random.choice(letters) for _ in range(length))
    return result


def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
    """ Get pipeline embeds for prompts bigger than the maxlength of the pipe
    :param pipeline:
    :param prompt:
    :param negative_prompt:
    :param device:
    :return:
    """
    max_length = pipeline.tokenizer.model_max_length

    # simple way to determine length of tokens
    count_prompt = len(prompt.split(" "))
    count_negative_prompt = len(negative_prompt.split(" "))

    # create the tensor based on which prompt is longer
    if count_prompt >= count_negative_prompt:
        input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False).input_ids.to(device)
        shape_max_length = input_ids.shape[-1]
        negative_ids = pipeline.tokenizer(negative_prompt, truncation=False, padding="max_length",
                                          max_length=shape_max_length, return_tensors="pt").input_ids.to(device)

    else:
        negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
        shape_max_length = negative_ids.shape[-1]
        input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
                                       max_length=shape_max_length).input_ids.to(device)

    concat_embeds = []
    neg_embeds = []
    for i in range(0, shape_max_length, max_length):
        concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length])[0])
        neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length])[0])

    return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)


model_path = "fine_tuned_models/sdxl-sarit"
device = "mps" if torch.backends.mps.is_available() else "cpu"
out_dir: str = "gluta40"

age_prompts: typ.List[str] = [
    "young asian girl",
    "a photograph of an angel with sly expression, wearing a see-thru short roman style dress, beautiful asian mixed european woman face, beautiful eyes, black hair, looking down, hyper realistic and detailed, 16k",
]
hand_prompts: typ.List[str] = [
    "left hand holding a gluta40 jar one hand, right hand is behind her back",
    "right hand holding a gluta40 jar one hand, left hand is behind her back",
]
face_angle_prompts: typ.List[str] = [
    "straight face",
]
hair_prompts: typ.List[str] = [
    "black long tied hair",
    "black long hair",
]
background_prompts: typ.List[str] = [
    "no background, hold both hands, bad hands",
]
negative_prompt: str = "disfigured, disproportionate, bad anatomy, bad proportions, ugly, out of frame, mangled, asymmetric, cross-eyed, depressed, immature, stuffed animal, out of focus, high depth of field, cloned face, cloned head, age spot, skin blemishes, collapsed eyeshadow, asymmetric ears, imperfect eyes, unnatural, conjoined, missing limb, missing arm, missing leg, poorly drawn face, poorly drawn feet, poorly drawn hands, floating limb, disconnected limb, extra limb, malformed limbs, malformed hands, poorly rendered face, poor facial details, poorly rendered hands, double face, unbalanced body, unnatural body, lacking body, long body, cripple, cartoon, 3D, weird colors, unnatural skin tone, unnatural skin, stiff face, fused hand, skewed eyes, surreal, cropped head, group of people, too many fingers, bad hands, six fingers"
combined_list = list(itertools.product(age_prompts, hand_prompts, face_angle_prompts, hair_prompts, background_prompts))
random.shuffle(combined_list)

for item in tqdm(combined_list, total=len(combined_list)):
    age, hand, face_angle, hair, background = item
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    prompt: str = ", ".join(item)
    print(prompt)
    out_filename: str = f"{out_dir}/{prompt.replace(' ', '_')}"
    if not os.path.exists(f"{out_filename}_0.png"):
        try:
            pipe = StableDiffusionXLPipeline.from_pretrained(model_path, safety_checker=None,
                                                             requires_safety_checker=False)
            pipe.to(device)
            prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(pipe, prompt, negative_prompt, device)
            images = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
                          num_images_per_prompt=3, width=768,
                          height=1024).images
            for idx, image in enumerate(images):
                image.save(f"{out_filename}_{idx}.png")
        except OSError as exc:
            if exc.errno == 36:
                short_filename: str = f"{out_dir}/{prompt.replace(' ', '_')[:100]}{generate_random_string(6)}"
            else:
                raise

bb.play_beep(440, 0.5)
time.sleep(1)
bb.play_beep(440, 0.5)
time.sleep(1)
bb.play_beep(440, 0.5)

Logs

Traceback (most recent call last):
  File "/Users/sarit/study/try_openai/try_fine_tune/gluta40.py", line 106, in <module>
    images = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sarit/anaconda3/envs/try_openai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sarit/anaconda3/envs/try_openai/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 700, in __call__
    self.check_inputs(
  File "/Users/sarit/anaconda3/envs/try_openai/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 492, in check_inputs
    raise ValueError(
ValueError: If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.

System Info

OSX 13.4.1 (c) (22F770820d)
python 3.11.4
absl-py==1.4.0
accelerate @ git+https://github.com/huggingface/accelerate@d087be01566477d99b660526adb7da4ec31abf1d
aiofiles==23.1.0
aiohttp==3.8.4
aiosignal==1.3.1
altair==5.0.0
ansiwrap==0.8.4
antlr4-python3-runtime==4.9.3
anyio==3.6.2
appnope==0.1.3
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
async-generator==1.10
async-lru==2.0.4
async-timeout==4.0.2
attrs==23.1.0
autopep8==2.0.2
Babel==2.12.1
backcall==0.2.0
backoff==2.2.1
beautifulsoup4==4.12.2
bitsandbytes==0.41.1
black==23.7.0
bleach==6.0.0
bson==0.5.10
build==0.10.0
cachetools==5.3.1
certifi==2023.5.7
cffi==1.15.1
chardet==5.2.0
charset-normalizer==3.1.0
click==8.1.3
cloudpickle==2.2.1
cmake==3.26.4
colorama==0.4.6
coloredlogs==15.0.1
comm==0.1.3
commonmark==0.9.1
contourpy==1.0.7
controlnet-aux==0.0.5
cycler==0.11.0
datasets==2.12.0
debuglater==1.4.4
debugpy==1.6.7
decorator==5.1.1
defusedxml==0.7.1
detectron2 @ git+https://github.com/facebookresearch/detectron2.git@7d2e68dbe452fc422268d40ac185ea2609affca8
diffusers==0.20.0
dill==0.3.6
einops==0.6.1
entrypoints==0.4
evaluate==0.4.0
exceptiongroup==1.1.1
executing==1.2.0
fastapi==0.95.2
fastjsonschema==2.17.1
ffmpeg-python==0.2.0
ffmpy==0.3.0
filelock==3.12.0
flatbuffers==23.5.26
fonttools==4.39.4
fqdn==1.5.1
frozenlist==1.3.3
fsspec==2023.5.0
fvcore==0.1.5.post20221221
google-auth==2.22.0
google-auth-oauthlib==1.0.0
gradio==3.32.0
gradio_client==0.2.5
grpcio==1.57.0
h11==0.14.0
h5py==3.9.0
httpcore==0.17.2
httpx==0.24.1
huggingface-hub==0.14.1
humanfriendly==10.0
humanize==4.6.0
hydra-core==1.3.2
idna==3.4
imageio==2.31.0
importlib-metadata==6.6.0
invisible-watermark==0.2.0
iopath==0.1.9
ipdb==0.13.13
ipykernel==6.23.1
ipython==8.13.2
isoduration==20.11.0
jedi==0.18.2
Jinja2==3.1.2
jprq==2.1.0
json5==0.9.14
jsonpointer==2.4
jsonschema==4.19.0
jsonschema-specifications==2023.7.1
jupyter-events==0.7.0
jupyter-lsp==2.2.0
jupyter_client==8.2.0
jupyter_core==5.3.0
jupyter_server==2.7.0
jupyter_server_terminals==0.4.4
jupyterlab==4.0.4
jupyterlab-pygments==0.2.2
jupyterlab_server==2.24.0
jupytext==1.14.5
kiwisolver==1.4.4
lazy_loader==0.2
linkify-it-py==2.0.2
llvmlite==0.40.1
Markdown==3.4.4
markdown-it-py==2.2.0
MarkupSafe==2.1.2
matplotlib==3.7.1
matplotlib-inline==0.1.6
mdit-py-plugins==0.3.3
mdurl==0.1.2
mediapipe==0.10.1
mistune==2.0.5
monotonic==1.6
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.14
mypy-extensions==1.0.0
nbclient==0.7.4
nbconvert==7.4.0
nbformat==5.8.0
nest-asyncio==1.5.6
networkx==3.1
notebook==7.0.2
notebook_shim==0.2.3
numba==0.57.1
numpy==1.24.3
oauthlib==3.2.2
omegaconf==2.3.0
onnxruntime==1.15.1
openai==0.27.7
opencv-contrib-python==4.7.0.72
opencv-python==4.7.0.72
orjson==3.8.13
outcome==1.2.0
overrides==7.4.0
packaging==23.1
pandas==2.0.1
pandocfilters==1.5.0
papermill==2.4.0
parso==0.8.3
pathspec==0.11.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.5.0
pip-tools==6.13.0
platformdirs==3.5.1
ploomber==0.22.3
ploomber-core==0.2.10
ploomber-engine==0.0.28
ploomber-scaffold==0.3.1
pooch==1.7.0
portalocker==2.7.0
posthog==3.0.1
prometheus-client==0.17.1
prompt-toolkit==3.0.38
protobuf==3.20.3
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
pyarrow==12.0.0
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycocotools==2.0.6
pycodestyle==2.10.0
pycparser==2.21
pydantic==1.10.7
pydub==0.25.1
pyflakes==3.0.1
Pygments==2.15.1
PyMatting==1.1.8
pyparsing==3.0.9
pypdf==3.9.0
pyproject_hooks==1.0.0
pyre-extensions==0.0.29
pyrsistent==0.19.3
PySocks==1.7.1
python-dateutil==2.8.2
python-dotenv==1.0.0
python-json-logger==2.0.7
python-multipart==0.0.6
pytz==2023.3
PyWavelets==1.4.1
PyYAML==6.0
pyzmq==25.0.2
referencing==0.30.2
regex==2023.5.5
requests==2.30.0
requests-oauthlib==1.3.1
responses==0.18.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==10.14.0
rpds-py==0.9.2
rsa==4.9
safetensors==0.3.1
scikit-image==0.21.0
scipy==1.10.1
seaborn==0.12.2
selenium==4.10.0
semantic-version==2.10.0
Send2Trash==1.8.2
simple-photo-gallery @ file:///Users/sarit/study/simple-photo-gallery
six==1.16.0
sniffio==1.3.0
sortedcontainers==2.4.0
sounddevice==0.4.6
soupsieve==2.4.1
SQLAlchemy==2.0.15
sqlparse==0.4.4
stack-data==0.6.2
starlette==0.27.0
super-image==0.1.7
sympy==1.12
tabulate==0.9.0
tenacity==8.2.2
tensorboard==2.14.0
tensorboard-data-server==0.7.1
tensorboardX==2.6.1
termcolor==2.3.0
terminado==0.17.1
textwrap3==0.9.2
tifffile==2023.4.12
tiktoken==0.4.0
timm==0.9.2
tinycss2==1.2.1
tokenizers==0.13.3
toml==0.10.2
tomli==2.0.1
toolz==0.12.0
torch==2.0.1
torchvision==0.15.2
tornado==6.3.2
tqdm==4.65.0
traitlets==5.9.0
transformers==4.31.0
trio==0.22.0
trio-websocket==0.10.3
triton-pre-mlir @ git+https://github.com/vchiley/triton.git@2dd3b957698a39bbca615c02a447a98482c144a3#subdirectory=python
typing-inspect==0.9.0
typing_extensions==4.5.0
tzdata==2023.3
uc-micro-py==1.0.2
ultralytics==8.0.158
uri-template==1.3.0
urllib3==1.26.16
uvicorn==0.22.0
wcwidth==0.2.6
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.1
websockets==11.0.3
Werkzeug==2.3.7
wsproto==1.2.0
xformers==0.0.20
xxhash==3.2.0
yacs==0.1.8
yarl==1.9.2
zipp==3.15.0

Who can help?

@patrickvonplaten @sayakpaul @williamberman

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions