Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
063e085
Add attention masking to attn processors
fabiorigano Feb 4, 2024
d753eec
Move latent image masking
fabiorigano Feb 5, 2024
60336ba
Remove redundant code
fabiorigano Feb 5, 2024
f6451d3
Fix removed line
fabiorigano Feb 5, 2024
bf4eb1d
Add padding
fabiorigano Feb 7, 2024
37419f1
Apply suggestions from code review
fabiorigano Feb 7, 2024
a180e25
Add IPAdapterMaskProcessing
fabiorigano Feb 7, 2024
c6fddae
Fix return types
fabiorigano Feb 7, 2024
708e0eb
Update image_processor
fabiorigano Feb 8, 2024
bbfeb67
Add test
fabiorigano Feb 8, 2024
e11bb7b
Merge branch 'main' into ipadaptermasks
fabiorigano Feb 8, 2024
2af0426
Apply suggestions from code review
fabiorigano Feb 9, 2024
8f6247d
Fix names
fabiorigano Feb 9, 2024
534b9d9
Fix style
fabiorigano Feb 9, 2024
cb929ff
Update src/diffusers/image_processor.py
fabiorigano Feb 9, 2024
4763c82
Fix init + docstring
fabiorigano Feb 9, 2024
21efcd1
Merge branch 'main' of https://github.com/huggingface/diffusers into …
fabiorigano Feb 10, 2024
4115a86
Remove unnecessary parameters
fabiorigano Feb 10, 2024
b1b9900
Update test
fabiorigano Feb 10, 2024
bfe55a7
Merge branch 'main' into ipadaptermasks
sayakpaul Feb 11, 2024
850c909
Merge branch 'main' into ipadaptermasks
sayakpaul Feb 12, 2024
9197ca1
Merge branch 'main' into ipadaptermasks
sayakpaul Feb 15, 2024
ec923bd
Add test for one mask + bugfix
fabiorigano Feb 16, 2024
4bc2621
Add docs
fabiorigano Feb 16, 2024
3315e81
Docs: update link
fabiorigano Feb 16, 2024
c407388
Update tensor conversion
fabiorigano Feb 17, 2024
7037644
Merge branch 'main' into ipadaptermasks
sayakpaul Feb 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions docs/source/en/using-diffusers/ip_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,83 @@ image
<div class="flex justify-center">
    <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png" />
</div>

### IP-Adapter masking

Binary masks can be used to specify which portion of the output image should be assigned to an IP-Adapter.
For each input IP-Adapter image, a binary mask and an IP-Adapter must be provided.

Before passing the masks to the pipeline, it's essential to preprocess them using [`IPAdapterMaskProcessor.preprocess()`].

> [!TIP]
> For optimal results, provide the output height and width to [`IPAdapterMaskProcessor.preprocess()`]. This ensures that masks with differing aspect ratios are appropriately stretched. If the input masks already match the aspect ratio of the generated image, specifying height and width can be omitted.

Here an example with two masks:

```py
from diffusers.image_processor import IPAdapterMaskProcessor

mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png")
mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png")

output_height = 1024
output_width = 1024

processor = IPAdapterMaskProcessor()
masks = processor.preprocess([mask1, mask2], height=output_height, width=output_width)
```

<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">mask one</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">mask two</figcaption>
</div>
</div>

If you have more than one IP-Adapter image, load them into a list, ensuring each image is assigned to a different IP-Adapter.

```py
face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png")
face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png")

ip_images =[[image1], [image2]]

```

<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">ip adapter image one</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">ip adapter image two</figcaption>
</div>
</div>

Pass preprocessed masks to the pipeline using `cross_attention_kwargs` as shown below:

```py

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2)
pipeline.set_ip_adapter_scale([0.7] * 2)
generator = torch.Generator(device="cpu").manual_seed(0)
num_images=1

image = pipeline(
prompt="2 girls",
ip_adapter_image=ip_images,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=20, num_images_per_prompt=num_images,
generator=generator, cross_attention_kwargs={"ip_adapter_masks": masks}
).images[0]
```

<div class="flex justify-center">
    <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_attention_mask_result_seed_0.png" />
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
</div>
106 changes: 106 additions & 0 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import warnings
from typing import List, Optional, Tuple, Union

import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from PIL import Image, ImageFilter, ImageOps

from .configuration_utils import ConfigMixin, register_to_config
Expand Down Expand Up @@ -882,3 +884,107 @@ def preprocess(
depth = self.binarize(depth)

return rgb, depth


class IPAdapterMaskProcessor(VaeImageProcessor):
"""
Image processor for IP Adapter image masks.

Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `False`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the image to 0/1.
do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
Whether to convert the images to grayscale format.

"""

config_name = CONFIG_NAME

@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = False,
do_binarize: bool = True,
do_convert_grayscale: bool = True,
):
super().__init__(
do_resize=do_resize,
vae_scale_factor=vae_scale_factor,
resample=resample,
do_normalize=do_normalize,
do_binarize=do_binarize,
do_convert_grayscale=do_convert_grayscale,
)

@staticmethod
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
"""
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.

Args:
mask (`torch.FloatTensor`):
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
batch_size (`int`):
The batch size.
num_queries (`int`):
The number of queries.
value_embed_dim (`int`):
The dimensionality of the value embeddings.

Returns:
`torch.FloatTensor`:
The downsampled mask tensor.

"""
o_h = mask.shape[1]
o_w = mask.shape[2]
ratio = o_w / o_h
mask_h = int(math.sqrt(num_queries / ratio))
mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
mask_w = num_queries // mask_h

mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)

# Repeat batch_size times
if mask_downsample.shape[0] < batch_size:
mask_downsample = mask_downsample.repeat(batch_size, 1, 1)

mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)

downsampled_area = mask_h * mask_w
# If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
# Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
if downsampled_area < num_queries:
warnings.warn(
"The aspect ratio of the mask does not match the aspect ratio of the output image. "
"Please update your masks or adjust the output size for optimal performance.",
UserWarning,
)
mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
# Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
if downsampled_area > num_queries:
warnings.warn(
"The aspect ratio of the mask does not match the aspect ratio of the output image. "
"Please update your masks or adjust the output size for optimal performance.",
UserWarning,
)
mask_downsample = mask_downsample[:, :num_queries]

# Repeat last dimension to match SDPA output shape
mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
1, 1, value_embed_dim
)

return mask_downsample
79 changes: 63 additions & 16 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.nn.functional as F
from torch import nn

from ..image_processor import IPAdapterMaskProcessor
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
Expand Down Expand Up @@ -2135,12 +2136,13 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale

def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=1.0,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states

Expand Down Expand Up @@ -2195,9 +2197,22 @@ def __call__(
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)

if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
raise ValueError(
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if len(ip_adapter_masks) != len(self.scale):
raise ValueError(
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from what I understand, it only works when we pass 1 image / 1 mask /1 ip-adapter?, if so, let's check the number of images here and throw an error if multiple image are passed

    if ip_hidden_states[0].shape[1] > 1: 
            raise ValueError("...."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you think that? If you perform that check, you will remove all the instant lora functionality.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's only when mask is not None though - you can still use multiple images without mask
and it's only based on the understanding that we can only use one image/one mask/one ip-adapter when we use mask, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it works with multiple images for sure we don't need this!

Copy link
Member

@asomoza asomoza Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it works with multiple images, I tested it, so the only check should be that the number of masks matches the number of ip adapters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we covering these cases in the tests?

else:
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
Expand All @@ -2209,6 +2224,15 @@ def __call__(
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

current_ip_hidden_states = current_ip_hidden_states * mask_downsample

hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
Expand Down Expand Up @@ -2272,12 +2296,13 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale

def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=1.0,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states

Expand Down Expand Up @@ -2346,9 +2371,22 @@ def __call__(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
raise ValueError(
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if len(ip_adapter_masks) != len(self.scale):
raise ValueError(
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
)
else:
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
Expand All @@ -2367,6 +2405,15 @@ def __call__(
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

current_ip_hidden_states = current_ip_hidden_states * mask_downsample

hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
Expand Down
Loading