diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index a89a13f70830..05e1efb5078c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -314,9 +314,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -329,6 +329,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] @@ -378,14 +380,18 @@ def encode_prompt( negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index cbb78e509b84..cbac63182f71 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -287,9 +287,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -302,6 +302,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] @@ -351,14 +353,18 @@ def encode_prompt( negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 6fe3d0c641e5..729a0b594973 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -325,9 +325,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -340,6 +340,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] @@ -389,14 +391,18 @@ def encode_prompt( negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 25e95e0b3454..f7ea36f5f726 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -263,9 +263,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -278,6 +278,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] @@ -327,14 +329,18 @@ def encode_prompt( negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 86f337fa2d51..0d0ac9bf8173 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -270,9 +270,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -285,6 +285,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] @@ -334,14 +336,18 @@ def encode_prompt( negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 9363d7b2a3d3..ae8d974a69f6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -419,9 +419,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -434,6 +434,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] @@ -483,14 +485,18 @@ def encode_prompt( negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index ca876440166e..5827ca1e53d8 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -287,9 +287,9 @@ def encode_prompt( adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] @@ -302,6 +302,8 @@ def encode_prompt( if prompt_embeds is None: prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] prompts = [prompt, prompt_2] @@ -351,14 +353,18 @@ def encode_prompt( negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index dad52238f73a..65c7526e3aa2 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -261,6 +261,42 @@ def test_stable_diffusion_xl_offloads(self): assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 + def test_stable_diffusion_xl_img2img_prompt_embeds_only(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward without prompt embeds + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + inputs["prompt"] = 3 * [inputs["prompt"]] + + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with prompt embeds + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + prompt = 3 * [inputs.pop("prompt")] + + ( + prompt_embeds, + _, + pooled_prompt_embeds, + _, + ) = sd_pipe.encode_prompt(prompt) + + output = sd_pipe( + **inputs, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + ) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # make sure that it's equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_stable_diffusion_two_xl_mixture_of_denoiser(self): components = self.get_dummy_components() pipe_1 = StableDiffusionXLPipeline(**components).to(torch_device) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index a3890317aea2..ba7d3e8be30f 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -559,6 +559,42 @@ def test_stable_diffusion_xl_img2img_negative_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_stable_diffusion_xl_img2img_prompt_embeds_only(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLImg2ImgPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward without prompt embeds + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + inputs["prompt"] = 3 * [inputs["prompt"]] + + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + # forward with prompt embeds + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + prompt = 3 * [inputs.pop("prompt")] + + ( + prompt_embeds, + _, + pooled_prompt_embeds, + _, + ) = sd_pipe.encode_prompt(prompt) + + output = sd_pipe( + **inputs, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + ) + image_slice_2 = output.images[0, -3:, -3:, -1] + + # make sure that it's equal + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_attention_slicing_forward_pass(self): super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)