From 671a91f157d83b099d256447eae55fbcdebdfeae Mon Sep 17 00:00:00 2001 From: njindal Date: Thu, 30 Mar 2023 13:41:10 +0530 Subject: [PATCH 1/3] [2884]: Fix cross_attention_kwargs in StableDiffusionImg2ImgPipeline --- .../pipeline_stable_diffusion_img2img.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index ba124ffecbee..a7929e13a5f9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL @@ -577,6 +577,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -634,6 +635,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: @@ -695,7 +700,12 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: From b5f353b65a0c5f86d5f609e96c1535b73e93cee4 Mon Sep 17 00:00:00 2001 From: njindal Date: Thu, 30 Mar 2023 13:45:35 +0530 Subject: [PATCH 2/3] [Build Fix] --- .../alt_diffusion/pipeline_alt_diffusion_img2img.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 23f4886f06c1..6f4b516ab10d 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -569,6 +569,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -626,6 +627,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: @@ -687,7 +692,12 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: From 84a73f914a2986c8e91b9d512eb64ede22ed941e Mon Sep 17 00:00:00 2001 From: njindal Date: Thu, 30 Mar 2023 13:48:50 +0530 Subject: [PATCH 3/3] [Build Fix] --- .../pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 6f4b516ab10d..f7c78d7af576 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL