diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index e21b41ccac6d..1af1b78ff381 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -317,10 +317,50 @@ def decode_latents(self, latents): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image - def check_inputs(self, prompt, image, noise_level, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): + def check_inputs( + self, + prompt, + image, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) @@ -480,13 +520,27 @@ def __call__( """ # 1. Check inputs - self.check_inputs(prompt, image, noise_level, callback_steps) + self.check_inputs( + prompt, + image, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`