-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Description
Describe the bug
I am unable to use embeds prompt in order to handle prompt that is longer than 77 tokens.
Reproduction
import itertools
import os.path
import random
import string
import time
import typing as typ
import torch
from diffusers import StableDiffusionXLPipeline
from tqdm import tqdm
import bb
from web_sdxl import seed_everything
seed_everything(42)
def generate_random_string(length):
letters = string.ascii_letters
result = ''.join(random.choice(letters) for _ in range(length))
return result
def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
""" Get pipeline embeds for prompts bigger than the maxlength of the pipe
:param pipeline:
:param prompt:
:param negative_prompt:
:param device:
:return:
"""
max_length = pipeline.tokenizer.model_max_length
# simple way to determine length of tokens
count_prompt = len(prompt.split(" "))
count_negative_prompt = len(negative_prompt.split(" "))
# create the tensor based on which prompt is longer
if count_prompt >= count_negative_prompt:
input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False).input_ids.to(device)
shape_max_length = input_ids.shape[-1]
negative_ids = pipeline.tokenizer(negative_prompt, truncation=False, padding="max_length",
max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
else:
negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
shape_max_length = negative_ids.shape[-1]
input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
max_length=shape_max_length).input_ids.to(device)
concat_embeds = []
neg_embeds = []
for i in range(0, shape_max_length, max_length):
concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length])[0])
neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length])[0])
return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
model_path = "fine_tuned_models/sdxl-sarit"
device = "mps" if torch.backends.mps.is_available() else "cpu"
out_dir: str = "gluta40"
age_prompts: typ.List[str] = [
"young asian girl",
"a photograph of an angel with sly expression, wearing a see-thru short roman style dress, beautiful asian mixed european woman face, beautiful eyes, black hair, looking down, hyper realistic and detailed, 16k",
]
hand_prompts: typ.List[str] = [
"left hand holding a gluta40 jar one hand, right hand is behind her back",
"right hand holding a gluta40 jar one hand, left hand is behind her back",
]
face_angle_prompts: typ.List[str] = [
"straight face",
]
hair_prompts: typ.List[str] = [
"black long tied hair",
"black long hair",
]
background_prompts: typ.List[str] = [
"no background, hold both hands, bad hands",
]
negative_prompt: str = "disfigured, disproportionate, bad anatomy, bad proportions, ugly, out of frame, mangled, asymmetric, cross-eyed, depressed, immature, stuffed animal, out of focus, high depth of field, cloned face, cloned head, age spot, skin blemishes, collapsed eyeshadow, asymmetric ears, imperfect eyes, unnatural, conjoined, missing limb, missing arm, missing leg, poorly drawn face, poorly drawn feet, poorly drawn hands, floating limb, disconnected limb, extra limb, malformed limbs, malformed hands, poorly rendered face, poor facial details, poorly rendered hands, double face, unbalanced body, unnatural body, lacking body, long body, cripple, cartoon, 3D, weird colors, unnatural skin tone, unnatural skin, stiff face, fused hand, skewed eyes, surreal, cropped head, group of people, too many fingers, bad hands, six fingers"
combined_list = list(itertools.product(age_prompts, hand_prompts, face_angle_prompts, hair_prompts, background_prompts))
random.shuffle(combined_list)
for item in tqdm(combined_list, total=len(combined_list)):
age, hand, face_angle, hair, background = item
if not os.path.exists(out_dir):
os.makedirs(out_dir)
prompt: str = ", ".join(item)
print(prompt)
out_filename: str = f"{out_dir}/{prompt.replace(' ', '_')}"
if not os.path.exists(f"{out_filename}_0.png"):
try:
pipe = StableDiffusionXLPipeline.from_pretrained(model_path, safety_checker=None,
requires_safety_checker=False)
pipe.to(device)
prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(pipe, prompt, negative_prompt, device)
images = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
num_images_per_prompt=3, width=768,
height=1024).images
for idx, image in enumerate(images):
image.save(f"{out_filename}_{idx}.png")
except OSError as exc:
if exc.errno == 36:
short_filename: str = f"{out_dir}/{prompt.replace(' ', '_')[:100]}{generate_random_string(6)}"
else:
raise
bb.play_beep(440, 0.5)
time.sleep(1)
bb.play_beep(440, 0.5)
time.sleep(1)
bb.play_beep(440, 0.5)
Logs
Traceback (most recent call last):
File "/Users/sarit/study/try_openai/try_fine_tune/gluta40.py", line 106, in <module>
images = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/sarit/anaconda3/envs/try_openai/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/sarit/anaconda3/envs/try_openai/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 700, in __call__
self.check_inputs(
File "/Users/sarit/anaconda3/envs/try_openai/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 492, in check_inputs
raise ValueError(
ValueError: If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.
System Info
OSX 13.4.1 (c) (22F770820d)
python 3.11.4
absl-py==1.4.0
accelerate @ git+https://github.com/huggingface/accelerate@d087be01566477d99b660526adb7da4ec31abf1d
aiofiles==23.1.0
aiohttp==3.8.4
aiosignal==1.3.1
altair==5.0.0
ansiwrap==0.8.4
antlr4-python3-runtime==4.9.3
anyio==3.6.2
appnope==0.1.3
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
async-generator==1.10
async-lru==2.0.4
async-timeout==4.0.2
attrs==23.1.0
autopep8==2.0.2
Babel==2.12.1
backcall==0.2.0
backoff==2.2.1
beautifulsoup4==4.12.2
bitsandbytes==0.41.1
black==23.7.0
bleach==6.0.0
bson==0.5.10
build==0.10.0
cachetools==5.3.1
certifi==2023.5.7
cffi==1.15.1
chardet==5.2.0
charset-normalizer==3.1.0
click==8.1.3
cloudpickle==2.2.1
cmake==3.26.4
colorama==0.4.6
coloredlogs==15.0.1
comm==0.1.3
commonmark==0.9.1
contourpy==1.0.7
controlnet-aux==0.0.5
cycler==0.11.0
datasets==2.12.0
debuglater==1.4.4
debugpy==1.6.7
decorator==5.1.1
defusedxml==0.7.1
detectron2 @ git+https://github.com/facebookresearch/detectron2.git@7d2e68dbe452fc422268d40ac185ea2609affca8
diffusers==0.20.0
dill==0.3.6
einops==0.6.1
entrypoints==0.4
evaluate==0.4.0
exceptiongroup==1.1.1
executing==1.2.0
fastapi==0.95.2
fastjsonschema==2.17.1
ffmpeg-python==0.2.0
ffmpy==0.3.0
filelock==3.12.0
flatbuffers==23.5.26
fonttools==4.39.4
fqdn==1.5.1
frozenlist==1.3.3
fsspec==2023.5.0
fvcore==0.1.5.post20221221
google-auth==2.22.0
google-auth-oauthlib==1.0.0
gradio==3.32.0
gradio_client==0.2.5
grpcio==1.57.0
h11==0.14.0
h5py==3.9.0
httpcore==0.17.2
httpx==0.24.1
huggingface-hub==0.14.1
humanfriendly==10.0
humanize==4.6.0
hydra-core==1.3.2
idna==3.4
imageio==2.31.0
importlib-metadata==6.6.0
invisible-watermark==0.2.0
iopath==0.1.9
ipdb==0.13.13
ipykernel==6.23.1
ipython==8.13.2
isoduration==20.11.0
jedi==0.18.2
Jinja2==3.1.2
jprq==2.1.0
json5==0.9.14
jsonpointer==2.4
jsonschema==4.19.0
jsonschema-specifications==2023.7.1
jupyter-events==0.7.0
jupyter-lsp==2.2.0
jupyter_client==8.2.0
jupyter_core==5.3.0
jupyter_server==2.7.0
jupyter_server_terminals==0.4.4
jupyterlab==4.0.4
jupyterlab-pygments==0.2.2
jupyterlab_server==2.24.0
jupytext==1.14.5
kiwisolver==1.4.4
lazy_loader==0.2
linkify-it-py==2.0.2
llvmlite==0.40.1
Markdown==3.4.4
markdown-it-py==2.2.0
MarkupSafe==2.1.2
matplotlib==3.7.1
matplotlib-inline==0.1.6
mdit-py-plugins==0.3.3
mdurl==0.1.2
mediapipe==0.10.1
mistune==2.0.5
monotonic==1.6
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.14
mypy-extensions==1.0.0
nbclient==0.7.4
nbconvert==7.4.0
nbformat==5.8.0
nest-asyncio==1.5.6
networkx==3.1
notebook==7.0.2
notebook_shim==0.2.3
numba==0.57.1
numpy==1.24.3
oauthlib==3.2.2
omegaconf==2.3.0
onnxruntime==1.15.1
openai==0.27.7
opencv-contrib-python==4.7.0.72
opencv-python==4.7.0.72
orjson==3.8.13
outcome==1.2.0
overrides==7.4.0
packaging==23.1
pandas==2.0.1
pandocfilters==1.5.0
papermill==2.4.0
parso==0.8.3
pathspec==0.11.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.5.0
pip-tools==6.13.0
platformdirs==3.5.1
ploomber==0.22.3
ploomber-core==0.2.10
ploomber-engine==0.0.28
ploomber-scaffold==0.3.1
pooch==1.7.0
portalocker==2.7.0
posthog==3.0.1
prometheus-client==0.17.1
prompt-toolkit==3.0.38
protobuf==3.20.3
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
pyarrow==12.0.0
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycocotools==2.0.6
pycodestyle==2.10.0
pycparser==2.21
pydantic==1.10.7
pydub==0.25.1
pyflakes==3.0.1
Pygments==2.15.1
PyMatting==1.1.8
pyparsing==3.0.9
pypdf==3.9.0
pyproject_hooks==1.0.0
pyre-extensions==0.0.29
pyrsistent==0.19.3
PySocks==1.7.1
python-dateutil==2.8.2
python-dotenv==1.0.0
python-json-logger==2.0.7
python-multipart==0.0.6
pytz==2023.3
PyWavelets==1.4.1
PyYAML==6.0
pyzmq==25.0.2
referencing==0.30.2
regex==2023.5.5
requests==2.30.0
requests-oauthlib==1.3.1
responses==0.18.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==10.14.0
rpds-py==0.9.2
rsa==4.9
safetensors==0.3.1
scikit-image==0.21.0
scipy==1.10.1
seaborn==0.12.2
selenium==4.10.0
semantic-version==2.10.0
Send2Trash==1.8.2
simple-photo-gallery @ file:///Users/sarit/study/simple-photo-gallery
six==1.16.0
sniffio==1.3.0
sortedcontainers==2.4.0
sounddevice==0.4.6
soupsieve==2.4.1
SQLAlchemy==2.0.15
sqlparse==0.4.4
stack-data==0.6.2
starlette==0.27.0
super-image==0.1.7
sympy==1.12
tabulate==0.9.0
tenacity==8.2.2
tensorboard==2.14.0
tensorboard-data-server==0.7.1
tensorboardX==2.6.1
termcolor==2.3.0
terminado==0.17.1
textwrap3==0.9.2
tifffile==2023.4.12
tiktoken==0.4.0
timm==0.9.2
tinycss2==1.2.1
tokenizers==0.13.3
toml==0.10.2
tomli==2.0.1
toolz==0.12.0
torch==2.0.1
torchvision==0.15.2
tornado==6.3.2
tqdm==4.65.0
traitlets==5.9.0
transformers==4.31.0
trio==0.22.0
trio-websocket==0.10.3
triton-pre-mlir @ git+https://github.com/vchiley/triton.git@2dd3b957698a39bbca615c02a447a98482c144a3#subdirectory=python
typing-inspect==0.9.0
typing_extensions==4.5.0
tzdata==2023.3
uc-micro-py==1.0.2
ultralytics==8.0.158
uri-template==1.3.0
urllib3==1.26.16
uvicorn==0.22.0
wcwidth==0.2.6
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.1
websockets==11.0.3
Werkzeug==2.3.7
wsproto==1.2.0
xformers==0.0.20
xxhash==3.2.0
yacs==0.1.8
yarl==1.9.2
zipp==3.15.0