diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index d190acb1fa1c..bc6ca1efbd9b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -116,6 +116,7 @@ def __call__( width: int = 512, num_inference_steps: int = 50, guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, @@ -144,6 +145,9 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -217,9 +221,32 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + " {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index c8f02b5896d6..f0846573d34a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -128,6 +128,7 @@ def __call__( strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", @@ -160,6 +161,9 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -258,9 +262,28 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + " {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError("The length of `negative_prompt` should be equal to batch_size.") + else: + uncond_tokens = negative_prompt + max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 21490d975730..107c4de07197 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -144,6 +144,7 @@ def __call__( strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", @@ -180,6 +181,9 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. @@ -292,9 +296,32 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + " {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 48e0e147c48d..0bb0ce440d90 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -52,6 +52,7 @@ def __call__( width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, eta: Optional[float] = 0.0, latents: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", @@ -102,9 +103,32 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + "`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + " {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", ) uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]