Skip to content

Commit

Permalink
Fixed a bug in the prompt embeddings.
Browse files Browse the repository at this point in the history
The original prompt embeddings will cause the pipeline to crash in the case the negative prompt was longer than the prompt. Implemented a fix suggested in huggingface/diffusers#2136.
  • Loading branch information
natsunoyuki committed Mar 9, 2024
1 parent 8347d0f commit 6489c6d
Showing 1 changed file with 18 additions and 32 deletions.
50 changes: 18 additions & 32 deletions src/diffuser_tools/text2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,57 +156,43 @@ def get_prompt_embeddings(
"""Prompt embeddings to overcome CLIP 77 token limit.
https://github.com/huggingface/diffusers/issues/2136
"""
count_prompt = len(self.prompt.split(split_character))
count_negative_prompt = len(self.negative_prompt.split(split_character))

max_length = self.pipe.tokenizer.model_max_length

# Simple method of checking if the prompt is longer than the negative
# prompt - split the input strings using `split_character`.
count_prompt = len(self.prompt.split(split_character))
count_negative_prompt = len(self.negative_prompt.split(split_character))
input_ids = self.pipe.tokenizer(
self.prompt, return_tensors = "pt", truncation = False
).input_ids.to(self.device)
negative_ids = self.pipe.tokenizer(
self.negative_prompt, return_tensors = "pt", truncation = False
).input_ids.to(self.device)

# If prompt is longer than negative prompt.
if count_prompt >= count_negative_prompt:
input_ids = self.pipe.tokenizer(
self.prompt, return_tensors = "pt", truncation = False,
).input_ids.to(self.device)
if input_ids.shape[-1] >= negative_ids.shape[-1]:
shape_max_length = input_ids.shape[-1]
negative_ids = self.pipe.tokenizer(
self.negative_prompt,
truncation = False,
padding = "max_length",
max_length = shape_max_length,
return_tensors = "pt",
self.negative_prompt, return_tensors = "pt", truncation = False,
padding = "max_length", max_length = shape_max_length
).input_ids.to(self.device)
# If negative prompt is longer than prompt.
else:
negative_ids = self.pipe.tokenizer(
self.negative_prompt, return_tensors = "pt", truncation = False,
).input_ids.to(self.device)
shape_max_length = negative_ids.shape[-1]
input_ids = self.pipe.tokenizer(
self.prompt,
return_tensors = "pt",
truncation = False,
padding = "max_length",
max_length = shape_max_length,
self.prompt, return_tensors = "pt", truncation = False,
padding = "max_length", max_length = shape_max_length
).input_ids.to(self.device)

# Concatenate the individual prompt embeddings.
concat_embeds = []
neg_embeds = []
for i in range(0, shape_max_length, max_length):
concat_embeds.append(
self.pipe.text_encoder(input_ids[:, i: i + max_length])[0]
)
neg_embeds.append(
self.pipe.text_encoder(negative_ids[:, i: i + max_length])[0]
)
concat_embeds.append(self.pipe.text_encoder(input_ids[:, i: i + max_length])[0])
neg_embeds.append(self.pipe.text_encoder(negative_ids[:, i: i + max_length])[0])

self.prompt_embeddings = torch.cat(concat_embeds, dim = 1)
self.negative_prompt_embeddings = torch.cat(neg_embeds, dim = 1)

if return_embeddings is True:
return torch.cat(concat_embeds, dim = 1), torch.cat(neg_embeds, dim = 1)
return self.prompt_embeddings, self.negative_prompt_embeddings


# %%
def run_pipe(
Expand Down Expand Up @@ -258,4 +244,4 @@ def run_pipe(
if verbose is True:
sys.stdout.write("{:.2f}s.\n".format(time_elapsed));

return imgs[0]
return imgs[0]

0 comments on commit 6489c6d

Please sign in to comment.