From ff022b79250fe53888dc16ebbd6b9ededf41b295 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 16 Sep 2023 13:29:34 +0000 Subject: [PATCH 1/5] [SDXL] Make sure multi batch prompt embeds works --- .../controlnet/pipeline_controlnet_inpaint_sd_xl.py | 8 +++++--- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 8 +++++--- .../controlnet/pipeline_controlnet_sd_xl_img2img.py | 8 +++++--- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 8 +++++--- .../pipeline_stable_diffusion_xl_img2img.py | 8 +++++--- .../pipeline_stable_diffusion_xl_inpaint.py | 8 +++++--- .../t2i_adapter/pipeline_stable_diffusion_xl_adapter.py | 8 +++++--- 7 files changed, 35 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index cb4ec2f25bd3..bc7d36ea8bac 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -376,7 +376,11 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = negative_prompt_2 or "" + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): @@ -384,8 +388,6 @@ def encode_prompt( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index cbb78e509b84..56b5eaf26ded 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -349,7 +349,11 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = negative_prompt_2 or "" + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): @@ -357,8 +361,6 @@ def encode_prompt( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 6fe3d0c641e5..722bf38bd0e4 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -387,7 +387,11 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = negative_prompt_2 or "" + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): @@ -395,8 +399,6 @@ def encode_prompt( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 40119a6087d2..61912180d913 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -325,7 +325,11 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = negative_prompt_2 or "" + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): @@ -333,8 +337,6 @@ def encode_prompt( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 162fc828fff4..65bf04d0f6bc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -332,7 +332,11 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = negative_prompt_2 or "" + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): @@ -340,8 +344,6 @@ def encode_prompt( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 25753859c210..387eec0b93b0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -481,7 +481,11 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = negative_prompt_2 or "" + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): @@ -489,8 +493,6 @@ def encode_prompt( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 6019d93fe02d..90ae55daaa95 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -349,7 +349,11 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = negative_prompt_2 or "" + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): @@ -357,8 +361,6 @@ def encode_prompt( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" From 295e65050ef424764a5fa7086ff3d12bd5eebcbc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 16 Sep 2023 13:36:18 +0000 Subject: [PATCH 2/5] [SDXL] Make sure multi batch prompt embeds works --- .../pipeline_controlnet_inpaint_sd_xl.py | 4 +- .../controlnet/pipeline_controlnet_sd_xl.py | 4 +- .../pipeline_controlnet_sd_xl_img2img.py | 4 +- .../pipeline_stable_diffusion_xl.py | 4 +- .../pipeline_stable_diffusion_xl_img2img.py | 4 +- .../pipeline_stable_diffusion_xl_inpaint.py | 4 +- .../pipeline_stable_diffusion_xl_adapter.py | 4 +- .../test_stable_diffusion_xl.py | 37 +++++++++++++++++++ .../test_stable_diffusion_xl_img2img.py | 37 +++++++++++++++++++ 9 files changed, 95 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index bc7d36ea8bac..a44884dc0b2c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -380,7 +380,9 @@ def encode_prompt( # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 56b5eaf26ded..6d73ef6ebea8 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -353,7 +353,9 @@ def encode_prompt( # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 722bf38bd0e4..f20391f2d7f8 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -391,7 +391,9 @@ def encode_prompt( # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 61912180d913..884835b5f300 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -329,7 +329,9 @@ def encode_prompt( # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 65bf04d0f6bc..698b02cb7238 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -336,7 +336,9 @@ def encode_prompt( # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 387eec0b93b0..867a67b33a17 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -485,7 +485,9 @@ def encode_prompt( # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 90ae55daaa95..a3789d0aeced 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -353,7 +353,9 @@ def encode_prompt( # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index dad52238f73a..37ad33d15b88 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -261,6 +261,43 @@ def test_stable_diffusion_xl_offloads(self): assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 + def test_stable_diffusion_xl_img2img_prompt_embeds_only(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward without prompt embeds + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + inputs["prompt"] = 3 * [inputs["prompt"]] + + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with prompt embeds + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + prompt = 3 * [inputs.pop("prompt")] + + ( + prompt_embeds, + _, + pooled_prompt_embeds, + _, + ) = sd_pipe.encode_prompt(prompt) + + output = sd_pipe( + **inputs, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + ) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # make sure that it's equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_stable_diffusion_two_xl_mixture_of_denoiser(self): components = self.get_dummy_components() pipe_1 = StableDiffusionXLPipeline(**components).to(torch_device) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index a3890317aea2..494db861aa71 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -559,6 +559,43 @@ def test_stable_diffusion_xl_img2img_negative_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_stable_diffusion_xl_img2img_prompt_embeds_only(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLImg2ImgPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward without prompt embeds + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + inputs["prompt"] = 3 * [inputs["prompt"]] + + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with prompt embeds + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + prompt = 3 * [inputs.pop("prompt")] + + ( + prompt_embeds, + _, + pooled_prompt_embeds, + _, + ) = sd_pipe.encode_prompt(prompt) + + output = sd_pipe( + **inputs, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + ) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # make sure that it's equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_attention_slicing_forward_pass(self): super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) From 80243dd5333e2c7c44c929c5123a34d23ff2a666 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 16 Sep 2023 20:32:31 +0200 Subject: [PATCH 3/5] improve more --- .../controlnet/pipeline_controlnet_inpaint_sd_xl.py | 8 +++++--- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 8 +++++--- .../controlnet/pipeline_controlnet_sd_xl_img2img.py | 8 +++++--- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 8 +++++--- .../pipeline_stable_diffusion_xl_img2img.py | 8 +++++--- .../pipeline_stable_diffusion_xl_inpaint.py | 8 +++++--- .../t2i_adapter/pipeline_stable_diffusion_xl_adapter.py | 8 +++++--- 7 files changed, 35 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index a44884dc0b2c..e30a972f9b94 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -314,9 +314,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -329,6 +329,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 6d73ef6ebea8..b30853082d65 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -287,9 +287,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -302,6 +302,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index f20391f2d7f8..c897d8796813 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -325,9 +325,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -340,6 +340,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 884835b5f300..ade816b4b251 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -263,9 +263,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -278,6 +278,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 698b02cb7238..5f714780c2c9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -270,9 +270,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -285,6 +285,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 867a67b33a17..1dd432e12aa8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -419,9 +419,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -434,6 +434,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index a3789d0aeced..de028fce540d 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -287,9 +287,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -302,6 +302,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] From da5f3e665b33e2874c7123a7cc1b37abb4433068 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 16 Sep 2023 20:39:50 +0200 Subject: [PATCH 4/5] improve more --- .../pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py | 2 +- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- .../pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py | 2 +- .../stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py | 2 +- .../t2i_adapter/pipeline_stable_diffusion_xl_adapter.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index e30a972f9b94..56549a6faf2c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -378,7 +378,7 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index b30853082d65..cbac63182f71 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -351,7 +351,7 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index c897d8796813..729a0b594973 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -389,7 +389,7 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index ade816b4b251..6f97bfc46136 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -327,7 +327,7 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 5f714780c2c9..d9ebcec5a5f7 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -334,7 +334,7 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 1dd432e12aa8..364830354b71 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -483,7 +483,7 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index de028fce540d..ffd421f3ea41 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -351,7 +351,7 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt From 8e6fb5ea7641a283263abd0ff97b56c19c26271f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 18 Sep 2023 14:47:46 +0200 Subject: [PATCH 5/5] Apply suggestions from code review --- tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py | 1 - .../stable_diffusion_xl/test_stable_diffusion_xl_img2img.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 37ad33d15b88..65c7526e3aa2 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -265,7 +265,6 @@ def test_stable_diffusion_xl_img2img_prompt_embeds_only(self): components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**components) sd_pipe = sd_pipe.to(torch_device) - sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) # forward without prompt embeds diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 494db861aa71..ba7d3e8be30f 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -563,7 +563,6 @@ def test_stable_diffusion_xl_img2img_prompt_embeds_only(self): components = self.get_dummy_components() sd_pipe = StableDiffusionXLImg2ImgPipeline(**components) sd_pipe = sd_pipe.to(torch_device) - sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) # forward without prompt embeds