diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md
index 0ad599c819dc..0df1e0e7a0ad 100644
--- a/docs/source/en/using-diffusers/ip_adapter.md
+++ b/docs/source/en/using-diffusers/ip_adapter.md
@@ -468,3 +468,83 @@ image
+
+### IP-Adapter masking
+
+Binary masks can be used to specify which portion of the output image should be assigned to an IP-Adapter.
+For each input IP-Adapter image, a binary mask and an IP-Adapter must be provided.
+
+Before passing the masks to the pipeline, it's essential to preprocess them using [`IPAdapterMaskProcessor.preprocess()`].
+
+> [!TIP]
+> For optimal results, provide the output height and width to [`IPAdapterMaskProcessor.preprocess()`]. This ensures that masks with differing aspect ratios are appropriately stretched. If the input masks already match the aspect ratio of the generated image, specifying height and width can be omitted.
+
+Here an example with two masks:
+
+```py
+from diffusers.image_processor import IPAdapterMaskProcessor
+
+mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png")
+mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png")
+
+output_height = 1024
+output_width = 1024
+
+processor = IPAdapterMaskProcessor()
+masks = processor.preprocess([mask1, mask2], height=output_height, width=output_width)
+```
+
+
+
+

+
mask one
+
+
+

+
mask two
+
+
+
+If you have more than one IP-Adapter image, load them into a list, ensuring each image is assigned to a different IP-Adapter.
+
+```py
+face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png")
+face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png")
+
+ip_images =[[image1], [image2]]
+
+```
+
+
+
+

+
ip adapter image one
+
+
+

+
ip adapter image two
+
+
+
+Pass preprocessed masks to the pipeline using `cross_attention_kwargs` as shown below:
+
+```py
+
+pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2)
+pipeline.set_ip_adapter_scale([0.7] * 2)
+generator = torch.Generator(device="cpu").manual_seed(0)
+num_images=1
+
+image = pipeline(
+ prompt="2 girls",
+ ip_adapter_image=ip_images,
+ negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
+ num_inference_steps=20, num_images_per_prompt=num_images,
+ generator=generator, cross_attention_kwargs={"ip_adapter_masks": masks}
+).images[0]
+```
+
+
+

+
output image
+
diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py
index f3a5cd3fb914..f6ccfda9fcb8 100644
--- a/src/diffusers/image_processor.py
+++ b/src/diffusers/image_processor.py
@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import math
import warnings
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
+import torch.nn.functional as F
from PIL import Image, ImageFilter, ImageOps
from .configuration_utils import ConfigMixin, register_to_config
@@ -882,3 +884,107 @@ def preprocess(
depth = self.binarize(depth)
return rgb, depth
+
+
+class IPAdapterMaskProcessor(VaeImageProcessor):
+ """
+ Image processor for IP Adapter image masks.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
+ resample (`str`, *optional*, defaults to `lanczos`):
+ Resampling filter to use when resizing the image.
+ do_normalize (`bool`, *optional*, defaults to `False`):
+ Whether to normalize the image to [-1,1].
+ do_binarize (`bool`, *optional*, defaults to `True`):
+ Whether to binarize the image to 0/1.
+ do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
+ Whether to convert the images to grayscale format.
+
+ """
+
+ config_name = CONFIG_NAME
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 8,
+ resample: str = "lanczos",
+ do_normalize: bool = False,
+ do_binarize: bool = True,
+ do_convert_grayscale: bool = True,
+ ):
+ super().__init__(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ resample=resample,
+ do_normalize=do_normalize,
+ do_binarize=do_binarize,
+ do_convert_grayscale=do_convert_grayscale,
+ )
+
+ @staticmethod
+ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
+ """
+ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
+ If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
+
+ Args:
+ mask (`torch.FloatTensor`):
+ The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
+ batch_size (`int`):
+ The batch size.
+ num_queries (`int`):
+ The number of queries.
+ value_embed_dim (`int`):
+ The dimensionality of the value embeddings.
+
+ Returns:
+ `torch.FloatTensor`:
+ The downsampled mask tensor.
+
+ """
+ o_h = mask.shape[1]
+ o_w = mask.shape[2]
+ ratio = o_w / o_h
+ mask_h = int(math.sqrt(num_queries / ratio))
+ mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
+ mask_w = num_queries // mask_h
+
+ mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
+
+ # Repeat batch_size times
+ if mask_downsample.shape[0] < batch_size:
+ mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
+
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
+
+ downsampled_area = mask_h * mask_w
+ # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
+ # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
+ if downsampled_area < num_queries:
+ warnings.warn(
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
+ "Please update your masks or adjust the output size for optimal performance.",
+ UserWarning,
+ )
+ mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
+ # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
+ if downsampled_area > num_queries:
+ warnings.warn(
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
+ "Please update your masks or adjust the output size for optimal performance.",
+ UserWarning,
+ )
+ mask_downsample = mask_downsample[:, :num_queries]
+
+ # Repeat last dimension to match SDPA output shape
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
+ 1, 1, value_embed_dim
+ )
+
+ return mask_downsample
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index d501213956bd..dccdf5bcc31f 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -19,6 +19,7 @@
import torch.nn.functional as F
from torch import nn
+from ..image_processor import IPAdapterMaskProcessor
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
@@ -2135,12 +2136,13 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale
def __call__(
self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- scale=1.0,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states
@@ -2195,9 +2197,22 @@ def __call__(
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
+ if ip_adapter_masks is not None:
+ if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
+ raise ValueError(
+ " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+ )
+ if len(ip_adapter_masks) != len(self.scale):
+ raise ValueError(
+ f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
+ )
+ else:
+ ip_adapter_masks = [None] * len(self.scale)
+
# for ip-adapter
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
@@ -2209,6 +2224,15 @@ def __call__(
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
+ if mask is not None:
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
+ )
+
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+
+ current_ip_hidden_states = current_ip_hidden_states * mask_downsample
+
hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
@@ -2272,12 +2296,13 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale
def __call__(
self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- scale=1.0,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states
@@ -2346,9 +2371,22 @@ def __call__(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
+ if ip_adapter_masks is not None:
+ if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
+ raise ValueError(
+ " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+ )
+ if len(ip_adapter_masks) != len(self.scale):
+ raise ValueError(
+ f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
+ )
+ else:
+ ip_adapter_masks = [None] * len(self.scale)
+
# for ip-adapter
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
@@ -2367,6 +2405,15 @@ def __call__(
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
+ if mask is not None:
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
+ )
+
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+
+ current_ip_hidden_states = current_ip_hidden_states * mask_downsample
+
hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
index 11066253c518..6289ee887d13 100644
--- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
+++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
@@ -31,6 +31,7 @@
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
+from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
@@ -64,7 +65,7 @@ def get_image_processor(self, repo_id):
image_processor = CLIPImageProcessor.from_pretrained(repo_id)
return image_processor
- def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False):
+ def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False, for_masks=False):
image = load_image(
"https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png"
)
@@ -101,6 +102,22 @@ def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_s
input_kwargs.update({"image": image, "mask_image": mask, "ip_adapter_image": ip_image})
+ elif for_masks:
+ face_image1 = load_image(
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png"
+ )
+ face_image2 = load_image(
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png"
+ )
+ mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png")
+ mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png")
+ input_kwargs.update(
+ {
+ "ip_adapter_image": [[face_image1], [face_image2]],
+ "cross_attention_kwargs": {"ip_adapter_masks": [mask1, mask2]},
+ }
+ )
+
return input_kwargs
@@ -465,3 +482,58 @@ def test_inpainting_sdxl(self):
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
+
+ def test_ip_adapter_single_mask(self):
+ image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ image_encoder=image_encoder,
+ torch_dtype=self.dtype,
+ )
+ pipeline.to(torch_device)
+ pipeline.load_ip_adapter(
+ "h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus-face_sdxl_vit-h.safetensors"
+ )
+ pipeline.set_ip_adapter_scale(0.7)
+
+ inputs = self.get_dummy_inputs(for_masks=True)
+ mask = inputs["cross_attention_kwargs"]["ip_adapter_masks"][0]
+ processor = IPAdapterMaskProcessor()
+ mask = processor.preprocess(mask)
+ inputs["cross_attention_kwargs"]["ip_adapter_masks"] = mask
+ inputs["ip_adapter_image"] = inputs["ip_adapter_image"][0]
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+ expected_slice = np.array(
+ [0.7307304, 0.73450166, 0.73731124, 0.7377061, 0.7318013, 0.73720926, 0.74746597, 0.7409929, 0.74074936]
+ )
+
+ max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
+ assert max_diff < 5e-4
+
+ def test_ip_adapter_multiple_masks(self):
+ image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ image_encoder=image_encoder,
+ torch_dtype=self.dtype,
+ )
+ pipeline.to(torch_device)
+ pipeline.load_ip_adapter(
+ "h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2
+ )
+ pipeline.set_ip_adapter_scale([0.7] * 2)
+
+ inputs = self.get_dummy_inputs(for_masks=True)
+ masks = inputs["cross_attention_kwargs"]["ip_adapter_masks"]
+ processor = IPAdapterMaskProcessor()
+ masks = processor.preprocess(masks)
+ inputs["cross_attention_kwargs"]["ip_adapter_masks"] = masks
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+ expected_slice = np.array(
+ [0.79474676, 0.7977683, 0.8013954, 0.7988008, 0.7970615, 0.8029355, 0.80614823, 0.8050743, 0.80627424]
+ )
+
+ max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
+ assert max_diff < 5e-4