From f664e9b7d99b1899bd9689d1c13de5a834a9f48d Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 24 Feb 2024 09:49:36 +0100 Subject: [PATCH 1/2] Fix IPAdapterAttnProcessor --- src/diffusers/models/attention_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1c008264ba33..dd0599f0fc67 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2186,6 +2186,8 @@ def __call__( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): + current_ip_hidden_states = torch.squeeze(current_ip_hidden_states, 1) + ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) From 615c23e3456aebec0831aee5602e34759b1efb63 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 24 Feb 2024 11:19:23 +0100 Subject: [PATCH 2/2] Fix batch_to_head_dim and revert reshape --- src/diffusers/models/attention_processor.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index dd0599f0fc67..62d764c5edbe 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -559,12 +559,16 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten `torch.Tensor`: The reshaped tensor. """ head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3) if out_dim == 3: - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) return tensor @@ -2186,8 +2190,6 @@ def __call__( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): - current_ip_hidden_states = torch.squeeze(current_ip_hidden_states, 1) - ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states)