diff --git a/src/diffuser_tools/text2img.py b/src/diffuser_tools/text2img.py index ef13490..1fec1f0 100644 --- a/src/diffuser_tools/text2img.py +++ b/src/diffuser_tools/text2img.py @@ -156,57 +156,43 @@ def get_prompt_embeddings( """Prompt embeddings to overcome CLIP 77 token limit. https://github.com/huggingface/diffusers/issues/2136 """ + count_prompt = len(self.prompt.split(split_character)) + count_negative_prompt = len(self.negative_prompt.split(split_character)) max_length = self.pipe.tokenizer.model_max_length - # Simple method of checking if the prompt is longer than the negative - # prompt - split the input strings using `split_character`. - count_prompt = len(self.prompt.split(split_character)) - count_negative_prompt = len(self.negative_prompt.split(split_character)) + input_ids = self.pipe.tokenizer( + self.prompt, return_tensors = "pt", truncation = False + ).input_ids.to(self.device) + negative_ids = self.pipe.tokenizer( + self.negative_prompt, return_tensors = "pt", truncation = False + ).input_ids.to(self.device) - # If prompt is longer than negative prompt. - if count_prompt >= count_negative_prompt: - input_ids = self.pipe.tokenizer( - self.prompt, return_tensors = "pt", truncation = False, - ).input_ids.to(self.device) + if input_ids.shape[-1] >= negative_ids.shape[-1]: shape_max_length = input_ids.shape[-1] negative_ids = self.pipe.tokenizer( - self.negative_prompt, - truncation = False, - padding = "max_length", - max_length = shape_max_length, - return_tensors = "pt", + self.negative_prompt, return_tensors = "pt", truncation = False, + padding = "max_length", max_length = shape_max_length ).input_ids.to(self.device) - # If negative prompt is longer than prompt. else: - negative_ids = self.pipe.tokenizer( - self.negative_prompt, return_tensors = "pt", truncation = False, - ).input_ids.to(self.device) shape_max_length = negative_ids.shape[-1] input_ids = self.pipe.tokenizer( - self.prompt, - return_tensors = "pt", - truncation = False, - padding = "max_length", - max_length = shape_max_length, + self.prompt, return_tensors = "pt", truncation = False, + padding = "max_length", max_length = shape_max_length ).input_ids.to(self.device) - # Concatenate the individual prompt embeddings. concat_embeds = [] neg_embeds = [] for i in range(0, shape_max_length, max_length): - concat_embeds.append( - self.pipe.text_encoder(input_ids[:, i: i + max_length])[0] - ) - neg_embeds.append( - self.pipe.text_encoder(negative_ids[:, i: i + max_length])[0] - ) + concat_embeds.append(self.pipe.text_encoder(input_ids[:, i: i + max_length])[0]) + neg_embeds.append(self.pipe.text_encoder(negative_ids[:, i: i + max_length])[0]) self.prompt_embeddings = torch.cat(concat_embeds, dim = 1) self.negative_prompt_embeddings = torch.cat(neg_embeds, dim = 1) if return_embeddings is True: - return torch.cat(concat_embeds, dim = 1), torch.cat(neg_embeds, dim = 1) + return self.prompt_embeddings, self.negative_prompt_embeddings + # %% def run_pipe( @@ -258,4 +244,4 @@ def run_pipe( if verbose is True: sys.stdout.write("{:.2f}s.\n".format(time_elapsed)); - return imgs[0] \ No newline at end of file + return imgs[0]