From 063e08569d43bdf7792d26ab8cfd42d25859727b Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sun, 4 Feb 2024 18:31:29 +0100 Subject: [PATCH 01/21] Add attention masking to attn processors --- src/diffusers/models/attention_processor.py | 74 ++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 908946119dc2..415dbbe88321 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -720,6 +720,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, + masks=None, ) -> torch.Tensor: residual = hidden_states @@ -1196,6 +1197,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, + masks=None, ) -> torch.FloatTensor: residual = hidden_states if attn.spatial_norm is not None: @@ -2131,6 +2133,7 @@ def __call__( attention_mask=None, temb=None, scale=1.0, + masks=None, ): residual = hidden_states @@ -2197,7 +2200,41 @@ def __call__( ip_attention_probs = attn.get_attention_scores(query, ip_key, None) current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + + if masks is not None: + if not isinstance(masks, list): + masks = [masks] + seq_len = current_ip_hidden_states.shape[1] + o_h = masks[0].shape[1] + o_w = masks[0].shape[2] + ratio = o_w / o_h + mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio))) + mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) + mask_w = seq_len // mask_h + mask_downsample = [] + for mask in masks: + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + mask_downsample.append( + F.interpolate( + torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic" + ).squeeze(0) + ) + mask_downsample = torch.cat(mask_downsample, dim=0) + + if mask_downsample.shape[0] < batch_size: + mask_downsample = mask_downsample.repeat(batch_size // len(masks), 1, 1) + if mask_downsample.shape[0] > batch_size: + mask_downsample = mask_downsample[:batch_size, :, :] + + mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat( + 1, 1, attn.heads * current_ip_hidden_states.shape[-1] + ) + + mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device) + + current_ip_hidden_states = current_ip_hidden_states * mask_downsample hidden_states = hidden_states + scale * current_ip_hidden_states @@ -2268,6 +2305,7 @@ def __call__( attention_mask=None, temb=None, scale=1.0, + masks=None, ): residual = hidden_states @@ -2357,6 +2395,40 @@ def __call__( ) current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + if masks is not None: + if not isinstance(masks, list): + masks = [masks] + seq_len = current_ip_hidden_states.shape[1] + o_h = masks[0].shape[1] + o_w = masks[0].shape[2] + ratio = o_w / o_h + mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio))) + mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) + mask_w = seq_len // mask_h + mask_downsample = [] + for mask in masks: + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + mask_downsample.append( + F.interpolate( + torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic" + ).squeeze(0) + ) + mask_downsample = torch.cat(mask_downsample, dim=0) + + if mask_downsample.shape[0] < batch_size: + mask_downsample = mask_downsample.repeat(batch_size // len(masks), 1, 1) + if mask_downsample.shape[0] > batch_size: + mask_downsample = mask_downsample[:batch_size, :, :] + + mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat( + 1, 1, attn.heads * head_dim + ) + + mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device) + + current_ip_hidden_states = current_ip_hidden_states * mask_downsample + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj From d753eec6b16a49970c2cecb15d0877ee5f2747d9 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Mon, 5 Feb 2024 12:06:02 +0100 Subject: [PATCH 02/21] Move latent image masking --- src/diffusers/models/attention_processor.py | 161 ++++++++++++-------- 1 file changed, 100 insertions(+), 61 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 415dbbe88321..7fb65af03ca2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2188,23 +2188,27 @@ def __call__( hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) - # for ip-adapter - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) + if masks is not None: + if not isinstance(masks, list): + masks = [masks] + if len(masks) != len(ip_hidden_states): + raise ValueError( + f"Number of masks ({len(masks)}) must match number of IP-Adapters ({len(self.scale)})" + ) + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, masks + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) - if masks is not None: - if not isinstance(masks, list): - masks = [masks] seq_len = current_ip_hidden_states.shape[1] o_h = masks[0].shape[1] o_w = masks[0].shape[2] @@ -2212,31 +2216,43 @@ def __call__( mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio))) mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) mask_w = seq_len // mask_h - mask_downsample = [] - for mask in masks: - if len(mask.shape) == 2: - mask = mask.unsqueeze(0) - mask_downsample.append( - F.interpolate( - torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic" - ).squeeze(0) - ) - mask_downsample = torch.cat(mask_downsample, dim=0) + + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + mask_downsample = F.interpolate( + torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic" + ).squeeze(0) if mask_downsample.shape[0] < batch_size: - mask_downsample = mask_downsample.repeat(batch_size // len(masks), 1, 1) + mask_downsample = mask_downsample.repeat(batch_size, 1, 1) if mask_downsample.shape[0] > batch_size: mask_downsample = mask_downsample[:batch_size, :, :] mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat( - 1, 1, attn.heads * current_ip_hidden_states.shape[-1] + 1, 1, current_ip_hidden_states.shape[-1] ) mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device) current_ip_hidden_states = current_ip_hidden_states * mask_downsample - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states + else: + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -2374,30 +2390,34 @@ def __call__( hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # for ip-adapter - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) + if masks is not None: + if not isinstance(masks, list): + masks = [masks] + if len(masks) != len(ip_hidden_states): + raise ValueError( + f"Number of masks ({len(masks)}) must match number of IP-Adapters ({len(self.scale)})" + ) + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, masks + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) - if masks is not None: - if not isinstance(masks, list): - masks = [masks] seq_len = current_ip_hidden_states.shape[1] o_h = masks[0].shape[1] o_w = masks[0].shape[2] @@ -2405,19 +2425,15 @@ def __call__( mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio))) mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) mask_w = seq_len // mask_h - mask_downsample = [] - for mask in masks: - if len(mask.shape) == 2: - mask = mask.unsqueeze(0) - mask_downsample.append( - F.interpolate( - torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic" - ).squeeze(0) - ) - mask_downsample = torch.cat(mask_downsample, dim=0) + + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + mask_downsample = F.interpolate( + torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic" + ).squeeze(0) if mask_downsample.shape[0] < batch_size: - mask_downsample = mask_downsample.repeat(batch_size // len(masks), 1, 1) + mask_downsample = mask_downsample.repeat(batch_size, 1, 1) if mask_downsample.shape[0] > batch_size: mask_downsample = mask_downsample[:batch_size, :, :] @@ -2429,7 +2445,30 @@ def __call__( current_ip_hidden_states = current_ip_hidden_states * mask_downsample - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states + else: + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) From 60336baa26bbeabef7a38ae34047d3dfa8412282 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Mon, 5 Feb 2024 18:31:27 +0100 Subject: [PATCH 03/21] Remove redundant code --- src/diffusers/models/attention_processor.py | 111 +++++++------------- 1 file changed, 40 insertions(+), 71 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7fb65af03ca2..0b4ac2a40e4d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2195,20 +2195,24 @@ def __call__( raise ValueError( f"Number of masks ({len(masks)}) must match number of IP-Adapters ({len(self.scale)})" ) - # for ip-adapter - for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, masks - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) + else: + masks = [None] * len(ip_hidden_states) + + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, masks + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + if mask is not None: seq_len = current_ip_hidden_states.shape[1] o_h = masks[0].shape[1] o_w = masks[0].shape[2] @@ -2236,23 +2240,7 @@ def __call__( current_ip_hidden_states = current_ip_hidden_states * mask_downsample - hidden_states = hidden_states + scale * current_ip_hidden_states - else: - # for ip-adapter - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) - - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) - - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -2397,27 +2385,31 @@ def __call__( raise ValueError( f"Number of masks ({len(masks)}) must match number of IP-Adapters ({len(self.scale)})" ) - # for ip-adapter - for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, masks - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) + else: + masks = [None] * len(ip_hidden_states) - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + # for ip-adapter + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, masks + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + + if mask is not None: seq_len = current_ip_hidden_states.shape[1] o_h = masks[0].shape[1] o_w = masks[0].shape[2] @@ -2445,30 +2437,7 @@ def __call__( current_ip_hidden_states = current_ip_hidden_states * mask_downsample - hidden_states = hidden_states + scale * current_ip_hidden_states - else: - # for ip-adapter - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) - - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) From f6451d3e3ec0ccf6b918a83abf74b877f16fb26e Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Mon, 5 Feb 2024 18:42:10 +0100 Subject: [PATCH 04/21] Fix removed line --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0b4ac2a40e4d..5551b8933ea2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2210,6 +2210,7 @@ def __call__( ip_attention_probs = attn.get_attention_scores(query, ip_key, None) current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) if mask is not None: From bf4eb1dda8f18b102c1bce3f5ac80aa63d316c2e Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 7 Feb 2024 09:16:02 +0100 Subject: [PATCH 05/21] Add padding --- src/diffusers/models/attention_processor.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5551b8933ea2..6e84c4e478a2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2233,7 +2233,14 @@ def __call__( if mask_downsample.shape[0] > batch_size: mask_downsample = mask_downsample[:batch_size, :, :] - mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat( + mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) + + if mask_h * mask_w < seq_len: + mask_downsample = F.pad(mask_downsample, (0, seq_len-mask_downsample.shape[1]), value=0.0) + if mask_h * mask_w > seq_len: + mask_downsample = mask_downsample[:, :seq_len] + + mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( 1, 1, current_ip_hidden_states.shape[-1] ) @@ -2430,8 +2437,15 @@ def __call__( if mask_downsample.shape[0] > batch_size: mask_downsample = mask_downsample[:batch_size, :, :] - mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat( - 1, 1, attn.heads * head_dim + mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) + + if mask_h * mask_w < seq_len: + mask_downsample = F.pad(mask_downsample, (0, seq_len-mask_downsample.shape[1]), value=0.0) + if mask_h * mask_w > seq_len: + mask_downsample = mask_downsample[:, :seq_len] + + mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( + 1, 1, current_ip_hidden_states.shape[-1] ) mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device) From 37419f13d70facac58194b5ec99217e64fd9dbc8 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 7 Feb 2024 21:01:31 +0100 Subject: [PATCH 06/21] Apply suggestions from code review Co-authored-by: YiYi Xu --- src/diffusers/models/attention_processor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6e84c4e478a2..0324cee085ca 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -720,7 +720,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, - masks=None, + ip_adapter_masks=None, ) -> torch.Tensor: residual = hidden_states @@ -1197,7 +1197,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, - masks=None, + ip_adapter_masks=None, ) -> torch.FloatTensor: residual = hidden_states if attn.spatial_norm is not None: @@ -2133,7 +2133,7 @@ def __call__( attention_mask=None, temb=None, scale=1.0, - masks=None, + ip_adapter_masks=None, ): residual = hidden_states @@ -2189,8 +2189,8 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) if masks is not None: - if not isinstance(masks, list): - masks = [masks] + if not isinstance(masks,np.ndarray) or mask.ndim != 4: + raise ValueError(" ip_adapter_mask should be a numpy array with shape num_ip_adapter, 1, height, width. Please use `IPAdapterMaskProcessor` to preprocess your mask") if len(masks) != len(ip_hidden_states): raise ValueError( f"Number of masks ({len(masks)}) must match number of IP-Adapters ({len(self.scale)})" From a180e258f3b43f17c2c199faeede0a3f91895a39 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 7 Feb 2024 22:13:38 +0100 Subject: [PATCH 07/21] Add IPAdapterMaskProcessing - Move downsampling code to downsample method - Add process method that internally calls preprocess --- src/diffusers/image_processor.py | 51 +++++++++++ src/diffusers/models/attention_processor.py | 97 +++++---------------- 2 files changed, 74 insertions(+), 74 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 843052c1adf3..21cd176fd997 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -18,6 +18,7 @@ import numpy as np import PIL.Image import torch +import torch.nn.functional as F from PIL import Image, ImageFilter, ImageOps from .configuration_utils import ConfigMixin, register_to_config @@ -882,3 +883,53 @@ def preprocess( depth = self.binarize(depth) return rgb, depth + + +class IPAdapterMaskProcessor(VaeImageProcessor): + """ + Image processor for IP Adapter image masks. + + """ + def __init__(self): + super().__init__(do_normalize=False, do_binarize=True, do_convert_grayscale=True) + + def process(self, images: List[PIL.Image.Image]) -> np.ndarray: + """ + Convert a list of PIL.Image.Image images to a np.ndarray + """ + images = self.preprocess(images) + return images + + @staticmethod + def downsample(mask: np.ndarray, batch_size: int, seq_length: int, value_embed_dim: int): + """ + Downsample a mask to target seq_length + """ + o_h = mask.shape[1] + o_w = mask.shape[2] + ratio = o_w / o_h + mask_h = int(torch.sqrt(torch.tensor(seq_length / ratio))) + mask_h = int(mask_h) + int((seq_length % int(mask_h)) != 0) + mask_w = seq_length // mask_h + + mask_downsample = F.interpolate( + torch.tensor(mask, dtype=torch.float32).clone().detach().unsqueeze(0), size=(mask_h, mask_w), mode="bicubic" + ).squeeze(0) + + if mask_downsample.shape[0] < batch_size: + mask_downsample = mask_downsample.repeat(batch_size, 1, 1) + if mask_downsample.shape[0] > batch_size: + mask_downsample = mask_downsample[:batch_size, :, :] + + mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) + + if mask_h * mask_w < seq_length: + mask_downsample = F.pad(mask_downsample, (0, seq_length-mask_downsample.shape[1]), value=0.0) + if mask_h * mask_w > seq_length: + mask_downsample = mask_downsample[:, :seq_length] + + mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( + 1, 1, value_embed_dim + ) + + return mask_downsample diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0324cee085ca..f4c88ba2f74c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -22,6 +22,7 @@ from ..utils.import_utils import is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph from .lora import LoRACompatibleLinear, LoRALinearLayer +from ..image_processor import IPAdapterMaskProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -2188,19 +2189,20 @@ def __call__( hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) - if masks is not None: - if not isinstance(masks,np.ndarray) or mask.ndim != 4: - raise ValueError(" ip_adapter_mask should be a numpy array with shape num_ip_adapter, 1, height, width. Please use `IPAdapterMaskProcessor` to preprocess your mask") - if len(masks) != len(ip_hidden_states): + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: + raise ValueError(" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask") + if len(ip_adapter_masks) != len(ip_hidden_states): raise ValueError( - f"Number of masks ({len(masks)}) must match number of IP-Adapters ({len(self.scale)})" + f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" ) else: - masks = [None] * len(ip_hidden_states) + ip_adapter_masks = [None] * len(ip_hidden_states) # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, masks + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) @@ -2214,36 +2216,9 @@ def __call__( current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) if mask is not None: - seq_len = current_ip_hidden_states.shape[1] - o_h = masks[0].shape[1] - o_w = masks[0].shape[2] - ratio = o_w / o_h - mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio))) - mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) - mask_w = seq_len // mask_h - - if len(mask.shape) == 2: - mask = mask.unsqueeze(0) - mask_downsample = F.interpolate( - torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic" - ).squeeze(0) - - if mask_downsample.shape[0] < batch_size: - mask_downsample = mask_downsample.repeat(batch_size, 1, 1) - if mask_downsample.shape[0] > batch_size: - mask_downsample = mask_downsample[:batch_size, :, :] - - mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) - - if mask_h * mask_w < seq_len: - mask_downsample = F.pad(mask_downsample, (0, seq_len-mask_downsample.shape[1]), value=0.0) - if mask_h * mask_w > seq_len: - mask_downsample = mask_downsample[:, :seq_len] - - mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( - 1, 1, current_ip_hidden_states.shape[-1] - ) - + mask_downsample = IPAdapterMaskProcessor.downsample(mask, batch_size, current_ip_hidden_states.shape[1], + current_ip_hidden_states.shape[2]) + mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device) current_ip_hidden_states = current_ip_hidden_states * mask_downsample @@ -2317,7 +2292,7 @@ def __call__( attention_mask=None, temb=None, scale=1.0, - masks=None, + ip_adapter_masks=None, ): residual = hidden_states @@ -2386,19 +2361,20 @@ def __call__( hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - if masks is not None: - if not isinstance(masks, list): - masks = [masks] - if len(masks) != len(ip_hidden_states): + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: + raise ValueError(" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask") + if len(ip_adapter_masks) != len(ip_hidden_states): raise ValueError( - f"Number of masks ({len(masks)}) must match number of IP-Adapters ({len(self.scale)})" + f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" ) else: - masks = [None] * len(ip_hidden_states) + ip_adapter_masks = [None] * len(ip_hidden_states) # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, masks + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) @@ -2418,35 +2394,8 @@ def __call__( current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) if mask is not None: - seq_len = current_ip_hidden_states.shape[1] - o_h = masks[0].shape[1] - o_w = masks[0].shape[2] - ratio = o_w / o_h - mask_h = int(torch.sqrt(torch.tensor(seq_len / ratio))) - mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) - mask_w = seq_len // mask_h - - if len(mask.shape) == 2: - mask = mask.unsqueeze(0) - mask_downsample = F.interpolate( - torch.tensor(mask, dtype=torch.float32).unsqueeze(0), size=(mask_h, mask_w), mode="bicubic" - ).squeeze(0) - - if mask_downsample.shape[0] < batch_size: - mask_downsample = mask_downsample.repeat(batch_size, 1, 1) - if mask_downsample.shape[0] > batch_size: - mask_downsample = mask_downsample[:batch_size, :, :] - - mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) - - if mask_h * mask_w < seq_len: - mask_downsample = F.pad(mask_downsample, (0, seq_len-mask_downsample.shape[1]), value=0.0) - if mask_h * mask_w > seq_len: - mask_downsample = mask_downsample[:, :seq_len] - - mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( - 1, 1, current_ip_hidden_states.shape[-1] - ) + mask_downsample = IPAdapterMaskProcessor.downsample(mask, batch_size, current_ip_hidden_states.shape[1], + current_ip_hidden_states.shape[2]) mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device) From c6fddaed9604a6555147687efa5e67fc51a1bafa Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 7 Feb 2024 22:23:29 +0100 Subject: [PATCH 08/21] Fix return types --- src/diffusers/image_processor.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 21cd176fd997..0fd0f0fab5be 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -893,15 +893,15 @@ class IPAdapterMaskProcessor(VaeImageProcessor): def __init__(self): super().__init__(do_normalize=False, do_binarize=True, do_convert_grayscale=True) - def process(self, images: List[PIL.Image.Image]) -> np.ndarray: + def process(self, images: List[PIL.Image.Image]) -> torch.FloatTensor: """ - Convert a list of PIL.Image.Image images to a np.ndarray + Convert a list of PIL.Image.Image images to a torch.FloatTensor """ images = self.preprocess(images) return images @staticmethod - def downsample(mask: np.ndarray, batch_size: int, seq_length: int, value_embed_dim: int): + def downsample(mask: torch.FloatTensor, batch_size: int, seq_length: int, value_embed_dim: int): """ Downsample a mask to target seq_length """ @@ -912,9 +912,7 @@ def downsample(mask: np.ndarray, batch_size: int, seq_length: int, value_embed_d mask_h = int(mask_h) + int((seq_length % int(mask_h)) != 0) mask_w = seq_length // mask_h - mask_downsample = F.interpolate( - torch.tensor(mask, dtype=torch.float32).clone().detach().unsqueeze(0), size=(mask_h, mask_w), mode="bicubic" - ).squeeze(0) + mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) if mask_downsample.shape[0] < batch_size: mask_downsample = mask_downsample.repeat(batch_size, 1, 1) From 708e0ebba36e47c661324efb3d34463e9af3c7f6 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Thu, 8 Feb 2024 21:42:51 +0100 Subject: [PATCH 09/21] Update image_processor --- src/diffusers/image_processor.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 0fd0f0fab5be..23461f6f7288 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -893,10 +893,12 @@ class IPAdapterMaskProcessor(VaeImageProcessor): def __init__(self): super().__init__(do_normalize=False, do_binarize=True, do_convert_grayscale=True) - def process(self, images: List[PIL.Image.Image]) -> torch.FloatTensor: + def process(self, images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> torch.FloatTensor: """ - Convert a list of PIL.Image.Image images to a torch.FloatTensor + Convert a PIL.Image.Image or a list of PIL.Image.Image to a torch.FloatTensor """ + if not isinstance(images, list): + images = [images] images = self.preprocess(images) return images @@ -914,18 +916,21 @@ def downsample(mask: torch.FloatTensor, batch_size: int, seq_length: int, value_ mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) + # Repeat mask until batch_size if mask_downsample.shape[0] < batch_size: mask_downsample = mask_downsample.repeat(batch_size, 1, 1) - if mask_downsample.shape[0] > batch_size: - mask_downsample = mask_downsample[:batch_size, :, :] mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) + # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match + # Pad tensor if downsampled_mask.shape[1] is smaller than seq_length if mask_h * mask_w < seq_length: mask_downsample = F.pad(mask_downsample, (0, seq_length-mask_downsample.shape[1]), value=0.0) + # Discard last embeddings if downsampled_mask.shape[1] is bigger than seq_length if mask_h * mask_w > seq_length: mask_downsample = mask_downsample[:, :seq_length] + # Repeat last dimension to match SDPA output shape mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( 1, 1, value_embed_dim ) From bbfeb676124debc32ae706274d0429618bcbebcf Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Thu, 8 Feb 2024 22:30:55 +0100 Subject: [PATCH 10/21] Add test --- .../test_ip_adapter_stable_diffusion.py | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 710dea3c2da7..04340b6aa7a8 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -39,6 +39,7 @@ slow, torch_device, ) +from diffusers.image_processor import IPAdapterMaskProcessor enable_full_determinism() @@ -62,7 +63,7 @@ def get_image_processor(self, repo_id): image_processor = CLIPImageProcessor.from_pretrained(repo_id) return image_processor - def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False): + def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False, for_masks=False): image = load_image( "https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png" ) @@ -99,6 +100,14 @@ def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_s input_kwargs.update({"image": image, "mask_image": mask, "ip_adapter_image": ip_image}) + elif for_masks: + face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png") + face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png") + mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png") + mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png") + input_kwargs.update({"ip_adapter_image": [[face_image1], [face_image2]], + "cross_attention_kwargs": {"ip_adapter_masks": [mask1, mask2]}}) + return input_kwargs @@ -458,3 +467,26 @@ def test_inpainting_sdxl(self): expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + def test_masks(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2 + ) + pipeline.set_ip_adapter_scale([0.7] * 2) + + inputs = self.get_dummy_inputs(for_masks=True) + masks = inputs["cross_attention_kwargs"]["ip_adapter_masks"] + processor = IPAdapterMaskProcessor() + masks = processor.process(masks) + inputs["cross_attention_kwargs"]["ip_adapter_masks"] = masks + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + expected_slice = np.array( + [0.79571414, 0.7987394, 0.80234784, 0.79982674, 0.798162, 0.80397135, 0.8073128, 0.8062345, 0.8074084] + ) + assert np.allclose(image_slice, expected_slice, atol=1e-3) From 2af042686a4f0dbb4a90398b97959336aefbce1d Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 9 Feb 2024 18:51:55 +0100 Subject: [PATCH 11/21] Apply suggestions from code review Co-authored-by: Sayak Paul --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f7e557d3ed2d..ce73e97eda08 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2219,7 +2219,7 @@ def __call__( mask_downsample = IPAdapterMaskProcessor.downsample(mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]) - mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device) + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) current_ip_hidden_states = current_ip_hidden_states * mask_downsample From 8f6247d35aee8428bf2f9ae218717f97b0fc6ce8 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 9 Feb 2024 19:25:17 +0100 Subject: [PATCH 12/21] Fix names --- src/diffusers/image_processor.py | 59 +++++++++++++-------- src/diffusers/models/attention_processor.py | 37 +++++++------ 2 files changed, 56 insertions(+), 40 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 2d07de049e0e..d02b95430dbd 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -15,6 +15,7 @@ import warnings from typing import List, Optional, Tuple, Union +import math import numpy as np import PIL.Image import torch @@ -893,42 +894,58 @@ class IPAdapterMaskProcessor(VaeImageProcessor): def __init__(self): super().__init__(do_normalize=False, do_binarize=True, do_convert_grayscale=True) - def process(self, images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> torch.FloatTensor: - """ - Convert a PIL.Image.Image or a list of PIL.Image.Image to a torch.FloatTensor - """ - if not isinstance(images, list): - images = [images] - images = self.preprocess(images) - return images - @staticmethod - def downsample(mask: torch.FloatTensor, batch_size: int, seq_length: int, value_embed_dim: int): + def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int): """ - Downsample a mask to target seq_length + Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. + If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. + + Args: + mask (`torch.FloatTensor`): + The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. + batch_size (`int`): + The batch size. + num_queries (`int`): + The number of queries. + value_embed_dim (`int`): + The dimensionality of the value embeddings. + + Returns: + `torch.FloatTensor`: + The downsampled mask tensor. + """ o_h = mask.shape[1] o_w = mask.shape[2] ratio = o_w / o_h - mask_h = int(torch.sqrt(torch.tensor(seq_length / ratio))) - mask_h = int(mask_h) + int((seq_length % int(mask_h)) != 0) - mask_w = seq_length // mask_h + mask_h = int(math.sqrt(num_queries / ratio)) + mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) + mask_w = num_queries // mask_h mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) - # Repeat mask until batch_size + # Repeat batch_size times if mask_downsample.shape[0] < batch_size: mask_downsample = mask_downsample.repeat(batch_size, 1, 1) mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) + downsampled_area = mask_h * mask_w # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match - # Pad tensor if downsampled_mask.shape[1] is smaller than seq_length - if mask_h * mask_w < seq_length: - mask_downsample = F.pad(mask_downsample, (0, seq_length-mask_downsample.shape[1]), value=0.0) - # Discard last embeddings if downsampled_mask.shape[1] is bigger than seq_length - if mask_h * mask_w > seq_length: - mask_downsample = mask_downsample[:, :seq_length] + # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries + if downsampled_area < num_queries: + warnings.warn( + "The aspect ratio of the mask does not match the aspect ratio of the output image. " + "Please update your masks or adjust the output size for optimal performance.", UserWarning, + ) + mask_downsample = F.pad(mask_downsample, (0, num_queries-mask_downsample.shape[1]), value=0.0) + # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries + if downsampled_area > num_queries: + warnings.warn( + "The aspect ratio of the mask does not match the aspect ratio of the output image. " + "Please update your masks or adjust the output size for optimal performance.", UserWarning, + ) + mask_downsample = mask_downsample[:, :num_queries] # Repeat last dimension to match SDPA output shape mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ce73e97eda08..0f39637501c9 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -721,7 +721,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, - ip_adapter_masks=None, + ip_adapter_masks: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: residual = hidden_states @@ -1198,7 +1198,7 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, - ip_adapter_masks=None, + ip_adapter_masks: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: residual = hidden_states if attn.spatial_norm is not None: @@ -2128,13 +2128,13 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale def __call__( self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - scale=1.0, - ip_adapter_masks=None, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.FloatTensor] = None, ): residual = hidden_states @@ -2193,7 +2193,7 @@ def __call__( if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: raise ValueError(" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." " Please use `IPAdapterMaskProcessor` to preprocess your mask") - if len(ip_adapter_masks) != len(ip_hidden_states): + if len(ip_adapter_masks) != len(self.scale): raise ValueError( f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" ) @@ -2213,7 +2213,6 @@ def __call__( ip_attention_probs = attn.get_attention_scores(query, ip_key, None) current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) if mask is not None: mask_downsample = IPAdapterMaskProcessor.downsample(mask, batch_size, current_ip_hidden_states.shape[1], @@ -2286,13 +2285,13 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale def __call__( self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - scale=1.0, - ip_adapter_masks=None, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.FloatTensor] = None, ): residual = hidden_states @@ -2365,7 +2364,7 @@ def __call__( if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: raise ValueError(" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." " Please use `IPAdapterMaskProcessor` to preprocess your mask") - if len(ip_adapter_masks) != len(ip_hidden_states): + if len(ip_adapter_masks) != len(self.scale): raise ValueError( f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" ) From 534b9d91d0459eb7009dc3d33ee682bfd6d58ae8 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 9 Feb 2024 19:30:01 +0100 Subject: [PATCH 13/21] Fix style --- src/diffusers/image_processor.py | 23 +++++++++------- src/diffusers/models/attention_processor.py | 26 ++++++++++++------- .../test_ip_adapter_stable_diffusion.py | 23 +++++++++++----- 3 files changed, 46 insertions(+), 26 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index d02b95430dbd..2caf1e73fc85 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import warnings from typing import List, Optional, Tuple, Union -import math import numpy as np import PIL.Image import torch @@ -891,6 +891,7 @@ class IPAdapterMaskProcessor(VaeImageProcessor): Image processor for IP Adapter image masks. """ + def __init__(self): super().__init__(do_normalize=False, do_binarize=True, do_convert_grayscale=True) @@ -901,19 +902,19 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. Args: - mask (`torch.FloatTensor`): + mask (`torch.FloatTensor`): The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. - batch_size (`int`): + batch_size (`int`): The batch size. - num_queries (`int`): + num_queries (`int`): The number of queries. - value_embed_dim (`int`): + value_embed_dim (`int`): The dimensionality of the value embeddings. Returns: - `torch.FloatTensor`: + `torch.FloatTensor`: The downsampled mask tensor. - + """ o_h = mask.shape[1] o_w = mask.shape[2] @@ -936,14 +937,16 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value if downsampled_area < num_queries: warnings.warn( "The aspect ratio of the mask does not match the aspect ratio of the output image. " - "Please update your masks or adjust the output size for optimal performance.", UserWarning, + "Please update your masks or adjust the output size for optimal performance.", + UserWarning, ) - mask_downsample = F.pad(mask_downsample, (0, num_queries-mask_downsample.shape[1]), value=0.0) + mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0) # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries if downsampled_area > num_queries: warnings.warn( "The aspect ratio of the mask does not match the aspect ratio of the output image. " - "Please update your masks or adjust the output size for optimal performance.", UserWarning, + "Please update your masks or adjust the output size for optimal performance.", + UserWarning, ) mask_downsample = mask_downsample[:, :num_queries] diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0f39637501c9..b1b70d191672 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -18,11 +18,11 @@ import torch.nn.functional as F from torch import nn +from ..image_processor import IPAdapterMaskProcessor from ..utils import USE_PEFT_BACKEND, deprecate, logging from ..utils.import_utils import is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph from .lora import LoRACompatibleLinear, LoRALinearLayer -from ..image_processor import IPAdapterMaskProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -2191,8 +2191,10 @@ def __call__( if ip_adapter_masks is not None: if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: - raise ValueError(" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." - " Please use `IPAdapterMaskProcessor` to preprocess your mask") + raise ValueError( + " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) if len(ip_adapter_masks) != len(self.scale): raise ValueError( f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" @@ -2215,9 +2217,10 @@ def __call__( current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) if mask is not None: - mask_downsample = IPAdapterMaskProcessor.downsample(mask, batch_size, current_ip_hidden_states.shape[1], - current_ip_hidden_states.shape[2]) - + mask_downsample = IPAdapterMaskProcessor.downsample( + mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] + ) + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) current_ip_hidden_states = current_ip_hidden_states * mask_downsample @@ -2362,8 +2365,10 @@ def __call__( if ip_adapter_masks is not None: if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: - raise ValueError(" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." - " Please use `IPAdapterMaskProcessor` to preprocess your mask") + raise ValueError( + " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) if len(ip_adapter_masks) != len(self.scale): raise ValueError( f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" @@ -2393,8 +2398,9 @@ def __call__( current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) if mask is not None: - mask_downsample = IPAdapterMaskProcessor.downsample(mask, batch_size, current_ip_hidden_states.shape[1], - current_ip_hidden_states.shape[2]) + mask_downsample = IPAdapterMaskProcessor.downsample( + mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] + ) mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index aaaa00a8e64d..cd7b285a0f7b 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -31,6 +31,7 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) +from diffusers.image_processor import IPAdapterMaskProcessor from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0 from diffusers.utils import load_image from diffusers.utils.testing_utils import ( @@ -40,7 +41,6 @@ slow, torch_device, ) -from diffusers.image_processor import IPAdapterMaskProcessor enable_full_determinism() @@ -102,12 +102,20 @@ def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_s input_kwargs.update({"image": image, "mask_image": mask, "ip_adapter_image": ip_image}) elif for_masks: - face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png") - face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png") + face_image1 = load_image( + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png" + ) + face_image2 = load_image( + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png" + ) mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png") mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png") - input_kwargs.update({"ip_adapter_image": [[face_image1], [face_image2]], - "cross_attention_kwargs": {"ip_adapter_masks": [mask1, mask2]}}) + input_kwargs.update( + { + "ip_adapter_image": [[face_image1], [face_image2]], + "cross_attention_kwargs": {"ip_adapter_masks": [mask1, mask2]}, + } + ) return input_kwargs @@ -476,7 +484,10 @@ def test_inpainting_sdxl(self): def test_masks(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionXLPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + safety_checker=None, + torch_dtype=self.dtype, ) pipeline.to(torch_device) pipeline.load_ip_adapter( From cb929ff781d34a19fb1d38a69f375393942331b4 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 9 Feb 2024 20:13:58 +0100 Subject: [PATCH 14/21] Update src/diffusers/image_processor.py Co-authored-by: YiYi Xu --- src/diffusers/image_processor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 2caf1e73fc85..9ded05d33921 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -893,6 +893,12 @@ class IPAdapterMaskProcessor(VaeImageProcessor): """ def __init__(self): + do_resize: bool = True, + vae_scale_factor: int = 8, + resample: str = "lanczos", + do_normalize: bool = False, + do_binarize: bool = True, + do_convert_grayscale: bool = True) super().__init__(do_normalize=False, do_binarize=True, do_convert_grayscale=True) @staticmethod From 4763c82c700f8dafc164a10fe3bb43e842498ade Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 9 Feb 2024 20:29:49 +0100 Subject: [PATCH 15/21] Fix init + docstring --- src/diffusers/image_processor.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 9ded05d33921..f6ccfda9fcb8 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -890,16 +890,42 @@ class IPAdapterMaskProcessor(VaeImageProcessor): """ Image processor for IP Adapter image masks. + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the image to 0/1. + do_convert_grayscale (`bool`, *optional*, defaults to be `True`): + Whether to convert the images to grayscale format. + """ - def __init__(self): + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, do_resize: bool = True, vae_scale_factor: int = 8, resample: str = "lanczos", do_normalize: bool = False, do_binarize: bool = True, - do_convert_grayscale: bool = True) - super().__init__(do_normalize=False, do_binarize=True, do_convert_grayscale=True) + do_convert_grayscale: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + resample=resample, + do_normalize=do_normalize, + do_binarize=do_binarize, + do_convert_grayscale=do_convert_grayscale, + ) @staticmethod def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int): From 4115a86273ad68dbc35e2b37605be7b6940432dd Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 10 Feb 2024 13:22:40 +0100 Subject: [PATCH 16/21] Remove unnecessary parameters --- src/diffusers/models/attention_processor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1dc7c78cc0b4..7604f7765e87 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -731,7 +731,6 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, - ip_adapter_masks: Optional[torch.FloatTensor] = None, ) -> torch.Tensor: residual = hidden_states @@ -1208,7 +1207,6 @@ def __call__( attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, - ip_adapter_masks: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: residual = hidden_states if attn.spatial_norm is not None: From b1b990088aafcacd9bc514252b2d4d895bf27b8a Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 10 Feb 2024 19:53:35 +0100 Subject: [PATCH 17/21] Update test --- .../ip_adapters/test_ip_adapter_stable_diffusion.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index cd7b285a0f7b..ca40d15a9112 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -498,11 +498,13 @@ def test_masks(self): inputs = self.get_dummy_inputs(for_masks=True) masks = inputs["cross_attention_kwargs"]["ip_adapter_masks"] processor = IPAdapterMaskProcessor() - masks = processor.process(masks) + masks = processor.preprocess(masks) inputs["cross_attention_kwargs"]["ip_adapter_masks"] = masks images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() expected_slice = np.array( - [0.79571414, 0.7987394, 0.80234784, 0.79982674, 0.798162, 0.80397135, 0.8073128, 0.8062345, 0.8074084] + [0.79474676, 0.7977683, 0.8013954, 0.7988008, 0.7970615, 0.8029355, 0.80614823, 0.8050743, 0.80627424] ) - assert np.allclose(image_slice, expected_slice, atol=1e-3) + + max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) + assert max_diff < 5e-4 From ec923bd2c3a6ec61b3b99823c31eda69b5b2cefa Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 16 Feb 2024 09:42:30 +0100 Subject: [PATCH 18/21] Add test for one mask + bugfix --- src/diffusers/models/attention_processor.py | 4 +-- .../test_ip_adapter_stable_diffusion.py | 31 +++++++++++++++++-- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7604f7765e87..dad8630cd1ee 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2208,7 +2208,7 @@ def __call__( f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" ) else: - ip_adapter_masks = [None] * len(ip_hidden_states) + ip_adapter_masks = [None] * len(self.scale) # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( @@ -2382,7 +2382,7 @@ def __call__( f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" ) else: - ip_adapter_masks = [None] * len(ip_hidden_states) + ip_adapter_masks = [None] * len(self.scale) # for ip-adapter for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index cb292aa0c80e..e27ae82fc0e8 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -483,12 +483,39 @@ def test_inpainting_sdxl(self): max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) assert max_diff < 5e-4 - def test_masks(self): + def test_one_mask(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus-face_sdxl_vit-h.safetensors" + ) + pipeline.set_ip_adapter_scale(0.7) + + inputs = self.get_dummy_inputs(for_masks=True) + mask = inputs["cross_attention_kwargs"]["ip_adapter_masks"][0] + processor = IPAdapterMaskProcessor() + mask = processor.preprocess(mask) + inputs["cross_attention_kwargs"]["ip_adapter_masks"] = mask + inputs["ip_adapter_image"] = inputs["ip_adapter_image"][0] + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + expected_slice = np.array( + [0.7307304, 0.73450166, 0.73731124, 0.7377061, 0.7318013, 0.73720926, 0.74746597, 0.7409929, 0.74074936] + ) + + max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) + assert max_diff < 5e-4 + + def test_multi_mask(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", image_encoder=image_encoder, - safety_checker=None, torch_dtype=self.dtype, ) pipeline.to(torch_device) From 4bc2621a694a64a6b313181d2e76f7389e95fcde Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 16 Feb 2024 10:48:39 +0100 Subject: [PATCH 19/21] Add docs --- docs/source/en/using-diffusers/ip_adapter.md | 80 +++++++++++++++++++ .../test_ip_adapter_stable_diffusion.py | 4 +- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md index b37ef15fc6af..7b29e6014c89 100644 --- a/docs/source/en/using-diffusers/ip_adapter.md +++ b/docs/source/en/using-diffusers/ip_adapter.md @@ -465,3 +465,83 @@ image
   
+ +### IP-Adapter masking + +Binary masks can be used to specify which portion of the output image should be assigned to an IP-Adapter. +For each input IP-Adapter image, a binary mask and an IP-Adapter must be provided. + +Before passing the masks to the pipeline, it's essential to preprocess them using [`IPAdapterMaskProcessor.preprocess()`]. + +> [!TIP] +> For optimal results, provide the output height and width to [`IPAdapterMaskProcessor.preprocess()`]. This ensures that masks with differing aspect ratios are appropriately stretched. If the input masks already match the aspect ratio of the generated image, specifying height and width can be omitted. + +Here an example with two masks: + +```py +from diffusers.image_processor import IPAdapterMaskProcessor + +mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png") +mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png") + +output_height = 1024 +output_width = 1024 + +processor = IPAdapterMaskProcessor() +masks = processor.preprocess([mask1, mask2], height=output_height, width=output_width) +``` + +
+
+ +
mask one
+
+
+ +
mask two
+
+
+ +If you have more than one IP-Adapter image, load them into a list, ensuring each image is assigned to a different IP-Adapter. + +```py +face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png") +face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png") + +ip_images =[[image1], [image2]] + +``` + +
+
+ +
ip adapter image one
+
+
+ +
ip adapter image two
+
+
+ +Pass preprocessed masks to the pipeline using `cross_attention_kwargs` as shown below: + +```py + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2) +pipeline.set_ip_adapter_scale([0.7] * 2) +generator = torch.Generator(device="cpu").manual_seed(0) +num_images=1 + +image = pipeline( + prompt="2 girls", + ip_adapter_image=ip_images, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=20, num_images_per_prompt=num_images, + generator=generator, cross_attention_kwargs={"ip_adapter_masks": masks} +).images[0] +``` + +
+    +
output image
+
diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index e27ae82fc0e8..6289ee887d13 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -483,7 +483,7 @@ def test_inpainting_sdxl(self): max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) assert max_diff < 5e-4 - def test_one_mask(self): + def test_ip_adapter_single_mask(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", @@ -511,7 +511,7 @@ def test_one_mask(self): max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) assert max_diff < 5e-4 - def test_multi_mask(self): + def test_ip_adapter_multiple_masks(self): image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", From 3315e81a58e7764bf9ef502ec31580102fb39afc Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 16 Feb 2024 12:56:26 +0100 Subject: [PATCH 20/21] Docs: update link --- docs/source/en/using-diffusers/ip_adapter.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md index 7b29e6014c89..865f343df560 100644 --- a/docs/source/en/using-diffusers/ip_adapter.md +++ b/docs/source/en/using-diffusers/ip_adapter.md @@ -542,6 +542,6 @@ image = pipeline( ```
-    +   
output image
From c407388e4248e88862e54fa6206f66a19c425190 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 17 Feb 2024 08:45:31 +0100 Subject: [PATCH 21/21] Update tensor conversion --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index dad8630cd1ee..dccdf5bcc31f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2410,7 +2410,7 @@ def __call__( mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] ) - mask_downsample = mask_downsample.to(query.dtype).to(current_ip_hidden_states.device) + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) current_ip_hidden_states = current_ip_hidden_states * mask_downsample