From 6489c6db60251b171bd7f125a0c40d2895d8b239 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=98=D0=B2=D0=B0=D0=BD=20=D0=92=D1=83=D0=BB=D0=BA=D0=B0?= =?UTF-8?q?=D0=BD=D0=B8=D0=BD?= <33706936+natsunoyuki@users.noreply.github.com> Date: Sat, 9 Mar 2024 12:02:41 +0900 Subject: [PATCH] Fixed a bug in the prompt embeddings. The original prompt embeddings will cause the pipeline to crash in the case the negative prompt was longer than the prompt. Implemented a fix suggested in https://github.com/huggingface/diffusers/issues/2136. --- src/diffuser_tools/text2img.py | 50 ++++++++++++---------------------- 1 file changed, 18 insertions(+), 32 deletions(-) diff --git a/src/diffuser_tools/text2img.py b/src/diffuser_tools/text2img.py index ef13490..1fec1f0 100644 --- a/src/diffuser_tools/text2img.py +++ b/src/diffuser_tools/text2img.py @@ -156,57 +156,43 @@ def get_prompt_embeddings( """Prompt embeddings to overcome CLIP 77 token limit. https://github.com/huggingface/diffusers/issues/2136 """ + count_prompt = len(self.prompt.split(split_character)) + count_negative_prompt = len(self.negative_prompt.split(split_character)) max_length = self.pipe.tokenizer.model_max_length - # Simple method of checking if the prompt is longer than the negative - # prompt - split the input strings using `split_character`. - count_prompt = len(self.prompt.split(split_character)) - count_negative_prompt = len(self.negative_prompt.split(split_character)) + input_ids = self.pipe.tokenizer( + self.prompt, return_tensors = "pt", truncation = False + ).input_ids.to(self.device) + negative_ids = self.pipe.tokenizer( + self.negative_prompt, return_tensors = "pt", truncation = False + ).input_ids.to(self.device) - # If prompt is longer than negative prompt. - if count_prompt >= count_negative_prompt: - input_ids = self.pipe.tokenizer( - self.prompt, return_tensors = "pt", truncation = False, - ).input_ids.to(self.device) + if input_ids.shape[-1] >= negative_ids.shape[-1]: shape_max_length = input_ids.shape[-1] negative_ids = self.pipe.tokenizer( - self.negative_prompt, - truncation = False, - padding = "max_length", - max_length = shape_max_length, - return_tensors = "pt", + self.negative_prompt, return_tensors = "pt", truncation = False, + padding = "max_length", max_length = shape_max_length ).input_ids.to(self.device) - # If negative prompt is longer than prompt. else: - negative_ids = self.pipe.tokenizer( - self.negative_prompt, return_tensors = "pt", truncation = False, - ).input_ids.to(self.device) shape_max_length = negative_ids.shape[-1] input_ids = self.pipe.tokenizer( - self.prompt, - return_tensors = "pt", - truncation = False, - padding = "max_length", - max_length = shape_max_length, + self.prompt, return_tensors = "pt", truncation = False, + padding = "max_length", max_length = shape_max_length ).input_ids.to(self.device) - # Concatenate the individual prompt embeddings. concat_embeds = [] neg_embeds = [] for i in range(0, shape_max_length, max_length): - concat_embeds.append( - self.pipe.text_encoder(input_ids[:, i: i + max_length])[0] - ) - neg_embeds.append( - self.pipe.text_encoder(negative_ids[:, i: i + max_length])[0] - ) + concat_embeds.append(self.pipe.text_encoder(input_ids[:, i: i + max_length])[0]) + neg_embeds.append(self.pipe.text_encoder(negative_ids[:, i: i + max_length])[0]) self.prompt_embeddings = torch.cat(concat_embeds, dim = 1) self.negative_prompt_embeddings = torch.cat(neg_embeds, dim = 1) if return_embeddings is True: - return torch.cat(concat_embeds, dim = 1), torch.cat(neg_embeds, dim = 1) + return self.prompt_embeddings, self.negative_prompt_embeddings + # %% def run_pipe( @@ -258,4 +244,4 @@ def run_pipe( if verbose is True: sys.stdout.write("{:.2f}s.\n".format(time_elapsed)); - return imgs[0] \ No newline at end of file + return imgs[0]