Skip to content

Commit db5194a

Browse files
HenryQUQsayakpaulyiyixuxu
authored
Fix Compatibility Issues in stable_diffusion_xl_reference.py (#6251)
* Fix Compatibility Issues in stable_diffusion_xl_reference.py --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent e6c9c25 commit db5194a

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

examples/community/stable_diffusion_xl_reference.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def hack_CrossAttnDownBlock2D_forward(
507507

508508
return hidden_states, output_states
509509

510-
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
510+
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs):
511511
eps = 1e-6
512512

513513
output_states = ()
@@ -686,8 +686,17 @@ def hacked_UpBlock2D_forward(
686686

687687
# 10. Prepare added time ids & embeddings
688688
add_text_embeds = pooled_prompt_embeds
689+
if self.text_encoder_2 is None:
690+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
691+
else:
692+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
693+
689694
add_time_ids = self._get_add_time_ids(
690-
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
695+
original_size,
696+
crops_coords_top_left,
697+
target_size,
698+
dtype=prompt_embeds.dtype,
699+
text_encoder_projection_dim=text_encoder_projection_dim,
691700
)
692701

693702
if do_classifier_free_guidance:

0 commit comments

Comments
 (0)