diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py index 8a8d5b65e31a..a05fb9001c0e 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -289,7 +289,9 @@ def __call__( guidance_scale: float = 0.0, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds_pooled: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, @@ -321,10 +323,17 @@ def __call__( prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + prompt_embeds_pooled (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input + argument. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -378,7 +387,7 @@ def __call__( # 2. Encode caption if prompt_embeds is None and negative_prompt_embeds is None: - prompt_embeds, _, negative_prompt_embeds, _ = self.encode_prompt( + _, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt( prompt=prompt, device=device, batch_size=batch_size, @@ -386,10 +395,16 @@ def __call__( do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, ) + + # The pooled embeds from the prior are pooled again before being passed to the decoder prompt_embeds_pooled = ( - torch.cat([prompt_embeds, negative_prompt_embeds]) if self.do_classifier_free_guidance else prompt_embeds + torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled]) + if self.do_classifier_free_guidance + else prompt_embeds_pooled ) effnet = ( torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py index cd3592b49ac0..ca7cc72407d6 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py @@ -155,14 +155,14 @@ def __call__( height: int = 512, width: int = 512, prior_num_inference_steps: int = 60, - prior_timesteps: Optional[List[float]] = None, prior_guidance_scale: float = 4.0, num_inference_steps: int = 12, - decoder_timesteps: Optional[List[float]] = None, decoder_guidance_scale: float = 0.0, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds_pooled: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, @@ -187,10 +187,17 @@ def __call__( prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + prompt_embeds_pooled (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` + input argument. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to 512): @@ -253,7 +260,6 @@ def __call__( [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - prior_outputs = self.prior_pipe( prompt=prompt if prompt_embeds is None else None, images=images, @@ -263,7 +269,9 @@ def __call__( guidance_scale=prior_guidance_scale, negative_prompt=negative_prompt if negative_prompt_embeds is None else None, prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, @@ -274,7 +282,9 @@ def __call__( ) image_embeddings = prior_outputs.image_embeddings prompt_embeds = prior_outputs.get("prompt_embeds", None) + prompt_embeds_pooled = prior_outputs.get("prompt_embeds_pooled", None) negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None) + negative_prompt_embeds_pooled = prior_outputs.get("negative_prompt_embeds_pooled", None) outputs = self.decoder_pipe( image_embeddings=image_embeddings, @@ -283,7 +293,9 @@ def __call__( guidance_scale=decoder_guidance_scale, negative_prompt=negative_prompt if negative_prompt_embeds is None else None, prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, generator=generator, output_type=output_type, return_dict=return_dict, diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py index 8df0d2a991b5..24ccc4b882e9 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py @@ -64,7 +64,9 @@ class StableCascadePriorPipelineOutput(BaseOutput): image_embeddings: Union[torch.FloatTensor, np.ndarray] prompt_embeds: Union[torch.FloatTensor, np.ndarray] + prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray] negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray] + negative_prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray] class StableCascadePriorPipeline(DiffusionPipeline): @@ -305,6 +307,16 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + if prompt_embeds is not None and prompt_embeds_pooled is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" + ) + + if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" + ) + if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None: if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape: raise ValueError( @@ -339,7 +351,7 @@ def do_classifier_free_guidance(self): def num_timesteps(self): return self._num_timesteps - def get_t_condioning(self, t, alphas_cumprod): + def get_timestep_ratio_conditioning(self, t, alphas_cumprod): s = torch.tensor([0.003]) clamp_range = [0, 1] min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 @@ -558,7 +570,7 @@ def __call__( for i, t in enumerate(self.progress_bar(timesteps)): if not isinstance(self.scheduler, DDPMWuerstchenScheduler): if len(alphas_cumprod) > 0: - timestep_ratio = self.get_t_condioning(t.long().cpu(), alphas_cumprod) + timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod) timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device) else: timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype) @@ -609,6 +621,18 @@ def __call__( ) # float() as bfloat16-> numpy doesnt work if not return_dict: - return (latents, prompt_embeds, negative_prompt_embeds) + return ( + latents, + prompt_embeds, + prompt_embeds_pooled, + negative_prompt_embeds, + negative_prompt_embeds_pooled, + ) - return StableCascadePriorPipelineOutput(latents, prompt_embeds, negative_prompt_embeds) + return StableCascadePriorPipelineOutput( + image_embeddings=latents, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + ) diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py index e717c7733b0a..4a9a123fcda1 100644 --- a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py +++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py @@ -241,6 +241,39 @@ def test_float16_inference(self): def test_callback_inputs(self): super().test_callback_inputs() - # def test_callback_cfg(self): - # pass - # pass + def test_stable_cascade_combined_prompt_embeds(self): + device = "cpu" + components = self.get_dummy_components() + + pipe = StableCascadeCombinedPipeline(**components) + pipe.set_progress_bar_config(disable=None) + + prompt = "A photograph of a shiba inu, wearing a hat" + ( + prompt_embeds, + prompt_embeds_pooled, + negative_prompt_embeds, + negative_prompt_embeds_pooled, + ) = pipe.prior_pipe.encode_prompt(device, 1, 1, False, prompt=prompt) + generator = torch.Generator(device=device) + + output_prompt = pipe( + prompt=prompt, + num_inference_steps=1, + prior_num_inference_steps=1, + output_type="np", + generator=generator.manual_seed(0), + ) + output_prompt_embeds = pipe( + prompt=None, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + num_inference_steps=1, + prior_num_inference_steps=1, + output_type="np", + generator=generator.manual_seed(0), + ) + + assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5 diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py index 7656744b49c7..57722045784e 100644 --- a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py +++ b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py @@ -207,6 +207,45 @@ def test_attention_slicing_forward_pass(self): def test_float16_inference(self): super().test_float16_inference() + def test_stable_cascade_decoder_prompt_embeds(self): + device = "cpu" + components = self.get_dummy_components() + + pipe = StableCascadeDecoderPipeline(**components) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image_embeddings = inputs["image_embeddings"] + prompt = "A photograph of a shiba inu, wearing a hat" + ( + prompt_embeds, + prompt_embeds_pooled, + negative_prompt_embeds, + negative_prompt_embeds_pooled, + ) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt) + generator = torch.Generator(device=device) + + decoder_output_prompt = pipe( + image_embeddings=image_embeddings, + prompt=prompt, + num_inference_steps=1, + output_type="np", + generator=generator.manual_seed(0), + ) + decoder_output_prompt_embeds = pipe( + image_embeddings=image_embeddings, + prompt=None, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + num_inference_steps=1, + output_type="np", + generator=generator.manual_seed(0), + ) + + assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5 + @slow @require_torch_gpu diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py index c0ee8cc75963..54ce5a7a72db 100644 --- a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py +++ b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py @@ -273,6 +273,41 @@ def test_inference_with_prior_lora(self): self.assertTrue(image_embed.shape == lora_image_embed.shape) + def test_stable_cascade_decoder_prompt_embeds(self): + device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + prompt = "A photograph of a shiba inu, wearing a hat" + ( + prompt_embeds, + prompt_embeds_pooled, + negative_prompt_embeds, + negative_prompt_embeds_pooled, + ) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt) + generator = torch.Generator(device=device) + + output_prompt = pipe( + prompt=prompt, + num_inference_steps=1, + output_type="np", + generator=generator.manual_seed(0), + ) + output_prompt_embeds = pipe( + prompt=None, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + num_inference_steps=1, + output_type="np", + generator=generator.manual_seed(0), + ) + + assert np.abs(output_prompt.image_embeddings - output_prompt_embeds.image_embeddings).max() < 1e-5 + @slow @require_torch_gpu