Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

Expand Down Expand Up @@ -627,7 +626,6 @@ def __init__(self, slice_size):
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)

batch_size, sequence_length, _ = hidden_states.shape

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/unclip/text_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings)
clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens)
clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1)

text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states)
text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states)
text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1)
text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2)
text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1)

return text_encoder_hidden_states, additive_clip_time_embeddings
1 change: 1 addition & 0 deletions tests/pipelines/unclip/test_unclip_image_variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
"decoder_num_inference_steps",
"super_res_num_inference_steps",
]
test_xformers_attention = False

@property
def text_embedder_hidden_size(self):
Expand Down