diff --git a/examples/community/stable_diffusion_xl_reference.py b/examples/community/stable_diffusion_xl_reference.py index 4c7efa4b5f7a..7282582baf91 100644 --- a/examples/community/stable_diffusion_xl_reference.py +++ b/examples/community/stable_diffusion_xl_reference.py @@ -507,7 +507,7 @@ def hack_CrossAttnDownBlock2D_forward( return hidden_states, output_states - def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs): + def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs): eps = 1e-6 output_states = () @@ -686,8 +686,17 @@ def hacked_UpBlock2D_forward( # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if do_classifier_free_guidance: